`_.
diff --git a/docs/ffi.ipynb b/docs/ffi.ipynb
index 9dc49a74ec36..a8cd5219d4b5 100644
--- a/docs/ffi.ipynb
+++ b/docs/ffi.ipynb
@@ -4,7 +4,9 @@
"cell_type": "markdown",
"metadata": {},
"source": [
- "# JAX's foreign function interface\n",
+ "(ffi-tutorial)=\n",
+ "\n",
+ "# Foreign function interface (FFI)\n",
"\n",
"_This tutorial requires JAX v0.4.31 or newer._\n",
"\n",
@@ -362,7 +364,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
- "We can inspect the [jaxpr](understanding-jaxprs) of the {func}`~jax.vmap` of `rms_norm` to confirm that it isn't being rewritten using {func}`~jax.lax.scan`:"
+ "We can inspect the [jaxpr](jax-internals-jaxpr) of the {func}`~jax.vmap` of `rms_norm` to confirm that it isn't being rewritten using {func}`~jax.lax.scan`:"
]
},
{
@@ -404,7 +406,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
- "If your foreign function provides an efficient batching rule that isn't supported by this simple `vectorized` parameter, it might also be possible to define more flexible custom `vmap` rules using the experimental `custom_vmap` interface, but it's worth also opening an issue describing your use case on [the JAX issue tracker](https://github.com/google/jax/issues)."
+ "If your foreign function provides an efficient batching rule that isn't supported by this simple `vectorized` parameter, it might also be possible to define more flexible custom `vmap` rules using the experimental `custom_vmap` interface, but it's worth also opening an issue describing your use case on [the JAX issue tracker](https://github.com/jax-ml/jax/issues)."
]
},
{
@@ -490,7 +492,7 @@
"source": [
"At this point, we can use our new `rms_norm` function transparently for many JAX applications, and it will transform appropriately under the standard JAX function transformations like {func}`~jax.vmap` and {func}`~jax.grad`.\n",
"One thing that this example doesn't support is forward-mode AD ({func}`jax.jvp`, for example) since {func}`~jax.custom_vjp` is restricted to reverse-mode.\n",
- "JAX doesn't currently expose a public API for simultaneously customizing both forward-mode and reverse-mode AD, but such an API is on the roadmap, so please [open an issue](https://github.com/google/jax/issues) describing you use case if you hit this limitation in practice.\n",
+ "JAX doesn't currently expose a public API for simultaneously customizing both forward-mode and reverse-mode AD, but such an API is on the roadmap, so please [open an issue](https://github.com/jax-ml/jax/issues) describing you use case if you hit this limitation in practice.\n",
"\n",
"One other JAX feature that this example doesn't support is higher-order AD.\n",
"It would be possible to work around this by wrapping the `res_norm_bwd` function above in a {func}`jax.custom_jvp` or {func}`jax.custom_vjp` decorator, but we won't go into the details of that advanced use case here.\n",
diff --git a/docs/ffi.md b/docs/ffi.md
index aa861d9a094f..cc3863ed99b2 100644
--- a/docs/ffi.md
+++ b/docs/ffi.md
@@ -5,14 +5,16 @@ jupytext:
extension: .md
format_name: myst
format_version: 0.13
- jupytext_version: 1.16.1
+ jupytext_version: 1.16.4
kernelspec:
display_name: Python 3 (ipykernel)
language: python
name: python3
---
-# JAX's foreign function interface
+(ffi-tutorial)=
+
+# Foreign function interface (FFI)
_This tutorial requires JAX v0.4.31 or newer._
@@ -309,7 +311,7 @@ Our implementation of `rms_norm` has the appropriate semantics, and it supports
np.testing.assert_allclose(jax.vmap(rms_norm)(x), jax.vmap(rms_norm_ref)(x), rtol=1e-5)
```
-We can inspect the [jaxpr](understanding-jaxprs) of the {func}`~jax.vmap` of `rms_norm` to confirm that it isn't being rewritten using {func}`~jax.lax.scan`:
+We can inspect the [jaxpr](jax-internals-jaxpr) of the {func}`~jax.vmap` of `rms_norm` to confirm that it isn't being rewritten using {func}`~jax.lax.scan`:
```{code-cell} ipython3
jax.make_jaxpr(jax.vmap(rms_norm))(x)
@@ -331,7 +333,7 @@ def rms_norm_not_vectorized(x, eps=1e-5):
jax.make_jaxpr(jax.vmap(rms_norm_not_vectorized))(x)
```
-If your foreign function provides an efficient batching rule that isn't supported by this simple `vectorized` parameter, it might also be possible to define more flexible custom `vmap` rules using the experimental `custom_vmap` interface, but it's worth also opening an issue describing your use case on [the JAX issue tracker](https://github.com/google/jax/issues).
+If your foreign function provides an efficient batching rule that isn't supported by this simple `vectorized` parameter, it might also be possible to define more flexible custom `vmap` rules using the experimental `custom_vmap` interface, but it's worth also opening an issue describing your use case on [the JAX issue tracker](https://github.com/jax-ml/jax/issues).
+++
@@ -404,7 +406,7 @@ np.testing.assert_allclose(
At this point, we can use our new `rms_norm` function transparently for many JAX applications, and it will transform appropriately under the standard JAX function transformations like {func}`~jax.vmap` and {func}`~jax.grad`.
One thing that this example doesn't support is forward-mode AD ({func}`jax.jvp`, for example) since {func}`~jax.custom_vjp` is restricted to reverse-mode.
-JAX doesn't currently expose a public API for simultaneously customizing both forward-mode and reverse-mode AD, but such an API is on the roadmap, so please [open an issue](https://github.com/google/jax/issues) describing you use case if you hit this limitation in practice.
+JAX doesn't currently expose a public API for simultaneously customizing both forward-mode and reverse-mode AD, but such an API is on the roadmap, so please [open an issue](https://github.com/jax-ml/jax/issues) describing you use case if you hit this limitation in practice.
One other JAX feature that this example doesn't support is higher-order AD.
It would be possible to work around this by wrapping the `res_norm_bwd` function above in a {func}`jax.custom_jvp` or {func}`jax.custom_vjp` decorator, but we won't go into the details of that advanced use case here.
diff --git a/docs/glossary.rst b/docs/glossary.rst
index 78b7fcd246f3..286b07e21a66 100644
--- a/docs/glossary.rst
+++ b/docs/glossary.rst
@@ -1,5 +1,5 @@
-JAX Glossary of Terms
-=====================
+Glossary of terms
+=================
.. glossary::
@@ -28,9 +28,9 @@ JAX Glossary of Terms
able to target GPUs for fast operations on arrays (see also :term:`CPU` and :term:`TPU`).
jaxpr
- Short for *JAX Expression*, a jaxpr is an intermediate representation of a computation that
+ Short for *JAX expression*, a jaxpr is an intermediate representation of a computation that
is generated by JAX, and is forwarded to :term:`XLA` for compilation and execution.
- See :ref:`understanding-jaxprs` for more discussion and examples.
+ See :ref:`jax-internals-jaxpr` for more discussion and examples.
JIT
Short for *Just In Time* compilation, JIT in JAX generally refers to the compilation of
diff --git a/docs/_tutorials/gradient-checkpointing.md b/docs/gradient-checkpointing.md
similarity index 99%
rename from docs/_tutorials/gradient-checkpointing.md
rename to docs/gradient-checkpointing.md
index b768514e4bb0..14a532b54dd1 100644
--- a/docs/_tutorials/gradient-checkpointing.md
+++ b/docs/gradient-checkpointing.md
@@ -5,7 +5,7 @@ jupytext:
extension: .md
format_name: myst
format_version: 0.13
- jupytext_version: 1.16.1
+ jupytext_version: 1.16.4
kernelspec:
display_name: Python 3
language: python
diff --git a/docs/index.rst b/docs/index.rst
index 2e13c109dbbe..2dd856ab88ef 100644
--- a/docs/index.rst
+++ b/docs/index.rst
@@ -1,13 +1,9 @@
-JAX: High-Performance Array Computing
+JAX: High performance array computing
=====================================
JAX is a Python library for accelerator-oriented array computation and program transformation,
designed for high-performance numerical computing and large-scale machine learning.
-If you're looking to train neural networks, use Flax_ and start with its documentation.
-Some associated tools are Optax_ and Orbax_.
-For an end-to-end transformer library built on JAX, see MaxText_.
-
.. grid:: 3
:margin: 0
:padding: 0
@@ -27,7 +23,7 @@ For an end-to-end transformer library built on JAX, see MaxText_.
JAX includes composable function transformations for compilation, batching, automatic differentiation, and parallelization.
- .. grid-item-card:: Run Anywhere
+ .. grid-item-card:: Run anywhere
:columns: 12 6 6 4
:class-card: sd-border-0
:shadow: None
@@ -36,34 +32,95 @@ For an end-to-end transformer library built on JAX, see MaxText_.
.. grid:: 3
- .. grid-item-card:: :material-regular:`rocket_launch;2em` Getting Started
+ .. grid-item-card:: :material-regular:`rocket_launch;2em` Getting started
:columns: 12 6 6 4
:link: beginner-guide
:link-type: ref
:class-card: getting-started
- .. grid-item-card:: :material-regular:`library_books;2em` User Guides
+ .. grid-item-card:: :material-regular:`library_books;2em` User guides
:columns: 12 6 6 4
:link: user-guides
:link-type: ref
:class-card: user-guides
- .. grid-item-card:: :material-regular:`laptop_chromebook;2em` Developer Docs
+ .. grid-item-card:: :material-regular:`laptop_chromebook;2em` Developer notes
:columns: 12 6 6 4
:link: contributor-guide
:link-type: ref
:class-card: developer-docs
+If you're looking to train neural networks, use Flax_ and start with its tutorials.
+For an end-to-end transformer library built on JAX, see MaxText_.
+
+Ecosystem
+---------
+JAX itself is narrowly-scoped and focuses on efficient array operations & program
+transformations. Built around JAX is an evolving ecosystem of machine learning and
+numerical computing tools; the following is just a small sample of what is out there:
+
+.. grid:: 4
+ :class-container: ecosystem-grid
+
+ .. grid-item:: :material-outlined:`hub;2em` **Neural networks**
+
+ - Flax_
+ - NNX_
+ - Equinox_
+ - Keras_
+
+ .. grid-item:: :material-regular:`show_chart;2em` **Optimizers & solvers**
+
+ - Optax_
+ - Optimistix_
+ - Lineax_
+ - Diffrax_
+
+ .. grid-item:: :material-outlined:`storage;2em` **Data loading**
+
+ - Grain_
+ - `Tensorflow datasets`_
+ - `Hugging Face datasets`_
+
+ .. grid-item:: :material-regular:`construction;2em` **Miscellaneous tools**
+
+ - Orbax_
+ - Chex_
+
+ .. grid-item:: :material-regular:`lan;2em` **Probabilistic programming**
+
+ - Blackjax_
+ - Numpyro_
+ - PyMC_
+
+ .. grid-item:: :material-regular:`bar_chart;2em` **Probabilistic modeling**
+
+ - `Tensorflow probabilty`_
+ - Distrax_
+
+ .. grid-item:: :material-outlined:`animation;2em` **Physics & simulation**
+
+ - `JAX MD`_
+ - Brax_
+
+ .. grid-item:: :material-regular:`language;2em` **LLMs**
+
+ - MaxText_
+ - AXLearn_
+ - Levanter_
+ - EasyLM_
+
+
+Many more JAX-based libraries have been developed; the community-run `Awesome JAX`_ page
+maintains an up-to-date list.
.. toctree::
:hidden:
:maxdepth: 1
- :caption: Getting Started
+ :caption: Getting started
installation
quickstart
- notebooks/Common_Gotchas_in_JAX
- faq
.. toctree::
:hidden:
@@ -71,16 +128,19 @@ For an end-to-end transformer library built on JAX, see MaxText_.
tutorials
+ notebooks/Common_Gotchas_in_JAX
+
+ faq
.. toctree::
:hidden:
:maxdepth: 2
- :caption: Further Resources
+ :caption: More guides/resources
user_guides
advanced_guide
contributor_guide
- building_on_jax
+ extensions
notes
jax
@@ -93,7 +153,28 @@ For an end-to-end transformer library built on JAX, see MaxText_.
glossary
+.. _Awesome JAX: https://github.com/n2cholas/awesome-jax
+.. _AXLearn: https://github.com/apple/axlearn
+.. _Blackjax: https://blackjax-devs.github.io/blackjax/
+.. _Brax: https://github.com/google/brax/
+.. _Chex: https://chex.readthedocs.io/
+.. _Diffrax: https://docs.kidger.site/diffrax/
+.. _Distrax: https://github.com/google-deepmind/distrax
+.. _EasyLM: https://github.com/young-geng/EasyLM
+.. _Equinox: https://docs.kidger.site/equinox/
.. _Flax: https://flax.readthedocs.io/
-.. _Orbax: https://orbax.readthedocs.io/
-.. _Optax: https://optax.readthedocs.io/
+.. _Grain: https://github.com/google/grain
+.. _Hugging Face datasets: https://huggingface.co/docs/datasets/
+.. _JAX MD: https://jax-md.readthedocs.io/
+.. _Keras: https://keras.io/
+.. _Levanter: https://github.com/stanford-crfm/levanter
+.. _Lineax: https://github.com/patrick-kidger/lineax
.. _MaxText: https://github.com/google/maxtext/
+.. _NNX: https://flax.readthedocs.io/en/latest/nnx/
+.. _Numpyro: https://num.pyro.ai/en/latest/index.html
+.. _Optax: https://optax.readthedocs.io/
+.. _Optimistix: https://github.com/patrick-kidger/optimistix
+.. _Orbax: https://orbax.readthedocs.io/
+.. _PyMC: https://www.pymc.io/
+.. _Tensorflow datasets: https://www.tensorflow.org/datasets
+.. _Tensorflow probabilty: https://www.tensorflow.org/probability
diff --git a/docs/installation.md b/docs/installation.md
index 20ffe436ff8a..acb802ea939c 100644
--- a/docs/installation.md
+++ b/docs/installation.md
@@ -1,5 +1,5 @@
(installation)=
-# Installing JAX
+# Installation
@@ -7,7 +7,7 @@ Using JAX requires installing two packages: `jax`, which is pure Python and
cross-platform, and `jaxlib` which contains compiled binaries, and requires
different builds for different operating systems and accelerators.
-**TL;DR** For most users, a typical JAX installation may look something like this:
+**Summary:** For most users, a typical JAX installation may look something like this:
* **CPU-only (Linux/macOS/Windows)**
```
@@ -176,7 +176,7 @@ installation.
JAX requires libdevice10.bc, which typically comes from the cuda-nvvm package.
Make sure that it is present in your CUDA installation.
-Please let the JAX team know on [the GitHub issue tracker](https://github.com/google/jax/issues)
+Please let the JAX team know on [the GitHub issue tracker](https://github.com/jax-ml/jax/issues)
if you run into any errors or problems with the pre-built wheels.
(docker-containers-nvidia-gpu)=
@@ -216,7 +216,7 @@ refer to
**Note:** There are several caveats with the Metal plugin:
* The Metal plugin is new and experimental and has a number of
- [known issues](https://github.com/google/jax/issues?q=is%3Aissue+is%3Aopen+label%3A%22Apple+GPU+%28Metal%29+plugin%22).
+ [known issues](https://github.com/jax-ml/jax/issues?q=is%3Aissue+is%3Aopen+label%3A%22Apple+GPU+%28Metal%29+plugin%22).
Please report any issues on the JAX issue tracker.
* The Metal plugin currently requires very specific versions of `jax` and
`jaxlib`. This restriction will be relaxed over time as the plugin API
@@ -269,22 +269,26 @@ for more details.
Nightly releases reflect the state of the main JAX repository at the time they are
built, and may not pass the full test suite.
+Unlike the instructions for installing a JAX release, here we name all of JAX's
+packages explicitly on the command line, so `pip` will upgrade them if a newer
+version is available.
+
- CPU only:
```bash
-pip install -U --pre jax -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html
+pip install -U --pre jax jaxlib -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html
```
- Google Cloud TPU:
```bash
-pip install -U --pre jax[tpu] -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
+pip install -U --pre jax jaxlib libtpu-nightly requests -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
```
- NVIDIA GPU (CUDA 12):
```bash
-pip install -U --pre jax[cuda12] -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html
+pip install -U --pre jax jaxlib jax-cuda12-plugin[with_cuda] jax-cuda12-pjrt -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html
```
- NVIDIA GPU (CUDA 12) legacy:
diff --git a/docs/investigating_a_regression.md b/docs/investigating_a_regression.md
index 4affae3a65d8..61d219d1bae1 100644
--- a/docs/investigating_a_regression.md
+++ b/docs/investigating_a_regression.md
@@ -9,7 +9,7 @@ Let's first make a JAX issue.
But if you can pinpoint the commit that triggered the regression, it will really help us.
This document explains how we identified the commit that caused a
-[15% performance regression](https://github.com/google/jax/issues/17686).
+[15% performance regression](https://github.com/jax-ml/jax/issues/17686).
## Steps
@@ -23,9 +23,9 @@ Here is a suggested investigation strategy:
2. Hourly recompilation while keeping XLA and JAX in sync.
3. Final verification: maybe a manual check of a few commits (or a git bisect).
-## Nightly investigation.
+## Nightly investigation
-This can be done by using [JAX-Toolbox nightly
+This can be done by using the [NVIDIA JAX-Toolbox nightly
containers](https://github.com/NVIDIA/JAX-Toolbox).
- Some days, bugs prevent the container from being built, or there are temporary regressions. Just discard those days.
@@ -34,7 +34,7 @@ containers](https://github.com/NVIDIA/JAX-Toolbox).
- test_runner.sh: will start the containers and the test.
- test.sh: will install missing dependencies and run the test
-Here are real example scripts used for the issue: https://github.com/google/jax/issues/17686
+Here are real example scripts used for the issue: https://github.com/jax-ml/jax/issues/17686
- test_runner.sh:
```
for m in 7 8 9; do
@@ -128,7 +128,7 @@ investigate hourly between 8-24 and 8-26. There was a smaller slowdown
earlier, lets ignore it for this example. It would be only another
hourly investigation between those dates.
-## Hourly investigation.
+## Hourly investigation
This does a checkout of JAX and XLA at each hour between the 2 dates,
rebuilds everything and runs the test. The scripts are structured
diff --git a/docs/_tutorials/jax-primitives.md b/docs/jax-primitives.md
similarity index 99%
rename from docs/_tutorials/jax-primitives.md
rename to docs/jax-primitives.md
index 51abe2916693..abdc8be6d0a8 100644
--- a/docs/_tutorials/jax-primitives.md
+++ b/docs/jax-primitives.md
@@ -5,7 +5,7 @@ jupytext:
extension: .md
format_name: myst
format_version: 0.13
- jupytext_version: 1.16.1
+ jupytext_version: 1.16.4
kernelspec:
display_name: Python 3
language: python
@@ -306,7 +306,7 @@ from jax.interpreters import mlir
mlir.register_lowering(multiply_add_p, multiply_add_lowering, platform='cpu')
```
-You will now succeed to apply `jax.jit`. Notice below that JAX first evaluates the function abstractly, which triggers the `multiply_add_abstract_eval` function, and then compiles the set of primitives it has encountered, including `multiply_add`. At this point JAX invokes `multiply_add_xla_translation`.
+You will now succeed to apply `jax.jit`. Notice below that JAX first evaluates the function abstractly, which triggers the `multiply_add_abstract_eval` function, and then compiles the set of primitives it has encountered, including `multiply_add`. At this point JAX invokes `multiply_add_lowering`.
```{code-cell}
assert api.jit(lambda x, y: square_add_prim(x, y))(2., 10.) == 14.
diff --git a/docs/jax.lax.rst b/docs/jax.lax.rst
index 7b19955d3d78..3a03665b3217 100644
--- a/docs/jax.lax.rst
+++ b/docs/jax.lax.rst
@@ -119,6 +119,7 @@ Operators
ne
neg
nextafter
+ optimization_barrier
pad
platform_dependent
polygamma
@@ -248,6 +249,7 @@ Argument classes
.. autoclass:: ConvDimensionNumbers
.. autoclass:: ConvGeneralDilatedDimensionNumbers
+.. autoclass:: DotAlgorithm
.. autoclass:: GatherDimensionNumbers
.. autoclass:: GatherScatterMode
.. autoclass:: Precision
diff --git a/docs/jax.random.rst b/docs/jax.random.rst
index 9d6369d2d2b1..6c5427c05e66 100644
--- a/docs/jax.random.rst
+++ b/docs/jax.random.rst
@@ -12,13 +12,13 @@ Key Creation & Manipulation
.. autosummary::
:toctree: _autosummary
- PRNGKey
key
key_data
wrap_key_data
fold_in
split
clone
+ PRNGKey
Random Samplers
~~~~~~~~~~~~~~~
diff --git a/docs/jax.rst b/docs/jax.rst
index b112490a0912..ecfeaf29e3c0 100644
--- a/docs/jax.rst
+++ b/docs/jax.rst
@@ -1,7 +1,7 @@
.. currentmodule:: jax
-Public API: jax package
-=======================
+Public API: ``jax`` package
+===========================
Subpackages
-----------
@@ -69,7 +69,6 @@ Just-in-time compilation (:code:`jit`)
jit
disable_jit
ensure_compile_time_eval
- xla_computation
make_jaxpr
eval_shape
ShapeDtypeStruct
@@ -92,6 +91,7 @@ Automatic differentiation
grad
value_and_grad
+ jacobian
jacfwd
jacrev
hessian
@@ -99,12 +99,29 @@ Automatic differentiation
linearize
linear_transpose
vjp
- custom_jvp
- custom_vjp
custom_gradient
closure_convert
checkpoint
+``custom_jvp``
+~~~~~~~~~~~~~~
+
+.. autosummary::
+ :toctree: _autosummary
+
+ custom_jvp
+ custom_jvp.defjvp
+ custom_jvp.defjvps
+
+``custom_vjp``
+~~~~~~~~~~~~~~
+
+.. autosummary::
+ :toctree: _autosummary
+
+ custom_vjp
+ custom_vjp.defvjp
+
jax.Array (:code:`jax.Array`)
-----------------------------
@@ -116,6 +133,73 @@ jax.Array (:code:`jax.Array`)
make_array_from_single_device_arrays
make_array_from_process_local_data
+Array properties and methods
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+.. autosummary::
+ :toctree: _autosummary
+
+ Array.addressable_shards
+ Array.all
+ Array.any
+ Array.argmax
+ Array.argmin
+ Array.argpartition
+ Array.argsort
+ Array.astype
+ Array.at
+ Array.choose
+ Array.clip
+ Array.compress
+ Array.conj
+ Array.conjugate
+ Array.copy
+ Array.copy_to_host_async
+ Array.cumprod
+ Array.cumsum
+ Array.device
+ Array.diagonal
+ Array.dot
+ Array.dtype
+ Array.flat
+ Array.flatten
+ Array.global_shards
+ Array.imag
+ Array.is_fully_addressable
+ Array.is_fully_replicated
+ Array.item
+ Array.itemsize
+ Array.max
+ Array.mean
+ Array.min
+ Array.nbytes
+ Array.ndim
+ Array.nonzero
+ Array.prod
+ Array.ptp
+ Array.ravel
+ Array.real
+ Array.repeat
+ Array.reshape
+ Array.round
+ Array.searchsorted
+ Array.shape
+ Array.sharding
+ Array.size
+ Array.sort
+ Array.squeeze
+ Array.std
+ Array.sum
+ Array.swapaxes
+ Array.take
+ Array.to_device
+ Array.trace
+ Array.transpose
+ Array.var
+ Array.view
+ Array.T
+ Array.mT
+
Vectorization (:code:`vmap`)
----------------------------
@@ -138,6 +222,7 @@ Parallelization (:code:`pmap`)
device_count
local_device_count
process_count
+ process_indices
Callbacks
---------
diff --git a/docs/jax.scipy.rst b/docs/jax.scipy.rst
index f6d8a151440b..abdf5069ee08 100644
--- a/docs/jax.scipy.rst
+++ b/docs/jax.scipy.rst
@@ -164,6 +164,7 @@ jax.scipy.special
expit
expn
factorial
+ fresnel
gamma
gammainc
gammaincc
diff --git a/docs/jax_internal_api.rst b/docs/jax_internal_api.rst
index fe65054d22c1..1ece596d88ef 100644
--- a/docs/jax_internal_api.rst
+++ b/docs/jax_internal_api.rst
@@ -1,5 +1,5 @@
-Internal APIs
-=============
+Internal API reference
+======================
core
----
diff --git a/docs/_tutorials/jaxpr.md b/docs/jaxpr.md
similarity index 99%
rename from docs/_tutorials/jaxpr.md
rename to docs/jaxpr.md
index 9fe990c0a8ba..974ed39c1663 100644
--- a/docs/_tutorials/jaxpr.md
+++ b/docs/jaxpr.md
@@ -5,7 +5,7 @@ jupytext:
extension: .md
format_name: myst
format_version: 0.13
- jupytext_version: 1.16.1
+ jupytext_version: 1.16.4
kernelspec:
display_name: Python 3
language: python
diff --git a/docs/jaxpr.rst b/docs/jaxpr.rst
deleted file mode 100644
index 56be62162a9e..000000000000
--- a/docs/jaxpr.rst
+++ /dev/null
@@ -1,472 +0,0 @@
-.. _understanding-jaxprs:
-
-Understanding Jaxprs
-====================
-
-Updated: May 3, 2020 (for commit f1a46fe).
-
-Conceptually, one can think of JAX transformations as first trace-specializing
-the Python function to be transformed into a small and well-behaved
-intermediate form that is then interpreted with transformation-specific
-interpretation rules. One of the reasons JAX can pack so much power into such a
-small software package is that it starts with a familiar and flexible
-programming interface (Python with NumPy) and it uses the actual Python
-interpreter to do most of the heavy lifting to distill the essence of the
-computation into a simple statically-typed expression language with limited
-higher-order features. That language is the jaxpr language.
-
-Not all Python programs can be processed this way, but it turns out that many
-scientific computing and machine learning programs can.
-
-Before we proceed, it is important to point out that not all JAX
-transformations literally materialize a jaxpr as described above; some, e.g.,
-differentiation or batching, will apply transformations incrementally during
-tracing. Nevertheless, if one wants to understand how JAX works internally, or
-to make use of the result of JAX tracing, it is useful to understand jaxprs.
-
-A jaxpr instance represents a function with one or more typed parameters (input
-variables) and one or more typed results. The results depend only on the input
-variables; there are no free variables captured from enclosing scopes. The
-inputs and outputs have types, which in JAX are represented as abstract values.
-There are two related representations in the code for jaxprs,
-:py:class:`jax.core.Jaxpr` and :py:class:`jax.core.ClosedJaxpr`. A
-:py:class:`jax.core.ClosedJaxpr` represents a partially-applied
-:py:class:`jax.core.Jaxpr`, and is what you obtain when you use
-:py:func:`jax.make_jaxpr` to inspect jaxprs. It has the following fields:
-
- * ``jaxpr`` is a :py:class:`jax.core.Jaxpr` representing the actual
- computation content of the function (described below).
- * ``consts`` is a list of constants.
-
-The most interesting part of the ClosedJaxpr is the actual execution content,
-represented as a :py:class:`jax.core.Jaxpr` as printed using the following
-grammar::
-
- Jaxpr ::= { lambda Var* ; Var+. let
- Eqn*
- in [Expr+] }
-
-where:
- * The parameters of the jaxpr are shown as two lists of variables separated by
- ``;``. The first set of variables are the ones that have been introduced
- to stand for constants that have been hoisted out. These are called the
- ``constvars``, and in a :py:class:`jax.core.ClosedJaxpr` the ``consts``
- field holds corresponding values. The second list of variables, called
- ``invars``, correspond to the inputs of the traced Python function.
- * ``Eqn*`` is a list of equations, defining intermediate variables referring to
- intermediate expressions. Each equation defines one or more variables as the
- result of applying a primitive on some atomic expressions. Each equation uses only
- input variables and intermediate variables defined by previous equations.
- * ``Expr+``: is a list of output atomic expressions (literals or variables)
- for the jaxpr.
-
-Equations are printed as follows::
-
- Eqn ::= Var+ = Primitive [ Param* ] Expr+
-
-where:
- * ``Var+`` are one or more intermediate variables to be defined as the output
- of a primitive invocation (some primitives can return multiple values).
- * ``Expr+`` are one or more atomic expressions, each either a variable or a
- literal constant. A special variable ``unitvar`` or literal ``unit``,
- printed as ``*``, represents a value that is not needed
- in the rest of the computation and has been elided. That is, units are just
- placeholders.
- * ``Param*`` are zero or more named parameters to the primitive, printed in
- square brackets. Each parameter is shown as ``Name = Value``.
-
-
-Most jaxpr primitives are first-order (they take just one or more ``Expr`` as arguments)::
-
- Primitive := add | sub | sin | mul | ...
-
-
-The jaxpr primitives are documented in the :py:mod:`jax.lax` module.
-
-For example, here is the jaxpr produced for the function ``func1`` below
-
->>> from jax import make_jaxpr
->>> import jax.numpy as jnp
->>> def func1(first, second):
-... temp = first + jnp.sin(second) * 3.
-... return jnp.sum(temp)
-...
->>> print(make_jaxpr(func1)(jnp.zeros(8), jnp.ones(8)))
-{ lambda ; a:f32[8] b:f32[8]. let
- c:f32[8] = sin b
- d:f32[8] = mul c 3.0
- e:f32[8] = add a d
- f:f32[] = reduce_sum[axes=(0,)] e
- in (f,) }
-
-Here there are no constvars, ``a`` and ``b`` are the input variables
-and they correspond respectively to
-``first`` and ``second`` function parameters. The scalar literal ``3.0`` is kept
-inline.
-The ``reduce_sum`` primitive has named parameter ``axes``, in addition to the
-operand ``e``.
-
-Note that even though execution of a program that calls into JAX builds a jaxpr,
-Python-level control-flow and Python-level functions execute normally.
-This means that just because a Python program contains functions and control-flow,
-the resulting jaxpr does not have to contain control-flow or higher-order features.
-
-For example, when tracing the function ``func3`` JAX will inline the call to
-``inner`` and the conditional ``if second.shape[0] > 4``, and will produce the same
-jaxpr as before
-
->>> def func2(inner, first, second):
-... temp = first + inner(second) * 3.
-... return jnp.sum(temp)
-...
->>> def inner(second):
-... if second.shape[0] > 4:
-... return jnp.sin(second)
-... else:
-... assert False
-...
->>> def func3(first, second):
-... return func2(inner, first, second)
-...
->>> print(make_jaxpr(func3)(jnp.zeros(8), jnp.ones(8)))
-{ lambda ; a:f32[8] b:f32[8]. let
- c:f32[8] = sin b
- d:f32[8] = mul c 3.0
- e:f32[8] = add a d
- f:f32[] = reduce_sum[axes=(0,)] e
- in (f,) }
-
-
-Handling PyTrees
-----------------
-
-In jaxpr there are no tuple types; instead primitives take multiple inputs
-and produce multiple outputs. When processing a function that has structured
-inputs or outputs, JAX will flatten those and in jaxpr they will appear as lists
-of inputs and outputs. For more details, please see the documentation for
-PyTrees (:ref:`pytrees`).
-
-For example, the following code produces an identical jaxpr to what we saw
-before (with two input vars, one for each element of the input tuple)
-
-
->>> def func4(arg): # Arg is a pair
-... temp = arg[0] + jnp.sin(arg[1]) * 3.
-... return jnp.sum(temp)
-...
->>> print(make_jaxpr(func4)((jnp.zeros(8), jnp.ones(8))))
-{ lambda ; a:f32[8] b:f32[8]. let
- c:f32[8] = sin b
- d:f32[8] = mul c 3.0
- e:f32[8] = add a d
- f:f32[] = reduce_sum[axes=(0,)] e
- in (f,) }
-
-
-
-Constant Vars
--------------
-
-Some values in jaxprs are constants, in that their value does not depend on the
-jaxpr's arguments. When these values are scalars they are represented directly
-in the jaxpr equations; non-scalar array constants are instead hoisted out to
-the top-level jaxpr, where they correspond to constant variables ("constvars").
-These constvars differ from the other jaxpr parameters ("invars") only as a
-bookkeeping convention.
-
-
-Higher-order primitives
------------------------
-
-jaxpr includes several higher-order primitives. They are more complicated because
-they include sub-jaxprs.
-
-Conditionals
-^^^^^^^^^^^^
-
-JAX traces through normal Python conditionals. To capture a
-conditional expression for dynamic execution, one must use the
-:py:func:`jax.lax.switch` and :py:func:`jax.lax.cond` constructors,
-which have the signatures::
-
- lax.switch(index: int, branches: Sequence[A -> B], operand: A) -> B
-
- lax.cond(pred: bool, true_body: A -> B, false_body: A -> B, operand: A) -> B
-
-Both of these will bind a primitive called ``cond`` internally. The
-``cond`` primitive in jaxprs reflects the more general signature of
-:py:func:`lax.switch`: it takes an integer denoting the index of the branch
-to execute (clamped into valid indexing range).
-
-For example:
-
->>> from jax import lax
->>>
->>> def one_of_three(index, arg):
-... return lax.switch(index, [lambda x: x + 1.,
-... lambda x: x - 2.,
-... lambda x: x + 3.],
-... arg)
-...
->>> print(make_jaxpr(one_of_three)(1, 5.))
-{ lambda ; a:i32[] b:f32[]. let
- c:i32[] = convert_element_type[new_dtype=int32 weak_type=False] a
- d:i32[] = clamp 0 c 2
- e:f32[] = cond[
- branches=(
- { lambda ; f:f32[]. let g:f32[] = add f 1.0 in (g,) }
- { lambda ; h:f32[]. let i:f32[] = sub h 2.0 in (i,) }
- { lambda ; j:f32[]. let k:f32[] = add j 3.0 in (k,) }
- )
- ] d b
- in (e,) }
-
-The `branches` parameter to the cond primitive corresponds to the branch
-functionals. In this example, those functionals each take one input variable,
-corresponding to ``x``.
-
-The above instance of the cond primitive takes two operands. The first
-one (``d``) is the branch index, then ``b`` is the operand (``arg``) to
-be passed to whichever jaxpr in ``branches`` is selected by the branch
-index.
-
-Another example, using :py:func:`lax.cond`:
-
->>> from jax import lax
->>>
->>> def func7(arg):
-... return lax.cond(arg >= 0.,
-... lambda xtrue: xtrue + 3.,
-... lambda xfalse: xfalse - 3.,
-... arg)
-...
->>> print(make_jaxpr(func7)(5.))
-{ lambda ; a:f32[]. let
- b:bool[] = ge a 0.0
- c:i32[] = convert_element_type[new_dtype=int32 weak_type=False] b
- d:f32[] = cond[
- branches=(
- { lambda ; e:f32[]. let f:f32[] = sub e 3.0 in (f,) }
- { lambda ; g:f32[]. let h:f32[] = add g 3.0 in (h,) }
- )
- ] c a
- in (d,) }
-
-In this case, the boolean predicate is converted to an integer index
-(0 or 1), and ``branches`` are jaxprs that correspond to the false and
-true branch functionals, in that order. Again, each functional takes
-one input variable, corresponding to ``xfalse`` and ``xtrue``
-respectively.
-
-The following example shows a more complicated situation when the input
-to the branch functionals is a tuple, and the `false` branch functional
-contains a constant ``jnp.ones(1)`` that is hoisted as a `constvar`
-
->>> def func8(arg1, arg2): # arg2 is a pair
-... return lax.cond(arg1 >= 0.,
-... lambda xtrue: xtrue[0],
-... lambda xfalse: jnp.array([1]) + xfalse[1],
-... arg2)
-...
->>> print(make_jaxpr(func8)(5., (jnp.zeros(1), 2.)))
-{ lambda a:i32[1]; b:f32[] c:f32[1] d:f32[]. let
- e:bool[] = ge b 0.0
- f:i32[] = convert_element_type[new_dtype=int32 weak_type=False] e
- g:f32[1] = cond[
- branches=(
- { lambda ; h:i32[1] i:f32[1] j:f32[]. let
- k:f32[1] = convert_element_type[new_dtype=float32 weak_type=True] h
- l:f32[1] = add k j
- in (l,) }
- { lambda ; m_:i32[1] n:f32[1] o:f32[]. let in (n,) }
- )
- ] f a c d
- in (g,) }
-
-
-
-While
-^^^^^
-
-Just like for conditionals, Python loops are inlined during tracing.
-If you want to capture a loop for dynamic execution, you must use one of several
-special operations, :py:func:`jax.lax.while_loop` (a primitive)
-and :py:func:`jax.lax.fori_loop`
-(a helper that generates a while_loop primitive)::
-
- lax.while_loop(cond_fun: (C -> bool), body_fun: (C -> C), init: C) -> C
- lax.fori_loop(start: int, end: int, body: (int -> C -> C), init: C) -> C
-
-
-In the above signature, “C” stands for the type of the loop “carry” value.
-For example, here is an example fori loop
-
->>> import numpy as np
->>>
->>> def func10(arg, n):
-... ones = jnp.ones(arg.shape) # A constant
-... return lax.fori_loop(0, n,
-... lambda i, carry: carry + ones * 3. + arg,
-... arg + ones)
-...
->>> print(make_jaxpr(func10)(np.ones(16), 5))
-{ lambda ; a:f32[16] b:i32[]. let
- c:f32[16] = broadcast_in_dim[broadcast_dimensions=() shape=(16,)] 1.0
- d:f32[16] = add a c
- _:i32[] _:i32[] e:f32[16] = while[
- body_jaxpr={ lambda ; f:f32[16] g:f32[16] h:i32[] i:i32[] j:f32[16]. let
- k:i32[] = add h 1
- l:f32[16] = mul f 3.0
- m:f32[16] = add j l
- n:f32[16] = add m g
- in (k, i, n) }
- body_nconsts=2
- cond_jaxpr={ lambda ; o:i32[] p:i32[] q:f32[16]. let
- r:bool[] = lt o p
- in (r,) }
- cond_nconsts=0
- ] c a 0 b d
- in (e,) }
-
-The while primitive takes 5 arguments: ``c a 0 b d``, as follows:
-
- * 0 constants for ``cond_jaxpr`` (since ``cond_nconsts`` is 0)
- * 2 constants for ``body_jaxpr`` (``c``, and ``a``)
- * 3 parameters for the initial value of carry
-
-Scan
-^^^^
-
-JAX supports a special form of loop over the elements of an array (with
-statically known shape). The fact that there are a fixed number of iterations
-makes this form of looping easily reverse-differentiable. Such loops are
-constructed with the :py:func:`jax.lax.scan` function::
-
- lax.scan(body_fun: (C -> A -> (C, B)), init_carry: C, in_arr: Array[A]) -> (C, Array[B])
-
-This is written in terms of a `Haskell Type Signature`_:
-``C`` is the type of the scan carry, ``A`` is the element type of the
-input array(s), and ``B`` is the element type of the output array(s).
-
-For the example consider the function ``func11`` below
-
->>> def func11(arr, extra):
-... ones = jnp.ones(arr.shape) # A constant
-... def body(carry, aelems):
-... # carry: running dot-product of the two arrays
-... # aelems: a pair with corresponding elements from the two arrays
-... ae1, ae2 = aelems
-... return (carry + ae1 * ae2 + extra, carry)
-... return lax.scan(body, 0., (arr, ones))
-...
->>> print(make_jaxpr(func11)(np.ones(16), 5.))
-{ lambda ; a:f32[16] b:f32[]. let
- c:f32[16] = broadcast_in_dim[broadcast_dimensions=() shape=(16,)] 1.0
- d:f32[] e:f32[16] = scan[
- _split_transpose=False
- jaxpr={ lambda ; f:f32[] g:f32[] h:f32[] i:f32[]. let
- j:f32[] = mul h i
- k:f32[] = convert_element_type[new_dtype=float32 weak_type=False] g
- l:f32[] = add k j
- m:f32[] = convert_element_type[new_dtype=float32 weak_type=False] f
- n:f32[] = add l m
- in (n, g) }
- length=16
- linear=(False, False, False, False)
- num_carry=1
- num_consts=1
- reverse=False
- unroll=1
- ] b 0.0 a c
- in (d, e) }
-
-The ``linear`` parameter describes for each of the input variables whether they
-are guaranteed to be used linearly in the body. Once the scan goes through
-linearization, more arguments will be linear.
-
-The scan primitive takes 4 arguments: ``b 0.0 a c``, of which:
-
- * one is the free variable for the body
- * one is the initial value of the carry
- * The next 2 are the arrays over which the scan operates.
-
-XLA_call
-^^^^^^^^
-
-The call primitive arises from JIT compilation, and it encapsulates
-a sub-jaxpr along with parameters that specify the backend and the device on
-which the computation should run. For example
-
->>> from jax import jit
->>>
->>> def func12(arg):
-... @jit
-... def inner(x):
-... return x + arg * jnp.ones(1) # Include a constant in the inner function
-... return arg + inner(arg - 2.)
-...
->>> print(make_jaxpr(func12)(1.)) # doctest:+ELLIPSIS
-{ lambda ; a:f32[]. let
- b:f32[] = sub a 2.0
- c:f32[1] = pjit[
- name=inner
- jaxpr={ lambda ; d:f32[] e:f32[]. let
- f:f32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] 1.0
- g:f32[] = convert_element_type[new_dtype=float32 weak_type=False] d
- h:f32[1] = mul g f
- i:f32[] = convert_element_type[new_dtype=float32 weak_type=False] e
- j:f32[1] = add i h
- in (j,) }
- ] a b
- k:f32[] = convert_element_type[new_dtype=float32 weak_type=False] a
- l:f32[1] = add k c
- in (l,) }
-
-
-XLA_pmap
-^^^^^^^^
-
-If you use the :py:func:`jax.pmap` transformation, the function to be mapped is
-captured using the ``xla_pmap`` primitive. Consider this example
-
->>> from jax import pmap
->>>
->>> def func13(arr, extra):
-... def inner(x):
-... # use a free variable "extra" and a constant jnp.ones(1)
-... return (x + extra + jnp.ones(1)) / lax.psum(x, axis_name='rows')
-... return pmap(inner, axis_name='rows')(arr)
-...
->>> print(make_jaxpr(func13)(jnp.ones((1, 3)), 5.))
-{ lambda ; a:f32[1,3] b:f32[]. let
- c:f32[1,3] = xla_pmap[
- axis_name=rows
- axis_size=1
- backend=None
- call_jaxpr={ lambda ; d:f32[] e:f32[3]. let
- f:f32[] = convert_element_type[new_dtype=float32 weak_type=False] d
- g:f32[3] = add e f
- h:f32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] 1.0
- i:f32[3] = add g h
- j:f32[3] = psum[axes=('rows',) axis_index_groups=None] e
- k:f32[3] = div i j
- in (k,) }
- devices=None
- donated_invars=(False, False)
- global_axis_size=1
- in_axes=(None, 0)
- is_explicit_global_axis_size=False
- name=inner
- out_axes=(0,)
- ] b a
- in (c,) }
-
-The ``xla_pmap`` primitive specifies the name of the axis (parameter
-``axis_name``) and the body of the function to be mapped as the ``call_jaxpr``
-parameter. The value of this parameter is a Jaxpr with 2 input variables.
-
-The parameter ``in_axes`` specifies which of the input variables should be
-mapped and which should be broadcast. In our example, the value of ``extra``
-is broadcast and the value of ``arr`` is mapped.
-
-.. _Haskell Type Signature: https://wiki.haskell.org/Type_signature
diff --git a/docs/jep/11830-new-remat-checkpoint.md b/docs/jep/11830-new-remat-checkpoint.md
index da0adaf18060..019188349257 100644
--- a/docs/jep/11830-new-remat-checkpoint.md
+++ b/docs/jep/11830-new-remat-checkpoint.md
@@ -14,7 +14,7 @@
## What’s going on?
-As of [#11830](https://github.com/google/jax/pull/11830) we're switching on a new implementation of {func}`jax.checkpoint`, aka {func}`jax.remat` (the two names are aliases of one another). **For most code, there will be no changes.** But there may be some observable differences in edge cases; see [What are the possible issues after the upgrade?](#what-are-the-possible-issues-after-the-upgrade)
+As of [#11830](https://github.com/jax-ml/jax/pull/11830) we're switching on a new implementation of {func}`jax.checkpoint`, aka {func}`jax.remat` (the two names are aliases of one another). **For most code, there will be no changes.** But there may be some observable differences in edge cases; see [What are the possible issues after the upgrade?](#what-are-the-possible-issues-after-the-upgrade)
## How can I disable the change, and go back to the old behavior for now?
@@ -29,7 +29,7 @@ If you need to revert to the old implementation, **please reach out** on a GitHu
As of `jax==0.3.17` the `jax_new_checkpoint` config option is no longer
available. If you have an issue, please reach out on [the issue
-tracker](https://github.com/google/jax/issues) so we can help fix it!
+tracker](https://github.com/jax-ml/jax/issues) so we can help fix it!
## Why are we doing this?
@@ -82,7 +82,7 @@ The old `jax.checkpoint` implementation was forced to save the value of `a`, whi
### Significantly less Python overhead in some cases
-The new `jax.checkpoint` incurs significantly less Python overhead in some cases. [Simple overhead benchmarks](https://github.com/google/jax/blob/88636d2b649bfa31fa58a30ea15c925f35637397/benchmarks/api_benchmark.py#L511-L539) got 10x faster. These overheads only arise in eager op-by-op execution, so in the common case of using a `jax.checkpoint` under a `jax.jit` or similar the speedups aren't relevant. But still, nice!
+The new `jax.checkpoint` incurs significantly less Python overhead in some cases. [Simple overhead benchmarks](https://github.com/jax-ml/jax/blob/88636d2b649bfa31fa58a30ea15c925f35637397/benchmarks/api_benchmark.py#L511-L539) got 10x faster. These overheads only arise in eager op-by-op execution, so in the common case of using a `jax.checkpoint` under a `jax.jit` or similar the speedups aren't relevant. But still, nice!
### Enabling new JAX features by simplifying internals
diff --git a/docs/jep/12049-type-annotations.md b/docs/jep/12049-type-annotations.md
index 9137e3e71232..7a20958c5cab 100644
--- a/docs/jep/12049-type-annotations.md
+++ b/docs/jep/12049-type-annotations.md
@@ -12,7 +12,7 @@ The current state of type annotations in JAX is a bit patchwork, and efforts to
This doc attempts to summarize those issues and generate a roadmap for the goals and non-goals of type annotations in JAX.
Why do we need such a roadmap? Better/more comprehensive type annotations are a frequent request from users, both internally and externally.
-In addition, we frequently receive pull requests from external users (for example, [PR #9917](https://github.com/google/jax/pull/9917), [PR #10322](https://github.com/google/jax/pull/10322)) seeking to improve JAX's type annotations: it's not always clear to the JAX team member reviewing the code whether such contributions are beneficial, particularly when they introduce complex Protocols to address the challenges inherent to full-fledged annotation of JAX's use of Python.
+In addition, we frequently receive pull requests from external users (for example, [PR #9917](https://github.com/jax-ml/jax/pull/9917), [PR #10322](https://github.com/jax-ml/jax/pull/10322)) seeking to improve JAX's type annotations: it's not always clear to the JAX team member reviewing the code whether such contributions are beneficial, particularly when they introduce complex Protocols to address the challenges inherent to full-fledged annotation of JAX's use of Python.
This document details JAX's goals and recommendations for type annotations within the package.
## Why type annotations?
@@ -21,7 +21,7 @@ There are a number of reasons that a Python project might wish to annotate their
### Level 1: Annotations as documentation
-When originally introduced in [PEP 3107](https://peps.python.org/pep-3107/), type annotations were motivated partly by the ability to use them as concise, inline documentation of function parameter types and return types. JAX has long utilized annotations in this manner; an example is the common pattern of creating type names aliased to `Any`. An example can be found in `lax/slicing.py` [[source](https://github.com/google/jax/blob/2bc3e39cd9104071ee39dacac22abd51b94eb27e/jax/_src/lax/slicing.py#L47-L58)]:
+When originally introduced in [PEP 3107](https://peps.python.org/pep-3107/), type annotations were motivated partly by the ability to use them as concise, inline documentation of function parameter types and return types. JAX has long utilized annotations in this manner; an example is the common pattern of creating type names aliased to `Any`. An example can be found in `lax/slicing.py` [[source](https://github.com/jax-ml/jax/blob/2bc3e39cd9104071ee39dacac22abd51b94eb27e/jax/_src/lax/slicing.py#L47-L58)]:
```python
Array = Any
@@ -44,14 +44,14 @@ Many modern IDEs take advantage of type annotations as inputs to [intelligent co
This use of type checking requires going further than the simple aliases used above; for example, knowing that the `slice` function returns an alias of `Any` named `Array` does not add any useful information to the code completion engine. However, were we to annotate the function with a `DeviceArray` return type, the autocomplete would know how to populate the namespace of the result, and thus be able to suggest more relevant autocompletions during the course of development.
-JAX has begun to add this level of type annotation in a few places; one example is the `jnp.ndarray` return type within the `jax.random` package [[source](https://github.com/google/jax/blob/2bc3e39cd9104071ee39dacac22abd51b94eb27e/jax/_src/random.py#L359)]:
+JAX has begun to add this level of type annotation in a few places; one example is the `jnp.ndarray` return type within the `jax.random` package [[source](https://github.com/jax-ml/jax/blob/2bc3e39cd9104071ee39dacac22abd51b94eb27e/jax/_src/random.py#L359)]:
```python
def shuffle(key: KeyArray, x: Array, axis: int = 0) -> jnp.ndarray:
...
```
-In this case `jnp.ndarray` is an abstract base class that forward-declares the attributes and methods of JAX arrays ([see source](https://github.com/google/jax/blob/2bc3e39cd9104071ee39dacac22abd51b94eb27e/jax/_src/numpy/ndarray.py#L41)), and so Pylance in VSCode can offer the full set of autocompletions on results from this function. Here is a screenshot showing the result:
+In this case `jnp.ndarray` is an abstract base class that forward-declares the attributes and methods of JAX arrays ([see source](https://github.com/jax-ml/jax/blob/2bc3e39cd9104071ee39dacac22abd51b94eb27e/jax/_src/numpy/ndarray.py#L41)), and so Pylance in VSCode can offer the full set of autocompletions on results from this function. Here is a screenshot showing the result:
![VSCode Intellisense Screenshot](../_static/vscode-completion.png)
@@ -232,7 +232,7 @@ assert jit(f)(x) # x will be a tracer
```
Again, there are a couple mechanisms that could be used for this:
-- override `type(ArrayInstance).__instancecheck__` to return `True` for both `Array` and `Tracer` objects; this is how `jnp.ndarray` is currently implemented ([source](https://github.com/google/jax/blob/jax-v0.3.17/jax/_src/numpy/ndarray.py#L24-L49)).
+- override `type(ArrayInstance).__instancecheck__` to return `True` for both `Array` and `Tracer` objects; this is how `jnp.ndarray` is currently implemented ([source](https://github.com/jax-ml/jax/blob/jax-v0.3.17/jax/_src/numpy/ndarray.py#L24-L49)).
- define `ArrayInstance` as an abstract base class and dynamically register it to `Array` and `Tracer`
- restructure `Array` and `Tracer` so that `ArrayInstance` is a true base class of both `Array` and `Tracer`
diff --git a/docs/jep/15856-jex.md b/docs/jep/15856-jex.md
index bec06000194e..a5625abf8930 100644
--- a/docs/jep/15856-jex.md
+++ b/docs/jep/15856-jex.md
@@ -170,7 +170,7 @@ print(jax.jit(mul_add_p.bind)(2, 3, 4)) # -> Array(10, dtype=int32)
This module could expose our mechanism for defining new RNG
implementations, and functions for working with PRNG key internals
-(see issue [#9263](https://github.com/google/jax/issues/9263)),
+(see issue [#9263](https://github.com/jax-ml/jax/issues/9263)),
such as the current `jax._src.prng.random_wrap` and
`random_unwrap`.
diff --git a/docs/jep/18137-numpy-scipy-scope.md b/docs/jep/18137-numpy-scipy-scope.md
index 2371e11ee07e..eaebe8fb8997 100644
--- a/docs/jep/18137-numpy-scipy-scope.md
+++ b/docs/jep/18137-numpy-scipy-scope.md
@@ -78,8 +78,8 @@ to JAX which have relatively complex implementations which are difficult to vali
and introduce outsized maintenance burdens; an example is {func}`jax.scipy.special.bessel_jn`:
as of the writing of this JEP, its current implementation is a non-straightforward
iterative approximation that has
-[convergence issues in some domains](https://github.com/google/jax/issues/12402#issuecomment-1384828637),
-and [proposed fixes](https://github.com/google/jax/pull/17038/files) introduce further
+[convergence issues in some domains](https://github.com/jax-ml/jax/issues/12402#issuecomment-1384828637),
+and [proposed fixes](https://github.com/jax-ml/jax/pull/17038/files) introduce further
complexity. Had we more carefully weighed the complexity and robustness of the
implementation when accepting the contribution, we may have chosen not to accept this
contribution to the package.
diff --git a/docs/jep/2026-custom-derivatives.md b/docs/jep/2026-custom-derivatives.md
index aa568adc0d9a..ce149fa6fb35 100644
--- a/docs/jep/2026-custom-derivatives.md
+++ b/docs/jep/2026-custom-derivatives.md
@@ -35,9 +35,9 @@ behavior of their code. This customization
Python control flow and workflows for NaN debugging.
As **JAX developers** we want to write library functions, like
-[`logit`](https://github.com/google/jax/blob/01039299304b148b405ef9b9fa5e82bbb527471d/jax/scipy/special.py#L83)
+[`logit`](https://github.com/jax-ml/jax/blob/01039299304b148b405ef9b9fa5e82bbb527471d/jax/scipy/special.py#L83)
and
-[`expit`](https://github.com/google/jax/blob/01039299304b148b405ef9b9fa5e82bbb527471d/jax/scipy/special.py#L91),
+[`expit`](https://github.com/jax-ml/jax/blob/01039299304b148b405ef9b9fa5e82bbb527471d/jax/scipy/special.py#L91),
that are defined in terms of other primitives, but for the purposes of
differentiation have primitive-like behavior in the sense that we want to define
custom differentiation rules for them, which may be more numerically stable or
@@ -50,9 +50,9 @@ looking to add custom differentiation rules for higher-order functions like
want to be confident we’re not going to preclude good solutions to that problem.
That is, our primary goals are
-1. solve the vmap-removes-custom-jvp semantics problem ([#1249](https://github.com/google/jax/issues/1249)), and
+1. solve the vmap-removes-custom-jvp semantics problem ([#1249](https://github.com/jax-ml/jax/issues/1249)), and
2. allow Python in custom VJPs, e.g. to debug NaNs
- ([#1275](https://github.com/google/jax/issues/1275)).
+ ([#1275](https://github.com/jax-ml/jax/issues/1275)).
Secondary goals are
3. clean up and simplify user experience (symbolic zeros, kwargs, etc)
@@ -60,18 +60,18 @@ Secondary goals are
`odeint`, `root`, etc.
Overall, we want to close
-[#116](https://github.com/google/jax/issues/116),
-[#1097](https://github.com/google/jax/issues/1097),
-[#1249](https://github.com/google/jax/issues/1249),
-[#1275](https://github.com/google/jax/issues/1275),
-[#1366](https://github.com/google/jax/issues/1366),
-[#1723](https://github.com/google/jax/issues/1723),
-[#1670](https://github.com/google/jax/issues/1670),
-[#1875](https://github.com/google/jax/issues/1875),
-[#1938](https://github.com/google/jax/issues/1938),
+[#116](https://github.com/jax-ml/jax/issues/116),
+[#1097](https://github.com/jax-ml/jax/issues/1097),
+[#1249](https://github.com/jax-ml/jax/issues/1249),
+[#1275](https://github.com/jax-ml/jax/issues/1275),
+[#1366](https://github.com/jax-ml/jax/issues/1366),
+[#1723](https://github.com/jax-ml/jax/issues/1723),
+[#1670](https://github.com/jax-ml/jax/issues/1670),
+[#1875](https://github.com/jax-ml/jax/issues/1875),
+[#1938](https://github.com/jax-ml/jax/issues/1938),
and replace the custom_transforms machinery (from
-[#636](https://github.com/google/jax/issues/636),
-[#818](https://github.com/google/jax/issues/818),
+[#636](https://github.com/jax-ml/jax/issues/636),
+[#818](https://github.com/jax-ml/jax/issues/818),
and others).
## Non-goals
@@ -400,7 +400,7 @@ There are some other bells and whistles to the API:
resolved to positions using the `inspect` module. This is a bit of an experiment
with Python 3’s improved ability to programmatically inspect argument
signatures. I believe it is sound but not complete, which is a fine place to be.
- (See also [#2069](https://github.com/google/jax/issues/2069).)
+ (See also [#2069](https://github.com/jax-ml/jax/issues/2069).)
* Arguments can be marked non-differentiable using `nondiff_argnums`, and as with
`jit`’s `static_argnums` these arguments don’t have to be JAX types. We need to
set a convention for how these arguments are passed to the rules. For a primal
@@ -433,5 +433,5 @@ There are some other bells and whistles to the API:
`custom_lin` to the tangent values; `custom_lin` carries with it the user’s
custom backward-pass function, and as a primitive it only has a transpose
rule.
- * This mechanism is described more in [#636](https://github.com/google/jax/issues/636).
+ * This mechanism is described more in [#636](https://github.com/jax-ml/jax/issues/636).
* To prevent
diff --git a/docs/jep/4008-custom-vjp-update.md b/docs/jep/4008-custom-vjp-update.md
index 65235dc64337..1e2270e052a6 100644
--- a/docs/jep/4008-custom-vjp-update.md
+++ b/docs/jep/4008-custom-vjp-update.md
@@ -9,7 +9,7 @@ notebook.
## What to update
-After JAX [PR #4008](https://github.com/google/jax/pull/4008), the arguments
+After JAX [PR #4008](https://github.com/jax-ml/jax/pull/4008), the arguments
passed into a `custom_vjp` function's `nondiff_argnums` can't be `Tracer`s (or
containers of `Tracer`s), which basically means to allow for
arbitrarily-transformable code `nondiff_argnums` shouldn't be used for
@@ -95,7 +95,7 @@ acted very much like lexical closure. But lexical closure over `Tracer`s wasn't
at the time intended to work with `custom_jvp`/`custom_vjp`. Implementing
`nondiff_argnums` that way was a mistake!
-**[PR #4008](https://github.com/google/jax/pull/4008) fixes all lexical closure
+**[PR #4008](https://github.com/jax-ml/jax/pull/4008) fixes all lexical closure
issues with `custom_jvp` and `custom_vjp`.** Woohoo! That is, now `custom_jvp`
and `custom_vjp` functions and rules can close over `Tracer`s to our hearts'
content. For all non-autodiff transformations, things will Just Work. For
@@ -120,9 +120,9 @@ manageable, until you think through how we have to handle arbitrary pytrees!
Moreover, that complexity isn't necessary: if user code treats array-like
non-differentiable arguments just like regular arguments and residuals,
everything already works. (Before
-[#4039](https://github.com/google/jax/pull/4039) JAX might've complained about
+[#4039](https://github.com/jax-ml/jax/pull/4039) JAX might've complained about
involving integer-valued inputs and outputs in autodiff, but after
-[#4039](https://github.com/google/jax/pull/4039) those will just work!)
+[#4039](https://github.com/jax-ml/jax/pull/4039) those will just work!)
Unlike `custom_vjp`, it was easy to make `custom_jvp` work with
`nondiff_argnums` arguments that were `Tracer`s. So these updates only need to
diff --git a/docs/jep/4410-omnistaging.md b/docs/jep/4410-omnistaging.md
index eb68ee5f0e0a..f95c15f404b6 100644
--- a/docs/jep/4410-omnistaging.md
+++ b/docs/jep/4410-omnistaging.md
@@ -20,7 +20,7 @@ This is more of an upgrade guide than a design doc.
### What's going on?
A change to JAX's tracing infrastructure called “omnistaging”
-([google/jax#3370](https://github.com/google/jax/pull/3370)) was switched on in
+([jax-ml/jax#3370](https://github.com/jax-ml/jax/pull/3370)) was switched on in
jax==0.2.0. This change improves memory performance, trace execution time, and
simplifies jax internals, but may cause some existing code to break. Breakage is
usually a result of buggy code, so long-term it’s best to fix the bugs, but
@@ -191,7 +191,7 @@ and potentially even fragmenting memory.
(The `broadcast` that corresponds to the construction of the zeros array for
`jnp.zeros_like(x)` is staged out because JAX is lazy about very simple
-expressions from [google/jax#1668](https://github.com/google/jax/pull/1668). After
+expressions from [jax-ml/jax#1668](https://github.com/jax-ml/jax/pull/1668). After
omnistaging, we can remove that lazy sublanguage and simplify JAX internals.)
The reason the creation of `mask` is not staged out is that, before omnistaging,
diff --git a/docs/jep/9263-typed-keys.md b/docs/jep/9263-typed-keys.md
index 828b95e8ce00..d520f6f63df9 100644
--- a/docs/jep/9263-typed-keys.md
+++ b/docs/jep/9263-typed-keys.md
@@ -321,7 +321,7 @@ Why introduce extended dtypes in generality, beyond PRNGs? We reuse this same
extended dtype mechanism elsewhere internally. For example, the
`jax._src.core.bint` object, a bounded integer type used for experimental work
on dynamic shapes, is another extended dtype. In recent JAX versions it satisfies
-the properties above (See [jax/_src/core.py#L1789-L1802](https://github.com/google/jax/blob/jax-v0.4.14/jax/_src/core.py#L1789-L1802)).
+the properties above (See [jax/_src/core.py#L1789-L1802](https://github.com/jax-ml/jax/blob/jax-v0.4.14/jax/_src/core.py#L1789-L1802)).
### PRNG dtypes
PRNG dtypes are defined as a particular case of extended dtypes. Specifically,
diff --git a/docs/jep/9407-type-promotion.ipynb b/docs/jep/9407-type-promotion.ipynb
index 2aef1768112f..a1ede3177a3a 100644
--- a/docs/jep/9407-type-promotion.ipynb
+++ b/docs/jep/9407-type-promotion.ipynb
@@ -8,7 +8,7 @@
"source": [
"# Design of Type Promotion Semantics for JAX\n",
"\n",
- "[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/jax/blob/main/docs/jep/9407-type-promotion.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/google/jax/blob/main/docs/jep/9407-type-promotion.ipynb)\n",
+ "[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jax-ml/jax/blob/main/docs/jep/9407-type-promotion.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/jax-ml/jax/blob/main/docs/jep/9407-type-promotion.ipynb)\n",
"\n",
"*Jake VanderPlas, December 2021*\n",
"\n",
@@ -3317,7 +3317,6 @@
],
"source": [
"# @title\n",
- "from jax import dtypes\n",
"import jax\n",
"import jax.numpy as jnp\n",
"import pandas as pd\n",
@@ -3802,7 +3801,6 @@
],
"source": [
"# @title\n",
- "from jax import dtypes\n",
"import jax\n",
"import jax.numpy as jnp\n",
"import pandas as pd\n",
diff --git a/docs/jep/9407-type-promotion.md b/docs/jep/9407-type-promotion.md
index 107bcd8c968b..ff67a8c21399 100644
--- a/docs/jep/9407-type-promotion.md
+++ b/docs/jep/9407-type-promotion.md
@@ -5,7 +5,7 @@ jupytext:
extension: .md
format_name: myst
format_version: 0.13
- jupytext_version: 1.16.1
+ jupytext_version: 1.16.4
kernelspec:
display_name: Python 3
language: python
@@ -16,7 +16,7 @@ kernelspec:
# Design of Type Promotion Semantics for JAX
-[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/jax/blob/main/docs/jep/9407-type-promotion.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/google/jax/blob/main/docs/jep/9407-type-promotion.ipynb)
+[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jax-ml/jax/blob/main/docs/jep/9407-type-promotion.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/jax-ml/jax/blob/main/docs/jep/9407-type-promotion.ipynb)
*Jake VanderPlas, December 2021*
@@ -908,7 +908,6 @@ display.HTML(table.to_html())
:tags: [hide-input]
# @title
-from jax import dtypes
import jax
import jax.numpy as jnp
import pandas as pd
@@ -963,7 +962,6 @@ display.HTML(table.to_html())
:tags: [hide-input]
# @title
-from jax import dtypes
import jax
import jax.numpy as jnp
import pandas as pd
diff --git a/docs/jep/9419-jax-versioning.md b/docs/jep/9419-jax-versioning.md
index 759a9be86713..b964aa2af45d 100644
--- a/docs/jep/9419-jax-versioning.md
+++ b/docs/jep/9419-jax-versioning.md
@@ -58,11 +58,11 @@ These constraints imply the following rules for releases:
* If a new `jaxlib` is released, a `jax` release must be made at the same time.
These
-[version constraints](https://github.com/google/jax/blob/main/jax/version.py)
+[version constraints](https://github.com/jax-ml/jax/blob/main/jax/version.py)
are currently checked by `jax` at import time, instead of being expressed as
Python package version constraints. `jax` checks the `jaxlib` version at
runtime rather than using a `pip` package version constraint because we
-[provide separate `jaxlib` wheels](https://github.com/google/jax#installation)
+[provide separate `jaxlib` wheels](https://github.com/jax-ml/jax#installation)
for a variety of hardware and software versions (e.g, GPU, TPU, etc.). Since we
do not know which is the right choice for any given user, we do not want `pip`
to install a `jaxlib` package for us automatically.
@@ -119,7 +119,7 @@ no released `jax` version uses that API.
## How is the source to `jaxlib` laid out?
`jaxlib` is split across two main repositories, namely the
-[`jaxlib/` subdirectory in the main JAX repository](https://github.com/google/jax/tree/main/jaxlib)
+[`jaxlib/` subdirectory in the main JAX repository](https://github.com/jax-ml/jax/tree/main/jaxlib)
and in the
[XLA source tree, which lives inside the XLA repository](https://github.com/openxla/xla).
The JAX-specific pieces inside XLA are primarily in the
@@ -146,7 +146,7 @@ level.
`jaxlib` is built using Bazel out of the `jax` repository. The pieces of
`jaxlib` from the XLA repository are incorporated into the build
-[as a Bazel submodule](https://github.com/google/jax/blob/main/WORKSPACE).
+[as a Bazel submodule](https://github.com/jax-ml/jax/blob/main/WORKSPACE).
To update the version of XLA used during the build, one must update the pinned
version in the Bazel `WORKSPACE`. This is done manually on an
as-needed basis, but can be overridden on a build-by-build basis.
diff --git a/docs/jep/index.rst b/docs/jep/index.rst
index 194eb0cb9d69..f9dda2657ced 100644
--- a/docs/jep/index.rst
+++ b/docs/jep/index.rst
@@ -32,7 +32,7 @@ should be linked to this issue.
Then create a pull request that adds a file named
`%d-{short-title}.md` - with the number being the issue number.
-.. _JEP label: https://github.com/google/jax/issues?q=label%3AJEP
+.. _JEP label: https://github.com/jax-ml/jax/issues?q=label%3AJEP
.. toctree::
:maxdepth: 1
diff --git a/docs/jit-compilation.md b/docs/jit-compilation.md
index 2d442c8411aa..59c7bbd8fb90 100644
--- a/docs/jit-compilation.md
+++ b/docs/jit-compilation.md
@@ -5,7 +5,7 @@ jupytext:
extension: .md
format_name: myst
format_version: 0.13
- jupytext_version: 1.16.1
+ jupytext_version: 1.16.4
kernelspec:
display_name: Python 3
language: python
@@ -51,7 +51,7 @@ def log2(x):
print(jax.make_jaxpr(log2)(3.0))
```
-The {ref}`understanding-jaxprs` section of the documentation provides more information on the meaning of the above output.
+The {ref}`jax-internals-jaxpr` section of the documentation provides more information on the meaning of the above output.
Importantly, notice that the jaxpr does not capture the side-effect present in the function: there is nothing in it corresponding to `global_list.append(x)`.
This is a feature, not a bug: JAX transformations are designed to understand side-effect-free (a.k.a. functionally pure) code.
diff --git a/docs/key-concepts.md b/docs/key-concepts.md
index 4b114c857460..daab2c9fdde4 100644
--- a/docs/key-concepts.md
+++ b/docs/key-concepts.md
@@ -5,7 +5,7 @@ jupytext:
extension: .md
format_name: myst
format_version: 0.13
- jupytext_version: 1.16.1
+ jupytext_version: 1.16.4
kernelspec:
display_name: Python 3
language: python
@@ -13,7 +13,7 @@ kernelspec:
---
(key-concepts)=
-# Key Concepts
+# Key concepts
@@ -23,13 +23,13 @@ This section briefly introduces some key concepts of the JAX package.
## JAX arrays ({class}`jax.Array`)
The default array implementation in JAX is {class}`jax.Array`. In many ways it is similar to
-the {class}`numpy.ndarray` type that you may be familar with from the NumPy package, but it
+the {class}`numpy.ndarray` type that you may be familiar with from the NumPy package, but it
has some important differences.
### Array creation
We typically don't call the {class}`jax.Array` constructor directly, but rather create arrays via JAX API functions.
-For example, {mod}`jax.numpy` provides familar NumPy-style array construction functionality
+For example, {mod}`jax.numpy` provides familiar NumPy-style array construction functionality
such as {func}`jax.numpy.zeros`, {func}`jax.numpy.linspace`, {func}`jax.numpy.arange`, etc.
```{code-cell}
@@ -147,10 +147,10 @@ jaxprs later in {ref}`jax-internals-jaxpr`.
## Pytrees
JAX functions and transformations fundamentally operate on arrays, but in practice it is
-convenient to write code that work with collections of arrays: for example, a neural
+convenient to write code that works with collection of arrays: for example, a neural
network might organize its parameters in a dictionary of arrays with meaningful keys.
Rather than handle such structures on a case-by-case basis, JAX relies on the {term}`pytree`
-abstraction to treat such collections in a uniform matter.
+abstraction to treat such collections in a uniform manner.
Here are some examples of objects that can be treated as pytrees:
diff --git a/docs/multi_process.md b/docs/multi_process.md
index 7d7083bde10f..32cfae126784 100644
--- a/docs/multi_process.md
+++ b/docs/multi_process.md
@@ -1,4 +1,4 @@
-# Using JAX in multi-host and multi-process environments
+# Multi-host and multi-process environments
diff --git a/docs/notebooks/Common_Gotchas_in_JAX.ipynb b/docs/notebooks/Common_Gotchas_in_JAX.ipynb
index 7ba192437a32..71bd4527644a 100644
--- a/docs/notebooks/Common_Gotchas_in_JAX.ipynb
+++ b/docs/notebooks/Common_Gotchas_in_JAX.ipynb
@@ -10,7 +10,7 @@
"\n",
"\n",
"\n",
- "[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/jax/blob/main/docs/notebooks/Common_Gotchas_in_JAX.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/google/jax/blob/main/docs/notebooks/Common_Gotchas_in_JAX.ipynb)"
+ "[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jax-ml/jax/blob/main/docs/notebooks/Common_Gotchas_in_JAX.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/jax-ml/jax/blob/main/docs/notebooks/Common_Gotchas_in_JAX.ipynb)"
]
},
{
@@ -19,8 +19,6 @@
"id": "4k5PVzEo2uJO"
},
"source": [
- "*levskaya@ mattjj@*\n",
- "\n",
"When walking about the countryside of Italy, the people will not hesitate to tell you that __JAX__ has [_\"una anima di pura programmazione funzionale\"_](https://www.sscardapane.it/iaml-backup/jax-intro/).\n",
"\n",
"__JAX__ is a language for __expressing__ and __composing__ __transformations__ of numerical programs. __JAX__ is also able to __compile__ numerical programs for CPU or accelerators (GPU/TPU).\n",
@@ -226,7 +224,6 @@
],
"source": [
"import jax.numpy as jnp\n",
- "import jax.lax as lax\n",
"from jax import make_jaxpr\n",
"\n",
"# lax.fori_loop\n",
@@ -258,7 +255,7 @@
"id": "oBdKtkVW8Lha"
},
"source": [
- "## 🔪 In-Place Updates"
+ "## 🔪 In-place updates"
]
},
{
@@ -533,7 +530,7 @@
"id": "oZ_jE2WAypdL"
},
"source": [
- "## 🔪 Out-of-Bounds Indexing"
+ "## 🔪 Out-of-bounds indexing"
]
},
{
@@ -664,7 +661,7 @@
"source": [
"Note that due to this behavior for index retrieval, functions like `jnp.nanargmin` and `jnp.nanargmax` return -1 for slices consisting of NaNs whereas Numpy would throw an error.\n",
"\n",
- "Note also that, as the two behaviors described above are not inverses of each other, reverse-mode automatic differentiation (which turns index updates into index retrievals and vice versa) [will not preserve the semantics of out of bounds indexing](https://github.com/google/jax/issues/5760). Thus it may be a good idea to think of out-of-bounds indexing in JAX as a case of [undefined behavior](https://en.wikipedia.org/wiki/Undefined_behavior)."
+ "Note also that, as the two behaviors described above are not inverses of each other, reverse-mode automatic differentiation (which turns index updates into index retrievals and vice versa) [will not preserve the semantics of out of bounds indexing](https://github.com/jax-ml/jax/issues/5760). Thus it may be a good idea to think of out-of-bounds indexing in JAX as a case of [undefined behavior](https://en.wikipedia.org/wiki/Undefined_behavior)."
]
},
{
@@ -868,7 +865,7 @@
"id": "MUycRNh6e50W"
},
"source": [
- "## 🔪 Random Numbers"
+ "## 🔪 Random numbers"
]
},
{
@@ -888,7 +885,7 @@
"id": "Qikt9pPW9L5K"
},
"source": [
- "### RNGs and State\n",
+ "### RNGs and state\n",
"You're used to _stateful_ pseudorandom number generators (PRNGs) from numpy and other libraries, which helpfully hide a lot of details under the hood to give you a ready fountain of pseudorandomness:"
]
},
@@ -1006,7 +1003,7 @@
"id": "COjzGBpO4tzL"
},
"source": [
- "JAX instead implements an _explicit_ PRNG where entropy production and consumption are handled by explicitly passing and iterating PRNG state. JAX uses a modern [Threefry counter-based PRNG](https://github.com/google/jax/blob/main/docs/jep/263-prng.md) that's __splittable__. That is, its design allows us to __fork__ the PRNG state into new PRNGs for use with parallel stochastic generation.\n",
+ "JAX instead implements an _explicit_ PRNG where entropy production and consumption are handled by explicitly passing and iterating PRNG state. JAX uses a modern [Threefry counter-based PRNG](https://github.com/jax-ml/jax/blob/main/docs/jep/263-prng.md) that's __splittable__. That is, its design allows us to __fork__ the PRNG state into new PRNGs for use with parallel stochastic generation.\n",
"\n",
"The random state is described by a special array element that we call a __key__:"
]
@@ -1031,7 +1028,6 @@
}
],
"source": [
- "from jax import random\n",
"key = random.key(0)\n",
"key"
]
@@ -1105,8 +1101,8 @@
"print(\"old key\", key)\n",
"key, subkey = random.split(key)\n",
"normal_pseudorandom = random.normal(subkey, shape=(1,))\n",
- "print(\" \\---SPLIT --> new key \", key)\n",
- "print(\" \\--> new subkey\", subkey, \"--> normal\", normal_pseudorandom)"
+ "print(r\" \\---SPLIT --> new key \", key)\n",
+ "print(r\" \\--> new subkey\", subkey, \"--> normal\", normal_pseudorandom)"
]
},
{
@@ -1140,8 +1136,8 @@
"print(\"old key\", key)\n",
"key, subkey = random.split(key)\n",
"normal_pseudorandom = random.normal(subkey, shape=(1,))\n",
- "print(\" \\---SPLIT --> new key \", key)\n",
- "print(\" \\--> new subkey\", subkey, \"--> normal\", normal_pseudorandom)"
+ "print(r\" \\---SPLIT --> new key \", key)\n",
+ "print(r\" \\--> new subkey\", subkey, \"--> normal\", normal_pseudorandom)"
]
},
{
@@ -1183,7 +1179,7 @@
"id": "rg4CpMZ8c3ri"
},
"source": [
- "## 🔪 Control Flow"
+ "## 🔪 Control flow"
]
},
{
@@ -1192,7 +1188,7 @@
"id": "izLTvT24dAq0"
},
"source": [
- "### ✔ python control_flow + autodiff ✔\n",
+ "### ✔ Python control_flow + autodiff ✔\n",
"\n",
"If you just want to apply `grad` to your python functions, you can use regular python control-flow constructs with no problems, as if you were using [Autograd](https://github.com/hips/autograd) (or Pytorch or TF Eager)."
]
@@ -1231,7 +1227,7 @@
"id": "hIfPT7WMmZ2H"
},
"source": [
- "### python control flow + JIT\n",
+ "### Python control flow + JIT\n",
"\n",
"Using control flow with `jit` is more complicated, and by default it has more constraints.\n",
"\n",
@@ -1353,7 +1349,7 @@
"\n",
"For example, if we evaluate an `@jit` function on the array `jnp.array([1., 2., 3.], jnp.float32)`, we might want to compile code that we can reuse to evaluate the function on `jnp.array([4., 5., 6.], jnp.float32)` to save on compile time.\n",
"\n",
- "To get a view of your Python code that is valid for many different argument values, JAX traces it on _abstract values_ that represent sets of possible inputs. There are [multiple different levels of abstraction](https://github.com/google/jax/blob/main/jax/_src/abstract_arrays.py), and different transformations use different abstraction levels.\n",
+ "To get a view of your Python code that is valid for many different argument values, JAX traces it on _abstract values_ that represent sets of possible inputs. There are [multiple different levels of abstraction](https://github.com/jax-ml/jax/blob/main/jax/_src/abstract_arrays.py), and different transformations use different abstraction levels.\n",
"\n",
"By default, `jit` traces your code on the `ShapedArray` abstraction level, where each abstract value represents the set of all array values with a fixed shape and dtype. For example, if we trace using the abstract value `ShapedArray((3,), jnp.float32)`, we get a view of the function that can be reused for any concrete value in the corresponding set of arrays. That means we can save on compile time.\n",
"\n",
@@ -1701,7 +1697,7 @@
],
"source": [
"init_val = 0\n",
- "cond_fun = lambda x: x<10\n",
+ "cond_fun = lambda x: x < 10\n",
"body_fun = lambda x: x+1\n",
"lax.while_loop(cond_fun, body_fun, init_val)\n",
"# --> array(10, dtype=int32)"
@@ -1791,7 +1787,7 @@
"id": "OxLsZUyRt_kF"
},
"source": [
- "## 🔪 Dynamic Shapes"
+ "## 🔪 Dynamic shapes"
]
},
{
@@ -2194,7 +2190,7 @@
"id": "WAHjmL0E2XwO"
},
"source": [
- "## 🔪 Miscellaneous Divergences from NumPy\n",
+ "## 🔪 Miscellaneous divergences from NumPy\n",
"\n",
"While `jax.numpy` makes every attempt to replicate the behavior of numpy's API, there do exist corner cases where the behaviors differ.\n",
"Many such cases are discussed in detail in the sections above; here we list several other known places where the APIs diverge.\n",
diff --git a/docs/notebooks/Common_Gotchas_in_JAX.md b/docs/notebooks/Common_Gotchas_in_JAX.md
index 98c4b391c7ce..741fa3af063c 100644
--- a/docs/notebooks/Common_Gotchas_in_JAX.md
+++ b/docs/notebooks/Common_Gotchas_in_JAX.md
@@ -5,7 +5,7 @@ jupytext:
extension: .md
format_name: myst
format_version: 0.13
- jupytext_version: 1.16.1
+ jupytext_version: 1.16.4
kernelspec:
display_name: Python 3
language: python
@@ -18,12 +18,10 @@ kernelspec:
-[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/jax/blob/main/docs/notebooks/Common_Gotchas_in_JAX.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/google/jax/blob/main/docs/notebooks/Common_Gotchas_in_JAX.ipynb)
+[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jax-ml/jax/blob/main/docs/notebooks/Common_Gotchas_in_JAX.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/jax-ml/jax/blob/main/docs/notebooks/Common_Gotchas_in_JAX.ipynb)
+++ {"id": "4k5PVzEo2uJO"}
-*levskaya@ mattjj@*
-
When walking about the countryside of Italy, the people will not hesitate to tell you that __JAX__ has [_"una anima di pura programmazione funzionale"_](https://www.sscardapane.it/iaml-backup/jax-intro/).
__JAX__ is a language for __expressing__ and __composing__ __transformations__ of numerical programs. __JAX__ is also able to __compile__ numerical programs for CPU or accelerators (GPU/TPU).
@@ -130,7 +128,6 @@ It is not recommended to use iterators in any JAX function you want to `jit` or
:outputId: 52d885fd-0239-4a08-f5ce-0c38cc008903
import jax.numpy as jnp
-import jax.lax as lax
from jax import make_jaxpr
# lax.fori_loop
@@ -158,7 +155,7 @@ iter_operand = iter(range(10))
+++ {"id": "oBdKtkVW8Lha"}
-## 🔪 In-Place Updates
+## 🔪 In-place updates
+++ {"id": "JffAqnEW4JEb"}
@@ -268,7 +265,7 @@ For more details on indexed array updates, see the [documentation for the `.at`
+++ {"id": "oZ_jE2WAypdL"}
-## 🔪 Out-of-Bounds Indexing
+## 🔪 Out-of-bounds indexing
+++ {"id": "btRFwEVzypdN"}
@@ -315,7 +312,7 @@ jnp.arange(10.0).at[11].get(mode='fill', fill_value=jnp.nan)
Note that due to this behavior for index retrieval, functions like `jnp.nanargmin` and `jnp.nanargmax` return -1 for slices consisting of NaNs whereas Numpy would throw an error.
-Note also that, as the two behaviors described above are not inverses of each other, reverse-mode automatic differentiation (which turns index updates into index retrievals and vice versa) [will not preserve the semantics of out of bounds indexing](https://github.com/google/jax/issues/5760). Thus it may be a good idea to think of out-of-bounds indexing in JAX as a case of [undefined behavior](https://en.wikipedia.org/wiki/Undefined_behavior).
+Note also that, as the two behaviors described above are not inverses of each other, reverse-mode automatic differentiation (which turns index updates into index retrievals and vice versa) [will not preserve the semantics of out of bounds indexing](https://github.com/jax-ml/jax/issues/5760). Thus it may be a good idea to think of out-of-bounds indexing in JAX as a case of [undefined behavior](https://en.wikipedia.org/wiki/Undefined_behavior).
+++ {"id": "LwB07Kx5sgHu"}
@@ -385,7 +382,7 @@ jnp.sum(jnp.array(x))
+++ {"id": "MUycRNh6e50W"}
-## 🔪 Random Numbers
+## 🔪 Random numbers
+++ {"id": "O8vvaVt3MRG2"}
@@ -395,7 +392,7 @@ jnp.sum(jnp.array(x))
+++ {"id": "Qikt9pPW9L5K"}
-### RNGs and State
+### RNGs and state
You're used to _stateful_ pseudorandom number generators (PRNGs) from numpy and other libraries, which helpfully hide a lot of details under the hood to give you a ready fountain of pseudorandomness:
```{code-cell} ipython3
@@ -463,7 +460,7 @@ The Mersenne Twister PRNG is also known to have a [number](https://cs.stackexcha
+++ {"id": "COjzGBpO4tzL"}
-JAX instead implements an _explicit_ PRNG where entropy production and consumption are handled by explicitly passing and iterating PRNG state. JAX uses a modern [Threefry counter-based PRNG](https://github.com/google/jax/blob/main/docs/jep/263-prng.md) that's __splittable__. That is, its design allows us to __fork__ the PRNG state into new PRNGs for use with parallel stochastic generation.
+JAX instead implements an _explicit_ PRNG where entropy production and consumption are handled by explicitly passing and iterating PRNG state. JAX uses a modern [Threefry counter-based PRNG](https://github.com/jax-ml/jax/blob/main/docs/jep/263-prng.md) that's __splittable__. That is, its design allows us to __fork__ the PRNG state into new PRNGs for use with parallel stochastic generation.
The random state is described by a special array element that we call a __key__:
@@ -471,7 +468,6 @@ The random state is described by a special array element that we call a __key__:
:id: yPHE7KTWgAWs
:outputId: ae8af0ee-f19e-474e-81b6-45e894eb2fc3
-from jax import random
key = random.key(0)
key
```
@@ -504,8 +500,8 @@ Instead, we __split__ the PRNG to get usable __subkeys__ every time we need a ne
print("old key", key)
key, subkey = random.split(key)
normal_pseudorandom = random.normal(subkey, shape=(1,))
-print(" \---SPLIT --> new key ", key)
-print(" \--> new subkey", subkey, "--> normal", normal_pseudorandom)
+print(r" \---SPLIT --> new key ", key)
+print(r" \--> new subkey", subkey, "--> normal", normal_pseudorandom)
```
+++ {"id": "tqtFVE4MthO3"}
@@ -519,8 +515,8 @@ We propagate the __key__ and make new __subkeys__ whenever we need a new random
print("old key", key)
key, subkey = random.split(key)
normal_pseudorandom = random.normal(subkey, shape=(1,))
-print(" \---SPLIT --> new key ", key)
-print(" \--> new subkey", subkey, "--> normal", normal_pseudorandom)
+print(r" \---SPLIT --> new key ", key)
+print(r" \--> new subkey", subkey, "--> normal", normal_pseudorandom)
```
+++ {"id": "0KLYUluz3lN3"}
@@ -538,11 +534,11 @@ for subkey in subkeys:
+++ {"id": "rg4CpMZ8c3ri"}
-## 🔪 Control Flow
+## 🔪 Control flow
+++ {"id": "izLTvT24dAq0"}
-### ✔ python control_flow + autodiff ✔
+### ✔ Python control_flow + autodiff ✔
If you just want to apply `grad` to your python functions, you can use regular python control-flow constructs with no problems, as if you were using [Autograd](https://github.com/hips/autograd) (or Pytorch or TF Eager).
@@ -562,7 +558,7 @@ print(grad(f)(4.)) # ok!
+++ {"id": "hIfPT7WMmZ2H"}
-### python control flow + JIT
+### Python control flow + JIT
Using control flow with `jit` is more complicated, and by default it has more constraints.
@@ -627,7 +623,7 @@ When we `jit`-compile a function, we usually want to compile a version of the fu
For example, if we evaluate an `@jit` function on the array `jnp.array([1., 2., 3.], jnp.float32)`, we might want to compile code that we can reuse to evaluate the function on `jnp.array([4., 5., 6.], jnp.float32)` to save on compile time.
-To get a view of your Python code that is valid for many different argument values, JAX traces it on _abstract values_ that represent sets of possible inputs. There are [multiple different levels of abstraction](https://github.com/google/jax/blob/main/jax/_src/abstract_arrays.py), and different transformations use different abstraction levels.
+To get a view of your Python code that is valid for many different argument values, JAX traces it on _abstract values_ that represent sets of possible inputs. There are [multiple different levels of abstraction](https://github.com/jax-ml/jax/blob/main/jax/_src/abstract_arrays.py), and different transformations use different abstraction levels.
By default, `jit` traces your code on the `ShapedArray` abstraction level, where each abstract value represents the set of all array values with a fixed shape and dtype. For example, if we trace using the abstract value `ShapedArray((3,), jnp.float32)`, we get a view of the function that can be reused for any concrete value in the corresponding set of arrays. That means we can save on compile time.
@@ -805,7 +801,7 @@ def while_loop(cond_fun, body_fun, init_val):
:outputId: 552fe42f-4d32-4e25-c8c2-b951160a3f4e
init_val = 0
-cond_fun = lambda x: x<10
+cond_fun = lambda x: x < 10
body_fun = lambda x: x+1
lax.while_loop(cond_fun, body_fun, init_val)
# --> array(10, dtype=int32)
@@ -865,7 +861,7 @@ $\ast$ = argument-value-independent loop condition - unrolls the loop
+++ {"id": "OxLsZUyRt_kF"}
-## 🔪 Dynamic Shapes
+## 🔪 Dynamic shapes
+++ {"id": "1tKXcAMduDR1"}
@@ -1130,7 +1126,7 @@ x.dtype # --> dtype('float64')
+++ {"id": "WAHjmL0E2XwO"}
-## 🔪 Miscellaneous Divergences from NumPy
+## 🔪 Miscellaneous divergences from NumPy
While `jax.numpy` makes every attempt to replicate the behavior of numpy's API, there do exist corner cases where the behaviors differ.
Many such cases are discussed in detail in the sections above; here we list several other known places where the APIs diverge.
diff --git a/docs/notebooks/Custom_derivative_rules_for_Python_code.ipynb b/docs/notebooks/Custom_derivative_rules_for_Python_code.ipynb
index 3abb6d9cbaec..5c09a0a4f732 100644
--- a/docs/notebooks/Custom_derivative_rules_for_Python_code.ipynb
+++ b/docs/notebooks/Custom_derivative_rules_for_Python_code.ipynb
@@ -6,13 +6,11 @@
"id": "LqiaKasFjH82"
},
"source": [
- "# Custom derivative rules for JAX-transformable Python functions\n",
+ "# Custom derivative rules\n",
"\n",
"\n",
"\n",
- "[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/jax/blob/main/docs/notebooks/Custom_derivative_rules_for_Python_code.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/google/jax/blob/main/docs/notebooks/Custom_derivative_rules_for_Python_code.ipynb)\n",
- "\n",
- "*mattjj@ Mar 19 2020, last updated Oct 14 2020*\n",
+ "[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jax-ml/jax/blob/main/docs/notebooks/Custom_derivative_rules_for_Python_code.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/jax-ml/jax/blob/main/docs/notebooks/Custom_derivative_rules_for_Python_code.ipynb)\n",
"\n",
"There are two ways to define differentiation rules in JAX:\n",
"\n",
@@ -30,7 +28,7 @@
"id": "9Fg3NFNY-2RY"
},
"source": [
- "## TL;DR"
+ "## Summary"
]
},
{
@@ -247,7 +245,6 @@
}
],
"source": [
- "import jax.numpy as jnp\n",
"\n",
"def log1pexp(x):\n",
" return jnp.log(1. + jnp.exp(x))\n",
@@ -984,7 +981,7 @@
" (a, x_star, x_star_bar),\n",
" x_star_bar))\n",
" return a_bar, jnp.zeros_like(x_star)\n",
- " \n",
+ "\n",
"def rev_iter(f, packed, u):\n",
" a, x_star, x_star_bar = packed\n",
" _, vjp_x = vjp(lambda x: f(a, x), x_star)\n",
@@ -1884,7 +1881,6 @@
}
],
"source": [
- "from jax import vjp\n",
"\n",
"y, f_vjp = vjp(f, 3.)\n",
"print(y)"
@@ -1983,7 +1979,7 @@
" return x, x\n",
"\n",
"def debug_bwd(x, g):\n",
- " import pdb; pdb.set_trace()\n",
+ " pdb.set_trace()\n",
" return g\n",
"\n",
"debug.defvjp(debug_fwd, debug_bwd)"
diff --git a/docs/notebooks/Custom_derivative_rules_for_Python_code.md b/docs/notebooks/Custom_derivative_rules_for_Python_code.md
index ad577d55cd0d..8a9b931552d9 100644
--- a/docs/notebooks/Custom_derivative_rules_for_Python_code.md
+++ b/docs/notebooks/Custom_derivative_rules_for_Python_code.md
@@ -5,7 +5,7 @@ jupytext:
extension: .md
format_name: myst
format_version: 0.13
- jupytext_version: 1.16.1
+ jupytext_version: 1.16.4
kernelspec:
display_name: Python 3
name: python3
@@ -13,13 +13,11 @@ kernelspec:
+++ {"id": "LqiaKasFjH82"}
-# Custom derivative rules for JAX-transformable Python functions
+# Custom derivative rules
-[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/jax/blob/main/docs/notebooks/Custom_derivative_rules_for_Python_code.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/google/jax/blob/main/docs/notebooks/Custom_derivative_rules_for_Python_code.ipynb)
-
-*mattjj@ Mar 19 2020, last updated Oct 14 2020*
+[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jax-ml/jax/blob/main/docs/notebooks/Custom_derivative_rules_for_Python_code.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/jax-ml/jax/blob/main/docs/notebooks/Custom_derivative_rules_for_Python_code.ipynb)
There are two ways to define differentiation rules in JAX:
@@ -32,7 +30,7 @@ For an introduction to JAX's automatic differentiation API, see [The Autodiff Co
+++ {"id": "9Fg3NFNY-2RY"}
-## TL;DR
+## Summary
+++ {"id": "ZgMNRtXyWIW8"}
@@ -145,7 +143,6 @@ Say we want to write a function called `log1pexp`, which computes $x \mapsto \lo
:id: 6lWbTvs40ET-
:outputId: 8caff99e-add1-4c70-ace3-212c0c5c6f4e
-import jax.numpy as jnp
def log1pexp(x):
return jnp.log(1. + jnp.exp(x))
@@ -524,7 +521,7 @@ def fixed_point_rev(f, res, x_star_bar):
(a, x_star, x_star_bar),
x_star_bar))
return a_bar, jnp.zeros_like(x_star)
-
+
def rev_iter(f, packed, u):
a, x_star, x_star_bar = packed
_, vjp_x = vjp(lambda x: f(a, x), x_star)
@@ -965,7 +962,6 @@ print(grad(f)(3.))
:id: s1Pn_qCIODcF
:outputId: 423d34e0-35b8-4b57-e89d-f70f20e28ea9
-from jax import vjp
y, f_vjp = vjp(f, 3.)
print(y)
@@ -1015,7 +1011,7 @@ def debug_fwd(x):
return x, x
def debug_bwd(x, g):
- import pdb; pdb.set_trace()
+ pdb.set_trace()
return g
debug.defvjp(debug_fwd, debug_bwd)
diff --git a/docs/notebooks/Distributed_arrays_and_automatic_parallelization.ipynb b/docs/notebooks/Distributed_arrays_and_automatic_parallelization.ipynb
index 2face1d4a0b2..32d332d9ac7e 100644
--- a/docs/notebooks/Distributed_arrays_and_automatic_parallelization.ipynb
+++ b/docs/notebooks/Distributed_arrays_and_automatic_parallelization.ipynb
@@ -17,22 +17,20 @@
"id": "pFtQjv4SzHRj"
},
"source": [
- "[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/jax/blob/main/docs/notebooks/Distributed_arrays_and_automatic_parallelization.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/google/jax/blob/main/docs/notebooks/Distributed_arrays_and_automatic_parallelization.ipynb)\n",
+ "[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jax-ml/jax/blob/main/docs/notebooks/Distributed_arrays_and_automatic_parallelization.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/jax-ml/jax/blob/main/docs/notebooks/Distributed_arrays_and_automatic_parallelization.ipynb)\n",
"\n",
"This tutorial discusses parallelism via `jax.Array`, the unified array object model available in JAX v0.4.1 and newer."
]
},
{
"cell_type": "code",
- "execution_count": null,
+ "execution_count": 1,
"metadata": {
"id": "FNxScTfq3vGF"
},
"outputs": [],
"source": [
- "import os\n",
"\n",
- "import functools\n",
"from typing import Optional\n",
"\n",
"import numpy as np\n",
@@ -52,7 +50,7 @@
},
{
"cell_type": "code",
- "execution_count": null,
+ "execution_count": 3,
"metadata": {
"id": "IZMLqOUV3vGG"
},
@@ -70,7 +68,7 @@
"source": [
"## Intro and a quick example\n",
"\n",
- "By reading this tutorial notebook, you'll learn about `jax.Array`, a unified\n",
+ "By reading this tutorial notebook, you'll learn about `jax.Array`, a unified \n",
"datatype for representing arrays, even with physical storage spanning multiple\n",
"devices. You'll also learn about how using `jax.Array`s together with `jax.jit`\n",
"can provide automatic compiler-based parallelization.\n",
@@ -81,57 +79,76 @@
},
{
"cell_type": "code",
- "execution_count": null,
+ "execution_count": 4,
"metadata": {
"id": "Gf2lO4ii3vGG"
},
"outputs": [],
"source": [
"from jax.experimental import mesh_utils\n",
- "from jax.sharding import PositionalSharding"
+ "from jax.sharding import Mesh, PartitionSpec as P, NamedSharding"
]
},
{
"cell_type": "code",
- "execution_count": null,
+ "execution_count": 5,
"metadata": {
"id": "q-XBTEoy3vGG"
},
"outputs": [],
"source": [
"# Create a Sharding object to distribute a value across devices:\n",
- "sharding = PositionalSharding(mesh_utils.create_device_mesh((8,)))"
+ "mesh = Mesh(devices=mesh_utils.create_device_mesh((4, 2)),\n",
+ " axis_names=('x', 'y'))"
]
},
{
"cell_type": "code",
- "execution_count": null,
+ "execution_count": 6,
"metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 166
+ },
"id": "vI39znW93vGH",
- "outputId": "3b518df8-5c29-4848-acc3-e41df939f30b"
+ "outputId": "4f702753-8add-4b65-a4af-0f18f098cc46"
},
"outputs": [
{
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "┌──────────┬──────────┐\n",
- "│ TPU 0 │ TPU 1 │\n",
- "├──────────┼──────────┤\n",
- "│ TPU 2 │ TPU 3 │\n",
- "├──────────┼──────────┤\n",
- "│ TPU 6 │ TPU 7 │\n",
- "├──────────┼──────────┤\n",
- "│ TPU 4 │ TPU 5 │\n",
- "└──────────┴──────────┘\n"
- ]
+ "data": {
+ "text/html": [
+ "┌──────────┬──────────┐\n",
+ "│ TPU 0 │ TPU 1 │\n",
+ "├──────────┼──────────┤\n",
+ "│ TPU 2 │ TPU 3 │\n",
+ "├──────────┼──────────┤\n",
+ "│ TPU 6 │ TPU 7 │\n",
+ "├──────────┼──────────┤\n",
+ "│ TPU 4 │ TPU 5 │\n",
+ "└──────────┴──────────┘\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "┌──────────┬──────────┐\n",
+ "│ TPU \u001b[1;36m0\u001b[0m │ TPU \u001b[1;36m1\u001b[0m │\n",
+ "├──────────┼──────────┤\n",
+ "│ TPU \u001b[1;36m2\u001b[0m │ TPU \u001b[1;36m3\u001b[0m │\n",
+ "├──────────┼──────────┤\n",
+ "│ TPU \u001b[1;36m6\u001b[0m │ TPU \u001b[1;36m7\u001b[0m │\n",
+ "├──────────┼──────────┤\n",
+ "│ TPU \u001b[1;36m4\u001b[0m │ TPU \u001b[1;36m5\u001b[0m │\n",
+ "└──────────┴──────────┘\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
}
],
"source": [
"# Create an array of random values:\n",
"x = jax.random.normal(jax.random.key(0), (8192, 8192))\n",
"# and use jax.device_put to distribute it across devices:\n",
- "y = jax.device_put(x, sharding.reshape(4, 2))\n",
+ "y = jax.device_put(x, NamedSharding(mesh, P('x', 'y')))\n",
"jax.debug.visualize_array_sharding(y)"
]
},
@@ -147,26 +164,44 @@
},
{
"cell_type": "code",
- "execution_count": null,
+ "execution_count": 7,
"metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 166
+ },
"id": "-qCnHZl83vGI",
- "outputId": "9da9c29e-ce88-4425-e1ec-e93e5bcf3106"
+ "outputId": "0e131c23-5765-43ae-f232-6417ae1acbb2"
},
"outputs": [
{
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "┌──────────┬──────────┐\n",
- "│ TPU 0 │ TPU 1 │\n",
- "├──────────┼──────────┤\n",
- "│ TPU 2 │ TPU 3 │\n",
- "├──────────┼──────────┤\n",
- "│ TPU 6 │ TPU 7 │\n",
- "├──────────┼──────────┤\n",
- "│ TPU 4 │ TPU 5 │\n",
- "└──────────┴──────────┘\n"
- ]
+ "data": {
+ "text/html": [
+ "┌──────────┬──────────┐\n",
+ "│ TPU 0 │ TPU 1 │\n",
+ "├──────────┼──────────┤\n",
+ "│ TPU 2 │ TPU 3 │\n",
+ "├──────────┼──────────┤\n",
+ "│ TPU 6 │ TPU 7 │\n",
+ "├──────────┼──────────┤\n",
+ "│ TPU 4 │ TPU 5 │\n",
+ "└──────────┴──────────┘\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "┌──────────┬──────────┐\n",
+ "│ TPU \u001b[1;36m0\u001b[0m │ TPU \u001b[1;36m1\u001b[0m │\n",
+ "├──────────┼──────────┤\n",
+ "│ TPU \u001b[1;36m2\u001b[0m │ TPU \u001b[1;36m3\u001b[0m │\n",
+ "├──────────┼──────────┤\n",
+ "│ TPU \u001b[1;36m6\u001b[0m │ TPU \u001b[1;36m7\u001b[0m │\n",
+ "├──────────┼──────────┤\n",
+ "│ TPU \u001b[1;36m4\u001b[0m │ TPU \u001b[1;36m5\u001b[0m │\n",
+ "└──────────┴──────────┘\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
}
],
"source": [
@@ -186,18 +221,21 @@
},
{
"cell_type": "code",
- "execution_count": null,
+ "execution_count": 8,
"metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
"id": "_VTzN0r03vGI",
- "outputId": "c9208010-984b-442b-d105-c8c6a3a010e6"
+ "outputId": "c03eecab-4c86-4dac-d776-5fc72cbb5273"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
- "The slowest run took 13.32 times longer than the fastest. This could mean that an intermediate result is being cached \n",
- "5 loops, best of 5: 9.69 ms per loop\n"
+ "The slowest run took 8.96 times longer than the fastest. This could mean that an intermediate result is being cached.\n",
+ "25.2 ms ± 30.9 ms per loop (mean ± std. dev. of 5 runs, 5 loops each)\n"
]
}
],
@@ -208,17 +246,20 @@
},
{
"cell_type": "code",
- "execution_count": null,
+ "execution_count": 9,
"metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
"id": "QuzhU1g63vGI",
- "outputId": "d48fc76e-79a7-47b9-d392-b18a1c33c798"
+ "outputId": "8135cca0-871b-4b6a-a7e5-02e78c2028c7"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
- "5 loops, best of 5: 1.86 ms per loop\n"
+ "2.4 ms ± 61.4 µs per loop (mean ± std. dev. of 5 runs, 5 loops each)\n"
]
}
],
@@ -245,7 +286,7 @@
"id": "W6HsXauGxL6w"
},
"source": [
- "### Sharding basics, and the `PositionalSharding` subclass"
+ "### Sharding basics, and the `NamedSharding` subclass"
]
},
{
@@ -263,7 +304,7 @@
},
{
"cell_type": "code",
- "execution_count": null,
+ "execution_count": 10,
"metadata": {
"id": "VmoX4SUp3vGJ"
},
@@ -275,511 +316,109 @@
},
{
"cell_type": "code",
- "execution_count": null,
+ "execution_count": 11,
"metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 199
+ },
"id": "vNRabO2J3vGJ",
- "outputId": "73db7b6e-c2e7-467d-a0ef-c35e29e582dd"
- },
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "┌───────────────────────┐\n",
- "│ │\n",
- "│ │\n",
- "│ │\n",
- "│ │\n",
- "│ TPU 0 │\n",
- "│ │\n",
- "│ │\n",
- "│ │\n",
- "│ │\n",
- "└───────────────────────┘\n"
- ]
- }
- ],
- "source": [
- "jax.debug.visualize_array_sharding(x)"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "HhCjhK0zXIqX"
- },
- "source": [
- "Here, we're using the `jax.debug.visualize_array_sharding` function to show where the value `x` is stored in memory. All of `x` is stored on a single device, so the visualization is pretty boring!\n",
- "\n",
- "But we can shard `x` across multiple devices by using `jax.device_put` and a `Sharding` object. First, we make a `numpy.ndarray` of `Devices` using `mesh_utils.create_device_mesh`, which takes hardware topology into account for the `Device` order:"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "id": "VUIEIzRp3vGK"
- },
- "outputs": [],
- "source": [
- "from jax.experimental import mesh_utils\n",
- "devices = mesh_utils.create_device_mesh((8,))"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "lbOKFWmBX1iv"
- },
- "source": [
- "Then, we create a `PositionalSharding` and use it with `device_put`:"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "id": "jwrWfZeB3vGK",
- "outputId": "e6f126bd-f6bd-48c7-c130-6f02757e3342"
- },
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "┌───────────────────────┐\n",
- "│ TPU 0 │\n",
- "├───────────────────────┤\n",
- "│ TPU 1 │\n",
- "├───────────────────────┤\n",
- "│ TPU 2 │\n",
- "├───────────────────────┤\n",
- "│ TPU 3 │\n",
- "├───────────────────────┤\n",
- "│ TPU 6 │\n",
- "├───────────────────────┤\n",
- "│ TPU 7 │\n",
- "├───────────────────────┤\n",
- "│ TPU 4 │\n",
- "├───────────────────────┤\n",
- "│ TPU 5 │\n",
- "└───────────────────────┘\n"
- ]
- }
- ],
- "source": [
- "from jax.sharding import PositionalSharding\n",
- "\n",
- "sharding = PositionalSharding(devices)\n",
- "\n",
- "x = jax.device_put(x, sharding.reshape(8, 1))\n",
- "jax.debug.visualize_array_sharding(x)"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "TUu69IWXZdTm"
- },
- "source": [
- "Here `sharding` is a `PositionalSharding` which acts like an array with sets of devices as elements:"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "id": "zxWB82Kz3vGK",
- "outputId": "11384a6b-fabc-4c4c-bcad-a3be51eb0465"
+ "outputId": "40fd7172-a16c-4dd8-e2e1-17bb3afe5409"
},
"outputs": [
{
"data": {
+ "text/html": [
+ "┌───────────────────────┐\n",
+ "│ │\n",
+ "│ │\n",
+ "│ │\n",
+ "│ │\n",
+ "│ TPU 0 │\n",
+ "│ │\n",
+ "│ │\n",
+ "│ │\n",
+ "│ │\n",
+ "└───────────────────────┘\n",
+ "
\n"
+ ],
"text/plain": [
- "PositionalSharding([{TPU 0} {TPU 1} {TPU 2} {TPU 3} {TPU 6} {TPU 7} {TPU 4} {TPU 5}])"
+ "┌───────────────────────┐\n",
+ "│ │\n",
+ "│ │\n",
+ "│ │\n",
+ "│ │\n",
+ "│ TPU \u001b[1;36m0\u001b[0m │\n",
+ "│ │\n",
+ "│ │\n",
+ "│ │\n",
+ "│ │\n",
+ "└───────────────────────┘\n"
]
},
- "execution_count": 13,
"metadata": {},
- "output_type": "execute_result"
+ "output_type": "display_data"
}
],
"source": [
- "sharding"
+ "jax.debug.visualize_array_sharding(x)"
]
},
{
"cell_type": "markdown",
"metadata": {
- "id": "uRLpOcmNj_Vt"
+ "id": "HhCjhK0zXIqX"
},
"source": [
- "The device numbers here are not in numerical order, because the mesh reflects the underlying toroidal topology of the device.\n",
+ "Here, we're using the `jax.debug.visualize_array_sharding` function to show where the value `x` is stored in memory. All of `x` is stored on a single device, so the visualization is pretty boring!\n",
"\n",
- "By writing `PositionalSharding(ndarray_of_devices)`, we fix the device order and the initial shape. Then we can reshape it:"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "id": "PLsnpSzc3vGL",
- "outputId": "9f4db733-cafe-46ae-c057-dc31046a6f66"
- },
- "outputs": [
- {
- "data": {
- "text/plain": [
- "PositionalSharding([[{TPU 0}]\n",
- " [{TPU 1}]\n",
- " [{TPU 2}]\n",
- " [{TPU 3}]\n",
- " [{TPU 6}]\n",
- " [{TPU 7}]\n",
- " [{TPU 4}]\n",
- " [{TPU 5}]])"
- ]
- },
- "execution_count": 14,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "sharding.reshape(8, 1)"
+ "But we can shard `x` across multiple devices by using `jax.device_put` and a `Sharding` object. First, we make a `numpy.ndarray` of `Devices` using `mesh_utils.create_device_mesh`, which takes hardware topology into account for the `Device` order:"
]
},
{
"cell_type": "code",
- "execution_count": null,
+ "execution_count": 16,
"metadata": {
- "id": "iqKdI4LO3vGL",
- "outputId": "6aa10fc2-cec4-4401-a0df-343e71646e0a"
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 166
+ },
+ "id": "zpB1JxyK3vGN",
+ "outputId": "8e385462-1c2c-4256-c38a-84299d3bd02c"
},
"outputs": [
{
"data": {
+ "text/html": [
+ "┌──────────┬──────────┐\n",
+ "│ TPU 0 │ TPU 1 │\n",
+ "├──────────┼──────────┤\n",
+ "│ TPU 2 │ TPU 3 │\n",
+ "├──────────┼──────────┤\n",
+ "│ TPU 6 │ TPU 7 │\n",
+ "├──────────┼──────────┤\n",
+ "│ TPU 4 │ TPU 5 │\n",
+ "└──────────┴──────────┘\n",
+ "
\n"
+ ],
"text/plain": [
- "PositionalSharding([[{TPU 0} {TPU 1}]\n",
- " [{TPU 2} {TPU 3}]\n",
- " [{TPU 6} {TPU 7}]\n",
- " [{TPU 4} {TPU 5}]])"
+ "┌──────────┬──────────┐\n",
+ "│ TPU \u001b[1;36m0\u001b[0m │ TPU \u001b[1;36m1\u001b[0m │\n",
+ "├──────────┼──────────┤\n",
+ "│ TPU \u001b[1;36m2\u001b[0m │ TPU \u001b[1;36m3\u001b[0m │\n",
+ "├──────────┼──────────┤\n",
+ "│ TPU \u001b[1;36m6\u001b[0m │ TPU \u001b[1;36m7\u001b[0m │\n",
+ "├──────────┼──────────┤\n",
+ "│ TPU \u001b[1;36m4\u001b[0m │ TPU \u001b[1;36m5\u001b[0m │\n",
+ "└──────────┴──────────┘\n"
]
},
- "execution_count": 15,
"metadata": {},
- "output_type": "execute_result"
+ "output_type": "display_data"
}
],
"source": [
- "sharding.reshape(4, 2)"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "KBu6WLfhm7ra"
- },
- "source": [
- "To use `device_put` with a data array `x`, we can reshape the `sharding` into a shape that is _congruent_ with `x.shape`, meaning a shape with the same length as `x.shape` and where each element evenly divides the corresponding element of `x.shape`:\n",
- "```python\n",
- "def is_congruent(x_shape: Sequence[int], sharding_shape: Sequence[int]) -> bool:\n",
- " return (len(x_shape) == len(sharding_shape) and\n",
- " all(d1 % d2 == 0 for d1, d2 in zip(x_shape, sharding_shape)))\n",
- "```\n",
- "\n",
- "For example, we can reshape `sharding` to have shape `(4, 2)`, then use it in a `device_put`:"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "id": "SELr4xNi3vGL",
- "outputId": "b2f4acec-0cd3-4829-ca16-cae2e0e8ca60"
- },
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "PositionalSharding([[{TPU 0} {TPU 1}]\n",
- " [{TPU 2} {TPU 3}]\n",
- " [{TPU 6} {TPU 7}]\n",
- " [{TPU 4} {TPU 5}]])\n"
- ]
- }
- ],
- "source": [
- "sharding = sharding.reshape(4, 2)\n",
- "print(sharding)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "id": "8IVIsqfX3vGL",
- "outputId": "033d0e02-a643-4f4c-9d24-9cd8465bc69a"
- },
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "┌──────────┬──────────┐\n",
- "│ TPU 0 │ TPU 1 │\n",
- "├──────────┼──────────┤\n",
- "│ TPU 2 │ TPU 3 │\n",
- "├──────────┼──────────┤\n",
- "│ TPU 6 │ TPU 7 │\n",
- "├──────────┼──────────┤\n",
- "│ TPU 4 │ TPU 5 │\n",
- "└──────────┴──────────┘\n"
- ]
- }
- ],
- "source": [
- "y = jax.device_put(x, sharding)\n",
- "jax.debug.visualize_array_sharding(y)"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "tyg9F-UIsU__"
- },
- "source": [
- "Here `y` represents the same _value_ as `x`, but its shards (i.e. slices) are stored in different devices' memories.\n",
- "\n",
- "Different `PositionalSharding` shapes result in different distributed layouts (i.e. shardings) of the result:"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "id": "cCjt6QCz3vGM",
- "outputId": "4ad8a611-596d-424f-b6c5-fc00f1adc306"
- },
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "PositionalSharding([[{TPU 0} {TPU 1} {TPU 2} {TPU 3} {TPU 6} {TPU 7} {TPU 4} {TPU 5}]])\n"
- ]
- }
- ],
- "source": [
- "sharding = sharding.reshape(1, 8)\n",
- "print(sharding)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "id": "yTK4Nz3u3vGM",
- "outputId": "e445c6bc-4fe3-4e9d-cc9e-d82858f58312"
- },
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "┌───────┬───────┬───────┬───────┬───────┬───────┬───────┬───────┐\n",
- "│ │ │ │ │ │ │ │ │\n",
- "│ │ │ │ │ │ │ │ │\n",
- "│ │ │ │ │ │ │ │ │\n",
- "│ │ │ │ │ │ │ │ │\n",
- "│ TPU 0 │ TPU 1 │ TPU 2 │ TPU 3 │ TPU 6 │ TPU 7 │ TPU 4 │ TPU 5 │\n",
- "│ │ │ │ │ │ │ │ │\n",
- "│ │ │ │ │ │ │ │ │\n",
- "│ │ │ │ │ │ │ │ │\n",
- "│ │ │ │ │ │ │ │ │\n",
- "└───────┴───────┴───────┴───────┴───────┴───────┴───────┴───────┘\n"
- ]
- }
- ],
- "source": [
- "y = jax.device_put(x, sharding)\n",
- "jax.debug.visualize_array_sharding(y)"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "0PuamOvXubcf"
- },
- "source": [
- "In some cases, we don't just want to store each slice of `x` in a single device's memory; we might want to _replicate_ some slices, meaning storing copies of a slice's values in multiple devices' memories.\n",
- "\n",
- "With `PositionalSharding`, we can express replication by calling the reducer method `replicate`:"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "id": "_jr6XYKx3vGM",
- "outputId": "59c8b9a4-b8af-493a-ba8d-da5931e88f93"
- },
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "PositionalSharding([[{TPU 0, 2, 4, 6} {TPU 1, 3, 5, 7}]])\n"
- ]
- }
- ],
- "source": [
- "sharding = sharding.reshape(4, 2)\n",
- "print(sharding.replicate(axis=0, keepdims=True))"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "id": "S5vzjFuH3vGN",
- "outputId": "b6ce2675-7261-4e57-fa8c-b4e87abf7e52"
- },
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "┌───────────┬───────────┐\n",
- "│ │ │\n",
- "│ │ │\n",
- "│ │ │\n",
- "│ │ │\n",
- "│TPU 0,2,4,6│TPU 1,3,5,7│\n",
- "│ │ │\n",
- "│ │ │\n",
- "│ │ │\n",
- "│ │ │\n",
- "└───────────┴───────────┘\n"
- ]
- }
- ],
- "source": [
- "y = jax.device_put(x, sharding.replicate(axis=0, keepdims=True))\n",
- "jax.debug.visualize_array_sharding(y)"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "FzeP0kpTvJv-"
- },
- "source": [
- "Here the visualization shows that `x` is sharded two ways along its second dimension (and not sharded along the first dimension), and each of those shards is replicated four ways (i.e. stored in four device memories).\n",
- "\n",
- "The `replicate` method is analogous to the familiar NumPy array reduction methods like `.sum()` and `.prod()`. It operates along an axis performing a set union. So if `sharding` has shape `(4, 2)`, then `sharding.replicate(0, keepdims=True)` has shape `(1, 2)`, and `sharding.replicate(1, keepdims=True)` has shape `(4, 1)`. Unlike analogous NumPy methods, `keepdims=True` is actually the default, so reduced-over axes aren't squeezed:"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "id": "DR7VV-6e3vGN",
- "outputId": "f879fc2c-5723-4199-b306-295bc1b3681e"
- },
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "(1, 2)\n",
- "(4, 1)\n"
- ]
- }
- ],
- "source": [
- "print(sharding.replicate(0).shape)\n",
- "print(sharding.replicate(1).shape)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "id": "agUtVUVx3vGN",
- "outputId": "0e9789ef-ce52-4ed6-8bd5-c876b95f66e6"
- },
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "┌───────────────────────┐\n",
- "│ TPU 0,1 │\n",
- "├───────────────────────┤\n",
- "│ TPU 2,3 │\n",
- "├───────────────────────┤\n",
- "│ TPU 6,7 │\n",
- "├───────────────────────┤\n",
- "│ TPU 4,5 │\n",
- "└───────────────────────┘\n"
- ]
- }
- ],
- "source": [
- "y = jax.device_put(x, sharding.replicate(1))\n",
- "jax.debug.visualize_array_sharding(y)"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "D31t5POXxHHJ"
- },
- "source": [
- "### `NamedSharding` gives a way to express shardings with names"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "ayMKWeTmxl-X"
- },
- "source": [
- "So far we've worked with `PositionalSharding`, but there are alternative ways to express shardings. In fact, `Sharding` is an interface, and any class that implements that interface can be used with functions like `device_put`.\n",
- "\n",
- "Another convenient way to express sharding is with the `NamedSharding`:"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "id": "zpB1JxyK3vGN",
- "outputId": "46d5da37-840c-49d8-8380-a162811bae8a"
- },
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "┌──────────┬──────────┐\n",
- "│ TPU 0 │ TPU 1 │\n",
- "├──────────┼──────────┤\n",
- "│ TPU 2 │ TPU 3 │\n",
- "├──────────┼──────────┤\n",
- "│ TPU 6 │ TPU 7 │\n",
- "├──────────┼──────────┤\n",
- "│ TPU 4 │ TPU 5 │\n",
- "└──────────┴──────────┘\n"
- ]
- }
- ],
- "source": [
- "from jax.sharding import Mesh\n",
- "from jax.sharding import PartitionSpec\n",
- "from jax.sharding import NamedSharding\n",
+ "from jax.sharding import Mesh, PartitionSpec, NamedSharding\n",
"from jax.experimental import mesh_utils\n",
"\n",
"P = PartitionSpec\n",
@@ -801,7 +440,7 @@
},
{
"cell_type": "code",
- "execution_count": null,
+ "execution_count": 17,
"metadata": {
"id": "8g0Md2Gd3vGO"
},
@@ -820,26 +459,44 @@
},
{
"cell_type": "code",
- "execution_count": null,
+ "execution_count": 18,
"metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 166
+ },
"id": "zp3MfS4Y3vGO",
- "outputId": "2c2f7201-c2c1-49e5-f8a5-0730c124d89a"
+ "outputId": "032fdd7e-19a1-45da-e1ad-b3227fa43ee6"
},
"outputs": [
{
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "┌──────────┬──────────┐\n",
- "│ TPU 0 │ TPU 1 │\n",
- "├──────────┼──────────┤\n",
- "│ TPU 2 │ TPU 3 │\n",
- "├──────────┼──────────┤\n",
- "│ TPU 6 │ TPU 7 │\n",
- "├──────────┼──────────┤\n",
- "│ TPU 4 │ TPU 5 │\n",
- "└──────────┴──────────┘\n"
- ]
+ "data": {
+ "text/html": [
+ "┌──────────┬──────────┐\n",
+ "│ TPU 0 │ TPU 1 │\n",
+ "├──────────┼──────────┤\n",
+ "│ TPU 2 │ TPU 3 │\n",
+ "├──────────┼──────────┤\n",
+ "│ TPU 6 │ TPU 7 │\n",
+ "├──────────┼──────────┤\n",
+ "│ TPU 4 │ TPU 5 │\n",
+ "└──────────┴──────────┘\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "┌──────────┬──────────┐\n",
+ "│ TPU \u001b[1;36m0\u001b[0m │ TPU \u001b[1;36m1\u001b[0m │\n",
+ "├──────────┼──────────┤\n",
+ "│ TPU \u001b[1;36m2\u001b[0m │ TPU \u001b[1;36m3\u001b[0m │\n",
+ "├──────────┼──────────┤\n",
+ "│ TPU \u001b[1;36m6\u001b[0m │ TPU \u001b[1;36m7\u001b[0m │\n",
+ "├──────────┼──────────┤\n",
+ "│ TPU \u001b[1;36m4\u001b[0m │ TPU \u001b[1;36m5\u001b[0m │\n",
+ "└──────────┴──────────┘\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
}
],
"source": [
@@ -858,28 +515,48 @@
},
{
"cell_type": "code",
- "execution_count": null,
+ "execution_count": 19,
"metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 199
+ },
"id": "FigK5Zsa3vGO",
- "outputId": "eca784e8-33fe-4e9b-a41d-21e9ee781a35"
+ "outputId": "e488d073-9d02-4376-a6af-19d6d5509c7d"
},
"outputs": [
{
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "┌───────┬───────┬───────┬───────┐\n",
- "│ │ │ │ │\n",
- "│ TPU 0 │ TPU 2 │ TPU 6 │ TPU 4 │\n",
- "│ │ │ │ │\n",
- "│ │ │ │ │\n",
- "├───────┼───────┼───────┼───────┤\n",
- "│ │ │ │ │\n",
- "│ TPU 1 │ TPU 3 │ TPU 7 │ TPU 5 │\n",
- "│ │ │ │ │\n",
- "│ │ │ │ │\n",
- "└───────┴───────┴───────┴───────┘\n"
- ]
+ "data": {
+ "text/html": [
+ "┌───────┬───────┬───────┬───────┐\n",
+ "│ │ │ │ │\n",
+ "│ TPU 0 │ TPU 2 │ TPU 6 │ TPU 4 │\n",
+ "│ │ │ │ │\n",
+ "│ │ │ │ │\n",
+ "├───────┼───────┼───────┼───────┤\n",
+ "│ │ │ │ │\n",
+ "│ TPU 1 │ TPU 3 │ TPU 7 │ TPU 5 │\n",
+ "│ │ │ │ │\n",
+ "│ │ │ │ │\n",
+ "└───────┴───────┴───────┴───────┘\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "┌───────┬───────┬───────┬───────┐\n",
+ "│ │ │ │ │\n",
+ "│ TPU \u001b[1;36m0\u001b[0m │ TPU \u001b[1;36m2\u001b[0m │ TPU \u001b[1;36m6\u001b[0m │ TPU \u001b[1;36m4\u001b[0m │\n",
+ "│ │ │ │ │\n",
+ "│ │ │ │ │\n",
+ "├───────┼───────┼───────┼───────┤\n",
+ "│ │ │ │ │\n",
+ "│ TPU \u001b[1;36m1\u001b[0m │ TPU \u001b[1;36m3\u001b[0m │ TPU \u001b[1;36m7\u001b[0m │ TPU \u001b[1;36m5\u001b[0m │\n",
+ "│ │ │ │ │\n",
+ "│ │ │ │ │\n",
+ "└───────┴───────┴───────┴───────┘\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
}
],
"source": [
@@ -889,26 +566,44 @@
},
{
"cell_type": "code",
- "execution_count": null,
+ "execution_count": 20,
"metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 166
+ },
"id": "hI-HD0xN3vGO",
- "outputId": "c3e7dc3c-4048-448a-ef0b-50683532fcdc"
+ "outputId": "b0c2e863-3aee-4417-b45f-21b2187f6ef7"
},
"outputs": [
{
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "┌───────────────────────┐\n",
- "│ TPU 0,1 │\n",
- "├───────────────────────┤\n",
- "│ TPU 2,3 │\n",
- "├───────────────────────┤\n",
- "│ TPU 6,7 │\n",
- "├───────────────────────┤\n",
- "│ TPU 4,5 │\n",
- "└───────────────────────┘\n"
- ]
+ "data": {
+ "text/html": [
+ "┌───────────────────────┐\n",
+ "│ TPU 0,1 │\n",
+ "├───────────────────────┤\n",
+ "│ TPU 2,3 │\n",
+ "├───────────────────────┤\n",
+ "│ TPU 6,7 │\n",
+ "├───────────────────────┤\n",
+ "│ TPU 4,5 │\n",
+ "└───────────────────────┘\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "┌───────────────────────┐\n",
+ "│ TPU \u001b[1;36m0\u001b[0m,\u001b[1;36m1\u001b[0m │\n",
+ "├───────────────────────┤\n",
+ "│ TPU \u001b[1;36m2\u001b[0m,\u001b[1;36m3\u001b[0m │\n",
+ "├───────────────────────┤\n",
+ "│ TPU \u001b[1;36m6\u001b[0m,\u001b[1;36m7\u001b[0m │\n",
+ "├───────────────────────┤\n",
+ "│ TPU \u001b[1;36m4\u001b[0m,\u001b[1;36m5\u001b[0m │\n",
+ "└───────────────────────┘\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
}
],
"source": [
@@ -932,28 +627,48 @@
},
{
"cell_type": "code",
- "execution_count": null,
+ "execution_count": 21,
"metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 199
+ },
"id": "EXBExMQC3vGP",
- "outputId": "fe1c8d7e-3345-4438-b9d2-780e7854b4eb"
+ "outputId": "c80e6177-12a6-40ef-b4e4-934dad22da3d"
},
"outputs": [
{
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "┌───────────┬───────────┐\n",
- "│ │ │\n",
- "│ │ │\n",
- "│ │ │\n",
- "│ │ │\n",
- "│TPU 0,2,4,6│TPU 1,3,5,7│\n",
- "│ │ │\n",
- "│ │ │\n",
- "│ │ │\n",
- "│ │ │\n",
- "└───────────┴───────────┘\n"
- ]
+ "data": {
+ "text/html": [
+ "┌───────────┬───────────┐\n",
+ "│ │ │\n",
+ "│ │ │\n",
+ "│ │ │\n",
+ "│ │ │\n",
+ "│TPU 0,2,4,6│TPU 1,3,5,7│\n",
+ "│ │ │\n",
+ "│ │ │\n",
+ "│ │ │\n",
+ "│ │ │\n",
+ "└───────────┴───────────┘\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "┌───────────┬───────────┐\n",
+ "│ │ │\n",
+ "│ │ │\n",
+ "│ │ │\n",
+ "│ │ │\n",
+ "│TPU \u001b[1;36m0\u001b[0m,\u001b[1;36m2\u001b[0m,\u001b[1;36m4\u001b[0m,\u001b[1;36m6\u001b[0m│TPU \u001b[1;36m1\u001b[0m,\u001b[1;36m3\u001b[0m,\u001b[1;36m5\u001b[0m,\u001b[1;36m7\u001b[0m│\n",
+ "│ │ │\n",
+ "│ │ │\n",
+ "│ │ │\n",
+ "│ │ │\n",
+ "└───────────┴───────────┘\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
}
],
"source": [
@@ -963,28 +678,48 @@
},
{
"cell_type": "code",
- "execution_count": null,
+ "execution_count": 22,
"metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 199
+ },
"id": "PjUpG8uz3vGP",
- "outputId": "64d8224d-15d9-4ad4-d613-f7f85b1dc1af"
+ "outputId": "a0f59dc5-b509-4b8b-bd22-bcd69f696763"
},
"outputs": [
{
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "┌───────┬───────┬───────┬───────┐\n",
- "│ │ │ │ │\n",
- "│ │ │ │ │\n",
- "│ │ │ │ │\n",
- "│ │ │ │ │\n",
- "│TPU 0,1│TPU 2,3│TPU 6,7│TPU 4,5│\n",
- "│ │ │ │ │\n",
- "│ │ │ │ │\n",
- "│ │ │ │ │\n",
- "│ │ │ │ │\n",
- "└───────┴───────┴───────┴───────┘\n"
- ]
+ "data": {
+ "text/html": [
+ "┌───────┬───────┬───────┬───────┐\n",
+ "│ │ │ │ │\n",
+ "│ │ │ │ │\n",
+ "│ │ │ │ │\n",
+ "│ │ │ │ │\n",
+ "│TPU 0,1│TPU 2,3│TPU 6,7│TPU 4,5│\n",
+ "│ │ │ │ │\n",
+ "│ │ │ │ │\n",
+ "│ │ │ │ │\n",
+ "│ │ │ │ │\n",
+ "└───────┴───────┴───────┴───────┘\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "┌───────┬───────┬───────┬───────┐\n",
+ "│ │ │ │ │\n",
+ "│ │ │ │ │\n",
+ "│ │ │ │ │\n",
+ "│ │ │ │ │\n",
+ "│TPU \u001b[1;36m0\u001b[0m,\u001b[1;36m1\u001b[0m│TPU \u001b[1;36m2\u001b[0m,\u001b[1;36m3\u001b[0m│TPU \u001b[1;36m6\u001b[0m,\u001b[1;36m7\u001b[0m│TPU \u001b[1;36m4\u001b[0m,\u001b[1;36m5\u001b[0m│\n",
+ "│ │ │ │ │\n",
+ "│ │ │ │ │\n",
+ "│ │ │ │ │\n",
+ "│ │ │ │ │\n",
+ "└───────┴───────┴───────┴───────┘\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
}
],
"source": [
@@ -1003,34 +738,60 @@
},
{
"cell_type": "code",
- "execution_count": null,
+ "execution_count": 23,
"metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 298
+ },
"id": "fVcPbDUA3vGP",
- "outputId": "7f524ba5-a6d8-4490-cda9-685ad11416f9"
+ "outputId": "da3f435d-dfc1-4a41-ec90-691cd7c748a0"
},
"outputs": [
{
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "┌───────────────────────┐\n",
- "│ TPU 0 │\n",
- "├───────────────────────┤\n",
- "│ TPU 1 │\n",
- "├───────────────────────┤\n",
- "│ TPU 2 │\n",
- "├───────────────────────┤\n",
- "│ TPU 3 │\n",
- "├───────────────────────┤\n",
- "│ TPU 6 │\n",
- "├───────────────────────┤\n",
- "│ TPU 7 │\n",
- "├───────────────────────┤\n",
- "│ TPU 4 │\n",
- "├───────────────────────┤\n",
- "│ TPU 5 │\n",
- "└───────────────────────┘\n"
- ]
+ "data": {
+ "text/html": [
+ "┌───────────────────────┐\n",
+ "│ TPU 0 │\n",
+ "├───────────────────────┤\n",
+ "│ TPU 1 │\n",
+ "├───────────────────────┤\n",
+ "│ TPU 2 │\n",
+ "├───────────────────────┤\n",
+ "│ TPU 3 │\n",
+ "├───────────────────────┤\n",
+ "│ TPU 6 │\n",
+ "├───────────────────────┤\n",
+ "│ TPU 7 │\n",
+ "├───────────────────────┤\n",
+ "│ TPU 4 │\n",
+ "├───────────────────────┤\n",
+ "│ TPU 5 │\n",
+ "└───────────────────────┘\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "┌───────────────────────┐\n",
+ "│ TPU \u001b[1;36m0\u001b[0m │\n",
+ "├───────────────────────┤\n",
+ "│ TPU \u001b[1;36m1\u001b[0m │\n",
+ "├───────────────────────┤\n",
+ "│ TPU \u001b[1;36m2\u001b[0m │\n",
+ "├───────────────────────┤\n",
+ "│ TPU \u001b[1;36m3\u001b[0m │\n",
+ "├───────────────────────┤\n",
+ "│ TPU \u001b[1;36m6\u001b[0m │\n",
+ "├───────────────────────┤\n",
+ "│ TPU \u001b[1;36m7\u001b[0m │\n",
+ "├───────────────────────┤\n",
+ "│ TPU \u001b[1;36m4\u001b[0m │\n",
+ "├───────────────────────┤\n",
+ "│ TPU \u001b[1;36m5\u001b[0m │\n",
+ "└───────────────────────┘\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
}
],
"source": [
@@ -1069,54 +830,103 @@
},
{
"cell_type": "code",
- "execution_count": null,
+ "execution_count": 24,
"metadata": {
"id": "_EmQwggc3vGQ"
},
"outputs": [],
"source": [
- "from jax.experimental import mesh_utils\n",
- "from jax.sharding import PositionalSharding\n",
- "sharding = PositionalSharding(mesh_utils.create_device_mesh((8,)))"
+ "devices = mesh_utils.create_device_mesh((4, 2))\n",
+ "mesh = Mesh(devices, axis_names=('a', 'b'))"
]
},
{
"cell_type": "code",
- "execution_count": null,
+ "execution_count": 25,
"metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 349
+ },
"id": "LnT0vWjc3vGQ",
- "outputId": "8089effc-aa4c-49e3-dd19-7064881dbad0"
+ "outputId": "8e642049-61eb-458d-af79-ac449b58d11b"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
- "input sharding:\n",
- "┌──────────┬──────────┐\n",
- "│ TPU 0 │ TPU 1 │\n",
- "├──────────┼──────────┤\n",
- "│ TPU 2 │ TPU 3 │\n",
- "├──────────┼──────────┤\n",
- "│ TPU 6 │ TPU 7 │\n",
- "├──────────┼──────────┤\n",
- "│ TPU 4 │ TPU 5 │\n",
- "└──────────┴──────────┘\n",
- "output sharding:\n",
- "┌──────────┬──────────┐\n",
- "│ TPU 0 │ TPU 1 │\n",
- "├──────────┼──────────┤\n",
- "│ TPU 2 │ TPU 3 │\n",
- "├──────────┼──────────┤\n",
- "│ TPU 6 │ TPU 7 │\n",
- "├──────────┼──────────┤\n",
- "│ TPU 4 │ TPU 5 │\n",
- "└──────────┴──────────┘\n"
+ "input sharding:\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "┌──────────┬──────────┐\n",
+ "│ TPU 0 │ TPU 1 │\n",
+ "├──────────┼──────────┤\n",
+ "│ TPU 2 │ TPU 3 │\n",
+ "├──────────┼──────────┤\n",
+ "│ TPU 6 │ TPU 7 │\n",
+ "├──────────┼──────────┤\n",
+ "│ TPU 4 │ TPU 5 │\n",
+ "└──────────┴──────────┘\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "┌──────────┬──────────┐\n",
+ "│ TPU \u001b[1;36m0\u001b[0m │ TPU \u001b[1;36m1\u001b[0m │\n",
+ "├──────────┼──────────┤\n",
+ "│ TPU \u001b[1;36m2\u001b[0m │ TPU \u001b[1;36m3\u001b[0m │\n",
+ "├──────────┼──────────┤\n",
+ "│ TPU \u001b[1;36m6\u001b[0m │ TPU \u001b[1;36m7\u001b[0m │\n",
+ "├──────────┼──────────┤\n",
+ "│ TPU \u001b[1;36m4\u001b[0m │ TPU \u001b[1;36m5\u001b[0m │\n",
+ "└──────────┴──────────┘\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "output sharding:\n"
]
+ },
+ {
+ "data": {
+ "text/html": [
+ "┌──────────┬──────────┐\n",
+ "│ TPU 0 │ TPU 1 │\n",
+ "├──────────┼──────────┤\n",
+ "│ TPU 2 │ TPU 3 │\n",
+ "├──────────┼──────────┤\n",
+ "│ TPU 6 │ TPU 7 │\n",
+ "├──────────┼──────────┤\n",
+ "│ TPU 4 │ TPU 5 │\n",
+ "└──────────┴──────────┘\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "┌──────────┬──────────┐\n",
+ "│ TPU \u001b[1;36m0\u001b[0m │ TPU \u001b[1;36m1\u001b[0m │\n",
+ "├──────────┼──────────┤\n",
+ "│ TPU \u001b[1;36m2\u001b[0m │ TPU \u001b[1;36m3\u001b[0m │\n",
+ "├──────────┼──────────┤\n",
+ "│ TPU \u001b[1;36m6\u001b[0m │ TPU \u001b[1;36m7\u001b[0m │\n",
+ "├──────────┼──────────┤\n",
+ "│ TPU \u001b[1;36m4\u001b[0m │ TPU \u001b[1;36m5\u001b[0m │\n",
+ "└──────────┴──────────┘\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
}
],
"source": [
- "x = jax.device_put(x, sharding.reshape(4, 2))\n",
+ "x = jax.device_put(x, NamedSharding(mesh, P('a', 'b')))\n",
"print('input sharding:')\n",
"jax.debug.visualize_array_sharding(x)\n",
"\n",
@@ -1140,54 +950,132 @@
},
{
"cell_type": "code",
- "execution_count": null,
+ "execution_count": 26,
"metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 548
+ },
"id": "Dq043GkP3vGQ",
- "outputId": "350219a8-1e4a-4404-fe14-50f97ea3e7ba"
+ "outputId": "3eff7b67-d7f0-4212-c9d3-2cc271ac1f98"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
- "lhs sharding:\n",
- "┌───────────────────────┐\n",
- "│ TPU 0,1 │\n",
- "├───────────────────────┤\n",
- "│ TPU 2,3 │\n",
- "├───────────────────────┤\n",
- "│ TPU 6,7 │\n",
- "├───────────────────────┤\n",
- "│ TPU 4,5 │\n",
- "└───────────────────────┘\n",
- "rhs sharding:\n",
- "┌───────────┬───────────┐\n",
- "│ │ │\n",
- "│ │ │\n",
- "│ │ │\n",
- "│ │ │\n",
- "│TPU 0,2,4,6│TPU 1,3,5,7│\n",
- "│ │ │\n",
- "│ │ │\n",
- "│ │ │\n",
- "│ │ │\n",
- "└───────────┴───────────┘\n",
- "out sharding:\n",
- "┌──────────┬──────────┐\n",
- "│ TPU 0 │ TPU 1 │\n",
- "├──────────┼──────────┤\n",
- "│ TPU 2 │ TPU 3 │\n",
- "├──────────┼──────────┤\n",
- "│ TPU 6 │ TPU 7 │\n",
- "├──────────┼──────────┤\n",
- "│ TPU 4 │ TPU 5 │\n",
- "└──────────┴──────────┘\n"
+ "lhs sharding:\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "┌───────────────────────┐\n",
+ "│ TPU 0,1 │\n",
+ "├───────────────────────┤\n",
+ "│ TPU 2,3 │\n",
+ "├───────────────────────┤\n",
+ "│ TPU 6,7 │\n",
+ "├───────────────────────┤\n",
+ "│ TPU 4,5 │\n",
+ "└───────────────────────┘\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "┌───────────────────────┐\n",
+ "│ TPU \u001b[1;36m0\u001b[0m,\u001b[1;36m1\u001b[0m │\n",
+ "├───────────────────────┤\n",
+ "│ TPU \u001b[1;36m2\u001b[0m,\u001b[1;36m3\u001b[0m │\n",
+ "├───────────────────────┤\n",
+ "│ TPU \u001b[1;36m6\u001b[0m,\u001b[1;36m7\u001b[0m │\n",
+ "├───────────────────────┤\n",
+ "│ TPU \u001b[1;36m4\u001b[0m,\u001b[1;36m5\u001b[0m │\n",
+ "└───────────────────────┘\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "rhs sharding:\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "┌───────────┬───────────┐\n",
+ "│ │ │\n",
+ "│ │ │\n",
+ "│ │ │\n",
+ "│ │ │\n",
+ "│TPU 0,2,4,6│TPU 1,3,5,7│\n",
+ "│ │ │\n",
+ "│ │ │\n",
+ "│ │ │\n",
+ "│ │ │\n",
+ "└───────────┴───────────┘\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "┌───────────┬───────────┐\n",
+ "│ │ │\n",
+ "│ │ │\n",
+ "│ │ │\n",
+ "│ │ │\n",
+ "│TPU \u001b[1;36m0\u001b[0m,\u001b[1;36m2\u001b[0m,\u001b[1;36m4\u001b[0m,\u001b[1;36m6\u001b[0m│TPU \u001b[1;36m1\u001b[0m,\u001b[1;36m3\u001b[0m,\u001b[1;36m5\u001b[0m,\u001b[1;36m7\u001b[0m│\n",
+ "│ │ │\n",
+ "│ │ │\n",
+ "│ │ │\n",
+ "│ │ │\n",
+ "└───────────┴───────────┘\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "out sharding:\n"
]
+ },
+ {
+ "data": {
+ "text/html": [
+ "┌──────────┬──────────┐\n",
+ "│ TPU 0 │ TPU 1 │\n",
+ "├──────────┼──────────┤\n",
+ "│ TPU 2 │ TPU 3 │\n",
+ "├──────────┼──────────┤\n",
+ "│ TPU 6 │ TPU 7 │\n",
+ "├──────────┼──────────┤\n",
+ "│ TPU 4 │ TPU 5 │\n",
+ "└──────────┴──────────┘\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "┌──────────┬──────────┐\n",
+ "│ TPU \u001b[1;36m0\u001b[0m │ TPU \u001b[1;36m1\u001b[0m │\n",
+ "├──────────┼──────────┤\n",
+ "│ TPU \u001b[1;36m2\u001b[0m │ TPU \u001b[1;36m3\u001b[0m │\n",
+ "├──────────┼──────────┤\n",
+ "│ TPU \u001b[1;36m6\u001b[0m │ TPU \u001b[1;36m7\u001b[0m │\n",
+ "├──────────┼──────────┤\n",
+ "│ TPU \u001b[1;36m4\u001b[0m │ TPU \u001b[1;36m5\u001b[0m │\n",
+ "└──────────┴──────────┘\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
}
],
"source": [
- "y = jax.device_put(x, sharding.reshape(4, 2).replicate(1))\n",
- "z = jax.device_put(x, sharding.reshape(4, 2).replicate(0))\n",
+ "y = jax.device_put(x, NamedSharding(mesh, P('a', None)))\n",
+ "z = jax.device_put(x, NamedSharding(mesh, P(None, 'b')))\n",
"print('lhs sharding:')\n",
"jax.debug.visualize_array_sharding(y)\n",
"print('rhs sharding:')\n",
@@ -1211,28 +1099,48 @@
},
{
"cell_type": "code",
- "execution_count": null,
+ "execution_count": 27,
"metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 199
+ },
"id": "QjQ5u8qh3vGQ",
- "outputId": "bd29edcd-b87c-486e-c568-906f06ae16be"
+ "outputId": "0aefc170-833c-4a6a-e003-5990d3db31d9"
},
"outputs": [
{
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "┌───────────────────────┐\n",
- "│ │\n",
- "│ │\n",
- "│ │\n",
- "│ │\n",
- "│ TPU 0 │\n",
- "│ │\n",
- "│ │\n",
- "│ │\n",
- "│ │\n",
- "└───────────────────────┘\n"
- ]
+ "data": {
+ "text/html": [
+ "┌───────────────────────┐\n",
+ "│ │\n",
+ "│ │\n",
+ "│ │\n",
+ "│ │\n",
+ "│ TPU 0 │\n",
+ "│ │\n",
+ "│ │\n",
+ "│ │\n",
+ "│ │\n",
+ "└───────────────────────┘\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "┌───────────────────────┐\n",
+ "│ │\n",
+ "│ │\n",
+ "│ │\n",
+ "│ │\n",
+ "│ TPU \u001b[1;36m0\u001b[0m │\n",
+ "│ │\n",
+ "│ │\n",
+ "│ │\n",
+ "│ │\n",
+ "└───────────────────────┘\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
}
],
"source": [
@@ -1242,10 +1150,13 @@
},
{
"cell_type": "code",
- "execution_count": null,
+ "execution_count": 28,
"metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
"id": "8tn8lOj73vGR",
- "outputId": "5809b3c8-7333-4cd3-db97-a7aede943dce"
+ "outputId": "d9898c93-7afc-416b-8c40-4d9551613cd0"
},
"outputs": [
{
@@ -1254,7 +1165,7 @@
"True"
]
},
- "execution_count": 36,
+ "execution_count": 28,
"metadata": {},
"output_type": "execute_result"
}
@@ -1266,17 +1177,20 @@
},
{
"cell_type": "code",
- "execution_count": null,
+ "execution_count": 29,
"metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
"id": "D7PpZwhR3vGR",
- "outputId": "4f0bd43d-0b32-4089-d3da-c8f1449e3526"
+ "outputId": "4901a11b-2354-4d26-a897-b88def07a716"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
- "5 loops, best of 5: 19.3 ms per loop\n"
+ "49.7 ms ± 349 µs per loop (mean ± std. dev. of 5 runs, 5 loops each)\n"
]
}
],
@@ -1286,17 +1200,20 @@
},
{
"cell_type": "code",
- "execution_count": null,
+ "execution_count": 30,
"metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
"id": "rgo_yVHF3vGR",
- "outputId": "97f19052-f1c9-4d30-f453-07b3a7208aa9"
+ "outputId": "e51216cf-b073-4250-d422-67f9fd72f6aa"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
- "5 loops, best of 5: 3.25 ms per loop\n"
+ "7.47 ms ± 44.8 µs per loop (mean ± std. dev. of 5 runs, 5 loops each)\n"
]
}
],
@@ -1315,26 +1232,44 @@
},
{
"cell_type": "code",
- "execution_count": null,
+ "execution_count": 31,
"metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 166
+ },
"id": "f1Zw-2lH3vGR",
- "outputId": "a796bed4-07b0-497d-8fd8-31a22ab9762e"
+ "outputId": "43d7a642-fde4-47a6-901f-dfdc64d6a613"
},
"outputs": [
{
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "┌──────────┬──────────┐\n",
- "│ TPU 0 │ TPU 1 │\n",
- "├──────────┼──────────┤\n",
- "│ TPU 2 │ TPU 3 │\n",
- "├──────────┼──────────┤\n",
- "│ TPU 6 │ TPU 7 │\n",
- "├──────────┼──────────┤\n",
- "│ TPU 4 │ TPU 5 │\n",
- "└──────────┴──────────┘\n"
- ]
+ "data": {
+ "text/html": [
+ "┌──────────┬──────────┐\n",
+ "│ TPU 0 │ TPU 1 │\n",
+ "├──────────┼──────────┤\n",
+ "│ TPU 2 │ TPU 3 │\n",
+ "├──────────┼──────────┤\n",
+ "│ TPU 6 │ TPU 7 │\n",
+ "├──────────┼──────────┤\n",
+ "│ TPU 4 │ TPU 5 │\n",
+ "└──────────┴──────────┘\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "┌──────────┬──────────┐\n",
+ "│ TPU \u001b[1;36m0\u001b[0m │ TPU \u001b[1;36m1\u001b[0m │\n",
+ "├──────────┼──────────┤\n",
+ "│ TPU \u001b[1;36m2\u001b[0m │ TPU \u001b[1;36m3\u001b[0m │\n",
+ "├──────────┼──────────┤\n",
+ "│ TPU \u001b[1;36m6\u001b[0m │ TPU \u001b[1;36m7\u001b[0m │\n",
+ "├──────────┼──────────┤\n",
+ "│ TPU \u001b[1;36m4\u001b[0m │ TPU \u001b[1;36m5\u001b[0m │\n",
+ "└──────────┴──────────┘\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
}
],
"source": [
@@ -1365,7 +1300,7 @@
},
{
"cell_type": "code",
- "execution_count": null,
+ "execution_count": 94,
"metadata": {
"id": "1vAkZAOY3vGR"
},
@@ -1375,54 +1310,63 @@
"from termcolor import colored\n",
"\n",
"def print_exception(e):\n",
- " name = colored(f'{type(e).__name__}', 'red')\n",
+ " name = colored(f'{type(e).__name__}', 'red', force_color=True)\n",
" print(textwrap.fill(f'{name}: {str(e)}'))"
]
},
{
"cell_type": "code",
- "execution_count": null,
+ "execution_count": 95,
"metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
"id": "DHh0N3vn3vGS",
- "outputId": "e7741882-0ebf-4237-e5d1-e48c9b9c178f"
+ "outputId": "8c4652f7-c484-423b-ad78-182134280187"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
- "\u001b[31mValueError\u001b[0m: Devices of all `Array` inputs and outputs should\n",
- "be the same. Got array device ids [0, 1, 2, 3] on platform TPU and\n",
- "another array's device ids [4, 5, 6, 7] on platform TPU\n"
+ "\u001b[31mValueError\u001b[0m: Received incompatible devices for jitted\n",
+ "computation. Got argument x1 of jax.numpy.add with shape int32[24] and\n",
+ "device ids [0, 1, 2, 3] on platform TPU and argument x2 of\n",
+ "jax.numpy.add with shape int32[24] and device ids [4, 5, 6, 7] on\n",
+ "platform TPU\n"
]
}
],
"source": [
- "sharding1 = PositionalSharding(jax.devices()[:4])\n",
- "sharding2 = PositionalSharding(jax.devices()[4:])\n",
+ "sharding1 = NamedSharding(Mesh(jax.devices()[:4], 'x'), P('x'))\n",
+ "sharding2 = NamedSharding(Mesh(jax.devices()[4:], 'x'), P('x'))\n",
"\n",
- "y = jax.device_put(x, sharding1.reshape(2, 2))\n",
- "z = jax.device_put(x, sharding2.reshape(2, 2))\n",
+ "y = jax.device_put(x, sharding1)\n",
+ "z = jax.device_put(x, sharding2)\n",
"try: y + z\n",
"except ValueError as e: print_exception(e)"
]
},
{
"cell_type": "code",
- "execution_count": null,
+ "execution_count": 96,
"metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
"id": "Im7DkoOl3vGS",
- "outputId": "3adfe1cb-db52-4a9d-e98e-62c6455c3100"
+ "outputId": "1b6fcd7a-762b-4366-a96d-aea63bad7fe0"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
- "\u001b[31mValueError\u001b[0m: Devices of all `Array` inputs and outputs should\n",
- "be the same. Got array device ids [0, 1, 2, 3, 4, 5, 6, 7] on platform\n",
- "TPU and another array's device ids [0, 1, 2, 3, 6, 7, 4, 5] on\n",
- "platform TPU\n"
+ "\u001b[31mValueError\u001b[0m: Received incompatible devices for jitted\n",
+ "computation. Got argument x1 of jax.numpy.add with shape int32[24] and\n",
+ "device ids [0, 1, 2, 3, 4, 5, 6, 7] on platform TPU and argument x2 of\n",
+ "jax.numpy.add with shape int32[24] and device ids [0, 1, 2, 3, 6, 7,\n",
+ "4, 5] on platform TPU\n"
]
}
],
@@ -1430,11 +1374,11 @@
"devices = jax.devices()\n",
"permuted_devices = [devices[i] for i in [0, 1, 2, 3, 6, 7, 4, 5]]\n",
"\n",
- "sharding1 = PositionalSharding(devices)\n",
- "sharding2 = PositionalSharding(permuted_devices)\n",
+ "sharding1 = NamedSharding(Mesh(devices, 'x'), P('x'))\n",
+ "sharding2 = NamedSharding(Mesh(permuted_devices, 'x'), P('x'))\n",
"\n",
- "y = jax.device_put(x, sharding1.reshape(4, 2))\n",
- "z = jax.device_put(x, sharding2.reshape(4, 2))\n",
+ "y = jax.device_put(x, sharding1)\n",
+ "z = jax.device_put(x, sharding2)\n",
"try: y + z\n",
"except ValueError as e: print_exception(e)"
]
@@ -1455,10 +1399,13 @@
},
{
"cell_type": "code",
- "execution_count": null,
+ "execution_count": 40,
"metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
"id": "_QvtKL8r3vGS",
- "outputId": "e0078805-bdfd-436e-f94f-7cd256d2574f"
+ "outputId": "761b1208-fe4b-4c09-a7d2-f62152183ef0"
},
"outputs": [
{
@@ -1470,7 +1417,7 @@
}
],
"source": [
- "y = jax.device_put(x, sharding1.reshape(4, 2))\n",
+ "y = jax.device_put(x, sharding1)\n",
"y + jnp.ones_like(y)\n",
"y + jnp.arange(y.size).reshape(y.shape)\n",
"print('no error!')"
@@ -1496,30 +1443,30 @@
},
{
"cell_type": "code",
- "execution_count": null,
+ "execution_count": 41,
"metadata": {
"id": "jniSFm5V3vGT"
},
"outputs": [],
"source": [
- "sharding = PositionalSharding(mesh_utils.create_device_mesh((8,)))"
+ "mesh = Mesh(mesh_utils.create_device_mesh((4, 2)), ('x', 'y'))"
]
},
{
"cell_type": "code",
- "execution_count": null,
+ "execution_count": 42,
"metadata": {
"id": "Q1wuDp-L3vGT"
},
"outputs": [],
"source": [
"x = jax.random.normal(jax.random.key(0), (8192, 8192))\n",
- "x = jax.device_put(x, sharding.reshape(4, 2))"
+ "x = jax.device_put(x, NamedSharding(mesh, P('x', 'y')))"
]
},
{
"cell_type": "code",
- "execution_count": null,
+ "execution_count": 44,
"metadata": {
"id": "rqEDj0wB3vGT"
},
@@ -1528,43 +1475,83 @@
"@jax.jit\n",
"def f(x):\n",
" x = x + 1\n",
- " y = jax.lax.with_sharding_constraint(x, sharding.reshape(2, 4))\n",
+ " y = jax.lax.with_sharding_constraint(x, NamedSharding(mesh, P('y', 'x')))\n",
" return y"
]
},
{
"cell_type": "code",
- "execution_count": null,
+ "execution_count": 45,
"metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 347
+ },
"id": "zYFS-n4r3vGT",
- "outputId": "d23a7938-cb7d-44b4-b9c7-83edf1d1145e"
+ "outputId": "0ac96b8f-ed23-4413-aed9-edd00a841c37"
},
"outputs": [
{
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "┌──────────┬──────────┐\n",
- "│ TPU 0 │ TPU 1 │\n",
- "├──────────┼──────────┤\n",
- "│ TPU 2 │ TPU 3 │\n",
- "├──────────┼──────────┤\n",
- "│ TPU 6 │ TPU 7 │\n",
- "├──────────┼──────────┤\n",
- "│ TPU 4 │ TPU 5 │\n",
- "└──────────┴──────────┘\n",
- "┌───────┬───────┬───────┬───────┐\n",
- "│ │ │ │ │\n",
- "│ TPU 0 │ TPU 1 │ TPU 2 │ TPU 3 │\n",
- "│ │ │ │ │\n",
- "│ │ │ │ │\n",
- "├───────┼───────┼───────┼───────┤\n",
- "│ │ │ │ │\n",
- "│ TPU 6 │ TPU 7 │ TPU 4 │ TPU 5 │\n",
- "│ │ │ │ │\n",
- "│ │ │ │ │\n",
- "└───────┴───────┴───────┴───────┘\n"
- ]
+ "data": {
+ "text/html": [
+ "┌──────────┬──────────┐\n",
+ "│ TPU 0 │ TPU 1 │\n",
+ "├──────────┼──────────┤\n",
+ "│ TPU 2 │ TPU 3 │\n",
+ "├──────────┼──────────┤\n",
+ "│ TPU 6 │ TPU 7 │\n",
+ "├──────────┼──────────┤\n",
+ "│ TPU 4 │ TPU 5 │\n",
+ "└──────────┴──────────┘\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "┌──────────┬──────────┐\n",
+ "│ TPU \u001b[1;36m0\u001b[0m │ TPU \u001b[1;36m1\u001b[0m │\n",
+ "├──────────┼──────────┤\n",
+ "│ TPU \u001b[1;36m2\u001b[0m │ TPU \u001b[1;36m3\u001b[0m │\n",
+ "├──────────┼──────────┤\n",
+ "│ TPU \u001b[1;36m6\u001b[0m │ TPU \u001b[1;36m7\u001b[0m │\n",
+ "├──────────┼──────────┤\n",
+ "│ TPU \u001b[1;36m4\u001b[0m │ TPU \u001b[1;36m5\u001b[0m │\n",
+ "└──────────┴──────────┘\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "┌───────┬───────┬───────┬───────┐\n",
+ "│ │ │ │ │\n",
+ "│ TPU 0 │ TPU 2 │ TPU 6 │ TPU 4 │\n",
+ "│ │ │ │ │\n",
+ "│ │ │ │ │\n",
+ "├───────┼───────┼───────┼───────┤\n",
+ "│ │ │ │ │\n",
+ "│ TPU 1 │ TPU 3 │ TPU 7 │ TPU 5 │\n",
+ "│ │ │ │ │\n",
+ "│ │ │ │ │\n",
+ "└───────┴───────┴───────┴───────┘\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "┌───────┬───────┬───────┬───────┐\n",
+ "│ │ │ │ │\n",
+ "│ TPU \u001b[1;36m0\u001b[0m │ TPU \u001b[1;36m2\u001b[0m │ TPU \u001b[1;36m6\u001b[0m │ TPU \u001b[1;36m4\u001b[0m │\n",
+ "│ │ │ │ │\n",
+ "│ │ │ │ │\n",
+ "├───────┼───────┼───────┼───────┤\n",
+ "│ │ │ │ │\n",
+ "│ TPU \u001b[1;36m1\u001b[0m │ TPU \u001b[1;36m3\u001b[0m │ TPU \u001b[1;36m7\u001b[0m │ TPU \u001b[1;36m5\u001b[0m │\n",
+ "│ │ │ │ │\n",
+ "│ │ │ │ │\n",
+ "└───────┴───────┴───────┴───────┘\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
}
],
"source": [
@@ -1575,7 +1562,7 @@
},
{
"cell_type": "code",
- "execution_count": null,
+ "execution_count": 46,
"metadata": {
"id": "8g_2Y8wp3vGT"
},
@@ -1584,43 +1571,83 @@
"@jax.jit\n",
"def f(x):\n",
" x = x + 1\n",
- " y = jax.lax.with_sharding_constraint(x, sharding.replicate())\n",
+ " y = jax.lax.with_sharding_constraint(x, NamedSharding(mesh, P()))\n",
" return y"
]
},
{
"cell_type": "code",
- "execution_count": null,
+ "execution_count": 47,
"metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 347
+ },
"id": "AiRFtVsR3vGT",
- "outputId": "f3e28a70-46cf-46fb-c801-82f0ddb447e4"
+ "outputId": "2edacc2c-ac80-4519-c9d1-bee364a22b31"
},
"outputs": [
{
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "┌──────────┬──────────┐\n",
- "│ TPU 0 │ TPU 1 │\n",
- "├──────────┼──────────┤\n",
- "│ TPU 2 │ TPU 3 │\n",
- "├──────────┼──────────┤\n",
- "│ TPU 6 │ TPU 7 │\n",
- "├──────────┼──────────┤\n",
- "│ TPU 4 │ TPU 5 │\n",
- "└──────────┴──────────┘\n",
- "┌───────────────────────┐\n",
- "│ │\n",
- "│ │\n",
- "│ │\n",
- "│ │\n",
- "│ TPU 0,1,2,3,4,5,6,7 │\n",
- "│ │\n",
- "│ │\n",
- "│ │\n",
- "│ │\n",
- "└───────────────────────┘\n"
- ]
+ "data": {
+ "text/html": [
+ "┌──────────┬──────────┐\n",
+ "│ TPU 0 │ TPU 1 │\n",
+ "├──────────┼──────────┤\n",
+ "│ TPU 2 │ TPU 3 │\n",
+ "├──────────┼──────────┤\n",
+ "│ TPU 6 │ TPU 7 │\n",
+ "├──────────┼──────────┤\n",
+ "│ TPU 4 │ TPU 5 │\n",
+ "└──────────┴──────────┘\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "┌──────────┬──────────┐\n",
+ "│ TPU \u001b[1;36m0\u001b[0m │ TPU \u001b[1;36m1\u001b[0m │\n",
+ "├──────────┼──────────┤\n",
+ "│ TPU \u001b[1;36m2\u001b[0m │ TPU \u001b[1;36m3\u001b[0m │\n",
+ "├──────────┼──────────┤\n",
+ "│ TPU \u001b[1;36m6\u001b[0m │ TPU \u001b[1;36m7\u001b[0m │\n",
+ "├──────────┼──────────┤\n",
+ "│ TPU \u001b[1;36m4\u001b[0m │ TPU \u001b[1;36m5\u001b[0m │\n",
+ "└──────────┴──────────┘\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "┌───────────────────────┐\n",
+ "│ │\n",
+ "│ │\n",
+ "│ │\n",
+ "│ │\n",
+ "│ TPU 0,1,2,3,4,5,6,7 │\n",
+ "│ │\n",
+ "│ │\n",
+ "│ │\n",
+ "│ │\n",
+ "└───────────────────────┘\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "┌───────────────────────┐\n",
+ "│ │\n",
+ "│ │\n",
+ "│ │\n",
+ "│ │\n",
+ "│ TPU \u001b[1;36m0\u001b[0m,\u001b[1;36m1\u001b[0m,\u001b[1;36m2\u001b[0m,\u001b[1;36m3\u001b[0m,\u001b[1;36m4\u001b[0m,\u001b[1;36m5\u001b[0m,\u001b[1;36m6\u001b[0m,\u001b[1;36m7\u001b[0m │\n",
+ "│ │\n",
+ "│ │\n",
+ "│ │\n",
+ "│ │\n",
+ "└───────────────────────┘\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
}
],
"source": [
@@ -1669,7 +1696,7 @@
},
{
"cell_type": "code",
- "execution_count": null,
+ "execution_count": 48,
"metadata": {
"id": "mEKF3zIF3vGU"
},
@@ -1681,7 +1708,7 @@
},
{
"cell_type": "code",
- "execution_count": null,
+ "execution_count": 49,
"metadata": {
"id": "Mocs3oGe3vGU"
},
@@ -1701,7 +1728,7 @@
},
{
"cell_type": "code",
- "execution_count": null,
+ "execution_count": 50,
"metadata": {
"id": "glBB8tzW3vGU"
},
@@ -1713,27 +1740,27 @@
},
{
"cell_type": "code",
- "execution_count": null,
+ "execution_count": 51,
"metadata": {
"id": "R0x62AIa3vGU"
},
"outputs": [],
"source": [
"def init_layer(key, n_in, n_out):\n",
- " k1, k2 = jax.random.split(key)\n",
- " W = jax.random.normal(k1, (n_in, n_out)) / jnp.sqrt(n_in)\n",
- " b = jax.random.normal(k2, (n_out,))\n",
- " return W, b\n",
+ " k1, k2 = jax.random.split(key)\n",
+ " W = jax.random.normal(k1, (n_in, n_out)) / jnp.sqrt(n_in)\n",
+ " b = jax.random.normal(k2, (n_out,))\n",
+ " return W, b\n",
"\n",
"def init_model(key, layer_sizes, batch_size):\n",
- " key, *keys = jax.random.split(key, len(layer_sizes))\n",
- " params = list(map(init_layer, keys, layer_sizes[:-1], layer_sizes[1:]))\n",
+ " key, *keys = jax.random.split(key, len(layer_sizes))\n",
+ " params = list(map(init_layer, keys, layer_sizes[:-1], layer_sizes[1:]))\n",
"\n",
- " key, *keys = jax.random.split(key, 3)\n",
- " inputs = jax.random.normal(keys[0], (batch_size, layer_sizes[0]))\n",
- " targets = jax.random.normal(keys[1], (batch_size, layer_sizes[-1]))\n",
+ " key, *keys = jax.random.split(key, 3)\n",
+ " inputs = jax.random.normal(keys[0], (batch_size, layer_sizes[0]))\n",
+ " targets = jax.random.normal(keys[1], (batch_size, layer_sizes[-1]))\n",
"\n",
- " return params, (inputs, targets)\n",
+ " return params, (inputs, targets)\n",
"\n",
"layer_sizes = [784, 8192, 8192, 8192, 10]\n",
"batch_size = 8192\n",
@@ -1752,33 +1779,48 @@
},
{
"cell_type": "code",
- "execution_count": null,
+ "execution_count": 52,
+ "metadata": {
+ "id": "mJLqRPpSDX0i"
+ },
+ "outputs": [],
+ "source": [
+ "mesh = Mesh(mesh_utils.create_device_mesh((8,)), 'batch')"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 54,
"metadata": {
"id": "_Q5NbdOn3vGV"
},
"outputs": [],
"source": [
- "sharding = PositionalSharding(jax.devices()).reshape(8, 1)"
+ "sharding = NamedSharding(mesh, P('batch'))\n",
+ "replicated_sharding = NamedSharding(mesh, P())"
]
},
{
"cell_type": "code",
- "execution_count": null,
+ "execution_count": 55,
"metadata": {
"id": "3KC6ieEe3vGV"
},
"outputs": [],
"source": [
"batch = jax.device_put(batch, sharding)\n",
- "params = jax.device_put(params, sharding.replicate())"
+ "params = jax.device_put(params, replicated_sharding)"
]
},
{
"cell_type": "code",
- "execution_count": null,
+ "execution_count": 56,
"metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
"id": "MUb-QE2b3vGV",
- "outputId": "1f831ea5-5a30-49ad-8195-977ff7ed476a"
+ "outputId": "5a27f007-c572-44f8-9f49-6e745ee739e8"
},
"outputs": [
{
@@ -1787,7 +1829,7 @@
"Array(23.469475, dtype=float32)"
]
},
- "execution_count": 57,
+ "execution_count": 56,
"metadata": {},
"output_type": "execute_result"
}
@@ -1798,17 +1840,20 @@
},
{
"cell_type": "code",
- "execution_count": null,
+ "execution_count": 57,
"metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
"id": "HUkw0u413vGV",
- "outputId": "dfa2599c-9440-4657-9035-0dc3bbf625e1"
+ "outputId": "07e481a1-97fb-4bd0-d754-cb6d8317bff6"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
- "10.760101\n"
+ "10.760109\n"
]
}
],
@@ -1825,17 +1870,20 @@
},
{
"cell_type": "code",
- "execution_count": null,
+ "execution_count": 58,
"metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
"id": "paCw6Zaj3vGV",
- "outputId": "8ab1c32c-f2b1-465c-df71-f5a599e7f19e"
+ "outputId": "ad4cce34-3a6a-4d44-9a86-477a7fee4841"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
- "5 loops, best of 5: 26.3 ms per loop\n"
+ "53.8 ms ± 1.14 ms per loop (mean ± std. dev. of 5 runs, 5 loops each)\n"
]
}
],
@@ -1845,7 +1893,7 @@
},
{
"cell_type": "code",
- "execution_count": null,
+ "execution_count": 59,
"metadata": {
"id": "BF86UWpg3vGV"
},
@@ -1857,17 +1905,20 @@
},
{
"cell_type": "code",
- "execution_count": null,
+ "execution_count": 60,
"metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
"id": "Z1wgUKXk3vGV",
- "outputId": "74df8892-c349-41dc-cb1b-e0843ec5c994"
+ "outputId": "d66767b7-3f17-482f-b811-919bb1793277"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
- "5 loops, best of 5: 122 ms per loop\n"
+ "351 ms ± 81.2 ms per loop (mean ± std. dev. of 5 runs, 5 loops each)\n"
]
}
],
@@ -1886,50 +1937,88 @@
},
{
"cell_type": "code",
- "execution_count": null,
+ "execution_count": 61,
"metadata": {
- "id": "N5-zzgW03vGW"
+ "id": "k1hxOfgRDwo0"
},
"outputs": [],
"source": [
- "sharding = sharding.reshape(4, 2)"
+ "mesh = Mesh(mesh_utils.create_device_mesh((4, 2)), ('batch', 'model'))"
]
},
{
"cell_type": "code",
- "execution_count": null,
+ "execution_count": 62,
"metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 314
+ },
"id": "sgIWCjJK3vGW",
- "outputId": "b2fdc556-05cc-4e68-fa04-48643d194dee"
+ "outputId": "8cb0f19f-3942-415c-c57a-31bb81784f46"
},
"outputs": [
{
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "┌───────┐\n",
- "│TPU 0,1│\n",
- "├───────┤\n",
- "│TPU 2,3│\n",
- "├───────┤\n",
- "│TPU 4,5│\n",
- "├───────┤\n",
- "│TPU 6,7│\n",
- "└───────┘\n",
- "┌───────┐\n",
- "│TPU 0,1│\n",
- "├───────┤\n",
- "│TPU 2,3│\n",
- "├───────┤\n",
- "│TPU 4,5│\n",
- "├───────┤\n",
- "│TPU 6,7│\n",
- "└───────┘\n"
- ]
+ "data": {
+ "text/html": [
+ "┌───────┐\n",
+ "│TPU 0,1│\n",
+ "├───────┤\n",
+ "│TPU 2,3│\n",
+ "├───────┤\n",
+ "│TPU 6,7│\n",
+ "├───────┤\n",
+ "│TPU 4,5│\n",
+ "└───────┘\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "┌───────┐\n",
+ "│TPU \u001b[1;36m0\u001b[0m,\u001b[1;36m1\u001b[0m│\n",
+ "├───────┤\n",
+ "│TPU \u001b[1;36m2\u001b[0m,\u001b[1;36m3\u001b[0m│\n",
+ "├───────┤\n",
+ "│TPU \u001b[1;36m6\u001b[0m,\u001b[1;36m7\u001b[0m│\n",
+ "├───────┤\n",
+ "│TPU \u001b[1;36m4\u001b[0m,\u001b[1;36m5\u001b[0m│\n",
+ "└───────┘\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "┌───────┐\n",
+ "│TPU 0,1│\n",
+ "├───────┤\n",
+ "│TPU 2,3│\n",
+ "├───────┤\n",
+ "│TPU 6,7│\n",
+ "├───────┤\n",
+ "│TPU 4,5│\n",
+ "└───────┘\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "┌───────┐\n",
+ "│TPU \u001b[1;36m0\u001b[0m,\u001b[1;36m1\u001b[0m│\n",
+ "├───────┤\n",
+ "│TPU \u001b[1;36m2\u001b[0m,\u001b[1;36m3\u001b[0m│\n",
+ "├───────┤\n",
+ "│TPU \u001b[1;36m6\u001b[0m,\u001b[1;36m7\u001b[0m│\n",
+ "├───────┤\n",
+ "│TPU \u001b[1;36m4\u001b[0m,\u001b[1;36m5\u001b[0m│\n",
+ "└───────┘\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
}
],
"source": [
- "batch = jax.device_put(batch, sharding.replicate(1))\n",
+ "batch = jax.device_put(batch, NamedSharding(mesh, P('batch', None)))\n",
"jax.debug.visualize_array_sharding(batch[0])\n",
"jax.debug.visualize_array_sharding(batch[1])"
]
@@ -1937,6 +2026,17 @@
{
"cell_type": "code",
"execution_count": null,
+ "metadata": {
+ "id": "q9PQP-0eEAO6"
+ },
+ "outputs": [],
+ "source": [
+ "replicated_sharding = NamedSharding(mesh, P())"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 67,
"metadata": {
"id": "BqCjYCgg3vGW"
},
@@ -1944,45 +2044,65 @@
"source": [
"(W1, b1), (W2, b2), (W3, b3), (W4, b4) = params\n",
"\n",
- "W1 = jax.device_put(W1, sharding.replicate())\n",
- "b1 = jax.device_put(b1, sharding.replicate())\n",
+ "W1 = jax.device_put(W1, replicated_sharding)\n",
+ "b1 = jax.device_put(b1, replicated_sharding)\n",
"\n",
- "W2 = jax.device_put(W2, sharding.replicate(0))\n",
- "b2 = jax.device_put(b2, sharding.replicate(0))\n",
+ "W2 = jax.device_put(W2, NamedSharding(mesh, P(None, 'model')))\n",
+ "b2 = jax.device_put(b2, NamedSharding(mesh, P('model')))\n",
"\n",
- "W3 = jax.device_put(W3, sharding.replicate(0).T)\n",
- "b3 = jax.device_put(b3, sharding.replicate())\n",
+ "W3 = jax.device_put(W3, NamedSharding(mesh, P('model', None)))\n",
+ "b3 = jax.device_put(b3, replicated_sharding)\n",
"\n",
- "W4 = jax.device_put(W4, sharding.replicate())\n",
- "b4 = jax.device_put(b4, sharding.replicate())\n",
+ "W4 = jax.device_put(W4, replicated_sharding)\n",
+ "b4 = jax.device_put(b4, replicated_sharding)\n",
"\n",
"params = (W1, b1), (W2, b2), (W3, b3), (W4, b4)"
]
},
{
"cell_type": "code",
- "execution_count": null,
+ "execution_count": 68,
"metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 199
+ },
"id": "_lSJ63sh3vGW",
- "outputId": "5b37aa8b-3226-4805-8282-876e8d06edda"
+ "outputId": "bcd3e33e-36b5-4787-9cd2-60623fd6e5fa"
},
"outputs": [
{
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "┌───────────┬───────────┐\n",
- "│ │ │\n",
- "│ │ │\n",
- "│ │ │\n",
- "│ │ │\n",
- "│TPU 0,2,4,6│TPU 1,3,5,7│\n",
- "│ │ │\n",
- "│ │ │\n",
- "│ │ │\n",
- "│ │ │\n",
- "└───────────┴───────────┘\n"
- ]
+ "data": {
+ "text/html": [
+ "┌───────────┬───────────┐\n",
+ "│ │ │\n",
+ "│ │ │\n",
+ "│ │ │\n",
+ "│ │ │\n",
+ "│TPU 0,2,4,6│TPU 1,3,5,7│\n",
+ "│ │ │\n",
+ "│ │ │\n",
+ "│ │ │\n",
+ "│ │ │\n",
+ "└───────────┴───────────┘\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "┌───────────┬───────────┐\n",
+ "│ │ │\n",
+ "│ │ │\n",
+ "│ │ │\n",
+ "│ │ │\n",
+ "│TPU \u001b[1;36m0\u001b[0m,\u001b[1;36m2\u001b[0m,\u001b[1;36m4\u001b[0m,\u001b[1;36m6\u001b[0m│TPU \u001b[1;36m1\u001b[0m,\u001b[1;36m3\u001b[0m,\u001b[1;36m5\u001b[0m,\u001b[1;36m7\u001b[0m│\n",
+ "│ │ │\n",
+ "│ │ │\n",
+ "│ │ │\n",
+ "│ │ │\n",
+ "└───────────┴───────────┘\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
}
],
"source": [
@@ -1991,28 +2111,48 @@
},
{
"cell_type": "code",
- "execution_count": null,
+ "execution_count": 69,
"metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 199
+ },
"id": "fxkfWYkk3vGW",
- "outputId": "8a1063c3-540b-47c1-d990-a6845da861f7"
+ "outputId": "59e60b16-fe37-47d4-8214-96096ffbd79c"
},
"outputs": [
{
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "┌───────────────────────┐\n",
- "│ │\n",
- "│ TPU 0,2,4,6 │\n",
- "│ │\n",
- "│ │\n",
- "├───────────────────────┤\n",
- "│ │\n",
- "│ TPU 1,3,5,7 │\n",
- "│ │\n",
- "│ │\n",
- "└───────────────────────┘\n"
- ]
+ "data": {
+ "text/html": [
+ "┌───────────────────────┐\n",
+ "│ │\n",
+ "│ TPU 0,2,4,6 │\n",
+ "│ │\n",
+ "│ │\n",
+ "├───────────────────────┤\n",
+ "│ │\n",
+ "│ TPU 1,3,5,7 │\n",
+ "│ │\n",
+ "│ │\n",
+ "└───────────────────────┘\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "┌───────────────────────┐\n",
+ "│ │\n",
+ "│ TPU \u001b[1;36m0\u001b[0m,\u001b[1;36m2\u001b[0m,\u001b[1;36m4\u001b[0m,\u001b[1;36m6\u001b[0m │\n",
+ "│ │\n",
+ "│ │\n",
+ "├───────────────────────┤\n",
+ "│ │\n",
+ "│ TPU \u001b[1;36m1\u001b[0m,\u001b[1;36m3\u001b[0m,\u001b[1;36m5\u001b[0m,\u001b[1;36m7\u001b[0m │\n",
+ "│ │\n",
+ "│ │\n",
+ "└───────────────────────┘\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
}
],
"source": [
@@ -2021,17 +2161,20 @@
},
{
"cell_type": "code",
- "execution_count": null,
+ "execution_count": 70,
"metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
"id": "uPCVs-_k3vGW",
- "outputId": "de01cdfc-36cb-4823-c692-22c692ef4220"
+ "outputId": "618516e9-9736-4ca0-dd22-09d094ce57a2"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
- "10.760103\n"
+ "10.760109\n"
]
}
],
@@ -2041,7 +2184,7 @@
},
{
"cell_type": "code",
- "execution_count": null,
+ "execution_count": 71,
"metadata": {
"id": "L9JebLK_3vGW"
},
@@ -2057,17 +2200,20 @@
},
{
"cell_type": "code",
- "execution_count": null,
+ "execution_count": 72,
"metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
"id": "c9Sbl69e3vGX",
- "outputId": "8272c5fa-e59f-4953-c2d5-658c42a28712"
+ "outputId": "2ee3d432-7172-46ca-e01a-614e83345808"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
- "10.752466\n"
+ "10.752513\n"
]
}
],
@@ -2077,39 +2223,81 @@
},
{
"cell_type": "code",
- "execution_count": null,
+ "execution_count": 73,
"metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 380
+ },
"id": "lkAF0dAb3vGX",
- "outputId": "acf0df31-c5e1-4683-b73f-b0cd1b0929f8"
+ "outputId": "6c1e317e-cded-4af4-8080-0de835fa4c71"
},
"outputs": [
{
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "┌───────────┬───────────┐\n",
- "│ │ │\n",
- "│ │ │\n",
- "│ │ │\n",
- "│ │ │\n",
- "│TPU 0,2,4,6│TPU 1,3,5,7│\n",
- "│ │ │\n",
- "│ │ │\n",
- "│ │ │\n",
- "│ │ │\n",
- "└───────────┴───────────┘\n",
- "┌───────────────────────┐\n",
- "│ │\n",
- "│ TPU 0,2,4,6 │\n",
- "│ │\n",
- "│ │\n",
- "├───────────────────────┤\n",
- "│ │\n",
- "│ TPU 1,3,5,7 │\n",
- "│ │\n",
- "│ │\n",
- "└───────────────────────┘\n"
- ]
+ "data": {
+ "text/html": [
+ "┌───────────┬───────────┐\n",
+ "│ │ │\n",
+ "│ │ │\n",
+ "│ │ │\n",
+ "│ │ │\n",
+ "│TPU 0,2,4,6│TPU 1,3,5,7│\n",
+ "│ │ │\n",
+ "│ │ │\n",
+ "│ │ │\n",
+ "│ │ │\n",
+ "└───────────┴───────────┘\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "┌───────────┬───────────┐\n",
+ "│ │ │\n",
+ "│ │ │\n",
+ "│ │ │\n",
+ "│ │ │\n",
+ "│TPU \u001b[1;36m0\u001b[0m,\u001b[1;36m2\u001b[0m,\u001b[1;36m4\u001b[0m,\u001b[1;36m6\u001b[0m│TPU \u001b[1;36m1\u001b[0m,\u001b[1;36m3\u001b[0m,\u001b[1;36m5\u001b[0m,\u001b[1;36m7\u001b[0m│\n",
+ "│ │ │\n",
+ "│ │ │\n",
+ "│ │ │\n",
+ "│ │ │\n",
+ "└───────────┴───────────┘\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/html": [
+ "┌───────────────────────┐\n",
+ "│ │\n",
+ "│ TPU 0,2,4,6 │\n",
+ "│ │\n",
+ "│ │\n",
+ "├───────────────────────┤\n",
+ "│ │\n",
+ "│ TPU 1,3,5,7 │\n",
+ "│ │\n",
+ "│ │\n",
+ "└───────────────────────┘\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "┌───────────────────────┐\n",
+ "│ │\n",
+ "│ TPU \u001b[1;36m0\u001b[0m,\u001b[1;36m2\u001b[0m,\u001b[1;36m4\u001b[0m,\u001b[1;36m6\u001b[0m │\n",
+ "│ │\n",
+ "│ │\n",
+ "├───────────────────────┤\n",
+ "│ │\n",
+ "│ TPU \u001b[1;36m1\u001b[0m,\u001b[1;36m3\u001b[0m,\u001b[1;36m5\u001b[0m,\u001b[1;36m7\u001b[0m │\n",
+ "│ │\n",
+ "│ │\n",
+ "└───────────────────────┘\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
}
],
"source": [
@@ -2120,17 +2308,20 @@
},
{
"cell_type": "code",
- "execution_count": null,
+ "execution_count": 74,
"metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
"id": "I1Npor3i3vGX",
- "outputId": "4099f6dd-7b46-4123-c1cb-5173c3d3278e"
+ "outputId": "479c4d81-cb0b-40a5-89ba-394c10dc3297"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
- "10 loops, best of 10: 30.5 ms per loop\n"
+ "51.4 ms ± 454 µs per loop (mean ± std. dev. of 10 runs, 10 loops each)\n"
]
}
],
@@ -2173,7 +2364,7 @@
},
{
"cell_type": "code",
- "execution_count": null,
+ "execution_count": 75,
"metadata": {
"id": "kwS-aQE_3vGX"
},
@@ -2185,7 +2376,8 @@
" return x + numbers\n",
"\n",
"key = jax.random.key(42)\n",
- "x_sharding = jax.sharding.PositionalSharding(jax.devices())\n",
+ "mesh = Mesh(jax.devices(), 'x')\n",
+ "x_sharding = NamedSharding(mesh, P('x'))\n",
"x = jax.device_put(jnp.arange(24), x_sharding)"
]
},
@@ -2200,20 +2392,32 @@
},
{
"cell_type": "code",
- "execution_count": null,
+ "execution_count": 76,
"metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 67
+ },
"id": "Oi97rpLz3vGY",
- "outputId": "204a7e8d-dc88-4b77-b7e3-0e72f306c5d3"
+ "outputId": "9dd63254-a483-4847-c0f5-5a4367bf08e9"
},
"outputs": [
{
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "┌───────┬───────┬───────┬───────┬───────┬───────┬───────┬───────┐\n",
- "│ TPU 0 │ TPU 1 │ TPU 2 │ TPU 3 │ TPU 4 │ TPU 5 │ TPU 6 │ TPU 7 │\n",
- "└───────┴───────┴───────┴───────┴───────┴───────┴───────┴───────┘\n"
- ]
+ "data": {
+ "text/html": [
+ "┌───────┬───────┬───────┬───────┬───────┬───────┬───────┬───────┐\n",
+ "│ TPU 0 │ TPU 1 │ TPU 2 │ TPU 3 │ TPU 4 │ TPU 5 │ TPU 6 │ TPU 7 │\n",
+ "└───────┴───────┴───────┴───────┴───────┴───────┴───────┴───────┘\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "┌───────┬───────┬───────┬───────┬───────┬───────┬───────┬───────┐\n",
+ "│ TPU \u001b[1;36m0\u001b[0m │ TPU \u001b[1;36m1\u001b[0m │ TPU \u001b[1;36m2\u001b[0m │ TPU \u001b[1;36m3\u001b[0m │ TPU \u001b[1;36m4\u001b[0m │ TPU \u001b[1;36m5\u001b[0m │ TPU \u001b[1;36m6\u001b[0m │ TPU \u001b[1;36m7\u001b[0m │\n",
+ "└───────┴───────┴───────┴───────┴───────┴───────┴───────┴───────┘\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
}
],
"source": [
@@ -2231,10 +2435,13 @@
},
{
"cell_type": "code",
- "execution_count": null,
+ "execution_count": 77,
"metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
"id": "64wIZuSJ3vGY",
- "outputId": "1054fe99-0476-44ec-9693-b0d8f98bf6a8"
+ "outputId": "fa166d45-ca9c-457a-be84-bcc9236d0730"
},
"outputs": [
{
@@ -2261,10 +2468,13 @@
},
{
"cell_type": "code",
- "execution_count": null,
+ "execution_count": 78,
"metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
"id": "1I7bqxA63vGY",
- "outputId": "ec4c579d-f446-4b48-ceda-785c09ba299b"
+ "outputId": "756e0a36-ff14-438f-bbd4-3ef03f97a47b"
},
"outputs": [
{
@@ -2292,20 +2502,32 @@
},
{
"cell_type": "code",
- "execution_count": null,
+ "execution_count": 79,
"metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 67
+ },
"id": "zHPJzdn23vGY",
- "outputId": "a8904d20-4d04-4f59-8eae-281e47d29246"
+ "outputId": "3332de0f-4827-4f0b-b9ef-69249b7c6bc6"
},
"outputs": [
{
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "┌───────┬───────┬───────┬───────┬───────┬───────┬───────┬───────┐\n",
- "│ TPU 0 │ TPU 1 │ TPU 2 │ TPU 3 │ TPU 4 │ TPU 5 │ TPU 6 │ TPU 7 │\n",
- "└───────┴───────┴───────┴───────┴───────┴───────┴───────┴───────┘\n"
- ]
+ "data": {
+ "text/html": [
+ "┌───────┬───────┬───────┬───────┬───────┬───────┬───────┬───────┐\n",
+ "│ TPU 0 │ TPU 1 │ TPU 2 │ TPU 3 │ TPU 4 │ TPU 5 │ TPU 6 │ TPU 7 │\n",
+ "└───────┴───────┴───────┴───────┴───────┴───────┴───────┴───────┘\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "┌───────┬───────┬───────┬───────┬───────┬───────┬───────┬───────┐\n",
+ "│ TPU \u001b[1;36m0\u001b[0m │ TPU \u001b[1;36m1\u001b[0m │ TPU \u001b[1;36m2\u001b[0m │ TPU \u001b[1;36m3\u001b[0m │ TPU \u001b[1;36m4\u001b[0m │ TPU \u001b[1;36m5\u001b[0m │ TPU \u001b[1;36m6\u001b[0m │ TPU \u001b[1;36m7\u001b[0m │\n",
+ "└───────┴───────┴───────┴───────┴───────┴───────┴───────┴───────┘\n"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
}
],
"source": [
@@ -2323,10 +2545,13 @@
},
{
"cell_type": "code",
- "execution_count": null,
+ "execution_count": 80,
"metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
"id": "nBUHBBal3vGY",
- "outputId": "f194c213-0688-4b7a-ffb8-c4453b82b1f1"
+ "outputId": "4b9be948-ccab-4a31-a06f-37ec9c7b5235"
},
"outputs": [
{
@@ -2371,10 +2596,10 @@
"metadata": {
"accelerator": "TPU",
"colab": {
+ "gpuType": "V28",
"provenance": [],
"toc_visible": true
},
- "gpuClass": "standard",
"jupytext": {
"formats": "ipynb,md:myst"
},
diff --git a/docs/notebooks/Distributed_arrays_and_automatic_parallelization.md b/docs/notebooks/Distributed_arrays_and_automatic_parallelization.md
index b9ec9dc694d2..2142db9866ae 100644
--- a/docs/notebooks/Distributed_arrays_and_automatic_parallelization.md
+++ b/docs/notebooks/Distributed_arrays_and_automatic_parallelization.md
@@ -5,7 +5,7 @@ jupytext:
extension: .md
format_name: myst
format_version: 0.13
- jupytext_version: 1.16.1
+ jupytext_version: 1.16.4
kernelspec:
display_name: Python 3
name: python3
@@ -19,16 +19,14 @@ kernelspec:
+++ {"id": "pFtQjv4SzHRj"}
-[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/jax/blob/main/docs/notebooks/Distributed_arrays_and_automatic_parallelization.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/google/jax/blob/main/docs/notebooks/Distributed_arrays_and_automatic_parallelization.ipynb)
+[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jax-ml/jax/blob/main/docs/notebooks/Distributed_arrays_and_automatic_parallelization.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/jax-ml/jax/blob/main/docs/notebooks/Distributed_arrays_and_automatic_parallelization.ipynb)
This tutorial discusses parallelism via `jax.Array`, the unified array object model available in JAX v0.4.1 and newer.
```{code-cell}
:id: FNxScTfq3vGF
-import os
-import functools
from typing import Optional
import numpy as np
@@ -52,7 +50,7 @@ if len(jax.local_devices()) < 8:
## Intro and a quick example
-By reading this tutorial notebook, you'll learn about `jax.Array`, a unified
+By reading this tutorial notebook, you'll learn about `jax.Array`, a unified
datatype for representing arrays, even with physical storage spanning multiple
devices. You'll also learn about how using `jax.Array`s together with `jax.jit`
can provide automatic compiler-based parallelization.
@@ -64,24 +62,29 @@ First, we'll create a `jax.Array` sharded across multiple devices:
:id: Gf2lO4ii3vGG
from jax.experimental import mesh_utils
-from jax.sharding import PositionalSharding
+from jax.sharding import Mesh, PartitionSpec as P, NamedSharding
```
```{code-cell}
:id: q-XBTEoy3vGG
# Create a Sharding object to distribute a value across devices:
-sharding = PositionalSharding(mesh_utils.create_device_mesh((8,)))
+mesh = Mesh(devices=mesh_utils.create_device_mesh((4, 2)),
+ axis_names=('x', 'y'))
```
```{code-cell}
-:id: vI39znW93vGH
-:outputId: 3b518df8-5c29-4848-acc3-e41df939f30b
-
+---
+colab:
+ base_uri: https://localhost:8080/
+ height: 166
+id: vI39znW93vGH
+outputId: 4f702753-8add-4b65-a4af-0f18f098cc46
+---
# Create an array of random values:
x = jax.random.normal(jax.random.key(0), (8192, 8192))
# and use jax.device_put to distribute it across devices:
-y = jax.device_put(x, sharding.reshape(4, 2))
+y = jax.device_put(x, NamedSharding(mesh, P('x', 'y')))
jax.debug.visualize_array_sharding(y)
```
@@ -91,9 +94,13 @@ Next, we'll apply a computation to it and visualize how the result values are
stored across multiple devices too:
```{code-cell}
-:id: -qCnHZl83vGI
-:outputId: 9da9c29e-ce88-4425-e1ec-e93e5bcf3106
-
+---
+colab:
+ base_uri: https://localhost:8080/
+ height: 166
+id: -qCnHZl83vGI
+outputId: 0e131c23-5765-43ae-f232-6417ae1acbb2
+---
z = jnp.sin(y)
jax.debug.visualize_array_sharding(z)
```
@@ -104,17 +111,23 @@ The evaluation of the `jnp.sin` application was automatically parallelized
across the devices on which the input values (and output values) are stored:
```{code-cell}
-:id: _VTzN0r03vGI
-:outputId: c9208010-984b-442b-d105-c8c6a3a010e6
-
+---
+colab:
+ base_uri: https://localhost:8080/
+id: _VTzN0r03vGI
+outputId: c03eecab-4c86-4dac-d776-5fc72cbb5273
+---
# `x` is present on a single device
%timeit -n 5 -r 5 jnp.sin(x).block_until_ready()
```
```{code-cell}
-:id: QuzhU1g63vGI
-:outputId: d48fc76e-79a7-47b9-d392-b18a1c33c798
-
+---
+colab:
+ base_uri: https://localhost:8080/
+id: QuzhU1g63vGI
+outputId: 8135cca0-871b-4b6a-a7e5-02e78c2028c7
+---
# `y` is sharded across 8 devices.
%timeit -n 5 -r 5 jnp.sin(y).block_until_ready()
```
@@ -128,7 +141,7 @@ Now let's look at each of these pieces in more detail!
+++ {"id": "W6HsXauGxL6w"}
-### Sharding basics, and the `PositionalSharding` subclass
+### Sharding basics, and the `NamedSharding` subclass
+++ {"id": "NWDyp_EjVHkg"}
@@ -146,9 +159,13 @@ x = jax.random.normal(jax.random.key(0), (8192, 8192))
```
```{code-cell}
-:id: vNRabO2J3vGJ
-:outputId: 73db7b6e-c2e7-467d-a0ef-c35e29e582dd
-
+---
+colab:
+ base_uri: https://localhost:8080/
+ height: 199
+id: vNRabO2J3vGJ
+outputId: 40fd7172-a16c-4dd8-e2e1-17bb3afe5409
+---
jax.debug.visualize_array_sharding(x)
```
@@ -159,169 +176,14 @@ Here, we're using the `jax.debug.visualize_array_sharding` function to show wher
But we can shard `x` across multiple devices by using `jax.device_put` and a `Sharding` object. First, we make a `numpy.ndarray` of `Devices` using `mesh_utils.create_device_mesh`, which takes hardware topology into account for the `Device` order:
```{code-cell}
-:id: VUIEIzRp3vGK
-
-from jax.experimental import mesh_utils
-devices = mesh_utils.create_device_mesh((8,))
-```
-
-+++ {"id": "lbOKFWmBX1iv"}
-
-Then, we create a `PositionalSharding` and use it with `device_put`:
-
-```{code-cell}
-:id: jwrWfZeB3vGK
-:outputId: e6f126bd-f6bd-48c7-c130-6f02757e3342
-
-from jax.sharding import PositionalSharding
-
-sharding = PositionalSharding(devices)
-
-x = jax.device_put(x, sharding.reshape(8, 1))
-jax.debug.visualize_array_sharding(x)
-```
-
-+++ {"id": "TUu69IWXZdTm"}
-
-Here `sharding` is a `PositionalSharding` which acts like an array with sets of devices as elements:
-
-```{code-cell}
-:id: zxWB82Kz3vGK
-:outputId: 11384a6b-fabc-4c4c-bcad-a3be51eb0465
-
-sharding
-```
-
-+++ {"id": "uRLpOcmNj_Vt"}
-
-The device numbers here are not in numerical order, because the mesh reflects the underlying toroidal topology of the device.
-
-By writing `PositionalSharding(ndarray_of_devices)`, we fix the device order and the initial shape. Then we can reshape it:
-
-```{code-cell}
-:id: PLsnpSzc3vGL
-:outputId: 9f4db733-cafe-46ae-c057-dc31046a6f66
-
-sharding.reshape(8, 1)
-```
-
-```{code-cell}
-:id: iqKdI4LO3vGL
-:outputId: 6aa10fc2-cec4-4401-a0df-343e71646e0a
-
-sharding.reshape(4, 2)
-```
-
-+++ {"id": "KBu6WLfhm7ra"}
-
-To use `device_put` with a data array `x`, we can reshape the `sharding` into a shape that is _congruent_ with `x.shape`, meaning a shape with the same length as `x.shape` and where each element evenly divides the corresponding element of `x.shape`:
-```python
-def is_congruent(x_shape: Sequence[int], sharding_shape: Sequence[int]) -> bool:
- return (len(x_shape) == len(sharding_shape) and
- all(d1 % d2 == 0 for d1, d2 in zip(x_shape, sharding_shape)))
-```
-
-For example, we can reshape `sharding` to have shape `(4, 2)`, then use it in a `device_put`:
-
-```{code-cell}
-:id: SELr4xNi3vGL
-:outputId: b2f4acec-0cd3-4829-ca16-cae2e0e8ca60
-
-sharding = sharding.reshape(4, 2)
-print(sharding)
-```
-
-```{code-cell}
-:id: 8IVIsqfX3vGL
-:outputId: 033d0e02-a643-4f4c-9d24-9cd8465bc69a
-
-y = jax.device_put(x, sharding)
-jax.debug.visualize_array_sharding(y)
-```
-
-+++ {"id": "tyg9F-UIsU__"}
-
-Here `y` represents the same _value_ as `x`, but its shards (i.e. slices) are stored in different devices' memories.
-
-Different `PositionalSharding` shapes result in different distributed layouts (i.e. shardings) of the result:
-
-```{code-cell}
-:id: cCjt6QCz3vGM
-:outputId: 4ad8a611-596d-424f-b6c5-fc00f1adc306
-
-sharding = sharding.reshape(1, 8)
-print(sharding)
-```
-
-```{code-cell}
-:id: yTK4Nz3u3vGM
-:outputId: e445c6bc-4fe3-4e9d-cc9e-d82858f58312
-
-y = jax.device_put(x, sharding)
-jax.debug.visualize_array_sharding(y)
-```
-
-+++ {"id": "0PuamOvXubcf"}
-
-In some cases, we don't just want to store each slice of `x` in a single device's memory; we might want to _replicate_ some slices, meaning storing copies of a slice's values in multiple devices' memories.
-
-With `PositionalSharding`, we can express replication by calling the reducer method `replicate`:
-
-```{code-cell}
-:id: _jr6XYKx3vGM
-:outputId: 59c8b9a4-b8af-493a-ba8d-da5931e88f93
-
-sharding = sharding.reshape(4, 2)
-print(sharding.replicate(axis=0, keepdims=True))
-```
-
-```{code-cell}
-:id: S5vzjFuH3vGN
-:outputId: b6ce2675-7261-4e57-fa8c-b4e87abf7e52
-
-y = jax.device_put(x, sharding.replicate(axis=0, keepdims=True))
-jax.debug.visualize_array_sharding(y)
-```
-
-+++ {"id": "FzeP0kpTvJv-"}
-
-Here the visualization shows that `x` is sharded two ways along its second dimension (and not sharded along the first dimension), and each of those shards is replicated four ways (i.e. stored in four device memories).
-
-The `replicate` method is analogous to the familiar NumPy array reduction methods like `.sum()` and `.prod()`. It operates along an axis performing a set union. So if `sharding` has shape `(4, 2)`, then `sharding.replicate(0, keepdims=True)` has shape `(1, 2)`, and `sharding.replicate(1, keepdims=True)` has shape `(4, 1)`. Unlike analogous NumPy methods, `keepdims=True` is actually the default, so reduced-over axes aren't squeezed:
-
-```{code-cell}
-:id: DR7VV-6e3vGN
-:outputId: f879fc2c-5723-4199-b306-295bc1b3681e
-
-print(sharding.replicate(0).shape)
-print(sharding.replicate(1).shape)
-```
-
-```{code-cell}
-:id: agUtVUVx3vGN
-:outputId: 0e9789ef-ce52-4ed6-8bd5-c876b95f66e6
-
-y = jax.device_put(x, sharding.replicate(1))
-jax.debug.visualize_array_sharding(y)
-```
-
-+++ {"id": "D31t5POXxHHJ"}
-
-### `NamedSharding` gives a way to express shardings with names
-
-+++ {"id": "ayMKWeTmxl-X"}
-
-So far we've worked with `PositionalSharding`, but there are alternative ways to express shardings. In fact, `Sharding` is an interface, and any class that implements that interface can be used with functions like `device_put`.
-
-Another convenient way to express sharding is with the `NamedSharding`:
-
-```{code-cell}
-:id: zpB1JxyK3vGN
-:outputId: 46d5da37-840c-49d8-8380-a162811bae8a
-
-from jax.sharding import Mesh
-from jax.sharding import PartitionSpec
-from jax.sharding import NamedSharding
+---
+colab:
+ base_uri: https://localhost:8080/
+ height: 166
+id: zpB1JxyK3vGN
+outputId: 8e385462-1c2c-4256-c38a-84299d3bd02c
+---
+from jax.sharding import Mesh, PartitionSpec, NamedSharding
from jax.experimental import mesh_utils
P = PartitionSpec
@@ -351,9 +213,13 @@ def mesh_sharding(
```
```{code-cell}
-:id: zp3MfS4Y3vGO
-:outputId: 2c2f7201-c2c1-49e5-f8a5-0730c124d89a
-
+---
+colab:
+ base_uri: https://localhost:8080/
+ height: 166
+id: zp3MfS4Y3vGO
+outputId: 032fdd7e-19a1-45da-e1ad-b3227fa43ee6
+---
y = jax.device_put(x, mesh_sharding(P('a', 'b')))
jax.debug.visualize_array_sharding(y)
```
@@ -363,17 +229,25 @@ jax.debug.visualize_array_sharding(y)
Here, we use `P('a', 'b')` to express that the first and second axes of `x` should be sharded over the device mesh axes `'a'` and `'b'`, respectively. We can easily switch to `P('b', 'a')` to shard the axes of `x` over different devices:
```{code-cell}
-:id: FigK5Zsa3vGO
-:outputId: eca784e8-33fe-4e9b-a41d-21e9ee781a35
-
+---
+colab:
+ base_uri: https://localhost:8080/
+ height: 199
+id: FigK5Zsa3vGO
+outputId: e488d073-9d02-4376-a6af-19d6d5509c7d
+---
y = jax.device_put(x, mesh_sharding(P('b', 'a')))
jax.debug.visualize_array_sharding(y)
```
```{code-cell}
-:id: hI-HD0xN3vGO
-:outputId: c3e7dc3c-4048-448a-ef0b-50683532fcdc
-
+---
+colab:
+ base_uri: https://localhost:8080/
+ height: 166
+id: hI-HD0xN3vGO
+outputId: b0c2e863-3aee-4417-b45f-21b2187f6ef7
+---
# This `None` means that `x` is not sharded on its second dimension,
# and since the Mesh axis name 'b' is not mentioned, shards are
# replicated across it.
@@ -388,17 +262,25 @@ Here, because `P('a', None)` doesn't mention the `Mesh` axis name `'b'`, we get
To shard only over the second axis of `x`, we can use a `None` placeholder in the `PartitionSpec`:
```{code-cell}
-:id: EXBExMQC3vGP
-:outputId: fe1c8d7e-3345-4438-b9d2-780e7854b4eb
-
+---
+colab:
+ base_uri: https://localhost:8080/
+ height: 199
+id: EXBExMQC3vGP
+outputId: c80e6177-12a6-40ef-b4e4-934dad22da3d
+---
y = jax.device_put(x, mesh_sharding(P(None, 'b')))
jax.debug.visualize_array_sharding(y)
```
```{code-cell}
-:id: PjUpG8uz3vGP
-:outputId: 64d8224d-15d9-4ad4-d613-f7f85b1dc1af
-
+---
+colab:
+ base_uri: https://localhost:8080/
+ height: 199
+id: PjUpG8uz3vGP
+outputId: a0f59dc5-b509-4b8b-bd22-bcd69f696763
+---
y = jax.device_put(x, mesh_sharding(P(None, 'a')))
jax.debug.visualize_array_sharding(y)
```
@@ -408,9 +290,13 @@ jax.debug.visualize_array_sharding(y)
For a fixed mesh, we can even partition one logical axis of `x` over multiple device mesh axes:
```{code-cell}
-:id: fVcPbDUA3vGP
-:outputId: 7f524ba5-a6d8-4490-cda9-685ad11416f9
-
+---
+colab:
+ base_uri: https://localhost:8080/
+ height: 298
+id: fVcPbDUA3vGP
+outputId: da3f435d-dfc1-4a41-ec90-691cd7c748a0
+---
y = jax.device_put(x, mesh_sharding(P(('a', 'b'), None)))
jax.debug.visualize_array_sharding(y)
```
@@ -432,16 +318,19 @@ For example, the simplest computation is an elementwise one:
```{code-cell}
:id: _EmQwggc3vGQ
-from jax.experimental import mesh_utils
-from jax.sharding import PositionalSharding
-sharding = PositionalSharding(mesh_utils.create_device_mesh((8,)))
+devices = mesh_utils.create_device_mesh((4, 2))
+mesh = Mesh(devices, axis_names=('a', 'b'))
```
```{code-cell}
-:id: LnT0vWjc3vGQ
-:outputId: 8089effc-aa4c-49e3-dd19-7064881dbad0
-
-x = jax.device_put(x, sharding.reshape(4, 2))
+---
+colab:
+ base_uri: https://localhost:8080/
+ height: 349
+id: LnT0vWjc3vGQ
+outputId: 8e642049-61eb-458d-af79-ac449b58d11b
+---
+x = jax.device_put(x, NamedSharding(mesh, P('a', 'b')))
print('input sharding:')
jax.debug.visualize_array_sharding(x)
@@ -459,11 +348,15 @@ In other words, even though we wrote the `jnp.sin` computation as if a single ma
We can do the same for more than just elementwise operations too. Consider a matrix multiplication with sharded inputs:
```{code-cell}
-:id: Dq043GkP3vGQ
-:outputId: 350219a8-1e4a-4404-fe14-50f97ea3e7ba
-
-y = jax.device_put(x, sharding.reshape(4, 2).replicate(1))
-z = jax.device_put(x, sharding.reshape(4, 2).replicate(0))
+---
+colab:
+ base_uri: https://localhost:8080/
+ height: 548
+id: Dq043GkP3vGQ
+outputId: 3eff7b67-d7f0-4212-c9d3-2cc271ac1f98
+---
+y = jax.device_put(x, NamedSharding(mesh, P('a', None)))
+z = jax.device_put(x, NamedSharding(mesh, P(None, 'b')))
print('lhs sharding:')
jax.debug.visualize_array_sharding(y)
print('rhs sharding:')
@@ -481,32 +374,45 @@ Here the compiler chose the output sharding so that it could maximally paralleli
How can we be sure it's actually running in parallel? We can do a simple timing experiment:
```{code-cell}
-:id: QjQ5u8qh3vGQ
-:outputId: bd29edcd-b87c-486e-c568-906f06ae16be
-
+---
+colab:
+ base_uri: https://localhost:8080/
+ height: 199
+id: QjQ5u8qh3vGQ
+outputId: 0aefc170-833c-4a6a-e003-5990d3db31d9
+---
x_single = jax.device_put(x, jax.devices()[0])
jax.debug.visualize_array_sharding(x_single)
```
```{code-cell}
-:id: 8tn8lOj73vGR
-:outputId: 5809b3c8-7333-4cd3-db97-a7aede943dce
-
+---
+colab:
+ base_uri: https://localhost:8080/
+id: 8tn8lOj73vGR
+outputId: d9898c93-7afc-416b-8c40-4d9551613cd0
+---
np.allclose(jnp.dot(x_single, x_single),
jnp.dot(y, z))
```
```{code-cell}
-:id: D7PpZwhR3vGR
-:outputId: 4f0bd43d-0b32-4089-d3da-c8f1449e3526
-
+---
+colab:
+ base_uri: https://localhost:8080/
+id: D7PpZwhR3vGR
+outputId: 4901a11b-2354-4d26-a897-b88def07a716
+---
%timeit -n 5 -r 5 jnp.dot(x_single, x_single).block_until_ready()
```
```{code-cell}
-:id: rgo_yVHF3vGR
-:outputId: 97f19052-f1c9-4d30-f453-07b3a7208aa9
-
+---
+colab:
+ base_uri: https://localhost:8080/
+id: rgo_yVHF3vGR
+outputId: e51216cf-b073-4250-d422-67f9fd72f6aa
+---
%timeit -n 5 -r 5 jnp.dot(y, z).block_until_ready()
```
@@ -515,9 +421,13 @@ np.allclose(jnp.dot(x_single, x_single),
Even copying a sharded `Array` produces a result with the sharding of the input:
```{code-cell}
-:id: f1Zw-2lH3vGR
-:outputId: a796bed4-07b0-497d-8fd8-31a22ab9762e
-
+---
+colab:
+ base_uri: https://localhost:8080/
+ height: 166
+id: f1Zw-2lH3vGR
+outputId: 43d7a642-fde4-47a6-901f-dfdc64d6a613
+---
w_copy = jnp.copy(w)
jax.debug.visualize_array_sharding(w_copy)
```
@@ -540,35 +450,41 @@ import textwrap
from termcolor import colored
def print_exception(e):
- name = colored(f'{type(e).__name__}', 'red')
+ name = colored(f'{type(e).__name__}', 'red', force_color=True)
print(textwrap.fill(f'{name}: {str(e)}'))
```
```{code-cell}
-:id: DHh0N3vn3vGS
-:outputId: e7741882-0ebf-4237-e5d1-e48c9b9c178f
-
-sharding1 = PositionalSharding(jax.devices()[:4])
-sharding2 = PositionalSharding(jax.devices()[4:])
+---
+colab:
+ base_uri: https://localhost:8080/
+id: DHh0N3vn3vGS
+outputId: 8c4652f7-c484-423b-ad78-182134280187
+---
+sharding1 = NamedSharding(Mesh(jax.devices()[:4], 'x'), P('x'))
+sharding2 = NamedSharding(Mesh(jax.devices()[4:], 'x'), P('x'))
-y = jax.device_put(x, sharding1.reshape(2, 2))
-z = jax.device_put(x, sharding2.reshape(2, 2))
+y = jax.device_put(x, sharding1)
+z = jax.device_put(x, sharding2)
try: y + z
except ValueError as e: print_exception(e)
```
```{code-cell}
-:id: Im7DkoOl3vGS
-:outputId: 3adfe1cb-db52-4a9d-e98e-62c6455c3100
-
+---
+colab:
+ base_uri: https://localhost:8080/
+id: Im7DkoOl3vGS
+outputId: 1b6fcd7a-762b-4366-a96d-aea63bad7fe0
+---
devices = jax.devices()
permuted_devices = [devices[i] for i in [0, 1, 2, 3, 6, 7, 4, 5]]
-sharding1 = PositionalSharding(devices)
-sharding2 = PositionalSharding(permuted_devices)
+sharding1 = NamedSharding(Mesh(devices, 'x'), P('x'))
+sharding2 = NamedSharding(Mesh(permuted_devices, 'x'), P('x'))
-y = jax.device_put(x, sharding1.reshape(4, 2))
-z = jax.device_put(x, sharding2.reshape(4, 2))
+y = jax.device_put(x, sharding1)
+z = jax.device_put(x, sharding2)
try: y + z
except ValueError as e: print_exception(e)
```
@@ -583,10 +499,13 @@ Unlike committed arrays, uncommitted arrays can be moved and resharded automatic
For example, the output of `jnp.zeros`, `jnp.arange`, and `jnp.array` are uncommitted:
```{code-cell}
-:id: _QvtKL8r3vGS
-:outputId: e0078805-bdfd-436e-f94f-7cd256d2574f
-
-y = jax.device_put(x, sharding1.reshape(4, 2))
+---
+colab:
+ base_uri: https://localhost:8080/
+id: _QvtKL8r3vGS
+outputId: 761b1208-fe4b-4c09-a7d2-f62152183ef0
+---
+y = jax.device_put(x, sharding1)
y + jnp.ones_like(y)
y + jnp.arange(y.size).reshape(y.shape)
print('no error!')
@@ -603,14 +522,14 @@ While the compiler will attempt to decide how a function's intermediate values a
```{code-cell}
:id: jniSFm5V3vGT
-sharding = PositionalSharding(mesh_utils.create_device_mesh((8,)))
+mesh = Mesh(mesh_utils.create_device_mesh((4, 2)), ('x', 'y'))
```
```{code-cell}
:id: Q1wuDp-L3vGT
x = jax.random.normal(jax.random.key(0), (8192, 8192))
-x = jax.device_put(x, sharding.reshape(4, 2))
+x = jax.device_put(x, NamedSharding(mesh, P('x', 'y')))
```
```{code-cell}
@@ -619,14 +538,18 @@ x = jax.device_put(x, sharding.reshape(4, 2))
@jax.jit
def f(x):
x = x + 1
- y = jax.lax.with_sharding_constraint(x, sharding.reshape(2, 4))
+ y = jax.lax.with_sharding_constraint(x, NamedSharding(mesh, P('y', 'x')))
return y
```
```{code-cell}
-:id: zYFS-n4r3vGT
-:outputId: d23a7938-cb7d-44b4-b9c7-83edf1d1145e
-
+---
+colab:
+ base_uri: https://localhost:8080/
+ height: 347
+id: zYFS-n4r3vGT
+outputId: 0ac96b8f-ed23-4413-aed9-edd00a841c37
+---
jax.debug.visualize_array_sharding(x)
y = f(x)
jax.debug.visualize_array_sharding(y)
@@ -638,14 +561,18 @@ jax.debug.visualize_array_sharding(y)
@jax.jit
def f(x):
x = x + 1
- y = jax.lax.with_sharding_constraint(x, sharding.replicate())
+ y = jax.lax.with_sharding_constraint(x, NamedSharding(mesh, P()))
return y
```
```{code-cell}
-:id: AiRFtVsR3vGT
-:outputId: f3e28a70-46cf-46fb-c801-82f0ddb447e4
-
+---
+colab:
+ base_uri: https://localhost:8080/
+ height: 347
+id: AiRFtVsR3vGT
+outputId: 2edacc2c-ac80-4519-c9d1-bee364a22b31
+---
jax.debug.visualize_array_sharding(x)
y = f(x)
jax.debug.visualize_array_sharding(y)
@@ -702,20 +629,20 @@ gradfun = jax.jit(jax.grad(loss))
:id: R0x62AIa3vGU
def init_layer(key, n_in, n_out):
- k1, k2 = jax.random.split(key)
- W = jax.random.normal(k1, (n_in, n_out)) / jnp.sqrt(n_in)
- b = jax.random.normal(k2, (n_out,))
- return W, b
+ k1, k2 = jax.random.split(key)
+ W = jax.random.normal(k1, (n_in, n_out)) / jnp.sqrt(n_in)
+ b = jax.random.normal(k2, (n_out,))
+ return W, b
def init_model(key, layer_sizes, batch_size):
- key, *keys = jax.random.split(key, len(layer_sizes))
- params = list(map(init_layer, keys, layer_sizes[:-1], layer_sizes[1:]))
+ key, *keys = jax.random.split(key, len(layer_sizes))
+ params = list(map(init_layer, keys, layer_sizes[:-1], layer_sizes[1:]))
- key, *keys = jax.random.split(key, 3)
- inputs = jax.random.normal(keys[0], (batch_size, layer_sizes[0]))
- targets = jax.random.normal(keys[1], (batch_size, layer_sizes[-1]))
+ key, *keys = jax.random.split(key, 3)
+ inputs = jax.random.normal(keys[0], (batch_size, layer_sizes[0]))
+ targets = jax.random.normal(keys[1], (batch_size, layer_sizes[-1]))
- return params, (inputs, targets)
+ return params, (inputs, targets)
layer_sizes = [784, 8192, 8192, 8192, 10]
batch_size = 8192
@@ -727,30 +654,43 @@ params, batch = init_model(jax.random.key(0), layer_sizes, batch_size)
### 8-way batch data parallelism
+```{code-cell}
+:id: mJLqRPpSDX0i
+
+mesh = Mesh(mesh_utils.create_device_mesh((8,)), 'batch')
+```
+
```{code-cell}
:id: _Q5NbdOn3vGV
-sharding = PositionalSharding(jax.devices()).reshape(8, 1)
+sharding = NamedSharding(mesh, P('batch'))
+replicated_sharding = NamedSharding(mesh, P())
```
```{code-cell}
:id: 3KC6ieEe3vGV
batch = jax.device_put(batch, sharding)
-params = jax.device_put(params, sharding.replicate())
+params = jax.device_put(params, replicated_sharding)
```
```{code-cell}
-:id: MUb-QE2b3vGV
-:outputId: 1f831ea5-5a30-49ad-8195-977ff7ed476a
-
+---
+colab:
+ base_uri: https://localhost:8080/
+id: MUb-QE2b3vGV
+outputId: 5a27f007-c572-44f8-9f49-6e745ee739e8
+---
loss_jit(params, batch)
```
```{code-cell}
-:id: HUkw0u413vGV
-:outputId: dfa2599c-9440-4657-9035-0dc3bbf625e1
-
+---
+colab:
+ base_uri: https://localhost:8080/
+id: HUkw0u413vGV
+outputId: 07e481a1-97fb-4bd0-d754-cb6d8317bff6
+---
step_size = 1e-5
for _ in range(30):
@@ -762,9 +702,12 @@ print(loss_jit(params, batch))
```
```{code-cell}
-:id: paCw6Zaj3vGV
-:outputId: 8ab1c32c-f2b1-465c-df71-f5a599e7f19e
-
+---
+colab:
+ base_uri: https://localhost:8080/
+id: paCw6Zaj3vGV
+outputId: ad4cce34-3a6a-4d44-9a86-477a7fee4841
+---
%timeit -n 5 -r 5 gradfun(params, batch)[0][0].block_until_ready()
```
@@ -776,9 +719,12 @@ params_single = jax.device_put(params, jax.devices()[0])
```
```{code-cell}
-:id: Z1wgUKXk3vGV
-:outputId: 74df8892-c349-41dc-cb1b-e0843ec5c994
-
+---
+colab:
+ base_uri: https://localhost:8080/
+id: Z1wgUKXk3vGV
+outputId: d66767b7-3f17-482f-b811-919bb1793277
+---
%timeit -n 5 -r 5 gradfun(params_single, batch_single)[0][0].block_until_ready()
```
@@ -787,58 +733,79 @@ params_single = jax.device_put(params, jax.devices()[0])
### 4-way batch data parallelism and 2-way model tensor parallelism
```{code-cell}
-:id: N5-zzgW03vGW
+:id: k1hxOfgRDwo0
-sharding = sharding.reshape(4, 2)
+mesh = Mesh(mesh_utils.create_device_mesh((4, 2)), ('batch', 'model'))
```
```{code-cell}
-:id: sgIWCjJK3vGW
-:outputId: b2fdc556-05cc-4e68-fa04-48643d194dee
-
-batch = jax.device_put(batch, sharding.replicate(1))
+---
+colab:
+ base_uri: https://localhost:8080/
+ height: 314
+id: sgIWCjJK3vGW
+outputId: 8cb0f19f-3942-415c-c57a-31bb81784f46
+---
+batch = jax.device_put(batch, NamedSharding(mesh, P('batch', None)))
jax.debug.visualize_array_sharding(batch[0])
jax.debug.visualize_array_sharding(batch[1])
```
+```{code-cell}
+:id: q9PQP-0eEAO6
+
+replicated_sharding = NamedSharding(mesh, P())
+```
+
```{code-cell}
:id: BqCjYCgg3vGW
(W1, b1), (W2, b2), (W3, b3), (W4, b4) = params
-W1 = jax.device_put(W1, sharding.replicate())
-b1 = jax.device_put(b1, sharding.replicate())
+W1 = jax.device_put(W1, replicated_sharding)
+b1 = jax.device_put(b1, replicated_sharding)
-W2 = jax.device_put(W2, sharding.replicate(0))
-b2 = jax.device_put(b2, sharding.replicate(0))
+W2 = jax.device_put(W2, NamedSharding(mesh, P(None, 'model')))
+b2 = jax.device_put(b2, NamedSharding(mesh, P('model')))
-W3 = jax.device_put(W3, sharding.replicate(0).T)
-b3 = jax.device_put(b3, sharding.replicate())
+W3 = jax.device_put(W3, NamedSharding(mesh, P('model', None)))
+b3 = jax.device_put(b3, replicated_sharding)
-W4 = jax.device_put(W4, sharding.replicate())
-b4 = jax.device_put(b4, sharding.replicate())
+W4 = jax.device_put(W4, replicated_sharding)
+b4 = jax.device_put(b4, replicated_sharding)
params = (W1, b1), (W2, b2), (W3, b3), (W4, b4)
```
```{code-cell}
-:id: _lSJ63sh3vGW
-:outputId: 5b37aa8b-3226-4805-8282-876e8d06edda
-
+---
+colab:
+ base_uri: https://localhost:8080/
+ height: 199
+id: _lSJ63sh3vGW
+outputId: bcd3e33e-36b5-4787-9cd2-60623fd6e5fa
+---
jax.debug.visualize_array_sharding(W2)
```
```{code-cell}
-:id: fxkfWYkk3vGW
-:outputId: 8a1063c3-540b-47c1-d990-a6845da861f7
-
+---
+colab:
+ base_uri: https://localhost:8080/
+ height: 199
+id: fxkfWYkk3vGW
+outputId: 59e60b16-fe37-47d4-8214-96096ffbd79c
+---
jax.debug.visualize_array_sharding(W3)
```
```{code-cell}
-:id: uPCVs-_k3vGW
-:outputId: de01cdfc-36cb-4823-c692-22c692ef4220
-
+---
+colab:
+ base_uri: https://localhost:8080/
+id: uPCVs-_k3vGW
+outputId: 618516e9-9736-4ca0-dd22-09d094ce57a2
+---
print(loss_jit(params, batch))
```
@@ -854,25 +821,35 @@ for _ in range(30):
```
```{code-cell}
-:id: c9Sbl69e3vGX
-:outputId: 8272c5fa-e59f-4953-c2d5-658c42a28712
-
+---
+colab:
+ base_uri: https://localhost:8080/
+id: c9Sbl69e3vGX
+outputId: 2ee3d432-7172-46ca-e01a-614e83345808
+---
print(loss_jit(params, batch))
```
```{code-cell}
-:id: lkAF0dAb3vGX
-:outputId: acf0df31-c5e1-4683-b73f-b0cd1b0929f8
-
+---
+colab:
+ base_uri: https://localhost:8080/
+ height: 380
+id: lkAF0dAb3vGX
+outputId: 6c1e317e-cded-4af4-8080-0de835fa4c71
+---
(W1, b1), (W2, b2), (W3, b3), (W4, b4) = params
jax.debug.visualize_array_sharding(W2)
jax.debug.visualize_array_sharding(W3)
```
```{code-cell}
-:id: I1Npor3i3vGX
-:outputId: 4099f6dd-7b46-4123-c1cb-5173c3d3278e
-
+---
+colab:
+ base_uri: https://localhost:8080/
+id: I1Npor3i3vGX
+outputId: 479c4d81-cb0b-40a5-89ba-394c10dc3297
+---
%timeit -n 10 -r 10 gradfun(params, batch)[0][0].block_until_ready()
```
@@ -903,7 +880,8 @@ def f(key, x):
return x + numbers
key = jax.random.key(42)
-x_sharding = jax.sharding.PositionalSharding(jax.devices())
+mesh = Mesh(jax.devices(), 'x')
+x_sharding = NamedSharding(mesh, P('x'))
x = jax.device_put(jnp.arange(24), x_sharding)
```
@@ -912,9 +890,13 @@ x = jax.device_put(jnp.arange(24), x_sharding)
On a partitioned input, the function `f` produces output that is also partitioned:
```{code-cell}
-:id: Oi97rpLz3vGY
-:outputId: 204a7e8d-dc88-4b77-b7e3-0e72f306c5d3
-
+---
+colab:
+ base_uri: https://localhost:8080/
+ height: 67
+id: Oi97rpLz3vGY
+outputId: 9dd63254-a483-4847-c0f5-5a4367bf08e9
+---
jax.debug.visualize_array_sharding(f(key, x))
```
@@ -923,9 +905,12 @@ jax.debug.visualize_array_sharding(f(key, x))
But if we inspect the compiled computation for `f` on this partitioned input, we see that it does involve some communication:
```{code-cell}
-:id: 64wIZuSJ3vGY
-:outputId: 1054fe99-0476-44ec-9693-b0d8f98bf6a8
-
+---
+colab:
+ base_uri: https://localhost:8080/
+id: 64wIZuSJ3vGY
+outputId: fa166d45-ca9c-457a-be84-bcc9236d0730
+---
f_exe = f.lower(key, x).compile()
print('Communicating?', 'collective-permute' in f_exe.as_text())
```
@@ -935,9 +920,12 @@ print('Communicating?', 'collective-permute' in f_exe.as_text())
One way to work around this is to configure JAX with the experimental upgrade flag `jax_threefry_partitionable`. With the flag on, the "collective permute" operation is now gone from the compiled computation:
```{code-cell}
-:id: 1I7bqxA63vGY
-:outputId: ec4c579d-f446-4b48-ceda-785c09ba299b
-
+---
+colab:
+ base_uri: https://localhost:8080/
+id: 1I7bqxA63vGY
+outputId: 756e0a36-ff14-438f-bbd4-3ef03f97a47b
+---
jax.config.update('jax_threefry_partitionable', True)
f_exe = f.lower(key, x).compile()
print('Communicating?', 'collective-permute' in f_exe.as_text())
@@ -948,9 +936,13 @@ print('Communicating?', 'collective-permute' in f_exe.as_text())
The output is still partitioned:
```{code-cell}
-:id: zHPJzdn23vGY
-:outputId: a8904d20-4d04-4f59-8eae-281e47d29246
-
+---
+colab:
+ base_uri: https://localhost:8080/
+ height: 67
+id: zHPJzdn23vGY
+outputId: 3332de0f-4827-4f0b-b9ef-69249b7c6bc6
+---
jax.debug.visualize_array_sharding(f(key, x))
```
@@ -959,9 +951,12 @@ jax.debug.visualize_array_sharding(f(key, x))
One caveat to the `jax_threefry_partitionable` option, however, is that _the random values produced may be different than without the flag set_, even though they were generated by the same random key:
```{code-cell}
-:id: nBUHBBal3vGY
-:outputId: f194c213-0688-4b7a-ffb8-c4453b82b1f1
-
+---
+colab:
+ base_uri: https://localhost:8080/
+id: nBUHBBal3vGY
+outputId: 4b9be948-ccab-4a31-a06f-37ec9c7b5235
+---
jax.config.update('jax_threefry_partitionable', False)
print('Stable:')
print(f(key, x))
diff --git a/docs/notebooks/How_JAX_primitives_work.ipynb b/docs/notebooks/How_JAX_primitives_work.ipynb
deleted file mode 100644
index f42e3f74b4e3..000000000000
--- a/docs/notebooks/How_JAX_primitives_work.ipynb
+++ /dev/null
@@ -1,1532 +0,0 @@
-{
- "cells": [
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "vfxqky4PCUnh"
- },
- "source": [
- "# How JAX primitives work\n",
- "\n",
- "\n",
- "\n",
- "[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/jax/blob/main/docs/notebooks/How_JAX_primitives_work.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/google/jax/blob/main/docs/notebooks/How_JAX_primitives_work.ipynb)\n",
- "\n",
- "*necula@google.com*, October 2019.\n",
- "\n",
- "JAX implements certain transformations of Python functions, e.g., `jit`, `grad`,\n",
- "`vmap`, or `pmap`. The Python functions to be transformed must be JAX-traceable, \n",
- "which means that as the Python function executes\n",
- "the only operations it applies to the data are either inspections of data\n",
- "attributes such as shape or type, or special operations called JAX primitives.\n",
- "In particular, a JAX-traceable function is sometimes invoked by JAX with\n",
- "abstract arguments. An example of a JAX abstract value is `ShapedArray(float32[2,2])`, \n",
- "which captures the type and the shape of values, but not the concrete data values.\n",
- "JAX primitives know how to operate on both concrete data\n",
- "values and on the JAX abstract values.\n",
- "\n",
- "\n",
- "The JAX-transformed functions must themselves be JAX-traceable functions,\n",
- "to ensure that these transformations\n",
- "can be composed, e.g., `jit(jacfwd(grad(f)))`.\n",
- "\n",
- "There are pre-defined JAX primitives corresponding to most XLA operations, \n",
- "e.g., add, matmul, sin, cos, indexing.\n",
- "JAX comes with an implementation of numpy functions in terms of JAX primitives, which means that Python programs\n",
- "using JAX’s implementation of numpy are JAX-traceable and therefore transformable.\n",
- "Other libraries can be made JAX-traceable by implementing them in terms of JAX primitives.\n",
- "\n",
- "The set of JAX primitives is extensible. Instead of reimplementing a function in terms of pre-defined JAX primitives,\n",
- "one can define a new primitive that encapsulates the behavior of the function.\n",
- "\n",
- "**The goal of this document is to explain the interface that a JAX primitive must support in order to allow JAX to perform all its transformations.**\n",
- "\n",
- "Consider that we want to add to JAX support for a multiply-add function with three arguments, defined mathematically\n",
- "as \"multiply_add(x, y, z) = x * y + z\". \n",
- "This function operates on 3 identically-shaped tensors of floating point \n",
- "values and performs the operations pointwise."
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "HIJYIHNTD1yI"
- },
- "source": [
- "## Using existing primitives\n",
- "\n",
- "The easiest way to define new functions is to write them in terms of JAX primitives, or in terms of other\n",
- "functions that are themselves written using JAX primitives, e.g., those \n",
- "defined in the `jax.lax` module:"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 1,
- "metadata": {
- "id": "tbOF0LB0EMne",
- "outputId": "3fb1c8a7-7a4c-4a3a-f7ff-37b7dc740528"
- },
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "square_add_lax = 14.0\n",
- "grad(square_add_lax) = 4.0\n"
- ]
- },
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "/usr/local/lib/python3.6/dist-packages/jax/lib/xla_bridge.py:115: UserWarning: No GPU/TPU found, falling back to CPU.\n",
- " warnings.warn('No GPU/TPU found, falling back to CPU.')\n"
- ]
- }
- ],
- "source": [
- "from jax import lax\n",
- "from jax._src import api\n",
- "\n",
- "def multiply_add_lax(x, y, z):\n",
- " \"\"\"Implementation of multiply-add using the jax.lax primitives.\"\"\"\n",
- " return lax.add(lax.mul(x, y), z)\n",
- "\n",
- "\n",
- "def square_add_lax(a, b):\n",
- " \"\"\"A square-add function using the newly defined multiply-add.\"\"\"\n",
- " return multiply_add_lax(a, a, b)\n",
- "\n",
- "print(\"square_add_lax = \", square_add_lax(2., 10.))\n",
- "# Differentiate w.r.t. the first argument\n",
- "print(\"grad(square_add_lax) = \", api.grad(square_add_lax, argnums=0)(2.0, 10.))"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "Cgv60Wm3E_D5"
- },
- "source": [
- "In order to understand how JAX is internally using the primitives,\n",
- "we add some helpers for tracing function calls."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "cellView": "form",
- "id": "mQRQGEGiE53K"
- },
- "outputs": [],
- "source": [
- "#@title Helper functions (execute this cell)\n",
- "import functools\n",
- "import traceback\n",
- "\n",
- "_indentation = 0\n",
- "def _trace(msg=None):\n",
- " \"\"\"Print a message at current indentation.\"\"\"\n",
- " if msg is not None:\n",
- " print(\" \" * _indentation + msg)\n",
- "\n",
- "def _trace_indent(msg=None):\n",
- " \"\"\"Print a message and then indent the rest.\"\"\"\n",
- " global _indentation\n",
- " _trace(msg)\n",
- " _indentation = 1 + _indentation\n",
- "\n",
- "def _trace_unindent(msg=None):\n",
- " \"\"\"Unindent then print a message.\"\"\"\n",
- " global _indentation\n",
- " _indentation = _indentation - 1\n",
- " _trace(msg)\n",
- "\n",
- "def trace(name):\n",
- " \"\"\"A decorator for functions to trace arguments and results.\"\"\"\n",
- "\n",
- " def trace_func(func): # pylint: disable=missing-docstring\n",
- " def pp(v):\n",
- " \"\"\"Print certain values more succinctly\"\"\"\n",
- " vtype = str(type(v))\n",
- " if \"jax._src.xla_bridge._JaxComputationBuilder\" in vtype:\n",
- " return \"\"\n",
- " elif \"jaxlib.xla_extension.XlaOp\" in vtype:\n",
- " return \"\".format(id(v))\n",
- " elif (\"partial_eval.JaxprTracer\" in vtype or\n",
- " \"batching.BatchTracer\" in vtype or\n",
- " \"ad.JVPTracer\" in vtype):\n",
- " return \"Traced<{}>\".format(v.aval)\n",
- " elif isinstance(v, tuple):\n",
- " return \"({})\".format(pp_values(v))\n",
- " else:\n",
- " return str(v)\n",
- " def pp_values(args):\n",
- " return \", \".join([pp(arg) for arg in args])\n",
- " \n",
- " @functools.wraps(func)\n",
- " def func_wrapper(*args):\n",
- " _trace_indent(\"call {}({})\".format(name, pp_values(args)))\n",
- " res = func(*args)\n",
- " _trace_unindent(\"|<- {} = {}\".format(name, pp(res)))\n",
- " return res\n",
- "\n",
- " return func_wrapper\n",
- "\n",
- " return trace_func\n",
- "\n",
- "class expectNotImplementedError(object):\n",
- " \"\"\"Context manager to check for NotImplementedError.\"\"\"\n",
- " def __enter__(self): pass\n",
- " def __exit__(self, type, value, tb):\n",
- " global _indentation\n",
- " _indentation = 0\n",
- " if type is NotImplementedError:\n",
- " print(\"\\nFound expected exception:\")\n",
- " traceback.print_exc(limit=3)\n",
- " return True\n",
- " elif type is None: # No exception\n",
- " assert False, \"Expected NotImplementedError\"\n",
- " else:\n",
- " return False"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "Qf4eLrLCFYDl"
- },
- "source": [
- "Instead of using `jax.lax` primitives directly, we can use other functions \n",
- "that are already written in terms of those primitives, such as those in `jax.numpy`:"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 3,
- "metadata": {
- "id": "QhKorz6cFRJb",
- "outputId": "aba3cef3-6bcc-4eb3-c7b3-34e405f2f82a"
- },
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "\n",
- "Normal evaluation:\n",
- "call square_add_numpy(2.0, 10.0)\n",
- " call multiply_add_numpy(2.0, 2.0, 10.0)\n",
- " |<- multiply_add_numpy = 14.0\n",
- "|<- square_add_numpy = 14.0\n",
- "square_add_numpy = 14.0\n",
- "\n",
- "Gradient evaluation:\n",
- "call square_add_numpy(Traced, 10.0)\n",
- " call multiply_add_numpy(Traced, Traced, 10.0)\n",
- " |<- multiply_add_numpy = Traced\n",
- "|<- square_add_numpy = Traced\n",
- "grad(square_add_numpy) = 4.0\n"
- ]
- }
- ],
- "source": [
- "import jax.numpy as jnp\n",
- "import numpy as np\n",
- "\n",
- "@trace(\"multiply_add_numpy\")\n",
- "def multiply_add_numpy(x, y, z):\n",
- " return jnp.add(jnp.multiply(x, y), z)\n",
- "\n",
- "@trace(\"square_add_numpy\")\n",
- "def square_add_numpy(a, b):\n",
- " return multiply_add_numpy(a, a, b)\n",
- "\n",
- "print(\"\\nNormal evaluation:\") \n",
- "print(\"square_add_numpy = \", square_add_numpy(2., 10.))\n",
- "print(\"\\nGradient evaluation:\")\n",
- "print(\"grad(square_add_numpy) = \", api.grad(square_add_numpy)(2.0, 10.))"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "Sg-D8EdeFn4a"
- },
- "source": [
- "Notice that in the process of computing `grad`, JAX invokes `square_add_numpy` and\n",
- "`multiply_add_numpy` with special arguments `ConcreteArray(...)` (described further \n",
- "below in this colab). \n",
- "It is important to remember that a JAX-traceable function must be able to \n",
- "operate not only on concrete arguments but also on special abstract arguments\n",
- "that JAX may use to abstract the function execution.\n",
- "\n",
- "The JAX traceability property is satisfied as long as the function is written \n",
- "in terms of JAX primitives."
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "WxrQO7-XGLcg"
- },
- "source": [
- "## Defining new JAX primitives\n",
- "\n",
- "The right way to add support for multiply-add is in terms of existing\n",
- "JAX primitives, as shown above. However, in order to demonstrate how JAX\n",
- "primitives work let us pretend that we want to add a new primitive to \n",
- "JAX for the multiply-add functionality."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "id": "cPqAH1XOGTN4"
- },
- "outputs": [],
- "source": [
- "from jax import core\n",
- "multiply_add_p = core.Primitive(\"multiply_add\") # Create the primitive\n",
- "\n",
- "@trace(\"multiply_add_prim\")\n",
- "def multiply_add_prim(x, y, z):\n",
- " \"\"\"The JAX-traceable way to use the JAX primitive.\n",
- " \n",
- " Note that the traced arguments must be passed as positional arguments\n",
- " to `bind`. \n",
- " \"\"\"\n",
- " return multiply_add_p.bind(x, y, z)\n",
- "\n",
- "@trace(\"square_add_prim\")\n",
- "def square_add_prim(a, b):\n",
- " \"\"\"A square-add function implemented using the new JAX-primitive.\"\"\"\n",
- " return multiply_add_prim(a, a, b)"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "LMzs5PAKGr-4"
- },
- "source": [
- "If we try to call the newly defined functions we get an error, because\n",
- "we have not yet told JAX anything about the semantics of the new primitive."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 5,
- "metadata": {
- "id": "_X3PAYxhGpWd",
- "outputId": "90ea2c6a-9ef3-40ea-e9a3-3ab1cfc59fc8"
- },
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "call square_add_prim(2.0, 10.0)\n",
- " call multiply_add_prim(2.0, 2.0, 10.0)\n",
- "\n",
- "Found expected exception:\n"
- ]
- },
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "Traceback (most recent call last):\n",
- " File \"\", line 2, in \n",
- " square_add_prim(2., 10.)\n",
- " File \"\", line 47, in func_wrapper\n",
- " res = func(*args)\n",
- " File \"\", line 16, in square_add_prim\n",
- " return multiply_add_prim(a, a, b)\n",
- "NotImplementedError: Evaluation rule for 'multiply_add' not implemented\n"
- ]
- }
- ],
- "source": [
- "with expectNotImplementedError():\n",
- " square_add_prim(2., 10.)"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "elha0FdgHSEF"
- },
- "source": [
- "### Primal evaluation rules"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 6,
- "metadata": {
- "id": "FT34FFAGHARU",
- "outputId": "4c54f1c2-8a50-4788-90e1-06aee412c43b"
- },
- "outputs": [
- {
- "data": {
- "text/plain": [
- ""
- ]
- },
- "execution_count": 6,
- "metadata": {
- "tags": []
- },
- "output_type": "execute_result"
- }
- ],
- "source": [
- "@trace(\"multiply_add_impl\")\n",
- "def multiply_add_impl(x, y, z):\n",
- " \"\"\"Concrete implementation of the primitive.\n",
- "\n",
- " This function does not need to be JAX traceable.\n",
- " Args:\n",
- " x, y, z: the concrete arguments of the primitive. Will only be called with \n",
- " concrete values.\n",
- " Returns:\n",
- " the concrete result of the primitive.\n",
- " \"\"\"\n",
- " # Note that we can use the original numpy, which is not JAX traceable\n",
- " return np.add(np.multiply(x, y), z)\n",
- "\n",
- "# Now we register the primal implementation with JAX\n",
- "multiply_add_p.def_impl(multiply_add_impl)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 7,
- "metadata": {
- "id": "G5bstKaeNAVV",
- "outputId": "deb94d5b-dfea-4e6f-9ec2-70b416c996c5"
- },
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "call square_add_prim(2.0, 10.0)\n",
- " call multiply_add_prim(2.0, 2.0, 10.0)\n",
- " call multiply_add_impl(2.0, 2.0, 10.0)\n",
- " |<- multiply_add_impl = 14.0\n",
- " |<- multiply_add_prim = 14.0\n",
- "|<- square_add_prim = 14.0\n"
- ]
- }
- ],
- "source": [
- "assert square_add_prim(2., 10.) == 14."
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "upBf-uAuHhPJ"
- },
- "source": [
- "### JIT\n",
- "\n",
- "If we now try to use `jit` we get a `NotImplementedError`:"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 8,
- "metadata": {
- "id": "QG-LULjiHk4b",
- "outputId": "d4ef4406-8dae-4c96-97ca-b662340474ee"
- },
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "call square_add_prim(Traced, Traced)\n",
- " call multiply_add_prim(Traced, Traced, Traced)\n",
- "\n",
- "Found expected exception:\n"
- ]
- },
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "Traceback (most recent call last):\n",
- " File \"\", line 2, in \n",
- " api.jit(square_add_prim)(2., 10.)\n",
- " File \"/usr/local/lib/python3.6/dist-packages/jax/api.py\", line 149, in f_jitted\n",
- " out = xla.xla_call(flat_fun, *args_flat, device_assignment=device_assignment, backend=backend)\n",
- " File \"/usr/local/lib/python3.6/dist-packages/jax/core.py\", line 569, in call_bind\n",
- " outs = primitive.impl(f, *args, **params)\n",
- "NotImplementedError: Abstract evaluation for 'multiply_add' not implemented\n"
- ]
- }
- ],
- "source": [
- "with expectNotImplementedError():\n",
- " api.jit(square_add_prim)(2., 10.)"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "rHS1bAGHH44E"
- },
- "source": [
- "#### Abstract evaluation rules\n",
- "In order to JIT the function, and for other transformations as well, \n",
- "JAX first evaluates it abstractly using only the \n",
- "shape and type of the arguments. This abstract evaluation serves multiple\n",
- "purposes:\n",
- "\n",
- " * Gets the sequence of JAX primitives that are used in the computation. This \n",
- " sequence will be compiled. \n",
- " * Computes the shape and type of all vectors and operations used in the computation. \n",
- "\n",
- "\n",
- "For example, the abstraction of a vector with 3 elements may be `ShapedArray(float32[3])`, or `ConcreteArray([1., 2., 3.])`. \n",
- "In the latter case, JAX uses the actual concrete value wrapped as an abstract value."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 9,
- "metadata": {
- "id": "ctQmEeckIbdo",
- "outputId": "e751d0cc-460e-4ffd-df2e-fdabf9cffdc2"
- },
- "outputs": [
- {
- "data": {
- "text/plain": [
- ""
- ]
- },
- "execution_count": 9,
- "metadata": {
- "tags": []
- },
- "output_type": "execute_result"
- }
- ],
- "source": [
- "from jax import core\n",
- "@trace(\"multiply_add_abstract_eval\")\n",
- "def multiply_add_abstract_eval(xs, ys, zs):\n",
- " \"\"\"Abstract evaluation of the primitive.\n",
- "\n",
- " This function does not need to be JAX traceable. It will be invoked with\n",
- " abstractions of the actual arguments. \n",
- " Args:\n",
- " xs, ys, zs: abstractions of the arguments.\n",
- " Result:\n",
- " a ShapedArray for the result of the primitive.\n",
- " \"\"\"\n",
- " assert xs.shape == ys.shape\n",
- " assert xs.shape == zs.shape\n",
- " return core.ShapedArray(xs.shape, xs.dtype)\n",
- "\n",
- "# Now we register the abstract evaluation with JAX\n",
- "multiply_add_p.def_abstract_eval(multiply_add_abstract_eval)"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "RPN88X6YI43A"
- },
- "source": [
- "If we re-attempt to JIT, we see how the abstract evaluation proceeds, but\n",
- "we get another error, about missing the actual XLA compilation rule:"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 10,
- "metadata": {
- "id": "eOcNR92SI2h-",
- "outputId": "356ef229-3703-4696-cc3d-7c05de405fb0"
- },
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "call square_add_prim(Traced, Traced)\n",
- " call multiply_add_prim(Traced, Traced, Traced)\n",
- " call multiply_add_abstract_eval(ShapedArray(float32[]), ShapedArray(float32[]), ShapedArray(float32[]))\n",
- " |<- multiply_add_abstract_eval = ShapedArray(float32[])\n",
- " |<- multiply_add_prim = Traced\n",
- "|<- square_add_prim = Traced\n",
- "\n",
- "Found expected exception:\n"
- ]
- },
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "Traceback (most recent call last):\n",
- " File \"\", line 2, in \n",
- " api.jit(square_add_prim)(2., 10.)\n",
- " File \"/usr/local/lib/python3.6/dist-packages/jax/api.py\", line 149, in f_jitted\n",
- " out = xla.xla_call(flat_fun, *args_flat, device_assignment=device_assignment, backend=backend)\n",
- " File \"/usr/local/lib/python3.6/dist-packages/jax/core.py\", line 569, in call_bind\n",
- " outs = primitive.impl(f, *args, **params)\n",
- "NotImplementedError: XLA translation rule for primitive 'multiply_add' not found\n"
- ]
- }
- ],
- "source": [
- "with expectNotImplementedError():\n",
- " api.jit(square_add_prim)(2., 10.)"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "9IOV1R-fJMHp"
- },
- "source": [
- "#### XLA Compilation rules\n",
- "\n",
- "JAX compilation works by compiling each primitive into a graph of XLA operations.\n",
- "\n",
- "This is the biggest hurdle to adding new functionality to JAX, because the \n",
- "set of XLA operations is limited, and JAX already has pre-defined primitives\n",
- "for most of them. However, XLA includes a `CustomCall` operation that can be used to encapsulate arbitrary functionality defined using C++."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "id": "FYQWSSjKJaWP"
- },
- "outputs": [],
- "source": [
- "from jax._src.lib.mlir.dialects import hlo\n",
- "@trace(\"multiply_add_lowering\")\n",
- "def multiply_add_lowering(ctx, xc, yc, zc):\n",
- " \"\"\"The compilation to XLA of the primitive.\n",
- "\n",
- " Given an mlir.ir.Value for each argument, return the mlir.ir.Values for\n",
- " the results of the function.\n",
- "\n",
- " Does not need to be a JAX-traceable function.\n",
- " \"\"\"\n",
- " return [hlo.AddOp(hlo.MulOp(xc, yc), zc).result]\n",
- "\n",
- "# Now we register the lowering rule with JAX\n",
- "# For GPU see the [Custom operations for GPUs](https://jax.readthedocs.io/en/latest/Custom_Operation_for_GPUs.html)\n",
- "# TODO: TPU?\n",
- "from jax.interpreters import mlir\n",
- "mlir.register_lowering(multiply_add_p, multiply_add_lowering, platform='cpu')"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "K98LX-VaJkFu"
- },
- "source": [
- "Now we succeed to JIT. Notice below that JAX first evaluates the function\n",
- "abstractly, which triggers the `multiply_add_abstract_eval` function, and \n",
- "then compiles the set of primitives it has encountered, including `multiply_add`.\n",
- "At this point JAX invokes `multiply_add_xla_translation`."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 12,
- "metadata": {
- "id": "rj3TLsolJgEc",
- "outputId": "e384bee4-1e9c-4344-f49c-d3b5ec08eb32"
- },
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "call square_add_prim(Traced, Traced)\n",
- " call multiply_add_prim(Traced, Traced, Traced)\n",
- " call multiply_add_abstract_eval(ShapedArray(float32[]), ShapedArray(float32[]), ShapedArray(float32[]))\n",
- " |<- multiply_add_abstract_eval = ShapedArray(float32[])\n",
- " |<- multiply_add_prim = Traced\n",
- "|<- square_add_prim = Traced\n",
- "call multiply_add_xla_translation(, , , )\n",
- "|<- multiply_add_xla_translation = \n"
- ]
- }
- ],
- "source": [
- "assert api.jit(lambda x, y: square_add_prim(x, y))(2., 10.) == 14."
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "Omrez-2_KFfo"
- },
- "source": [
- "Below is another use of `jit` where we compile only\n",
- "with respect to the first argument. Notice how the second argument to `square_add_prim` is concrete, which leads\n",
- "in the third argument to `multiply_add_abstract_eval` being \n",
- "`ConcreteArray`. We see that `multiply_add_abstract_eval` may be used with\n",
- "both `ShapedArray` and `ConcreteArray`."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 13,
- "metadata": {
- "id": "mPfTwIBoKOEK",
- "outputId": "b293b9b6-a2f9-48f5-f7eb-d4f99c3d905b"
- },
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "call square_add_prim(Traced, 10.0)\n",
- " call multiply_add_prim(Traced, Traced, 10.0)\n",
- " call multiply_add_abstract_eval(ShapedArray(float32[]), ShapedArray(float32[]), ConcreteArray(10.0))\n",
- " |<- multiply_add_abstract_eval = ShapedArray(float32[])\n",
- " |<- multiply_add_prim = Traced\n",
- "|<- square_add_prim = Traced\n",
- "call multiply_add_xla_translation(, , , )\n",
- "|<- multiply_add_xla_translation = \n"
- ]
- }
- ],
- "source": [
- "assert api.jit(lambda x, y: square_add_prim(x, y), \n",
- " static_argnums=1)(2., 10.) == 14."
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "_Ya3B5l4J1VA"
- },
- "source": [
- "### Forward differentiation\n",
- "\n",
- "JAX implements forward differentiation in the form of\n",
- "a Jacobian-vector product (see the [JAX autodiff cookbook](https://jax.readthedocs.io/en/latest/notebooks/autodiff_cookbook.html#Jacobian-Matrix-and-Matrix-Jacobian-products)).\n",
- "\n",
- "If we attempt now to compute the `jvp` function we get an\n",
- "error because we have not yet told JAX how to differentiate\n",
- "the `multiply_add` primitive."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 14,
- "metadata": {
- "id": "OxDx6NQnKwMI",
- "outputId": "ce659ef3-c03c-4856-f252-49ec4b6eb964"
- },
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "call square_add_prim(Traced, Traced)\n",
- " call multiply_add_prim(Traced, Traced, Traced)\n",
- "\n",
- "Found expected exception:\n"
- ]
- },
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "Traceback (most recent call last):\n",
- " File \"/usr/local/lib/python3.6/dist-packages/jax/interpreters/ad.py\", line 217, in process_primitive\n",
- " jvp = primitive_jvps[primitive]\n",
- "KeyError: multiply_add\n",
- "\n",
- "During handling of the above exception, another exception occurred:\n",
- "\n",
- "Traceback (most recent call last):\n",
- " File \"\", line 2, in \n",
- " api.jvp(square_add_prim, (2., 10.), (1., 1.))\n",
- " File \"/usr/local/lib/python3.6/dist-packages/jax/api.py\", line 978, in jvp\n",
- " out_primals, out_tangents = ad.jvp(flat_fun).call_wrapped(ps_flat, ts_flat)\n",
- " File \"/usr/local/lib/python3.6/dist-packages/jax/linear_util.py\", line 165, in call_wrapped\n",
- " ans = self.f(*args, **dict(self.params, **kwargs))\n",
- "NotImplementedError: Forward-mode differentiation rule for 'multiply_add' not implemented\n"
- ]
- }
- ],
- "source": [
- "# The second argument `(2., 10.)` are the argument values\n",
- "# where we evaluate the Jacobian, and the third `(1., 1.)`\n",
- "# are the values of the tangents for the arguments.\n",
- "with expectNotImplementedError():\n",
- " api.jvp(square_add_prim, (2., 10.), (1., 1.))"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "id": "zxG24C1JMIMM"
- },
- "outputs": [],
- "source": [
- "from jax.interpreters import ad\n",
- "\n",
- "\n",
- "@trace(\"multiply_add_value_and_jvp\")\n",
- "def multiply_add_value_and_jvp(arg_values, arg_tangents):\n",
- " \"\"\"Evaluates the primal output and the tangents (Jacobian-vector product).\n",
- "\n",
- " Given values of the arguments and perturbation of the arguments (tangents), \n",
- " compute the output of the primitive and the perturbation of the output.\n",
- "\n",
- " This method must be JAX-traceable. JAX may invoke it with abstract values \n",
- " for the arguments and tangents.\n",
- "\n",
- " Args:\n",
- " arg_values: a tuple of arguments\n",
- " arg_tangents: a tuple with the tangents of the arguments. The tuple has \n",
- " the same length as the arg_values. Some of the tangents may also be the \n",
- " special value ad.Zero to specify a zero tangent.\n",
- " Returns:\n",
- " a pair of the primal output and the tangent.\n",
- " \"\"\"\n",
- " x, y, z = arg_values\n",
- " xt, yt, zt = arg_tangents\n",
- " _trace(\"Primal evaluation:\")\n",
- " # Now we have a JAX-traceable computation of the output. \n",
- " # Normally, we can use the ma primitive itself to compute the primal output. \n",
- " primal_out = multiply_add_prim(x, y, z)\n",
- " \n",
- " _trace(\"Tangent evaluation:\")\n",
- " # We must use a JAX-traceable way to compute the tangent. It turns out that \n",
- " # the output tangent can be computed as (xt * y + x * yt + zt),\n",
- " # which we can implement in a JAX-traceable way using the same \"multiply_add_prim\" primitive.\n",
- " \n",
- " # We do need to deal specially with Zero. Here we just turn it into a \n",
- " # proper tensor of 0s (of the same shape as 'x'). \n",
- " # An alternative would be to check for Zero and perform algebraic \n",
- " # simplification of the output tangent computation.\n",
- " def make_zero(tan):\n",
- " return lax.zeros_like_array(x) if type(tan) is ad.Zero else tan \n",
- " \n",
- " output_tangent = multiply_add_prim(make_zero(xt), y, multiply_add_prim(x, make_zero(yt), make_zero(zt)))\n",
- " return (primal_out, output_tangent)\n",
- "\n",
- "# Register the forward differentiation rule with JAX \n",
- "ad.primitive_jvps[multiply_add_p] = multiply_add_value_and_jvp"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 16,
- "metadata": {
- "id": "ma3KBkiAMfW1",
- "outputId": "f34cbbc6-20d9-48ca-9a9a-b5d91a972cdd"
- },
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "call square_add_prim(Traced, Traced)\n",
- " call multiply_add_prim(Traced, Traced, Traced)\n",
- " call multiply_add_value_and_jvp((2.0, 2.0, 10.0), (1.0, 1.0, 1.0))\n",
- " Primal evaluation:\n",
- " call multiply_add_prim(2.0, 2.0, 10.0)\n",
- " call multiply_add_impl(2.0, 2.0, 10.0)\n",
- " |<- multiply_add_impl = 14.0\n",
- " |<- multiply_add_prim = 14.0\n",
- " Tangent evaluation:\n",
- " call multiply_add_prim(2.0, 1.0, 1.0)\n",
- " call multiply_add_impl(2.0, 1.0, 1.0)\n",
- " |<- multiply_add_impl = 3.0\n",
- " |<- multiply_add_prim = 3.0\n",
- " call multiply_add_prim(1.0, 2.0, 3.0)\n",
- " call multiply_add_impl(1.0, 2.0, 3.0)\n",
- " |<- multiply_add_impl = 5.0\n",
- " |<- multiply_add_prim = 5.0\n",
- " |<- multiply_add_value_and_jvp = (14.0, 5.0)\n",
- " |<- multiply_add_prim = Traced\n",
- "|<- square_add_prim = Traced\n"
- ]
- }
- ],
- "source": [
- "# Tangent is: xt*y + x*yt + zt = 1.*2. + 2.*1. + 1. = 5.\n",
- "assert api.jvp(square_add_prim, (2., 10.), (1., 1.)) == (14., 5.)"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "69QsEcu-lP4u"
- },
- "source": [
- "TO EXPLAIN: \n",
- "\n",
- " * Why is JAX using ConcreteArray in square_add_prim? There is no abstract evaluation going on here.\n",
- " * Not sure how to explain that multiply_add_prim is invoked with ConcreteValue, yet\n",
- " we do not call the multiply_add_abstract_eval.\n",
- " * I think it would be useful to show the jaxpr here"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "Sb6e3ZAHOPHv"
- },
- "source": [
- "#### JIT of forward differentiation\n",
- "\n",
- "We can apply JIT to the forward differentiation function:"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 17,
- "metadata": {
- "id": "hg-hzVu-N-hv",
- "outputId": "38d32067-e152-4046-ad80-7f95a31ba628"
- },
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "call square_add_prim(Traced, Traced)\n",
- " call multiply_add_prim(Traced, Traced, Traced)\n",
- " call multiply_add_value_and_jvp((Traced, Traced, Traced), (Traced, Traced, Traced))\n",
- " Primal evaluation:\n",
- " call multiply_add_prim(Traced, Traced, Traced)\n",
- " call multiply_add_abstract_eval(ShapedArray(float32[]), ShapedArray(float32[]), ShapedArray(float32[]))\n",
- " |<- multiply_add_abstract_eval = ShapedArray(float32[])\n",
- " |<- multiply_add_prim = Traced\n",
- " Tangent evaluation:\n",
- " call multiply_add_prim(Traced, Traced, Traced)\n",
- " call multiply_add_abstract_eval(ShapedArray(float32[]), ShapedArray(float32[]), ShapedArray(float32[]))\n",
- " |<- multiply_add_abstract_eval = ShapedArray(float32[])\n",
- " |<- multiply_add_prim = Traced\n",
- " call multiply_add_prim(Traced, Traced, Traced)\n",
- " call multiply_add_abstract_eval(ShapedArray(float32[]), ShapedArray(float32[]), ShapedArray(float32[]))\n",
- " |<- multiply_add_abstract_eval = ShapedArray(float32[])\n",
- " |<- multiply_add_prim = Traced\n",
- " |<- multiply_add_value_and_jvp = (Traced, Traced)\n",
- " |<- multiply_add_prim = Traced\n",
- "|<- square_add_prim = Traced\n",
- "call multiply_add_xla_translation(, , , )\n",
- "|<- multiply_add_xla_translation = \n",
- "call multiply_add_xla_translation(, , , )\n",
- "|<- multiply_add_xla_translation = \n",
- "call multiply_add_xla_translation(, , , )\n",
- "|<- multiply_add_xla_translation = \n"
- ]
- }
- ],
- "source": [
- "assert api.jit(lambda arg_values, arg_tangents: \n",
- " api.jvp(square_add_prim, arg_values, arg_tangents))(\n",
- " (2., 10.), (1., 1.)) == (14., 5.)"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "jlZt1_v2mU88"
- },
- "source": [
- "Notice that first we evaluate `multiply_add_value_and_jvp` abstractly, which in turn\n",
- "evaluates abstractly both the primal and the tangent evaluation (a total of \n",
- "3 invocations of the `ma` primitive). Then we compile the 3 occurrences\n",
- "of the primitive."
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "555yt6ZIOePB"
- },
- "source": [
- "### Reverse differentiation\n",
- "\n",
- "If we attempt now to use reverse differentiation we\n",
- "see that JAX starts by using the `multiply_add_value_and_jvp` to \n",
- "compute the forward differentiation for abstract values, but then runs\n",
- "into a `NotImplementedError`. \n",
- "\n",
- "When computing the reverse differentiation JAX first does abstract evaluation\n",
- "of the forward differentiation code `multiply_add_value_and_jvp` to obtain a \n",
- "trace of primitives that compute the output tangent. \n",
- "Observe that JAX performs this abstract evaluation with concrete values\n",
- "for the differentiation point, and abstract values for the tangents. \n",
- "Observe also that JAX uses the special abstract tangent value `Zero` for\n",
- "the tangent corresponding to the 3rd argument of `ma`. This reflects the \n",
- "fact that we do not differentiate w.r.t. the 2nd argument to `square_add_prim`,\n",
- "which flows to the 3rd argument to `multiply_add_prim`.\n",
- "\n",
- "Observe also that during the abstract evaluation of the tangent we pass the \n",
- "value 0.0 as the tangent for the 3rd argument. This is due to the use\n",
- "of the `make_zero` function in the definition of `multiply_add_value_and_jvp`."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 18,
- "metadata": {
- "id": "8eAVnexaOjBn",
- "outputId": "e4ee89cf-ab4a-4505-9817-fa978a2865ab"
- },
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "call square_add_prim(Traced, 10.0)\n",
- " call multiply_add_prim(Traced, Traced, 10.0)\n",
- " call multiply_add_value_and_jvp((Traced, Traced, 10.0), (Traced, Traced, Zero))\n",
- " Primal evaluation:\n",
- " call multiply_add_prim(Traced, Traced, 10.0)\n",
- " call multiply_add_impl(2.0, 2.0, 10.0)\n",
- " |<- multiply_add_impl = 14.0\n",
- " |<- multiply_add_prim = 14.0\n",
- " Tangent evaluation:\n",
- " call multiply_add_prim(Traced, Traced, 0.0)\n",
- " call multiply_add_abstract_eval(ConcreteArray(2.0), ShapedArray(float32[]), ConcreteArray(0.0))\n",
- " |<- multiply_add_abstract_eval = ShapedArray(float32[])\n",
- " |<- multiply_add_prim = Traced\n",
- " call multiply_add_prim(Traced, Traced, Traced)\n",
- " call multiply_add_abstract_eval(ShapedArray(float32[]), ConcreteArray(2.0), ShapedArray(float32[]))\n",
- " |<- multiply_add_abstract_eval = ShapedArray(float32[])\n",
- " |<- multiply_add_prim = Traced\n",
- " |<- multiply_add_value_and_jvp = (14.0, Traced)\n",
- " |<- multiply_add_prim = Traced\n",
- "|<- square_add_prim = Traced\n",
- "\n",
- "Found expected exception:\n"
- ]
- },
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "Traceback (most recent call last):\n",
- " File \"/usr/local/lib/python3.6/dist-packages/jax/interpreters/ad.py\", line 198, in get_primitive_transpose\n",
- " return primitive_transposes[p]\n",
- "KeyError: multiply_add\n",
- "\n",
- "During handling of the above exception, another exception occurred:\n",
- "\n",
- "Traceback (most recent call last):\n",
- " File \"\", line 2, in \n",
- " api.grad(square_add_prim)(2., 10.)\n",
- " File \"/usr/local/lib/python3.6/dist-packages/jax/api.py\", line 340, in grad_f\n",
- " _, g = value_and_grad_f(*args, **kwargs)\n",
- " File \"/usr/local/lib/python3.6/dist-packages/jax/api.py\", line 398, in value_and_grad_f\n",
- " g = vjp_py(np.ones((), dtype=dtype))\n",
- "NotImplementedError: Reverse-mode differentiation rule for 'multiply_add' not implemented\n"
- ]
- }
- ],
- "source": [
- "# This is reverse differentiation w.r.t. the first argument of square_add_prim\n",
- "with expectNotImplementedError():\n",
- " api.grad(square_add_prim)(2., 10.)"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "fSHLUMDN26AY"
- },
- "source": [
- "The above error is because there is a missing piece for JAX to be able\n",
- "to use the forward differentiation code to compute reverse differentiation."
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "3ibDbGF-PjK9"
- },
- "source": [
- "#### Transposition\n",
- "\n",
- "\n",
- "As explained above, when computing reverse differentiation JAX obtains\n",
- "a trace of primitives that compute the tangent using forward differentiation.\n",
- "Then, **JAX interprets this trace abstractly backwards** and for each \n",
- "primitive it applies a **transposition** rule.\n",
- "\n",
- "To understand what is going on, consider for now a simpler example of the function \"f(x, y) = x * y + y\". Assume we need to differentiate at the point `(2., 4.)`. JAX will produce the following JVP tangent calculation of `ft` from the tangents of the input `xt` and `yt`:\n",
- "```\n",
- " a = xt * 4.\n",
- " b = 2. * yt\n",
- " c = a + b\n",
- " ft = c + yt\n",
- "```\n",
- "\n",
- "By construction, the tangent calculation is always linear in the input tangents. \n",
- "The only non-linear operator that may arise in the tangent calculation is multiplication,\n",
- "but then one of the operands is constant.\n",
- "\n",
- "JAX will produce the reverse differentiation computation by processing the\n",
- "JVP computation backwards. For each operation in the tangent computation,\n",
- "it accumulates the cotangents\n",
- "of the variables used by the operation, using the cotangent of the result\n",
- "of the operation:\n",
- "```\n",
- " # Initialize cotangents of inputs and intermediate vars\n",
- " xct = yct = act = bct = cct = 0.\n",
- " # Initialize cotangent of the output\n",
- " fct = 1.\n",
- " # Process \"ft = c + yt\"\n",
- " cct += fct\n",
- " yct += fct\n",
- " # Process \"c = a + b\"\n",
- " act += cct\n",
- " bct += cct\n",
- " # Process \"b = 2. * yt\"\n",
- " yct += 2. * bct\n",
- " # Process \"a = xt * 4.\"\n",
- " xct += act * 4.\n",
- "```\n",
- "\n",
- "One can verify that this computation produces `xct = 4.` and `yct = 3.`, which \n",
- "are the partial derivatives of the function `f`. \n",
- "\n",
- "JAX knows for each primitive that may appear in a JVP calculation how to transpose it. Conceptually, if the primitive `p(x, y, z)` is linear in the arguments `y` and `z` for a constant value of `x`, e.g., `p(x, y, z) = y*cy + z*cz`, then the transposition of the primitive is:\n",
- "```\n",
- "p_transpose(out_ct, x, _, _) = (None, out_ct*cy, out_ct*cz)\n",
- "```\n",
- "\n",
- "Notice that `p_transpose` takes the cotangent of the output of the primitive and a value corresponding to each argument of the primitive. For the linear arguments, the transposition gets an undefined `_` value, and for the other\n",
- "arguments it gets the actual constants. The transposition returns a cotangent value for each argument of the primitive, with the value `None` returned \n",
- "for the constant arguments.\n",
- "\n",
- "In particular, \n",
- "```\n",
- " add_transpose(out_ct, _, _) = (out_ct, out_ct)\n",
- " mult_transpose(out_ct, x, _) = (None, x * out_ct)\n",
- " mult_transpose(out_ct, _, y) = (out_ct * y, None)\n",
- "```"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "id": "JaHxFdkRO42r"
- },
- "outputs": [],
- "source": [
- "@trace(\"multiply_add_transpose\")\n",
- "def multiply_add_transpose(ct, x, y, z):\n",
- " \"\"\"Evaluates the transpose of a linear primitive.\n",
- "\n",
- " This method is only used when computing the backward gradient following \n",
- " value_and_jvp, and is only needed for primitives that are used in the JVP \n",
- " calculation for some other primitive. We need transposition for multiply_add_prim, \n",
- " because we have used multiply_add_prim in the computation of the output_tangent in \n",
- " multiply_add_value_and_jvp.\n",
- "\n",
- " In our case, multiply_add is not a linear primitive. However, it is used linearly \n",
- " w.r.t. tangents in multiply_add_value_and_jvp:\n",
- " output_tangent(xt, yt, zt) = multiply_add_prim(xt, y, multiply_add_prim(x, yt, zt))\n",
- " \n",
- " Always one of the first two multiplicative arguments is a constant.\n",
- "\n",
- " Args:\n",
- " ct: the cotangent of the output of the primitive.\n",
- " x, y, z: values of the arguments. The arguments that are used linearly\n",
- " get an ad.UndefinedPrimal value. The other arguments get a constant\n",
- " value.\n",
- " Returns:\n",
- " a tuple with the cotangent of the inputs, with the value None\n",
- " corresponding to the constant arguments.\n",
- " \"\"\"\n",
- " if not ad.is_undefined_primal(x):\n",
- " # This use of multiply_add is with a constant \"x\"\n",
- " assert ad.is_undefined_primal(y)\n",
- " ct_y = ad.Zero(y.aval) if type(ct) is ad.Zero else multiply_add_prim(x, ct, lax.zeros_like_array(x))\n",
- " res = None, ct_y, ct\n",
- " else:\n",
- " # This use of multiply_add is with a constant \"y\"\n",
- " assert ad.is_undefined_primal(x)\n",
- " ct_x = ad.Zero(x.aval) if type(ct) is ad.Zero else multiply_add_prim(ct, y, lax.zeros_like_array(y))\n",
- " res = ct_x, None, ct\n",
- " return res\n",
- "\n",
- "\n",
- "ad.primitive_transposes[multiply_add_p] = multiply_add_transpose"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "PpChox-Jp7wb"
- },
- "source": [
- "Now we can complete the run of the `grad`:"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 20,
- "metadata": {
- "id": "PogPKS4MPevd",
- "outputId": "d33328d4-3e87-45b5-9b31-21ad624b67af"
- },
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "call square_add_prim(Traced, 10.0)\n",
- " call multiply_add_prim(Traced, Traced, 10.0)\n",
- " call multiply_add_value_and_jvp((Traced, Traced, 10.0), (Traced, Traced, Zero))\n",
- " Primal evaluation:\n",
- " call multiply_add_prim(Traced, Traced, 10.0)\n",
- " call multiply_add_impl(2.0, 2.0, 10.0)\n",
- " |<- multiply_add_impl = 14.0\n",
- " |<- multiply_add_prim = 14.0\n",
- " Tangent evaluation:\n",
- " call multiply_add_prim(Traced, Traced, 0.0)\n",
- " call multiply_add_abstract_eval(ConcreteArray(2.0), ShapedArray(float32[]), ConcreteArray(0.0))\n",
- " |<- multiply_add_abstract_eval = ShapedArray(float32[])\n",
- " |<- multiply_add_prim = Traced\n",
- " call multiply_add_prim(Traced, Traced, Traced)\n",
- " call multiply_add_abstract_eval(ShapedArray(float32[]), ConcreteArray(2.0), ShapedArray(float32[]))\n",
- " |<- multiply_add_abstract_eval = ShapedArray(float32[])\n",
- " |<- multiply_add_prim = Traced\n",
- " |<- multiply_add_value_and_jvp = (14.0, Traced)\n",
- " |<- multiply_add_prim = Traced\n",
- "|<- square_add_prim = Traced\n",
- "call multiply_add_transpose(1.0, _, 2.0, _)\n",
- " call multiply_add_prim(1.0, 2.0, 0.0)\n",
- " call multiply_add_impl(1.0, 2.0, 0.0)\n",
- " |<- multiply_add_impl = 2.0\n",
- " |<- multiply_add_prim = 2.0\n",
- "|<- multiply_add_transpose = (2.0, None, 1.0)\n",
- "call multiply_add_transpose(1.0, 2.0, _, 0.0)\n",
- " call multiply_add_prim(2.0, 1.0, 0.0)\n",
- " call multiply_add_impl(2.0, 1.0, 0.0)\n",
- " |<- multiply_add_impl = 2.0\n",
- " |<- multiply_add_prim = 2.0\n",
- "|<- multiply_add_transpose = (None, 2.0, 1.0)\n"
- ]
- }
- ],
- "source": [
- "assert api.grad(square_add_prim)(2., 10.) == 4."
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "8M1xLCXW4fK7"
- },
- "source": [
- "Notice the two calls to `multiply_add_transpose`. They correspond to the two\n",
- "uses of `multiply_add_prim` in the computation of the `output_tangent` in `multiply_add_value_and_jvp`. The first call to transpose corresponds to the \n",
- "last use of `multiply_add_prim`: `multiply_add_prim(xt, y, ...)` where `y` is the constant 2.0."
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "EIJs6FYmPg6c"
- },
- "source": [
- "#### JIT of reverse differentiation \n",
- "\n",
- "Notice that the abstract evaluation of the `multiply_add_value_and_jvp` is using only\n",
- "abstract values, while in the absence of JIT we used `ConcreteArray`."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 21,
- "metadata": {
- "id": "FZ-JGbWZPq2-",
- "outputId": "e42b5222-9c3e-4853-e13a-874f6605d178"
- },
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "call square_add_prim(Traced, Traced)\n",
- " call multiply_add_prim(Traced, Traced, Traced)\n",
- " call multiply_add_value_and_jvp((Traced, Traced, Traced), (Traced, Traced, Zero))\n",
- " Primal evaluation:\n",
- " call multiply_add_prim(Traced, Traced, Traced)\n",
- " call multiply_add_abstract_eval(ShapedArray(float32[]), ShapedArray(float32[]), ShapedArray(float32[]))\n",
- " |<- multiply_add_abstract_eval = ShapedArray(float32[])\n",
- " |<- multiply_add_prim = Traced\n",
- " Tangent evaluation:\n",
- " call multiply_add_prim(Traced, Traced, Traced)\n",
- " call multiply_add_abstract_eval(ShapedArray(float32[]), ShapedArray(float32[]), ConcreteArray(0.0))\n",
- " |<- multiply_add_abstract_eval = ShapedArray(float32[])\n",
- " |<- multiply_add_prim = Traced\n",
- " call multiply_add_prim(Traced, Traced, Traced)\n",
- " call multiply_add_abstract_eval(ShapedArray(float32[]), ShapedArray(float32[]), ShapedArray(float32[]))\n",
- " |<- multiply_add_abstract_eval = ShapedArray(float32[])\n",
- " |<- multiply_add_prim = Traced\n",
- " |<- multiply_add_value_and_jvp = (Traced, Traced)\n",
- " |<- multiply_add_prim = Traced\n",
- "|<- square_add_prim = Traced\n",
- "call multiply_add_transpose(1.0, _, Traced, _)\n",
- " call multiply_add_prim(1.0, Traced, Traced)\n",
- " call multiply_add_abstract_eval(ConcreteArray(1.0), ShapedArray(float32[]), ConcreteArray(0.0))\n",
- " |<- multiply_add_abstract_eval = ShapedArray(float32[])\n",
- " |<- multiply_add_prim = Traced\n",
- "|<- multiply_add_transpose = (Traced, None, 1.0)\n",
- "call multiply_add_transpose(1.0, Traced, _, Traced)\n",
- " call multiply_add_prim(Traced, 1.0, Traced)\n",
- " call multiply_add_abstract_eval(ShapedArray(float32[]), ConcreteArray(1.0), ConcreteArray(0.0))\n",
- " |<- multiply_add_abstract_eval = ShapedArray(float32[])\n",
- " |<- multiply_add_prim = Traced\n",
- "|<- multiply_add_transpose = (None, Traced, 1.0)\n",
- "call multiply_add_xla_translation(, , , )\n",
- "|<- multiply_add_xla_translation = \n",
- "call multiply_add_xla_translation(, , , )\n",
- "|<- multiply_add_xla_translation = \n"
- ]
- }
- ],
- "source": [
- "assert api.jit(api.grad(square_add_prim))(2., 10.) == 4."
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "-3lqPkdQPvl5"
- },
- "source": [
- "### Batching\n",
- "\n",
- "The batching transformation takes a point-wise computation and turns it\n",
- "into a computation on vectors. If we try it right now, we get a `NotImplementedError`:"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 22,
- "metadata": {
- "id": "hFvBR3I9Pzh3",
- "outputId": "434608bc-281f-4d3b-83bd-eaaf3b51b1cd"
- },
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "call square_add_prim(Traced, Traced)\n",
- " call multiply_add_prim(Traced, Traced, Traced)\n",
- "\n",
- "Found expected exception:\n"
- ]
- },
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "Traceback (most recent call last):\n",
- " File \"/usr/local/lib/python3.6/dist-packages/jax/interpreters/batching.py\", line 163, in get_primitive_batcher\n",
- " return primitive_batchers[p]\n",
- "KeyError: multiply_add\n",
- "\n",
- "During handling of the above exception, another exception occurred:\n",
- "\n",
- "Traceback (most recent call last):\n",
- " File \"\", line 3, in \n",
- " np.array([10., 20.]))\n",
- " File \"/usr/local/lib/python3.6/dist-packages/jax/api.py\", line 611, in batched_fun\n",
- " lambda: _flatten_axes(out_tree(), out_axes))\n",
- " File \"/usr/local/lib/python3.6/dist-packages/jax/interpreters/batching.py\", line 41, in batch\n",
- " out_vals, out_dims = batch2(fun, in_vals, in_dims)\n",
- "NotImplementedError: Batching rule for 'multiply_add' not implemented\n"
- ]
- }
- ],
- "source": [
- "# The arguments are two vectors instead of two scalars\n",
- "with expectNotImplementedError():\n",
- " api.vmap(square_add_prim, in_axes=0, out_axes=0)(np.array([2., 3.]),\n",
- " np.array([10., 20.]))"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "gILasMiP6elR"
- },
- "source": [
- "We need to tell JAX how to evaluate the batched version of the primitive. In this particular case, the `multiply_add_prim` already operates pointwise for any dimension of input vectors. So the batched version can use the same `multiply_add_prim` implementation."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "id": "KQfeqRIrP7zg"
- },
- "outputs": [],
- "source": [
- "from jax.interpreters import batching\n",
- "\n",
- "\n",
- "@trace(\"multiply_add_batch\")\n",
- "def multiply_add_batch(vector_arg_values, batch_axes):\n",
- " \"\"\"Computes the batched version of the primitive.\n",
- " \n",
- " This must be a JAX-traceable function.\n",
- " \n",
- " Since the multiply_add primitive already operates pointwise on arbitrary\n",
- " dimension tensors, to batch it we can use the primitive itself. This works as\n",
- " long as both the inputs have the same dimensions and are batched along the\n",
- " same axes. The result is batched along the axis that the inputs are batched.\n",
- " \n",
- " Args:\n",
- " vector_arg_values: a tuple of two arguments, each being a tensor of matching\n",
- " shape.\n",
- " batch_axes: the axes that are being batched. See vmap documentation.\n",
- " Returns:\n",
- " a tuple of the result, and the result axis that was batched. \n",
- " \"\"\"\n",
- " assert batch_axes[0] == batch_axes[1]\n",
- " assert batch_axes[0] == batch_axes[2]\n",
- " _trace(\"Using multiply_add to compute the batch:\")\n",
- " res = multiply_add_prim(*vector_arg_values)\n",
- " return res, batch_axes[0]\n",
- "\n",
- "\n",
- "batching.primitive_batchers[multiply_add_p] = multiply_add_batch"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 24,
- "metadata": {
- "id": "VwxNk869P_YG",
- "outputId": "9d22c921-5803-4d33-9e88-b6e439ba9738"
- },
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "call square_add_prim(Traced, Traced)\n",
- " call multiply_add_prim(Traced, Traced, Traced)\n",
- " call multiply_add_batch(([2. 3.], [2. 3.], [10. 20.]), (0, 0, 0))\n",
- " Using multiply_add to compute the batch:\n",
- " call multiply_add_prim([2. 3.], [2. 3.], [10. 20.])\n",
- " call multiply_add_impl([2. 3.], [2. 3.], [10. 20.])\n",
- " |<- multiply_add_impl = [14. 29.]\n",
- " |<- multiply_add_prim = [14. 29.]\n",
- " |<- multiply_add_batch = ([14. 29.], 0)\n",
- " |<- multiply_add_prim = Traced\n",
- "|<- square_add_prim = Traced\n"
- ]
- }
- ],
- "source": [
- "assert np.allclose(api.vmap(square_add_prim, in_axes=0, out_axes=0)(\n",
- " np.array([2., 3.]),\n",
- " np.array([10., 20.])),\n",
- " [14., 29.])"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "NmqLlV1TQDCC"
- },
- "source": [
- "#### JIT of batching"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 25,
- "metadata": {
- "id": "xqEdXVUgQCTt",
- "outputId": "9c22fd9c-919c-491d-bbeb-32c241b808fa"
- },
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "call square_add_prim(Traced, Traced)\n",
- " call multiply_add_prim(Traced, Traced, Traced)\n",
- " call multiply_add_batch((Traced, Traced, Traced), (0, 0, 0))\n",
- " Using multiply_add to compute the batch:\n",
- " call multiply_add_prim(Traced, Traced, Traced)\n",
- " call multiply_add_abstract_eval(ShapedArray(float32[2]), ShapedArray(float32[2]), ShapedArray(float32[2]))\n",
- " |<- multiply_add_abstract_eval = ShapedArray(float32[2])\n",
- " |<- multiply_add_prim = Traced\n",
- " |<- multiply_add_batch = (Traced, 0)\n",
- " |<- multiply_add_prim = Traced\n",
- "|<- square_add_prim = Traced\n",
- "call multiply_add_xla_translation(, , , )\n",
- "|<- multiply_add_xla_translation = \n"
- ]
- }
- ],
- "source": [
- "assert np.allclose(api.jit(api.vmap(square_add_prim, in_axes=0, out_axes=0))\n",
- " (np.array([2., 3.]),\n",
- " np.array([10., 20.])),\n",
- " [14., 29.])"
- ]
- }
- ],
- "metadata": {
- "colab": {
- "collapsed_sections": [],
- "name": "How JAX primitives work.ipynb",
- "provenance": [],
- "toc_visible": true
- },
- "jupytext": {
- "formats": "ipynb,md:myst"
- },
- "kernelspec": {
- "display_name": "Python 3",
- "name": "python3"
- },
- "language_info": {
- "codemirror_mode": {
- "name": "ipython",
- "version": 3
- },
- "file_extension": ".py",
- "mimetype": "text/x-python",
- "name": "python",
- "nbconvert_exporter": "python",
- "pygments_lexer": "ipython3",
- "version": "3.7.6"
- }
- },
- "nbformat": 4,
- "nbformat_minor": 0
-}
diff --git a/docs/notebooks/How_JAX_primitives_work.md b/docs/notebooks/How_JAX_primitives_work.md
deleted file mode 100644
index 0ebf202f2258..000000000000
--- a/docs/notebooks/How_JAX_primitives_work.md
+++ /dev/null
@@ -1,771 +0,0 @@
----
-jupytext:
- formats: ipynb,md:myst
- text_representation:
- extension: .md
- format_name: myst
- format_version: 0.13
- jupytext_version: 1.16.1
-kernelspec:
- display_name: Python 3
- name: python3
----
-
-+++ {"id": "vfxqky4PCUnh"}
-
-# How JAX primitives work
-
-
-
-[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/jax/blob/main/docs/notebooks/How_JAX_primitives_work.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/google/jax/blob/main/docs/notebooks/How_JAX_primitives_work.ipynb)
-
-*necula@google.com*, October 2019.
-
-JAX implements certain transformations of Python functions, e.g., `jit`, `grad`,
-`vmap`, or `pmap`. The Python functions to be transformed must be JAX-traceable,
-which means that as the Python function executes
-the only operations it applies to the data are either inspections of data
-attributes such as shape or type, or special operations called JAX primitives.
-In particular, a JAX-traceable function is sometimes invoked by JAX with
-abstract arguments. An example of a JAX abstract value is `ShapedArray(float32[2,2])`,
-which captures the type and the shape of values, but not the concrete data values.
-JAX primitives know how to operate on both concrete data
-values and on the JAX abstract values.
-
-
-The JAX-transformed functions must themselves be JAX-traceable functions,
-to ensure that these transformations
-can be composed, e.g., `jit(jacfwd(grad(f)))`.
-
-There are pre-defined JAX primitives corresponding to most XLA operations,
-e.g., add, matmul, sin, cos, indexing.
-JAX comes with an implementation of numpy functions in terms of JAX primitives, which means that Python programs
-using JAX’s implementation of numpy are JAX-traceable and therefore transformable.
-Other libraries can be made JAX-traceable by implementing them in terms of JAX primitives.
-
-The set of JAX primitives is extensible. Instead of reimplementing a function in terms of pre-defined JAX primitives,
-one can define a new primitive that encapsulates the behavior of the function.
-
-**The goal of this document is to explain the interface that a JAX primitive must support in order to allow JAX to perform all its transformations.**
-
-Consider that we want to add to JAX support for a multiply-add function with three arguments, defined mathematically
-as "multiply_add(x, y, z) = x * y + z".
-This function operates on 3 identically-shaped tensors of floating point
-values and performs the operations pointwise.
-
-+++ {"id": "HIJYIHNTD1yI"}
-
-## Using existing primitives
-
-The easiest way to define new functions is to write them in terms of JAX primitives, or in terms of other
-functions that are themselves written using JAX primitives, e.g., those
-defined in the `jax.lax` module:
-
-```{code-cell} ipython3
-:id: tbOF0LB0EMne
-:outputId: 3fb1c8a7-7a4c-4a3a-f7ff-37b7dc740528
-
-from jax import lax
-from jax._src import api
-
-def multiply_add_lax(x, y, z):
- """Implementation of multiply-add using the jax.lax primitives."""
- return lax.add(lax.mul(x, y), z)
-
-
-def square_add_lax(a, b):
- """A square-add function using the newly defined multiply-add."""
- return multiply_add_lax(a, a, b)
-
-print("square_add_lax = ", square_add_lax(2., 10.))
-# Differentiate w.r.t. the first argument
-print("grad(square_add_lax) = ", api.grad(square_add_lax, argnums=0)(2.0, 10.))
-```
-
-+++ {"id": "Cgv60Wm3E_D5"}
-
-In order to understand how JAX is internally using the primitives,
-we add some helpers for tracing function calls.
-
-```{code-cell} ipython3
-:cellView: form
-:id: mQRQGEGiE53K
-
-#@title Helper functions (execute this cell)
-import functools
-import traceback
-
-_indentation = 0
-def _trace(msg=None):
- """Print a message at current indentation."""
- if msg is not None:
- print(" " * _indentation + msg)
-
-def _trace_indent(msg=None):
- """Print a message and then indent the rest."""
- global _indentation
- _trace(msg)
- _indentation = 1 + _indentation
-
-def _trace_unindent(msg=None):
- """Unindent then print a message."""
- global _indentation
- _indentation = _indentation - 1
- _trace(msg)
-
-def trace(name):
- """A decorator for functions to trace arguments and results."""
-
- def trace_func(func): # pylint: disable=missing-docstring
- def pp(v):
- """Print certain values more succinctly"""
- vtype = str(type(v))
- if "jax._src.xla_bridge._JaxComputationBuilder" in vtype:
- return ""
- elif "jaxlib.xla_extension.XlaOp" in vtype:
- return "".format(id(v))
- elif ("partial_eval.JaxprTracer" in vtype or
- "batching.BatchTracer" in vtype or
- "ad.JVPTracer" in vtype):
- return "Traced<{}>".format(v.aval)
- elif isinstance(v, tuple):
- return "({})".format(pp_values(v))
- else:
- return str(v)
- def pp_values(args):
- return ", ".join([pp(arg) for arg in args])
-
- @functools.wraps(func)
- def func_wrapper(*args):
- _trace_indent("call {}({})".format(name, pp_values(args)))
- res = func(*args)
- _trace_unindent("|<- {} = {}".format(name, pp(res)))
- return res
-
- return func_wrapper
-
- return trace_func
-
-class expectNotImplementedError(object):
- """Context manager to check for NotImplementedError."""
- def __enter__(self): pass
- def __exit__(self, type, value, tb):
- global _indentation
- _indentation = 0
- if type is NotImplementedError:
- print("\nFound expected exception:")
- traceback.print_exc(limit=3)
- return True
- elif type is None: # No exception
- assert False, "Expected NotImplementedError"
- else:
- return False
-```
-
-+++ {"id": "Qf4eLrLCFYDl"}
-
-Instead of using `jax.lax` primitives directly, we can use other functions
-that are already written in terms of those primitives, such as those in `jax.numpy`:
-
-```{code-cell} ipython3
-:id: QhKorz6cFRJb
-:outputId: aba3cef3-6bcc-4eb3-c7b3-34e405f2f82a
-
-import jax.numpy as jnp
-import numpy as np
-
-@trace("multiply_add_numpy")
-def multiply_add_numpy(x, y, z):
- return jnp.add(jnp.multiply(x, y), z)
-
-@trace("square_add_numpy")
-def square_add_numpy(a, b):
- return multiply_add_numpy(a, a, b)
-
-print("\nNormal evaluation:")
-print("square_add_numpy = ", square_add_numpy(2., 10.))
-print("\nGradient evaluation:")
-print("grad(square_add_numpy) = ", api.grad(square_add_numpy)(2.0, 10.))
-```
-
-+++ {"id": "Sg-D8EdeFn4a"}
-
-Notice that in the process of computing `grad`, JAX invokes `square_add_numpy` and
-`multiply_add_numpy` with special arguments `ConcreteArray(...)` (described further
-below in this colab).
-It is important to remember that a JAX-traceable function must be able to
-operate not only on concrete arguments but also on special abstract arguments
-that JAX may use to abstract the function execution.
-
-The JAX traceability property is satisfied as long as the function is written
-in terms of JAX primitives.
-
-+++ {"id": "WxrQO7-XGLcg"}
-
-## Defining new JAX primitives
-
-The right way to add support for multiply-add is in terms of existing
-JAX primitives, as shown above. However, in order to demonstrate how JAX
-primitives work let us pretend that we want to add a new primitive to
-JAX for the multiply-add functionality.
-
-```{code-cell} ipython3
-:id: cPqAH1XOGTN4
-
-from jax import core
-multiply_add_p = core.Primitive("multiply_add") # Create the primitive
-
-@trace("multiply_add_prim")
-def multiply_add_prim(x, y, z):
- """The JAX-traceable way to use the JAX primitive.
-
- Note that the traced arguments must be passed as positional arguments
- to `bind`.
- """
- return multiply_add_p.bind(x, y, z)
-
-@trace("square_add_prim")
-def square_add_prim(a, b):
- """A square-add function implemented using the new JAX-primitive."""
- return multiply_add_prim(a, a, b)
-```
-
-+++ {"id": "LMzs5PAKGr-4"}
-
-If we try to call the newly defined functions we get an error, because
-we have not yet told JAX anything about the semantics of the new primitive.
-
-```{code-cell} ipython3
-:id: _X3PAYxhGpWd
-:outputId: 90ea2c6a-9ef3-40ea-e9a3-3ab1cfc59fc8
-
-with expectNotImplementedError():
- square_add_prim(2., 10.)
-```
-
-+++ {"id": "elha0FdgHSEF"}
-
-### Primal evaluation rules
-
-```{code-cell} ipython3
-:id: FT34FFAGHARU
-:outputId: 4c54f1c2-8a50-4788-90e1-06aee412c43b
-
-@trace("multiply_add_impl")
-def multiply_add_impl(x, y, z):
- """Concrete implementation of the primitive.
-
- This function does not need to be JAX traceable.
- Args:
- x, y, z: the concrete arguments of the primitive. Will only be called with
- concrete values.
- Returns:
- the concrete result of the primitive.
- """
- # Note that we can use the original numpy, which is not JAX traceable
- return np.add(np.multiply(x, y), z)
-
-# Now we register the primal implementation with JAX
-multiply_add_p.def_impl(multiply_add_impl)
-```
-
-```{code-cell} ipython3
-:id: G5bstKaeNAVV
-:outputId: deb94d5b-dfea-4e6f-9ec2-70b416c996c5
-
-assert square_add_prim(2., 10.) == 14.
-```
-
-+++ {"id": "upBf-uAuHhPJ"}
-
-### JIT
-
-If we now try to use `jit` we get a `NotImplementedError`:
-
-```{code-cell} ipython3
-:id: QG-LULjiHk4b
-:outputId: d4ef4406-8dae-4c96-97ca-b662340474ee
-
-with expectNotImplementedError():
- api.jit(square_add_prim)(2., 10.)
-```
-
-+++ {"id": "rHS1bAGHH44E"}
-
-#### Abstract evaluation rules
-In order to JIT the function, and for other transformations as well,
-JAX first evaluates it abstractly using only the
-shape and type of the arguments. This abstract evaluation serves multiple
-purposes:
-
- * Gets the sequence of JAX primitives that are used in the computation. This
- sequence will be compiled.
- * Computes the shape and type of all vectors and operations used in the computation.
-
-
-For example, the abstraction of a vector with 3 elements may be `ShapedArray(float32[3])`, or `ConcreteArray([1., 2., 3.])`.
-In the latter case, JAX uses the actual concrete value wrapped as an abstract value.
-
-```{code-cell} ipython3
-:id: ctQmEeckIbdo
-:outputId: e751d0cc-460e-4ffd-df2e-fdabf9cffdc2
-
-from jax import core
-@trace("multiply_add_abstract_eval")
-def multiply_add_abstract_eval(xs, ys, zs):
- """Abstract evaluation of the primitive.
-
- This function does not need to be JAX traceable. It will be invoked with
- abstractions of the actual arguments.
- Args:
- xs, ys, zs: abstractions of the arguments.
- Result:
- a ShapedArray for the result of the primitive.
- """
- assert xs.shape == ys.shape
- assert xs.shape == zs.shape
- return core.ShapedArray(xs.shape, xs.dtype)
-
-# Now we register the abstract evaluation with JAX
-multiply_add_p.def_abstract_eval(multiply_add_abstract_eval)
-```
-
-+++ {"id": "RPN88X6YI43A"}
-
-If we re-attempt to JIT, we see how the abstract evaluation proceeds, but
-we get another error, about missing the actual XLA compilation rule:
-
-```{code-cell} ipython3
-:id: eOcNR92SI2h-
-:outputId: 356ef229-3703-4696-cc3d-7c05de405fb0
-
-with expectNotImplementedError():
- api.jit(square_add_prim)(2., 10.)
-```
-
-+++ {"id": "9IOV1R-fJMHp"}
-
-#### XLA Compilation rules
-
-JAX compilation works by compiling each primitive into a graph of XLA operations.
-
-This is the biggest hurdle to adding new functionality to JAX, because the
-set of XLA operations is limited, and JAX already has pre-defined primitives
-for most of them. However, XLA includes a `CustomCall` operation that can be used to encapsulate arbitrary functionality defined using C++.
-
-```{code-cell} ipython3
-:id: FYQWSSjKJaWP
-
-from jax._src.lib.mlir.dialects import hlo
-@trace("multiply_add_lowering")
-def multiply_add_lowering(ctx, xc, yc, zc):
- """The compilation to XLA of the primitive.
-
- Given an mlir.ir.Value for each argument, return the mlir.ir.Values for
- the results of the function.
-
- Does not need to be a JAX-traceable function.
- """
- return [hlo.AddOp(hlo.MulOp(xc, yc), zc).result]
-
-# Now we register the lowering rule with JAX
-# For GPU see the [Custom operations for GPUs](https://jax.readthedocs.io/en/latest/Custom_Operation_for_GPUs.html)
-# TODO: TPU?
-from jax.interpreters import mlir
-mlir.register_lowering(multiply_add_p, multiply_add_lowering, platform='cpu')
-```
-
-+++ {"id": "K98LX-VaJkFu"}
-
-Now we succeed to JIT. Notice below that JAX first evaluates the function
-abstractly, which triggers the `multiply_add_abstract_eval` function, and
-then compiles the set of primitives it has encountered, including `multiply_add`.
-At this point JAX invokes `multiply_add_xla_translation`.
-
-```{code-cell} ipython3
-:id: rj3TLsolJgEc
-:outputId: e384bee4-1e9c-4344-f49c-d3b5ec08eb32
-
-assert api.jit(lambda x, y: square_add_prim(x, y))(2., 10.) == 14.
-```
-
-+++ {"id": "Omrez-2_KFfo"}
-
-Below is another use of `jit` where we compile only
-with respect to the first argument. Notice how the second argument to `square_add_prim` is concrete, which leads
-in the third argument to `multiply_add_abstract_eval` being
-`ConcreteArray`. We see that `multiply_add_abstract_eval` may be used with
-both `ShapedArray` and `ConcreteArray`.
-
-```{code-cell} ipython3
-:id: mPfTwIBoKOEK
-:outputId: b293b9b6-a2f9-48f5-f7eb-d4f99c3d905b
-
-assert api.jit(lambda x, y: square_add_prim(x, y),
- static_argnums=1)(2., 10.) == 14.
-```
-
-+++ {"id": "_Ya3B5l4J1VA"}
-
-### Forward differentiation
-
-JAX implements forward differentiation in the form of
-a Jacobian-vector product (see the [JAX autodiff cookbook](https://jax.readthedocs.io/en/latest/notebooks/autodiff_cookbook.html#Jacobian-Matrix-and-Matrix-Jacobian-products)).
-
-If we attempt now to compute the `jvp` function we get an
-error because we have not yet told JAX how to differentiate
-the `multiply_add` primitive.
-
-```{code-cell} ipython3
-:id: OxDx6NQnKwMI
-:outputId: ce659ef3-c03c-4856-f252-49ec4b6eb964
-
-# The second argument `(2., 10.)` are the argument values
-# where we evaluate the Jacobian, and the third `(1., 1.)`
-# are the values of the tangents for the arguments.
-with expectNotImplementedError():
- api.jvp(square_add_prim, (2., 10.), (1., 1.))
-```
-
-```{code-cell} ipython3
-:id: zxG24C1JMIMM
-
-from jax.interpreters import ad
-
-
-@trace("multiply_add_value_and_jvp")
-def multiply_add_value_and_jvp(arg_values, arg_tangents):
- """Evaluates the primal output and the tangents (Jacobian-vector product).
-
- Given values of the arguments and perturbation of the arguments (tangents),
- compute the output of the primitive and the perturbation of the output.
-
- This method must be JAX-traceable. JAX may invoke it with abstract values
- for the arguments and tangents.
-
- Args:
- arg_values: a tuple of arguments
- arg_tangents: a tuple with the tangents of the arguments. The tuple has
- the same length as the arg_values. Some of the tangents may also be the
- special value ad.Zero to specify a zero tangent.
- Returns:
- a pair of the primal output and the tangent.
- """
- x, y, z = arg_values
- xt, yt, zt = arg_tangents
- _trace("Primal evaluation:")
- # Now we have a JAX-traceable computation of the output.
- # Normally, we can use the ma primitive itself to compute the primal output.
- primal_out = multiply_add_prim(x, y, z)
-
- _trace("Tangent evaluation:")
- # We must use a JAX-traceable way to compute the tangent. It turns out that
- # the output tangent can be computed as (xt * y + x * yt + zt),
- # which we can implement in a JAX-traceable way using the same "multiply_add_prim" primitive.
-
- # We do need to deal specially with Zero. Here we just turn it into a
- # proper tensor of 0s (of the same shape as 'x').
- # An alternative would be to check for Zero and perform algebraic
- # simplification of the output tangent computation.
- def make_zero(tan):
- return lax.zeros_like_array(x) if type(tan) is ad.Zero else tan
-
- output_tangent = multiply_add_prim(make_zero(xt), y, multiply_add_prim(x, make_zero(yt), make_zero(zt)))
- return (primal_out, output_tangent)
-
-# Register the forward differentiation rule with JAX
-ad.primitive_jvps[multiply_add_p] = multiply_add_value_and_jvp
-```
-
-```{code-cell} ipython3
-:id: ma3KBkiAMfW1
-:outputId: f34cbbc6-20d9-48ca-9a9a-b5d91a972cdd
-
-# Tangent is: xt*y + x*yt + zt = 1.*2. + 2.*1. + 1. = 5.
-assert api.jvp(square_add_prim, (2., 10.), (1., 1.)) == (14., 5.)
-```
-
-+++ {"id": "69QsEcu-lP4u"}
-
-TO EXPLAIN:
-
- * Why is JAX using ConcreteArray in square_add_prim? There is no abstract evaluation going on here.
- * Not sure how to explain that multiply_add_prim is invoked with ConcreteValue, yet
- we do not call the multiply_add_abstract_eval.
- * I think it would be useful to show the jaxpr here
-
-+++ {"id": "Sb6e3ZAHOPHv"}
-
-#### JIT of forward differentiation
-
-We can apply JIT to the forward differentiation function:
-
-```{code-cell} ipython3
-:id: hg-hzVu-N-hv
-:outputId: 38d32067-e152-4046-ad80-7f95a31ba628
-
-assert api.jit(lambda arg_values, arg_tangents:
- api.jvp(square_add_prim, arg_values, arg_tangents))(
- (2., 10.), (1., 1.)) == (14., 5.)
-```
-
-+++ {"id": "jlZt1_v2mU88"}
-
-Notice that first we evaluate `multiply_add_value_and_jvp` abstractly, which in turn
-evaluates abstractly both the primal and the tangent evaluation (a total of
-3 invocations of the `ma` primitive). Then we compile the 3 occurrences
-of the primitive.
-
-+++ {"id": "555yt6ZIOePB"}
-
-### Reverse differentiation
-
-If we attempt now to use reverse differentiation we
-see that JAX starts by using the `multiply_add_value_and_jvp` to
-compute the forward differentiation for abstract values, but then runs
-into a `NotImplementedError`.
-
-When computing the reverse differentiation JAX first does abstract evaluation
-of the forward differentiation code `multiply_add_value_and_jvp` to obtain a
-trace of primitives that compute the output tangent.
-Observe that JAX performs this abstract evaluation with concrete values
-for the differentiation point, and abstract values for the tangents.
-Observe also that JAX uses the special abstract tangent value `Zero` for
-the tangent corresponding to the 3rd argument of `ma`. This reflects the
-fact that we do not differentiate w.r.t. the 2nd argument to `square_add_prim`,
-which flows to the 3rd argument to `multiply_add_prim`.
-
-Observe also that during the abstract evaluation of the tangent we pass the
-value 0.0 as the tangent for the 3rd argument. This is due to the use
-of the `make_zero` function in the definition of `multiply_add_value_and_jvp`.
-
-```{code-cell} ipython3
-:id: 8eAVnexaOjBn
-:outputId: e4ee89cf-ab4a-4505-9817-fa978a2865ab
-
-# This is reverse differentiation w.r.t. the first argument of square_add_prim
-with expectNotImplementedError():
- api.grad(square_add_prim)(2., 10.)
-```
-
-+++ {"id": "fSHLUMDN26AY"}
-
-The above error is because there is a missing piece for JAX to be able
-to use the forward differentiation code to compute reverse differentiation.
-
-+++ {"id": "3ibDbGF-PjK9"}
-
-#### Transposition
-
-
-As explained above, when computing reverse differentiation JAX obtains
-a trace of primitives that compute the tangent using forward differentiation.
-Then, **JAX interprets this trace abstractly backwards** and for each
-primitive it applies a **transposition** rule.
-
-To understand what is going on, consider for now a simpler example of the function "f(x, y) = x * y + y". Assume we need to differentiate at the point `(2., 4.)`. JAX will produce the following JVP tangent calculation of `ft` from the tangents of the input `xt` and `yt`:
-```
- a = xt * 4.
- b = 2. * yt
- c = a + b
- ft = c + yt
-```
-
-By construction, the tangent calculation is always linear in the input tangents.
-The only non-linear operator that may arise in the tangent calculation is multiplication,
-but then one of the operands is constant.
-
-JAX will produce the reverse differentiation computation by processing the
-JVP computation backwards. For each operation in the tangent computation,
-it accumulates the cotangents
-of the variables used by the operation, using the cotangent of the result
-of the operation:
-```
- # Initialize cotangents of inputs and intermediate vars
- xct = yct = act = bct = cct = 0.
- # Initialize cotangent of the output
- fct = 1.
- # Process "ft = c + yt"
- cct += fct
- yct += fct
- # Process "c = a + b"
- act += cct
- bct += cct
- # Process "b = 2. * yt"
- yct += 2. * bct
- # Process "a = xt * 4."
- xct += act * 4.
-```
-
-One can verify that this computation produces `xct = 4.` and `yct = 3.`, which
-are the partial derivatives of the function `f`.
-
-JAX knows for each primitive that may appear in a JVP calculation how to transpose it. Conceptually, if the primitive `p(x, y, z)` is linear in the arguments `y` and `z` for a constant value of `x`, e.g., `p(x, y, z) = y*cy + z*cz`, then the transposition of the primitive is:
-```
-p_transpose(out_ct, x, _, _) = (None, out_ct*cy, out_ct*cz)
-```
-
-Notice that `p_transpose` takes the cotangent of the output of the primitive and a value corresponding to each argument of the primitive. For the linear arguments, the transposition gets an undefined `_` value, and for the other
-arguments it gets the actual constants. The transposition returns a cotangent value for each argument of the primitive, with the value `None` returned
-for the constant arguments.
-
-In particular,
-```
- add_transpose(out_ct, _, _) = (out_ct, out_ct)
- mult_transpose(out_ct, x, _) = (None, x * out_ct)
- mult_transpose(out_ct, _, y) = (out_ct * y, None)
-```
-
-```{code-cell} ipython3
-:id: JaHxFdkRO42r
-
-@trace("multiply_add_transpose")
-def multiply_add_transpose(ct, x, y, z):
- """Evaluates the transpose of a linear primitive.
-
- This method is only used when computing the backward gradient following
- value_and_jvp, and is only needed for primitives that are used in the JVP
- calculation for some other primitive. We need transposition for multiply_add_prim,
- because we have used multiply_add_prim in the computation of the output_tangent in
- multiply_add_value_and_jvp.
-
- In our case, multiply_add is not a linear primitive. However, it is used linearly
- w.r.t. tangents in multiply_add_value_and_jvp:
- output_tangent(xt, yt, zt) = multiply_add_prim(xt, y, multiply_add_prim(x, yt, zt))
-
- Always one of the first two multiplicative arguments is a constant.
-
- Args:
- ct: the cotangent of the output of the primitive.
- x, y, z: values of the arguments. The arguments that are used linearly
- get an ad.UndefinedPrimal value. The other arguments get a constant
- value.
- Returns:
- a tuple with the cotangent of the inputs, with the value None
- corresponding to the constant arguments.
- """
- if not ad.is_undefined_primal(x):
- # This use of multiply_add is with a constant "x"
- assert ad.is_undefined_primal(y)
- ct_y = ad.Zero(y.aval) if type(ct) is ad.Zero else multiply_add_prim(x, ct, lax.zeros_like_array(x))
- res = None, ct_y, ct
- else:
- # This use of multiply_add is with a constant "y"
- assert ad.is_undefined_primal(x)
- ct_x = ad.Zero(x.aval) if type(ct) is ad.Zero else multiply_add_prim(ct, y, lax.zeros_like_array(y))
- res = ct_x, None, ct
- return res
-
-
-ad.primitive_transposes[multiply_add_p] = multiply_add_transpose
-```
-
-+++ {"id": "PpChox-Jp7wb"}
-
-Now we can complete the run of the `grad`:
-
-```{code-cell} ipython3
-:id: PogPKS4MPevd
-:outputId: d33328d4-3e87-45b5-9b31-21ad624b67af
-
-assert api.grad(square_add_prim)(2., 10.) == 4.
-```
-
-+++ {"id": "8M1xLCXW4fK7"}
-
-Notice the two calls to `multiply_add_transpose`. They correspond to the two
-uses of `multiply_add_prim` in the computation of the `output_tangent` in `multiply_add_value_and_jvp`. The first call to transpose corresponds to the
-last use of `multiply_add_prim`: `multiply_add_prim(xt, y, ...)` where `y` is the constant 2.0.
-
-+++ {"id": "EIJs6FYmPg6c"}
-
-#### JIT of reverse differentiation
-
-Notice that the abstract evaluation of the `multiply_add_value_and_jvp` is using only
-abstract values, while in the absence of JIT we used `ConcreteArray`.
-
-```{code-cell} ipython3
-:id: FZ-JGbWZPq2-
-:outputId: e42b5222-9c3e-4853-e13a-874f6605d178
-
-assert api.jit(api.grad(square_add_prim))(2., 10.) == 4.
-```
-
-+++ {"id": "-3lqPkdQPvl5"}
-
-### Batching
-
-The batching transformation takes a point-wise computation and turns it
-into a computation on vectors. If we try it right now, we get a `NotImplementedError`:
-
-```{code-cell} ipython3
-:id: hFvBR3I9Pzh3
-:outputId: 434608bc-281f-4d3b-83bd-eaaf3b51b1cd
-
-# The arguments are two vectors instead of two scalars
-with expectNotImplementedError():
- api.vmap(square_add_prim, in_axes=0, out_axes=0)(np.array([2., 3.]),
- np.array([10., 20.]))
-```
-
-+++ {"id": "gILasMiP6elR"}
-
-We need to tell JAX how to evaluate the batched version of the primitive. In this particular case, the `multiply_add_prim` already operates pointwise for any dimension of input vectors. So the batched version can use the same `multiply_add_prim` implementation.
-
-```{code-cell} ipython3
-:id: KQfeqRIrP7zg
-
-from jax.interpreters import batching
-
-
-@trace("multiply_add_batch")
-def multiply_add_batch(vector_arg_values, batch_axes):
- """Computes the batched version of the primitive.
-
- This must be a JAX-traceable function.
-
- Since the multiply_add primitive already operates pointwise on arbitrary
- dimension tensors, to batch it we can use the primitive itself. This works as
- long as both the inputs have the same dimensions and are batched along the
- same axes. The result is batched along the axis that the inputs are batched.
-
- Args:
- vector_arg_values: a tuple of two arguments, each being a tensor of matching
- shape.
- batch_axes: the axes that are being batched. See vmap documentation.
- Returns:
- a tuple of the result, and the result axis that was batched.
- """
- assert batch_axes[0] == batch_axes[1]
- assert batch_axes[0] == batch_axes[2]
- _trace("Using multiply_add to compute the batch:")
- res = multiply_add_prim(*vector_arg_values)
- return res, batch_axes[0]
-
-
-batching.primitive_batchers[multiply_add_p] = multiply_add_batch
-```
-
-```{code-cell} ipython3
-:id: VwxNk869P_YG
-:outputId: 9d22c921-5803-4d33-9e88-b6e439ba9738
-
-assert np.allclose(api.vmap(square_add_prim, in_axes=0, out_axes=0)(
- np.array([2., 3.]),
- np.array([10., 20.])),
- [14., 29.])
-```
-
-+++ {"id": "NmqLlV1TQDCC"}
-
-#### JIT of batching
-
-```{code-cell} ipython3
-:id: xqEdXVUgQCTt
-:outputId: 9c22fd9c-919c-491d-bbeb-32c241b808fa
-
-assert np.allclose(api.jit(api.vmap(square_add_prim, in_axes=0, out_axes=0))
- (np.array([2., 3.]),
- np.array([10., 20.])),
- [14., 29.])
-```
diff --git a/docs/notebooks/Neural_Network_and_Data_Loading.ipynb b/docs/notebooks/Neural_Network_and_Data_Loading.ipynb
index f0c157655790..a7ef2a017048 100644
--- a/docs/notebooks/Neural_Network_and_Data_Loading.ipynb
+++ b/docs/notebooks/Neural_Network_and_Data_Loading.ipynb
@@ -6,11 +6,11 @@
"id": "18AF5Ab4p6VL"
},
"source": [
- "# Training a Simple Neural Network, with PyTorch Data Loading\n",
+ "# Training a simple neural network, with PyTorch data loading\n",
"\n",
"\n",
"\n",
- "[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/jax/blob/main/docs/notebooks/Neural_Network_and_Data_Loading.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/google/jax/blob/main/docs/notebooks/Neural_Network_and_Data_Loading.ipynb)\n",
+ "[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jax-ml/jax/blob/main/docs/notebooks/Neural_Network_and_Data_Loading.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/jax-ml/jax/blob/main/docs/notebooks/Neural_Network_and_Data_Loading.ipynb)\n",
"\n",
"**Copyright 2018 The JAX Authors.**\n",
"\n",
@@ -32,9 +32,9 @@
"id": "B_XlLLpcWjkA"
},
"source": [
- "![JAX](https://raw.githubusercontent.com/google/jax/main/images/jax_logo_250px.png)\n",
+ "![JAX](https://raw.githubusercontent.com/jax-ml/jax/main/images/jax_logo_250px.png)\n",
"\n",
- "Let's combine everything we showed in the [quickstart](https://colab.research.google.com/github/google/jax/blob/main/docs/quickstart.html) to train a simple neural network. We will first specify and train a simple MLP on MNIST using JAX for the computation. We will use PyTorch's data loading API to load images and labels (because it's pretty great, and the world doesn't need yet another data loading library).\n",
+ "Let's combine everything we showed in the [quickstart](https://colab.research.google.com/github/jax-ml/jax/blob/main/docs/quickstart.html) to train a simple neural network. We will first specify and train a simple MLP on MNIST using JAX for the computation. We will use PyTorch's data loading API to load images and labels (because it's pretty great, and the world doesn't need yet another data loading library).\n",
"\n",
"Of course, you can use JAX with any API that is compatible with NumPy to make specifying the model a bit more plug-and-play. Here, just for explanatory purposes, we won't use any neural network libraries or special APIs for building our model."
]
@@ -119,7 +119,7 @@
" for w, b in params[:-1]:\n",
" outputs = jnp.dot(w, activations) + b\n",
" activations = relu(outputs)\n",
- " \n",
+ "\n",
" final_w, final_b = params[-1]\n",
" logits = jnp.dot(final_w, activations) + final_b\n",
" return logits - logsumexp(logits)"
@@ -238,7 +238,7 @@
"def one_hot(x, k, dtype=jnp.float32):\n",
" \"\"\"Create a one-hot encoding of x of size k.\"\"\"\n",
" return jnp.array(x[:, None] == jnp.arange(k), dtype)\n",
- " \n",
+ "\n",
"def accuracy(params, images, targets):\n",
" target_class = jnp.argmax(targets, axis=1)\n",
" predicted_class = jnp.argmax(batched_predict(params, images), axis=1)\n",
@@ -261,7 +261,7 @@
"id": "umJJGZCC2oKl"
},
"source": [
- "## Data Loading with PyTorch\n",
+ "## Data loading with PyTorch\n",
"\n",
"JAX is laser-focused on program transformations and accelerator-backed NumPy, so we don't include data loading or munging in the JAX library. There are already a lot of great data loaders out there, so let's just use them instead of reinventing anything. We'll grab PyTorch's data loader, and make a tiny shim to make it work with NumPy arrays."
]
@@ -494,7 +494,7 @@
"id": "xxPd6Qw3Z98v"
},
"source": [
- "## Training Loop"
+ "## Training loop"
]
},
{
diff --git a/docs/notebooks/Neural_Network_and_Data_Loading.md b/docs/notebooks/Neural_Network_and_Data_Loading.md
index 2c53bb1e4ab5..cd98022e7421 100644
--- a/docs/notebooks/Neural_Network_and_Data_Loading.md
+++ b/docs/notebooks/Neural_Network_and_Data_Loading.md
@@ -5,7 +5,7 @@ jupytext:
extension: .md
format_name: myst
format_version: 0.13
- jupytext_version: 1.16.1
+ jupytext_version: 1.16.4
kernelspec:
display_name: Python 3
language: python
@@ -14,11 +14,11 @@ kernelspec:
+++ {"id": "18AF5Ab4p6VL"}
-# Training a Simple Neural Network, with PyTorch Data Loading
+# Training a simple neural network, with PyTorch data loading
-[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/jax/blob/main/docs/notebooks/Neural_Network_and_Data_Loading.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/google/jax/blob/main/docs/notebooks/Neural_Network_and_Data_Loading.ipynb)
+[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jax-ml/jax/blob/main/docs/notebooks/Neural_Network_and_Data_Loading.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/jax-ml/jax/blob/main/docs/notebooks/Neural_Network_and_Data_Loading.ipynb)
**Copyright 2018 The JAX Authors.**
@@ -35,9 +35,9 @@ limitations under the License.
+++ {"id": "B_XlLLpcWjkA"}
-![JAX](https://raw.githubusercontent.com/google/jax/main/images/jax_logo_250px.png)
+![JAX](https://raw.githubusercontent.com/jax-ml/jax/main/images/jax_logo_250px.png)
-Let's combine everything we showed in the [quickstart](https://colab.research.google.com/github/google/jax/blob/main/docs/quickstart.html) to train a simple neural network. We will first specify and train a simple MLP on MNIST using JAX for the computation. We will use PyTorch's data loading API to load images and labels (because it's pretty great, and the world doesn't need yet another data loading library).
+Let's combine everything we showed in the [quickstart](https://colab.research.google.com/github/jax-ml/jax/blob/main/docs/quickstart.html) to train a simple neural network. We will first specify and train a simple MLP on MNIST using JAX for the computation. We will use PyTorch's data loading API to load images and labels (because it's pretty great, and the world doesn't need yet another data loading library).
Of course, you can use JAX with any API that is compatible with NumPy to make specifying the model a bit more plug-and-play. Here, just for explanatory purposes, we won't use any neural network libraries or special APIs for building our model.
@@ -96,7 +96,7 @@ def predict(params, image):
for w, b in params[:-1]:
outputs = jnp.dot(w, activations) + b
activations = relu(outputs)
-
+
final_w, final_b = params[-1]
logits = jnp.dot(final_w, activations) + final_b
return logits - logsumexp(logits)
@@ -156,7 +156,7 @@ At this point, we have all the ingredients we need to define our neural network
def one_hot(x, k, dtype=jnp.float32):
"""Create a one-hot encoding of x of size k."""
return jnp.array(x[:, None] == jnp.arange(k), dtype)
-
+
def accuracy(params, images, targets):
target_class = jnp.argmax(targets, axis=1)
predicted_class = jnp.argmax(batched_predict(params, images), axis=1)
@@ -175,7 +175,7 @@ def update(params, x, y):
+++ {"id": "umJJGZCC2oKl"}
-## Data Loading with PyTorch
+## Data loading with PyTorch
JAX is laser-focused on program transformations and accelerator-backed NumPy, so we don't include data loading or munging in the JAX library. There are already a lot of great data loaders out there, so let's just use them instead of reinventing anything. We'll grab PyTorch's data loader, and make a tiny shim to make it work with NumPy arrays.
@@ -245,7 +245,7 @@ test_labels = one_hot(np.array(mnist_dataset_test.test_labels), n_targets)
+++ {"id": "xxPd6Qw3Z98v"}
-## Training Loop
+## Training loop
```{code-cell} ipython3
:id: X2DnZo3iYj18
diff --git a/docs/notebooks/Writing_custom_interpreters_in_Jax.ipynb b/docs/notebooks/Writing_custom_interpreters_in_Jax.ipynb
index 7e65aefe359c..00ba9186eeec 100644
--- a/docs/notebooks/Writing_custom_interpreters_in_Jax.ipynb
+++ b/docs/notebooks/Writing_custom_interpreters_in_Jax.ipynb
@@ -10,7 +10,7 @@
"\n",
"\n",
"\n",
- "[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/jax/blob/main/docs/notebooks/Writing_custom_interpreters_in_Jax.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/google/jax/blob/main/docs/notebooks/Writing_custom_interpreters_in_Jax.ipynb)"
+ "[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jax-ml/jax/blob/main/docs/notebooks/Writing_custom_interpreters_in_Jax.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/jax-ml/jax/blob/main/docs/notebooks/Writing_custom_interpreters_in_Jax.ipynb)"
]
},
{
@@ -35,7 +35,6 @@
},
"outputs": [],
"source": [
- "import numpy as np\n",
"import jax\n",
"import jax.numpy as jnp\n",
"from jax import jit, grad, vmap\n",
@@ -80,7 +79,7 @@
"id": "gA8V51wZdsjh"
},
"source": [
- "When we call `fast_f`, what happens? JAX traces the function and constructs an XLA computation graph. The graph is then JIT-compiled and executed. Other transformations work similarly in that they first trace the function and handle the output trace in some way. To learn more about Jax's tracing machinery, you can refer to the [\"How it works\"](https://github.com/google/jax#how-it-works) section in the README."
+ "When we call `fast_f`, what happens? JAX traces the function and constructs an XLA computation graph. The graph is then JIT-compiled and executed. Other transformations work similarly in that they first trace the function and handle the output trace in some way. To learn more about Jax's tracing machinery, you can refer to the [\"How it works\"](https://github.com/jax-ml/jax#how-it-works) section in the README."
]
},
{
@@ -214,7 +213,6 @@
"outputs": [],
"source": [
"# Importing Jax functions useful for tracing/interpreting.\n",
- "import numpy as np\n",
"from functools import wraps\n",
"\n",
"from jax import core\n",
@@ -273,7 +271,7 @@
"def eval_jaxpr(jaxpr, consts, *args):\n",
" # Mapping from variable -> value\n",
" env = {}\n",
- " \n",
+ "\n",
" def read(var):\n",
" # Literals are values baked into the Jaxpr\n",
" if type(var) is core.Literal:\n",
@@ -290,16 +288,16 @@
" # Loop through equations and evaluate primitives using `bind`\n",
" for eqn in jaxpr.eqns:\n",
" # Read inputs to equation from environment\n",
- " invals = safe_map(read, eqn.invars) \n",
+ " invals = safe_map(read, eqn.invars)\n",
" # `bind` is how a primitive is called\n",
" outvals = eqn.primitive.bind(*invals, **eqn.params)\n",
" # Primitives may return multiple outputs or not\n",
- " if not eqn.primitive.multiple_results: \n",
+ " if not eqn.primitive.multiple_results:\n",
" outvals = [outvals]\n",
" # Write the results of the primitive into the environment\n",
- " safe_map(write, eqn.outvars, outvals) \n",
+ " safe_map(write, eqn.outvars, outvals)\n",
" # Read the final result of the Jaxpr from the environment\n",
- " return safe_map(read, jaxpr.outvars) "
+ " return safe_map(read, jaxpr.outvars)"
]
},
{
@@ -322,7 +320,7 @@
"source": [
"Notice that `eval_jaxpr` will always return a flat list even if the original function does not.\n",
"\n",
- "Furthermore, this interpreter does not handle higher-order primitives (like `jit` and `pmap`), which we will not cover in this guide. You can refer to `core.eval_jaxpr` ([link](https://github.com/google/jax/blob/main/jax/core.py)) to see the edge cases that this interpreter does not cover."
+ "Furthermore, this interpreter does not handle higher-order primitives (like `jit` and `pmap`), which we will not cover in this guide. You can refer to `core.eval_jaxpr` ([link](https://github.com/jax-ml/jax/blob/main/jax/core.py)) to see the edge cases that this interpreter does not cover."
]
},
{
@@ -335,7 +333,7 @@
"\n",
"An `inverse` interpreter doesn't look too different from `eval_jaxpr`. We'll first set up the registry which will map primitives to their inverses. We'll then write a custom interpreter that looks up primitives in the registry.\n",
"\n",
- "It turns out that this interpreter will also look similar to the \"transpose\" interpreter used in reverse-mode autodifferentiation [found here](https://github.com/google/jax/blob/main/jax/interpreters/ad.py#L164-L234)."
+ "It turns out that this interpreter will also look similar to the \"transpose\" interpreter used in reverse-mode autodifferentiation [found here](https://github.com/jax-ml/jax/blob/main/jax/interpreters/ad.py#L164-L234)."
]
},
{
@@ -417,7 +415,7 @@
"source": [
"def inverse_jaxpr(jaxpr, consts, *args):\n",
" env = {}\n",
- " \n",
+ "\n",
" def read(var):\n",
" if type(var) is core.Literal:\n",
" return var.val\n",
@@ -431,12 +429,12 @@
"\n",
" # Looping backward\n",
" for eqn in jaxpr.eqns[::-1]:\n",
- " # outvars are now invars \n",
+ " # outvars are now invars\n",
" invals = safe_map(read, eqn.outvars)\n",
" if eqn.primitive not in inverse_registry:\n",
" raise NotImplementedError(\n",
" f\"{eqn.primitive} does not have registered inverse.\")\n",
- " # Assuming a unary function \n",
+ " # Assuming a unary function\n",
" outval = inverse_registry[eqn.primitive](*invals)\n",
" safe_map(write, eqn.invars, [outval])\n",
" return safe_map(read, jaxpr.invars)"
diff --git a/docs/notebooks/Writing_custom_interpreters_in_Jax.md b/docs/notebooks/Writing_custom_interpreters_in_Jax.md
index e52c6a5f8742..10c4e7cb6e3b 100644
--- a/docs/notebooks/Writing_custom_interpreters_in_Jax.md
+++ b/docs/notebooks/Writing_custom_interpreters_in_Jax.md
@@ -5,7 +5,7 @@ jupytext:
extension: .md
format_name: myst
format_version: 0.13
- jupytext_version: 1.16.1
+ jupytext_version: 1.16.4
kernelspec:
display_name: Python 3
language: python
@@ -18,7 +18,7 @@ kernelspec:
-[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/jax/blob/main/docs/notebooks/Writing_custom_interpreters_in_Jax.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/google/jax/blob/main/docs/notebooks/Writing_custom_interpreters_in_Jax.ipynb)
+[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jax-ml/jax/blob/main/docs/notebooks/Writing_custom_interpreters_in_Jax.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/jax-ml/jax/blob/main/docs/notebooks/Writing_custom_interpreters_in_Jax.ipynb)
+++ {"id": "r-3vMiKRYXPJ"}
@@ -32,7 +32,6 @@ Here we show how to add your own function transformations to the system, by writ
```{code-cell} ipython3
:id: s27RDKvKXFL8
-import numpy as np
import jax
import jax.numpy as jnp
from jax import jit, grad, vmap
@@ -58,7 +57,7 @@ fast_f = jit(f)
+++ {"id": "gA8V51wZdsjh"}
-When we call `fast_f`, what happens? JAX traces the function and constructs an XLA computation graph. The graph is then JIT-compiled and executed. Other transformations work similarly in that they first trace the function and handle the output trace in some way. To learn more about Jax's tracing machinery, you can refer to the ["How it works"](https://github.com/google/jax#how-it-works) section in the README.
+When we call `fast_f`, what happens? JAX traces the function and constructs an XLA computation graph. The graph is then JIT-compiled and executed. Other transformations work similarly in that they first trace the function and handle the output trace in some way. To learn more about Jax's tracing machinery, you can refer to the ["How it works"](https://github.com/jax-ml/jax#how-it-works) section in the README.
+++ {"id": "2Th1vYLVaFBz"}
@@ -146,7 +145,6 @@ Let's use `make_jaxpr` to trace a function into a Jaxpr.
:id: BHkg_3P1pXJj
# Importing Jax functions useful for tracing/interpreting.
-import numpy as np
from functools import wraps
from jax import core
@@ -185,7 +183,7 @@ To do this, we first create an environment to store the values for each of the v
def eval_jaxpr(jaxpr, consts, *args):
# Mapping from variable -> value
env = {}
-
+
def read(var):
# Literals are values baked into the Jaxpr
if type(var) is core.Literal:
@@ -202,16 +200,16 @@ def eval_jaxpr(jaxpr, consts, *args):
# Loop through equations and evaluate primitives using `bind`
for eqn in jaxpr.eqns:
# Read inputs to equation from environment
- invals = safe_map(read, eqn.invars)
+ invals = safe_map(read, eqn.invars)
# `bind` is how a primitive is called
outvals = eqn.primitive.bind(*invals, **eqn.params)
# Primitives may return multiple outputs or not
- if not eqn.primitive.multiple_results:
+ if not eqn.primitive.multiple_results:
outvals = [outvals]
# Write the results of the primitive into the environment
- safe_map(write, eqn.outvars, outvals)
+ safe_map(write, eqn.outvars, outvals)
# Read the final result of the Jaxpr from the environment
- return safe_map(read, jaxpr.outvars)
+ return safe_map(read, jaxpr.outvars)
```
```{code-cell} ipython3
@@ -225,7 +223,7 @@ eval_jaxpr(closed_jaxpr.jaxpr, closed_jaxpr.literals, jnp.ones(5))
Notice that `eval_jaxpr` will always return a flat list even if the original function does not.
-Furthermore, this interpreter does not handle higher-order primitives (like `jit` and `pmap`), which we will not cover in this guide. You can refer to `core.eval_jaxpr` ([link](https://github.com/google/jax/blob/main/jax/core.py)) to see the edge cases that this interpreter does not cover.
+Furthermore, this interpreter does not handle higher-order primitives (like `jit` and `pmap`), which we will not cover in this guide. You can refer to `core.eval_jaxpr` ([link](https://github.com/jax-ml/jax/blob/main/jax/core.py)) to see the edge cases that this interpreter does not cover.
+++ {"id": "0vb2ZoGrCMM4"}
@@ -233,7 +231,7 @@ Furthermore, this interpreter does not handle higher-order primitives (like `jit
An `inverse` interpreter doesn't look too different from `eval_jaxpr`. We'll first set up the registry which will map primitives to their inverses. We'll then write a custom interpreter that looks up primitives in the registry.
-It turns out that this interpreter will also look similar to the "transpose" interpreter used in reverse-mode autodifferentiation [found here](https://github.com/google/jax/blob/main/jax/interpreters/ad.py#L164-L234).
+It turns out that this interpreter will also look similar to the "transpose" interpreter used in reverse-mode autodifferentiation [found here](https://github.com/jax-ml/jax/blob/main/jax/interpreters/ad.py#L164-L234).
```{code-cell} ipython3
:id: gSMIT2z1vUpO
@@ -279,7 +277,7 @@ Now we just need to define `inverse_jaxpr`, which will walk through the Jaxpr ba
def inverse_jaxpr(jaxpr, consts, *args):
env = {}
-
+
def read(var):
if type(var) is core.Literal:
return var.val
@@ -293,12 +291,12 @@ def inverse_jaxpr(jaxpr, consts, *args):
# Looping backward
for eqn in jaxpr.eqns[::-1]:
- # outvars are now invars
+ # outvars are now invars
invals = safe_map(read, eqn.outvars)
if eqn.primitive not in inverse_registry:
raise NotImplementedError(
f"{eqn.primitive} does not have registered inverse.")
- # Assuming a unary function
+ # Assuming a unary function
outval = inverse_registry[eqn.primitive](*invals)
safe_map(write, eqn.invars, [outval])
return safe_map(read, jaxpr.invars)
diff --git a/docs/notebooks/autodiff_cookbook.ipynb b/docs/notebooks/autodiff_cookbook.ipynb
index edfd0d4535f8..5538b70dac93 100644
--- a/docs/notebooks/autodiff_cookbook.ipynb
+++ b/docs/notebooks/autodiff_cookbook.ipynb
@@ -10,9 +10,7 @@
"\n",
"\n",
"\n",
- "[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/jax/blob/main/docs/notebooks/autodiff_cookbook.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/google/jax/blob/main/docs/notebooks/autodiff_cookbook.ipynb)\n",
- "\n",
- "*alexbw@, mattjj@* \n",
+ "[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jax-ml/jax/blob/main/docs/notebooks/autodiff_cookbook.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/jax-ml/jax/blob/main/docs/notebooks/autodiff_cookbook.ipynb)\n",
"\n",
"JAX has a pretty general automatic differentiation system. In this notebook, we'll go through a whole bunch of neat autodiff ideas that you can cherry pick for your own work, starting with the basics."
]
@@ -257,7 +255,7 @@
"id": "cJ2NxiN58bfI"
},
"source": [
- "You can [register your own container types](https://github.com/google/jax/issues/446#issuecomment-467105048) to work with not just `grad` but all the JAX transformations (`jit`, `vmap`, etc.)."
+ "You can [register your own container types](https://github.com/jax-ml/jax/issues/446#issuecomment-467105048) to work with not just `grad` but all the JAX transformations (`jit`, `vmap`, etc.)."
]
},
{
@@ -487,7 +485,7 @@
"id": "iZDL-n_AvgBt"
},
"source": [
- "These two functions compute the same values (up to machine numerics), but differ in their implementation: `jacfwd` uses forward-mode automatic differentiation, which is more efficient for \"tall\" Jacobian matrices, while `jacrev` uses reverse-mode, which is more efficient for \"wide\" Jacobian matrices. For matrices that are near-square, `jacfwd` probably has an edge over `jacrev`."
+ "These two functions compute the same values (up to machine numerics), but differ in their implementation: `jacfwd` uses forward-mode automatic differentiation, which is more efficient for \"tall\" Jacobian matrices (more outputs than inputs), while `jacrev` uses reverse-mode, which is more efficient for \"wide\" Jacobian matrices (more inputs than outputs). For matrices that are near-square, `jacfwd` probably has an edge over `jacrev`."
]
},
{
@@ -1017,7 +1015,7 @@
"source": [
"### Jacobian-Matrix and Matrix-Jacobian products\n",
"\n",
- "Now that we have `jvp` and `vjp` transformations that give us functions to push-forward or pull-back single vectors at a time, we can use JAX's `vmap` [transformation](https://github.com/google/jax#auto-vectorization-with-vmap) to push and pull entire bases at once. In particular, we can use that to write fast matrix-Jacobian and Jacobian-matrix products."
+ "Now that we have `jvp` and `vjp` transformations that give us functions to push-forward or pull-back single vectors at a time, we can use JAX's `vmap` [transformation](https://github.com/jax-ml/jax#auto-vectorization-with-vmap) to push and pull entire bases at once. In particular, we can use that to write fast matrix-Jacobian and Jacobian-matrix products."
]
},
{
@@ -1148,7 +1146,7 @@
" y, vjp_fun = vjp(f, x)\n",
" # Use vmap to do a matrix-Jacobian product.\n",
" # Here, the matrix is the Euclidean basis, so we get all\n",
- " # entries in the Jacobian at once. \n",
+ " # entries in the Jacobian at once.\n",
" J, = vmap(vjp_fun, in_axes=0)(jnp.eye(len(y)))\n",
" return J\n",
" return jacfun\n",
@@ -1169,7 +1167,7 @@
"def our_jacfwd(f):\n",
" def jacfun(x):\n",
" _jvp = lambda s: jvp(f, (x,), (s,))[1]\n",
- " Jt =vmap(_jvp, in_axes=1)(jnp.eye(len(x)))\n",
+ " Jt = vmap(_jvp, in_axes=1)(jnp.eye(len(x)))\n",
" return jnp.transpose(Jt)\n",
" return jacfun\n",
"\n",
diff --git a/docs/notebooks/autodiff_cookbook.md b/docs/notebooks/autodiff_cookbook.md
index c24d05c0e7c9..db6fde8051d1 100644
--- a/docs/notebooks/autodiff_cookbook.md
+++ b/docs/notebooks/autodiff_cookbook.md
@@ -5,7 +5,7 @@ jupytext:
extension: .md
format_name: myst
format_version: 0.13
- jupytext_version: 1.16.1
+ jupytext_version: 1.16.4
kernelspec:
display_name: Python 3
language: python
@@ -18,9 +18,7 @@ kernelspec:
-[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/jax/blob/main/docs/notebooks/autodiff_cookbook.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/google/jax/blob/main/docs/notebooks/autodiff_cookbook.ipynb)
-
-*alexbw@, mattjj@*
+[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jax-ml/jax/blob/main/docs/notebooks/autodiff_cookbook.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/jax-ml/jax/blob/main/docs/notebooks/autodiff_cookbook.ipynb)
JAX has a pretty general automatic differentiation system. In this notebook, we'll go through a whole bunch of neat autodiff ideas that you can cherry pick for your own work, starting with the basics.
@@ -153,7 +151,7 @@ print(grad(loss2)({'W': W, 'b': b}))
+++ {"id": "cJ2NxiN58bfI"}
-You can [register your own container types](https://github.com/google/jax/issues/446#issuecomment-467105048) to work with not just `grad` but all the JAX transformations (`jit`, `vmap`, etc.).
+You can [register your own container types](https://github.com/jax-ml/jax/issues/446#issuecomment-467105048) to work with not just `grad` but all the JAX transformations (`jit`, `vmap`, etc.).
+++ {"id": "PaCHzAtGruBz"}
@@ -276,7 +274,7 @@ print(J)
+++ {"id": "iZDL-n_AvgBt"}
-These two functions compute the same values (up to machine numerics), but differ in their implementation: `jacfwd` uses forward-mode automatic differentiation, which is more efficient for "tall" Jacobian matrices, while `jacrev` uses reverse-mode, which is more efficient for "wide" Jacobian matrices. For matrices that are near-square, `jacfwd` probably has an edge over `jacrev`.
+These two functions compute the same values (up to machine numerics), but differ in their implementation: `jacfwd` uses forward-mode automatic differentiation, which is more efficient for "tall" Jacobian matrices (more outputs than inputs), while `jacrev` uses reverse-mode, which is more efficient for "wide" Jacobian matrices (more inputs than outputs). For matrices that are near-square, `jacfwd` probably has an edge over `jacrev`.
+++ {"id": "zeKlr7Xz8bfm"}
@@ -594,7 +592,7 @@ print("Naive full Hessian materialization")
### Jacobian-Matrix and Matrix-Jacobian products
-Now that we have `jvp` and `vjp` transformations that give us functions to push-forward or pull-back single vectors at a time, we can use JAX's `vmap` [transformation](https://github.com/google/jax#auto-vectorization-with-vmap) to push and pull entire bases at once. In particular, we can use that to write fast matrix-Jacobian and Jacobian-matrix products.
+Now that we have `jvp` and `vjp` transformations that give us functions to push-forward or pull-back single vectors at a time, we can use JAX's `vmap` [transformation](https://github.com/jax-ml/jax#auto-vectorization-with-vmap) to push and pull entire bases at once. In particular, we can use that to write fast matrix-Jacobian and Jacobian-matrix products.
```{code-cell} ipython3
:id: asAWvxVaCmsx
@@ -675,7 +673,7 @@ def our_jacrev(f):
y, vjp_fun = vjp(f, x)
# Use vmap to do a matrix-Jacobian product.
# Here, the matrix is the Euclidean basis, so we get all
- # entries in the Jacobian at once.
+ # entries in the Jacobian at once.
J, = vmap(vjp_fun, in_axes=0)(jnp.eye(len(y)))
return J
return jacfun
@@ -691,7 +689,7 @@ from jax import jacfwd as builtin_jacfwd
def our_jacfwd(f):
def jacfun(x):
_jvp = lambda s: jvp(f, (x,), (s,))[1]
- Jt =vmap(_jvp, in_axes=1)(jnp.eye(len(x)))
+ Jt = vmap(_jvp, in_axes=1)(jnp.eye(len(x)))
return jnp.transpose(Jt)
return jacfun
diff --git a/docs/notebooks/autodiff_remat.ipynb b/docs/notebooks/autodiff_remat.ipynb
index f0552e52688f..82381838a5aa 100644
--- a/docs/notebooks/autodiff_remat.ipynb
+++ b/docs/notebooks/autodiff_remat.ipynb
@@ -27,7 +27,7 @@
"id": "qaIsQSh1XoKF"
},
"source": [
- "### TL;DR\n",
+ "### Summary\n",
"\n",
"Use the `jax.checkpoint` decorator (aliased as `jax.remat`) with `jax.grad` to control which intermediates are saved on the forward pass versus recomputed on the backward pass, trading off memory and FLOPs.\n",
"\n",
@@ -739,8 +739,6 @@
"metadata": {},
"outputs": [],
"source": [
- "from jax.ad_checkpoint import checkpoint_name\n",
- "\n",
"def predict(params, x):\n",
" *Ws, Wlast = params\n",
" for i, W in enumerate(Ws):\n",
diff --git a/docs/notebooks/autodiff_remat.md b/docs/notebooks/autodiff_remat.md
index b31e093b6f91..0a6c84b2d88f 100644
--- a/docs/notebooks/autodiff_remat.md
+++ b/docs/notebooks/autodiff_remat.md
@@ -5,7 +5,7 @@ jupytext:
extension: .md
format_name: myst
format_version: 0.13
- jupytext_version: 1.16.1
+ jupytext_version: 1.16.4
kernelspec:
display_name: Python 3
name: python3
@@ -24,7 +24,7 @@ import jax.numpy as jnp
+++ {"id": "qaIsQSh1XoKF"}
-### TL;DR
+### Summary
Use the `jax.checkpoint` decorator (aliased as `jax.remat`) with `jax.grad` to control which intermediates are saved on the forward pass versus recomputed on the backward pass, trading off memory and FLOPs.
@@ -370,8 +370,6 @@ Notice also that by providing a policy, we didn't need to edit the code defining
Some policies can refer to values named with `jax.ad_checkpoint.checkpoint_name`:
```{code-cell}
-from jax.ad_checkpoint import checkpoint_name
-
def predict(params, x):
*Ws, Wlast = params
for i, W in enumerate(Ws):
diff --git a/docs/notebooks/convolutions.ipynb b/docs/notebooks/convolutions.ipynb
index 0a823353068b..9d91804b6021 100644
--- a/docs/notebooks/convolutions.ipynb
+++ b/docs/notebooks/convolutions.ipynb
@@ -6,11 +6,11 @@
"id": "TVT_MVvc02AA"
},
"source": [
- "# Generalized Convolutions in JAX\n",
+ "# Generalized convolutions in JAX\n",
"\n",
"\n",
"\n",
- "[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/jax/blob/main/docs/notebooks/convolutions.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/google/jax/blob/main/docs/notebooks/convolutions.ipynb)\n",
+ "[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jax-ml/jax/blob/main/docs/notebooks/convolutions.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/jax-ml/jax/blob/main/docs/notebooks/convolutions.ipynb)\n",
"\n",
"JAX provides a number of interfaces to compute convolutions across data, including:\n",
"\n",
@@ -28,7 +28,7 @@
"id": "ewZEn2X12-Ng"
},
"source": [
- "## Basic One-dimensional Convolution\n",
+ "## Basic one-dimensional convolution\n",
"\n",
"Basic one-dimensional convolution is implemented by {func}`jax.numpy.convolve`, which provides a JAX interface for {func}`numpy.convolve`. Here is a simple example of 1D smoothing implemented via a convolution:"
]
@@ -91,7 +91,7 @@
"id": "5ndvLDIH4rv6"
},
"source": [
- "## Basic N-dimensional Convolution\n",
+ "## Basic N-dimensional convolution\n",
"\n",
"For *N*-dimensional convolution, {func}`jax.scipy.signal.convolve` provides a similar interface to that of {func}`jax.numpy.convolve`, generalized to *N* dimensions.\n",
"\n",
@@ -160,7 +160,7 @@
"id": "bxuUjFVG-v1h"
},
"source": [
- "## General Convolutions"
+ "## General convolutions"
]
},
{
@@ -410,7 +410,7 @@
],
"source": [
"dn = lax.conv_dimension_numbers(img.shape, # only ndim matters, not shape\n",
- " kernel.shape, # only ndim matters, not shape \n",
+ " kernel.shape, # only ndim matters, not shape\n",
" ('NHWC', 'HWIO', 'NHWC')) # the important bit\n",
"print(dn)"
]
@@ -806,8 +806,8 @@
],
"source": [
"# 1D kernel - WIO layout\n",
- "kernel = jnp.array([[[1, 0, -1], [-1, 0, 1]], \n",
- " [[1, 1, 1], [-1, -1, -1]]], \n",
+ "kernel = jnp.array([[[1, 0, -1], [-1, 0, 1]],\n",
+ " [[1, 1, 1], [-1, -1, -1]]],\n",
" dtype=jnp.float32).transpose([2,1,0])\n",
"# 1D data - NWC layout\n",
"data = np.zeros((1, 200, 2), dtype=jnp.float32)\n",
@@ -895,8 +895,8 @@
"# Random 3D kernel - HWDIO layout\n",
"kernel = jnp.array([\n",
" [[0, 0, 0], [0, 1, 0], [0, 0, 0]],\n",
- " [[0, -1, 0], [-1, 0, -1], [0, -1, 0]], \n",
- " [[0, 0, 0], [0, 1, 0], [0, 0, 0]]], \n",
+ " [[0, -1, 0], [-1, 0, -1], [0, -1, 0]],\n",
+ " [[0, 0, 0], [0, 1, 0], [0, 0, 0]]],\n",
" dtype=jnp.float32)[:, :, :, jnp.newaxis, jnp.newaxis]\n",
"\n",
"# 3D data - NHWDC layout\n",
@@ -919,7 +919,6 @@
"print(\"out shape: \", out.shape)\n",
"\n",
"# Make some simple 3d density plots:\n",
- "from mpl_toolkits.mplot3d import Axes3D\n",
"def make_alpha(cmap):\n",
" my_cmap = cmap(jnp.arange(cmap.N))\n",
" my_cmap[:,-1] = jnp.linspace(0, 1, cmap.N)**3\n",
diff --git a/docs/notebooks/convolutions.md b/docs/notebooks/convolutions.md
index 3de8f261aa5b..b98099aa9571 100644
--- a/docs/notebooks/convolutions.md
+++ b/docs/notebooks/convolutions.md
@@ -5,7 +5,7 @@ jupytext:
extension: .md
format_name: myst
format_version: 0.13
- jupytext_version: 1.16.1
+ jupytext_version: 1.16.4
kernelspec:
display_name: Python 3
language: python
@@ -14,11 +14,11 @@ kernelspec:
+++ {"id": "TVT_MVvc02AA"}
-# Generalized Convolutions in JAX
+# Generalized convolutions in JAX
-[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/jax/blob/main/docs/notebooks/convolutions.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/google/jax/blob/main/docs/notebooks/convolutions.ipynb)
+[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jax-ml/jax/blob/main/docs/notebooks/convolutions.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/jax-ml/jax/blob/main/docs/notebooks/convolutions.ipynb)
JAX provides a number of interfaces to compute convolutions across data, including:
@@ -31,7 +31,7 @@ For basic convolution operations, the `jax.numpy` and `jax.scipy` operations are
+++ {"id": "ewZEn2X12-Ng"}
-## Basic One-dimensional Convolution
+## Basic one-dimensional convolution
Basic one-dimensional convolution is implemented by {func}`jax.numpy.convolve`, which provides a JAX interface for {func}`numpy.convolve`. Here is a simple example of 1D smoothing implemented via a convolution:
@@ -65,7 +65,7 @@ For more information, see the {func}`jax.numpy.convolve` documentation, or the d
+++ {"id": "5ndvLDIH4rv6"}
-## Basic N-dimensional Convolution
+## Basic N-dimensional convolution
For *N*-dimensional convolution, {func}`jax.scipy.signal.convolve` provides a similar interface to that of {func}`jax.numpy.convolve`, generalized to *N* dimensions.
@@ -105,7 +105,7 @@ Like in the one-dimensional case, we use `mode='same'` to specify how we would l
+++ {"id": "bxuUjFVG-v1h"}
-## General Convolutions
+## General convolutions
+++ {"id": "0pcn2LeS-03b"}
@@ -210,7 +210,7 @@ The important argument is the 3-tuple of axis layout arguments:
:outputId: d5a569b3-febc-4832-f725-1d5e8fd31b9b
dn = lax.conv_dimension_numbers(img.shape, # only ndim matters, not shape
- kernel.shape, # only ndim matters, not shape
+ kernel.shape, # only ndim matters, not shape
('NHWC', 'HWIO', 'NHWC')) # the important bit
print(dn)
```
@@ -363,8 +363,8 @@ You aren't limited to 2D convolutions, a simple 1D demo is below:
:outputId: 67c46ace-6adc-4c47-c1c7-1f185be5fd4b
# 1D kernel - WIO layout
-kernel = jnp.array([[[1, 0, -1], [-1, 0, 1]],
- [[1, 1, 1], [-1, -1, -1]]],
+kernel = jnp.array([[[1, 0, -1], [-1, 0, 1]],
+ [[1, 1, 1], [-1, -1, -1]]],
dtype=jnp.float32).transpose([2,1,0])
# 1D data - NWC layout
data = np.zeros((1, 200, 2), dtype=jnp.float32)
@@ -406,8 +406,8 @@ import matplotlib as mpl
# Random 3D kernel - HWDIO layout
kernel = jnp.array([
[[0, 0, 0], [0, 1, 0], [0, 0, 0]],
- [[0, -1, 0], [-1, 0, -1], [0, -1, 0]],
- [[0, 0, 0], [0, 1, 0], [0, 0, 0]]],
+ [[0, -1, 0], [-1, 0, -1], [0, -1, 0]],
+ [[0, 0, 0], [0, 1, 0], [0, 0, 0]]],
dtype=jnp.float32)[:, :, :, jnp.newaxis, jnp.newaxis]
# 3D data - NHWDC layout
@@ -430,7 +430,6 @@ out = lax.conv_general_dilated(data, # lhs = image tensor
print("out shape: ", out.shape)
# Make some simple 3d density plots:
-from mpl_toolkits.mplot3d import Axes3D
def make_alpha(cmap):
my_cmap = cmap(jnp.arange(cmap.N))
my_cmap[:,-1] = jnp.linspace(0, 1, cmap.N)**3
diff --git a/docs/notebooks/external_callbacks.ipynb b/docs/notebooks/external_callbacks.ipynb
deleted file mode 100644
index bdf71004c01b..000000000000
--- a/docs/notebooks/external_callbacks.ipynb
+++ /dev/null
@@ -1,1122 +0,0 @@
-{
- "cells": [
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "7XNMxdTwURqI"
- },
- "source": [
- "# External Callbacks in JAX\n",
- "\n",
- ""
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "h6lXo6bSUYGq"
- },
- "source": [
- "This guide outlines the uses of various callback functions, which allow JAX runtimes to execute Python code on the host, even while running under `jit`, `vmap`, `grad`, or another transformation."
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "Xi_nhfpnlmbm"
- },
- "source": [
- "## Why callbacks?\n",
- "\n",
- "A callback routine is a way to perform **host-side** execution of code at runtime.\n",
- "As a simple example, suppose you'd like to print the *value* of some variable during the course of a computation.\n",
- "Using a simple Python `print` statement, it looks like this:"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 1,
- "metadata": {
- "id": "lz8rEL1Amb4r",
- "outputId": "bbd37102-19f2-46d2-b794-3d4952c6fe97"
- },
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "intermediate value: Tracedwith\n"
- ]
- }
- ],
- "source": [
- "import jax\n",
- "\n",
- "@jax.jit\n",
- "def f(x):\n",
- " y = x + 1\n",
- " print(\"intermediate value: {}\".format(y))\n",
- " return y * 2\n",
- "\n",
- "result = f(2)"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "yEy41sFAmxOp"
- },
- "source": [
- "What is printed is not the runtime value, but the trace-time abstract value (if you're not famililar with *tracing* in JAX, a good primer can be found in [How To Think In JAX](https://jax.readthedocs.io/en/latest/notebooks/thinking_in_jax.html)).\n",
- "\n",
- "To print the value at runtime we need a callback, for example `jax.debug.print`:"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 2,
- "metadata": {
- "id": "wFfHmoQxnKDF",
- "outputId": "6bea21d9-9bb1-4d4d-f3ec-fcf1c691a46a"
- },
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "intermediate value: 3\n"
- ]
- }
- ],
- "source": [
- "@jax.jit\n",
- "def f(x):\n",
- " y = x + 1\n",
- " jax.debug.print(\"intermediate value: {}\", y)\n",
- " return y * 2\n",
- "\n",
- "result = f(2)"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "CvWv3pudn9X5"
- },
- "source": [
- "This works by passing the runtime value represented by `y` back to the host process, where the host can print the value."
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "X0vR078znuT-"
- },
- "source": [
- "## Flavors of Callback\n",
- "\n",
- "In earlier versions of JAX, there was only one kind of callback available, implemented in `jax.experimental.host_callback`. The `host_callback` routines had some deficiencies, and are now deprecated in favor of several callbacks designed for different situations:\n",
- "\n",
- "- {func}`jax.pure_callback`: appropriate for pure functions: i.e. functions with no side effect.\n",
- "- {func}`jax.experimental.io_callback`: appropriate for impure functions: e.g. functions which read or write data to disk.\n",
- "- {func}`jax.debug.callback`: appropriate for functions that should reflect the execution behavior of the compiler.\n",
- "\n",
- "(The {func}`jax.debug.print` function we used above is a wrapper around {func}`jax.debug.callback`).\n",
- "\n",
- "From the user perspective, these three flavors of callback are mainly distinguished by what transformations and compiler optimizations they allow.\n",
- "\n",
- "|callback function | supports return value | `jit` | `vmap` | `grad` | `scan`/`while_loop` | guaranteed execution |\n",
- "|-------------------------------------|----|----|----|----|----|----|\n",
- "|`jax.pure_callback` | ✅ | ✅ | ✅ | ❌¹ | ✅ | ❌ |\n",
- "|`jax.experimental.io_callback` | ✅ | ✅ | ✅/❌² | ❌ | ✅³ | ✅ |\n",
- "|`jax.debug.callback` | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ |\n",
- "\n",
- "¹ `jax.pure_callback` can be used with `custom_jvp` to make it compatible with autodiff\n",
- "\n",
- "² `jax.experimental.io_callback` is compatible with `vmap` only if `ordered=False`.\n",
- "\n",
- "³ Note that `vmap` of `scan`/`while_loop` of `io_callback` has complicated semantics, and its behavior may change in future releases."
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "hE_M8DaPvoym"
- },
- "source": [
- "### Exploring `jax.pure_callback`\n",
- "\n",
- "`jax.pure_callback` is generally the callback function you should reach for when you want host-side execution of a pure function: i.e. a function that has no side-effects (such as printing values, reading data from disk, updating a global state, etc.).\n",
- "\n",
- "The function you pass to `jax.pure_callback` need not actually be pure, but it will be assumed pure by JAX's transformations and higher-order functions, which means that it may be silently elided or called multiple times."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 3,
- "metadata": {
- "id": "4lQDzXy6t_-k",
- "outputId": "279e4daf-0540-4eab-f535-d3bcbac74c44"
- },
- "outputs": [
- {
- "data": {
- "text/plain": [
- "Array([ 0. , 0.841471 , 0.9092974, 0.14112 , -0.7568025], dtype=float32)"
- ]
- },
- "execution_count": 3,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "import jax\n",
- "import jax.numpy as jnp\n",
- "import numpy as np\n",
- "\n",
- "def f_host(x):\n",
- " # call a numpy (not jax.numpy) operation:\n",
- " return np.sin(x).astype(x.dtype)\n",
- "\n",
- "def f(x):\n",
- " result_shape = jax.ShapeDtypeStruct(x.shape, x.dtype)\n",
- " return jax.pure_callback(f_host, result_shape, x)\n",
- "\n",
- "x = jnp.arange(5.0)\n",
- "f(x)"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "q7YCIr8qMrDs"
- },
- "source": [
- "Because `pure_callback` can be elided or duplicated, it is compatible out-of-the-box with transformations like `jit` and `vmap`, as well as higher-order primitives like `scan` and `while_loop`:\""
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 4,
- "metadata": {
- "id": "bgoZ0fxsuoWV",
- "outputId": "901443bd-5cb4-4923-ce53-6f832ac22ca9"
- },
- "outputs": [
- {
- "data": {
- "text/plain": [
- "Array([ 0. , 0.841471 , 0.9092974, 0.14112 , -0.7568025], dtype=float32)"
- ]
- },
- "execution_count": 4,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "jax.jit(f)(x)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 5,
- "metadata": {
- "id": "ajBRGWGfupu2",
- "outputId": "b28e31ee-7457-4b92-872b-52d819f53ddf"
- },
- "outputs": [
- {
- "data": {
- "text/plain": [
- "Array([ 0. , 0.841471 , 0.9092974, 0.14112 , -0.7568025], dtype=float32)"
- ]
- },
- "execution_count": 5,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "jax.vmap(f)(x)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 6,
- "metadata": {
- "id": "xe7AOGexvC13",
- "outputId": "8fa77977-1f2b-41c5-cc5e-11993ee5aa3e"
- },
- "outputs": [
- {
- "data": {
- "text/plain": [
- "Array([ 0. , 0.841471 , 0.9092974, 0.14112 , -0.7568025], dtype=float32)"
- ]
- },
- "execution_count": 6,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "def body_fun(_, x):\n",
- " return _, f(x)\n",
- "jax.lax.scan(body_fun, None, jnp.arange(5.0))[1]"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "tMzAVs2VNj5G"
- },
- "source": [
- "However, because there is no way for JAX to introspect the content of the callback, `pure_callback` has undefined autodiff semantics:"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 7,
- "metadata": {
- "id": "4QAF4VhUu5bb",
- "outputId": "f8a06d02-47e9-4240-8077-d7be81e5a480"
- },
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "Exception reporting mode: Minimal\n"
- ]
- }
- ],
- "source": [
- "%xmode minimal"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 8,
- "metadata": {
- "id": "qUpKPxlOurfY",
- "outputId": "11a665e8-40eb-4b0e-dc2e-a544a25fc57e",
- "tags": [
- "raises-exception"
- ]
- },
- "outputs": [
- {
- "ename": "ValueError",
- "evalue": "ignored",
- "output_type": "error",
- "traceback": [
- "\u001b[0;31mValueError\u001b[0m\u001b[0;31m:\u001b[0m Pure callbacks do not support JVP. Please use `jax.custom_jvp` to use callbacks while taking gradients.\n"
- ]
- }
- ],
- "source": [
- "jax.grad(f)(x)"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "y9DAibV4Nwpo"
- },
- "source": [
- "For an example of using `pure_callback` with `jax.custom_jvp`, see *Example: `pure_callback` with `custom_jvp`* below."
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "LrvdAloMZbIe"
- },
- "source": [
- "By design functions passed to `pure_callback` are treated as if they have no side-effects: one consequence of this is that if the output of the function is not used, the compiler may eliminate the callback entirely:"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 9,
- "metadata": {
- "id": "mmFc_zawZrBq",
- "outputId": "a4df7568-3f64-4b2f-9a2c-7adb2e0815e0"
- },
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "printing something\n"
- ]
- }
- ],
- "source": [
- "def print_something():\n",
- " print('printing something')\n",
- " return np.int32(0)\n",
- "\n",
- "@jax.jit\n",
- "def f1():\n",
- " return jax.pure_callback(print_something, np.int32(0))\n",
- "f1();"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 10,
- "metadata": {
- "id": "tTwE4kpmaNei"
- },
- "outputs": [],
- "source": [
- "@jax.jit\n",
- "def f2():\n",
- " jax.pure_callback(print_something, np.int32(0))\n",
- " return 1.0\n",
- "f2();"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "qfyGYbw4Z5U3"
- },
- "source": [
- "In `f1`, the output of the callback is used in the return value of the function, so the callback is executed and we see the printed output.\n",
- "In `f2` on the other hand, the output of the callback is unused, and so the compiler notices this and eliminates the function call. These are the correct semantics for a callback to a function with no side-effects."
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "JHcJybr7OEBM"
- },
- "source": [
- "### Exploring `jax.experimental.io_callback`\n",
- "\n",
- "In contrast to {func}`jax.pure_callback`, {func}`jax.experimental.io_callback` is explicitly meant to be used with impure functions, i.e. functions that do have side-effects.\n",
- "\n",
- "As an example, here is a callback to a global host-side numpy random generator. This is an impure operation because a side-effect of generating a random number in numpy is that the random state is updated (Please note that this is meant as a toy example of `io_callback` and not necessarily a recommended way of generating random numbers in JAX!)."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 11,
- "metadata": {
- "id": "eAg5xIhrOiWV",
- "outputId": "e3cfec21-d843-4852-a49d-69a69fba9fc1"
- },
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "generating float32[5]\n"
- ]
- },
- {
- "data": {
- "text/plain": [
- "Array([0.6369617 , 0.26978672, 0.04097353, 0.01652764, 0.8132702 ], dtype=float32)"
- ]
- },
- "execution_count": 11,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "from jax.experimental import io_callback\n",
- "from functools import partial\n",
- "\n",
- "global_rng = np.random.default_rng(0)\n",
- "\n",
- "def host_side_random_like(x):\n",
- " \"\"\"Generate a random array like x using the global_rng state\"\"\"\n",
- " # We have two side-effects here:\n",
- " # - printing the shape and dtype\n",
- " # - calling global_rng, thus updating its state\n",
- " print(f'generating {x.dtype}{list(x.shape)}')\n",
- " return global_rng.uniform(size=x.shape).astype(x.dtype)\n",
- "\n",
- "@jax.jit\n",
- "def numpy_random_like(x):\n",
- " return io_callback(host_side_random_like, x, x)\n",
- "\n",
- "x = jnp.zeros(5)\n",
- "numpy_random_like(x)"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "mAIF31MlXj33"
- },
- "source": [
- "The `io_callback` is compatible with `vmap` by default:"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 12,
- "metadata": {
- "id": "NY3o5dG6Vg6u",
- "outputId": "a67a8a98-214e-40ca-ad98-a930cd3db85e"
- },
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "generating float32[]\n",
- "generating float32[]\n",
- "generating float32[]\n",
- "generating float32[]\n",
- "generating float32[]\n"
- ]
- },
- {
- "data": {
- "text/plain": [
- "Array([0.91275555, 0.60663575, 0.72949654, 0.543625 , 0.9350724 ], dtype=float32)"
- ]
- },
- "execution_count": 12,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "jax.vmap(numpy_random_like)(x)"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "XXvSeeOXXquZ"
- },
- "source": [
- "Note, however, that this may execute the mapped callbacks in any order. So, for example, if you ran this on a GPU, the order of the mapped outputs might differ from run to run.\n",
- "\n",
- "If it is important that the order of callbacks be preserved, you can set `ordered=True`, in which case attempting to `vmap` will raise an error:"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 13,
- "metadata": {
- "id": "3aNmRsDrX3-2",
- "outputId": "a8ff4b77-f4cb-442f-8cfb-ea7251c66274",
- "tags": [
- "raises-exception"
- ]
- },
- "outputs": [
- {
- "ename": "ValueError",
- "evalue": "ignored",
- "output_type": "error",
- "traceback": [
- "\u001b[0;31mJaxStackTraceBeforeTransformation\u001b[0m\u001b[0;31m:\u001b[0m ValueError: Cannot `vmap` ordered IO callback.\n\nThe preceding stack trace is the source of the JAX operation that, once transformed by JAX, triggered the following exception.\n\n--------------------\n",
- "\nThe above exception was the direct cause of the following exception:\n",
- "\u001b[0;31mValueError\u001b[0m\u001b[0;31m:\u001b[0m Cannot `vmap` ordered IO callback.\n"
- ]
- }
- ],
- "source": [
- "@jax.jit\n",
- "def numpy_random_like_ordered(x):\n",
- " return io_callback(host_side_random_like, x, x, ordered=True)\n",
- "\n",
- "jax.vmap(numpy_random_like_ordered)(x)"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "fD2FTHlUYAZH"
- },
- "source": [
- "On the other hand, `scan` and `while_loop` work with `io_callback` regardless of whether ordering is enforced:"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 14,
- "metadata": {
- "id": "lMVzZlIEWL7F",
- "outputId": "f9741c18-a30d-4d46-b706-8102849286b5"
- },
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "generating float32[]\n",
- "generating float32[]\n",
- "generating float32[]\n",
- "generating float32[]\n",
- "generating float32[]\n"
- ]
- },
- {
- "data": {
- "text/plain": [
- "Array([0.81585354, 0.0027385 , 0.8574043 , 0.03358557, 0.72965544], dtype=float32)"
- ]
- },
- "execution_count": 14,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "def body_fun(_, x):\n",
- " return _, numpy_random_like_ordered(x)\n",
- "jax.lax.scan(body_fun, None, jnp.arange(5.0))[1]"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "w_sf8mCbbo8K"
- },
- "source": [
- "Like `pure_callback`, `io_callback` fails under automatic differentiation if it is passed a differentiated variable:"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 15,
- "metadata": {
- "id": "Cn6_RG4JcKZm",
- "outputId": "336ae5d2-e35b-4fe5-cbfb-14a7aef28c07",
- "tags": [
- "raises-exception"
- ]
- },
- "outputs": [
- {
- "ename": "ValueError",
- "evalue": "ignored",
- "output_type": "error",
- "traceback": [
- "\u001b[0;31mJaxStackTraceBeforeTransformation\u001b[0m\u001b[0;31m:\u001b[0m ValueError: IO callbacks do not support JVP.\n\nThe preceding stack trace is the source of the JAX operation that, once transformed by JAX, triggered the following exception.\n\n--------------------\n",
- "\nThe above exception was the direct cause of the following exception:\n",
- "\u001b[0;31mValueError\u001b[0m\u001b[0;31m:\u001b[0m IO callbacks do not support JVP.\n"
- ]
- }
- ],
- "source": [
- "jax.grad(numpy_random_like)(x)"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "plvfn9lWcKu4"
- },
- "source": [
- "However, if the callback is not dependent on a differentiated variable, it will execute:"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 16,
- "metadata": {
- "id": "wxgfDmDfb5bx",
- "outputId": "d8c0285c-cd04-4b4d-d15a-1b07f778882d"
- },
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "hello\n"
- ]
- }
- ],
- "source": [
- "@jax.jit\n",
- "def f(x):\n",
- " io_callback(lambda: print('hello'), None)\n",
- " return x\n",
- "\n",
- "jax.grad(f)(1.0);"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "STLI40EZcVIY"
- },
- "source": [
- "Unlike `pure_callback`, the compiler will not remove the callback execution in this case, even though the output of the callback is unused in the subsequent computation."
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "pkkM1ZmqclV-"
- },
- "source": [
- "### Exploring `debug.callback`\n",
- "\n",
- "Both `pure_callback` and `io_callback` enforce some assumptions about the purity of the function they're calling, and limit in various ways what JAX transforms and compilation machinery may do. `debug.callback` essentially assumes *nothing* about the callback function, such that the action of the callback reflects exactly what JAX is doing during the course of a program. Further, `debug.callback` *cannot* return any value to the program."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 17,
- "metadata": {
- "id": "74TdWyu9eqBa",
- "outputId": "d8551dab-2e61-492e-9ac3-dc3db51b2c18"
- },
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "log: 1.0\n"
- ]
- }
- ],
- "source": [
- "from jax import debug\n",
- "\n",
- "def log_value(x):\n",
- " # This could be an actual logging call; we'll use\n",
- " # print() for demonstration\n",
- " print(\"log:\", x)\n",
- "\n",
- "@jax.jit\n",
- "def f(x):\n",
- " debug.callback(log_value, x)\n",
- " return x\n",
- "\n",
- "f(1.0);"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "P848STlsfzmW"
- },
- "source": [
- "The debug callback is compatible with `vmap`:"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 18,
- "metadata": {
- "id": "2sSNsPB-fGVI",
- "outputId": "fff58575-d94c-48fb-b88a-c1c395595fd0"
- },
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "log: 0.0\n",
- "log: 1.0\n",
- "log: 2.0\n",
- "log: 3.0\n",
- "log: 4.0\n"
- ]
- }
- ],
- "source": [
- "x = jnp.arange(5.0)\n",
- "jax.vmap(f)(x);"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "VDMacqpXf3La"
- },
- "source": [
- "And is also compatible with `grad` and other autodiff transformations"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 19,
- "metadata": {
- "id": "wkFRle-tfTDe",
- "outputId": "4e8a81d0-5012-4c51-d843-3fbdc498df31"
- },
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "log: 1.0\n"
- ]
- }
- ],
- "source": [
- "jax.grad(f)(1.0);"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "w8t-SDZ3gRzE"
- },
- "source": [
- "This can make `debug.callback` more useful for general-purpose debugging than either `pure_callback` or `io_callback`."
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "dF7hoWGQUneJ"
- },
- "source": [
- "## Example: `pure_callback` with `custom_jvp`\n",
- "\n",
- "One powerful way to take advantage of {func}`jax.pure_callback` is to combine it with {class}`jax.custom_jvp` (see [Custom derivative rules](https://jax.readthedocs.io/en/latest/notebooks/Custom_derivative_rules_for_Python_code.html) for more details on `custom_jvp`).\n",
- "Suppose we want to create a JAX-compatible wrapper for a scipy or numpy function that is not yet available in the `jax.scipy` or `jax.numpy` wrappers.\n",
- "\n",
- "Here, we'll consider creating a wrapper for the Bessel function of the first kind, implemented in `scipy.special.jv`.\n",
- "We can start by defining a straightforward `pure_callback`:"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "id": "Ge4fNPZdVSJY"
- },
- "outputs": [],
- "source": [
- "import jax\n",
- "import jax.numpy as jnp\n",
- "import scipy.special\n",
- "\n",
- "def jv(v, z):\n",
- " v, z = jnp.asarray(v), jnp.asarray(z)\n",
- "\n",
- " # Require the order v to be integer type: this simplifies\n",
- " # the JVP rule below.\n",
- " assert jnp.issubdtype(v.dtype, jnp.integer)\n",
- "\n",
- " # Promote the input to inexact (float/complex).\n",
- " # Note that jnp.result_type() accounts for the enable_x64 flag.\n",
- " z = z.astype(jnp.result_type(float, z.dtype))\n",
- "\n",
- " # Wrap scipy function to return the expected dtype.\n",
- " _scipy_jv = lambda v, z: scipy.special.jv(v, z).astype(z.dtype)\n",
- "\n",
- " # Define the expected shape & dtype of output.\n",
- " result_shape_dtype = jax.ShapeDtypeStruct(\n",
- " shape=jnp.broadcast_shapes(v.shape, z.shape),\n",
- " dtype=z.dtype)\n",
- "\n",
- " # We use vectorize=True because scipy.special.jv handles broadcasted inputs.\n",
- " return jax.pure_callback(_scipy_jv, result_shape_dtype, v, z, vectorized=True)"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "vyjQj-0QVuoN"
- },
- "source": [
- "This lets us call into `scipy.special.jv` from transformed JAX code, including when transformed by `jit` and `vmap`:"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "id": "f4e46670f4e4"
- },
- "outputs": [],
- "source": [
- "from functools import partial\n",
- "j1 = partial(jv, 1)\n",
- "z = jnp.arange(5.0)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "id": "6svImqFHWBwj",
- "outputId": "bc8c778a-6c10-443b-9be2-c0f28e2ac1a9"
- },
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "[ 0. 0.44005057 0.5767248 0.33905897 -0.06604332]\n"
- ]
- }
- ],
- "source": [
- "print(j1(z))"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "d48eb4f2d48e"
- },
- "source": [
- "Here is the same result with `jit`:"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "id": "txvRqR9DWGdC",
- "outputId": "d25f3476-23b1-48e4-dda1-3c06d32c3b87"
- },
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "[ 0. 0.44005057 0.5767248 0.33905897 -0.06604332]\n"
- ]
- }
- ],
- "source": [
- "print(jax.jit(j1)(z))"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "d861a472d861"
- },
- "source": [
- "And here is the same result again with `vmap`:"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "id": "BS-Ve5u_WU0C",
- "outputId": "08cecd1f-6953-4853-e9db-25a03eb5b000"
- },
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "[ 0. 0.44005057 0.5767248 0.33905897 -0.06604332]\n"
- ]
- }
- ],
- "source": [
- "print(jax.vmap(j1)(z))"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "SCH2ii_dWXP6"
- },
- "source": [
- "However, if we call `jax.grad`, we see an error because there is no autodiff rule defined for this function:"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "id": "q3qh_4DrWxdQ",
- "outputId": "c46b0bfa-96f3-4629-b9af-a4d4f3ccb870",
- "tags": [
- "raises-exception"
- ]
- },
- "outputs": [
- {
- "ename": "ValueError",
- "evalue": "ignored",
- "output_type": "error",
- "traceback": [
- "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
- "\u001b[0;31mUnfilteredStackTrace\u001b[0m Traceback (most recent call last)",
- "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mjax\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mgrad\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mj1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mz\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
- "\u001b[0;32m/usr/local/lib/python3.8/dist-packages/jax/_src/traceback_util.py\u001b[0m in \u001b[0;36mreraise_with_filtered_traceback\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 161\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 162\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mfun\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 163\u001b[0m \u001b[0;32mexcept\u001b[0m \u001b[0mException\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0me\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
- "\u001b[0;32m/usr/local/lib/python3.8/dist-packages/jax/_src/api.py\u001b[0m in \u001b[0;36mgrad_f\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 1090\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mgrad_f\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1091\u001b[0;31m \u001b[0m_\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mg\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mvalue_and_grad_f\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1092\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mg\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
- "\u001b[0;32m/usr/local/lib/python3.8/dist-packages/jax/_src/traceback_util.py\u001b[0m in \u001b[0;36mreraise_with_filtered_traceback\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 161\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 162\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mfun\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 163\u001b[0m \u001b[0;32mexcept\u001b[0m \u001b[0mException\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0me\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
- "\u001b[0;32m/usr/local/lib/python3.8/dist-packages/jax/_src/api.py\u001b[0m in \u001b[0;36mvalue_and_grad_f\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 1166\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0mhas_aux\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1167\u001b[0;31m \u001b[0mans\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mvjp_py\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0m_vjp\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mf_partial\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0mdyn_args\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mreduce_axes\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mreduce_axes\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1168\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
- "\u001b[0;32m/usr/local/lib/python3.8/dist-packages/jax/_src/api.py\u001b[0m in \u001b[0;36m_vjp\u001b[0;34m(fun, has_aux, reduce_axes, *primals)\u001b[0m\n\u001b[1;32m 2655\u001b[0m \u001b[0mflat_fun\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mout_tree\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mflatten_fun_nokwargs\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfun\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0min_tree\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 2656\u001b[0;31m out_primal, out_vjp = ad.vjp(\n\u001b[0m\u001b[1;32m 2657\u001b[0m flat_fun, primals_flat, reduce_axes=reduce_axes)\n",
- "\u001b[0;32m/usr/local/lib/python3.8/dist-packages/jax/interpreters/ad.py\u001b[0m in \u001b[0;36mvjp\u001b[0;34m(traceable, primals, has_aux, reduce_axes)\u001b[0m\n\u001b[1;32m 134\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0mhas_aux\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 135\u001b[0;31m \u001b[0mout_primals\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mpvals\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mjaxpr\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mconsts\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mlinearize\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtraceable\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0mprimals\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 136\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
- "\u001b[0;32m/usr/local/lib/python3.8/dist-packages/jax/interpreters/ad.py\u001b[0m in \u001b[0;36mlinearize\u001b[0;34m(traceable, *primals, **kwargs)\u001b[0m\n\u001b[1;32m 123\u001b[0m \u001b[0mjvpfun_flat\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mout_tree\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mflatten_fun\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mjvpfun\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0min_tree\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 124\u001b[0;31m \u001b[0mjaxpr\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mout_pvals\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mconsts\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mpe\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtrace_to_jaxpr_nounits\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mjvpfun_flat\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0min_pvals\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 125\u001b[0m \u001b[0mout_primals_pvals\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mout_tangents_pvals\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtree_unflatten\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mout_tree\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mout_pvals\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
- "\u001b[0;32m/usr/local/lib/python3.8/dist-packages/jax/_src/profiler.py\u001b[0m in \u001b[0;36mwrapper\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 313\u001b[0m \u001b[0;32mwith\u001b[0m \u001b[0mTraceAnnotation\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mname\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mdecorator_kwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 314\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mfunc\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 315\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mwrapper\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
- "\u001b[0;32m/usr/local/lib/python3.8/dist-packages/jax/interpreters/partial_eval.py\u001b[0m in \u001b[0;36mtrace_to_jaxpr_nounits\u001b[0;34m(fun, pvals, instantiate)\u001b[0m\n\u001b[1;32m 766\u001b[0m \u001b[0mfun\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtrace_to_subjaxpr_nounits\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfun\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmain\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minstantiate\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 767\u001b[0;31m \u001b[0mjaxpr\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mout_pvals\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mconsts\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0menv\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mfun\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcall_wrapped\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mpvals\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 768\u001b[0m \u001b[0;32massert\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0menv\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
- "\u001b[0;32m/usr/local/lib/python3.8/dist-packages/jax/linear_util.py\u001b[0m in \u001b[0;36mcall_wrapped\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 166\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 167\u001b[0;31m \u001b[0mans\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mf\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mdict\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mparams\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 168\u001b[0m \u001b[0;32mexcept\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
- "\u001b[0;32m\u001b[0m in \u001b[0;36mjv\u001b[0;34m(v, z)\u001b[0m\n\u001b[1;32m 24\u001b[0m \u001b[0;31m# We use vectorize=True because scipy.special.jv handles broadcasted inputs.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 25\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mjax\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpure_callback\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0m_scipy_jv\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mresult_shape_dtype\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mv\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mz\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mvectorized\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mTrue\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
- "\u001b[0;32m/usr/local/lib/python3.8/dist-packages/jax/_src/api.py\u001b[0m in \u001b[0;36mpure_callback\u001b[0;34m(callback, result_shape_dtypes, *args, **kwargs)\u001b[0m\n\u001b[1;32m 3425\u001b[0m \"\"\"\n\u001b[0;32m-> 3426\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mjcb\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpure_callback\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mcallback\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mresult_shape_dtypes\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 3427\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
- "\u001b[0;32m/usr/local/lib/python3.8/dist-packages/jax/_src/callback.py\u001b[0m in \u001b[0;36mpure_callback\u001b[0;34m(callback, result_shape_dtypes, vectorized, *args, **kwargs)\u001b[0m\n\u001b[1;32m 130\u001b[0m \u001b[0mflat_result_avals\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mout_tree\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtree_util\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtree_flatten\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mresult_avals\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 131\u001b[0;31m out_flat = pure_callback_p.bind(\n\u001b[0m\u001b[1;32m 132\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0mflat_args\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcallback\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0m_flat_callback\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
- "\u001b[0;32m/usr/local/lib/python3.8/dist-packages/jax/core.py\u001b[0m in \u001b[0;36mbind\u001b[0;34m(self, *args, **params)\u001b[0m\n\u001b[1;32m 328\u001b[0m all(isinstance(arg, Tracer) or valid_jaxtype(arg) for arg in args)), args\n\u001b[0;32m--> 329\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbind_with_trace\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfind_top_trace\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mparams\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 330\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
- "\u001b[0;32m/usr/local/lib/python3.8/dist-packages/jax/core.py\u001b[0m in \u001b[0;36mbind_with_trace\u001b[0;34m(self, trace, args, params)\u001b[0m\n\u001b[1;32m 331\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mbind_with_trace\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtrace\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mparams\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 332\u001b[0;31m \u001b[0mout\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtrace\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mprocess_primitive\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmap\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtrace\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfull_raise\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0margs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mparams\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 333\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mmap\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfull_lower\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mout\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmultiple_results\u001b[0m \u001b[0;32melse\u001b[0m \u001b[0mfull_lower\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mout\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
- "\u001b[0;32m/usr/local/lib/python3.8/dist-packages/jax/interpreters/ad.py\u001b[0m in \u001b[0;36mprocess_primitive\u001b[0;34m(self, primitive, tracers, params)\u001b[0m\n\u001b[1;32m 309\u001b[0m \u001b[0;32mraise\u001b[0m \u001b[0mNotImplementedError\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmsg\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 310\u001b[0;31m \u001b[0mprimal_out\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtangent_out\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mjvp\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mprimals_in\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtangents_in\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mparams\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 311\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mprimitive\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmultiple_results\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
- "\u001b[0;32m/usr/local/lib/python3.8/dist-packages/jax/_src/callback.py\u001b[0m in \u001b[0;36mpure_callback_jvp_rule\u001b[0;34m(***failed resolving arguments***)\u001b[0m\n\u001b[1;32m 54\u001b[0m \u001b[0;32mdel\u001b[0m \u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mkwargs\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 55\u001b[0;31m raise ValueError(\n\u001b[0m\u001b[1;32m 56\u001b[0m \u001b[0;34m\"Pure callbacks do not support JVP. \"\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
- "\u001b[0;31mUnfilteredStackTrace\u001b[0m: ValueError: Pure callbacks do not support JVP. Please use `jax.custom_jvp` to use callbacks while taking gradients.\n\nThe stack trace below excludes JAX-internal frames.\nThe preceding is the original exception that occurred, unmodified.\n\n--------------------",
- "\nThe above exception was the direct cause of the following exception:\n",
- "\u001b[0;31mValueError\u001b[0m Traceback (most recent call last)",
- "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mjax\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mgrad\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mj1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mz\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
- "\u001b[0;32m\u001b[0m in \u001b[0;36mjv\u001b[0;34m(v, z)\u001b[0m\n\u001b[1;32m 23\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 24\u001b[0m \u001b[0;31m# We use vectorize=True because scipy.special.jv handles broadcasted inputs.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 25\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mjax\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpure_callback\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0m_scipy_jv\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mresult_shape_dtype\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mv\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mz\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mvectorized\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mTrue\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
- "\u001b[0;32m/usr/local/lib/python3.8/dist-packages/jax/_src/callback.py\u001b[0m in \u001b[0;36mpure_callback\u001b[0;34m(callback, result_shape_dtypes, vectorized, *args, **kwargs)\u001b[0m\n\u001b[1;32m 129\u001b[0m lambda x: core.ShapedArray(x.shape, x.dtype), result_shape_dtypes)\n\u001b[1;32m 130\u001b[0m \u001b[0mflat_result_avals\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mout_tree\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtree_util\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtree_flatten\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mresult_avals\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 131\u001b[0;31m out_flat = pure_callback_p.bind(\n\u001b[0m\u001b[1;32m 132\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0mflat_args\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcallback\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0m_flat_callback\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 133\u001b[0m result_avals=tuple(flat_result_avals), vectorized=vectorized)\n",
- "\u001b[0;32m/usr/local/lib/python3.8/dist-packages/jax/_src/callback.py\u001b[0m in \u001b[0;36mpure_callback_jvp_rule\u001b[0;34m(***failed resolving arguments***)\u001b[0m\n\u001b[1;32m 53\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mpure_callback_jvp_rule\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 54\u001b[0m \u001b[0;32mdel\u001b[0m \u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mkwargs\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 55\u001b[0;31m raise ValueError(\n\u001b[0m\u001b[1;32m 56\u001b[0m \u001b[0;34m\"Pure callbacks do not support JVP. \"\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 57\u001b[0m \"Please use `jax.custom_jvp` to use callbacks while taking gradients.\")\n",
- "\u001b[0;31mValueError\u001b[0m: Pure callbacks do not support JVP. Please use `jax.custom_jvp` to use callbacks while taking gradients."
- ]
- }
- ],
- "source": [
- "jax.grad(j1)(z)"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "PtYeJ_xUW09v"
- },
- "source": [
- "Let's define a custom gradient rule for this. Looking at the definition of the [Bessel Function of the First Kind](https://en.wikipedia.org/?title=Bessel_function_of_the_first_kind), we find that there is a relatively straightforward recurrence relationship for the derivative with respect to the argument `z`:\n",
- "\n",
- "$$\n",
- "d J_\\nu(z) = \\left\\{\n",
- "\\begin{eqnarray}\n",
- "-J_1(z),\\ &\\nu=0\\\\\n",
- "[J_{\\nu - 1}(z) - J_{\\nu + 1}(z)]/2,\\ &\\nu\\ne 0\n",
- "\\end{eqnarray}\\right.\n",
- "$$\n",
- "\n",
- "The gradient with respect to $\\nu$ is more complicated, but since we've restricted the `v` argument to integer types we don't need to worry about its gradient for the sake of this example.\n",
- "\n",
- "We can use `jax.custom_jvp` to define this automatic differentiation rule for our callback function:"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "id": "BOVQnt05XvLs"
- },
- "outputs": [],
- "source": [
- "jv = jax.custom_jvp(jv)\n",
- "\n",
- "@jv.defjvp\n",
- "def _jv_jvp(primals, tangents):\n",
- " v, z = primals\n",
- " _, z_dot = tangents # Note: v_dot is always 0 because v is integer.\n",
- " jv_minus_1, jv_plus_1 = jv(v - 1, z), jv(v + 1, z)\n",
- " djv_dz = jnp.where(v == 0, -jv_plus_1, 0.5 * (jv_minus_1 - jv_plus_1))\n",
- " return jv(v, z), z_dot * djv_dz"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "W1SxcvQSX44c"
- },
- "source": [
- "Now computing the gradient of our function will work correctly:"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "id": "sCGceBs-X8nL",
- "outputId": "71c5589f-f996-44a0-f09a-ca8bb40c167a"
- },
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "-0.06447162\n"
- ]
- }
- ],
- "source": [
- "j1 = partial(jv, 1)\n",
- "print(jax.grad(j1)(2.0))"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "gWQ4phN5YB26"
- },
- "source": [
- "Further, since we've defined our gradient in terms of `jv` itself, JAX's architecture means that we get second-order and higher derivatives for free:"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "id": "QTe5mRAvYQBh",
- "outputId": "d58ecff3-9419-422a-fd0e-14a7d9cf2cc3"
- },
- "outputs": [
- {
- "data": {
- "text/plain": [
- "Array(-0.4003078, dtype=float32, weak_type=True)"
- ]
- },
- "execution_count": 8,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "jax.hessian(j1)(2.0)"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "QEXGxU4uYZii"
- },
- "source": [
- "Keep in mind that although this all works correctly with JAX, each call to our callback-based `jv` function will result in passing the input data from the device to the host, and passing the output of `scipy.special.jv` from the host back to the device.\n",
- "When running on accelerators like GPU or TPU, this data movement and host synchronization can lead to significant overhead each time `jv` is called.\n",
- "However, if you are running JAX on a single CPU (where the \"host\" and \"device\" are on the same hardware), JAX will generally do this data transfer in a fast, zero-copy fashion, making this pattern is a relatively straightforward way extend JAX's capabilities."
- ]
- }
- ],
- "metadata": {
- "colab": {
- "provenance": []
- },
- "jupytext": {
- "formats": "ipynb,md:myst"
- },
- "kernelspec": {
- "display_name": "Python 3",
- "name": "python3"
- },
- "language_info": {
- "codemirror_mode": {
- "name": "ipython",
- "version": 3
- },
- "name": "python",
- "version": "3.10.0"
- }
- },
- "nbformat": 4,
- "nbformat_minor": 0
-}
diff --git a/docs/notebooks/external_callbacks.md b/docs/notebooks/external_callbacks.md
deleted file mode 100644
index 857eef42e2b3..000000000000
--- a/docs/notebooks/external_callbacks.md
+++ /dev/null
@@ -1,516 +0,0 @@
----
-jupytext:
- formats: ipynb,md:myst
- text_representation:
- extension: .md
- format_name: myst
- format_version: 0.13
- jupytext_version: 1.16.1
-kernelspec:
- display_name: Python 3
- name: python3
----
-
-+++ {"id": "7XNMxdTwURqI"}
-
-# External Callbacks in JAX
-
-
-
-+++ {"id": "h6lXo6bSUYGq"}
-
-This guide outlines the uses of various callback functions, which allow JAX runtimes to execute Python code on the host, even while running under `jit`, `vmap`, `grad`, or another transformation.
-
-+++ {"id": "Xi_nhfpnlmbm"}
-
-## Why callbacks?
-
-A callback routine is a way to perform **host-side** execution of code at runtime.
-As a simple example, suppose you'd like to print the *value* of some variable during the course of a computation.
-Using a simple Python `print` statement, it looks like this:
-
-```{code-cell}
-:id: lz8rEL1Amb4r
-:outputId: bbd37102-19f2-46d2-b794-3d4952c6fe97
-
-import jax
-
-@jax.jit
-def f(x):
- y = x + 1
- print("intermediate value: {}".format(y))
- return y * 2
-
-result = f(2)
-```
-
-+++ {"id": "yEy41sFAmxOp"}
-
-What is printed is not the runtime value, but the trace-time abstract value (if you're not famililar with *tracing* in JAX, a good primer can be found in [How To Think In JAX](https://jax.readthedocs.io/en/latest/notebooks/thinking_in_jax.html)).
-
-To print the value at runtime we need a callback, for example `jax.debug.print`:
-
-```{code-cell}
-:id: wFfHmoQxnKDF
-:outputId: 6bea21d9-9bb1-4d4d-f3ec-fcf1c691a46a
-
-@jax.jit
-def f(x):
- y = x + 1
- jax.debug.print("intermediate value: {}", y)
- return y * 2
-
-result = f(2)
-```
-
-+++ {"id": "CvWv3pudn9X5"}
-
-This works by passing the runtime value represented by `y` back to the host process, where the host can print the value.
-
-+++ {"id": "X0vR078znuT-"}
-
-## Flavors of Callback
-
-In earlier versions of JAX, there was only one kind of callback available, implemented in `jax.experimental.host_callback`. The `host_callback` routines had some deficiencies, and are now deprecated in favor of several callbacks designed for different situations:
-
-- {func}`jax.pure_callback`: appropriate for pure functions: i.e. functions with no side effect.
-- {func}`jax.experimental.io_callback`: appropriate for impure functions: e.g. functions which read or write data to disk.
-- {func}`jax.debug.callback`: appropriate for functions that should reflect the execution behavior of the compiler.
-
-(The {func}`jax.debug.print` function we used above is a wrapper around {func}`jax.debug.callback`).
-
-From the user perspective, these three flavors of callback are mainly distinguished by what transformations and compiler optimizations they allow.
-
-|callback function | supports return value | `jit` | `vmap` | `grad` | `scan`/`while_loop` | guaranteed execution |
-|-------------------------------------|----|----|----|----|----|----|
-|`jax.pure_callback` | ✅ | ✅ | ✅ | ❌¹ | ✅ | ❌ |
-|`jax.experimental.io_callback` | ✅ | ✅ | ✅/❌² | ❌ | ✅³ | ✅ |
-|`jax.debug.callback` | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ |
-
-¹ `jax.pure_callback` can be used with `custom_jvp` to make it compatible with autodiff
-
-² `jax.experimental.io_callback` is compatible with `vmap` only if `ordered=False`.
-
-³ Note that `vmap` of `scan`/`while_loop` of `io_callback` has complicated semantics, and its behavior may change in future releases.
-
-+++ {"id": "hE_M8DaPvoym"}
-
-### Exploring `jax.pure_callback`
-
-`jax.pure_callback` is generally the callback function you should reach for when you want host-side execution of a pure function: i.e. a function that has no side-effects (such as printing values, reading data from disk, updating a global state, etc.).
-
-The function you pass to `jax.pure_callback` need not actually be pure, but it will be assumed pure by JAX's transformations and higher-order functions, which means that it may be silently elided or called multiple times.
-
-```{code-cell}
-:id: 4lQDzXy6t_-k
-:outputId: 279e4daf-0540-4eab-f535-d3bcbac74c44
-
-import jax
-import jax.numpy as jnp
-import numpy as np
-
-def f_host(x):
- # call a numpy (not jax.numpy) operation:
- return np.sin(x).astype(x.dtype)
-
-def f(x):
- result_shape = jax.ShapeDtypeStruct(x.shape, x.dtype)
- return jax.pure_callback(f_host, result_shape, x)
-
-x = jnp.arange(5.0)
-f(x)
-```
-
-+++ {"id": "q7YCIr8qMrDs"}
-
-Because `pure_callback` can be elided or duplicated, it is compatible out-of-the-box with transformations like `jit` and `vmap`, as well as higher-order primitives like `scan` and `while_loop`:"
-
-```{code-cell}
-:id: bgoZ0fxsuoWV
-:outputId: 901443bd-5cb4-4923-ce53-6f832ac22ca9
-
-jax.jit(f)(x)
-```
-
-```{code-cell}
-:id: ajBRGWGfupu2
-:outputId: b28e31ee-7457-4b92-872b-52d819f53ddf
-
-jax.vmap(f)(x)
-```
-
-```{code-cell}
-:id: xe7AOGexvC13
-:outputId: 8fa77977-1f2b-41c5-cc5e-11993ee5aa3e
-
-def body_fun(_, x):
- return _, f(x)
-jax.lax.scan(body_fun, None, jnp.arange(5.0))[1]
-```
-
-+++ {"id": "tMzAVs2VNj5G"}
-
-However, because there is no way for JAX to introspect the content of the callback, `pure_callback` has undefined autodiff semantics:
-
-```{code-cell}
-:id: 4QAF4VhUu5bb
-:outputId: f8a06d02-47e9-4240-8077-d7be81e5a480
-
-%xmode minimal
-```
-
-```{code-cell}
-:id: qUpKPxlOurfY
-:outputId: 11a665e8-40eb-4b0e-dc2e-a544a25fc57e
-:tags: [raises-exception]
-
-jax.grad(f)(x)
-```
-
-+++ {"id": "y9DAibV4Nwpo"}
-
-For an example of using `pure_callback` with `jax.custom_jvp`, see *Example: `pure_callback` with `custom_jvp`* below.
-
-+++ {"id": "LrvdAloMZbIe"}
-
-By design functions passed to `pure_callback` are treated as if they have no side-effects: one consequence of this is that if the output of the function is not used, the compiler may eliminate the callback entirely:
-
-```{code-cell}
-:id: mmFc_zawZrBq
-:outputId: a4df7568-3f64-4b2f-9a2c-7adb2e0815e0
-
-def print_something():
- print('printing something')
- return np.int32(0)
-
-@jax.jit
-def f1():
- return jax.pure_callback(print_something, np.int32(0))
-f1();
-```
-
-```{code-cell}
-:id: tTwE4kpmaNei
-
-@jax.jit
-def f2():
- jax.pure_callback(print_something, np.int32(0))
- return 1.0
-f2();
-```
-
-+++ {"id": "qfyGYbw4Z5U3"}
-
-In `f1`, the output of the callback is used in the return value of the function, so the callback is executed and we see the printed output.
-In `f2` on the other hand, the output of the callback is unused, and so the compiler notices this and eliminates the function call. These are the correct semantics for a callback to a function with no side-effects.
-
-+++ {"id": "JHcJybr7OEBM"}
-
-### Exploring `jax.experimental.io_callback`
-
-In contrast to {func}`jax.pure_callback`, {func}`jax.experimental.io_callback` is explicitly meant to be used with impure functions, i.e. functions that do have side-effects.
-
-As an example, here is a callback to a global host-side numpy random generator. This is an impure operation because a side-effect of generating a random number in numpy is that the random state is updated (Please note that this is meant as a toy example of `io_callback` and not necessarily a recommended way of generating random numbers in JAX!).
-
-```{code-cell}
-:id: eAg5xIhrOiWV
-:outputId: e3cfec21-d843-4852-a49d-69a69fba9fc1
-
-from jax.experimental import io_callback
-from functools import partial
-
-global_rng = np.random.default_rng(0)
-
-def host_side_random_like(x):
- """Generate a random array like x using the global_rng state"""
- # We have two side-effects here:
- # - printing the shape and dtype
- # - calling global_rng, thus updating its state
- print(f'generating {x.dtype}{list(x.shape)}')
- return global_rng.uniform(size=x.shape).astype(x.dtype)
-
-@jax.jit
-def numpy_random_like(x):
- return io_callback(host_side_random_like, x, x)
-
-x = jnp.zeros(5)
-numpy_random_like(x)
-```
-
-+++ {"id": "mAIF31MlXj33"}
-
-The `io_callback` is compatible with `vmap` by default:
-
-```{code-cell}
-:id: NY3o5dG6Vg6u
-:outputId: a67a8a98-214e-40ca-ad98-a930cd3db85e
-
-jax.vmap(numpy_random_like)(x)
-```
-
-+++ {"id": "XXvSeeOXXquZ"}
-
-Note, however, that this may execute the mapped callbacks in any order. So, for example, if you ran this on a GPU, the order of the mapped outputs might differ from run to run.
-
-If it is important that the order of callbacks be preserved, you can set `ordered=True`, in which case attempting to `vmap` will raise an error:
-
-```{code-cell}
-:id: 3aNmRsDrX3-2
-:outputId: a8ff4b77-f4cb-442f-8cfb-ea7251c66274
-:tags: [raises-exception]
-
-@jax.jit
-def numpy_random_like_ordered(x):
- return io_callback(host_side_random_like, x, x, ordered=True)
-
-jax.vmap(numpy_random_like_ordered)(x)
-```
-
-+++ {"id": "fD2FTHlUYAZH"}
-
-On the other hand, `scan` and `while_loop` work with `io_callback` regardless of whether ordering is enforced:
-
-```{code-cell}
-:id: lMVzZlIEWL7F
-:outputId: f9741c18-a30d-4d46-b706-8102849286b5
-
-def body_fun(_, x):
- return _, numpy_random_like_ordered(x)
-jax.lax.scan(body_fun, None, jnp.arange(5.0))[1]
-```
-
-+++ {"id": "w_sf8mCbbo8K"}
-
-Like `pure_callback`, `io_callback` fails under automatic differentiation if it is passed a differentiated variable:
-
-```{code-cell}
-:id: Cn6_RG4JcKZm
-:outputId: 336ae5d2-e35b-4fe5-cbfb-14a7aef28c07
-:tags: [raises-exception]
-
-jax.grad(numpy_random_like)(x)
-```
-
-+++ {"id": "plvfn9lWcKu4"}
-
-However, if the callback is not dependent on a differentiated variable, it will execute:
-
-```{code-cell}
-:id: wxgfDmDfb5bx
-:outputId: d8c0285c-cd04-4b4d-d15a-1b07f778882d
-
-@jax.jit
-def f(x):
- io_callback(lambda: print('hello'), None)
- return x
-
-jax.grad(f)(1.0);
-```
-
-+++ {"id": "STLI40EZcVIY"}
-
-Unlike `pure_callback`, the compiler will not remove the callback execution in this case, even though the output of the callback is unused in the subsequent computation.
-
-+++ {"id": "pkkM1ZmqclV-"}
-
-### Exploring `debug.callback`
-
-Both `pure_callback` and `io_callback` enforce some assumptions about the purity of the function they're calling, and limit in various ways what JAX transforms and compilation machinery may do. `debug.callback` essentially assumes *nothing* about the callback function, such that the action of the callback reflects exactly what JAX is doing during the course of a program. Further, `debug.callback` *cannot* return any value to the program.
-
-```{code-cell}
-:id: 74TdWyu9eqBa
-:outputId: d8551dab-2e61-492e-9ac3-dc3db51b2c18
-
-from jax import debug
-
-def log_value(x):
- # This could be an actual logging call; we'll use
- # print() for demonstration
- print("log:", x)
-
-@jax.jit
-def f(x):
- debug.callback(log_value, x)
- return x
-
-f(1.0);
-```
-
-+++ {"id": "P848STlsfzmW"}
-
-The debug callback is compatible with `vmap`:
-
-```{code-cell}
-:id: 2sSNsPB-fGVI
-:outputId: fff58575-d94c-48fb-b88a-c1c395595fd0
-
-x = jnp.arange(5.0)
-jax.vmap(f)(x);
-```
-
-+++ {"id": "VDMacqpXf3La"}
-
-And is also compatible with `grad` and other autodiff transformations
-
-```{code-cell}
-:id: wkFRle-tfTDe
-:outputId: 4e8a81d0-5012-4c51-d843-3fbdc498df31
-
-jax.grad(f)(1.0);
-```
-
-+++ {"id": "w8t-SDZ3gRzE"}
-
-This can make `debug.callback` more useful for general-purpose debugging than either `pure_callback` or `io_callback`.
-
-+++ {"id": "dF7hoWGQUneJ"}
-
-## Example: `pure_callback` with `custom_jvp`
-
-One powerful way to take advantage of {func}`jax.pure_callback` is to combine it with {class}`jax.custom_jvp` (see [Custom derivative rules](https://jax.readthedocs.io/en/latest/notebooks/Custom_derivative_rules_for_Python_code.html) for more details on `custom_jvp`).
-Suppose we want to create a JAX-compatible wrapper for a scipy or numpy function that is not yet available in the `jax.scipy` or `jax.numpy` wrappers.
-
-Here, we'll consider creating a wrapper for the Bessel function of the first kind, implemented in `scipy.special.jv`.
-We can start by defining a straightforward `pure_callback`:
-
-```{code-cell}
-:id: Ge4fNPZdVSJY
-
-import jax
-import jax.numpy as jnp
-import scipy.special
-
-def jv(v, z):
- v, z = jnp.asarray(v), jnp.asarray(z)
-
- # Require the order v to be integer type: this simplifies
- # the JVP rule below.
- assert jnp.issubdtype(v.dtype, jnp.integer)
-
- # Promote the input to inexact (float/complex).
- # Note that jnp.result_type() accounts for the enable_x64 flag.
- z = z.astype(jnp.result_type(float, z.dtype))
-
- # Wrap scipy function to return the expected dtype.
- _scipy_jv = lambda v, z: scipy.special.jv(v, z).astype(z.dtype)
-
- # Define the expected shape & dtype of output.
- result_shape_dtype = jax.ShapeDtypeStruct(
- shape=jnp.broadcast_shapes(v.shape, z.shape),
- dtype=z.dtype)
-
- # We use vectorize=True because scipy.special.jv handles broadcasted inputs.
- return jax.pure_callback(_scipy_jv, result_shape_dtype, v, z, vectorized=True)
-```
-
-+++ {"id": "vyjQj-0QVuoN"}
-
-This lets us call into `scipy.special.jv` from transformed JAX code, including when transformed by `jit` and `vmap`:
-
-```{code-cell}
-:id: f4e46670f4e4
-
-from functools import partial
-j1 = partial(jv, 1)
-z = jnp.arange(5.0)
-```
-
-```{code-cell}
-:id: 6svImqFHWBwj
-:outputId: bc8c778a-6c10-443b-9be2-c0f28e2ac1a9
-
-print(j1(z))
-```
-
-+++ {"id": "d48eb4f2d48e"}
-
-Here is the same result with `jit`:
-
-```{code-cell}
-:id: txvRqR9DWGdC
-:outputId: d25f3476-23b1-48e4-dda1-3c06d32c3b87
-
-print(jax.jit(j1)(z))
-```
-
-+++ {"id": "d861a472d861"}
-
-And here is the same result again with `vmap`:
-
-```{code-cell}
-:id: BS-Ve5u_WU0C
-:outputId: 08cecd1f-6953-4853-e9db-25a03eb5b000
-
-print(jax.vmap(j1)(z))
-```
-
-+++ {"id": "SCH2ii_dWXP6"}
-
-However, if we call `jax.grad`, we see an error because there is no autodiff rule defined for this function:
-
-```{code-cell}
-:id: q3qh_4DrWxdQ
-:outputId: c46b0bfa-96f3-4629-b9af-a4d4f3ccb870
-:tags: [raises-exception]
-
-jax.grad(j1)(z)
-```
-
-+++ {"id": "PtYeJ_xUW09v"}
-
-Let's define a custom gradient rule for this. Looking at the definition of the [Bessel Function of the First Kind](https://en.wikipedia.org/?title=Bessel_function_of_the_first_kind), we find that there is a relatively straightforward recurrence relationship for the derivative with respect to the argument `z`:
-
-$$
-d J_\nu(z) = \left\{
-\begin{eqnarray}
--J_1(z),\ &\nu=0\\
-[J_{\nu - 1}(z) - J_{\nu + 1}(z)]/2,\ &\nu\ne 0
-\end{eqnarray}\right.
-$$
-
-The gradient with respect to $\nu$ is more complicated, but since we've restricted the `v` argument to integer types we don't need to worry about its gradient for the sake of this example.
-
-We can use `jax.custom_jvp` to define this automatic differentiation rule for our callback function:
-
-```{code-cell}
-:id: BOVQnt05XvLs
-
-jv = jax.custom_jvp(jv)
-
-@jv.defjvp
-def _jv_jvp(primals, tangents):
- v, z = primals
- _, z_dot = tangents # Note: v_dot is always 0 because v is integer.
- jv_minus_1, jv_plus_1 = jv(v - 1, z), jv(v + 1, z)
- djv_dz = jnp.where(v == 0, -jv_plus_1, 0.5 * (jv_minus_1 - jv_plus_1))
- return jv(v, z), z_dot * djv_dz
-```
-
-+++ {"id": "W1SxcvQSX44c"}
-
-Now computing the gradient of our function will work correctly:
-
-```{code-cell}
-:id: sCGceBs-X8nL
-:outputId: 71c5589f-f996-44a0-f09a-ca8bb40c167a
-
-j1 = partial(jv, 1)
-print(jax.grad(j1)(2.0))
-```
-
-+++ {"id": "gWQ4phN5YB26"}
-
-Further, since we've defined our gradient in terms of `jv` itself, JAX's architecture means that we get second-order and higher derivatives for free:
-
-```{code-cell}
-:id: QTe5mRAvYQBh
-:outputId: d58ecff3-9419-422a-fd0e-14a7d9cf2cc3
-
-jax.hessian(j1)(2.0)
-```
-
-+++ {"id": "QEXGxU4uYZii"}
-
-Keep in mind that although this all works correctly with JAX, each call to our callback-based `jv` function will result in passing the input data from the device to the host, and passing the output of `scipy.special.jv` from the host back to the device.
-When running on accelerators like GPU or TPU, this data movement and host synchronization can lead to significant overhead each time `jv` is called.
-However, if you are running JAX on a single CPU (where the "host" and "device" are on the same hardware), JAX will generally do this data transfer in a fast, zero-copy fashion, making this pattern is a relatively straightforward way extend JAX's capabilities.
diff --git a/docs/notebooks/neural_network_with_tfds_data.ipynb b/docs/notebooks/neural_network_with_tfds_data.ipynb
index 95c00bf1e689..c31a99746866 100644
--- a/docs/notebooks/neural_network_with_tfds_data.ipynb
+++ b/docs/notebooks/neural_network_with_tfds_data.ipynb
@@ -36,15 +36,15 @@
"id": "B_XlLLpcWjkA"
},
"source": [
- "# Training a Simple Neural Network, with tensorflow/datasets Data Loading\n",
+ "# Training a simple neural network, with tensorflow/datasets data loading\n",
"\n",
"\n",
"\n",
- "[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/jax/blob/main/docs/notebooks/neural_network_with_tfds_data.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/google/jax/blob/main/docs/notebooks/neural_network_with_tfds_data.ipynb)\n",
+ "[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jax-ml/jax/blob/main/docs/notebooks/neural_network_with_tfds_data.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/jax-ml/jax/blob/main/docs/notebooks/neural_network_with_tfds_data.ipynb)\n",
"\n",
"_Forked from_ `neural_network_and_data_loading.ipynb`\n",
"\n",
- "![JAX](https://raw.githubusercontent.com/google/jax/main/images/jax_logo_250px.png)\n",
+ "![JAX](https://raw.githubusercontent.com/jax-ml/jax/main/images/jax_logo_250px.png)\n",
"\n",
"Let's combine everything we showed in the [quickstart](https://jax.readthedocs.io/en/latest/quickstart.html) to train a simple neural network. We will first specify and train a simple MLP on MNIST using JAX for the computation. We will use `tensorflow/datasets` data loading API to load images and labels (because it's pretty great, and the world doesn't need yet another data loading library :P).\n",
"\n",
@@ -132,7 +132,7 @@
" for w, b in params[:-1]:\n",
" outputs = jnp.dot(w, activations) + b\n",
" activations = relu(outputs)\n",
- " \n",
+ "\n",
" final_w, final_b = params[-1]\n",
" logits = jnp.dot(final_w, activations) + final_b\n",
" return logits - logsumexp(logits)"
@@ -251,7 +251,7 @@
"def one_hot(x, k, dtype=jnp.float32):\n",
" \"\"\"Create a one-hot encoding of x of size k.\"\"\"\n",
" return jnp.array(x[:, None] == jnp.arange(k), dtype)\n",
- " \n",
+ "\n",
"def accuracy(params, images, targets):\n",
" target_class = jnp.argmax(targets, axis=1)\n",
" predicted_class = jnp.argmax(batched_predict(params, images), axis=1)\n",
@@ -274,7 +274,7 @@
"id": "umJJGZCC2oKl"
},
"source": [
- "## Data Loading with `tensorflow/datasets`\n",
+ "## Data loading with `tensorflow/datasets`\n",
"\n",
"JAX is laser-focused on program transformations and accelerator-backed NumPy, so we don't include data loading or munging in the JAX library. There are already a lot of great data loaders out there, so let's just use them instead of reinventing anything. We'll use the `tensorflow/datasets` data loader."
]
@@ -344,7 +344,7 @@
"id": "xxPd6Qw3Z98v"
},
"source": [
- "## Training Loop"
+ "## Training loop"
]
},
{
diff --git a/docs/notebooks/neural_network_with_tfds_data.md b/docs/notebooks/neural_network_with_tfds_data.md
index 8f795484d5b9..53b7d47358c2 100644
--- a/docs/notebooks/neural_network_with_tfds_data.md
+++ b/docs/notebooks/neural_network_with_tfds_data.md
@@ -5,7 +5,7 @@ jupytext:
extension: .md
format_name: myst
format_version: 0.13
- jupytext_version: 1.16.1
+ jupytext_version: 1.16.4
kernelspec:
display_name: Python 3
language: python
@@ -34,15 +34,15 @@ limitations under the License.
+++ {"id": "B_XlLLpcWjkA"}
-# Training a Simple Neural Network, with tensorflow/datasets Data Loading
+# Training a simple neural network, with tensorflow/datasets data loading
-[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/jax/blob/main/docs/notebooks/neural_network_with_tfds_data.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/google/jax/blob/main/docs/notebooks/neural_network_with_tfds_data.ipynb)
+[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jax-ml/jax/blob/main/docs/notebooks/neural_network_with_tfds_data.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/jax-ml/jax/blob/main/docs/notebooks/neural_network_with_tfds_data.ipynb)
_Forked from_ `neural_network_and_data_loading.ipynb`
-![JAX](https://raw.githubusercontent.com/google/jax/main/images/jax_logo_250px.png)
+![JAX](https://raw.githubusercontent.com/jax-ml/jax/main/images/jax_logo_250px.png)
Let's combine everything we showed in the [quickstart](https://jax.readthedocs.io/en/latest/quickstart.html) to train a simple neural network. We will first specify and train a simple MLP on MNIST using JAX for the computation. We will use `tensorflow/datasets` data loading API to load images and labels (because it's pretty great, and the world doesn't need yet another data loading library :P).
@@ -104,7 +104,7 @@ def predict(params, image):
for w, b in params[:-1]:
outputs = jnp.dot(w, activations) + b
activations = relu(outputs)
-
+
final_w, final_b = params[-1]
logits = jnp.dot(final_w, activations) + final_b
return logits - logsumexp(logits)
@@ -164,7 +164,7 @@ At this point, we have all the ingredients we need to define our neural network
def one_hot(x, k, dtype=jnp.float32):
"""Create a one-hot encoding of x of size k."""
return jnp.array(x[:, None] == jnp.arange(k), dtype)
-
+
def accuracy(params, images, targets):
target_class = jnp.argmax(targets, axis=1)
predicted_class = jnp.argmax(batched_predict(params, images), axis=1)
@@ -183,7 +183,7 @@ def update(params, x, y):
+++ {"id": "umJJGZCC2oKl"}
-## Data Loading with `tensorflow/datasets`
+## Data loading with `tensorflow/datasets`
JAX is laser-focused on program transformations and accelerator-backed NumPy, so we don't include data loading or munging in the JAX library. There are already a lot of great data loaders out there, so let's just use them instead of reinventing anything. We'll use the `tensorflow/datasets` data loader.
@@ -229,7 +229,7 @@ print('Test:', test_images.shape, test_labels.shape)
+++ {"id": "xxPd6Qw3Z98v"}
-## Training Loop
+## Training loop
```{code-cell} ipython3
:id: X2DnZo3iYj18
diff --git a/docs/notebooks/shard_map.ipynb b/docs/notebooks/shard_map.ipynb
index 919690d230ab..1315783c340c 100644
--- a/docs/notebooks/shard_map.ipynb
+++ b/docs/notebooks/shard_map.ipynb
@@ -5,10 +5,12 @@
"id": "41a7e222",
"metadata": {},
"source": [
- "# SPMD multi-device parallelism with `shard_map`\n",
+ "# Manual parallelism with `shard_map`\n",
"\n",
"\n",
"\n",
+ "## Overview\n",
+ "\n",
"`shard_map` is a single-program multiple-data (SPMD) multi-device parallelism API to map a function over shards of data. Mapped function applications, or _instances_, communicate with each other via explicit collective communication operations.\n",
"\n",
"`shard_map` is complementary to, and composable with, the automatic compiler-based parallelization built into `jit`. With `jit` you write code as if for a single device, and [the compiler can automatically partition computation over multiple devices](https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html), generating per-device code and communication collectives behind the scenes. With `shard_map` you take control, writing your own partitioned code and explicit collectives. Or you can do a bit of both: take manual control across groups of devices while leaving within-group device partitioning up to the compiler. The two approaches can be mixed, matched, and composed as needed.\n",
@@ -36,7 +38,7 @@
"id": "97c57a94",
"metadata": {},
"source": [
- "## So, let's see a `shard_map`!\n",
+ "### So, let's see a `shard_map`!\n",
"\n",
"Without further ado, here's a toy example:"
]
@@ -189,9 +191,9 @@
"id": "532fe5f6",
"metadata": {},
"source": [
- "## Slow down, start with the basics!\n",
+ "### Slow down, start with the basics!\n",
"\n",
- "### Rank-reducing vs rank-preserving maps\n",
+ "#### Rank-reducing vs rank-preserving maps\n",
"\n",
"We can think of `vmap` and `pmap` as unstacking each array input along an axis\n",
"(e.g. unpacking a 2D matrix into its 1D rows), applying its body function to\n",
@@ -274,7 +276,7 @@
"over 4 devices) then semantically we get 4 logical applications of the\n",
"function, corresponding to the 4 devices physically computing them.\n",
"\n",
- "### Controlling how each input is split (unconcatenated) and tiled with `in_specs`\n",
+ "#### Controlling how each input is split (unconcatenated) and tiled with `in_specs`\n",
"\n",
"Each of the `in_specs` identifies some of the corresponding input array's axes\n",
"with mesh axes by name using `PartitionSpec`s, representing how to split (or\n",
@@ -354,7 +356,7 @@
"Physical data movement is possible on inputs, as each device needs to have a\n",
"copy of the appropriate data.\n",
"\n",
- "### Controlling how each output assembled by concatenation, block transposition, and untiling using `out_specs`\n",
+ "#### Controlling how each output assembled by concatenation, block transposition, and untiling using `out_specs`\n",
"\n",
"Analogously to the input side, each of the `out_specs` identifies some of the\n",
"corresponding output array's axes with mesh axes by name, representing how the\n",
@@ -482,7 +484,7 @@
"`Array`s, or physically how to interpret the buffers across devices as the\n",
"physical layout of a single logical `Array`.\n",
"\n",
- "# API Specification\n",
+ "## API Specification\n",
"\n",
"```python\n",
"from jax.sharding import Mesh\n",
@@ -508,7 +510,7 @@
"the corresponding `PartitionSpec` `spec` as roughly\n",
"`tuple(sz // (1 if n is None else mesh.shape[n]) for sz, n in zip(shape, spec))`.\n",
"\n",
- "# Collectives tutorial\n",
+ "## Collectives tutorial\n",
"\n",
"A `shard_map` need not be a pure map: function applications can communicate\n",
"with each other via _collectives_, using axis names defined in the `mesh`\n",
@@ -529,7 +531,7 @@
"\n",
"```python\n",
"def f_shmapped_ref(x):\n",
- " x_blocks = jnp.array_split(x, mesh.shape[0])\n",
+ " x_blocks = jnp.array_split(x, mesh.shape['i'])\n",
" y_blocks = [f(x_blk) for x_blk in x_blocks]\n",
" return jnp.concatenate(y_blocks)\n",
"```\n",
@@ -572,7 +574,7 @@
"means communication across devices. Exactly what communication happens, and\n",
"what values are computed, depend on the collective.\n",
"\n",
- "## `psum`\n",
+ "### `psum`\n",
"\n",
"The simplest collective may be `jax.lax.psum`, which computes an\n",
"all-reduce-sum along a device mesh axis (or multiple axes).\n",
@@ -714,7 +716,7 @@
"In the sequel, we'll see how `psum` can be implemented in terms of other\n",
"primitives, which gives some intuition about its communication cost.\n",
"\n",
- "## `all_gather`\n",
+ "### `all_gather`\n",
"\n",
"Another fundamental operation is gathering array shards along an axis, so that\n",
"each function application has a full copy of the data along that axis:\n",
@@ -796,7 +798,7 @@
"In deep learning, we might use `all_gather`s on parameters in fully sharded\n",
"data parallelism (FSDP).\n",
"\n",
- "## `psum_scatter`\n",
+ "### `psum_scatter`\n",
"\n",
"The `jax.lax.psum_scatter` collective is a bit less intuitive. It's like\n",
"`psum` except each function instance gets only one shard of the result:\n",
@@ -871,7 +873,7 @@
"multiplies or fully-sharded data parallel gradient accumulation, as shown in\n",
"the examples to follow.\n",
"\n",
- "## `ppermute`\n",
+ "### `ppermute`\n",
"\n",
"The `jax.lax.ppermute` collective provides the most direct way for\n",
"function instances to send data to one another. Given a mesh axis and a\n",
@@ -998,7 +1000,7 @@
"spatial axes and thus devices must communicate \"halos\" to each other. Or it\n",
"may be used under-the-hood in tensor-parallel matrix multiplies.\n",
"\n",
- "## `all_to_all`\n",
+ "### `all_to_all`\n",
"\n",
"A final collective is `all_to_all`, which is essentially a block matrix\n",
"transpose operating along one positional axis and one cross-device axis:\n",
@@ -1059,12 +1061,12 @@
"where we first sort our local batch of examples according to which expert they\n",
"should go to, then apply an `all_to_all` to redistribute examples to experts.\n",
"\n",
- "# Toy examples\n",
+ "## Toy examples\n",
"\n",
"How might we use `shard_map` and collective communication in practice? These\n",
"examples, while simple, give some idea.\n",
"\n",
- "## Matrix multiplies\n",
+ "### Matrix multiplies\n",
"\n",
"Parallelizing matrix multiplication is central in scaling up deep learning\n",
"models, both for training and for inference. When `jax.jit` automatically\n",
@@ -1107,7 +1109,7 @@
"id": "2e2b33b9",
"metadata": {},
"source": [
- "### Example 1: `all-gather` on one side\n",
+ "#### Example 1: `all-gather` on one side\n",
"\n",
"Consider performing a matrix multiplication where we shard the left-hand side\n",
"argument (can think: parameters) on its leading (non-contracting) dimension:"
@@ -1301,7 +1303,7 @@
"`jax.lax.fori_loop`. We might also have additional axes of parallelism\n",
"involved.\n",
"\n",
- "### Example 2: `psum_scatter` the result\n",
+ "#### Example 2: `psum_scatter` the result\n",
"\n",
"Another sharding we might start with has both `lhs` and `rhs` sharded along\n",
"their contracting dimensions, with the output sharded like `rhs` again:"
@@ -1446,7 +1448,7 @@
"id": "60c2d2bc",
"metadata": {},
"source": [
- "## Neural networks\n",
+ "### Neural networks\n",
"\n",
"We can use `shard_map` to parallelize computation in neural networks, either by\n",
"itself or in combination with the automatic partitioning in `jax.jit`. This\n",
@@ -1483,20 +1485,20 @@
"outputs": [],
"source": [
"def init_layer(key, n_in, n_out):\n",
- " k1, k2 = jax.random.split(key)\n",
- " W = jax.random.normal(k1, (n_in, n_out)) / jnp.sqrt(n_in)\n",
- " b = jax.random.normal(k2, (n_out,))\n",
- " return W, b\n",
+ " k1, k2 = jax.random.split(key)\n",
+ " W = jax.random.normal(k1, (n_in, n_out)) / jnp.sqrt(n_in)\n",
+ " b = jax.random.normal(k2, (n_out,))\n",
+ " return W, b\n",
"\n",
"def init(key, layer_sizes, batch_size):\n",
- " key, *keys = jax.random.split(key, len(layer_sizes))\n",
- " params = list(map(init_layer, keys, layer_sizes[:-1], layer_sizes[1:]))\n",
+ " key, *keys = jax.random.split(key, len(layer_sizes))\n",
+ " params = list(map(init_layer, keys, layer_sizes[:-1], layer_sizes[1:]))\n",
"\n",
- " key, *keys = jax.random.split(key, 3)\n",
- " inputs = jax.random.normal(keys[0], (batch_size, layer_sizes[0]))\n",
- " targets = jax.random.normal(keys[1], (batch_size, layer_sizes[-1]))\n",
+ " key, *keys = jax.random.split(key, 3)\n",
+ " inputs = jax.random.normal(keys[0], (batch_size, layer_sizes[0]))\n",
+ " targets = jax.random.normal(keys[1], (batch_size, layer_sizes[-1]))\n",
"\n",
- " return params, (inputs, targets)"
+ " return params, (inputs, targets)"
]
},
{
@@ -1509,7 +1511,7 @@
"layer_sizes = [784, 128, 128, 128, 128, 128, 8]\n",
"batch_size = 32\n",
"\n",
- "params, batch = init(jax.random.PRNGKey(0), layer_sizes, batch_size)"
+ "params, batch = init(jax.random.key(0), layer_sizes, batch_size)"
]
},
{
@@ -1524,7 +1526,7 @@
"functions to use different parallelization strategies, with `shard_map` we\n",
"often do.\n",
"\n",
- "### 8-way batch data parallelism\n",
+ "#### 8-way batch data parallelism\n",
"\n",
"The simplest multi-device parallelism strategy is to shard the batch of inputs\n",
"and targets over multiple devices, replicate the parameters over those devices,\n",
@@ -1608,7 +1610,7 @@
"end of the forward pass to compute the loss value, and in the backward pass to\n",
"compute the total parameter gradients.\n",
"\n",
- "### 8-way fully sharded data parallelism (FSDP)\n",
+ "#### 8-way fully sharded data parallelism (FSDP)\n",
"\n",
"Another strategy is to additionally shard the parameters over the devices,\n",
"all-gathering each one when the full value is needed for the `jnp.dot` or bias\n",
@@ -1697,7 +1699,7 @@
"id": "f88ddefe",
"metadata": {},
"source": [
- "### 8-way tensor parallelism (TP)\n",
+ "#### 8-way tensor parallelism (TP)\n",
"\n",
"Usually we don't use tensor model parallelism by itself, but seeing it in\n",
"isolation is a good warmup on parallel matrix multiplication. It's also a good\n",
@@ -1750,7 +1752,7 @@
"id": "cf59d537",
"metadata": {},
"source": [
- "### FSDP + TP, with `shard_map` at the top level\n",
+ "#### FSDP + TP, with `shard_map` at the top level\n",
"\n",
"We can compose these strategies together, using multiple axes of parallelism."
]
@@ -1821,7 +1823,7 @@
"id": "94a352ca",
"metadata": {},
"source": [
- "### SPMD pipeline parallelism (PP)\n",
+ "#### SPMD pipeline parallelism (PP)\n",
"\n",
"With pipeline parallelism we aim to parallelize the evaluation of layers at\n",
"different depths in our network. For example, one device might compute the\n",
diff --git a/docs/notebooks/shard_map.md b/docs/notebooks/shard_map.md
index 6f9dfbb659e1..96667e709ac6 100644
--- a/docs/notebooks/shard_map.md
+++ b/docs/notebooks/shard_map.md
@@ -7,17 +7,19 @@ jupytext:
extension: .md
format_name: myst
format_version: 0.13
- jupytext_version: 1.16.1
+ jupytext_version: 1.16.4
kernelspec:
display_name: Python 3
language: python
name: python3
---
-# SPMD multi-device parallelism with `shard_map`
+# Manual parallelism with `shard_map`
+## Overview
+
`shard_map` is a single-program multiple-data (SPMD) multi-device parallelism API to map a function over shards of data. Mapped function applications, or _instances_, communicate with each other via explicit collective communication operations.
`shard_map` is complementary to, and composable with, the automatic compiler-based parallelization built into `jit`. With `jit` you write code as if for a single device, and [the compiler can automatically partition computation over multiple devices](https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html), generating per-device code and communication collectives behind the scenes. With `shard_map` you take control, writing your own partitioned code and explicit collectives. Or you can do a bit of both: take manual control across groups of devices while leaving within-group device partitioning up to the compiler. The two approaches can be mixed, matched, and composed as needed.
@@ -33,7 +35,7 @@ import os
os.environ["XLA_FLAGS"] = '--xla_force_host_platform_device_count=8' # Use 8 CPU devices
```
-## So, let's see a `shard_map`!
+### So, let's see a `shard_map`!
Without further ado, here's a toy example:
@@ -120,9 +122,9 @@ print('b blocks:'); jax.debug.visualize_array_sharding(b)
print('c blocks:'); jax.debug.visualize_array_sharding(c)
```
-## Slow down, start with the basics!
+### Slow down, start with the basics!
-### Rank-reducing vs rank-preserving maps
+#### Rank-reducing vs rank-preserving maps
We can think of `vmap` and `pmap` as unstacking each array input along an axis
(e.g. unpacking a 2D matrix into its 1D rows), applying its body function to
@@ -181,7 +183,7 @@ by any input axis size: for example, if we have a mesh of total size 4 (i.e.
over 4 devices) then semantically we get 4 logical applications of the
function, corresponding to the 4 devices physically computing them.
-### Controlling how each input is split (unconcatenated) and tiled with `in_specs`
+#### Controlling how each input is split (unconcatenated) and tiled with `in_specs`
Each of the `in_specs` identifies some of the corresponding input array's axes
with mesh axes by name using `PartitionSpec`s, representing how to split (or
@@ -237,7 +239,7 @@ along the first axis, and used the pspec `P(('j', 'i'), None)`.
Physical data movement is possible on inputs, as each device needs to have a
copy of the appropriate data.
-### Controlling how each output assembled by concatenation, block transposition, and untiling using `out_specs`
+#### Controlling how each output assembled by concatenation, block transposition, and untiling using `out_specs`
Analogously to the input side, each of the `out_specs` identifies some of the
corresponding output array's axes with mesh axes by name, representing how the
@@ -329,7 +331,7 @@ Instead, `out_specs` just encodes how to assemble the block outputs into
`Array`s, or physically how to interpret the buffers across devices as the
physical layout of a single logical `Array`.
-# API Specification
+## API Specification
```python
from jax.sharding import Mesh
@@ -355,7 +357,7 @@ from the shape `shape` of the corresponding argument to `shard_map`-of-`f` and
the corresponding `PartitionSpec` `spec` as roughly
`tuple(sz // (1 if n is None else mesh.shape[n]) for sz, n in zip(shape, spec))`.
-# Collectives tutorial
+## Collectives tutorial
A `shard_map` need not be a pure map: function applications can communicate
with each other via _collectives_, using axis names defined in the `mesh`
@@ -376,7 +378,7 @@ values, as this reference function:
```python
def f_shmapped_ref(x):
- x_blocks = jnp.array_split(x, mesh.shape[0])
+ x_blocks = jnp.array_split(x, mesh.shape['i'])
y_blocks = [f(x_blk) for x_blk in x_blocks]
return jnp.concatenate(y_blocks)
```
@@ -419,7 +421,7 @@ collective introduces some amount of cross-block dependence. Physically, that
means communication across devices. Exactly what communication happens, and
what values are computed, depend on the collective.
-## `psum`
+### `psum`
The simplest collective may be `jax.lax.psum`, which computes an
all-reduce-sum along a device mesh axis (or multiple axes).
@@ -513,7 +515,7 @@ have a `grad` inside the `shard_map`ped function body, total gradients.
In the sequel, we'll see how `psum` can be implemented in terms of other
primitives, which gives some intuition about its communication cost.
-## `all_gather`
+### `all_gather`
Another fundamental operation is gathering array shards along an axis, so that
each function application has a full copy of the data along that axis:
@@ -571,7 +573,7 @@ def all_gather_ref(_, x_blocks, *, tiled=False):
In deep learning, we might use `all_gather`s on parameters in fully sharded
data parallelism (FSDP).
-## `psum_scatter`
+### `psum_scatter`
The `jax.lax.psum_scatter` collective is a bit less intuitive. It's like
`psum` except each function instance gets only one shard of the result:
@@ -634,7 +636,7 @@ machine learning, `psum_scatter` can be used in tensor-parallel matrix
multiplies or fully-sharded data parallel gradient accumulation, as shown in
the examples to follow.
-## `ppermute`
+### `ppermute`
The `jax.lax.ppermute` collective provides the most direct way for
function instances to send data to one another. Given a mesh axis and a
@@ -731,7 +733,7 @@ parallelizing the evaluation of convolutional layers, where we shard over
spatial axes and thus devices must communicate "halos" to each other. Or it
may be used under-the-hood in tensor-parallel matrix multiplies.
-## `all_to_all`
+### `all_to_all`
A final collective is `all_to_all`, which is essentially a block matrix
transpose operating along one positional axis and one cross-device axis:
@@ -780,12 +782,12 @@ In deep learning, we might use `all_to_all` in mixture-of-expert routing,
where we first sort our local batch of examples according to which expert they
should go to, then apply an `all_to_all` to redistribute examples to experts.
-# Toy examples
+## Toy examples
How might we use `shard_map` and collective communication in practice? These
examples, while simple, give some idea.
-## Matrix multiplies
+### Matrix multiplies
Parallelizing matrix multiplication is central in scaling up deep learning
models, both for training and for inference. When `jax.jit` automatically
@@ -810,7 +812,7 @@ def device_put(x, pspec):
return jax.device_put(x, NamedSharding(mesh, pspec))
```
-### Example 1: `all-gather` on one side
+#### Example 1: `all-gather` on one side
Consider performing a matrix multiplication where we shard the left-hand side
argument (can think: parameters) on its leading (non-contracting) dimension:
@@ -926,7 +928,7 @@ In practice, to reduce compile times we would probably roll this into a
`jax.lax.fori_loop`. We might also have additional axes of parallelism
involved.
-### Example 2: `psum_scatter` the result
+#### Example 2: `psum_scatter` the result
Another sharding we might start with has both `lhs` and `rhs` sharded along
their contracting dimensions, with the output sharded like `rhs` again:
@@ -1011,7 +1013,7 @@ out = matmul_psumscatter_overlapped_bidi(lhs, rhs)
print(jnp.allclose(out, lhs @ rhs, atol=1e-3, rtol=1e-3))
```
-## Neural networks
+### Neural networks
We can use `shard_map` to parallelize computation in neural networks, either by
itself or in combination with the automatic partitioning in `jax.jit`. This
@@ -1035,27 +1037,27 @@ def loss(params, batch):
```{code-cell}
def init_layer(key, n_in, n_out):
- k1, k2 = jax.random.split(key)
- W = jax.random.normal(k1, (n_in, n_out)) / jnp.sqrt(n_in)
- b = jax.random.normal(k2, (n_out,))
- return W, b
+ k1, k2 = jax.random.split(key)
+ W = jax.random.normal(k1, (n_in, n_out)) / jnp.sqrt(n_in)
+ b = jax.random.normal(k2, (n_out,))
+ return W, b
def init(key, layer_sizes, batch_size):
- key, *keys = jax.random.split(key, len(layer_sizes))
- params = list(map(init_layer, keys, layer_sizes[:-1], layer_sizes[1:]))
+ key, *keys = jax.random.split(key, len(layer_sizes))
+ params = list(map(init_layer, keys, layer_sizes[:-1], layer_sizes[1:]))
- key, *keys = jax.random.split(key, 3)
- inputs = jax.random.normal(keys[0], (batch_size, layer_sizes[0]))
- targets = jax.random.normal(keys[1], (batch_size, layer_sizes[-1]))
+ key, *keys = jax.random.split(key, 3)
+ inputs = jax.random.normal(keys[0], (batch_size, layer_sizes[0]))
+ targets = jax.random.normal(keys[1], (batch_size, layer_sizes[-1]))
- return params, (inputs, targets)
+ return params, (inputs, targets)
```
```{code-cell}
layer_sizes = [784, 128, 128, 128, 128, 128, 8]
batch_size = 32
-params, batch = init(jax.random.PRNGKey(0), layer_sizes, batch_size)
+params, batch = init(jax.random.key(0), layer_sizes, batch_size)
```
Compare these examples with the purely [automatic partitioning examples in the
@@ -1065,7 +1067,7 @@ While in those automatic partitioning examples we don't need to edit the model
functions to use different parallelization strategies, with `shard_map` we
often do.
-### 8-way batch data parallelism
+#### 8-way batch data parallelism
The simplest multi-device parallelism strategy is to shard the batch of inputs
and targets over multiple devices, replicate the parameters over those devices,
@@ -1119,7 +1121,7 @@ that the collective all-reduce-sum operations happen where we'd expect: at the
end of the forward pass to compute the loss value, and in the backward pass to
compute the total parameter gradients.
-### 8-way fully sharded data parallelism (FSDP)
+#### 8-way fully sharded data parallelism (FSDP)
Another strategy is to additionally shard the parameters over the devices,
all-gathering each one when the full value is needed for the `jnp.dot` or bias
@@ -1184,7 +1186,7 @@ print(allclose(jax.jit(jax.grad(loss))(params, batch),
jax.jit(jax.grad(loss_fsdp))(params, batch)))
```
-### 8-way tensor parallelism (TP)
+#### 8-way tensor parallelism (TP)
Usually we don't use tensor model parallelism by itself, but seeing it in
isolation is a good warmup on parallel matrix multiplication. It's also a good
@@ -1225,7 +1227,7 @@ def loss_tp(params, batch):
return jnp.mean(jnp.sum((predictions - targets) ** 2, axis=-1)) # NOTE psum!
```
-### FSDP + TP, with `shard_map` at the top level
+#### FSDP + TP, with `shard_map` at the top level
We can compose these strategies together, using multiple axes of parallelism.
@@ -1272,7 +1274,7 @@ print(allclose(jax.jit(jax.grad(loss))(params, batch),
jax.jit(jax.grad(loss_fsdp_tp))(params, batch)))
```
-### SPMD pipeline parallelism (PP)
+#### SPMD pipeline parallelism (PP)
With pipeline parallelism we aim to parallelize the evaluation of layers at
different depths in our network. For example, one device might compute the
diff --git a/docs/notebooks/thinking_in_jax.ipynb b/docs/notebooks/thinking_in_jax.ipynb
index 1c1c9729b654..b5f8074c0f3e 100644
--- a/docs/notebooks/thinking_in_jax.ipynb
+++ b/docs/notebooks/thinking_in_jax.ipynb
@@ -6,11 +6,11 @@
"id": "LQHmwePqryRU"
},
"source": [
- "# How to Think in JAX\n",
+ "# How to think in JAX\n",
"\n",
"\n",
"\n",
- "[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/jax/blob/main/docs/notebooks/thinking_in_jax.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/google/jax/blob/main/docs/notebooks/thinking_in_jax.ipynb)\n",
+ "[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jax-ml/jax/blob/main/docs/notebooks/thinking_in_jax.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/jax-ml/jax/blob/main/docs/notebooks/thinking_in_jax.ipynb)\n",
"\n",
"JAX provides a simple and powerful API for writing accelerated numerical code, but working effectively in JAX sometimes requires extra consideration. This document is meant to help build a ground-up understanding of how JAX operates, so that you can use it more effectively."
]
@@ -23,7 +23,7 @@
"source": [
"## JAX vs. NumPy\n",
"\n",
- "**Key Concepts:**\n",
+ "**Key concepts:**\n",
"\n",
"- JAX provides a NumPy-inspired interface for convenience.\n",
"- Through duck-typing, JAX arrays can often be used as drop-in replacements of NumPy arrays.\n",
@@ -282,7 +282,7 @@
"source": [
"## NumPy, lax & XLA: JAX API layering\n",
"\n",
- "**Key Concepts:**\n",
+ "**Key concepts:**\n",
"\n",
"- `jax.numpy` is a high-level wrapper that provides a familiar interface.\n",
"- `jax.lax` is a lower-level API that is stricter and often more powerful.\n",
@@ -475,7 +475,7 @@
"source": [
"## To JIT or not to JIT\n",
"\n",
- "**Key Concepts:**\n",
+ "**Key concepts:**\n",
"\n",
"- By default JAX executes operations one at a time, in sequence.\n",
"- Using a just-in-time (JIT) compilation decorator, sequences of operations can be optimized together and run at once.\n",
@@ -675,7 +675,7 @@
"source": [
"## JIT mechanics: tracing and static variables\n",
"\n",
- "**Key Concepts:**\n",
+ "**Key concepts:**\n",
"\n",
"- JIT and other JAX transforms work by *tracing* a function to determine its effect on inputs of a specific shape and type.\n",
"\n",
@@ -932,9 +932,9 @@
"id": "r-RCl_wD5lI7"
},
"source": [
- "## Static vs Traced Operations\n",
+ "## Static vs traced operations\n",
"\n",
- "**Key Concepts:**\n",
+ "**Key concepts:**\n",
"\n",
"- Just as values can be either static or traced, operations can be static or traced.\n",
"\n",
diff --git a/docs/notebooks/thinking_in_jax.md b/docs/notebooks/thinking_in_jax.md
index 14089fa36e32..b3672b90e653 100644
--- a/docs/notebooks/thinking_in_jax.md
+++ b/docs/notebooks/thinking_in_jax.md
@@ -5,7 +5,7 @@ jupytext:
extension: .md
format_name: myst
format_version: 0.13
- jupytext_version: 1.16.1
+ jupytext_version: 1.16.4
kernelspec:
display_name: Python 3
name: python3
@@ -13,11 +13,11 @@ kernelspec:
+++ {"id": "LQHmwePqryRU"}
-# How to Think in JAX
+# How to think in JAX
-[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/jax/blob/main/docs/notebooks/thinking_in_jax.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/google/jax/blob/main/docs/notebooks/thinking_in_jax.ipynb)
+[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jax-ml/jax/blob/main/docs/notebooks/thinking_in_jax.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/jax-ml/jax/blob/main/docs/notebooks/thinking_in_jax.ipynb)
JAX provides a simple and powerful API for writing accelerated numerical code, but working effectively in JAX sometimes requires extra consideration. This document is meant to help build a ground-up understanding of how JAX operates, so that you can use it more effectively.
@@ -25,7 +25,7 @@ JAX provides a simple and powerful API for writing accelerated numerical code, b
## JAX vs. NumPy
-**Key Concepts:**
+**Key concepts:**
- JAX provides a NumPy-inspired interface for convenience.
- Through duck-typing, JAX arrays can often be used as drop-in replacements of NumPy arrays.
@@ -132,7 +132,7 @@ print(y)
## NumPy, lax & XLA: JAX API layering
-**Key Concepts:**
+**Key concepts:**
- `jax.numpy` is a high-level wrapper that provides a familiar interface.
- `jax.lax` is a lower-level API that is stricter and often more powerful.
@@ -215,7 +215,7 @@ Every JAX operation is eventually expressed in terms of these fundamental XLA op
## To JIT or not to JIT
-**Key Concepts:**
+**Key concepts:**
- By default JAX executes operations one at a time, in sequence.
- Using a just-in-time (JIT) compilation decorator, sequences of operations can be optimized together and run at once.
@@ -308,7 +308,7 @@ This is because the function generates an array whose shape is not known at comp
## JIT mechanics: tracing and static variables
-**Key Concepts:**
+**Key concepts:**
- JIT and other JAX transforms work by *tracing* a function to determine its effect on inputs of a specific shape and type.
@@ -417,9 +417,9 @@ Understanding which values and operations will be static and which will be trace
+++ {"id": "r-RCl_wD5lI7"}
-## Static vs Traced Operations
+## Static vs traced operations
-**Key Concepts:**
+**Key concepts:**
- Just as values can be either static or traced, operations can be static or traced.
diff --git a/docs/notebooks/vmapped_log_probs.ipynb b/docs/notebooks/vmapped_log_probs.ipynb
index 96b334296667..dccc83168ac0 100644
--- a/docs/notebooks/vmapped_log_probs.ipynb
+++ b/docs/notebooks/vmapped_log_probs.ipynb
@@ -6,11 +6,11 @@
"id": "6umP1IKf4Dg6"
},
"source": [
- "# Autobatching for Bayesian Inference\n",
+ "# Autobatching for Bayesian inference\n",
"\n",
"\n",
"\n",
- "[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/jax/blob/main/docs/notebooks/vmapped_log_probs.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/google/jax/blob/main/docs/notebooks/vmapped_log_probs.ipynb)\n",
+ "[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jax-ml/jax/blob/main/docs/notebooks/vmapped_log_probs.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/jax-ml/jax/blob/main/docs/notebooks/vmapped_log_probs.ipynb)\n",
"\n",
"This notebook demonstrates a simple Bayesian inference example where autobatching makes user code easier to write, easier to read, and less likely to include bugs.\n",
"\n",
@@ -25,17 +25,10 @@
},
"outputs": [],
"source": [
- "import functools\n",
- "import itertools\n",
- "import re\n",
- "import sys\n",
- "import time\n",
- "\n",
- "from matplotlib.pyplot import *\n",
+ "import matplotlib.pyplot as plt\n",
"\n",
"import jax\n",
"\n",
- "from jax import lax\n",
"import jax.numpy as jnp\n",
"import jax.scipy as jsp\n",
"from jax import random\n",
@@ -348,7 +341,7 @@
"def elbo(beta_loc, beta_log_scale, epsilon):\n",
" beta_sample = beta_loc + jnp.exp(beta_log_scale) * epsilon\n",
" return jnp.mean(batched_log_joint(beta_sample), 0) + jnp.sum(beta_log_scale - 0.5 * np.log(2*np.pi))\n",
- " \n",
+ "\n",
"elbo = jax.jit(elbo)\n",
"elbo_val_and_grad = jax.jit(jax.value_and_grad(elbo, argnums=(0, 1)))"
]
@@ -548,25 +541,16 @@
}
],
"source": [
- "figure(figsize=(7, 7))\n",
- "plot(true_beta, beta_loc, '.', label='Approximated Posterior Means')\n",
- "plot(true_beta, beta_loc + 2*jnp.exp(beta_log_scale), 'r.', label='Approximated Posterior $2\\sigma$ Error Bars')\n",
- "plot(true_beta, beta_loc - 2*jnp.exp(beta_log_scale), 'r.')\n",
+ "plt.figure(figsize=(7, 7))\n",
+ "plt.plot(true_beta, beta_loc, '.', label='Approximated Posterior Means')\n",
+ "plt.plot(true_beta, beta_loc + 2*jnp.exp(beta_log_scale), 'r.', label=r'Approximated Posterior $2\\sigma$ Error Bars')\n",
+ "plt.plot(true_beta, beta_loc - 2*jnp.exp(beta_log_scale), 'r.')\n",
"plot_scale = 3\n",
- "plot([-plot_scale, plot_scale], [-plot_scale, plot_scale], 'k')\n",
- "xlabel('True beta')\n",
- "ylabel('Estimated beta')\n",
- "legend(loc='best')"
+ "plt.plot([-plot_scale, plot_scale], [-plot_scale, plot_scale], 'k')\n",
+ "plt.xlabel('True beta')\n",
+ "plt.ylabel('Estimated beta')\n",
+ "plt.legend(loc='best')"
]
- },
- {
- "cell_type": "code",
- "execution_count": 0,
- "metadata": {
- "id": "_bXdOlvUEJl0"
- },
- "outputs": [],
- "source": []
}
],
"metadata": {
diff --git a/docs/notebooks/vmapped_log_probs.md b/docs/notebooks/vmapped_log_probs.md
index ea8b4fce2f70..3f836e680e88 100644
--- a/docs/notebooks/vmapped_log_probs.md
+++ b/docs/notebooks/vmapped_log_probs.md
@@ -5,7 +5,7 @@ jupytext:
extension: .md
format_name: myst
format_version: 0.13
- jupytext_version: 1.16.1
+ jupytext_version: 1.16.4
kernelspec:
display_name: Python 3
language: python
@@ -14,11 +14,11 @@ kernelspec:
+++ {"id": "6umP1IKf4Dg6"}
-# Autobatching for Bayesian Inference
+# Autobatching for Bayesian inference
-[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/jax/blob/main/docs/notebooks/vmapped_log_probs.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/google/jax/blob/main/docs/notebooks/vmapped_log_probs.ipynb)
+[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jax-ml/jax/blob/main/docs/notebooks/vmapped_log_probs.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/jax-ml/jax/blob/main/docs/notebooks/vmapped_log_probs.ipynb)
This notebook demonstrates a simple Bayesian inference example where autobatching makes user code easier to write, easier to read, and less likely to include bugs.
@@ -27,17 +27,10 @@ Inspired by a notebook by @davmre.
```{code-cell} ipython3
:id: 8RZDkfbV3zdR
-import functools
-import itertools
-import re
-import sys
-import time
-
-from matplotlib.pyplot import *
+import matplotlib.pyplot as plt
import jax
-from jax import lax
import jax.numpy as jnp
import jax.scipy as jsp
from jax import random
@@ -192,7 +185,7 @@ batched_log_joint = jax.jit(jax.vmap(log_joint))
def elbo(beta_loc, beta_log_scale, epsilon):
beta_sample = beta_loc + jnp.exp(beta_log_scale) * epsilon
return jnp.mean(batched_log_joint(beta_sample), 0) + jnp.sum(beta_log_scale - 0.5 * np.log(2*np.pi))
-
+
elbo = jax.jit(elbo)
elbo_val_and_grad = jax.jit(jax.value_and_grad(elbo, argnums=(0, 1)))
```
@@ -240,19 +233,13 @@ Coverage isn't quite as good as we might like, but it's not bad, and nobody said
:id: zt1NBLoVHtOG
:outputId: fb159795-e6e7-497c-e501-9933ec761af4
-figure(figsize=(7, 7))
-plot(true_beta, beta_loc, '.', label='Approximated Posterior Means')
-plot(true_beta, beta_loc + 2*jnp.exp(beta_log_scale), 'r.', label='Approximated Posterior $2\sigma$ Error Bars')
-plot(true_beta, beta_loc - 2*jnp.exp(beta_log_scale), 'r.')
+plt.figure(figsize=(7, 7))
+plt.plot(true_beta, beta_loc, '.', label='Approximated Posterior Means')
+plt.plot(true_beta, beta_loc + 2*jnp.exp(beta_log_scale), 'r.', label=r'Approximated Posterior $2\sigma$ Error Bars')
+plt.plot(true_beta, beta_loc - 2*jnp.exp(beta_log_scale), 'r.')
plot_scale = 3
-plot([-plot_scale, plot_scale], [-plot_scale, plot_scale], 'k')
-xlabel('True beta')
-ylabel('Estimated beta')
-legend(loc='best')
-```
-
-```{code-cell} ipython3
-:id: _bXdOlvUEJl0
-
-
+plt.plot([-plot_scale, plot_scale], [-plot_scale, plot_scale], 'k')
+plt.xlabel('True beta')
+plt.ylabel('Estimated beta')
+plt.legend(loc='best')
```
diff --git a/docs/pallas/CHANGELOG.md b/docs/pallas/CHANGELOG.md
index e43a178db50e..43ba3ebd6afb 100644
--- a/docs/pallas/CHANGELOG.md
+++ b/docs/pallas/CHANGELOG.md
@@ -11,15 +11,31 @@ For the overall JAX change log see [here](https://jax.readthedocs.io/en/latest/c
Remember to align the itemized text with the first line of an item within a list.
-->
-## Released with jax 0.4.32
+## Released with jax 0.4.34
* Changes
- * The kernel function is not allowed to close over constants. Instead, all the needed arrays
- must be passed as inputs, with proper block specs ({jax-issue}`#22746`).
+
+ * {func}`jax.experimental.pallas.debug_print` no longer requires all arguments
+ to be scalars. The restrictions on the arguments are backend-specific:
+ Non-scalar arguments are currently only supported on GPU, when using Triton.
* Deprecations
-* New functionality:
+* New functionality
+
+ * {func}`jax.experimental.pallas.pallas_call` now accepts `scratch_shapes`,
+ a PyTree specifying backend-specific temporary objects needed by the
+ kernel, for example, buffers, synchronization primitives etc.
+
+## Released with jax 0.4.33 (September 16, 2024)
+
+## Released with jax 0.4.32 (September 11, 2024)
+
+* Changes
+ * The kernel function is not allowed to close over constants. Instead, all the needed arrays
+ must be passed as inputs, with proper block specs ({jax-issue}`#22746`).
+
+* New functionality
* Improved error messages for mistakes in the signature of the index map functions,
to include the name and source location of the index map.
@@ -36,18 +52,14 @@ Remember to align the itemized text with the first line of an item within a list
* The method `compute_index` of {class}`jax.experimental.pallas.GridSpec` has
been removed because it is private. Similarly, the `get_grid_mapping` and
`unzip_dynamic_bounds` have been removed from `BlockSpec` ({jax-issue}`#22593`).
- * Fixed the interpreter mode to work with BlockSpec that involve padding
+ * Fixed the interpret mode to work with BlockSpec that involve padding
({jax-issue}`#22275`).
- Padding in interpreter mode will be with NaN, to help debug out-of-bounds
+ Padding in interpret mode will be with NaN, to help debug out-of-bounds
errors, but this behavior is not present when running in custom kernel mode,
and should not be depended on.
* Previously it was possible to import many APIs that are meant to be
private, as `jax.experimental.pallas.pallas`. This is not possible anymore.
-
-* Deprecations
-
-
* New Functionality
* Added documentation for BlockSpec: {ref}`pallas_grids_and_blockspecs`.
* Improved error messages for the {func}`jax.experimental.pallas.pallas_call`
@@ -73,7 +85,3 @@ Remember to align the itemized text with the first line of an item within a list
* Added checkify support for {func}`jax.experimental.pallas.pallas_call` in
interpret mode ({jax-issue}`#21862`).
* Improved support for PRNG keys for TPU kernels ({jax-issue}`#21773`).
-
-
-
-
diff --git a/docs/pallas/async_note.md b/docs/pallas/async_note.md
new file mode 100644
index 000000000000..42e32a074fd7
--- /dev/null
+++ b/docs/pallas/async_note.md
@@ -0,0 +1,675 @@
+# Pallas Async Operations
+
+## Background \+ Motivation
+
+We’d like to expose APIs in Pallas to explicitly overlap computation and communication *across multiple kernels*.
+
+### XLA Async Decomposition
+
+As motivation, consider the following JAX pseudocode:
+
+```py
+def f(x):
+ y = ppermute(x)
+ z = x + 1
+ return y, z
+```
+
+In this function, we could perform the `ppermute` at the same time as the `x + 1`. This is an optimization XLA does automatically by:
+
+1. decomposing `ppermute` into a `ppermute_start` and `ppermute_done` op, which are connected via a future.
+2. scheduling the `x + 1` between the `ppermute_start` and `ppermute_done`,
+
+resulting in the following program:
+
+```py
+def f(x):
+ fut = ppermute_start(x)
+ z = x + 1 # happens at the same time as ppermute
+ y = ppermute_done(fut)
+ return y, z
+```
+
+### Async ops inside kernels
+
+Now imagine we aren’t using XLA’s `ppermute` but have our own custom Pallas `ppermute`.
+
+```py
+def ppermute_kernel(x_ref, y_ref, send_sem, recv_sem):
+ right_neighbor = ...
+ descriptor = pltpu.make_async_remote_copy(x_ref, y_ref, send_sem, recv_sem, device_id=right_neighbor)
+ descriptor.start()
+ descriptor.wait_send()
+ descriptor.wait_recv()
+
+def ppermute(x):
+ return pl.pallas_call(ppermute_kernel, out_shape=x, ...)(x)
+```
+
+Currently, we cannot decompose `ppermute` into a `start/done` pair as XLA does, so instead we explicitly **fuse** the `x + 1` into the kernel.
+
+```py
+def add_one(x_ref, z_ref):
+ z_ref[...] = x_ref[...] + 1
+
+def ppermute_add_one_kernel(x_ref, y_ref, z_ref, send_sem, recv_sem):
+ right_neighbor = ...
+ descriptor = pltpu.make_async_remote_copy(x_ref, y_ref, send_sem, recv_sem, device_id=right_neighbor)
+ descriptor.start()
+
+ # Explicitly schedule inner kernel between start/wait
+ pltpu.emit_pipeline(add_one)(x_ref, z_ref)
+
+ descriptor.wait_send()
+ descriptor.wait_recv()
+
+def ppermute_and_add_one(x):
+ return pl.pallas_call(ppermute_add_one_kernel, out_shape=(x, x), ...)(x)
+
+```
+
+The goal is to enable writing separate kernels for starting the `ppermute` and waiting on it to complete, so that we can use a regular old `x + 1` in between (or whatever compute we want). This makes the code more readable, maintainable, and less bug-prone.
+
+## How do we implement decomposed Pallas async operations (on TPU)?
+
+The main thing to figure out when implementing decomposed async operations in Pallas is what the `future` that is passed between them contains. Specifically, it must contain some important state about the operation happening in the background.
+
+If we look at the Pallas code, we can see that we need a “descriptor” to both start and wait on a remote copy. Can we plumb this descriptor out of the Pallas kernel, and then pass it into another one? Well kinda. The underlying TPU hardware tracks async op progress via a pair of semaphores: `send_sem` enables us to wait on when a device is done sending data to its neighbor and `recv_sem` tracks the data transfer sent to a device from their neighbor. If we imagine writing a start kernel and a done kernel, all we’d need to pass from the start to the done would be the semaphores and some information about how much to wait on those semaphores.
+
+We can do this via extending Pallas to support returning semaphores from kernels.
+
+```py
+def ppermute_start_kernel(
+ in_ref, send_sem, recv_sem, out_ref, *, axis_name,
+):
+ axis_size = jax.lax.psum(1, axis_name)
+ left_neighbor = jax.lax.rem(
+ jax.lax.axis_index(axis_name) - 1 + axis_size, axis_size
+ )
+ right_neighbor = jax.lax.rem(jax.lax.axis_index(axis_name) + 1, axis_size)
+ barrier_sem = pltpu.get_barrier_semaphore()
+ pltpu.semaphore_signal(barrier_sem, device_id=left_neighbor)
+ pltpu.semaphore_wait(barrier_sem, 1)
+ pltpu.make_async_remote_copy(
+ in_ref, out_ref, send_sem, recv_sem, device_id=right_neighbor
+ ).start()
+
+def ppermute_start(x, *, axis_name) -> tuple[Semaphore, Semaphore, Array]:
+ send_sem, recv_sem, out = pl.pallas_call(
+ functools.partial(ppermute_start_kernel, axis_name=axis_name),
+ out_shape=(
+ pltpu.SemaphoreType.DMA(()),
+ pltpu.SemaphoreType.DMA(()),
+ jax.ShapeDtypeStruct(
+ x.shape,
+ dtype=x.dtype,
+ ),
+ ),
+ in_specs=[
+ pl.BlockSpec(memory_space=pltpu.ANY),
+ ],
+ out_specs=(
+ pl.BlockSpec(memory_space=pltpu.SEMAPHORE),
+ pl.BlockSpec(memory_space=pltpu.SEMAPHORE),
+ pl.BlockSpec(memory_space=pltpu.ANY),
+ ),
+ )(x)
+ return send_sem, recv_sem, out
+```
+
+Note that something subtle is happening here. Pallas is telling XLA that it would like some outputs to be semaphores (a.k.a. sync flags) and XLA will treat them as “reserved” (e.g. while they are alive in the XLA program, those sync flags cannot be allocated by other kernels). They behave similarly to barrier semaphores, which are reserved semaphores managed by XLA.
+
+Another thing to notice is that we return the output buffer `out` from the start kernel *while it’s being actively copied into*.
+
+Now we write the `done` kernel that performs the blocking operation. We pass `out` into the kernel to compute the shape needed to block on the semaphore.
+
+```py
+def ppermute_done_kernel(ref, send_sem, recv_sem, _):
+ pltpu.make_async_copy(ref, ref, send_sem).wait()
+ pltpu.make_async_copy(ref, ref, recv_sem).wait()
+
+def ppermute_done(send_sem, recv_sem, out) ->Array:
+ out = pl.pallas_call(
+ ppermute_done_kernel,
+ out_shape=(
+ jax.ShapeDtypeStruct(
+ out.shape,
+ dtype=out.dtype,
+ ),
+ ),
+ in_specs=[
+ pl.BlockSpec(memory_space=pltpu.ANY),
+ pl.BlockSpec(memory_space=pltpu.SEMAPHORE),
+ pl.BlockSpec(memory_space=pltpu.SEMAPHORE),
+ ],
+ out_specs=pl.BlockSpec(memory_space=pltpu.ANY),
+ input_output_aliases={0:0}
+ )(out, send_sem, recv_sem)
+ return out
+```
+
+Note: we i/o alias the output buffer here to guarantee that the consumers are downstream of the `ppermute_done`.
+
+We now can implement the decomposed collective permute.
+
+```py
+def f(x):
+ fut = ppermute_start(x)
+ z = x + 1 # happens at the same time as ppermute
+ y = ppermute_done(fut)
+ return y, z
+```
+
+***OR CAN WE?***
+
+## Why *doesn’t* this work?
+
+There are three remaining issues with this, each of which exists outside of Pallas to some degree. Here they are at a high level.
+
+1. Scheduling \- just because we write `ppermute_start`, then `x + 1`, then `ppermute_done` doesn’t guarantee that they will happen in that order. XLA is responsible for scheduling, so when we write JAX programs, we are setting up data dependencies that XLA will respect but XLA will not respect the specific order of operations written in JAX.
+2. Lifetimes \- XLA assumes that once a value is out of scope in the dependency graph, its memory can be freed for use by other values. If we have an op that asynchronously copies x \-\> y, we need to ensure that x is alive until the copy is complete, otherwise we will be copying from garbage memory.
+3. Defensive copies \- XLA reserves the right to create copies of values. We need to make sure we don’t introduce unnecessary copies to a) avoid unnecessary runtime overhead and b) ensure correctness.
+
+We will go over these issues one by one and suggest fixes.
+
+### Scheduling
+
+How do we explicitly force ops to happen in a particular order in JAX? Note that this is not a Pallas specific problem, and if we had async ops implemented using an alternative method, we’d still run into this.
+
+One way is to introduce an optimization barrier into the XLA program. The optimization barrier will prevent XLA moving ops around it.
+
+Here’s our original code:
+
+```py
+def f(x):
+ fut = ppermute_start(x)
+ z = x + 1
+ y = ppermute_done(fut)
+ return y, z
+```
+
+XLA could choose to execute `x + 1` in any of three places:
+
+```py
+def f(x):
+ z = x + 1
+ fut = ppermute_start(x)
+ y = ppermute_done(fut)
+ return y, z
+
+# OR
+
+def f(x):
+ fut = ppermute_start(x)
+ z = x + 1
+ y = ppermute_done(fut)
+ return y, z
+
+# OR
+
+def f(x):
+ fut = ppermute_start(x)
+ y = ppermute_done(fut)
+ z = x + 1
+ return y, z
+```
+
+To force the `x + 1` to happen between the `ppermute` ops, we can use `optimization_barrier`, which is semantically the identity function (i.e. `lambda x: x`) but introduces an explicit data dependency between values. Specifically, if we make the `x` that is used in `x + 1` dependent on the `fut` returned by `ppermute_start`, it must happen after `ppermute_start`.
+
+We also introduce a dependency that forces the output value `y` to depend on `z`.
+
+```py
+def f(x):
+ fut = ppermute_start(x)
+ x, fut = optimization_barrier((x, fut)) # x now depends on fut
+ z = x + 1
+ z, fut = optimization_barrier((z, fut)) # fut now depends on z
+ y = ppermute_done(fut)
+ return y, z
+```
+
+`optimization_barrier` is a good enough hammer for us to explicitly write out schedules.
+
+### Lifetimes
+
+Let’s look at our original code again and assume the ops are happening in the correct order.
+
+```py
+def f(x):
+ fut = ppermute_start(x)
+ z = x + 1
+ y = ppermute_done(fut)
+ return y, z
+```
+
+Let’s look at which point in the program XLA believes it is okay to free the buffer for `x`. It would be the point after which `x` is no longer used, specifically after `z = x + 1`.
+
+```py
+def f(x):
+ fut = ppermute_start(x)
+ z = x + 1
+ # XLA can free x here!
+ y = ppermute_done(fut)
+ return y, z
+```
+
+If XLA frees `x` after `z = x + 1` has completed, we run into a very bad problem. The `ppermute` could still be actively copying `x` to the neighbor after `z = x + 1` which means if `x` is freed, the `ppermute` will be reading from garbage memory\!
+
+How do we extend `x`’s lifetime to the `ppermute_done`? Well we can introduce a data dependency\! We need to modify our kernels a little bit to make this happen.
+
+First, we rewrite `ppermute_start` to return `x`, aliasing it through the kernel.
+
+```py
+def ppermute_start_kernel(
+ in_ref, send_sem, recv_sem, out_ref, _, *, axis_name,
+):
+ axis_size = jax.lax.psum(1, axis_name)
+ left_neighbor = jax.lax.rem(
+ jax.lax.axis_index(axis_name) - 1 + axis_size, axis_size
+ )
+ right_neighbor = jax.lax.rem(jax.lax.axis_index(axis_name) + 1, axis_size)
+ barrier_sem = pltpu.get_barrier_semaphore()
+ pltpu.semaphore_signal(barrier_sem, device_id=left_neighbor)
+ pltpu.semaphore_wait(barrier_sem, 1)
+ pltpu.make_async_remote_copy(
+ in_ref, out_ref, send_sem, recv_sem, device_id=right_neighbor
+ ).start()
+
+def ppermute_start(x, *, axis_name) -> tuple[Semaphore, Semaphore, Array, Array]:
+ send_sem, recv_sem, x, out = pl.pallas_call(
+ functools.partial(ppermute_start_kernel, axis_name=axis_name),
+ out_shape=(
+ pltpu.SemaphoreType.DMA(()),
+ pltpu.SemaphoreType.DMA(()),
+ jax.ShapeDtypeStruct(
+ x.shape,
+ dtype=x.dtype,
+ ),
+ jax.ShapeDtypeStruct(
+ x.shape,
+ dtype=x.dtype,
+ ),
+ ),
+ in_specs=[
+ pl.BlockSpec(memory_space=pltpu.ANY),
+ ],
+ out_specs=(
+ pl.BlockSpec(memory_space=pltpu.SEMAPHORE),
+ pl.BlockSpec(memory_space=pltpu.SEMAPHORE),
+ pl.BlockSpec(memory_space=pltpu.ANY),
+ pl.BlockSpec(memory_space=pltpu.ANY),
+ ),
+ input_output_aliases={0:2}
+ )(x)
+ return send_sem, recv_sem, x, out
+```
+
+We then have `ppermute_done` take in `x` and do nothing with it.
+
+```py
+def ppermute_done_kernel(_, ref, send_sem, recv_sem, _):
+ pltpu.make_async_copy(ref, ref, send_sem).wait()
+ pltpu.make_async_copy(ref, ref, recv_sem).wait()
+
+def ppermute_done(send_sem, recv_sem, x, out) ->Array:
+ out = pl.pallas_call(
+ ppermute_done_kernel,
+ out_shape=(
+ jax.ShapeDtypeStruct(
+ out.shape,
+ dtype=out.dtype,
+ ),
+ ),
+ in_specs=[
+ pl.BlockSpec(memory_space=pltpu.ANY),
+ pl.BlockSpec(memory_space=pltpu.ANY),
+ pl.BlockSpec(memory_space=pltpu.SEMAPHORE),
+ pl.BlockSpec(memory_space=pltpu.SEMAPHORE),
+ ],
+ out_specs=pl.BlockSpec(memory_space=pltpu.ANY),
+ input_output_aliases={1:0}
+ )(x, out, send_sem, recv_sem)
+ return out
+
+```
+
+Now when we write
+
+```py
+def f(x):
+ *sems, x ,out = ppermute_start(x)
+ z = x + 1
+ y = ppermute_done(*sems, x, out)
+ return y, z
+```
+
+XLA can no longer free `x` because it is an input to `ppermute_done`\! This means that `x`’s lifetime is tied to the `ppermute` and this code is now correct.
+
+### Defensive copies
+
+XLA, in its buffer assignment pass, analyzes which buffers are aliased to each other and inserts copies whenever an operation that aliases one of its inputs is not the final consumer of that input.
+
+#### Background
+
+Here’s a simple example. Let’s say we have an op `add_one_inplace` which takes in an array and adds one, but promises to do it in-place.
+
+The following code would be legal.
+
+```py
+def f():
+ x = jnp.arange(...)
+ y = add_one_inplace(x) return y
+```
+
+However, if `x` had a separate consumer as well, the program may not execute correctly.
+
+```py
+def f():
+ x = jnp.arange(...)
+ y = add_one_inplace(x)
+ return y, x * 2 # another x consumer!
+```
+
+This is because `x * 2` operates on the original `x` but `add_one_inplace` clobbers the value in `x`. `x * 2` needs to make sure to read the original values of `x`, not the ones after we’ve incremented it by 1\. XLA notices this and inserts a `copy` op (which is semantically the identity but the input and output buffers will be different).
+
+```py
+def f(x):
+ x2 = copy(x)
+ y = add_one_inplace(x2)
+ return y, x * 2
+```
+
+This pass in XLA ensures correctness in the presence of ops that perform in-place updates by forcing them to effectively be out-of-place with `copy` ops.
+
+#### Copies with downstream ops
+
+Let’s revisit our example where we add 1 while `ppermute`ing.
+
+```py
+def f(x):
+ fut = ppermute_start(x)
+ z = x + 1
+ y = ppermute_done(fut)
+ return y, z
+```
+
+If we unpack the future into its components, we’ll see the the aliasing patterns:
+
+```py
+def f(x):
+ *sems, x2, y = ppermute_start(x)
+ z = x + 1
+ y = ppermute_done((*sems, x2, y))
+ return y, z
+```
+
+We know that `x` is left unchanged by `ppermute_start` (that is, `x` is identical to `x2`), but XLA does not. In fact, it looks like our `add_one_inplace` example to XLA, where it conservatively assumes that `ppermute_start` mutated `x` and `x2` is the new aliased result. Therefore, when we do `z = x + 1`, we run into a consumer of the original buffer. XLA therefore introduces a copy\!
+
+```py
+def f(x):
+ x2 = copy(x)
+ *sems, x2, y = ppermute_start(x2)
+ z = x + 1
+ y = ppermute_done((*sems, x2, y))
+ return y, z
+```
+
+This copy is unnecessary because we know that `x2` is unchanged from `x`. In order to remove this copy, we’d need some mechanism to inform XLA we are just forwarding a value. However, in the absence of that we can rewrite our program a bit to explicitly use `x2` instead of `x`.
+
+```py
+def f(x):
+ *sems, x2, y = ppermute_start(x)
+ z = x2 + 1
+ y = ppermute_done((*sems, x2, y))
+ return y, z
+```
+
+Now, XLA doesn’t see a separate consumer of `x` so no more copy is introduced. However, this comes at a major downside in that it forces us to unpack the future coming from `ppermute_start`. It couples the lifetime problem to the copying problem.
+
+#### Loop aliasing
+
+Let’s consider a slightly more advanced example. Let’s implement a function that uses a `while_loop` with `ppermute` to send values around a ring.
+
+```py
+def f(x):
+ def body(i, x):
+ fut = ppermute_start(x)
+ y = ppermute_done(fut)
+ return y
+ return fori_loop(0, 8, body, x)
+```
+
+One implementation detail of `fori_loop` is that the inputs and outputs buffers are automatically aliased to each other. Note that we are setting up some additional aliasing in the `ppermute_start` and `ppermute_done` ops. Let’s run our own “buffer assignment” by coloring each of the values in the program to determine how many unique buffers we need.
+
+First, we’ll unpack the `fut` tuple that has the aliased `x` and `out` buffers.
+
+```py
+def f(x):
+ def body(i, x):
+ *sems, x, y = ppermute_start(x)
+ y = ppermute_done(*sems, x, y)
+ return y
+ return fori_loop(0, 8, body, x)
+```
+
+Let’s now color each of the values according to the unique buffer they are assigned. We have the input/output aliasing coming from `fori_loop`, the `x` aliasing coming from `ppermute_start` and the `y` aliasing coming from `ppermute_done`.
+
+```py
+def f(x):
+ def body(i, x):
+ *sems, x, y = ppermute_start(x)
+ y = ppermute_done((*sems, x, y))
+ return y
+ return fori_loop(0, 8, body, x)
+```
+
+If you run the alias analysis, you’ll find that all of the buffers have been colored the same\! Intuitively, this is problematic because if we are doing a loop of `ppermute`s, we can’t write into the same buffer we are sending into. We generally need an extra (i.e. a “double”) buffer to receive, and then usually we will switch the send/recv buffers on the next iteration. What XLA will do in practice is that it will observe the buffer re-use and defensively insert a copy.
+
+```py
+def f(x):
+ def body(i, x):
+ x = copy(x)
+ *sems, x, y = ppermute_start(x)
+ y = ppermute_done((*sems, x, y))
+ return y
+ return fori_loop(0, 8, body, x)
+```
+
+This copy means `x` and `y` are no longer aliased to each other and the program will be correct. However, do we need this copy? How do we introduce a double buffer to avoid expensive copies each iteration? The answer is unrolling\!
+
+We’ll manually unroll our code.
+
+```py
+def f(x):
+ def body(i, x):
+ *sems, x, x2 = ppermute_start(x)
+ x2 = ppermute_done((*sems, x, x2))
+
+ *sems, x2, y = ppermute_start(x2)
+ y = ppermute_done((*sems, x2, y))
+ return y
+ return fori_loop(0, 4, body, x)
+```
+
+Now if we were to run the same alias analysis, we’ll find that the buffers all no longer alias to each other and that we won’t need to insert defensive copies to be correct.
+
+Therefore, the simple solution to removing these copies is to use `fori_loop` with `unroll >= 2`.
+
+```py
+def f(x):
+ def body(i, x):
+ fut = ppermute_start(x)
+ y = ppermute_done(fut)
+ return y
+ return fori_loop(0, 8, body, x, unroll=2)
+```
+
+That’s sufficient to implement this loop without extra copies\!
+
+#### Passing futures across loop boundaries
+
+Let’s now look at an even more advanced example. We’ll implement the same program as before but stagger the loop, where we begin the `ppermute` in a prologue before the loop, and wait on the `ppermute` at the beginning of the loop.
+
+```py
+def f(x):
+ fut = ppermute_start(x)
+ def body(i, fut):
+ x = ppermute_done(fut)
+ fut = ppermute_start(x)
+ return fut
+ fut = fori_loop(0, 7, body, fut)
+ return ppermute_done(fut)
+```
+
+In this example, rather than passing a value `x` from one loop to another we are passing a future value.
+
+Let’s unpack the future again to see what’s happening.
+
+```py
+def f(x):
+ fut = ppermute_start(x)
+ def body(i, fut):
+ *sems, x, out = fut
+ x = ppermute_done((*sems, x, out))
+ (*sems, x, out) = ppermute_start(x)
+ return (*sems, x, out)
+ (*sems, x, out) = fori_loop(0, 7, body, x)
+ return ppermute_done((*sems, x, out))
+```
+
+So we’re explicitly threading the semaphores, the input buffer, and the target output buffer as a loop carry. What happens if we run alias analysis now? Well, we’ll run into the same aliasing issue as in the previous section where `x` and `out` will be aliased to each other. XLA will introduce a copy.
+
+```py
+def f(x):
+ fut = ppermute_start(x)
+ def body(i, fut):
+ *sems, x, out = fut
+ out = copy(out)
+ x = ppermute_done((*sems, x, out))
+ (*sems, x, out) = ppermute_start(x)
+ return (*sems, x, out)
+ (*sems, x, out) = fori_loop(0, 7, body, x)
+ return ppermute_done((*sems, x, out))
+```
+
+In this case, we inserted a copy on `out`. However, this is a really bad scenario because `out` is being actively copied into\! Even if we insert a copy on `x`, we will also run into issues because then `x`’s lifetime will not extend to the `ppermute_done`. This is very very bad\! We will not only get copies, but we will also get incorrect results\!
+
+The solution, as we observed before, is to avoid the copies by avoiding aliasing all the buffers via unrolling. So, if we do:
+
+```py
+def f(x):
+ fut = ppermute_start(x)
+ def body(i, fut):
+ x = ppermute_done(fut)
+ fut = ppermute_start(x)
+ return fut
+ fut = fori_loop(0, 7, body, x, unroll=2)
+ return ppermute_done(fut)
+```
+
+our program should now be correct.
+
+### Putting it all together
+
+So we’ve come up with some rules of thumb:
+
+1. If we have operations dependent on the input value to the `ppermute`, unpack the future to use the aliased value instead of the original value.
+2. Use `unroll >= 2` when doing `ppermute`s in a loop body.
+
+Let’s combine everything into one function that does `ppermute`s in a loop and accumulates the result.
+
+```py
+def f(x):
+ out = jnp.zeros_like(x)
+ fut = (*sems, x, out) = ppermute_start(x)
+ out = out + x
+ def body(i, carry):
+ out, fut = carry
+ x = ppermute_done(fut)
+ fut = (*sems, x, out) = ppermute_start(x)
+ out = out + x
+ return out, fut
+ out, fut = fori_loop(0, 7, body, (out, fut), unroll=2)
+ return out, ppermute_done(fut)
+```
+
+Note that in this example, we don’t need `optimization_barrier`s because the loop boundary acts as a scheduling barrier, splitting up the `start`s and `done`s.
+
+That’s it, we are done\! This will be the official API for doing async ops in Pallas. Thank you everyone\! Mission accomplished\!
+
+***OR IS IT?***
+
+## Revenge of the State
+
+While it seems we have worked around copies and incorrectness issues by using some clever tricks, we are still in an awkward position. This API is powerful, but has many many footguns and caveats. There are likely far many more edge cases we will need to deal with that even require deep knowledge of XLA to predict or understand. Should we release an API like this? Or is there an alternative?
+
+Well, the answer may have been in front of us this whole time.
+
+Let’s run through this whole exercise one more time, *except*, let’s write the stateful version. This means each of our custom async ops now operate on `Ref`s instead of values.
+
+```py
+def ppermute_start_stateful(x_ref, y_ref) -> tuple[Semaphore, Semaphore]:
+ ...
+
+def ppermute_done_stateful(send_sem, recv_sem, x_ref, y_ref) -> None:
+ ...
+```
+
+Let’s assume we can implement these in Pallas and see what our new programs will look like. Let’s start with a basic collective permute:
+
+```py
+def f(x):
+ x_ref = make_ref(x)
+ y_ref = make_ref(zeros_like(x))
+ fut = ppermute_start_stateful(x_ref, y_ref)
+ ppermute_done_stateful(*fut, x_ref, y_ref)
+ return y_ref[...]
+```
+
+It’s a little bit more verbose than our original value-based version, but it has a few key differences. The first is that we create an “empty” `Ref` to receive the result of the `ppermute`, unlike the value-based version, which creates a value for us. One neat thing is that the lifetime of `x_ref` is clear here: it lives until `ppermute_done_stateful`. We don’t need to “sneak” the `x` value into the op like we did before.
+
+Another difference becomes more clear when we try adding an op between the `start/done`.
+
+```py
+def f(x):
+ x_ref = make_ref(x)
+ y_ref = make_ref(zeros_like(x))
+ fut = ppermute_start_stateful(x_ref, y_ref)
+ x_ref[...] += 1
+ ppermute_done_stateful(*fut, x_ref, y_ref)
+ return y_ref[...]
+```
+
+Before, we ran into scheduling ambiguity, where XLA could re-order the add w.r.t. the `ppermute`. With stateful semantics, we actually add in an ordering constraint\! `x_ref[...] += 1` mutates `x_ref` so it can’t be moved wrt to `ppermute_done_stateful`. JAX can inject these scheduling constraints as part of the lowering to HLO.
+
+The final key difference is evident when we try our loop examples.
+
+```py
+def f(x):
+ x_ref = make_ref(x)
+ y_ref = make_ref(zeros_like(x))
+ def body(i, _):
+ fut = ppermute_start_stateful(x_ref, y_ref)
+ ppermute_done_stateful(*fut, x_ref, y_ref)
+ # Now switch to y_ref -> x_ref
+ fut = ppermute_start_stateful(y_ref, x_ref)
+ ppermute_done_stateful(*fut, y_ref, x_ref)
+ fori_loop(0, 8 // 2, body, None)
+ return x_ref[...]
+```
+
+Because of the requirement that we have a separate buffer ready to receive the `ppermute`, we were forced to write our code in such a way that unrolls it\! There is no way to write the version in XLA that requires copying because that would involve a `ppermute` that sends from a `Ref` into itself, which doesn’t really make sense.
+
+To handle this without the manual unrolling, we’d create a scratch buffer with a leading `2` dimension that acts as the send/recv target across iterations, switching each one. This is the same pattern we use internally in Pallas kernels when writing manually overlapped kernels.
+
+The realization here is that being stateful forces us to deal with a lot of the issues that pop up with value semantics earlier on. We define them away\!
+
+1. Scheduling \- stateful ops that have `Ref`s as inputs force an ordering of our program. Note that this will schedule operations on the same `Ref` wrt to each other. We might also need an `opt_barrier_stateful` to enforce more ordering constraints.
+2. Lifetimes \- `Ref` lifetimes can be scoped via `run_state` or could be inputs to stateful ops.
+3. Defensive copies \- Using `Ref`s forces us to handle buffer assignment “manually” and the lowering can ensure the aliasing works out to avoid any copies.
+
+Another important fundamental limitation is that we eventually stage out an HLO program where the live buffers and semaphores are represented as array value types. XLA does not provide guarantees about buffer lifetimes or which memory spaces they live in for these intermediate values. *Therefore, it is possible XLA can copy array values even if they are actively being copied into by Pallas kernels.* This is easy to verify in HLO but it is a sharp edge of using custom calls to represent asynchronous operations in HLO.
+
+## Conclusion
+
+We’ve gone over some tricky challenges when it comes to async ops in Pallas and JAX. `Ref`s seem like a promising way of representing these ops that circumvents some of the issues that come up with value semantics. However, a downside is that it puts stateful JAX front and center, which we haven’t done yet outside of Pallas. It’s worth thinking whether we should educate users about stateful ops, or provide a more dangerous API. We also don’t know if everything we want to do is expressible via `Ref`s as well. We should also brainstorm alternatives to state to flesh out the design space. For example, what if XLA offered a first-class futures API that respected lifetimes, and it could automatically do things like double buffer loops with futures in them? That might be a viable alternative but the tradeoff would be giving more control to the compiler vs explicit control from the user.
diff --git a/docs/pallas/index.rst b/docs/pallas/index.rst
index 403a8ce9c620..5969349c962a 100644
--- a/docs/pallas/index.rst
+++ b/docs/pallas/index.rst
@@ -3,6 +3,9 @@
Pallas: a JAX kernel language
=============================
Pallas is an extension to JAX that enables writing custom kernels for GPU and TPU.
+It aims to provide fine-grained control over the generated code, combined with
+the high-level ergonomics of JAX tracing and the `jax.numpy` API.
+
This section contains tutorials, guides and examples for using Pallas.
See also the :class:`jax.experimental.pallas` module API documentation.
@@ -10,6 +13,10 @@ See also the :class:`jax.experimental.pallas` module API documentation.
Pallas is experimental and is changing frequently.
See the :ref:`pallas-changelog` for the recent changes.
+ You can expect to encounter errors and unimplemented cases, e.g., when
+ lowering of high-level JAX concepts that would require emulation,
+ or simply because Pallas is still under development.
+
.. toctree::
:caption: Guides
:maxdepth: 2
@@ -26,6 +33,13 @@ See also the :class:`jax.experimental.pallas` module API documentation.
tpu/index
.. toctree::
+ :caption: Design Notes
+ :maxdepth: 1
+
+ async_note
+
+.. toctree::
+ :caption: Other
:maxdepth: 1
CHANGELOG
diff --git a/docs/pallas/quickstart.ipynb b/docs/pallas/quickstart.ipynb
index 5a8608f494c3..0e759a493a61 100644
--- a/docs/pallas/quickstart.ipynb
+++ b/docs/pallas/quickstart.ipynb
@@ -282,6 +282,35 @@
"On TPUs, programs are executed in a combination of parallel and sequential\n",
"(depending on the architecture) so there are slightly different considerations.\n",
"\n",
+ "To call the above kernel on TPU, run:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "796f928c",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from jax.experimental.pallas import tpu as pltpu\n",
+ "\n",
+ "def iota(size: int):\n",
+ " return pl.pallas_call(iota_kernel,\n",
+ " out_specs=pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.SMEM),\n",
+ " out_shape=jax.ShapeDtypeStruct((size,), jnp.int32),\n",
+ " grid=(size,))()\n",
+ "iota(8)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "68f97b4e",
+ "metadata": {},
+ "source": [
+ "TPUs distinguish between vector and scalar memory spaces and in this case the\n",
+ "output must be placed in scalar memory (`TPUMemorySpace.SMEM`) since `i` is\n",
+ "a scalar. For more details read {ref}`pallas_tpu_pipelining`.\n",
+ "\n",
"You can read more details at {ref}`pallas_grid`."
]
},
diff --git a/docs/pallas/quickstart.md b/docs/pallas/quickstart.md
index 36cc14bf5c34..a8b13ea38eaf 100644
--- a/docs/pallas/quickstart.md
+++ b/docs/pallas/quickstart.md
@@ -5,7 +5,7 @@ jupytext:
extension: .md
format_name: myst
format_version: 0.13
- jupytext_version: 1.16.1
+ jupytext_version: 1.16.4
kernelspec:
display_name: Python 3 (ipykernel)
language: python
@@ -188,6 +188,23 @@ operations like matrix multiplications really quickly.
On TPUs, programs are executed in a combination of parallel and sequential
(depending on the architecture) so there are slightly different considerations.
+To call the above kernel on TPU, run:
+
+```{code-cell} ipython3
+from jax.experimental.pallas import tpu as pltpu
+
+def iota(size: int):
+ return pl.pallas_call(iota_kernel,
+ out_specs=pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.SMEM),
+ out_shape=jax.ShapeDtypeStruct((size,), jnp.int32),
+ grid=(size,))()
+iota(8)
+```
+
+TPUs distinguish between vector and scalar memory spaces and in this case the
+output must be placed in scalar memory (`TPUMemorySpace.SMEM`) since `i` is
+a scalar. For more details read {ref}`pallas_tpu_pipelining`.
+
You can read more details at {ref}`pallas_grid`.
+++
diff --git a/docs/pallas/tpu/details.rst b/docs/pallas/tpu/details.rst
index ae9505c4eb8b..4a2d4daa637f 100644
--- a/docs/pallas/tpu/details.rst
+++ b/docs/pallas/tpu/details.rst
@@ -21,7 +21,7 @@ software emulation, and can slow down the computation.
If you see unexpected outputs, please compare them against a kernel run with
``interpret=True`` passed in to ``pallas_call``. If the results diverge,
- please file a `bug report `_.
+ please file a `bug report `_.
What is a TPU?
--------------
@@ -148,10 +148,8 @@ grid axes over cores. This is an opt-in procedure. To allow that,
..
pallas_call(
...,
- compiler_params=dict(
- mosaic=dict(
- dimension_semantics=["parallel", "parallel", "arbitrary"]
- )
+ compiler_params=pltpu.TPUCompilerParams(
+ dimension_semantics=["parallel", "parallel", "arbitrary"]
),
)
diff --git a/docs/pallas/tpu/distributed.ipynb b/docs/pallas/tpu/distributed.ipynb
new file mode 100644
index 000000000000..8552e10d8552
--- /dev/null
+++ b/docs/pallas/tpu/distributed.ipynb
@@ -0,0 +1,1743 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "zSNjLhGQJMgq"
+ },
+ "source": [
+ "# Distributed Computing in Pallas for TPUs\n",
+ "\n",
+ "In this tutorial, we will cover the basics of distributed computing in Pallas on TPUs. We will learn about TPU topologies, communication using the remote DMA primitive, and calling a distributed kernel from JAX using `shard_map`. We will also cover some more advanced kernel writing techniques, such as double-buffering, bi-directional bandwidth optimization, and nested pipelining. As educational examples, we will learn how to implement various collective primitives from JAX, such as `lax.ppermute`, `lax.all_gather`, `lax.psum`, and `lax.psum_scatter`.\n",
+ "\n",
+ "Some recommended readings beforehand:\n",
+ " - [Pallas Pipelining on TPU](https://jax.readthedocs.io/en/latest/pallas/tpu/pipelining.html)\n",
+ " - [Collectives with `shard_map`](https://jax.readthedocs.io/en/latest/notebooks/shard_map.html#collectives-tutorial)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "metadata": {
+ "executionInfo": {
+ "elapsed": 1978,
+ "status": "ok",
+ "timestamp": 1722904801801,
+ "user": {
+ "displayName": "Justin Fu",
+ "userId": "17543197034567316452"
+ },
+ "user_tz": 420
+ },
+ "id": "PyAGnWc9yI8T",
+ "outputId": "1d8229bd-cab5-495f-93e9-fff2e41db480"
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Running with 4 TPU v5 lite devices.\n"
+ ]
+ }
+ ],
+ "source": [
+ "import jax\n",
+ "from jax import lax\n",
+ "from jax import numpy as jnp\n",
+ "from jax.experimental import mesh_utils\n",
+ "from jax.experimental import pallas as pl\n",
+ "from jax.experimental import shard_map\n",
+ "from jax.experimental.pallas import tpu as pltpu\n",
+ "\n",
+ "P = jax.sharding.PartitionSpec\n",
+ "\n",
+ "num_devices = jax.local_device_count()\n",
+ "assert num_devices > 1, \"Please run this notebook with more than one device.\"\n",
+ "assert \"TPU\" in jax.devices()[0].device_kind, \"Please run this notebook with TPU devices.\"\n",
+ "print(f\"Running with {num_devices} {jax.devices()[0].device_kind} devices.\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "DySMGNByclMi"
+ },
+ "source": [
+ "## TPU Topologies\n",
+ "\n",
+ "TPUs are typically deployed in pods of multiple devices connected via a high-bandwidth interchip interconnect (ICI) for communication within the pod that is much faster than a typical network connection. For example, the specifications sheet for a [TPU v5p](https://cloud.google.com/tpu/docs/v5p) states an ICI bandwidth of 4.8Tb/s per chip (for reference, TPU v5p also has 21Tb/s of *local* HBM bandwidth). The ICI allows us to implement fast and performant distributed kernels that require high-bandwidth communication within a pod, and use the datacenter network for parallelization over less bandwidth-intensive operations, such as data-parallelism over a batch dimension.\n",
+ "\n",
+ "TPUs pods are typically arranged in an ND torus topology. The following graphic gives several examples of configurations of different sizes.\n",
+ "\n",
+ "![tpu_topologies](https://cloud.google.com/static/tpu/docs/images/v4-topologies.png)\n",
+ "\n",
+ "Flattened as a graph, the torus can be visualized as follows. Each edge (orange or black) is a bidirectional connection between two devices. You will commonly hear about rings in conjunction with discussion about device toplogies — a key feature of a torus is that when taking a slice along an axis of the pod, such as the nodes `[(0,1), (1, 1), (2, 1), (3, 1)]` or `[(0, 1), (1, 1)]`, we have a ring of devices. This is a feature we can use to simplify communication patterns within the pod.\n",
+ "\n",
+ "![tpu_torus](https://cloud.google.com/static/tpu/docs/images/untwisted-tori.png)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "1Oc_WD1hChfN"
+ },
+ "source": [
+ "## Remote Direct Memory Access (RDMA) Model\n",
+ "\n",
+ "TPUs communicate via a push-only model known as a remote direct memory access (RDMA). A TPU is allowed to issue copy instruction to push from a local buffer to any buffer on another device within the same pod that executes asynchronously from the main program thread. However, a TPU can only read data that is stored locally. This is in contrast to more traditional multi-core programming where it is possible to both read from and write to values to a shared memory.\n",
+ "\n",
+ "### Async Remote Copy Operation\n",
+ "The `pltpu.make_async_remote_copy` function is used to create a remote DMA descriptor object which parameterizes both a \"send\" operation and a \"receive\" operation. Here's its signature:\n",
+ "\n",
+ "```python\n",
+ " def make_async_remote_copy(\n",
+ " src_ref: Ref,\n",
+ " dst_ref: Ref,\n",
+ " send_sem: Ref[SemaphoreType],\n",
+ " recv_sem: Ref[SemaphoreType],\n",
+ " device_id: int | tuple[int, ...],\n",
+ " device_id_type: DeviceIdType\n",
+ " ) -> AsyncCopyDescriptor:\n",
+ "```\n",
+ "\n",
+ "- `src_ref` is the local `Ref` (in any memory space) containing the data you wish to send to `dst_ref` on another device.\n",
+ "- `dst_ref` is the remote `Ref` (in any memory space) at which data will be copied to on the target device.\n",
+ "- `send_sem` is a DMA semaphore used to block until all data has been sent from `src_ref`.\n",
+ "- `recv_sem` is a DMA semaphore used to block until the expected number of bytes have been received at `dst_ref`. The sender of the DMA will write to the receiver's `recv_sem`.\n",
+ "- `device_id` is the device ID of the target device to send to.\n",
+ "- `device_id_type` specifies the format of `device_id`, which can either be in LOGICAL format (integer device ID), or in MESH format (an ND-tuple index into the logical device mesh). The default mode is MESH.\n",
+ "\n",
+ "`make_async_remote_copy` returns a descriptor object on which you use the `.start()` method to initiate the DMA, and the `.wait_send()` to block on `send_sem` and `.wait_recv()` to block on `recv_sem` (or `.wait()` to block on both). If a device is only expected to send data, it is sufficient to only call `.start()` and `.wait_send()`, and likewise if a device is only receiving it is sufficient to only call `.wait_recv()`. If using a SPMD pattern where all devices execute the DMA, each device will generally call both `.start()` and `.wait()`.\n",
+ "```python\n",
+ "dma_descriptor = make_async_remote_copy(src_ref, dst_ref, send_sem, recv_sem, device_id)\n",
+ "dma_descriptor.start() # Initiate the DMA (non-blocking).\n",
+ "# ... do other work\n",
+ "dma_descriptor.wait_send() # Block until all data has been sent.\n",
+ "dma_descriptor.wait_recv() # Block until all data has been received.\n",
+ "```\n",
+ "\n",
+ "As an example, let's visualize a DMA where we consider 4 devices (indexed 0, 1, 2, 3). We consider a scheme where device 0 copies to device 1, and device 2 & 3 copy to each other. In practice, we can create such an asymmetric communication pattern by using `@pl.when` to branch on the device ID.\n",
+ "\n",
+ "(1) Each device creates the DMA descriptor. Devices 0, 2, and 3 call `.start()` to initiate the DMA from `src_ref`. Device 1 is skips the `.start()` and does nothing, e.g. by using `pl.when`.\n",
+ "\n",
+ "![rdma_start](../../_static/pallas/distributed/rdma_start.svg)\n",
+ "\n",
+ "(2) As `.start()` is non-blocking, each device is free to do other computation while the DMA is in flight. Devices 0, 2, and 3 call `.wait_send()` to wait on `send_sem` which blocks until all data has been sent.\n",
+ "\n",
+ "![rdma_send](../../_static/pallas/distributed/rdma_send.svg)\n",
+ "\n",
+ "(3) Finally, devices 1, 2, and 3 will call `.wait_recv()` to wait on `recv_sem` until all data has arrived at `dst_ref`.\n",
+ "\n",
+ "![rdma_recv](../../_static/pallas/distributed/rdma_recv.svg)\n",
+ "\n",
+ "The above communication pattern can be written as follows:\n",
+ "```python\n",
+ "def example_kernel(input_ref, output_ref, send_sem, recv_sem):\n",
+ " device_id = lax.axis_index('x')\n",
+ " copy_0_to_1 = pltpu.make_async_remote_copy(\n",
+ " src_ref=input_ref,\n",
+ " dst_ref=output_ref,\n",
+ " send_sem=send_sem,\n",
+ " recv_sem=recv_sem,\n",
+ " device_id=1,\n",
+ " )\n",
+ " copy_2_to_3 = pltpu.make_async_remote_copy(\n",
+ " src_ref=input_ref,\n",
+ " dst_ref=output_ref,\n",
+ " send_sem=send_sem,\n",
+ " recv_sem=recv_sem,\n",
+ " device_id=3,\n",
+ " )\n",
+ " copy_3_to_2 = pltpu.make_async_remote_copy(\n",
+ " src_ref=input_ref,\n",
+ " dst_ref=output_ref,\n",
+ " send_sem=send_sem,\n",
+ " recv_sem=recv_sem,\n",
+ " device_id=2,\n",
+ " )\n",
+ " @pl.when(device_id == 0)\n",
+ " def _():\n",
+ " copy_0_to_1.start()\n",
+ " copy_0_to_1.wait_send()\n",
+ " @pl.when(device_id == 1)\n",
+ " def _():\n",
+ " copy_0_to_1.wait_recv()\n",
+ " @pl.when(device_id == 2)\n",
+ " def _():\n",
+ " copy_2_to_3.start()\n",
+ " copy_2_to_3.wait_send()\n",
+ " copy_3_to_2.wait_recv()\n",
+ " @pl.when(device_id == 3)\n",
+ " def _():\n",
+ " copy_3_to_2.start()\n",
+ " copy_3_to_2.wait_send()\n",
+ " copy_2_to_3.wait_recv()\n",
+ "```\n",
+ "\n",
+ "### DMA Semaphores\n",
+ "\n",
+ "`send_sem` and `recv_sem` are instances of a special type of semaphore reserved exclusively for use with DMAs. They must be allocated with the `tpu.SemaphoreType.DMA` type when specifying input specs to `pallas_call`.\n",
+ "\n",
+ "Internally, DMA semaphores can be thought of as integer-valued progress trackers. On DMA start, the local device will begin to increment the value of `send_sem` and the receiver's `recv_sem` asynchronously. Waiting on a semaphore will block until the value of the semaphore reaches the total bytes of data sent/received; when the value is reached, waiting threads are released and the sempahore's value is decremented by the same amount. This means that either all data has been sent (for `send_sem`) or all data has been received (for `dst_sem`). The value of the semaphore can be read with `pl.semaphore_read`, but note that the underlying semantics of the value could change between hardware generations (e.g. the value may not represent exactly the number of bytes sent, although this is a useful mental model to have when reasoning about the behavior of the semaphore).\n",
+ "\n",
+ "### Routing\n",
+ "\n",
+ "A sender is allowed to send data to any receiver within the same pod, even if they do not share a direct connection (the exception to this rule is for TPU v5e, where devices can only route to a power of 2 offset from themselves). TPUs have an internal routing mechanism which can pass data along to the next device on the path to the destination. However, communicating in this way is not recommended as you have no control over network contention as a kernel writer. The examples we will cover in this tutorial minimize inefficient communication by only transferring data to neighboring devices.\n",
+ "\n",
+ "### Failure modes\n",
+ "\n",
+ "If using remote DMAs incorrectly, you may encounter several failure modes which can be difficult to debug. The general symptoms of buggy DMA usage are crashes, hanging, or silent data corruption:\n",
+ "- If semaphores exit the program with an invalid non-zero value, Pallas will crash and exit the program.\n",
+ "- If semaphores are waited on but an insufficient number of bytes are received (i.e. there is no sender, or if the sent data is less than the size of `dst_ref` on the receiving device), the program may hang indefinitely waiting for bytes that are never sent. In this case the program would need to be restarted.\n",
+ "- If encountering a race condition, there could be silent data corruption if two simultaneous writes or a simultaneous read and write occur.\n",
+ "\n",
+ "Some common causes of the above include:\n",
+ "- If a device calls `.wait_recv()` but no other device sends to it, the kernel may hang.\n",
+ "- If a device is sent a more bytes than it expected to receive, it may also crash due to non-zero semaphore states. If sent less, it may hang indefinitely.\n",
+ "- If DMAs are started but the semaphores are not waited on, the program may crash due to non-zero semaphore states.\n",
+ "- If two devices copy to the same destination, you may encounter non-deterministic results due to a race condition, or crashing due to non-zero semaphore states."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "vpGSN1Sui0Bu"
+ },
+ "source": [
+ "### Example: Right Permute (`lax.ppermute`)\n",
+ "\n",
+ "Let's dive into a very basic example. We will implement a kernel that performs a right permutation, where each device sends its slice of the data to its right neighbor.\n",
+ "\n",
+ "Suppose we had an array with 512 elements, which we shard into slices of size 128 across 4 devices. Each device will pass its slice to the next device, and the output will consist of the same data, but with the slices rotated by 1. This is identical to the `lax.ppermute` operation where the permutation is set to `(n, (n+1) % 4)`.\n",
+ "\n",
+ "In order to call the kernel in distributed mode, we wrap the `pallas_call` in a `shard_map` transformation. From there, we can write the kernel the same way as you would write a normal single-device Pallas kernel, except we now have access to remote DMA instructions. JAX collective primitives such as `lax.axis_index` can be used to obtain a `device_id` that can be used to compute which target devices to copy to, by referencing the same named axes names passed into `shard_map`."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 2,
+ "metadata": {
+ "executionInfo": {
+ "elapsed": 1606,
+ "status": "ok",
+ "timestamp": 1722904803566,
+ "user": {
+ "displayName": "Justin Fu",
+ "userId": "17543197034567316452"
+ },
+ "user_tz": 420
+ },
+ "id": "YkyIKN2thZ-V",
+ "outputId": "9b7ed142-d161-4237-fed8-cbce41adc5f0"
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Input = [0.9858954 0.11763906 0.9955574 0.775211 ]\n",
+ "Pallas Result = [0.775211 0.9858954 0.11763906 0.9955574 ]\n",
+ "lax.ppermute Result = [0.775211 0.9858954 0.11763906 0.9955574 ]\n",
+ "Difference |Pallas - lax.ppermute| = 0.0\n"
+ ]
+ }
+ ],
+ "source": [
+ "partition = P(None, 'x')\n",
+ "devices = mesh_utils.create_device_mesh((1, num_devices))\n",
+ "mesh = jax.sharding.Mesh(devices, partition)\n",
+ "sharding = jax.sharding.NamedSharding(mesh, partition)\n",
+ "\n",
+ "# Create an input array that shards the last dimension across\n",
+ "# all devices.\n",
+ "input_arr = jax.random.uniform(jax.random.key(0), (8, 128 * num_devices))\n",
+ "input_arr = jax.device_put(input_arr, sharding)\n",
+ "\n",
+ "\n",
+ "def right_permute_kernel(input_ref, output_ref, send_sem, recv_sem):\n",
+ " my_id = lax.axis_index('x')\n",
+ " right_neighbor = lax.rem(my_id + 1, num_devices)\n",
+ " remote_copy_op = pltpu.make_async_remote_copy(\n",
+ " src_ref=input_ref,\n",
+ " dst_ref=output_ref,\n",
+ " send_sem=send_sem,\n",
+ " recv_sem=recv_sem,\n",
+ " device_id=(0, right_neighbor),\n",
+ " device_id_type=pltpu.DeviceIdType.MESH,\n",
+ " )\n",
+ " remote_copy_op.start()\n",
+ " remote_copy_op.wait()\n",
+ "\n",
+ "\n",
+ "out_shape = jax.ShapeDtypeStruct((8, 128), jnp.float32)\n",
+ "grid_spec = pltpu.PrefetchScalarGridSpec(\n",
+ " num_scalar_prefetch=0,\n",
+ " # TPUMemorySpace.ANY will (usually) place the tensor in HBM.\n",
+ " in_specs=[\n",
+ " pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY),\n",
+ " ],\n",
+ " out_specs=pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY),\n",
+ " scratch_shapes=(\n",
+ " # We allocate DMA semaphores in scratch memory.\n",
+ " [pltpu.SemaphoreType.DMA] * 2\n",
+ " ),\n",
+ ")\n",
+ "right_permute = pl.pallas_call(\n",
+ " right_permute_kernel,\n",
+ " out_shape=out_shape,\n",
+ " grid_spec=grid_spec,\n",
+ ")\n",
+ "# Wrap the kernel within a shard_map to call.\n",
+ "pallas_result = jax.jit(\n",
+ " shard_map.shard_map(\n",
+ " right_permute,\n",
+ " mesh=mesh,\n",
+ " in_specs=partition,\n",
+ " out_specs=partition,\n",
+ " check_rep=False,\n",
+ " )\n",
+ ")(input_arr)\n",
+ "\n",
+ "# Compare Pallas result to XLA shard_map result.\n",
+ "perm = tuple((src, (src + 1) % num_devices) for src in range(num_devices))\n",
+ "\n",
+ "xla_result = jax.jit(\n",
+ " shard_map.shard_map(\n",
+ " lambda x: lax.ppermute(x, 'x', perm),\n",
+ " mesh=mesh, in_specs=partition, out_specs=partition)\n",
+ ")(input_arr)\n",
+ "\n",
+ "print('Input = ', input_arr[0, ::128])\n",
+ "print('Pallas Result = ', pallas_result[0, ::128])\n",
+ "print('lax.ppermute Result = ', xla_result[0, ::128])\n",
+ "print(\n",
+ " 'Difference |Pallas - lax.ppermute| = ',\n",
+ " jnp.mean(jnp.abs(pallas_result - xla_result)),\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "iyfhdGXuUnq2"
+ },
+ "source": [
+ "### Example: All-gather (`lax.all_gather`)\n",
+ "\n",
+ "In this next example we will implement the all-gather collective operation, which has a JAX equivalent in `lax.all_gather`. In contrast with the right-permute example from above which only involves a pair of source and destination neighbors, an all-gather operation requires communication between all devices and therefore we must think about how data is routed between them. The specifics of how we implement this are dictated by the device topology, for which we assume is a ring.\n",
+ "\n",
+ "#### Ring Communication Pattern\n",
+ "\n",
+ "We will write our kernel assuming a ring topology. Rings are a natural fit for TPUs as slicing along any dimension of a torus produces a ring. When writing collectives, we often only need to think about 1D slices of our torus at a time because the different dimensions of the torus are reserved for different types of parallelism (data vs. model, for example).\n",
+ "\n",
+ "The strategy we will use is to write a looped kernel, where on each iteration a device receives one slice of the sharded array from its left neighbor, and copies the previously received slice to its right neighbor. After `num_devices` iterations, each device will have a copy of the entire array in its local HBM.\n",
+ "\n",
+ "![all_gather](../../_static/pallas/distributed/all_gather.svg)\n",
+ "\n",
+ "We can re-purpose Pallas's `grid` argument to implement the loop. Rather than iterating over tiles of an array as we have done in previous tutorials, we instead set the grid to `(num_devices,)` to indicate that we want to loop over the number of devices and use `pl.program_id` to obtain the loop iteration inside of the Pallas kernel. The following code snippet demonstrates how to implement this:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 3,
+ "metadata": {
+ "executionInfo": {
+ "elapsed": 812,
+ "status": "ok",
+ "timestamp": 1722904804531,
+ "user": {
+ "displayName": "Justin Fu",
+ "userId": "17543197034567316452"
+ },
+ "user_tz": 420
+ },
+ "id": "ojQEZB5mBRqM",
+ "outputId": "e1648f54-737c-4921-ca3b-b4c639a38d2b"
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Input: (32, 128) [0.9858954 0.54248166 0.9547038 0.954962 ]\n",
+ "Pallas Result: (16, 8, 128) [0.9858954 0.54248166 0.9547038 0.954962 0.9858954 0.54248166\n",
+ " 0.9547038 0.954962 0.9858954 0.54248166 0.9547038 0.954962\n",
+ " 0.9858954 0.54248166 0.9547038 0.954962 ]\n",
+ "lax.all_gather Result: (16, 8, 128) [0.9858954 0.54248166 0.9547038 0.954962 0.9858954 0.54248166\n",
+ " 0.9547038 0.954962 0.9858954 0.54248166 0.9547038 0.954962\n",
+ " 0.9858954 0.54248166 0.9547038 0.954962 ]\n",
+ "Difference |Pallas - lax.all_gather| = 0.0\n"
+ ]
+ }
+ ],
+ "source": [
+ "partition = P('x', None)\n",
+ "devices = mesh_utils.create_device_mesh((num_devices, 1))\n",
+ "mesh = jax.sharding.Mesh(devices, partition)\n",
+ "sharding = jax.sharding.NamedSharding(mesh, partition)\n",
+ "\n",
+ "# Create an input array that shards the first dimension across\n",
+ "# all devices.\n",
+ "input_arr = jax.random.uniform(jax.random.key(0), (8 * num_devices, 128))\n",
+ "input_arr = jax.device_put(input_arr, sharding)\n",
+ "\n",
+ "\n",
+ "def all_gather_kernel(input_ref,\n",
+ " output_ref,\n",
+ " local_copy_sem,\n",
+ " send_sem,\n",
+ " recv_sems):\n",
+ " outer_step = pl.program_id(0)\n",
+ " my_id = lax.axis_index('x')\n",
+ " right_neighbor = lax.rem(my_id + 1, num_devices)\n",
+ " copy_slot = my_id - outer_step\n",
+ " copy_slot = lax.rem(copy_slot + num_devices, num_devices)\n",
+ "\n",
+ " @pl.when(outer_step == 0)\n",
+ " def _():\n",
+ " local_copy_op = pltpu.make_async_copy(\n",
+ " src_ref=input_ref,\n",
+ " dst_ref=output_ref.at[my_id],\n",
+ " sem=local_copy_sem,\n",
+ " )\n",
+ " local_copy_op.start()\n",
+ " local_copy_op.wait()\n",
+ "\n",
+ " # Copy to our right neighbor.\n",
+ " # Note that we will also be receiving data from our left neighbor,\n",
+ " # but at `copy_slot-1` rather than `copy_slot`! This makes use of the fact\n",
+ " # that the indices do not need to be symmetric between remote DMAs.\n",
+ " remote_copy_op = pltpu.make_async_remote_copy(\n",
+ " src_ref=output_ref.at[copy_slot],\n",
+ " dst_ref=output_ref.at[copy_slot],\n",
+ " send_sem=send_sem,\n",
+ " recv_sem=recv_sems.at[outer_step],\n",
+ " device_id=(right_neighbor, 0),\n",
+ " device_id_type=pltpu.DeviceIdType.MESH,\n",
+ " )\n",
+ " remote_copy_op.start()\n",
+ " remote_copy_op.wait()\n",
+ "\n",
+ "out_shape = jax.ShapeDtypeStruct((num_devices, 8, 128), jnp.float32)\n",
+ "grid_spec = pltpu.PrefetchScalarGridSpec(\n",
+ " num_scalar_prefetch=0,\n",
+ " in_specs=[\n",
+ " # TPUMemorySpace.ANY will (usually) place the tensor in HBM.\n",
+ " pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY),\n",
+ " ],\n",
+ " out_specs=pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY),\n",
+ " scratch_shapes=(\n",
+ " # DMA semaphores are allocated in scratch memory.\n",
+ " # We allocated one semaphore for a local HBM-VMEM copy,\n",
+ " # and one for the remote send semaphore.\n",
+ " [pltpu.SemaphoreType.DMA] * 2\n",
+ " # We additionally allocate one receive semaphore per device.\n",
+ " # This is to avoid situations where we have multiple\n",
+ " # DMAs in flight, as we do not want to share a receive\n",
+ " # semaphore between the DMAs.\n",
+ " + [pltpu.SemaphoreType.DMA((num_devices-1,))]\n",
+ "\n",
+ " ),\n",
+ " grid=(num_devices-1,)\n",
+ " )\n",
+ "\n",
+ "all_gather = pl.pallas_call(\n",
+ " all_gather_kernel,\n",
+ " out_shape=out_shape,\n",
+ " grid_spec=grid_spec,\n",
+ " )\n",
+ "\n",
+ "# Wrap the kernel within a shard_map to call.\n",
+ "pallas_result = jax.jit(\n",
+ " shard_map.shard_map(\n",
+ " all_gather,\n",
+ " mesh=mesh,\n",
+ " in_specs=partition,\n",
+ " out_specs=partition,\n",
+ " check_rep=False\n",
+ " )\n",
+ ")(input_arr)\n",
+ "\n",
+ "# Compare Pallas result to XLA shard_map result.\n",
+ "xla_result = jax.jit(\n",
+ " shard_map.shard_map(\n",
+ " lambda x: lax.all_gather(x, 'x'),\n",
+ " mesh=mesh, in_specs=partition, out_specs=partition\n",
+ " )\n",
+ ")(input_arr)\n",
+ "\n",
+ "print('Input: ', input_arr.shape, input_arr[::8, 0])\n",
+ "print('Pallas Result: ', pallas_result.shape, pallas_result[:, 0, 0])\n",
+ "print('lax.all_gather Result: ', xla_result.shape, xla_result[:, 0, 0])\n",
+ "print('Difference |Pallas - lax.all_gather| = ',\n",
+ " jnp.mean(jnp.abs(pallas_result - xla_result)))"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "KgU7HI2pS4om"
+ },
+ "source": [
+ "A detail worth mentioning here is the use of multiple receive semaphores. Because we only block on the receiving device, it is still possible for a sender to have sent multiple DMAs in flight before the receiver has finished processing the first one (see the next section and reduce-sum example which discusses race conditions in more detail). In this situation we may hit a situation where the same semaphore is being used for multiple DMAs occurring simultaneously. To avoid this, we allocate `num_devices-1` semaphores so there is no risk of re-use. While this race condition is unlikely to happen on such a small kernel, on larger kernels there is more chance for devices to fall out of sync and potentially cause a silent failure."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "KgU7HI2pS4om"
+ },
+ "source": [
+ "## Advanced Techniques\n",
+ "\n",
+ "Now that we have seen how to write several basic kernels using remote DMA operations, we will go over more advanced techniques for synchronization and writing efficient kernels."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "8M_kdl0FCtrL"
+ },
+ "source": [
+ "### Synchronization: Regular and Barrier Semaphores\n",
+ "\n",
+ "The examples we implemented in the basic tutorial do not require special handling of synchronization as all necessary communication writes to disjoint buffers. However, other operations may require more complex communication patterns that need additional synchronization primitives to avoid race conditions. Pallas provides two additional primitives to help with this: regular and barrier semaphores.\n",
+ "\n",
+ "#### Regular Semaphores\n",
+ "\n",
+ "Regular semaphores are the standard tool used to synchronize across multiple devices. Semaphores are fundamentally counters - they can be incremented by any device after which a device can block until the value of the semaphore reaches a specific value (and then decrement the value).\n",
+ "\n",
+ "The three main operations that can be used on regular semaphores are signal, wait, and read:\n",
+ "```python\n",
+ "def semaphore_signal(\n",
+ " sem: Ref[SemaphoreType],\n",
+ " inc: int,\n",
+ " device_id: int | tuple[int, ...],\n",
+ " device_id_type: DeviceIdType\n",
+ ") -> None:\n",
+ " ... # Increments the semaphore `sem` on the target device `device_id` by `inc`.\n",
+ " \n",
+ "def semaphore_wait(\n",
+ " semaphore: Ref[SemaphoreType],\n",
+ " value: int,\n",
+ ") -> None:\n",
+ " ... # Blocks until the locally allocated copy of `sem` reaches `value`, then decrement by `value` and proceed.\n",
+ " \n",
+ "def semaphore_read(\n",
+ " sem: Ref[SemaphoreType],\n",
+ ") -> jax.Array:\n",
+ " ... # Returns the current value of `sem` as an `int32[]`.\n",
+ "```\n",
+ "\n",
+ "In order to use regular semaphores, they can be allocated in the same way as a DMA semaphore, but by specifying `pltpu.SemaphoreType.REGULAR` rather than `pltpu.SemaphoreType.DMA`.\n",
+ "\n",
+ "Semaphores must be zero at the end of a Pallas program to complete succesfully. There are two error cases where this may happen:\n",
+ " - If a semaphore is over-signaled, the program will end with non-zero (>0) semaphores. In this case, the program will crash upon completion. This is useful for debugging as non-zero semaphores typically means there is a bug somewhere inside of the program.\n",
+ " - If a semaphore is over-waited, the program will hang on the blocking `semaphore_wait` call while it waits for the sempahore to be incremented. In this case the device or program will need to be restarted.\n",
+ "\n",
+ "#### Barrier Semaphores\n",
+ "\n",
+ "Barrier semaphores are globally-allocated semaphores used to synchronize devices across an entire program and ensure that all devices have entered the Pallas kernel.\n",
+ "\n",
+ "If a Pallas kernel is executed within the context of a larger XLA program, we need to ensure that all devices that communicate have entered the kernel. However, DMA and regular semaphores are both locally scoped - they are only understood by other devices that have entered the kernel. Barrier semaphores serve as a globally understood semaphore that can be used for synchronization no matter where in the XLA program the device is currently executing.\n",
+ "\n",
+ "By default, if you do not specify a barrier semaphore, Pallas will automatically insert a barrier semaphore at the beginning of your program. However, it can be more efficient to write your own. Barrier semaphores are similar to regular semaphores in that they are counters that can be incremented via `semaphore_signal` and can be decremented via `semaphore_wait`. They are created by calling `get_barrier_semaphore()` within a kernel. Typically, we use barriers once at the beginning of a kernel to synchronize with all devices we are communicating with.\n",
+ "\n",
+ "```python\n",
+ "from jax.experimental.pallas import tpu as pltpu\n",
+ "\n",
+ "def example_kernel(...):\n",
+ " # Use barrier semaphores at the beginning of a kernel.\n",
+ " # is_start_of_kernel = ...\n",
+ " # right_neighbor = ...\n",
+ " # ...\n",
+ " @pl.when(is_start_of_kernel)\n",
+ " def _():\n",
+ " barrier_sem = pltpu.get_barrier_semaphore()\n",
+ " # Increment the semaphore of your right neighbor.\n",
+ " pltpu.semaphore_signal(\n",
+ " barrier_sem,\n",
+ " device_id=right_neighbor,\n",
+ " device_id_type=pltpu.DeviceIdType.LOGICAL,\n",
+ " )\n",
+ " # Wait until your left neighbor has incremented your semaphore\n",
+ " pltpu.semaphore_wait(barrier_sem, 1)\n",
+ " # ...\n",
+ "```\n",
+ "\n",
+ "When using barrier semaphores, the `collective_id` compiler parameter must be passed to `pallas_call` to specify which barrier semaphore is being used. A TPU has a small, fixed number of barrier semaphores available (typically on the order of 20-30) and therefore they should be used sparingly. In order to ensure correctness, only kernels that share the same communication pattern should use the same `collective_id`. For example, if two kernels synchronize only with neighbors on the same mesh axis, they are allowed to share the same `collective_id`. However, if two kernels synchronize along different axes, they must have different `collective_id`s. Failure to do so may result in race conditions that are difficult to debug.\n",
+ "\n",
+ "```python\n",
+ "kernel = pl.pallas_call(\n",
+ " example_kernel,\n",
+ " ...,\n",
+ " compiler_params=pltpu.TPUCompilerParams(collective_id=0),\n",
+ ")\n",
+ "```"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "zy20AxN5TSLA"
+ },
+ "source": [
+ "### Double-buffering\n",
+ "\n",
+ "In order to avoid reading from a local `Ref` that is also being written into by another device and creating a race condition, a useful technique is the \"double-buffered\" strategy where we allocate a two `Ref`s for each destination value. On each iteration, one `Ref` will be designated as a \"working\" slot, and the other will be designated as a \"receiving\" slot. The device is free to use the working slot for computation, but will only copy data into its neighbor's receiving slot. The working and receiving slots alternate every iteration, so that once a copy is finished, the old receiving slot becomes the new working slot, and vice versa. Using this scheme properly, data is never read from and written to the same buffer.\n",
+ "\n",
+ "The following code skeleton demonstrates how double-buffering can be used. We keep a running iteration counter in the variable `iteration`, and the `working_slot` and `receiving_slot` alternate between 0 and 1 every iteration. `dst_ref` is allocated as a double-buffer and has the size `[2, ...]`. On each iteration, we read from the working slot using `dst_ref.at[working_slot, ...]` and use the value to perform computation. Simultaneously, we copy to our neighbor's `dst_ref.at[receiving_slot]` to avoid overwriting their `working_slot` value. By structuring our communication in this fashion it is possible to overlap the communication latency of the remote DMA with local computation while minimizing the risk of race conditions.\n",
+ "```python\n",
+ "def kernel(...):\n",
+ " # ...\n",
+ " iteration = pl.program_id(0)\n",
+ " working_slot = lax.rem(iteration, 2)\n",
+ " receiving_slot = 1 - working_slot\n",
+ " # ...\n",
+ "\n",
+ " local_copy_op = pltpu.make_async_copy(\n",
+ " src_ref=dst_ref.at[working_slot, ...],\n",
+ " dst_ref=local_scratch_ref,\n",
+ " sem=local_copy_sem,\n",
+ " )\n",
+ " local_copy_op.start()\n",
+ " remote_copy_op = pltpu.make_async_remote_copy(\n",
+ " src_ref=src_ref,\n",
+ " dst_ref=dst_ref.at[receiving_slot, ...],\n",
+ " send_sem=send_sem,\n",
+ " recv_sem=recv_sem,\n",
+ " device_id=target_device,\n",
+ " device_id_type=pltpu.DeviceIdType.MESH,\n",
+ " )\n",
+ " remote_copy_op.start()\n",
+ " \n",
+ " local_copy_op.wait()\n",
+ " # ... do work on local_scratch while waiting for async_copy_op to finish.\n",
+ " remote_copy_op.wait()\n",
+ "\n",
+ "```\n",
+ "\n",
+ "In terms of synchronization, the double-buffered construction works if all devices are executing on the same iteration. If a sender manages to get one iteration ahead of its receiver, it's `working_slot` and `receiving_slot` indices will be flipped compared to the receiver, meaning that it could be writing into the `working_slot` at the same time the receiver is reading from it. In order to avoid this, it may be necessary to use a semaphore to synchronize the sender with the receiver, or add additional buffering slots (\"triple\", \"quadruple\", or N-buffered) to allow additional run-ahead at the cost of more memory. In our previous `all_gather` example, note that the kernel contained a receiving buffer with N slots, which avoids race conditions altogether. In our next kernel, we will instead go through an example which uses a double-buffer with explicit synchronization."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "Or0Itv72No5d"
+ },
+ "source": [
+ "### Example: All-Reduce Sum (`lax.psum`)\n",
+ "\n",
+ "We will now implement an all-reduce sum kernel using double-buffering and semaphores for synchronization. For those familiar with collective operations in JAX, the equivalent operation is `lax.psum`. All-reduce is a standard collective operation where the objective is to reduce along an axis of an array, but the array is sharded across multiple devices.\n",
+ "\n",
+ "![reduce_sum_1](../../_static/pallas/distributed/reduce_sum_1.svg)\n",
+ "\n",
+ "In the above example, we have the array [5, 2, 1, 3] sharded across 4 devices. An all-reduce sum operation would sum all values and replicate the result on each device, leading to the result [11, 11, 11, 11] sharded across all 4 devices.\n",
+ "\n",
+ "The naive implementation of all-reduce would be to gather all required values onto each device, and then reduce. However, we can improve the performance of this implementation by interleaving communication with computation. An interleaved, single-direction all-reduce can be visualized as follows. On each iteration, we receive an input value from our left neighbor, and concurrently pass input along to our next neighbor while incrementing it with our local accumulator. After N-1 iterations, each device will have a copy of the full sum in it's memory.\n",
+ "\n",
+ "![reduce_sum_2](../../_static/pallas/distributed/reduce_sum_2.svg)\n",
+ "\n",
+ "#### Putting it all together\n",
+ "\n",
+ "The following kernel demonstrates how to combine these principles into a functional kernel.\n",
+ "\n",
+ "The prologue (executed when `outer_step==0`) first initiates a barrier with both neighbors to ensure that they have also entered the kernel. It also handles initialization for all `Ref`s and handles the first remote copy to the right neighbor's \"working\" slot.\n",
+ "\n",
+ "The main body assumes that a value has already been copied into our local working slot, either from the previous iteration or from the prologue. A complicating factor is that our destination buffers live in HBM, but we need to load values to VMEM before we perform arithmetic. Therefore, we simultaneously copy the working slot value into our VMEM (`receive_scratch`) and pass the value on to our right neighbor's receiving slot. Once the value has been copied into our VMEM, we can accumulate it into our result (contained in `o_ref`).\n",
+ "\n",
+ "A subtle race condition can occur if one device runs one loop ahead of it's right neighbor. In this case, it could copy into the receiver's `working_slot` at the same time the receiver is reading from it. In order to avoid this, each device will block on a `REGULAR` semaphore before copying into the right neighbor's `dst_ref` until it has signaled that it is done reading from its `working_slot`. This race condition is rarely triggered for a small kernel such as this example, but can it can be explicitly triggered if for example using a `pltpu.delay` instruction to artifically hang a device.\n",
+ "\n",
+ "Note that this is not an optimal or fully general kernel, as the block sizes must entirely fit in VMEM and we could better interleave communication and accumulation. We will discuss these optimizations in later sections."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 4,
+ "metadata": {
+ "executionInfo": {
+ "elapsed": 254,
+ "status": "ok",
+ "timestamp": 1722904804952,
+ "user": {
+ "displayName": "Justin Fu",
+ "userId": "17543197034567316452"
+ },
+ "user_tz": 420
+ },
+ "id": "XrY5bMlvBroQ",
+ "outputId": "77497000-4496-462e-cc3c-73fb640cc14c"
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Input = [0.9858954 0.11763906 0.9955574 0.775211 ]\n",
+ "Pallas result = [2.8743029 2.8743029 2.8743029 2.8743029]\n",
+ "lax.psum result = [2.8743029 2.8743029 2.8743029 2.8743029]\n",
+ "Difference |Pallas - lax.psum| = 1.4959369e-08\n"
+ ]
+ }
+ ],
+ "source": [
+ "partition = P(None, 'x')\n",
+ "devices = mesh_utils.create_device_mesh((1, num_devices))\n",
+ "mesh = jax.sharding.Mesh(devices, partition)\n",
+ "sharding = jax.sharding.NamedSharding(mesh, partition)\n",
+ "\n",
+ "input_arr = jax.random.uniform(jax.random.key(0), shape=(8, 128 * num_devices))\n",
+ "input_arr = jax.device_put(input_arr, sharding)\n",
+ "\n",
+ "\n",
+ "def all_reduce_kernel(\n",
+ " x_ref,\n",
+ " o_ref,\n",
+ " hbm_scratch,\n",
+ " copy_sem,\n",
+ " remote_recv_sem,\n",
+ " remote_send_sem,\n",
+ " capacity_sem,\n",
+ " receive_scratch,\n",
+ "):\n",
+ " outer_step = pl.program_id(0)\n",
+ " working_slot = lax.rem(outer_step, 2)\n",
+ " receiving_slot = 1 - working_slot\n",
+ "\n",
+ " my_id = lax.axis_index('x')\n",
+ " right_neighbor = lax.rem(my_id + 1, num_devices)\n",
+ " left_neighbor = lax.rem(my_id - 1 + num_devices, num_devices)\n",
+ "\n",
+ " @pl.when(outer_step == 0)\n",
+ " def _():\n",
+ " # Barrier with both neighbors at the start, since we will be\n",
+ " # communicating with both.\n",
+ " barrier_sem = pltpu.get_barrier_semaphore()\n",
+ " pltpu.semaphore_signal(\n",
+ " barrier_sem,\n",
+ " inc=1,\n",
+ " device_id=(0, left_neighbor),\n",
+ " device_id_type=pltpu.DeviceIdType.MESH,\n",
+ " )\n",
+ " pltpu.semaphore_signal(\n",
+ " barrier_sem,\n",
+ " inc=1,\n",
+ " device_id=(0, right_neighbor),\n",
+ " device_id_type=pltpu.DeviceIdType.MESH,\n",
+ " )\n",
+ " pltpu.semaphore_wait(barrier_sem, 2)\n",
+ "\n",
+ " # Initialize o_ref, acc_scratch, and hbm_scratch.\n",
+ " o_ref[...] = jnp.zeros_like(o_ref)\n",
+ " receive_scratch[...] = jnp.zeros_like(receive_scratch)\n",
+ " initial_copy = pltpu.make_async_remote_copy(\n",
+ " src_ref=x_ref,\n",
+ " dst_ref=hbm_scratch.at[working_slot],\n",
+ " send_sem=remote_send_sem,\n",
+ " recv_sem=remote_recv_sem,\n",
+ " device_id=(0, right_neighbor),\n",
+ " device_id_type=pltpu.DeviceIdType.MESH,\n",
+ " )\n",
+ " initial_copy.start()\n",
+ " initial_copy.wait()\n",
+ "\n",
+ " # Signal to our left neighbor that we are ready to receive.\n",
+ " # Without this signal, our left neighbor can be >=1 iteration ahead,\n",
+ " # meaning it could write into our working slot.\n",
+ " pltpu.semaphore_signal(\n",
+ " capacity_sem,\n",
+ " inc=1,\n",
+ " device_id=(0, left_neighbor),\n",
+ " device_id_type=pltpu.DeviceIdType.MESH,\n",
+ " )\n",
+ "\n",
+ " # Copy the partial result our left neighbor sent to us into VMEM for\n",
+ " # computation.\n",
+ " local_copy = pltpu.make_async_copy(\n",
+ " src_ref=hbm_scratch.at[working_slot],\n",
+ " dst_ref=receive_scratch,\n",
+ " sem=copy_sem,\n",
+ " )\n",
+ " local_copy.start()\n",
+ "\n",
+ " # Block until our right neighbor is ready to receive.\n",
+ " pltpu.semaphore_wait(capacity_sem, 1)\n",
+ " # Pass the value to our right neighbor.\n",
+ " remote_copy = pltpu.make_async_remote_copy(\n",
+ " src_ref=hbm_scratch.at[working_slot],\n",
+ " dst_ref=hbm_scratch.at[receiving_slot],\n",
+ " send_sem=remote_send_sem,\n",
+ " recv_sem=remote_recv_sem,\n",
+ " device_id=(0, right_neighbor),\n",
+ " device_id_type=pltpu.DeviceIdType.MESH,\n",
+ " )\n",
+ " remote_copy.start()\n",
+ " # Finish local copy and accumulate while remote_copy is happening.\n",
+ " local_copy.wait()\n",
+ " o_ref[...] += receive_scratch[...]\n",
+ " # Block until remote copy finishes.\n",
+ " remote_copy.wait()\n",
+ "\n",
+ "\n",
+ "out_shape = (\n",
+ " jax.ShapeDtypeStruct((8, 128), jnp.float32),\n",
+ " # We allocate the double-buffer as a Pallas output so that it is\n",
+ " # resident in HBM.\n",
+ " jax.ShapeDtypeStruct((2, 8, 128), jnp.float32), # hbm_scratch\n",
+ ")\n",
+ "\n",
+ "grid_spec = pltpu.PrefetchScalarGridSpec(\n",
+ " num_scalar_prefetch=0,\n",
+ " in_specs=[\n",
+ " # Our input lives in VMEM\n",
+ " pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.VMEM),\n",
+ " ],\n",
+ " out_specs=[\n",
+ " # Our output lives in VMEM\n",
+ " pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.VMEM),\n",
+ " # Our double-buffer lives in HBM\n",
+ " pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY),\n",
+ " ],\n",
+ " grid=(num_devices,),\n",
+ " scratch_shapes=(\n",
+ " [pltpu.SemaphoreType.DMA] * 3\n",
+ " + [pltpu.SemaphoreType.REGULAR] # capacity_sem\n",
+ " + [pltpu.VMEM((8, 128), jnp.float32)] # receive_scratch\n",
+ " ),\n",
+ ")\n",
+ "\n",
+ "kernel = pl.pallas_call(\n",
+ " all_reduce_kernel,\n",
+ " out_shape=out_shape,\n",
+ " grid_spec=grid_spec,\n",
+ " compiler_params=pltpu.TPUCompilerParams(collective_id=0),\n",
+ ")\n",
+ "\n",
+ "pallas_result = jax.jit(\n",
+ " shard_map.shard_map(\n",
+ " kernel,\n",
+ " mesh=mesh,\n",
+ " in_specs=partition,\n",
+ " out_specs=partition,\n",
+ " check_rep=False,\n",
+ " )\n",
+ ")(input_arr)\n",
+ "pallas_result = jax.block_until_ready(pallas_result)[0]\n",
+ "\n",
+ "\n",
+ "def lax_sum(x):\n",
+ " return lax.psum(x, 'x')\n",
+ "\n",
+ "\n",
+ "xla_result = jax.jit(\n",
+ " shard_map.shard_map(\n",
+ " lax_sum, mesh=mesh, in_specs=P(None, 'x'), out_specs=P(None, 'x')\n",
+ " )\n",
+ ")(input_arr)\n",
+ "\n",
+ "print('Input = ', input_arr[0, ::128])\n",
+ "print('Pallas result = ', pallas_result[0, ::128])\n",
+ "print('lax.psum result = ', xla_result[0, ::128])\n",
+ "difference = jnp.mean(jnp.abs(pallas_result - xla_result))\n",
+ "print('Difference |Pallas - lax.psum| = ', difference)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "d8bsZAzQreC_"
+ },
+ "source": [
+ "### Run-ahead and Race Conditions\n",
+ "\n",
+ "As a general rule of thumb, to maximize performance we want to allow a device to run-ahead of other devices without synchronization as much as possible without sacrificing correctness of the program. While we could enforce a barrier across all devices at the beginning of each iteration, this bottlenecks the performance of the program to the slowest device on each loop. By relaxing synchronization and allowing a moderate amount of run-ahead, we can better accommodate variance in latency between iterations and devices because a device that is slow on one iteration could catch up on the next iteration.\n",
+ "\n",
+ "In the all-reduce kernel we wrote previously, we allow devices to run ahead but by less than one iteration compared to its neighbors (however, non-neighboring devices could be more than 1 iteration apart). To see why the semaphore synchronization is necessary, consider the case when one device (say device 2) hangs and falls behind the other devices. An RDMA has no \"handshake\" — only the receiver is blocked while waiting for the data to arrive. Therefore, each device can run up to one iteration ahead before it becomes blocked waiting for the next RDMA to arrive. If we have N devices, this means that the final device can be up to N iterations ahead of the first device.\n",
+ "\n",
+ "![race_condition](../../_static/pallas/distributed/race_condition.svg)\n",
+ "\n",
+ "Without adding synchronization in the other direction (forcing senders to block), device 1 could potentially run up to `N` iterations (`N = num_devices`) ahead of device 2, sending multiple writes and overwriting values in the process. To solve this in the `all_reduce` kernel we wrote previously we implemented a \"handshake\" protocol where the receiver signals back to the sender that it is ready to receive, and only then does the sender begin issuing the next RDMA."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "UD8lNrqsUeXy"
+ },
+ "source": [
+ "### Bi-directional Communication\n",
+ "\n",
+ "In our previous kernels, we communicated in a single direction around a ring from left-to-right. However, as ICI connections are bi-directional, we are effectively wasting half of the total bandwidth by not sending values in the opposite direction from right-to-left. In this next kernel we will demonstrate an example which communicates in both directions to maximize ICI bandwidth."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "4KjakLhbBk73"
+ },
+ "source": [
+ "### Example: Bi-directional Reduce-Scatter (`lax.psum_scatter`)\n",
+ "\n",
+ "A reduce-scatter operation is the combination of an all-reduce followed by a scatter. Or alternatively, an all-reduce is the combination of a reduce-scatter followed by all-gather.\n",
+ "\n",
+ "The following graphic depicts the semantics of this operation. We assume that each device starts with a collection of partial sums (denoted by a letter + number, such as `A0`). The goal is to reduce along one axis (numbers), while sharding along the other axis (letters).\n",
+ "\n",
+ "![reduce_scatter_1](../../_static/pallas/distributed/reduce_scatter_1.svg)\n",
+ "\n",
+ "In order to implement a bi-directional communication strategy, we slice each input block in half, and designate a direction for each half. The top half of each block will be passed from right-to-left, and the bottom half will be passed from left-to-right. A second deviation from the communication patterns of our previous all-reduce and all-gather kernels is that we will also pass around accumulators or partial sums and keep the inputs local to each device. This is in contrast to the previous examples where we passed around inputs but kept the accumulator local to the device. Passing around the accumulator is a more natural fit for this problem as in contrast to all-reduce, most of the data in the inputs are not part of the output that will be stored locally on the device. (e.g. `B0`, `C0`, and `D0` in the above graphic will not be stored on the device holding `A` at the end).\n",
+ "\n",
+ "The following diagram illustrates this communication pattern, where the colored boxes represent accumulators (not inputs!). Initially, the accumulator is simply the value that was contained in the input. At each iteration of the algorithm, we will receive a partial sum from our neighbors in each direction. We then compute the correct slice of our input to accumulate into the partial buffer, then pass the new partial sum along to our next neighbor. After N iterations, the accumulator will have passed through each device, meaning that it will hold the full sum in the end.\n",
+ "\n",
+ "![reduce_scatter_2](../../_static/pallas/distributed/reduce_scatter_2.svg)\n",
+ "\n",
+ "In terms of construction of the kernel, we introduce an additional `phase` dimension to the Pallas grid, which denotes which accumulator (left or right) we are currently computing on. We let `phase=0` denote the accumulator moving to the left, and `phase=1` denote the accumulator moving to the right. We then pipeline the two phases, such that while computing the result for one phase we are transferring our previously computed values in the opposite direction in preparation for the next phase. For example, when we are on `phase=0` (left), we first begin a DMA to transfer results we computed in the previous iteration to our right neighbor (right-DMA). Then, we accumulate into the left-buffer and save the result to HBM. We then wait for the right-DMA to complete so that it is ready for `phase=1` (right)."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 5,
+ "metadata": {
+ "executionInfo": {
+ "elapsed": 544,
+ "status": "ok",
+ "timestamp": 1722904805699,
+ "user": {
+ "displayName": "Justin Fu",
+ "userId": "17543197034567316452"
+ },
+ "user_tz": 420
+ },
+ "id": "nRauUAxNHg28"
+ },
+ "outputs": [],
+ "source": [
+ "partition = P(None, 'x')\n",
+ "devices = mesh_utils.create_device_mesh((1, num_devices))\n",
+ "mesh = jax.sharding.Mesh(devices, partition)\n",
+ "sharding = jax.sharding.NamedSharding(mesh, partition)\n",
+ "\n",
+ "# We need a block size of (16, 128) to ensure that a half-slice is at least\n",
+ "# of size (8, 128), which is the size of a VREG. This makes tiling easier\n",
+ "# for the compiler.\n",
+ "block_size = (16, 128)\n",
+ "input_arr = jax.random.uniform(\n",
+ " jax.random.key(0),\n",
+ " shape=(block_size[0] * num_devices, block_size[1] * num_devices),\n",
+ ")\n",
+ "input_arr = jax.device_put(input_arr, sharding)\n",
+ "\n",
+ "LEFT = 0\n",
+ "RIGHT = 1\n",
+ "\n",
+ "\n",
+ "def mod(x, n):\n",
+ " return lax.rem(x + n, n)\n",
+ "\n",
+ "\n",
+ "def signal(left_or_right, semaphore):\n",
+ " my_id = lax.axis_index('x')\n",
+ " if left_or_right == LEFT:\n",
+ " neighbor = mod(my_id - 1, num_devices)\n",
+ " else:\n",
+ " neighbor = mod(my_id + 1, num_devices)\n",
+ " pltpu.semaphore_signal(\n",
+ " semaphore,\n",
+ " inc=1,\n",
+ " device_id=(0, neighbor),\n",
+ " device_id_type=pltpu.DeviceIdType.MESH,\n",
+ " )\n",
+ "\n",
+ "\n",
+ "def reduce_scatter_kernel(\n",
+ " x_ref,\n",
+ " o_ref,\n",
+ " hbm_scratch,\n",
+ " local_copy_sem,\n",
+ " left_recv_sem,\n",
+ " left_send_sem,\n",
+ " right_recv_sem,\n",
+ " right_send_sem,\n",
+ " left_capacity_sem,\n",
+ " right_capacity_sem,\n",
+ " accum_scratch,\n",
+ "):\n",
+ " outer_step = pl.program_id(0)\n",
+ " phase = pl.program_id(1)\n",
+ " is_start = jnp.logical_and(outer_step == 0, phase == 0)\n",
+ " last_iteration = outer_step == pl.num_programs(0) - 1\n",
+ "\n",
+ " working_slot = lax.rem(outer_step, 2)\n",
+ " receiving_slot = 1 - working_slot\n",
+ " my_id = lax.axis_index('x')\n",
+ " right_neighbor = mod(my_id + 1, num_devices)\n",
+ " left_neighbor = mod(my_id - 1, num_devices)\n",
+ "\n",
+ " left_copy_device = mod(my_id + outer_step + 1, num_devices)\n",
+ " right_copy_device = mod(my_id - outer_step - 1, num_devices)\n",
+ " # Slices can be specified using pl.ds(start, size)\n",
+ " left_copy_slice = pl.ds(0, block_size[0] // 2)\n",
+ " right_copy_slice = pl.ds(block_size[0] // 2, block_size[0] // 2)\n",
+ " current_phase_slice = pl.ds(phase * (block_size[0] // 2), block_size[0] // 2)\n",
+ "\n",
+ " initial_left_copy = pltpu.make_async_remote_copy(\n",
+ " src_ref=x_ref.at[my_id, left_copy_slice],\n",
+ " dst_ref=hbm_scratch.at[working_slot, left_copy_slice],\n",
+ " send_sem=left_send_sem,\n",
+ " recv_sem=left_recv_sem,\n",
+ " device_id=(0, left_neighbor),\n",
+ " device_id_type=pltpu.DeviceIdType.MESH,\n",
+ " )\n",
+ "\n",
+ " initial_right_copy = pltpu.make_async_remote_copy(\n",
+ " src_ref=x_ref.at[my_id, right_copy_slice],\n",
+ " dst_ref=hbm_scratch.at[working_slot, right_copy_slice],\n",
+ " send_sem=right_send_sem,\n",
+ " recv_sem=right_recv_sem,\n",
+ " device_id=(0, right_neighbor),\n",
+ " device_id_type=pltpu.DeviceIdType.MESH,\n",
+ " )\n",
+ "\n",
+ " left_copy = pltpu.make_async_remote_copy(\n",
+ " src_ref=hbm_scratch.at[working_slot, left_copy_slice],\n",
+ " dst_ref=hbm_scratch.at[receiving_slot, left_copy_slice],\n",
+ " send_sem=left_send_sem,\n",
+ " recv_sem=left_recv_sem,\n",
+ " device_id=(0, left_neighbor),\n",
+ " device_id_type=pltpu.DeviceIdType.MESH,\n",
+ " )\n",
+ " right_copy = pltpu.make_async_remote_copy(\n",
+ " # Note: Right copy is flipped with regards to slots since we are copying\n",
+ " # to the next outer_step iteration.\n",
+ " src_ref=hbm_scratch.at[receiving_slot, right_copy_slice],\n",
+ " dst_ref=hbm_scratch.at[working_slot, right_copy_slice],\n",
+ " send_sem=right_send_sem,\n",
+ " recv_sem=right_recv_sem,\n",
+ " device_id=(0, right_neighbor),\n",
+ " device_id_type=pltpu.DeviceIdType.MESH,\n",
+ " )\n",
+ "\n",
+ " # --- Prologue ---\n",
+ " @pl.when(is_start)\n",
+ " def _():\n",
+ " # Barrier with both neighbors at the start, since we will be\n",
+ " # communicating with both.\n",
+ " barrier_sem = pltpu.get_barrier_semaphore()\n",
+ " pltpu.semaphore_signal(\n",
+ " barrier_sem,\n",
+ " inc=1,\n",
+ " device_id=(0, left_neighbor),\n",
+ " device_id_type=pltpu.DeviceIdType.MESH,\n",
+ " )\n",
+ " pltpu.semaphore_signal(\n",
+ " barrier_sem,\n",
+ " inc=1,\n",
+ " device_id=(0, right_neighbor),\n",
+ " device_id_type=pltpu.DeviceIdType.MESH,\n",
+ " )\n",
+ " pltpu.semaphore_wait(barrier_sem, 2)\n",
+ "\n",
+ " # Initialize o_ref, acc_scratch, and hbm_scratch with initial copies.\n",
+ " o_ref[...] = jnp.zeros_like(o_ref[...])\n",
+ " accum_scratch[...] = jnp.zeros_like(accum_scratch[...])\n",
+ "\n",
+ " initial_left_copy.start()\n",
+ " initial_left_copy.wait()\n",
+ " initial_right_copy.start()\n",
+ "\n",
+ " # We tell our left neighbor that it is allowed to send to the right.\n",
+ " # (and vice versa for right neighbor)\n",
+ " signal(LEFT, right_capacity_sem)\n",
+ " signal(RIGHT, left_capacity_sem)\n",
+ "\n",
+ " # --- Body ---\n",
+ " # At the beginning of our kernel body, we start a DMA which copies\n",
+ " # the result we computed in the previous phase to our neighbor.\n",
+ " # This allows us to overlap the communication of sending our previous phase\n",
+ " # with the computation for the current phase.\n",
+ " @pl.when(~is_start)\n",
+ " def _():\n",
+ " @pl.when(phase == LEFT)\n",
+ " def _():\n",
+ " # We block here until our right neighbor tells use we can send to\n",
+ " # the right.\n",
+ " pltpu.semaphore_wait(right_capacity_sem, 1)\n",
+ " right_copy.start()\n",
+ "\n",
+ " @pl.when(phase == RIGHT)\n",
+ " def _():\n",
+ " # We block here until our left neighbor tells use we can send to\n",
+ " # the left.\n",
+ " pltpu.semaphore_wait(left_capacity_sem, 1)\n",
+ " left_copy.start()\n",
+ "\n",
+ " local_copy = pltpu.make_async_copy(\n",
+ " src_ref=hbm_scratch.at[working_slot, current_phase_slice],\n",
+ " dst_ref=accum_scratch,\n",
+ " sem=local_copy_sem,\n",
+ " )\n",
+ " local_copy.start()\n",
+ " local_copy.wait()\n",
+ "\n",
+ " @pl.when(~last_iteration)\n",
+ " def _():\n",
+ " @pl.when(phase == LEFT)\n",
+ " def _():\n",
+ " accum_scratch[...] += x_ref[left_copy_device, left_copy_slice]\n",
+ "\n",
+ " @pl.when(phase == RIGHT)\n",
+ " def _():\n",
+ " accum_scratch[...] += x_ref[right_copy_device, right_copy_slice]\n",
+ "\n",
+ " local_copy = pltpu.make_async_copy(\n",
+ " src_ref=accum_scratch,\n",
+ " dst_ref=hbm_scratch.at[working_slot, current_phase_slice],\n",
+ " sem=local_copy_sem,\n",
+ " )\n",
+ " local_copy.start()\n",
+ " local_copy.wait()\n",
+ "\n",
+ " @pl.when(is_start)\n",
+ " def _():\n",
+ " initial_right_copy.wait()\n",
+ "\n",
+ " # At the end of our kernel body, we wait on the DMA of the previous phase\n",
+ " # to make sure the results are ready for the next phase.\n",
+ " @pl.when(~is_start)\n",
+ " def _():\n",
+ " @pl.when(phase == LEFT)\n",
+ " def _():\n",
+ " right_copy.wait()\n",
+ " signal(LEFT, right_capacity_sem)\n",
+ "\n",
+ " @pl.when(phase == RIGHT)\n",
+ " def _():\n",
+ " left_copy.wait()\n",
+ " signal(RIGHT, left_capacity_sem)\n",
+ "\n",
+ " # --- Epilogue ---\n",
+ " # Store result on last iteration.\n",
+ " @pl.when(last_iteration)\n",
+ " def _():\n",
+ " # Clean up semaphores so that they exit with a value of 0.\n",
+ " @pl.when(phase == LEFT)\n",
+ " def _():\n",
+ " o_ref[left_copy_slice, ...] = accum_scratch[...]\n",
+ " pltpu.semaphore_wait(right_capacity_sem, 1)\n",
+ "\n",
+ " @pl.when(phase == RIGHT)\n",
+ " def _():\n",
+ " o_ref[right_copy_slice, ...] = accum_scratch[...]\n",
+ " pltpu.semaphore_wait(left_capacity_sem, 1)\n",
+ "\n",
+ "\n",
+ "out_shape = (\n",
+ " jax.ShapeDtypeStruct((block_size[0], block_size[1]), jnp.float32), # output\n",
+ " # Shape: [working/recv, block[0], block[1]]\n",
+ " jax.ShapeDtypeStruct(\n",
+ " (2, block_size[0], block_size[1]), jnp.float32\n",
+ " ), # hbm_scratch\n",
+ ")\n",
+ "\n",
+ "grid_spec = pltpu.PrefetchScalarGridSpec(\n",
+ " num_scalar_prefetch=0,\n",
+ " in_specs=[\n",
+ " pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.VMEM),\n",
+ " ],\n",
+ " out_specs=[\n",
+ " pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.VMEM),\n",
+ " pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY),\n",
+ " ],\n",
+ " grid=(num_devices, 2),\n",
+ " scratch_shapes=(\n",
+ " [pltpu.SemaphoreType.DMA] * 5\n",
+ " + [pltpu.SemaphoreType.REGULAR] * 2 # Capacity semaphores\n",
+ " + [\n",
+ " pltpu.VMEM((block_size[0] // 2, block_size[1]), jnp.float32)\n",
+ " ] # accum_scratch\n",
+ " ),\n",
+ ")\n",
+ "\n",
+ "\n",
+ "def pallas_reduce_scatter(input_arr):\n",
+ " input_arr = input_arr.reshape(num_devices, block_size[0], block_size[1])\n",
+ " return pl.pallas_call(\n",
+ " reduce_scatter_kernel,\n",
+ " out_shape=out_shape,\n",
+ " grid_spec=grid_spec,\n",
+ " compiler_params=pltpu.TPUCompilerParams(collective_id=0),\n",
+ " )(input_arr)[0]\n",
+ "\n",
+ "\n",
+ "pallas_result = jax.jit(\n",
+ " shard_map.shard_map(\n",
+ " pallas_reduce_scatter,\n",
+ " mesh=mesh,\n",
+ " in_specs=P(None, 'x'),\n",
+ " out_specs=P('x', None),\n",
+ " check_rep=False,\n",
+ " )\n",
+ ")(input_arr)\n",
+ "\n",
+ "pallas_result = jax.block_until_ready(pallas_result)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 6,
+ "metadata": {
+ "executionInfo": {
+ "elapsed": 596,
+ "status": "ok",
+ "timestamp": 1722904806442,
+ "user": {
+ "displayName": "Justin Fu",
+ "userId": "17543197034567316452"
+ },
+ "user_tz": 420
+ },
+ "id": "E-NMh-_teoi4",
+ "outputId": "24beb42f-1bdd-4c34-e8d2-681dd7f2e9c0"
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Input: (64, 512) [0.78051674 0.3524047 0.59993696 0.9714314 0.24692321 0.01347649\n",
+ " 0.01857424 0.24841607 0.86097646 0.8261659 0.9753758 0.6902338\n",
+ " 0.4431417 0.963323 0.3158517 0.535548 ]\n",
+ "Pallas Result: (64, 128) [1.3593563 1.6274805 1.0979297 3.082869 1.4194957 1.4163033 1.2401303\n",
+ " 1.1892898 2.6545286 2.221559 2.7995253 2.08431 2.2509837 3.0726733\n",
+ " 2.4662397 1.9542246]\n",
+ "lax.psum_scatter Result: (64, 128) [1.3593563 1.6274805 1.0979297 3.082869 1.4194957 1.4163033 1.2401303\n",
+ " 1.1892898 2.6545286 2.221559 2.7995253 2.08431 2.2509837 3.0726733\n",
+ " 2.4662397 1.9542246]\n",
+ "Difference |Pallas - lax.psum_scatter|: 2.3841858e-07\n"
+ ]
+ }
+ ],
+ "source": [
+ "# Compare our result to XLA.\n",
+ "def lax_reduce_sum_scatter(x):\n",
+ " x = x.reshape(num_devices, block_size[0], block_size[1])\n",
+ " return lax.psum_scatter(x, 'x')\n",
+ "\n",
+ "\n",
+ "xla_result = jax.jit(\n",
+ " shard_map.shard_map(\n",
+ " lax_reduce_sum_scatter,\n",
+ " mesh=mesh,\n",
+ " in_specs=P(None, 'x'),\n",
+ " out_specs=P('x', None),\n",
+ " )\n",
+ ")(input_arr)\n",
+ "\n",
+ "print('Input:', input_arr.shape, input_arr[::4, 0])\n",
+ "print('Pallas Result:', pallas_result.shape, pallas_result[::4, 0])\n",
+ "print('lax.psum_scatter Result:', xla_result.shape, xla_result[::4, 0])\n",
+ "print(\n",
+ " 'Difference |Pallas - lax.psum_scatter|:',\n",
+ " jnp.max(jnp.abs(pallas_result - xla_result)),\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "ThKas40r40Ji"
+ },
+ "source": [
+ "### Nested Remote and Local DMA Pipelines\n",
+ "\n",
+ "A limitation of the previous all-reduce and reduce-scatter kernels that we wrote is that the blocks we copy via remote DMA must be small enough to fit in our working VMEM that we use for accumulation. For some kernels it may be advantageous to use larger block sizes to better utilize the TPU. For example, a matrix multiplication requires on the order of $O(N^3)$ compute operations, but only $O(N^2)$ memory transfers. Therefore, we want each block of work transferred between devices to be large enough such that the operation becomes compute bound and we can hide the communication cost using pipelining. For reference, the VMEM of a TPU (for generations v4/v5) is typically on the order of 10-100MB, whereas HBM ranges from 10-100GB.\n",
+ "\n",
+ "To address this problem, we need to be able to write an \"inner kernel\" that handles local HBM-VMEM pipelining inside of the \"outer kernel\" that handles pipelining larger HBM-HBM transfers between devices. Pallas offers an API for constructing nested pipelines using the `emit_pipeline` function. The basic call signature for `emit_pipeline` follows that of a standard `pallas_call` by specifying a `grid` and `BlockSpec`s for the inputs and outputs:\n",
+ "\n",
+ "```python\n",
+ "def emit_pipeline(\n",
+ " kernel: Callable,\n",
+ " grid: tuple[int],\n",
+ " in_specs: PyTree[BlockSpec] = None,\n",
+ " out_specs: PyTree[BlockSpec] = None,\n",
+ " should_accumulate_out: bool = False,\n",
+ " dimension_semantics: tuple[GridDimensionSemantics] = None,\n",
+ ") -> Callable:\n",
+ " ... # Returns a custom pipeline given an inner kernel and BlockSpecs.\n",
+ "```\n",
+ "\n",
+ "Indeed, one can view `pallas_call` itself as simply a wrapper around `emit_pipeline`. Because our outer kernel only involves remote HBM-HBM transfers, we are not using any of the built-in pipelining that `pallas_call` provides for HBM-VMEM transfers. The following code skeleton demonstrates what a typical program structure would look like using this pattern:\n",
+ "\n",
+ "```python\n",
+ "\n",
+ "def outer_kernel(...):\n",
+ " # ... do work to pipeline remote HBM-HBM transfers (outer kernel)\n",
+ "\n",
+ " def inner_kernel(...):\n",
+ " # ... do work (inner kernel)\n",
+ " pltpu.emit_pipeline(\n",
+ " inner_kernel,\n",
+ " grid=inner_grid,\n",
+ " in_specs=...,\n",
+ " out_specs=...,\n",
+ " )(inner_kernel_args)\n",
+ " # ... do more work (outer kernel)\n",
+ "\n",
+ "pl.pallas_call(\n",
+ " outer_kernel,\n",
+ " grid=outer_grid,\n",
+ " in_specs=...\n",
+ " out_specs=...\n",
+ " scratch=inner_kernel_allocs\n",
+ ")\n",
+ "```"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "DzFeQjYaasX5"
+ },
+ "source": [
+ "### Example: Reduce-Scatter with large HBM blocks\n",
+ "\n",
+ "In this next example we will modify our previous reduce-scatter example to utilize a nested inner pipeline. Note that the communication and computation costs of `reduce_scatter` both scale linearly with the size of the input, so we do not necessarily expect to see the operation become compute-bound with larger block sizes. This example is purely for demonstration purposes on how to use the pipeline emitter.\n",
+ "\n",
+ "We will increase the block sizes of the outer kernel such that they would be undesirable to place inside of VMEM, and allocate all inputs and outputs in HBM (`memory_space=TPUMemorySpace.Any`). The only major change from our previous kernel is the body of the kernel where accumulation is done. Rather than manually copying from HBM to VMEM, accumulating, and copying back to HBM, we use `emit_pipeline` to handle the memory transfers for us. Accumulation is done in an inner kernel with a much smaller, VMEM-friendly block size.\n",
+ "\n",
+ "In our previous kernel we had the following kernel body to copy data from HBM to the VMEM accumulator, increment, and then copy the results back to HBM:\n",
+ "\n",
+ "```python\n",
+ "local_copy = pltpu.make_async_copy(\n",
+ " src_ref=hbm_scratch.at[working_slot, current_phase_slice],\n",
+ " dst_ref=accum_scratch,\n",
+ " sem=local_copy_sem,\n",
+ ")\n",
+ "local_copy.start()\n",
+ "local_copy.wait()\n",
+ "@pl.when(~last_iteration)\n",
+ "def _():\n",
+ " @pl.when(phase == LEFT)\n",
+ " def _():\n",
+ " accum_scratch[...] += x_ref[left_copy_device, left_copy_slice]\n",
+ " @pl.when(phase == RIGHT)\n",
+ " def _():\n",
+ " accum_scratch[...] += x_ref[right_copy_device, right_copy_slice]\n",
+ "local_copy = pltpu.make_async_copy(\n",
+ " src_ref=accum_scratch,\n",
+ " dst_ref=hbm_scratch.at[working_slot, current_phase_slice],\n",
+ " sem=local_copy_sem,\n",
+ ")\n",
+ "local_copy.start()\n",
+ "local_copy.wait()\n",
+ "```\n",
+ "\n",
+ "Our new kernel replaces it with the following `emit_pipeline` call:\n",
+ "\n",
+ "```python\n",
+ "def inner_kernel(input_ref, accum_ref):\n",
+ " accum_ref[...] = input_ref[...]\n",
+ "accum_pipeline = pltpu.emit_pipeline(inner_kernel,\n",
+ " in_specs=[inner_block_spec],\n",
+ " out_specs=inner_block_spec,\n",
+ " should_accumulate_out=True,\n",
+ " grid=inner_grid)\n",
+ "@pl.when(~last_iteration)\n",
+ "def _():\n",
+ " @pl.when(phase == LEFT)\n",
+ " def _():\n",
+ " accum_pipeline(x_ref.at[left_copy_device, left_copy_slice],\n",
+ " hbm_scratch.at[working_slot, left_copy_slice],\n",
+ " )\n",
+ " @pl.when(phase == RIGHT)\n",
+ " def _():\n",
+ " accum_pipeline(x_ref.at[right_copy_device, right_copy_slice],\n",
+ " hbm_scratch.at[working_slot, right_copy_slice],\n",
+ " )\n",
+ "```\n",
+ "\n",
+ "The full kernel is as follows:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 7,
+ "metadata": {
+ "executionInfo": {
+ "elapsed": 1341,
+ "status": "ok",
+ "timestamp": 1722904807930,
+ "user": {
+ "displayName": "Justin Fu",
+ "userId": "17543197034567316452"
+ },
+ "user_tz": 420
+ },
+ "id": "27jni-pSartL"
+ },
+ "outputs": [],
+ "source": [
+ "partition = P(None, 'x')\n",
+ "devices = mesh_utils.create_device_mesh((1, num_devices))\n",
+ "mesh = jax.sharding.Mesh(devices, partition)\n",
+ "sharding = jax.sharding.NamedSharding(mesh, partition)\n",
+ "\n",
+ "# We pick a large outer kernel block size that we do not want to place\n",
+ "# in VMEM. For pedagogical purposes we use (4096, 4096), although in\n",
+ "# principle this can be much larger.\n",
+ "outer_block_size = (4096, 4096)\n",
+ "# We pick a smaller VMEM block size for the inner kernel.\n",
+ "inner_block_size = (128, 128)\n",
+ "input_arr = jax.random.uniform(\n",
+ " jax.random.key(0),\n",
+ " shape=(\n",
+ " outer_block_size[0] * num_devices,\n",
+ " outer_block_size[1] * num_devices,\n",
+ " ),\n",
+ ")\n",
+ "input_arr = jax.device_put(input_arr, sharding)\n",
+ "\n",
+ "\n",
+ "inner_grid = (\n",
+ " outer_block_size[0] // inner_block_size[0] // 2,\n",
+ " outer_block_size[1] // inner_block_size[1],\n",
+ ")\n",
+ "inner_block_spec = pl.BlockSpec(\n",
+ " index_map=lambda i, j: (i, j),\n",
+ " block_shape=inner_block_size,\n",
+ " memory_space=pltpu.TPUMemorySpace.ANY,\n",
+ ")\n",
+ "\n",
+ "\n",
+ "def reduce_scatter_kernel(\n",
+ " x_ref,\n",
+ " o_ref,\n",
+ " hbm_scratch,\n",
+ " left_recv_sem,\n",
+ " left_send_sem,\n",
+ " copy_sem,\n",
+ " right_recv_sem,\n",
+ " right_send_sem,\n",
+ " left_capacity_sem,\n",
+ " right_capacity_sem,\n",
+ "):\n",
+ " outer_step = pl.program_id(0)\n",
+ " phase = pl.program_id(1)\n",
+ " is_start = jnp.logical_and(outer_step == 0, phase == 0)\n",
+ " last_iteration = outer_step == pl.num_programs(0) - 1\n",
+ "\n",
+ " working_slot = lax.rem(outer_step, 2)\n",
+ " receiving_slot = 1 - working_slot\n",
+ " my_id = lax.axis_index('x')\n",
+ " right_neighbor = mod(my_id + 1, num_devices)\n",
+ " left_neighbor = mod(my_id - 1, num_devices)\n",
+ "\n",
+ " left_copy_device = mod(my_id + outer_step + 1, num_devices)\n",
+ " right_copy_device = mod(my_id - outer_step - 1, num_devices)\n",
+ " left_copy_slice = pl.ds(0, outer_block_size[0] // 2)\n",
+ " right_copy_slice = pl.ds(outer_block_size[0] // 2, outer_block_size[0] // 2)\n",
+ " current_phase_slice = pl.ds(\n",
+ " phase * (outer_block_size[0] // 2), outer_block_size[0] // 2\n",
+ " )\n",
+ "\n",
+ " initial_left_copy = pltpu.make_async_remote_copy(\n",
+ " src_ref=x_ref.at[my_id, left_copy_slice],\n",
+ " dst_ref=hbm_scratch.at[working_slot, left_copy_slice],\n",
+ " send_sem=left_send_sem,\n",
+ " recv_sem=left_recv_sem,\n",
+ " device_id=(0, left_neighbor),\n",
+ " device_id_type=pltpu.DeviceIdType.MESH,\n",
+ " )\n",
+ "\n",
+ " initial_right_copy = pltpu.make_async_remote_copy(\n",
+ " src_ref=x_ref.at[my_id, right_copy_slice],\n",
+ " dst_ref=hbm_scratch.at[working_slot, right_copy_slice],\n",
+ " send_sem=right_send_sem,\n",
+ " recv_sem=right_recv_sem,\n",
+ " device_id=(0, right_neighbor),\n",
+ " device_id_type=pltpu.DeviceIdType.MESH,\n",
+ " )\n",
+ "\n",
+ " left_copy = pltpu.make_async_remote_copy(\n",
+ " src_ref=hbm_scratch.at[working_slot, left_copy_slice],\n",
+ " dst_ref=hbm_scratch.at[receiving_slot, left_copy_slice],\n",
+ " send_sem=left_send_sem,\n",
+ " recv_sem=left_recv_sem,\n",
+ " device_id=(0, left_neighbor),\n",
+ " device_id_type=pltpu.DeviceIdType.MESH,\n",
+ " )\n",
+ " right_copy = pltpu.make_async_remote_copy(\n",
+ " src_ref=hbm_scratch.at[receiving_slot, right_copy_slice],\n",
+ " dst_ref=hbm_scratch.at[working_slot, right_copy_slice],\n",
+ " send_sem=right_send_sem,\n",
+ " recv_sem=right_recv_sem,\n",
+ " device_id=(0, right_neighbor),\n",
+ " device_id_type=pltpu.DeviceIdType.MESH,\n",
+ " )\n",
+ "\n",
+ " # --- Prologue ---\n",
+ " @pl.when(is_start)\n",
+ " def _():\n",
+ " # Barrier with both neighbors at the start, since we will be\n",
+ " # communicating with both.\n",
+ " barrier_sem = pltpu.get_barrier_semaphore()\n",
+ " pltpu.semaphore_signal(\n",
+ " barrier_sem,\n",
+ " inc=1,\n",
+ " device_id=(0, left_neighbor),\n",
+ " device_id_type=pltpu.DeviceIdType.MESH,\n",
+ " )\n",
+ " pltpu.semaphore_signal(\n",
+ " barrier_sem,\n",
+ " inc=1,\n",
+ " device_id=(0, right_neighbor),\n",
+ " device_id_type=pltpu.DeviceIdType.MESH,\n",
+ " )\n",
+ " pltpu.semaphore_wait(barrier_sem, 2)\n",
+ "\n",
+ " initial_left_copy.start()\n",
+ " initial_left_copy.wait()\n",
+ " initial_right_copy.start()\n",
+ "\n",
+ " # We tell our left neighbor that it is allowed to send to the right.\n",
+ " # (and vice versa for right neighbor)\n",
+ " signal(LEFT, right_capacity_sem)\n",
+ " signal(RIGHT, left_capacity_sem)\n",
+ "\n",
+ " @pl.when(~is_start)\n",
+ " def _():\n",
+ " @pl.when(phase == LEFT)\n",
+ " def _():\n",
+ " # We block here until our right neighbor tells use we can send to\n",
+ " # the right.\n",
+ " pltpu.semaphore_wait(right_capacity_sem, 1)\n",
+ " right_copy.start()\n",
+ "\n",
+ " @pl.when(phase == RIGHT)\n",
+ " def _():\n",
+ " # We block here until our left neighbor tells use we can send to\n",
+ " # the left.\n",
+ " pltpu.semaphore_wait(left_capacity_sem, 1)\n",
+ " left_copy.start()\n",
+ "\n",
+ " # --- Body ---\n",
+ " def inner_kernel(input_ref, accum_ref):\n",
+ " # We do not explicitly use += because we set should_accumulate_out=True.\n",
+ " accum_ref[...] = input_ref[...]\n",
+ "\n",
+ " accum_pipeline = pltpu.emit_pipeline(\n",
+ " inner_kernel,\n",
+ " in_specs=[inner_block_spec],\n",
+ " out_specs=inner_block_spec,\n",
+ " should_accumulate_out=True,\n",
+ " grid=inner_grid,\n",
+ " )\n",
+ "\n",
+ " @pl.when(~last_iteration)\n",
+ " def _():\n",
+ " @pl.when(phase == LEFT)\n",
+ " def _():\n",
+ " accum_pipeline(\n",
+ " x_ref.at[left_copy_device, left_copy_slice],\n",
+ " hbm_scratch.at[working_slot, left_copy_slice],\n",
+ " )\n",
+ "\n",
+ " @pl.when(phase == RIGHT)\n",
+ " def _():\n",
+ " accum_pipeline(\n",
+ " x_ref.at[right_copy_device, right_copy_slice],\n",
+ " hbm_scratch.at[working_slot, right_copy_slice],\n",
+ " )\n",
+ "\n",
+ " # --- Epilogue ---\n",
+ " @pl.when(is_start)\n",
+ " def _():\n",
+ " initial_right_copy.wait()\n",
+ "\n",
+ " @pl.when(~is_start)\n",
+ " def _():\n",
+ " @pl.when(phase == LEFT)\n",
+ " def _():\n",
+ " right_copy.wait()\n",
+ " signal(LEFT, right_capacity_sem)\n",
+ "\n",
+ " @pl.when(phase == RIGHT)\n",
+ " def _():\n",
+ " left_copy.wait()\n",
+ " signal(RIGHT, left_capacity_sem)\n",
+ "\n",
+ " # Store result on last iteration.\n",
+ " @pl.when(last_iteration)\n",
+ " def _():\n",
+ " output_copy = pltpu.make_async_copy(\n",
+ " src_ref=hbm_scratch.at[working_slot, current_phase_slice],\n",
+ " dst_ref=o_ref.at[current_phase_slice],\n",
+ " sem=copy_sem,\n",
+ " )\n",
+ " output_copy.start()\n",
+ " output_copy.wait()\n",
+ "\n",
+ " # Clean up semaphores so that they exit with a value of 0.\n",
+ " @pl.when(phase == LEFT)\n",
+ " def _():\n",
+ " pltpu.semaphore_wait(right_capacity_sem, 1)\n",
+ "\n",
+ " @pl.when(phase == RIGHT)\n",
+ " def _():\n",
+ " pltpu.semaphore_wait(left_capacity_sem, 1)\n",
+ "\n",
+ "\n",
+ "out_shape = (\n",
+ " jax.ShapeDtypeStruct(\n",
+ " (outer_block_size[0], outer_block_size[1]), jnp.float32\n",
+ " ),\n",
+ " # Shape: [working/recv, block[0], block[1]]\n",
+ " jax.ShapeDtypeStruct(\n",
+ " (2, outer_block_size[0], outer_block_size[1]), jnp.float32\n",
+ " ), # hbm_scratch\n",
+ ")\n",
+ "\n",
+ "grid_spec = pltpu.PrefetchScalarGridSpec(\n",
+ " num_scalar_prefetch=0,\n",
+ " in_specs=[\n",
+ " pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY),\n",
+ " ],\n",
+ " out_specs=[\n",
+ " pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY),\n",
+ " pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY),\n",
+ " ],\n",
+ " grid=(num_devices, 2),\n",
+ " scratch_shapes=(\n",
+ " [pltpu.SemaphoreType.DMA] * 5\n",
+ " + [pltpu.SemaphoreType.REGULAR] * 2 # Capacity semaphores\n",
+ " ),\n",
+ ")\n",
+ "\n",
+ "\n",
+ "def pallas_reduce_scatter(input_arr):\n",
+ " input_arr = input_arr.reshape(\n",
+ " num_devices, outer_block_size[0], outer_block_size[1]\n",
+ " )\n",
+ " return pl.pallas_call(\n",
+ " reduce_scatter_kernel,\n",
+ " out_shape=out_shape,\n",
+ " grid_spec=grid_spec,\n",
+ " compiler_params=pltpu.TPUCompilerParams(collective_id=0),\n",
+ " )(input_arr)[0]\n",
+ "\n",
+ "\n",
+ "pallas_result = jax.jit(\n",
+ " shard_map.shard_map(\n",
+ " pallas_reduce_scatter,\n",
+ " mesh=mesh,\n",
+ " in_specs=P(None, 'x'),\n",
+ " out_specs=P('x', None),\n",
+ " check_rep=False,\n",
+ " )\n",
+ ")(input_arr)\n",
+ "\n",
+ "pallas_result = jax.block_until_ready(pallas_result)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 8,
+ "metadata": {
+ "executionInfo": {
+ "elapsed": 768,
+ "status": "ok",
+ "timestamp": 1722904808851,
+ "user": {
+ "displayName": "Justin Fu",
+ "userId": "17543197034567316452"
+ },
+ "user_tz": 420
+ },
+ "id": "cTEyiMDyx9Y0",
+ "outputId": "1de26695-3713-430e-9ab4-4ea646691680"
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Input: (16384, 16384) [0.74162567 0.0242182 0.27751946 ... 0.05213022 0.36088037 0.04494429]\n",
+ "Pallas Result: (16384, 4096) [2.0648427 1.674587 1.9148926 ... 1.3371865 1.3296283 1.2887063]\n",
+ "lax.psum_scatter Result: (16384, 4096) [2.0648427 1.674587 1.9148926 ... 1.3371865 1.3296283 1.2887063]\n",
+ "Difference |Pallas - lax.psum_scatter|: 2.3841858e-07\n"
+ ]
+ }
+ ],
+ "source": [
+ "# Now we compare our result to XLA.\n",
+ "def lax_reduce_sum_scatter(x):\n",
+ " x = x.reshape(num_devices, outer_block_size[0], outer_block_size[1])\n",
+ " return lax.psum_scatter(x, 'x')\n",
+ "\n",
+ "\n",
+ "xla_result = jax.jit(\n",
+ " shard_map.shard_map(\n",
+ " lax_reduce_sum_scatter,\n",
+ " mesh=mesh,\n",
+ " in_specs=P(None, 'x'),\n",
+ " out_specs=P('x', None),\n",
+ " )\n",
+ ")(input_arr)\n",
+ "\n",
+ "print('Input:', input_arr.shape, input_arr[::4, 0])\n",
+ "print('Pallas Result:', pallas_result.shape, pallas_result[::4, 0])\n",
+ "print('lax.psum_scatter Result:', xla_result.shape, xla_result[::4, 0])\n",
+ "print(\n",
+ " 'Difference |Pallas - lax.psum_scatter|:',\n",
+ " jnp.max(jnp.abs(pallas_result - xla_result)),\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "zz5AFbriliyv"
+ },
+ "source": [
+ "## Final Notes\n",
+ "\n",
+ "### Megacore\n",
+ "\n",
+ "Certain TPUs contain multiple cores in a [Megacore](https://jax.readthedocs.io/en/latest/pallas/tpu/pipelining.html#tpus-in-megacore-configuration) configuration. In this configuration, our general recommendation is to only initiate DMAs from a single core, and only perform HBM-HBM transfers. To do this, set one of the grid axes to the number of cores (can be obtained via `jax.devices()[0].num_cores`) and the dimension_semantics to `\"parallel\"`. Then, you can use `core_index = pl.program_id(axis)` to obtain the core index along that axis, and use `@pl.when(core_index==i)` to execute code specific to that core.\n",
+ "\n",
+ "### Interaction with XLA\n",
+ "\n",
+ "In this tutorial we covered several kernel examples which replicate the functionality of collective operations in JAX such as `lax.all_gather`, `lax.psum`, and `lax.psum_scatter`. An important caveat to note is that a Pallas kernel is somewhat opaque to the XLA compiler and may cause it to miss some optimizations it would normally perform. For example, XLA can asynchronously dispatch collective operations in order to interleave communication and computation without writing a custom kernel. This is not guaranteed to happen when Pallas kernels are involved so it is important to profile your program to see if this is an issue. Another example is the fact that the `emit_pipeline` function we used in this tutorial to generate nested pipelines is not visible to the XLA compiler, and therefore cannot be fused with neighboring operations.\n",
+ "\n",
+ "### Next Steps\n",
+ "\n",
+ "Excellent follow-up excercises for the reader could include implementing a distributed matrix multiplication, implementing `lax.all_to_all`, and relaxing synchronization to allow for additional run-ahead."
+ ]
+ }
+ ],
+ "metadata": {
+ "jupytext": {
+ "formats": "ipynb,md:myst",
+ "main_language": "python"
+ },
+ "kernelspec": {
+ "display_name": "Python 3 (ipykernel)",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.12.3"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 4
+}
diff --git a/docs/pallas/tpu/distributed.md b/docs/pallas/tpu/distributed.md
new file mode 100644
index 000000000000..fc3f929866bd
--- /dev/null
+++ b/docs/pallas/tpu/distributed.md
@@ -0,0 +1,1527 @@
+---
+jupytext:
+ formats: ipynb,md:myst
+ main_language: python
+ text_representation:
+ extension: .md
+ format_name: myst
+ format_version: 0.13
+ jupytext_version: 1.16.4
+kernelspec:
+ display_name: Python 3 (ipykernel)
+ language: python
+ name: python3
+---
+
++++ {"id": "zSNjLhGQJMgq"}
+
+# Distributed Computing in Pallas for TPUs
+
+In this tutorial, we will cover the basics of distributed computing in Pallas on TPUs. We will learn about TPU topologies, communication using the remote DMA primitive, and calling a distributed kernel from JAX using `shard_map`. We will also cover some more advanced kernel writing techniques, such as double-buffering, bi-directional bandwidth optimization, and nested pipelining. As educational examples, we will learn how to implement various collective primitives from JAX, such as `lax.ppermute`, `lax.all_gather`, `lax.psum`, and `lax.psum_scatter`.
+
+Some recommended readings beforehand:
+ - [Pallas Pipelining on TPU](https://jax.readthedocs.io/en/latest/pallas/tpu/pipelining.html)
+ - [Collectives with `shard_map`](https://jax.readthedocs.io/en/latest/notebooks/shard_map.html#collectives-tutorial)
+
+```{code-cell} ipython3
+---
+executionInfo:
+ elapsed: 1978
+ status: ok
+ timestamp: 1722904801801
+ user:
+ displayName: Justin Fu
+ userId: '17543197034567316452'
+ user_tz: 420
+id: PyAGnWc9yI8T
+outputId: 1d8229bd-cab5-495f-93e9-fff2e41db480
+---
+import jax
+from jax import lax
+from jax import numpy as jnp
+from jax.experimental import mesh_utils
+from jax.experimental import pallas as pl
+from jax.experimental import shard_map
+from jax.experimental.pallas import tpu as pltpu
+
+P = jax.sharding.PartitionSpec
+
+num_devices = jax.local_device_count()
+assert num_devices > 1, "Please run this notebook with more than one device."
+assert "TPU" in jax.devices()[0].device_kind, "Please run this notebook with TPU devices."
+print(f"Running with {num_devices} {jax.devices()[0].device_kind} devices.")
+```
+
++++ {"id": "DySMGNByclMi"}
+
+## TPU Topologies
+
+TPUs are typically deployed in pods of multiple devices connected via a high-bandwidth interchip interconnect (ICI) for communication within the pod that is much faster than a typical network connection. For example, the specifications sheet for a [TPU v5p](https://cloud.google.com/tpu/docs/v5p) states an ICI bandwidth of 4.8Tb/s per chip (for reference, TPU v5p also has 21Tb/s of *local* HBM bandwidth). The ICI allows us to implement fast and performant distributed kernels that require high-bandwidth communication within a pod, and use the datacenter network for parallelization over less bandwidth-intensive operations, such as data-parallelism over a batch dimension.
+
+TPUs pods are typically arranged in an ND torus topology. The following graphic gives several examples of configurations of different sizes.
+
+![tpu_topologies](https://cloud.google.com/static/tpu/docs/images/v4-topologies.png)
+
+Flattened as a graph, the torus can be visualized as follows. Each edge (orange or black) is a bidirectional connection between two devices. You will commonly hear about rings in conjunction with discussion about device toplogies — a key feature of a torus is that when taking a slice along an axis of the pod, such as the nodes `[(0,1), (1, 1), (2, 1), (3, 1)]` or `[(0, 1), (1, 1)]`, we have a ring of devices. This is a feature we can use to simplify communication patterns within the pod.
+
+![tpu_torus](https://cloud.google.com/static/tpu/docs/images/untwisted-tori.png)
+
++++ {"id": "1Oc_WD1hChfN"}
+
+## Remote Direct Memory Access (RDMA) Model
+
+TPUs communicate via a push-only model known as a remote direct memory access (RDMA). A TPU is allowed to issue copy instruction to push from a local buffer to any buffer on another device within the same pod that executes asynchronously from the main program thread. However, a TPU can only read data that is stored locally. This is in contrast to more traditional multi-core programming where it is possible to both read from and write to values to a shared memory.
+
+### Async Remote Copy Operation
+The `pltpu.make_async_remote_copy` function is used to create a remote DMA descriptor object which parameterizes both a "send" operation and a "receive" operation. Here's its signature:
+
+```python
+ def make_async_remote_copy(
+ src_ref: Ref,
+ dst_ref: Ref,
+ send_sem: Ref[SemaphoreType],
+ recv_sem: Ref[SemaphoreType],
+ device_id: int | tuple[int, ...],
+ device_id_type: DeviceIdType
+ ) -> AsyncCopyDescriptor:
+```
+
+- `src_ref` is the local `Ref` (in any memory space) containing the data you wish to send to `dst_ref` on another device.
+- `dst_ref` is the remote `Ref` (in any memory space) at which data will be copied to on the target device.
+- `send_sem` is a DMA semaphore used to block until all data has been sent from `src_ref`.
+- `recv_sem` is a DMA semaphore used to block until the expected number of bytes have been received at `dst_ref`. The sender of the DMA will write to the receiver's `recv_sem`.
+- `device_id` is the device ID of the target device to send to.
+- `device_id_type` specifies the format of `device_id`, which can either be in LOGICAL format (integer device ID), or in MESH format (an ND-tuple index into the logical device mesh). The default mode is MESH.
+
+`make_async_remote_copy` returns a descriptor object on which you use the `.start()` method to initiate the DMA, and the `.wait_send()` to block on `send_sem` and `.wait_recv()` to block on `recv_sem` (or `.wait()` to block on both). If a device is only expected to send data, it is sufficient to only call `.start()` and `.wait_send()`, and likewise if a device is only receiving it is sufficient to only call `.wait_recv()`. If using a SPMD pattern where all devices execute the DMA, each device will generally call both `.start()` and `.wait()`.
+```python
+dma_descriptor = make_async_remote_copy(src_ref, dst_ref, send_sem, recv_sem, device_id)
+dma_descriptor.start() # Initiate the DMA (non-blocking).
+# ... do other work
+dma_descriptor.wait_send() # Block until all data has been sent.
+dma_descriptor.wait_recv() # Block until all data has been received.
+```
+
+As an example, let's visualize a DMA where we consider 4 devices (indexed 0, 1, 2, 3). We consider a scheme where device 0 copies to device 1, and device 2 & 3 copy to each other. In practice, we can create such an asymmetric communication pattern by using `@pl.when` to branch on the device ID.
+
+(1) Each device creates the DMA descriptor. Devices 0, 2, and 3 call `.start()` to initiate the DMA from `src_ref`. Device 1 is skips the `.start()` and does nothing, e.g. by using `pl.when`.
+
+![rdma_start](../../_static/pallas/distributed/rdma_start.svg)
+
+(2) As `.start()` is non-blocking, each device is free to do other computation while the DMA is in flight. Devices 0, 2, and 3 call `.wait_send()` to wait on `send_sem` which blocks until all data has been sent.
+
+![rdma_send](../../_static/pallas/distributed/rdma_send.svg)
+
+(3) Finally, devices 1, 2, and 3 will call `.wait_recv()` to wait on `recv_sem` until all data has arrived at `dst_ref`.
+
+![rdma_recv](../../_static/pallas/distributed/rdma_recv.svg)
+
+The above communication pattern can be written as follows:
+```python
+def example_kernel(input_ref, output_ref, send_sem, recv_sem):
+ device_id = lax.axis_index('x')
+ copy_0_to_1 = pltpu.make_async_remote_copy(
+ src_ref=input_ref,
+ dst_ref=output_ref,
+ send_sem=send_sem,
+ recv_sem=recv_sem,
+ device_id=1,
+ )
+ copy_2_to_3 = pltpu.make_async_remote_copy(
+ src_ref=input_ref,
+ dst_ref=output_ref,
+ send_sem=send_sem,
+ recv_sem=recv_sem,
+ device_id=3,
+ )
+ copy_3_to_2 = pltpu.make_async_remote_copy(
+ src_ref=input_ref,
+ dst_ref=output_ref,
+ send_sem=send_sem,
+ recv_sem=recv_sem,
+ device_id=2,
+ )
+ @pl.when(device_id == 0)
+ def _():
+ copy_0_to_1.start()
+ copy_0_to_1.wait_send()
+ @pl.when(device_id == 1)
+ def _():
+ copy_0_to_1.wait_recv()
+ @pl.when(device_id == 2)
+ def _():
+ copy_2_to_3.start()
+ copy_2_to_3.wait_send()
+ copy_3_to_2.wait_recv()
+ @pl.when(device_id == 3)
+ def _():
+ copy_3_to_2.start()
+ copy_3_to_2.wait_send()
+ copy_2_to_3.wait_recv()
+```
+
+### DMA Semaphores
+
+`send_sem` and `recv_sem` are instances of a special type of semaphore reserved exclusively for use with DMAs. They must be allocated with the `tpu.SemaphoreType.DMA` type when specifying input specs to `pallas_call`.
+
+Internally, DMA semaphores can be thought of as integer-valued progress trackers. On DMA start, the local device will begin to increment the value of `send_sem` and the receiver's `recv_sem` asynchronously. Waiting on a semaphore will block until the value of the semaphore reaches the total bytes of data sent/received; when the value is reached, waiting threads are released and the sempahore's value is decremented by the same amount. This means that either all data has been sent (for `send_sem`) or all data has been received (for `dst_sem`). The value of the semaphore can be read with `pl.semaphore_read`, but note that the underlying semantics of the value could change between hardware generations (e.g. the value may not represent exactly the number of bytes sent, although this is a useful mental model to have when reasoning about the behavior of the semaphore).
+
+### Routing
+
+A sender is allowed to send data to any receiver within the same pod, even if they do not share a direct connection (the exception to this rule is for TPU v5e, where devices can only route to a power of 2 offset from themselves). TPUs have an internal routing mechanism which can pass data along to the next device on the path to the destination. However, communicating in this way is not recommended as you have no control over network contention as a kernel writer. The examples we will cover in this tutorial minimize inefficient communication by only transferring data to neighboring devices.
+
+### Failure modes
+
+If using remote DMAs incorrectly, you may encounter several failure modes which can be difficult to debug. The general symptoms of buggy DMA usage are crashes, hanging, or silent data corruption:
+- If semaphores exit the program with an invalid non-zero value, Pallas will crash and exit the program.
+- If semaphores are waited on but an insufficient number of bytes are received (i.e. there is no sender, or if the sent data is less than the size of `dst_ref` on the receiving device), the program may hang indefinitely waiting for bytes that are never sent. In this case the program would need to be restarted.
+- If encountering a race condition, there could be silent data corruption if two simultaneous writes or a simultaneous read and write occur.
+
+Some common causes of the above include:
+- If a device calls `.wait_recv()` but no other device sends to it, the kernel may hang.
+- If a device is sent a more bytes than it expected to receive, it may also crash due to non-zero semaphore states. If sent less, it may hang indefinitely.
+- If DMAs are started but the semaphores are not waited on, the program may crash due to non-zero semaphore states.
+- If two devices copy to the same destination, you may encounter non-deterministic results due to a race condition, or crashing due to non-zero semaphore states.
+
++++ {"id": "vpGSN1Sui0Bu"}
+
+### Example: Right Permute (`lax.ppermute`)
+
+Let's dive into a very basic example. We will implement a kernel that performs a right permutation, where each device sends its slice of the data to its right neighbor.
+
+Suppose we had an array with 512 elements, which we shard into slices of size 128 across 4 devices. Each device will pass its slice to the next device, and the output will consist of the same data, but with the slices rotated by 1. This is identical to the `lax.ppermute` operation where the permutation is set to `(n, (n+1) % 4)`.
+
+In order to call the kernel in distributed mode, we wrap the `pallas_call` in a `shard_map` transformation. From there, we can write the kernel the same way as you would write a normal single-device Pallas kernel, except we now have access to remote DMA instructions. JAX collective primitives such as `lax.axis_index` can be used to obtain a `device_id` that can be used to compute which target devices to copy to, by referencing the same named axes names passed into `shard_map`.
+
+```{code-cell} ipython3
+---
+executionInfo:
+ elapsed: 1606
+ status: ok
+ timestamp: 1722904803566
+ user:
+ displayName: Justin Fu
+ userId: '17543197034567316452'
+ user_tz: 420
+id: YkyIKN2thZ-V
+outputId: 9b7ed142-d161-4237-fed8-cbce41adc5f0
+---
+partition = P(None, 'x')
+devices = mesh_utils.create_device_mesh((1, num_devices))
+mesh = jax.sharding.Mesh(devices, partition)
+sharding = jax.sharding.NamedSharding(mesh, partition)
+
+# Create an input array that shards the last dimension across
+# all devices.
+input_arr = jax.random.uniform(jax.random.key(0), (8, 128 * num_devices))
+input_arr = jax.device_put(input_arr, sharding)
+
+
+def right_permute_kernel(input_ref, output_ref, send_sem, recv_sem):
+ my_id = lax.axis_index('x')
+ right_neighbor = lax.rem(my_id + 1, num_devices)
+ remote_copy_op = pltpu.make_async_remote_copy(
+ src_ref=input_ref,
+ dst_ref=output_ref,
+ send_sem=send_sem,
+ recv_sem=recv_sem,
+ device_id=(0, right_neighbor),
+ device_id_type=pltpu.DeviceIdType.MESH,
+ )
+ remote_copy_op.start()
+ remote_copy_op.wait()
+
+
+out_shape = jax.ShapeDtypeStruct((8, 128), jnp.float32)
+grid_spec = pltpu.PrefetchScalarGridSpec(
+ num_scalar_prefetch=0,
+ # TPUMemorySpace.ANY will (usually) place the tensor in HBM.
+ in_specs=[
+ pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY),
+ ],
+ out_specs=pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY),
+ scratch_shapes=(
+ # We allocate DMA semaphores in scratch memory.
+ [pltpu.SemaphoreType.DMA] * 2
+ ),
+)
+right_permute = pl.pallas_call(
+ right_permute_kernel,
+ out_shape=out_shape,
+ grid_spec=grid_spec,
+)
+# Wrap the kernel within a shard_map to call.
+pallas_result = jax.jit(
+ shard_map.shard_map(
+ right_permute,
+ mesh=mesh,
+ in_specs=partition,
+ out_specs=partition,
+ check_rep=False,
+ )
+)(input_arr)
+
+# Compare Pallas result to XLA shard_map result.
+perm = tuple((src, (src + 1) % num_devices) for src in range(num_devices))
+
+xla_result = jax.jit(
+ shard_map.shard_map(
+ lambda x: lax.ppermute(x, 'x', perm),
+ mesh=mesh, in_specs=partition, out_specs=partition)
+)(input_arr)
+
+print('Input = ', input_arr[0, ::128])
+print('Pallas Result = ', pallas_result[0, ::128])
+print('lax.ppermute Result = ', xla_result[0, ::128])
+print(
+ 'Difference |Pallas - lax.ppermute| = ',
+ jnp.mean(jnp.abs(pallas_result - xla_result)),
+)
+```
+
++++ {"id": "iyfhdGXuUnq2"}
+
+### Example: All-gather (`lax.all_gather`)
+
+In this next example we will implement the all-gather collective operation, which has a JAX equivalent in `lax.all_gather`. In contrast with the right-permute example from above which only involves a pair of source and destination neighbors, an all-gather operation requires communication between all devices and therefore we must think about how data is routed between them. The specifics of how we implement this are dictated by the device topology, for which we assume is a ring.
+
+#### Ring Communication Pattern
+
+We will write our kernel assuming a ring topology. Rings are a natural fit for TPUs as slicing along any dimension of a torus produces a ring. When writing collectives, we often only need to think about 1D slices of our torus at a time because the different dimensions of the torus are reserved for different types of parallelism (data vs. model, for example).
+
+The strategy we will use is to write a looped kernel, where on each iteration a device receives one slice of the sharded array from its left neighbor, and copies the previously received slice to its right neighbor. After `num_devices` iterations, each device will have a copy of the entire array in its local HBM.
+
+![all_gather](../../_static/pallas/distributed/all_gather.svg)
+
+We can re-purpose Pallas's `grid` argument to implement the loop. Rather than iterating over tiles of an array as we have done in previous tutorials, we instead set the grid to `(num_devices,)` to indicate that we want to loop over the number of devices and use `pl.program_id` to obtain the loop iteration inside of the Pallas kernel. The following code snippet demonstrates how to implement this:
+
+```{code-cell} ipython3
+---
+executionInfo:
+ elapsed: 812
+ status: ok
+ timestamp: 1722904804531
+ user:
+ displayName: Justin Fu
+ userId: '17543197034567316452'
+ user_tz: 420
+id: ojQEZB5mBRqM
+outputId: e1648f54-737c-4921-ca3b-b4c639a38d2b
+---
+partition = P('x', None)
+devices = mesh_utils.create_device_mesh((num_devices, 1))
+mesh = jax.sharding.Mesh(devices, partition)
+sharding = jax.sharding.NamedSharding(mesh, partition)
+
+# Create an input array that shards the first dimension across
+# all devices.
+input_arr = jax.random.uniform(jax.random.key(0), (8 * num_devices, 128))
+input_arr = jax.device_put(input_arr, sharding)
+
+
+def all_gather_kernel(input_ref,
+ output_ref,
+ local_copy_sem,
+ send_sem,
+ recv_sems):
+ outer_step = pl.program_id(0)
+ my_id = lax.axis_index('x')
+ right_neighbor = lax.rem(my_id + 1, num_devices)
+ copy_slot = my_id - outer_step
+ copy_slot = lax.rem(copy_slot + num_devices, num_devices)
+
+ @pl.when(outer_step == 0)
+ def _():
+ local_copy_op = pltpu.make_async_copy(
+ src_ref=input_ref,
+ dst_ref=output_ref.at[my_id],
+ sem=local_copy_sem,
+ )
+ local_copy_op.start()
+ local_copy_op.wait()
+
+ # Copy to our right neighbor.
+ # Note that we will also be receiving data from our left neighbor,
+ # but at `copy_slot-1` rather than `copy_slot`! This makes use of the fact
+ # that the indices do not need to be symmetric between remote DMAs.
+ remote_copy_op = pltpu.make_async_remote_copy(
+ src_ref=output_ref.at[copy_slot],
+ dst_ref=output_ref.at[copy_slot],
+ send_sem=send_sem,
+ recv_sem=recv_sems.at[outer_step],
+ device_id=(right_neighbor, 0),
+ device_id_type=pltpu.DeviceIdType.MESH,
+ )
+ remote_copy_op.start()
+ remote_copy_op.wait()
+
+out_shape = jax.ShapeDtypeStruct((num_devices, 8, 128), jnp.float32)
+grid_spec = pltpu.PrefetchScalarGridSpec(
+ num_scalar_prefetch=0,
+ in_specs=[
+ # TPUMemorySpace.ANY will (usually) place the tensor in HBM.
+ pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY),
+ ],
+ out_specs=pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY),
+ scratch_shapes=(
+ # DMA semaphores are allocated in scratch memory.
+ # We allocated one semaphore for a local HBM-VMEM copy,
+ # and one for the remote send semaphore.
+ [pltpu.SemaphoreType.DMA] * 2
+ # We additionally allocate one receive semaphore per device.
+ # This is to avoid situations where we have multiple
+ # DMAs in flight, as we do not want to share a receive
+ # semaphore between the DMAs.
+ + [pltpu.SemaphoreType.DMA((num_devices-1,))]
+
+ ),
+ grid=(num_devices-1,)
+ )
+
+all_gather = pl.pallas_call(
+ all_gather_kernel,
+ out_shape=out_shape,
+ grid_spec=grid_spec,
+ )
+
+# Wrap the kernel within a shard_map to call.
+pallas_result = jax.jit(
+ shard_map.shard_map(
+ all_gather,
+ mesh=mesh,
+ in_specs=partition,
+ out_specs=partition,
+ check_rep=False
+ )
+)(input_arr)
+
+# Compare Pallas result to XLA shard_map result.
+xla_result = jax.jit(
+ shard_map.shard_map(
+ lambda x: lax.all_gather(x, 'x'),
+ mesh=mesh, in_specs=partition, out_specs=partition
+ )
+)(input_arr)
+
+print('Input: ', input_arr.shape, input_arr[::8, 0])
+print('Pallas Result: ', pallas_result.shape, pallas_result[:, 0, 0])
+print('lax.all_gather Result: ', xla_result.shape, xla_result[:, 0, 0])
+print('Difference |Pallas - lax.all_gather| = ',
+ jnp.mean(jnp.abs(pallas_result - xla_result)))
+```
+
++++ {"id": "KgU7HI2pS4om"}
+
+A detail worth mentioning here is the use of multiple receive semaphores. Because we only block on the receiving device, it is still possible for a sender to have sent multiple DMAs in flight before the receiver has finished processing the first one (see the next section and reduce-sum example which discusses race conditions in more detail). In this situation we may hit a situation where the same semaphore is being used for multiple DMAs occurring simultaneously. To avoid this, we allocate `num_devices-1` semaphores so there is no risk of re-use. While this race condition is unlikely to happen on such a small kernel, on larger kernels there is more chance for devices to fall out of sync and potentially cause a silent failure.
+
++++ {"id": "KgU7HI2pS4om"}
+
+## Advanced Techniques
+
+Now that we have seen how to write several basic kernels using remote DMA operations, we will go over more advanced techniques for synchronization and writing efficient kernels.
+
++++ {"id": "8M_kdl0FCtrL"}
+
+### Synchronization: Regular and Barrier Semaphores
+
+The examples we implemented in the basic tutorial do not require special handling of synchronization as all necessary communication writes to disjoint buffers. However, other operations may require more complex communication patterns that need additional synchronization primitives to avoid race conditions. Pallas provides two additional primitives to help with this: regular and barrier semaphores.
+
+#### Regular Semaphores
+
+Regular semaphores are the standard tool used to synchronize across multiple devices. Semaphores are fundamentally counters - they can be incremented by any device after which a device can block until the value of the semaphore reaches a specific value (and then decrement the value).
+
+The three main operations that can be used on regular semaphores are signal, wait, and read:
+```python
+def semaphore_signal(
+ sem: Ref[SemaphoreType],
+ inc: int,
+ device_id: int | tuple[int, ...],
+ device_id_type: DeviceIdType
+) -> None:
+ ... # Increments the semaphore `sem` on the target device `device_id` by `inc`.
+
+def semaphore_wait(
+ semaphore: Ref[SemaphoreType],
+ value: int,
+) -> None:
+ ... # Blocks until the locally allocated copy of `sem` reaches `value`, then decrement by `value` and proceed.
+
+def semaphore_read(
+ sem: Ref[SemaphoreType],
+) -> jax.Array:
+ ... # Returns the current value of `sem` as an `int32[]`.
+```
+
+In order to use regular semaphores, they can be allocated in the same way as a DMA semaphore, but by specifying `pltpu.SemaphoreType.REGULAR` rather than `pltpu.SemaphoreType.DMA`.
+
+Semaphores must be zero at the end of a Pallas program to complete succesfully. There are two error cases where this may happen:
+ - If a semaphore is over-signaled, the program will end with non-zero (>0) semaphores. In this case, the program will crash upon completion. This is useful for debugging as non-zero semaphores typically means there is a bug somewhere inside of the program.
+ - If a semaphore is over-waited, the program will hang on the blocking `semaphore_wait` call while it waits for the sempahore to be incremented. In this case the device or program will need to be restarted.
+
+#### Barrier Semaphores
+
+Barrier semaphores are globally-allocated semaphores used to synchronize devices across an entire program and ensure that all devices have entered the Pallas kernel.
+
+If a Pallas kernel is executed within the context of a larger XLA program, we need to ensure that all devices that communicate have entered the kernel. However, DMA and regular semaphores are both locally scoped - they are only understood by other devices that have entered the kernel. Barrier semaphores serve as a globally understood semaphore that can be used for synchronization no matter where in the XLA program the device is currently executing.
+
+By default, if you do not specify a barrier semaphore, Pallas will automatically insert a barrier semaphore at the beginning of your program. However, it can be more efficient to write your own. Barrier semaphores are similar to regular semaphores in that they are counters that can be incremented via `semaphore_signal` and can be decremented via `semaphore_wait`. They are created by calling `get_barrier_semaphore()` within a kernel. Typically, we use barriers once at the beginning of a kernel to synchronize with all devices we are communicating with.
+
+```python
+from jax.experimental.pallas import tpu as pltpu
+
+def example_kernel(...):
+ # Use barrier semaphores at the beginning of a kernel.
+ # is_start_of_kernel = ...
+ # right_neighbor = ...
+ # ...
+ @pl.when(is_start_of_kernel)
+ def _():
+ barrier_sem = pltpu.get_barrier_semaphore()
+ # Increment the semaphore of your right neighbor.
+ pltpu.semaphore_signal(
+ barrier_sem,
+ device_id=right_neighbor,
+ device_id_type=pltpu.DeviceIdType.LOGICAL,
+ )
+ # Wait until your left neighbor has incremented your semaphore
+ pltpu.semaphore_wait(barrier_sem, 1)
+ # ...
+```
+
+When using barrier semaphores, the `collective_id` compiler parameter must be passed to `pallas_call` to specify which barrier semaphore is being used. A TPU has a small, fixed number of barrier semaphores available (typically on the order of 20-30) and therefore they should be used sparingly. In order to ensure correctness, only kernels that share the same communication pattern should use the same `collective_id`. For example, if two kernels synchronize only with neighbors on the same mesh axis, they are allowed to share the same `collective_id`. However, if two kernels synchronize along different axes, they must have different `collective_id`s. Failure to do so may result in race conditions that are difficult to debug.
+
+```python
+kernel = pl.pallas_call(
+ example_kernel,
+ ...,
+ compiler_params=pltpu.TPUCompilerParams(collective_id=0),
+)
+```
+
++++ {"id": "zy20AxN5TSLA"}
+
+### Double-buffering
+
+In order to avoid reading from a local `Ref` that is also being written into by another device and creating a race condition, a useful technique is the "double-buffered" strategy where we allocate a two `Ref`s for each destination value. On each iteration, one `Ref` will be designated as a "working" slot, and the other will be designated as a "receiving" slot. The device is free to use the working slot for computation, but will only copy data into its neighbor's receiving slot. The working and receiving slots alternate every iteration, so that once a copy is finished, the old receiving slot becomes the new working slot, and vice versa. Using this scheme properly, data is never read from and written to the same buffer.
+
+The following code skeleton demonstrates how double-buffering can be used. We keep a running iteration counter in the variable `iteration`, and the `working_slot` and `receiving_slot` alternate between 0 and 1 every iteration. `dst_ref` is allocated as a double-buffer and has the size `[2, ...]`. On each iteration, we read from the working slot using `dst_ref.at[working_slot, ...]` and use the value to perform computation. Simultaneously, we copy to our neighbor's `dst_ref.at[receiving_slot]` to avoid overwriting their `working_slot` value. By structuring our communication in this fashion it is possible to overlap the communication latency of the remote DMA with local computation while minimizing the risk of race conditions.
+```python
+def kernel(...):
+ # ...
+ iteration = pl.program_id(0)
+ working_slot = lax.rem(iteration, 2)
+ receiving_slot = 1 - working_slot
+ # ...
+
+ local_copy_op = pltpu.make_async_copy(
+ src_ref=dst_ref.at[working_slot, ...],
+ dst_ref=local_scratch_ref,
+ sem=local_copy_sem,
+ )
+ local_copy_op.start()
+ remote_copy_op = pltpu.make_async_remote_copy(
+ src_ref=src_ref,
+ dst_ref=dst_ref.at[receiving_slot, ...],
+ send_sem=send_sem,
+ recv_sem=recv_sem,
+ device_id=target_device,
+ device_id_type=pltpu.DeviceIdType.MESH,
+ )
+ remote_copy_op.start()
+
+ local_copy_op.wait()
+ # ... do work on local_scratch while waiting for async_copy_op to finish.
+ remote_copy_op.wait()
+
+```
+
+In terms of synchronization, the double-buffered construction works if all devices are executing on the same iteration. If a sender manages to get one iteration ahead of its receiver, it's `working_slot` and `receiving_slot` indices will be flipped compared to the receiver, meaning that it could be writing into the `working_slot` at the same time the receiver is reading from it. In order to avoid this, it may be necessary to use a semaphore to synchronize the sender with the receiver, or add additional buffering slots ("triple", "quadruple", or N-buffered) to allow additional run-ahead at the cost of more memory. In our previous `all_gather` example, note that the kernel contained a receiving buffer with N slots, which avoids race conditions altogether. In our next kernel, we will instead go through an example which uses a double-buffer with explicit synchronization.
+
++++ {"id": "Or0Itv72No5d"}
+
+### Example: All-Reduce Sum (`lax.psum`)
+
+We will now implement an all-reduce sum kernel using double-buffering and semaphores for synchronization. For those familiar with collective operations in JAX, the equivalent operation is `lax.psum`. All-reduce is a standard collective operation where the objective is to reduce along an axis of an array, but the array is sharded across multiple devices.
+
+![reduce_sum_1](../../_static/pallas/distributed/reduce_sum_1.svg)
+
+In the above example, we have the array [5, 2, 1, 3] sharded across 4 devices. An all-reduce sum operation would sum all values and replicate the result on each device, leading to the result [11, 11, 11, 11] sharded across all 4 devices.
+
+The naive implementation of all-reduce would be to gather all required values onto each device, and then reduce. However, we can improve the performance of this implementation by interleaving communication with computation. An interleaved, single-direction all-reduce can be visualized as follows. On each iteration, we receive an input value from our left neighbor, and concurrently pass input along to our next neighbor while incrementing it with our local accumulator. After N-1 iterations, each device will have a copy of the full sum in it's memory.
+
+![reduce_sum_2](../../_static/pallas/distributed/reduce_sum_2.svg)
+
+#### Putting it all together
+
+The following kernel demonstrates how to combine these principles into a functional kernel.
+
+The prologue (executed when `outer_step==0`) first initiates a barrier with both neighbors to ensure that they have also entered the kernel. It also handles initialization for all `Ref`s and handles the first remote copy to the right neighbor's "working" slot.
+
+The main body assumes that a value has already been copied into our local working slot, either from the previous iteration or from the prologue. A complicating factor is that our destination buffers live in HBM, but we need to load values to VMEM before we perform arithmetic. Therefore, we simultaneously copy the working slot value into our VMEM (`receive_scratch`) and pass the value on to our right neighbor's receiving slot. Once the value has been copied into our VMEM, we can accumulate it into our result (contained in `o_ref`).
+
+A subtle race condition can occur if one device runs one loop ahead of it's right neighbor. In this case, it could copy into the receiver's `working_slot` at the same time the receiver is reading from it. In order to avoid this, each device will block on a `REGULAR` semaphore before copying into the right neighbor's `dst_ref` until it has signaled that it is done reading from its `working_slot`. This race condition is rarely triggered for a small kernel such as this example, but can it can be explicitly triggered if for example using a `pltpu.delay` instruction to artifically hang a device.
+
+Note that this is not an optimal or fully general kernel, as the block sizes must entirely fit in VMEM and we could better interleave communication and accumulation. We will discuss these optimizations in later sections.
+
+```{code-cell} ipython3
+---
+executionInfo:
+ elapsed: 254
+ status: ok
+ timestamp: 1722904804952
+ user:
+ displayName: Justin Fu
+ userId: '17543197034567316452'
+ user_tz: 420
+id: XrY5bMlvBroQ
+outputId: 77497000-4496-462e-cc3c-73fb640cc14c
+---
+partition = P(None, 'x')
+devices = mesh_utils.create_device_mesh((1, num_devices))
+mesh = jax.sharding.Mesh(devices, partition)
+sharding = jax.sharding.NamedSharding(mesh, partition)
+
+input_arr = jax.random.uniform(jax.random.key(0), shape=(8, 128 * num_devices))
+input_arr = jax.device_put(input_arr, sharding)
+
+
+def all_reduce_kernel(
+ x_ref,
+ o_ref,
+ hbm_scratch,
+ copy_sem,
+ remote_recv_sem,
+ remote_send_sem,
+ capacity_sem,
+ receive_scratch,
+):
+ outer_step = pl.program_id(0)
+ working_slot = lax.rem(outer_step, 2)
+ receiving_slot = 1 - working_slot
+
+ my_id = lax.axis_index('x')
+ right_neighbor = lax.rem(my_id + 1, num_devices)
+ left_neighbor = lax.rem(my_id - 1 + num_devices, num_devices)
+
+ @pl.when(outer_step == 0)
+ def _():
+ # Barrier with both neighbors at the start, since we will be
+ # communicating with both.
+ barrier_sem = pltpu.get_barrier_semaphore()
+ pltpu.semaphore_signal(
+ barrier_sem,
+ inc=1,
+ device_id=(0, left_neighbor),
+ device_id_type=pltpu.DeviceIdType.MESH,
+ )
+ pltpu.semaphore_signal(
+ barrier_sem,
+ inc=1,
+ device_id=(0, right_neighbor),
+ device_id_type=pltpu.DeviceIdType.MESH,
+ )
+ pltpu.semaphore_wait(barrier_sem, 2)
+
+ # Initialize o_ref, acc_scratch, and hbm_scratch.
+ o_ref[...] = jnp.zeros_like(o_ref)
+ receive_scratch[...] = jnp.zeros_like(receive_scratch)
+ initial_copy = pltpu.make_async_remote_copy(
+ src_ref=x_ref,
+ dst_ref=hbm_scratch.at[working_slot],
+ send_sem=remote_send_sem,
+ recv_sem=remote_recv_sem,
+ device_id=(0, right_neighbor),
+ device_id_type=pltpu.DeviceIdType.MESH,
+ )
+ initial_copy.start()
+ initial_copy.wait()
+
+ # Signal to our left neighbor that we are ready to receive.
+ # Without this signal, our left neighbor can be >=1 iteration ahead,
+ # meaning it could write into our working slot.
+ pltpu.semaphore_signal(
+ capacity_sem,
+ inc=1,
+ device_id=(0, left_neighbor),
+ device_id_type=pltpu.DeviceIdType.MESH,
+ )
+
+ # Copy the partial result our left neighbor sent to us into VMEM for
+ # computation.
+ local_copy = pltpu.make_async_copy(
+ src_ref=hbm_scratch.at[working_slot],
+ dst_ref=receive_scratch,
+ sem=copy_sem,
+ )
+ local_copy.start()
+
+ # Block until our right neighbor is ready to receive.
+ pltpu.semaphore_wait(capacity_sem, 1)
+ # Pass the value to our right neighbor.
+ remote_copy = pltpu.make_async_remote_copy(
+ src_ref=hbm_scratch.at[working_slot],
+ dst_ref=hbm_scratch.at[receiving_slot],
+ send_sem=remote_send_sem,
+ recv_sem=remote_recv_sem,
+ device_id=(0, right_neighbor),
+ device_id_type=pltpu.DeviceIdType.MESH,
+ )
+ remote_copy.start()
+ # Finish local copy and accumulate while remote_copy is happening.
+ local_copy.wait()
+ o_ref[...] += receive_scratch[...]
+ # Block until remote copy finishes.
+ remote_copy.wait()
+
+
+out_shape = (
+ jax.ShapeDtypeStruct((8, 128), jnp.float32),
+ # We allocate the double-buffer as a Pallas output so that it is
+ # resident in HBM.
+ jax.ShapeDtypeStruct((2, 8, 128), jnp.float32), # hbm_scratch
+)
+
+grid_spec = pltpu.PrefetchScalarGridSpec(
+ num_scalar_prefetch=0,
+ in_specs=[
+ # Our input lives in VMEM
+ pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.VMEM),
+ ],
+ out_specs=[
+ # Our output lives in VMEM
+ pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.VMEM),
+ # Our double-buffer lives in HBM
+ pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY),
+ ],
+ grid=(num_devices,),
+ scratch_shapes=(
+ [pltpu.SemaphoreType.DMA] * 3
+ + [pltpu.SemaphoreType.REGULAR] # capacity_sem
+ + [pltpu.VMEM((8, 128), jnp.float32)] # receive_scratch
+ ),
+)
+
+kernel = pl.pallas_call(
+ all_reduce_kernel,
+ out_shape=out_shape,
+ grid_spec=grid_spec,
+ compiler_params=pltpu.TPUCompilerParams(collective_id=0),
+)
+
+pallas_result = jax.jit(
+ shard_map.shard_map(
+ kernel,
+ mesh=mesh,
+ in_specs=partition,
+ out_specs=partition,
+ check_rep=False,
+ )
+)(input_arr)
+pallas_result = jax.block_until_ready(pallas_result)[0]
+
+
+def lax_sum(x):
+ return lax.psum(x, 'x')
+
+
+xla_result = jax.jit(
+ shard_map.shard_map(
+ lax_sum, mesh=mesh, in_specs=P(None, 'x'), out_specs=P(None, 'x')
+ )
+)(input_arr)
+
+print('Input = ', input_arr[0, ::128])
+print('Pallas result = ', pallas_result[0, ::128])
+print('lax.psum result = ', xla_result[0, ::128])
+difference = jnp.mean(jnp.abs(pallas_result - xla_result))
+print('Difference |Pallas - lax.psum| = ', difference)
+```
+
++++ {"id": "d8bsZAzQreC_"}
+
+### Run-ahead and Race Conditions
+
+As a general rule of thumb, to maximize performance we want to allow a device to run-ahead of other devices without synchronization as much as possible without sacrificing correctness of the program. While we could enforce a barrier across all devices at the beginning of each iteration, this bottlenecks the performance of the program to the slowest device on each loop. By relaxing synchronization and allowing a moderate amount of run-ahead, we can better accommodate variance in latency between iterations and devices because a device that is slow on one iteration could catch up on the next iteration.
+
+In the all-reduce kernel we wrote previously, we allow devices to run ahead but by less than one iteration compared to its neighbors (however, non-neighboring devices could be more than 1 iteration apart). To see why the semaphore synchronization is necessary, consider the case when one device (say device 2) hangs and falls behind the other devices. An RDMA has no "handshake" — only the receiver is blocked while waiting for the data to arrive. Therefore, each device can run up to one iteration ahead before it becomes blocked waiting for the next RDMA to arrive. If we have N devices, this means that the final device can be up to N iterations ahead of the first device.
+
+![race_condition](../../_static/pallas/distributed/race_condition.svg)
+
+Without adding synchronization in the other direction (forcing senders to block), device 1 could potentially run up to `N` iterations (`N = num_devices`) ahead of device 2, sending multiple writes and overwriting values in the process. To solve this in the `all_reduce` kernel we wrote previously we implemented a "handshake" protocol where the receiver signals back to the sender that it is ready to receive, and only then does the sender begin issuing the next RDMA.
+
++++ {"id": "UD8lNrqsUeXy"}
+
+### Bi-directional Communication
+
+In our previous kernels, we communicated in a single direction around a ring from left-to-right. However, as ICI connections are bi-directional, we are effectively wasting half of the total bandwidth by not sending values in the opposite direction from right-to-left. In this next kernel we will demonstrate an example which communicates in both directions to maximize ICI bandwidth.
+
++++ {"id": "4KjakLhbBk73"}
+
+### Example: Bi-directional Reduce-Scatter (`lax.psum_scatter`)
+
+A reduce-scatter operation is the combination of an all-reduce followed by a scatter. Or alternatively, an all-reduce is the combination of a reduce-scatter followed by all-gather.
+
+The following graphic depicts the semantics of this operation. We assume that each device starts with a collection of partial sums (denoted by a letter + number, such as `A0`). The goal is to reduce along one axis (numbers), while sharding along the other axis (letters).
+
+![reduce_scatter_1](../../_static/pallas/distributed/reduce_scatter_1.svg)
+
+In order to implement a bi-directional communication strategy, we slice each input block in half, and designate a direction for each half. The top half of each block will be passed from right-to-left, and the bottom half will be passed from left-to-right. A second deviation from the communication patterns of our previous all-reduce and all-gather kernels is that we will also pass around accumulators or partial sums and keep the inputs local to each device. This is in contrast to the previous examples where we passed around inputs but kept the accumulator local to the device. Passing around the accumulator is a more natural fit for this problem as in contrast to all-reduce, most of the data in the inputs are not part of the output that will be stored locally on the device. (e.g. `B0`, `C0`, and `D0` in the above graphic will not be stored on the device holding `A` at the end).
+
+The following diagram illustrates this communication pattern, where the colored boxes represent accumulators (not inputs!). Initially, the accumulator is simply the value that was contained in the input. At each iteration of the algorithm, we will receive a partial sum from our neighbors in each direction. We then compute the correct slice of our input to accumulate into the partial buffer, then pass the new partial sum along to our next neighbor. After N iterations, the accumulator will have passed through each device, meaning that it will hold the full sum in the end.
+
+![reduce_scatter_2](../../_static/pallas/distributed/reduce_scatter_2.svg)
+
+In terms of construction of the kernel, we introduce an additional `phase` dimension to the Pallas grid, which denotes which accumulator (left or right) we are currently computing on. We let `phase=0` denote the accumulator moving to the left, and `phase=1` denote the accumulator moving to the right. We then pipeline the two phases, such that while computing the result for one phase we are transferring our previously computed values in the opposite direction in preparation for the next phase. For example, when we are on `phase=0` (left), we first begin a DMA to transfer results we computed in the previous iteration to our right neighbor (right-DMA). Then, we accumulate into the left-buffer and save the result to HBM. We then wait for the right-DMA to complete so that it is ready for `phase=1` (right).
+
+```{code-cell} ipython3
+---
+executionInfo:
+ elapsed: 544
+ status: ok
+ timestamp: 1722904805699
+ user:
+ displayName: Justin Fu
+ userId: '17543197034567316452'
+ user_tz: 420
+id: nRauUAxNHg28
+---
+partition = P(None, 'x')
+devices = mesh_utils.create_device_mesh((1, num_devices))
+mesh = jax.sharding.Mesh(devices, partition)
+sharding = jax.sharding.NamedSharding(mesh, partition)
+
+# We need a block size of (16, 128) to ensure that a half-slice is at least
+# of size (8, 128), which is the size of a VREG. This makes tiling easier
+# for the compiler.
+block_size = (16, 128)
+input_arr = jax.random.uniform(
+ jax.random.key(0),
+ shape=(block_size[0] * num_devices, block_size[1] * num_devices),
+)
+input_arr = jax.device_put(input_arr, sharding)
+
+LEFT = 0
+RIGHT = 1
+
+
+def mod(x, n):
+ return lax.rem(x + n, n)
+
+
+def signal(left_or_right, semaphore):
+ my_id = lax.axis_index('x')
+ if left_or_right == LEFT:
+ neighbor = mod(my_id - 1, num_devices)
+ else:
+ neighbor = mod(my_id + 1, num_devices)
+ pltpu.semaphore_signal(
+ semaphore,
+ inc=1,
+ device_id=(0, neighbor),
+ device_id_type=pltpu.DeviceIdType.MESH,
+ )
+
+
+def reduce_scatter_kernel(
+ x_ref,
+ o_ref,
+ hbm_scratch,
+ local_copy_sem,
+ left_recv_sem,
+ left_send_sem,
+ right_recv_sem,
+ right_send_sem,
+ left_capacity_sem,
+ right_capacity_sem,
+ accum_scratch,
+):
+ outer_step = pl.program_id(0)
+ phase = pl.program_id(1)
+ is_start = jnp.logical_and(outer_step == 0, phase == 0)
+ last_iteration = outer_step == pl.num_programs(0) - 1
+
+ working_slot = lax.rem(outer_step, 2)
+ receiving_slot = 1 - working_slot
+ my_id = lax.axis_index('x')
+ right_neighbor = mod(my_id + 1, num_devices)
+ left_neighbor = mod(my_id - 1, num_devices)
+
+ left_copy_device = mod(my_id + outer_step + 1, num_devices)
+ right_copy_device = mod(my_id - outer_step - 1, num_devices)
+ # Slices can be specified using pl.ds(start, size)
+ left_copy_slice = pl.ds(0, block_size[0] // 2)
+ right_copy_slice = pl.ds(block_size[0] // 2, block_size[0] // 2)
+ current_phase_slice = pl.ds(phase * (block_size[0] // 2), block_size[0] // 2)
+
+ initial_left_copy = pltpu.make_async_remote_copy(
+ src_ref=x_ref.at[my_id, left_copy_slice],
+ dst_ref=hbm_scratch.at[working_slot, left_copy_slice],
+ send_sem=left_send_sem,
+ recv_sem=left_recv_sem,
+ device_id=(0, left_neighbor),
+ device_id_type=pltpu.DeviceIdType.MESH,
+ )
+
+ initial_right_copy = pltpu.make_async_remote_copy(
+ src_ref=x_ref.at[my_id, right_copy_slice],
+ dst_ref=hbm_scratch.at[working_slot, right_copy_slice],
+ send_sem=right_send_sem,
+ recv_sem=right_recv_sem,
+ device_id=(0, right_neighbor),
+ device_id_type=pltpu.DeviceIdType.MESH,
+ )
+
+ left_copy = pltpu.make_async_remote_copy(
+ src_ref=hbm_scratch.at[working_slot, left_copy_slice],
+ dst_ref=hbm_scratch.at[receiving_slot, left_copy_slice],
+ send_sem=left_send_sem,
+ recv_sem=left_recv_sem,
+ device_id=(0, left_neighbor),
+ device_id_type=pltpu.DeviceIdType.MESH,
+ )
+ right_copy = pltpu.make_async_remote_copy(
+ # Note: Right copy is flipped with regards to slots since we are copying
+ # to the next outer_step iteration.
+ src_ref=hbm_scratch.at[receiving_slot, right_copy_slice],
+ dst_ref=hbm_scratch.at[working_slot, right_copy_slice],
+ send_sem=right_send_sem,
+ recv_sem=right_recv_sem,
+ device_id=(0, right_neighbor),
+ device_id_type=pltpu.DeviceIdType.MESH,
+ )
+
+ # --- Prologue ---
+ @pl.when(is_start)
+ def _():
+ # Barrier with both neighbors at the start, since we will be
+ # communicating with both.
+ barrier_sem = pltpu.get_barrier_semaphore()
+ pltpu.semaphore_signal(
+ barrier_sem,
+ inc=1,
+ device_id=(0, left_neighbor),
+ device_id_type=pltpu.DeviceIdType.MESH,
+ )
+ pltpu.semaphore_signal(
+ barrier_sem,
+ inc=1,
+ device_id=(0, right_neighbor),
+ device_id_type=pltpu.DeviceIdType.MESH,
+ )
+ pltpu.semaphore_wait(barrier_sem, 2)
+
+ # Initialize o_ref, acc_scratch, and hbm_scratch with initial copies.
+ o_ref[...] = jnp.zeros_like(o_ref[...])
+ accum_scratch[...] = jnp.zeros_like(accum_scratch[...])
+
+ initial_left_copy.start()
+ initial_left_copy.wait()
+ initial_right_copy.start()
+
+ # We tell our left neighbor that it is allowed to send to the right.
+ # (and vice versa for right neighbor)
+ signal(LEFT, right_capacity_sem)
+ signal(RIGHT, left_capacity_sem)
+
+ # --- Body ---
+ # At the beginning of our kernel body, we start a DMA which copies
+ # the result we computed in the previous phase to our neighbor.
+ # This allows us to overlap the communication of sending our previous phase
+ # with the computation for the current phase.
+ @pl.when(~is_start)
+ def _():
+ @pl.when(phase == LEFT)
+ def _():
+ # We block here until our right neighbor tells use we can send to
+ # the right.
+ pltpu.semaphore_wait(right_capacity_sem, 1)
+ right_copy.start()
+
+ @pl.when(phase == RIGHT)
+ def _():
+ # We block here until our left neighbor tells use we can send to
+ # the left.
+ pltpu.semaphore_wait(left_capacity_sem, 1)
+ left_copy.start()
+
+ local_copy = pltpu.make_async_copy(
+ src_ref=hbm_scratch.at[working_slot, current_phase_slice],
+ dst_ref=accum_scratch,
+ sem=local_copy_sem,
+ )
+ local_copy.start()
+ local_copy.wait()
+
+ @pl.when(~last_iteration)
+ def _():
+ @pl.when(phase == LEFT)
+ def _():
+ accum_scratch[...] += x_ref[left_copy_device, left_copy_slice]
+
+ @pl.when(phase == RIGHT)
+ def _():
+ accum_scratch[...] += x_ref[right_copy_device, right_copy_slice]
+
+ local_copy = pltpu.make_async_copy(
+ src_ref=accum_scratch,
+ dst_ref=hbm_scratch.at[working_slot, current_phase_slice],
+ sem=local_copy_sem,
+ )
+ local_copy.start()
+ local_copy.wait()
+
+ @pl.when(is_start)
+ def _():
+ initial_right_copy.wait()
+
+ # At the end of our kernel body, we wait on the DMA of the previous phase
+ # to make sure the results are ready for the next phase.
+ @pl.when(~is_start)
+ def _():
+ @pl.when(phase == LEFT)
+ def _():
+ right_copy.wait()
+ signal(LEFT, right_capacity_sem)
+
+ @pl.when(phase == RIGHT)
+ def _():
+ left_copy.wait()
+ signal(RIGHT, left_capacity_sem)
+
+ # --- Epilogue ---
+ # Store result on last iteration.
+ @pl.when(last_iteration)
+ def _():
+ # Clean up semaphores so that they exit with a value of 0.
+ @pl.when(phase == LEFT)
+ def _():
+ o_ref[left_copy_slice, ...] = accum_scratch[...]
+ pltpu.semaphore_wait(right_capacity_sem, 1)
+
+ @pl.when(phase == RIGHT)
+ def _():
+ o_ref[right_copy_slice, ...] = accum_scratch[...]
+ pltpu.semaphore_wait(left_capacity_sem, 1)
+
+
+out_shape = (
+ jax.ShapeDtypeStruct((block_size[0], block_size[1]), jnp.float32), # output
+ # Shape: [working/recv, block[0], block[1]]
+ jax.ShapeDtypeStruct(
+ (2, block_size[0], block_size[1]), jnp.float32
+ ), # hbm_scratch
+)
+
+grid_spec = pltpu.PrefetchScalarGridSpec(
+ num_scalar_prefetch=0,
+ in_specs=[
+ pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.VMEM),
+ ],
+ out_specs=[
+ pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.VMEM),
+ pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY),
+ ],
+ grid=(num_devices, 2),
+ scratch_shapes=(
+ [pltpu.SemaphoreType.DMA] * 5
+ + [pltpu.SemaphoreType.REGULAR] * 2 # Capacity semaphores
+ + [
+ pltpu.VMEM((block_size[0] // 2, block_size[1]), jnp.float32)
+ ] # accum_scratch
+ ),
+)
+
+
+def pallas_reduce_scatter(input_arr):
+ input_arr = input_arr.reshape(num_devices, block_size[0], block_size[1])
+ return pl.pallas_call(
+ reduce_scatter_kernel,
+ out_shape=out_shape,
+ grid_spec=grid_spec,
+ compiler_params=pltpu.TPUCompilerParams(collective_id=0),
+ )(input_arr)[0]
+
+
+pallas_result = jax.jit(
+ shard_map.shard_map(
+ pallas_reduce_scatter,
+ mesh=mesh,
+ in_specs=P(None, 'x'),
+ out_specs=P('x', None),
+ check_rep=False,
+ )
+)(input_arr)
+
+pallas_result = jax.block_until_ready(pallas_result)
+```
+
+```{code-cell} ipython3
+---
+executionInfo:
+ elapsed: 596
+ status: ok
+ timestamp: 1722904806442
+ user:
+ displayName: Justin Fu
+ userId: '17543197034567316452'
+ user_tz: 420
+id: E-NMh-_teoi4
+outputId: 24beb42f-1bdd-4c34-e8d2-681dd7f2e9c0
+---
+# Compare our result to XLA.
+def lax_reduce_sum_scatter(x):
+ x = x.reshape(num_devices, block_size[0], block_size[1])
+ return lax.psum_scatter(x, 'x')
+
+
+xla_result = jax.jit(
+ shard_map.shard_map(
+ lax_reduce_sum_scatter,
+ mesh=mesh,
+ in_specs=P(None, 'x'),
+ out_specs=P('x', None),
+ )
+)(input_arr)
+
+print('Input:', input_arr.shape, input_arr[::4, 0])
+print('Pallas Result:', pallas_result.shape, pallas_result[::4, 0])
+print('lax.psum_scatter Result:', xla_result.shape, xla_result[::4, 0])
+print(
+ 'Difference |Pallas - lax.psum_scatter|:',
+ jnp.max(jnp.abs(pallas_result - xla_result)),
+)
+```
+
++++ {"id": "ThKas40r40Ji"}
+
+### Nested Remote and Local DMA Pipelines
+
+A limitation of the previous all-reduce and reduce-scatter kernels that we wrote is that the blocks we copy via remote DMA must be small enough to fit in our working VMEM that we use for accumulation. For some kernels it may be advantageous to use larger block sizes to better utilize the TPU. For example, a matrix multiplication requires on the order of $O(N^3)$ compute operations, but only $O(N^2)$ memory transfers. Therefore, we want each block of work transferred between devices to be large enough such that the operation becomes compute bound and we can hide the communication cost using pipelining. For reference, the VMEM of a TPU (for generations v4/v5) is typically on the order of 10-100MB, whereas HBM ranges from 10-100GB.
+
+To address this problem, we need to be able to write an "inner kernel" that handles local HBM-VMEM pipelining inside of the "outer kernel" that handles pipelining larger HBM-HBM transfers between devices. Pallas offers an API for constructing nested pipelines using the `emit_pipeline` function. The basic call signature for `emit_pipeline` follows that of a standard `pallas_call` by specifying a `grid` and `BlockSpec`s for the inputs and outputs:
+
+```python
+def emit_pipeline(
+ kernel: Callable,
+ grid: tuple[int],
+ in_specs: PyTree[BlockSpec] = None,
+ out_specs: PyTree[BlockSpec] = None,
+ should_accumulate_out: bool = False,
+ dimension_semantics: tuple[GridDimensionSemantics] = None,
+) -> Callable:
+ ... # Returns a custom pipeline given an inner kernel and BlockSpecs.
+```
+
+Indeed, one can view `pallas_call` itself as simply a wrapper around `emit_pipeline`. Because our outer kernel only involves remote HBM-HBM transfers, we are not using any of the built-in pipelining that `pallas_call` provides for HBM-VMEM transfers. The following code skeleton demonstrates what a typical program structure would look like using this pattern:
+
+```python
+
+def outer_kernel(...):
+ # ... do work to pipeline remote HBM-HBM transfers (outer kernel)
+
+ def inner_kernel(...):
+ # ... do work (inner kernel)
+ pltpu.emit_pipeline(
+ inner_kernel,
+ grid=inner_grid,
+ in_specs=...,
+ out_specs=...,
+ )(inner_kernel_args)
+ # ... do more work (outer kernel)
+
+pl.pallas_call(
+ outer_kernel,
+ grid=outer_grid,
+ in_specs=...
+ out_specs=...
+ scratch=inner_kernel_allocs
+)
+```
+
++++ {"id": "DzFeQjYaasX5"}
+
+### Example: Reduce-Scatter with large HBM blocks
+
+In this next example we will modify our previous reduce-scatter example to utilize a nested inner pipeline. Note that the communication and computation costs of `reduce_scatter` both scale linearly with the size of the input, so we do not necessarily expect to see the operation become compute-bound with larger block sizes. This example is purely for demonstration purposes on how to use the pipeline emitter.
+
+We will increase the block sizes of the outer kernel such that they would be undesirable to place inside of VMEM, and allocate all inputs and outputs in HBM (`memory_space=TPUMemorySpace.Any`). The only major change from our previous kernel is the body of the kernel where accumulation is done. Rather than manually copying from HBM to VMEM, accumulating, and copying back to HBM, we use `emit_pipeline` to handle the memory transfers for us. Accumulation is done in an inner kernel with a much smaller, VMEM-friendly block size.
+
+In our previous kernel we had the following kernel body to copy data from HBM to the VMEM accumulator, increment, and then copy the results back to HBM:
+
+```python
+local_copy = pltpu.make_async_copy(
+ src_ref=hbm_scratch.at[working_slot, current_phase_slice],
+ dst_ref=accum_scratch,
+ sem=local_copy_sem,
+)
+local_copy.start()
+local_copy.wait()
+@pl.when(~last_iteration)
+def _():
+ @pl.when(phase == LEFT)
+ def _():
+ accum_scratch[...] += x_ref[left_copy_device, left_copy_slice]
+ @pl.when(phase == RIGHT)
+ def _():
+ accum_scratch[...] += x_ref[right_copy_device, right_copy_slice]
+local_copy = pltpu.make_async_copy(
+ src_ref=accum_scratch,
+ dst_ref=hbm_scratch.at[working_slot, current_phase_slice],
+ sem=local_copy_sem,
+)
+local_copy.start()
+local_copy.wait()
+```
+
+Our new kernel replaces it with the following `emit_pipeline` call:
+
+```python
+def inner_kernel(input_ref, accum_ref):
+ accum_ref[...] = input_ref[...]
+accum_pipeline = pltpu.emit_pipeline(inner_kernel,
+ in_specs=[inner_block_spec],
+ out_specs=inner_block_spec,
+ should_accumulate_out=True,
+ grid=inner_grid)
+@pl.when(~last_iteration)
+def _():
+ @pl.when(phase == LEFT)
+ def _():
+ accum_pipeline(x_ref.at[left_copy_device, left_copy_slice],
+ hbm_scratch.at[working_slot, left_copy_slice],
+ )
+ @pl.when(phase == RIGHT)
+ def _():
+ accum_pipeline(x_ref.at[right_copy_device, right_copy_slice],
+ hbm_scratch.at[working_slot, right_copy_slice],
+ )
+```
+
+The full kernel is as follows:
+
+```{code-cell} ipython3
+---
+executionInfo:
+ elapsed: 1341
+ status: ok
+ timestamp: 1722904807930
+ user:
+ displayName: Justin Fu
+ userId: '17543197034567316452'
+ user_tz: 420
+id: 27jni-pSartL
+---
+partition = P(None, 'x')
+devices = mesh_utils.create_device_mesh((1, num_devices))
+mesh = jax.sharding.Mesh(devices, partition)
+sharding = jax.sharding.NamedSharding(mesh, partition)
+
+# We pick a large outer kernel block size that we do not want to place
+# in VMEM. For pedagogical purposes we use (4096, 4096), although in
+# principle this can be much larger.
+outer_block_size = (4096, 4096)
+# We pick a smaller VMEM block size for the inner kernel.
+inner_block_size = (128, 128)
+input_arr = jax.random.uniform(
+ jax.random.key(0),
+ shape=(
+ outer_block_size[0] * num_devices,
+ outer_block_size[1] * num_devices,
+ ),
+)
+input_arr = jax.device_put(input_arr, sharding)
+
+
+inner_grid = (
+ outer_block_size[0] // inner_block_size[0] // 2,
+ outer_block_size[1] // inner_block_size[1],
+)
+inner_block_spec = pl.BlockSpec(
+ index_map=lambda i, j: (i, j),
+ block_shape=inner_block_size,
+ memory_space=pltpu.TPUMemorySpace.ANY,
+)
+
+
+def reduce_scatter_kernel(
+ x_ref,
+ o_ref,
+ hbm_scratch,
+ left_recv_sem,
+ left_send_sem,
+ copy_sem,
+ right_recv_sem,
+ right_send_sem,
+ left_capacity_sem,
+ right_capacity_sem,
+):
+ outer_step = pl.program_id(0)
+ phase = pl.program_id(1)
+ is_start = jnp.logical_and(outer_step == 0, phase == 0)
+ last_iteration = outer_step == pl.num_programs(0) - 1
+
+ working_slot = lax.rem(outer_step, 2)
+ receiving_slot = 1 - working_slot
+ my_id = lax.axis_index('x')
+ right_neighbor = mod(my_id + 1, num_devices)
+ left_neighbor = mod(my_id - 1, num_devices)
+
+ left_copy_device = mod(my_id + outer_step + 1, num_devices)
+ right_copy_device = mod(my_id - outer_step - 1, num_devices)
+ left_copy_slice = pl.ds(0, outer_block_size[0] // 2)
+ right_copy_slice = pl.ds(outer_block_size[0] // 2, outer_block_size[0] // 2)
+ current_phase_slice = pl.ds(
+ phase * (outer_block_size[0] // 2), outer_block_size[0] // 2
+ )
+
+ initial_left_copy = pltpu.make_async_remote_copy(
+ src_ref=x_ref.at[my_id, left_copy_slice],
+ dst_ref=hbm_scratch.at[working_slot, left_copy_slice],
+ send_sem=left_send_sem,
+ recv_sem=left_recv_sem,
+ device_id=(0, left_neighbor),
+ device_id_type=pltpu.DeviceIdType.MESH,
+ )
+
+ initial_right_copy = pltpu.make_async_remote_copy(
+ src_ref=x_ref.at[my_id, right_copy_slice],
+ dst_ref=hbm_scratch.at[working_slot, right_copy_slice],
+ send_sem=right_send_sem,
+ recv_sem=right_recv_sem,
+ device_id=(0, right_neighbor),
+ device_id_type=pltpu.DeviceIdType.MESH,
+ )
+
+ left_copy = pltpu.make_async_remote_copy(
+ src_ref=hbm_scratch.at[working_slot, left_copy_slice],
+ dst_ref=hbm_scratch.at[receiving_slot, left_copy_slice],
+ send_sem=left_send_sem,
+ recv_sem=left_recv_sem,
+ device_id=(0, left_neighbor),
+ device_id_type=pltpu.DeviceIdType.MESH,
+ )
+ right_copy = pltpu.make_async_remote_copy(
+ src_ref=hbm_scratch.at[receiving_slot, right_copy_slice],
+ dst_ref=hbm_scratch.at[working_slot, right_copy_slice],
+ send_sem=right_send_sem,
+ recv_sem=right_recv_sem,
+ device_id=(0, right_neighbor),
+ device_id_type=pltpu.DeviceIdType.MESH,
+ )
+
+ # --- Prologue ---
+ @pl.when(is_start)
+ def _():
+ # Barrier with both neighbors at the start, since we will be
+ # communicating with both.
+ barrier_sem = pltpu.get_barrier_semaphore()
+ pltpu.semaphore_signal(
+ barrier_sem,
+ inc=1,
+ device_id=(0, left_neighbor),
+ device_id_type=pltpu.DeviceIdType.MESH,
+ )
+ pltpu.semaphore_signal(
+ barrier_sem,
+ inc=1,
+ device_id=(0, right_neighbor),
+ device_id_type=pltpu.DeviceIdType.MESH,
+ )
+ pltpu.semaphore_wait(barrier_sem, 2)
+
+ initial_left_copy.start()
+ initial_left_copy.wait()
+ initial_right_copy.start()
+
+ # We tell our left neighbor that it is allowed to send to the right.
+ # (and vice versa for right neighbor)
+ signal(LEFT, right_capacity_sem)
+ signal(RIGHT, left_capacity_sem)
+
+ @pl.when(~is_start)
+ def _():
+ @pl.when(phase == LEFT)
+ def _():
+ # We block here until our right neighbor tells use we can send to
+ # the right.
+ pltpu.semaphore_wait(right_capacity_sem, 1)
+ right_copy.start()
+
+ @pl.when(phase == RIGHT)
+ def _():
+ # We block here until our left neighbor tells use we can send to
+ # the left.
+ pltpu.semaphore_wait(left_capacity_sem, 1)
+ left_copy.start()
+
+ # --- Body ---
+ def inner_kernel(input_ref, accum_ref):
+ # We do not explicitly use += because we set should_accumulate_out=True.
+ accum_ref[...] = input_ref[...]
+
+ accum_pipeline = pltpu.emit_pipeline(
+ inner_kernel,
+ in_specs=[inner_block_spec],
+ out_specs=inner_block_spec,
+ should_accumulate_out=True,
+ grid=inner_grid,
+ )
+
+ @pl.when(~last_iteration)
+ def _():
+ @pl.when(phase == LEFT)
+ def _():
+ accum_pipeline(
+ x_ref.at[left_copy_device, left_copy_slice],
+ hbm_scratch.at[working_slot, left_copy_slice],
+ )
+
+ @pl.when(phase == RIGHT)
+ def _():
+ accum_pipeline(
+ x_ref.at[right_copy_device, right_copy_slice],
+ hbm_scratch.at[working_slot, right_copy_slice],
+ )
+
+ # --- Epilogue ---
+ @pl.when(is_start)
+ def _():
+ initial_right_copy.wait()
+
+ @pl.when(~is_start)
+ def _():
+ @pl.when(phase == LEFT)
+ def _():
+ right_copy.wait()
+ signal(LEFT, right_capacity_sem)
+
+ @pl.when(phase == RIGHT)
+ def _():
+ left_copy.wait()
+ signal(RIGHT, left_capacity_sem)
+
+ # Store result on last iteration.
+ @pl.when(last_iteration)
+ def _():
+ output_copy = pltpu.make_async_copy(
+ src_ref=hbm_scratch.at[working_slot, current_phase_slice],
+ dst_ref=o_ref.at[current_phase_slice],
+ sem=copy_sem,
+ )
+ output_copy.start()
+ output_copy.wait()
+
+ # Clean up semaphores so that they exit with a value of 0.
+ @pl.when(phase == LEFT)
+ def _():
+ pltpu.semaphore_wait(right_capacity_sem, 1)
+
+ @pl.when(phase == RIGHT)
+ def _():
+ pltpu.semaphore_wait(left_capacity_sem, 1)
+
+
+out_shape = (
+ jax.ShapeDtypeStruct(
+ (outer_block_size[0], outer_block_size[1]), jnp.float32
+ ),
+ # Shape: [working/recv, block[0], block[1]]
+ jax.ShapeDtypeStruct(
+ (2, outer_block_size[0], outer_block_size[1]), jnp.float32
+ ), # hbm_scratch
+)
+
+grid_spec = pltpu.PrefetchScalarGridSpec(
+ num_scalar_prefetch=0,
+ in_specs=[
+ pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY),
+ ],
+ out_specs=[
+ pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY),
+ pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY),
+ ],
+ grid=(num_devices, 2),
+ scratch_shapes=(
+ [pltpu.SemaphoreType.DMA] * 5
+ + [pltpu.SemaphoreType.REGULAR] * 2 # Capacity semaphores
+ ),
+)
+
+
+def pallas_reduce_scatter(input_arr):
+ input_arr = input_arr.reshape(
+ num_devices, outer_block_size[0], outer_block_size[1]
+ )
+ return pl.pallas_call(
+ reduce_scatter_kernel,
+ out_shape=out_shape,
+ grid_spec=grid_spec,
+ compiler_params=pltpu.TPUCompilerParams(collective_id=0),
+ )(input_arr)[0]
+
+
+pallas_result = jax.jit(
+ shard_map.shard_map(
+ pallas_reduce_scatter,
+ mesh=mesh,
+ in_specs=P(None, 'x'),
+ out_specs=P('x', None),
+ check_rep=False,
+ )
+)(input_arr)
+
+pallas_result = jax.block_until_ready(pallas_result)
+```
+
+```{code-cell} ipython3
+---
+executionInfo:
+ elapsed: 768
+ status: ok
+ timestamp: 1722904808851
+ user:
+ displayName: Justin Fu
+ userId: '17543197034567316452'
+ user_tz: 420
+id: cTEyiMDyx9Y0
+outputId: 1de26695-3713-430e-9ab4-4ea646691680
+---
+# Now we compare our result to XLA.
+def lax_reduce_sum_scatter(x):
+ x = x.reshape(num_devices, outer_block_size[0], outer_block_size[1])
+ return lax.psum_scatter(x, 'x')
+
+
+xla_result = jax.jit(
+ shard_map.shard_map(
+ lax_reduce_sum_scatter,
+ mesh=mesh,
+ in_specs=P(None, 'x'),
+ out_specs=P('x', None),
+ )
+)(input_arr)
+
+print('Input:', input_arr.shape, input_arr[::4, 0])
+print('Pallas Result:', pallas_result.shape, pallas_result[::4, 0])
+print('lax.psum_scatter Result:', xla_result.shape, xla_result[::4, 0])
+print(
+ 'Difference |Pallas - lax.psum_scatter|:',
+ jnp.max(jnp.abs(pallas_result - xla_result)),
+)
+```
+
++++ {"id": "zz5AFbriliyv"}
+
+## Final Notes
+
+### Megacore
+
+Certain TPUs contain multiple cores in a [Megacore](https://jax.readthedocs.io/en/latest/pallas/tpu/pipelining.html#tpus-in-megacore-configuration) configuration. In this configuration, our general recommendation is to only initiate DMAs from a single core, and only perform HBM-HBM transfers. To do this, set one of the grid axes to the number of cores (can be obtained via `jax.devices()[0].num_cores`) and the dimension_semantics to `"parallel"`. Then, you can use `core_index = pl.program_id(axis)` to obtain the core index along that axis, and use `@pl.when(core_index==i)` to execute code specific to that core.
+
+### Interaction with XLA
+
+In this tutorial we covered several kernel examples which replicate the functionality of collective operations in JAX such as `lax.all_gather`, `lax.psum`, and `lax.psum_scatter`. An important caveat to note is that a Pallas kernel is somewhat opaque to the XLA compiler and may cause it to miss some optimizations it would normally perform. For example, XLA can asynchronously dispatch collective operations in order to interleave communication and computation without writing a custom kernel. This is not guaranteed to happen when Pallas kernels are involved so it is important to profile your program to see if this is an issue. Another example is the fact that the `emit_pipeline` function we used in this tutorial to generate nested pipelines is not visible to the XLA compiler, and therefore cannot be fused with neighboring operations.
+
+### Next Steps
+
+Excellent follow-up excercises for the reader could include implementing a distributed matrix multiplication, implementing `lax.all_to_all`, and relaxing synchronization to allow for additional run-ahead.
diff --git a/docs/pallas/tpu/index.rst b/docs/pallas/tpu/index.rst
index 5680481f3947..20abad5f610e 100644
--- a/docs/pallas/tpu/index.rst
+++ b/docs/pallas/tpu/index.rst
@@ -9,3 +9,6 @@ TPU specific documentation.
details
pipelining
matmul
+ sparse
+ distributed
+
diff --git a/docs/pallas/tpu/matmul.ipynb b/docs/pallas/tpu/matmul.ipynb
index 0bd16095cb7e..51ce2ed6868f 100644
--- a/docs/pallas/tpu/matmul.ipynb
+++ b/docs/pallas/tpu/matmul.ipynb
@@ -210,8 +210,8 @@
" pl.BlockSpec((bk, bn), lambda i, j, k: (k, j))],\n",
" out_specs=pl.BlockSpec((bm, bn), lambda i, j, k: (i, j)),\n",
" grid=(m // bm, n // bn, k // bk),\n",
- " compiler_params=dict(mosaic=dict(\n",
- " dimension_semantics=(\"parallel\", \"parallel\", \"arbitrary\"))),\n",
+ " compiler_params=pltpu.TPUCompilerParams(\n",
+ " dimension_semantics=(\"parallel\", \"parallel\", \"arbitrary\")),\n",
" )(x, y)"
]
},
@@ -466,8 +466,8 @@
" grid=(m // bm, n // bn, k // bk),\n",
" ),\n",
" out_shape=jax.ShapeDtypeStruct((m, n), x.dtype),\n",
- " compiler_params=dict(mosaic=dict(\n",
- " dimension_semantics=(\"parallel\", \"parallel\", \"arbitrary\"))),\n",
+ " compiler_params=pltpu.TPUCompilerParams(\n",
+ " dimension_semantics=(\"parallel\", \"parallel\", \"arbitrary\")),\n",
" )(x, y)"
]
},
@@ -741,8 +741,8 @@
" grid=(m // bm, n // bn, k // bk),\n",
" ),\n",
" out_shape=jax.ShapeDtypeStruct((m, n), x.dtype),\n",
- " compiler_params=dict(mosaic=dict(\n",
- " dimension_semantics=(\"parallel\", \"parallel\", \"arbitrary\"))),\n",
+ " compiler_params=pltpu.TPUCompilerParams(\n",
+ " dimension_semantics=(\"parallel\", \"parallel\", \"arbitrary\")),\n",
" )(x, y)"
]
},
@@ -929,8 +929,8 @@
" grid=(m // bm, n // bn, k // bk),\n",
" ),\n",
" out_shape=jax.ShapeDtypeStruct((m, n), x.dtype),\n",
- " compiler_params=dict(mosaic=dict(\n",
- " dimension_semantics=(\"parallel\", \"parallel\", \"arbitrary\"))),\n",
+ " compiler_params=pltpu.TPUCompilerParams(\n",
+ " dimension_semantics=(\"parallel\", \"parallel\", \"arbitrary\")),\n",
" )(x, y)"
]
},
diff --git a/docs/pallas/tpu/matmul.md b/docs/pallas/tpu/matmul.md
index a00880ebaf37..b8e6acbd45f9 100644
--- a/docs/pallas/tpu/matmul.md
+++ b/docs/pallas/tpu/matmul.md
@@ -5,7 +5,7 @@ jupytext:
extension: .md
format_name: myst
format_version: 0.13
- jupytext_version: 1.16.1
+ jupytext_version: 1.16.4
kernelspec:
display_name: Python 3 (ipykernel)
language: python
@@ -167,8 +167,8 @@ def matmul(
pl.BlockSpec((bk, bn), lambda i, j, k: (k, j))],
out_specs=pl.BlockSpec((bm, bn), lambda i, j, k: (i, j)),
grid=(m // bm, n // bn, k // bk),
- compiler_params=dict(mosaic=dict(
- dimension_semantics=("parallel", "parallel", "arbitrary"))),
+ compiler_params=pltpu.TPUCompilerParams(
+ dimension_semantics=("parallel", "parallel", "arbitrary")),
)(x, y)
```
@@ -321,8 +321,8 @@ def matmul(
grid=(m // bm, n // bn, k // bk),
),
out_shape=jax.ShapeDtypeStruct((m, n), x.dtype),
- compiler_params=dict(mosaic=dict(
- dimension_semantics=("parallel", "parallel", "arbitrary"))),
+ compiler_params=pltpu.TPUCompilerParams(
+ dimension_semantics=("parallel", "parallel", "arbitrary")),
)(x, y)
```
@@ -489,8 +489,8 @@ def matmul(
grid=(m // bm, n // bn, k // bk),
),
out_shape=jax.ShapeDtypeStruct((m, n), x.dtype),
- compiler_params=dict(mosaic=dict(
- dimension_semantics=("parallel", "parallel", "arbitrary"))),
+ compiler_params=pltpu.TPUCompilerParams(
+ dimension_semantics=("parallel", "parallel", "arbitrary")),
)(x, y)
```
@@ -613,8 +613,8 @@ def matmul(
grid=(m // bm, n // bn, k // bk),
),
out_shape=jax.ShapeDtypeStruct((m, n), x.dtype),
- compiler_params=dict(mosaic=dict(
- dimension_semantics=("parallel", "parallel", "arbitrary"))),
+ compiler_params=pltpu.TPUCompilerParams(
+ dimension_semantics=("parallel", "parallel", "arbitrary")),
)(x, y)
```
diff --git a/docs/pallas/tpu/pipelining.ipynb b/docs/pallas/tpu/pipelining.ipynb
index 275a72f3837b..b5f2c652b5a5 100644
--- a/docs/pallas/tpu/pipelining.ipynb
+++ b/docs/pallas/tpu/pipelining.ipynb
@@ -1,5 +1,13 @@
{
"cells": [
+ {
+ "cell_type": "markdown",
+ "id": "7704d3bb",
+ "metadata": {},
+ "source": [
+ "(pallas_tpu_pipelining)="
+ ]
+ },
{
"cell_type": "markdown",
"metadata": {
@@ -33,6 +41,7 @@
"\n",
"import jax\n",
"from jax.experimental import pallas as pl\n",
+ "from jax.experimental.pallas import tpu as pltpu\n",
"import jax.numpy as jnp\n",
"import numpy as np"
]
@@ -696,7 +705,7 @@
" in_specs=[block_spec, block_spec],\n",
" out_specs=block_spec,\n",
" grid=(2,),\n",
- " compiler_params=dict(mosaic=dict(dimension_semantics=(\"parallel\",)))\n",
+ " compiler_params=pltpu.TPUCompilerParams(dimension_semantics=(\"parallel\",))\n",
" )(x, y)\n",
"\n",
"x, y = jnp.ones((512, 512)), jnp.ones((512, 512))\n",
diff --git a/docs/pallas/tpu/pipelining.md b/docs/pallas/tpu/pipelining.md
index d753b404db1a..19150b3832fa 100644
--- a/docs/pallas/tpu/pipelining.md
+++ b/docs/pallas/tpu/pipelining.md
@@ -5,12 +5,14 @@ jupytext:
extension: .md
format_name: myst
format_version: 0.13
- jupytext_version: 1.16.1
+ jupytext_version: 1.16.4
kernelspec:
display_name: Python 3
name: python3
---
+(pallas_tpu_pipelining)=
+
+++ {"id": "teoJ_fUwlu0l"}
# Pipelining
@@ -29,6 +31,7 @@ pipelines in Pallas that overlap memory I/O with compute.
import jax
from jax.experimental import pallas as pl
+from jax.experimental.pallas import tpu as pltpu
import jax.numpy as jnp
import numpy as np
```
@@ -465,7 +468,7 @@ def add_matrices_pipelined_megacore(x: jax.Array, y: jax.Array) -> jax.Array:
in_specs=[block_spec, block_spec],
out_specs=block_spec,
grid=(2,),
- compiler_params=dict(mosaic=dict(dimension_semantics=("parallel",)))
+ compiler_params=pltpu.TPUCompilerParams(dimension_semantics=("parallel",))
)(x, y)
x, y = jnp.ones((512, 512)), jnp.ones((512, 512))
diff --git a/docs/pallas/tpu/sparse.ipynb b/docs/pallas/tpu/sparse.ipynb
new file mode 100644
index 000000000000..a80ba4ebedbb
--- /dev/null
+++ b/docs/pallas/tpu/sparse.ipynb
@@ -0,0 +1,724 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "ZHuzXqQ-9JUQ"
+ },
+ "source": [
+ "# Scalar Prefetch and Block-Sparse Computation\n",
+ "\n",
+ "In this tutorial, we will cover the basics of block-sparse computing in Pallas. Sparse computation is a major reason to write custom Pallas kernels over simply using JAX/XLA, since it is generally difficult to express programs that perform a dynamic amount of computation in XLA due to static array shapes. In this tutorial we will learn how to use the scalar prefetch feature of Pallas in order to write block-sparse kernels that can dynamically skip over computation and blocks of memory."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "executionInfo": {
+ "elapsed": 56,
+ "status": "ok",
+ "timestamp": 1726001133029,
+ "user": {
+ "displayName": "Justin Fu",
+ "userId": "17543197034567316452"
+ },
+ "user_tz": 420
+ },
+ "id": "ibeIs_6QFMAM",
+ "outputId": "d72edb91-4529-4650-c9e9-b96788608635"
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Running on TPU v5 lite\n"
+ ]
+ }
+ ],
+ "source": [
+ "import functools\n",
+ "import timeit\n",
+ "import numpy as np\n",
+ "import jax\n",
+ "from jax import numpy as jnp\n",
+ "from jax import lax\n",
+ "from jax.experimental import checkify\n",
+ "from jax.experimental import pallas as pl\n",
+ "from jax.experimental.pallas import tpu as pltpu\n",
+ "\n",
+ "assert \"TPU\" in jax.devices()[0].device_kind, \"Please run this notebook with TPU devices.\"\n",
+ "print(\"Running on\", jax.devices()[0].device_kind)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "FIDGpPTEIcOa"
+ },
+ "source": [
+ "## Dynamic Block Indexing with Scalar Prefetch\n",
+ "\n",
+ "We will be exploiting the \"scalar prefetch\" feature of Pallas to enable us to write sparse kernels. Scalar prefetch allows you to pass in a small amount of data into SMEM (\"scalar memory\") that is loaded before the start of the pipeline (\"prefetch\"). Because this data is loaded before the pipeline, it is available for use in the `index_map` for each BlockSpec, allowing the you to perform data-dependent indexing calculations. The main goal of this tutorial is to go over common programming patterns that utilize this feature.\n",
+ "\n",
+ "To use scalar prefetch, use `pltpu.PrefetchScalarGridSpec` in place of the standard `pl.GridSpec`:\n",
+ "\n",
+ "```python\n",
+ "class PrefetchScalarGridSpec:\n",
+ " def __init__(self,\n",
+ " num_scalar_prefetch: int,\n",
+ " grid: tuple[int, ...],\n",
+ " in_specs: PyTree[BlockSpec],\n",
+ " out_specs: PyTree[BlockSpec],\n",
+ " scratch_shapes: tuple[MemorySpace, ...]):\n",
+ " ...\n",
+ "```\n",
+ "\n",
+ "The `num_scalar_prefetch` parameter indicates the number of scalar prefetch values. When this is set to a non-zero value, it changes the call signature of the kernel and index maps to expect additional prefetch values. The prefetch `Ref`s passed in to the `index_map` and kernel are all allocated in SMEM and are not partitioned into blocks as they do not have a BlockSpec defined. Moreover, the order of arguments to both `index_map` and kernel are always fixed and described below:\n",
+ "\n",
+ "- Each `BlockSpec`'s `index_map` now expects the prefetch `Ref`s to come after the grid indices:\n",
+ "```python\n",
+ "def index_map(*grid_indices, *prefetch_refs):\n",
+ " ...\n",
+ "```\n",
+ "\n",
+ "- The user-defined kernel expects prefetch `Ref`s to come before the input `Ref`s. Additionally, the scratch refs come after the output `Ref`s.\n",
+ "```python\n",
+ "def kernel(*prefetch_refs, *input_refs, *output_refs, *scratch_refs):\n",
+ " ...\n",
+ "```\n",
+ "\n",
+ "- When calling a new kernel using `pallas_call`, the function returned by `pallas_call` also expects the scalar prefetch arguments to come before the inputs, e.g.\n",
+ "```python\n",
+ "kernel = pl.pallas_call(...)\n",
+ "result = kernel(*prefetch_args, *input_args)\n",
+ "```"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "pA8RmHEA2HN3"
+ },
+ "source": [
+ "## Example: Block Dynamic Slice with Scalar Prefetch\n",
+ "\n",
+ "Let's begin with a basic example that demonstrates how to use the scalar prefetch feature. We will implement a block-aligned dynamic slice kernel which simply extracts a block out of larger array based on user-specified indices:\n",
+ "\n",
+ "1. Outside of the kernel, we compute the block index to extract as: `block_idx = (start[0] // size[0], start[1] // size[1])`\n",
+ "\n",
+ "2. We pass `block_idx` as a scalar prefetch argument into `pallas_call`.\n",
+ "\n",
+ "3. In our index map, we use the block index to select the corresponding block by returning `(block_idx[0], block_idx[1])`.\n",
+ "\n",
+ "Of course, this kernel is limited in that our slice sizes must fit inside of a kernel block (limited by VMEM size) and we can only start on size-aligned indices. A more advanced kernel would decouple the kernel block size with the slice size and allow non-aligned start indices."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "executionInfo": {
+ "elapsed": 143,
+ "status": "ok",
+ "timestamp": 1726003877561,
+ "user": {
+ "displayName": "Justin Fu",
+ "userId": "17543197034567316452"
+ },
+ "user_tz": 420
+ },
+ "id": "FWeTBlEYlCGD",
+ "outputId": "4b04a441-c97c-4d0d-d167-c60d4d31fd2e"
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Error |result - lax.dynamic_slice| = 0\n"
+ ]
+ }
+ ],
+ "source": [
+ "def dynamic_slice_kernel(indices, x_ref, o_ref):\n",
+ " del indices\n",
+ " o_ref[...] = x_ref[...]\n",
+ "\n",
+ "@checkify.checkify\n",
+ "@functools.partial(jax.jit, static_argnums=(2,))\n",
+ "def block_dynamic_slice(x, starts, sizes):\n",
+ " grid_spec = pltpu.PrefetchScalarGridSpec(\n",
+ " num_scalar_prefetch=1,\n",
+ " grid=(1, 1),\n",
+ " in_specs=[pl.BlockSpec(\n",
+ " sizes,\n",
+ " lambda i, j, block_idx: (block_idx[0], block_idx[1]))],\n",
+ " out_specs=pl.BlockSpec(sizes, lambda *_: (0, 0)),\n",
+ " )\n",
+ "\n",
+ " kernel = pl.pallas_call(\n",
+ " dynamic_slice_kernel,\n",
+ " grid_spec=grid_spec,\n",
+ " out_shape=jax.ShapeDtypeStruct(shape=sizes, dtype=x.dtype),\n",
+ " )\n",
+ " # Checkify inserts a runtime assert that starts are divisible by block size.\n",
+ " checkify.check(starts[0] % sizes[0] == 0, \"Starts must be divisible by size.\")\n",
+ " checkify.check(starts[1] % sizes[1] == 0, \"Starts must be divisible by size.\")\n",
+ " block_idx = jnp.array([starts[0] // sizes[0], starts[1] // sizes[1]])\n",
+ " return kernel(block_idx, x)\n",
+ "\n",
+ "shape = (512, 512)\n",
+ "x = jnp.reshape(jnp.arange(np.prod(shape), dtype=jnp.int32), shape)\n",
+ "err, result = block_dynamic_slice(x, starts=(128, 256), sizes=(128, 128))\n",
+ "err.throw()\n",
+ "ref = lax.dynamic_slice(x, start_indices=(128, 256), slice_sizes=(128, 128))\n",
+ "diff = jnp.max(jnp.abs(result - ref))\n",
+ "print(\"Error |result - lax.dynamic_slice| =\", diff)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "K2dod4lkoifa"
+ },
+ "source": [
+ "## Sparse Kernels: Representing Sparse Data\n",
+ "\n",
+ "Before we dive into implementing sparse kernels, let's first review how sparse matrices are represented. While there are several popular formats for storing sparse matrices, we will be following a blocked variant of the coordinate-list format (COO) in which we will store a matrix as a list of `(block_index, block_data)` pairs. All blocks that are not explicitly stored in the list are assumed to be zero, meaning we can save a significant amount of memory if there are many zero blocks in the matrix.\n",
+ "\n",
+ "The following figure demonstrates how we convert a 4x4 dense matrix (left) into a block-COO format (right) with a block size of 2x2. Note that in the sparse format, we can avoid explicitly storing the upper-right block which consists of all zero elements.\n",
+ "\n",
+ "![block_coo](../../_static/pallas/sparse/block_coo.svg)\n",
+ "\n",
+ "We will use the following helper function to sample a block-sparse matrix. It returns a dense matrix used for checking our results, as well as a list of block data and indices for each axis."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "1gLiSvgIYUEx"
+ },
+ "outputs": [],
+ "source": [
+ "def generate_block_sparse_mat(key, M, N, blk_M, blk_N, p=0.2, dtype=jnp.float32):\n",
+ " \"\"\"Returns a sampled matrix and its block-sparse representation.\n",
+ "\n",
+ " Args:\n",
+ " key: RNG Key.\n",
+ " M: Major array dimension.\n",
+ " N: Minor array dimension.\n",
+ " blk_M: Block size along M dimension.\n",
+ " blk_N: Block size along N dimension.\n",
+ " p: Probability that a block will be non-zero.\n",
+ " dtype: dtype of the sampled matrix.\n",
+ "\n",
+ " Returns:\n",
+ " dense_mat: A (M, N) dense sampled array.\n",
+ " block_data: A (num_blocks, blk_M, blk_N) array of data blocks representing\n",
+ " the non-zero blocks of the matrix.\n",
+ " indices_i: A (num_blocks,) array of block indices for the first axis.\n",
+ " indices_j: A (num_blocks,) array of block indices for the second axis.\n",
+ " \"\"\"\n",
+ " mask_key, blocks_key = jax.random.split(key)\n",
+ " num_blocks = (M // blk_M, N // blk_N)\n",
+ " # We first sample a block mask, denoting which blocks are nonzero.\n",
+ " block_mask = jax.random.bernoulli(mask_key, p=p, shape=num_blocks)\n",
+ " num_blocks = jnp.sum(block_mask)\n",
+ " indices = jnp.where(block_mask)\n",
+ " # For each non-zero block, we sample a block of random values.\n",
+ " block_data = jax.random.uniform(blocks_key,\n",
+ " shape=(num_blocks, blk_M, blk_N),\n",
+ " dtype=dtype)\n",
+ " # For checking purposes, create the dense version of the sparse matrix.\n",
+ " dense_mat = jnp.zeros((M, N), dtype=dtype)\n",
+ " for blk in range(num_blocks):\n",
+ " idx_i = indices[0][blk]\n",
+ " idx_j = indices[1][blk]\n",
+ " slice_i = slice(idx_i * blk_M, (idx_i + 1) * blk_M)\n",
+ " slice_j = slice(idx_j * blk_N, (idx_j + 1) * blk_N)\n",
+ " dense_mat = dense_mat.at[slice_i, slice_j].set(block_data[blk])\n",
+ " return dense_mat, block_data, indices[0], indices[1]"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "eFyoZSTOH9Fk"
+ },
+ "source": [
+ "## Example: Sparse @ Dense Matrix Multiplication\n",
+ "\n",
+ "In our first example, we will multiple a sparse LHS matrix with a dense RHS matrix to produce a dense output.\n",
+ "\n",
+ "We will structure our kernel grid with 2 loops - the outer loop over the columns of the RHS/output, and inner loop over the sparse blocks of the LHS. During each inner loop iteration, we load one block from the LHS and lookup the corresponding block on in the RHS using the block index of the contracting dimension (K). We multiply the two blocks together and accumulate into the correct output block. One outer loop iteration will compute a result for an entire column as depicted by the following diagram:\n",
+ "\n",
+ "![sparse_matmul](../../_static/pallas/sparse/sparse_matmul.svg)\n",
+ "\n",
+ "It is important that we group the block indices by row (e.g. `[0, 0, 1, 2, 3, 3]`) before we pass them into the kernel for two reasons. First, in our kernel we need to know when to initially zero-out the accumulator in the output ref, and it is easy to do so if the row indices are grouped. Second, the pipelining logic for Pallas does not allow us to re-visit blocks in the output `Ref` on non-consecutive iterations, and therefore we need to do all accumulation into an output block in consecutive kernel iterations. This is because the pipeline emitter will realize that we loading the same output block on consecutive iterations and keep the block in VMEM. When we change output block Pallas will finally store the output into HBM and assume we never touch it again. Failure to access output blocks consecutively will result in incorrect values even though the kernel is otherwise logically correct."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "executionInfo": {
+ "elapsed": 673,
+ "status": "ok",
+ "timestamp": 1725919879291,
+ "user": {
+ "displayName": "Justin Fu",
+ "userId": "17543197034567316452"
+ },
+ "user_tz": 420
+ },
+ "id": "WfyV2WWhjsyA",
+ "outputId": "fa4d4fff-bc6b-4dc9-ac14-63276ca14131"
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "mean |result - ref|: 0\n"
+ ]
+ }
+ ],
+ "source": [
+ "M = N = K = 16384\n",
+ "blk_M = blk_N = blk_K = 512\n",
+ "\n",
+ "\n",
+ "def dsd_kernel(idxs_i_ref, idxs_k_ref, # Scalar prefetch inputs.\n",
+ " x_ref, y_ref, _, o_ref, # Kernel inputs.\n",
+ " accum_scratch,\n",
+ " ):\n",
+ " \"\"\"A DSD (Dense = Sparse @ Dense) matmul kernel.\"\"\"\n",
+ " del idxs_k_ref\n",
+ " blk_idx = pl.program_id(0)\n",
+ " is_start = blk_idx == 0\n",
+ " changed_blocks = (idxs_i_ref[blk_idx] != idxs_i_ref[jnp.maximum(blk_idx-1, 0)])\n",
+ " @pl.when(is_start | changed_blocks)\n",
+ " def _():\n",
+ " accum_scratch[...] = jnp.zeros_like(accum_scratch)\n",
+ " accum_scratch[...] += jnp.dot(x_ref[0, :, :], y_ref[...], preferred_element_type=jnp.float32)\n",
+ "\n",
+ " next_block_change = (idxs_i_ref[blk_idx] != idxs_i_ref[jnp.minimum(blk_idx+1, num_blocks)])\n",
+ " is_end = blk_idx == (num_blocks - 1)\n",
+ " @pl.when(is_end | next_block_change)\n",
+ " def _():\n",
+ " o_ref[...] = accum_scratch[...].astype(o_ref.dtype)\n",
+ "\n",
+ "\n",
+ "def x_map(blk_idx, j, blk_idxs_i, blk_idxs_k):\n",
+ " del j, blk_idxs_i, blk_idxs_k\n",
+ " return (blk_idx, 0, 0)\n",
+ "def y_map(blk_idx, j, blk_idxs_i, blk_idxs_k):\n",
+ " del blk_idxs_i\n",
+ " return (blk_idxs_k[blk_idx], j)\n",
+ "def o_map(blk_idx, j, blk_idxs_i, blk_idxs_k):\n",
+ " del blk_idxs_k\n",
+ " return (blk_idxs_i[blk_idx], j)\n",
+ "\n",
+ "(X_dense, X_blocks, indices_i, indices_k) = generate_block_sparse_mat(\n",
+ " jax.random.key(0), M, K, blk_M, blk_K, p=0.1, dtype=jnp.bfloat16)\n",
+ "num_blocks = X_blocks.shape[0]\n",
+ "Y = jax.random.uniform(jax.random.key(1), shape=(K, N), dtype=jnp.bfloat16)\n",
+ "zeros = jnp.zeros((M, N), dtype=jnp.bfloat16)\n",
+ "out_shape = jax.ShapeDtypeStruct((M, N), dtype=jnp.bfloat16)\n",
+ "\n",
+ "grid_spec = pltpu.PrefetchScalarGridSpec(\n",
+ " num_scalar_prefetch=2,\n",
+ " # Note that while num_blocks is static here, Pallas does support\n",
+ " # dynamic grid sizes.\n",
+ " grid=(num_blocks, N // blk_N),\n",
+ " in_specs=[pl.BlockSpec((1, blk_M, blk_K), x_map),\n",
+ " pl.BlockSpec((blk_K, blk_N), y_map),\n",
+ " # Placeholder for a zeros-array used by input_output_aliases.\n",
+ " pl.BlockSpec((blk_M, blk_N), o_map),\n",
+ " ],\n",
+ " out_specs=pl.BlockSpec((blk_M, blk_N), o_map),\n",
+ " scratch_shapes=[pltpu.VMEM((blk_M, blk_N), dtype=jnp.float32)]\n",
+ ")\n",
+ "kernel = pl.pallas_call(\n",
+ " dsd_kernel,\n",
+ " grid_spec=grid_spec,\n",
+ " out_shape=out_shape,\n",
+ " # We use input-output aliases to zero-out o_ref for blocks that we never\n",
+ " # visit. By passing in an array of zeros we avoid having o_ref start with\n",
+ " # uninitialized values.\n",
+ " input_output_aliases={4: 0}, # Map zeros to o_ref.\n",
+ ")\n",
+ "args = (indices_i, indices_k, X_blocks, Y, zeros)\n",
+ "result = kernel(*args)\n",
+ "\n",
+ "ref = X_dense @ Y\n",
+ "diff = jnp.abs(ref - result)\n",
+ "print('mean |result - ref|:', jnp.mean(diff))"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "2KDgPKF2tUjq"
+ },
+ "source": [
+ "We can do a quick benchmark to compare the performance of our sparse kernel compared to a dense matmul in JAX. On a TPU v5e chip, this kernel achieves a roughly ~6x speed increase compared to the theoretical 10x from the sparsity factor.\n",
+ "\n",
+ "There are a few main tips for performance here, mainly centered around reducing the communication overhead between HBM/VMEM:\n",
+ "- Using `dtype=jnp.bfloat16` is critical for performance since it reduces memory bandwidth by half.\n",
+ "- Using larger block sizes also helps, since matrix multiply is an $O(N^3)$ compute and $O(N^2)$ memory operation. As $N$ grows larger, the kernel becomes compute-bound. However, a counter-argument to this in practice is that smaller block sizes also enables data to be more sparse, so this is a parameter that should be selected carefully."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "executionInfo": {
+ "elapsed": 6576,
+ "status": "ok",
+ "timestamp": 1725919886762,
+ "user": {
+ "displayName": "Justin Fu",
+ "userId": "17543197034567316452"
+ },
+ "user_tz": 420
+ },
+ "id": "CkzjqnekpZbx",
+ "outputId": "1ae9031e-705a-4d05-f8b9-d09623918300"
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Sparse Kernel: 8.136 ms (avg over 100 trials)\n",
+ "Reference: 46.953 ms (avg over 100 trials)\n"
+ ]
+ }
+ ],
+ "source": [
+ "# Benchmark Sparse Pallas kernel vs reference JAX implementation\n",
+ "\n",
+ "def benchmark(f, ntrials: int = 100):\n",
+ " def run(*args, **kwargs):\n",
+ " # Compile function first\n",
+ " jax.block_until_ready(f(*args, **kwargs))\n",
+ " # Time function\n",
+ " result = timeit.timeit(lambda: jax.block_until_ready(f(*args, **kwargs)),\n",
+ " number=ntrials)\n",
+ " time = result / ntrials\n",
+ " return time\n",
+ " return run\n",
+ "\n",
+ "\n",
+ "n_trials = 100\n",
+ "\n",
+ "pallas_impl = lambda *args: kernel(*args)\n",
+ "time = benchmark(pallas_impl, n_trials)(indices_i, indices_k, X_blocks, Y, zeros)\n",
+ "print(\"Sparse Kernel: %.3f ms (avg over %d trials)\" % (time * 1000, n_trials))\n",
+ "\n",
+ "ref_impl = jax.jit(lambda x, y: x @ y)\n",
+ "time = benchmark(ref_impl, n_trials)(X_dense, Y)\n",
+ "print(\"Reference: %.3f ms (avg over %d trials)\" % (time * 1000, n_trials))"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "Q1KKd5vTCwnB"
+ },
+ "source": [
+ "## Sparse Access Patterns on Dense Data\n",
+ "\n",
+ "In our previous example we considered the case when the data itself is sparse. This manifested itself in the kernel structure as a dimension in the kernel grid that was dynamic and looped over the number of nonzero blocks (`num_blocks`).\n",
+ "\n",
+ "A second useful programming pattern emerges when the underlying is data is dense, but we wish to perform sparse computation over it. Our kernel grid in this case will be dense, but we wish to skip over some blocks in the grid as indicated by a block-sparse mask. This type of programming pattern is commonly arises when using masks in many machine learning applications, such as causal or local masks in self-attention. In these cases, we can entirely skip over computation in blocks where the mask is zeroed-out. Examples of this programming pattern can be found in the Splash Attention and Grouped Matrix Multiplication kernels located in `jax/experimental/pallas/ops/tpu`, or in PyTorch's [FlexAttention](https://pytorch.org/blog/flexattention/).\n",
+ "\n",
+ "The main performance consideration with dealing with a sparse access pattern on dense data is the interaction with pipelining. On any given kernel iteration, the Pallas pipeline emitter will attempt to prefetch the next block of data by calling the `index_map` for each `BlockSpec` on the next iteration of the grid. However, if our computation is sparse we may be skipping the computation for the next block in the grid, so we need some method to tell the pipeline instead begin fetching the *next block that we are not skipping*. In order to do this, we need to construct *prefetch maps* which contains indices to the next non-skipped block of data for each kernel input. The following diagram illustrates how a prefetch map could be constructed for a block-sparse mask that is stored in a COO-like format.\n",
+ "\n",
+ "![prefetch_map](../../_static/pallas/sparse/prefetch_map.svg)\n",
+ "\n",
+ "*Left: A sparse access pattern, where the color blue denotes blocks with non-zero masks that we need to compute. Right: The prefetch map, where each element of the array contains the index of the next non-zero block data.*\n",
+ "\n",
+ "Once the prefetch map has been constructed, we can pass the map as a scalar prefetch argument and query it in the `index_map` function of the BlockSpec.\n",
+ "\n",
+ "```python\n",
+ "def mask_index_map(prefetch_map, i, j, ...):\n",
+ " next_nonzero_block = prefetch_map[i, j]\n",
+ " return (next_nonzero_block, 0, 0)\n",
+ "```\n",
+ "\n",
+ "We can construct similar index maps for the other inputs to the kernel. For dense inputs you will most likely need to construct prefetch maps which point to the next non-zero block index in the grid. Our next example will provide an example of using these prefetch maps."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "ii7rzL5YIA8-"
+ },
+ "source": [
+ "## Example: Dense @ Dense Matrix Multiplication with a Block-Sparse Output Mask"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "ecjiqWfA2RlV"
+ },
+ "source": [
+ "In our next example we will cover dense matrix multiplication fused with a sparse output mask using a prefetch map to improve pipelining performance. We will use the mask to selectively skip computing output blocks that are zeroed-out, therefore saving on computation costs.\n",
+ "\n",
+ "As we will be working with a sparse mask, we will begin by implementing a function that converts an `N x M` mask stored in dense format into a block-sparse format. We additionally need to compute prefetch maps to help the pipeline emitter know which block to fetch next. In total, our `sparsify_mask` function computes:\n",
+ "- A `block_mask` of shape `(num_N_blocks, num_M_blocks)` indicating if a block is all-zeros (value `0`) or contains non-zero elements (value `1`). If the `block_mask` has a value of 0 we can skip computing the block in the kernel.\n",
+ "- A `prefetch_mask` array of shape `(num_N_blocks, num_M_blocks)` consisting of indices into `mask_data` for the next non-zero block.\n",
+ "- A `prefetch_i` array of shape `(num_N_blocks, num_M_blocks)` consisting of the next non-masked `i` index of the mask.\n",
+ "- A `prefetch_j` array of shape `(num_N_blocks, num_M_blocks)` consisting of the next non-masked `j` index of the mask.\n",
+ "- A `mask_data` array of shape `(num_blocks, blk_N, blk_M)` containing data for non-zero blocks of the mask."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "19zGcliL2SJy"
+ },
+ "outputs": [],
+ "source": [
+ "def sparsify_mask(mask: jax.Array,\n",
+ " block_shape: tuple[int, int]):\n",
+ " \"\"\"Preprocesses a mask into a sparse reprentation.\n",
+ "\n",
+ " Args:\n",
+ " mask: A boolean array of shape [M, N]\n",
+ " block_shape: The size of a single block.\n",
+ "\n",
+ " Returns:\n",
+ " block_mask: A block_shape array of booleans indicating whether a block\n",
+ " is all-zeros (0) or contains non-zero elements (1).\n",
+ " prefetch_mask: A block_shape array of integers indicating the index of the\n",
+ " next non-zero block.\n",
+ " mask_data: A (num_blocks, block_shape) array containing\n",
+ " the data for non-zero blocks of the mask.\n",
+ " \"\"\"\n",
+ " M, N = mask.shape\n",
+ " bm, bn = block_shape\n",
+ "\n",
+ " block_mask = jnp.zeros((M // bm, N // bn), dtype=mask.dtype)\n",
+ " mask_types_finder = []\n",
+ " mask_data = []\n",
+ " mask_type_idxs = []\n",
+ "\n",
+ " next_mask_type_idx = 0\n",
+ " prefetch_mask = jnp.zeros_like(block_mask)\n",
+ " next_i = (M // bm) - 1\n",
+ " next_j = (N // bn) - 1\n",
+ " prefetch_i = jnp.zeros_like(block_mask)\n",
+ " prefetch_j = jnp.zeros_like(block_mask)\n",
+ " for i in range(M // bm, -1, -1):\n",
+ " for j in range(N // bn, -1, -1):\n",
+ " mask_block = mask[i * bm :(i + 1) * bm,\n",
+ " j * bn :(j + 1) * bn]\n",
+ " is_nonzero = jnp.any(mask_block)\n",
+ " if is_nonzero:\n",
+ " try:\n",
+ " type_index = mask_types_finder.index(str(mask_block))\n",
+ " except ValueError:\n",
+ " type_index = len(mask_types_finder)\n",
+ " mask_types_finder.append(str(mask_block))\n",
+ " mask_data.append(mask_block)\n",
+ " next_mask_type_idx = type_index\n",
+ " next_i = i\n",
+ " next_j = j\n",
+ " else:\n",
+ " type_index = -1\n",
+ " mask_type_idxs.append(type_index)\n",
+ " block_mask = block_mask.at[i, j].set(is_nonzero)\n",
+ " prefetch_mask = prefetch_mask.at[i, j].set(next_mask_type_idx)\n",
+ " prefetch_i = prefetch_i.at[i, j].set(next_i)\n",
+ " prefetch_j = prefetch_j.at[i, j].set(next_j)\n",
+ " return block_mask, prefetch_mask, prefetch_i, prefetch_j, jnp.stack(mask_data)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "w4b7ckKq67Xw"
+ },
+ "source": [
+ "In terms of the structure of the kernel, we use the same grid pattern as the standard matrix multiplication kernel we covered in previous tutorials with a 3 loops over the `N`, `M`, and `K` dimensions. Within the kernel itself, we first check the `block_mask` to see if the mask for the current output block was all zeros. If the mask is all zeros, we can skip computation and move onto the next block; otherwise we need to compute the matrix multiplication and then mask the result."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "executionInfo": {
+ "elapsed": 5374,
+ "status": "ok",
+ "timestamp": 1725919713252,
+ "user": {
+ "displayName": "Justin Fu",
+ "userId": "17543197034567316452"
+ },
+ "user_tz": 420
+ },
+ "id": "4YQ9OmbTCSjT",
+ "outputId": "2d752609-34f2-4059-e8ba-4d80afe8cb26"
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "mean |result - ref|: 1.0252e-05\n"
+ ]
+ }
+ ],
+ "source": [
+ "M = N = K = 16384\n",
+ "blk_M = blk_N = 512\n",
+ "blk_K = 1024\n",
+ "\n",
+ "def sparse_mask_matmul(\n",
+ " block_mask_ref, prefetch_mask, prefetch_i, prefetch_j, # Scalar prefetch inputs.\n",
+ " x_ref, y_ref, mask_ref, o_ref, # Kernel inputs.\n",
+ " accum_scratch\n",
+ " ):\n",
+ " del prefetch_mask, prefetch_i, prefetch_j\n",
+ " i, j, k = pl.program_id(0), pl.program_id(1), pl.program_id(2)\n",
+ " should_compute = block_mask_ref[i, j] != 0\n",
+ " @pl.when(k == 0)\n",
+ " def _():\n",
+ " o_ref[...] = jnp.zeros_like(o_ref)\n",
+ " accum_scratch[...] = jnp.zeros_like(accum_scratch[...])\n",
+ "\n",
+ " # We only compute the output for blocks with non-zero masks.\n",
+ " # Otherwise we skip the computation entirely.\n",
+ " @pl.when(should_compute)\n",
+ " def _():\n",
+ " result = jnp.dot(x_ref[...], y_ref[...], preferred_element_type=jnp.float32)\n",
+ " accum_scratch[...] += result\n",
+ " @pl.when(k == pl.num_programs(2) - 1)\n",
+ " def _():\n",
+ " o_ref[...] = (mask_ref[0, ...] * accum_scratch[...]).astype(o_ref.dtype)\n",
+ "\n",
+ "X = jax.random.normal(jax.random.key(0), shape=(M, K), dtype=jnp.bfloat16)\n",
+ "Y = jax.random.normal(jax.random.key(1), shape=(K, N), dtype=jnp.bfloat16)\n",
+ "mask = jnp.ones((M, N), dtype=jnp.int32)\n",
+ "mask = jnp.tril(mask)\n",
+ "block_mask, prefetch_mask, prefetch_i, prefetch_j, sparse_mask_data = sparsify_mask(mask, (blk_M, blk_N))\n",
+ "\n",
+ "def x_map(i, j, k, block_mask, prefetch_mask, prefetch_i, prefetch_j):\n",
+ " del prefetch_mask, prefetch_j\n",
+ " # Zero-out the k index if the mask is zero, to avoid constantly fetching\n",
+ " # new blocks in the inner loop for blocks we are skipping.\n",
+ " k_fetch = (block_mask[i, j] != 0) * k\n",
+ " return (prefetch_i[i, j], k_fetch)\n",
+ "\n",
+ "def y_map(i, j, k, block_mask, prefetch_mask, prefetch_i, prefetch_j):\n",
+ " del prefetch_mask, prefetch_i\n",
+ " k_fetch = (block_mask[i, j] != 0) * k\n",
+ " return (k_fetch, prefetch_j[i, j])\n",
+ "\n",
+ "def mask_map(i, j, k, block_mask, prefetch_mask, *_):\n",
+ " del k, block_mask\n",
+ " return (prefetch_mask[i, j], 0, 0)\n",
+ "\n",
+ "def o_map(i, j, k, *_):\n",
+ " del k\n",
+ " return (i, j)\n",
+ "\n",
+ "grid_spec = pltpu.PrefetchScalarGridSpec(\n",
+ " num_scalar_prefetch=4,\n",
+ " grid=(M // blk_M, N // blk_N, K // blk_K),\n",
+ " in_specs=[pl.BlockSpec((blk_M, blk_K), x_map),\n",
+ " pl.BlockSpec((blk_K, blk_N), y_map),\n",
+ " pl.BlockSpec((1, blk_M, blk_N), mask_map)],\n",
+ " out_specs=pl.BlockSpec((blk_M, blk_N), o_map),\n",
+ " scratch_shapes=[pltpu.VMEM((blk_M, blk_N), dtype=jnp.float32)]\n",
+ ")\n",
+ "kernel = pl.pallas_call(\n",
+ " sparse_mask_matmul,\n",
+ " grid_spec=grid_spec,\n",
+ " out_shape=jax.ShapeDtypeStruct((M, N), jnp.bfloat16),\n",
+ ")\n",
+ "args = (block_mask, prefetch_mask, prefetch_i, prefetch_j, X, Y, sparse_mask_data)\n",
+ "result = kernel(*args)\n",
+ "\n",
+ "ref = mask * (X @ Y)\n",
+ "diff = jnp.abs(ref - result)\n",
+ "print('mean |result - ref|:', jnp.mean(diff))"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "uutNGgjZGGhB"
+ },
+ "source": [
+ "Now let's compare performance versus a naive dense implementation. On TPU v5e, we achieve around a ~1.8x speed increase with the sparse kernel, compared to a theoretical best-case of 2x from using a lower triangular mask and only visiting half of the possible outputs.\n",
+ "\n",
+ "We would generally expect performance to get closer to the theoretical peak as our inputs get larger, since a few of the main reasons why we don't exactly reach theoretical performance are:\n",
+ "- We skip slightly less than half of computation since the blocks along the diagonal are mixed 0s and 1s, and for mixed blocks we need to compute the entire block. With larger inputs, our overhead for mixed blocks becomes smaller relative to the overall computation.\n",
+ "- The pipeline bubble also becomes accounts for a less percentage of the overall runtime as inputs become larger."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "executionInfo": {
+ "elapsed": 8877,
+ "status": "ok",
+ "timestamp": 1725917397452,
+ "user": {
+ "displayName": "Justin Fu",
+ "userId": "17543197034567316452"
+ },
+ "user_tz": 420
+ },
+ "id": "MAT9JjGNvsx8",
+ "outputId": "a32d56fb-a71b-4007-c6a5-e5270dcaa6cf"
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Sparse Kernel: 28.648 ms (avg over 100 trials)\n",
+ "Reference: 49.988 ms (avg over 100 trials)\n"
+ ]
+ }
+ ],
+ "source": [
+ "n_trials = 100\n",
+ "\n",
+ "pallas_impl = lambda *args: kernel(*args)\n",
+ "time = benchmark(pallas_impl, n_trials)(block_mask, prefetch_mask, prefetch_i, prefetch_j, X, Y, sparse_mask_data)\n",
+ "print(\"Sparse Kernel: %.3f ms (avg over %d trials)\" % (time * 1000, n_trials))\n",
+ "\n",
+ "ref_impl = jax.jit(lambda mask, x, y: mask * (x @ y))\n",
+ "time = benchmark(ref_impl, n_trials)(mask, X, Y)\n",
+ "print(\"Reference: %.3f ms (avg over %d trials)\" % (time * 1000, n_trials))"
+ ]
+ }
+ ],
+ "metadata": {
+ "jupytext": {
+ "formats": "ipynb,md:myst",
+ "main_language": "python"
+ },
+ "kernelspec": {
+ "display_name": "Python 3",
+ "name": "python3"
+ },
+ "language_info": {
+ "name": "python"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 0
+}
diff --git a/docs/pallas/tpu/sparse.md b/docs/pallas/tpu/sparse.md
new file mode 100644
index 000000000000..2ac25edb5064
--- /dev/null
+++ b/docs/pallas/tpu/sparse.md
@@ -0,0 +1,567 @@
+---
+jupytext:
+ formats: ipynb,md:myst
+ main_language: python
+ text_representation:
+ extension: .md
+ format_name: myst
+ format_version: 0.13
+ jupytext_version: 1.16.4
+kernelspec:
+ display_name: Python 3
+ name: python3
+---
+
++++ {"id": "ZHuzXqQ-9JUQ"}
+
+# Scalar Prefetch and Block-Sparse Computation
+
+In this tutorial, we will cover the basics of block-sparse computing in Pallas. Sparse computation is a major reason to write custom Pallas kernels over simply using JAX/XLA, since it is generally difficult to express programs that perform a dynamic amount of computation in XLA due to static array shapes. In this tutorial we will learn how to use the scalar prefetch feature of Pallas in order to write block-sparse kernels that can dynamically skip over computation and blocks of memory.
+
+```{code-cell}
+---
+executionInfo:
+ elapsed: 56
+ status: ok
+ timestamp: 1726001133029
+ user:
+ displayName: Justin Fu
+ userId: '17543197034567316452'
+ user_tz: 420
+id: ibeIs_6QFMAM
+outputId: d72edb91-4529-4650-c9e9-b96788608635
+---
+import functools
+import timeit
+import numpy as np
+import jax
+from jax import numpy as jnp
+from jax import lax
+from jax.experimental import checkify
+from jax.experimental import pallas as pl
+from jax.experimental.pallas import tpu as pltpu
+
+assert "TPU" in jax.devices()[0].device_kind, "Please run this notebook with TPU devices."
+print("Running on", jax.devices()[0].device_kind)
+```
+
++++ {"id": "FIDGpPTEIcOa"}
+
+## Dynamic Block Indexing with Scalar Prefetch
+
+We will be exploiting the "scalar prefetch" feature of Pallas to enable us to write sparse kernels. Scalar prefetch allows you to pass in a small amount of data into SMEM ("scalar memory") that is loaded before the start of the pipeline ("prefetch"). Because this data is loaded before the pipeline, it is available for use in the `index_map` for each BlockSpec, allowing the you to perform data-dependent indexing calculations. The main goal of this tutorial is to go over common programming patterns that utilize this feature.
+
+To use scalar prefetch, use `pltpu.PrefetchScalarGridSpec` in place of the standard `pl.GridSpec`:
+
+```python
+class PrefetchScalarGridSpec:
+ def __init__(self,
+ num_scalar_prefetch: int,
+ grid: tuple[int, ...],
+ in_specs: PyTree[BlockSpec],
+ out_specs: PyTree[BlockSpec],
+ scratch_shapes: tuple[MemorySpace, ...]):
+ ...
+```
+
+The `num_scalar_prefetch` parameter indicates the number of scalar prefetch values. When this is set to a non-zero value, it changes the call signature of the kernel and index maps to expect additional prefetch values. The prefetch `Ref`s passed in to the `index_map` and kernel are all allocated in SMEM and are not partitioned into blocks as they do not have a BlockSpec defined. Moreover, the order of arguments to both `index_map` and kernel are always fixed and described below:
+
+- Each `BlockSpec`'s `index_map` now expects the prefetch `Ref`s to come after the grid indices:
+```python
+def index_map(*grid_indices, *prefetch_refs):
+ ...
+```
+
+- The user-defined kernel expects prefetch `Ref`s to come before the input `Ref`s. Additionally, the scratch refs come after the output `Ref`s.
+```python
+def kernel(*prefetch_refs, *input_refs, *output_refs, *scratch_refs):
+ ...
+```
+
+- When calling a new kernel using `pallas_call`, the function returned by `pallas_call` also expects the scalar prefetch arguments to come before the inputs, e.g.
+```python
+kernel = pl.pallas_call(...)
+result = kernel(*prefetch_args, *input_args)
+```
+
++++ {"id": "pA8RmHEA2HN3"}
+
+## Example: Block Dynamic Slice with Scalar Prefetch
+
+Let's begin with a basic example that demonstrates how to use the scalar prefetch feature. We will implement a block-aligned dynamic slice kernel which simply extracts a block out of larger array based on user-specified indices:
+
+1. Outside of the kernel, we compute the block index to extract as: `block_idx = (start[0] // size[0], start[1] // size[1])`
+
+2. We pass `block_idx` as a scalar prefetch argument into `pallas_call`.
+
+3. In our index map, we use the block index to select the corresponding block by returning `(block_idx[0], block_idx[1])`.
+
+Of course, this kernel is limited in that our slice sizes must fit inside of a kernel block (limited by VMEM size) and we can only start on size-aligned indices. A more advanced kernel would decouple the kernel block size with the slice size and allow non-aligned start indices.
+
+```{code-cell}
+---
+executionInfo:
+ elapsed: 143
+ status: ok
+ timestamp: 1726003877561
+ user:
+ displayName: Justin Fu
+ userId: '17543197034567316452'
+ user_tz: 420
+id: FWeTBlEYlCGD
+outputId: 4b04a441-c97c-4d0d-d167-c60d4d31fd2e
+---
+def dynamic_slice_kernel(indices, x_ref, o_ref):
+ del indices
+ o_ref[...] = x_ref[...]
+
+@checkify.checkify
+@functools.partial(jax.jit, static_argnums=(2,))
+def block_dynamic_slice(x, starts, sizes):
+ grid_spec = pltpu.PrefetchScalarGridSpec(
+ num_scalar_prefetch=1,
+ grid=(1, 1),
+ in_specs=[pl.BlockSpec(
+ sizes,
+ lambda i, j, block_idx: (block_idx[0], block_idx[1]))],
+ out_specs=pl.BlockSpec(sizes, lambda *_: (0, 0)),
+ )
+
+ kernel = pl.pallas_call(
+ dynamic_slice_kernel,
+ grid_spec=grid_spec,
+ out_shape=jax.ShapeDtypeStruct(shape=sizes, dtype=x.dtype),
+ )
+ # Checkify inserts a runtime assert that starts are divisible by block size.
+ checkify.check(starts[0] % sizes[0] == 0, "Starts must be divisible by size.")
+ checkify.check(starts[1] % sizes[1] == 0, "Starts must be divisible by size.")
+ block_idx = jnp.array([starts[0] // sizes[0], starts[1] // sizes[1]])
+ return kernel(block_idx, x)
+
+shape = (512, 512)
+x = jnp.reshape(jnp.arange(np.prod(shape), dtype=jnp.int32), shape)
+err, result = block_dynamic_slice(x, starts=(128, 256), sizes=(128, 128))
+err.throw()
+ref = lax.dynamic_slice(x, start_indices=(128, 256), slice_sizes=(128, 128))
+diff = jnp.max(jnp.abs(result - ref))
+print("Error |result - lax.dynamic_slice| =", diff)
+```
+
++++ {"id": "K2dod4lkoifa"}
+
+## Sparse Kernels: Representing Sparse Data
+
+Before we dive into implementing sparse kernels, let's first review how sparse matrices are represented. While there are several popular formats for storing sparse matrices, we will be following a blocked variant of the coordinate-list format (COO) in which we will store a matrix as a list of `(block_index, block_data)` pairs. All blocks that are not explicitly stored in the list are assumed to be zero, meaning we can save a significant amount of memory if there are many zero blocks in the matrix.
+
+The following figure demonstrates how we convert a 4x4 dense matrix (left) into a block-COO format (right) with a block size of 2x2. Note that in the sparse format, we can avoid explicitly storing the upper-right block which consists of all zero elements.
+
+![block_coo](../../_static/pallas/sparse/block_coo.svg)
+
+We will use the following helper function to sample a block-sparse matrix. It returns a dense matrix used for checking our results, as well as a list of block data and indices for each axis.
+
+```{code-cell}
+:id: 1gLiSvgIYUEx
+
+def generate_block_sparse_mat(key, M, N, blk_M, blk_N, p=0.2, dtype=jnp.float32):
+ """Returns a sampled matrix and its block-sparse representation.
+
+ Args:
+ key: RNG Key.
+ M: Major array dimension.
+ N: Minor array dimension.
+ blk_M: Block size along M dimension.
+ blk_N: Block size along N dimension.
+ p: Probability that a block will be non-zero.
+ dtype: dtype of the sampled matrix.
+
+ Returns:
+ dense_mat: A (M, N) dense sampled array.
+ block_data: A (num_blocks, blk_M, blk_N) array of data blocks representing
+ the non-zero blocks of the matrix.
+ indices_i: A (num_blocks,) array of block indices for the first axis.
+ indices_j: A (num_blocks,) array of block indices for the second axis.
+ """
+ mask_key, blocks_key = jax.random.split(key)
+ num_blocks = (M // blk_M, N // blk_N)
+ # We first sample a block mask, denoting which blocks are nonzero.
+ block_mask = jax.random.bernoulli(mask_key, p=p, shape=num_blocks)
+ num_blocks = jnp.sum(block_mask)
+ indices = jnp.where(block_mask)
+ # For each non-zero block, we sample a block of random values.
+ block_data = jax.random.uniform(blocks_key,
+ shape=(num_blocks, blk_M, blk_N),
+ dtype=dtype)
+ # For checking purposes, create the dense version of the sparse matrix.
+ dense_mat = jnp.zeros((M, N), dtype=dtype)
+ for blk in range(num_blocks):
+ idx_i = indices[0][blk]
+ idx_j = indices[1][blk]
+ slice_i = slice(idx_i * blk_M, (idx_i + 1) * blk_M)
+ slice_j = slice(idx_j * blk_N, (idx_j + 1) * blk_N)
+ dense_mat = dense_mat.at[slice_i, slice_j].set(block_data[blk])
+ return dense_mat, block_data, indices[0], indices[1]
+```
+
++++ {"id": "eFyoZSTOH9Fk"}
+
+## Example: Sparse @ Dense Matrix Multiplication
+
+In our first example, we will multiple a sparse LHS matrix with a dense RHS matrix to produce a dense output.
+
+We will structure our kernel grid with 2 loops - the outer loop over the columns of the RHS/output, and inner loop over the sparse blocks of the LHS. During each inner loop iteration, we load one block from the LHS and lookup the corresponding block on in the RHS using the block index of the contracting dimension (K). We multiply the two blocks together and accumulate into the correct output block. One outer loop iteration will compute a result for an entire column as depicted by the following diagram:
+
+![sparse_matmul](../../_static/pallas/sparse/sparse_matmul.svg)
+
+It is important that we group the block indices by row (e.g. `[0, 0, 1, 2, 3, 3]`) before we pass them into the kernel for two reasons. First, in our kernel we need to know when to initially zero-out the accumulator in the output ref, and it is easy to do so if the row indices are grouped. Second, the pipelining logic for Pallas does not allow us to re-visit blocks in the output `Ref` on non-consecutive iterations, and therefore we need to do all accumulation into an output block in consecutive kernel iterations. This is because the pipeline emitter will realize that we loading the same output block on consecutive iterations and keep the block in VMEM. When we change output block Pallas will finally store the output into HBM and assume we never touch it again. Failure to access output blocks consecutively will result in incorrect values even though the kernel is otherwise logically correct.
+
+```{code-cell}
+---
+executionInfo:
+ elapsed: 673
+ status: ok
+ timestamp: 1725919879291
+ user:
+ displayName: Justin Fu
+ userId: '17543197034567316452'
+ user_tz: 420
+id: WfyV2WWhjsyA
+outputId: fa4d4fff-bc6b-4dc9-ac14-63276ca14131
+---
+M = N = K = 16384
+blk_M = blk_N = blk_K = 512
+
+
+def dsd_kernel(idxs_i_ref, idxs_k_ref, # Scalar prefetch inputs.
+ x_ref, y_ref, _, o_ref, # Kernel inputs.
+ accum_scratch,
+ ):
+ """A DSD (Dense = Sparse @ Dense) matmul kernel."""
+ del idxs_k_ref
+ blk_idx = pl.program_id(0)
+ is_start = blk_idx == 0
+ changed_blocks = (idxs_i_ref[blk_idx] != idxs_i_ref[jnp.maximum(blk_idx-1, 0)])
+ @pl.when(is_start | changed_blocks)
+ def _():
+ accum_scratch[...] = jnp.zeros_like(accum_scratch)
+ accum_scratch[...] += jnp.dot(x_ref[0, :, :], y_ref[...], preferred_element_type=jnp.float32)
+
+ next_block_change = (idxs_i_ref[blk_idx] != idxs_i_ref[jnp.minimum(blk_idx+1, num_blocks)])
+ is_end = blk_idx == (num_blocks - 1)
+ @pl.when(is_end | next_block_change)
+ def _():
+ o_ref[...] = accum_scratch[...].astype(o_ref.dtype)
+
+
+def x_map(blk_idx, j, blk_idxs_i, blk_idxs_k):
+ del j, blk_idxs_i, blk_idxs_k
+ return (blk_idx, 0, 0)
+def y_map(blk_idx, j, blk_idxs_i, blk_idxs_k):
+ del blk_idxs_i
+ return (blk_idxs_k[blk_idx], j)
+def o_map(blk_idx, j, blk_idxs_i, blk_idxs_k):
+ del blk_idxs_k
+ return (blk_idxs_i[blk_idx], j)
+
+(X_dense, X_blocks, indices_i, indices_k) = generate_block_sparse_mat(
+ jax.random.key(0), M, K, blk_M, blk_K, p=0.1, dtype=jnp.bfloat16)
+num_blocks = X_blocks.shape[0]
+Y = jax.random.uniform(jax.random.key(1), shape=(K, N), dtype=jnp.bfloat16)
+zeros = jnp.zeros((M, N), dtype=jnp.bfloat16)
+out_shape = jax.ShapeDtypeStruct((M, N), dtype=jnp.bfloat16)
+
+grid_spec = pltpu.PrefetchScalarGridSpec(
+ num_scalar_prefetch=2,
+ # Note that while num_blocks is static here, Pallas does support
+ # dynamic grid sizes.
+ grid=(num_blocks, N // blk_N),
+ in_specs=[pl.BlockSpec((1, blk_M, blk_K), x_map),
+ pl.BlockSpec((blk_K, blk_N), y_map),
+ # Placeholder for a zeros-array used by input_output_aliases.
+ pl.BlockSpec((blk_M, blk_N), o_map),
+ ],
+ out_specs=pl.BlockSpec((blk_M, blk_N), o_map),
+ scratch_shapes=[pltpu.VMEM((blk_M, blk_N), dtype=jnp.float32)]
+)
+kernel = pl.pallas_call(
+ dsd_kernel,
+ grid_spec=grid_spec,
+ out_shape=out_shape,
+ # We use input-output aliases to zero-out o_ref for blocks that we never
+ # visit. By passing in an array of zeros we avoid having o_ref start with
+ # uninitialized values.
+ input_output_aliases={4: 0}, # Map zeros to o_ref.
+)
+args = (indices_i, indices_k, X_blocks, Y, zeros)
+result = kernel(*args)
+
+ref = X_dense @ Y
+diff = jnp.abs(ref - result)
+print('mean |result - ref|:', jnp.mean(diff))
+```
+
++++ {"id": "2KDgPKF2tUjq"}
+
+We can do a quick benchmark to compare the performance of our sparse kernel compared to a dense matmul in JAX. On a TPU v5e chip, this kernel achieves a roughly ~6x speed increase compared to the theoretical 10x from the sparsity factor.
+
+There are a few main tips for performance here, mainly centered around reducing the communication overhead between HBM/VMEM:
+- Using `dtype=jnp.bfloat16` is critical for performance since it reduces memory bandwidth by half.
+- Using larger block sizes also helps, since matrix multiply is an $O(N^3)$ compute and $O(N^2)$ memory operation. As $N$ grows larger, the kernel becomes compute-bound. However, a counter-argument to this in practice is that smaller block sizes also enables data to be more sparse, so this is a parameter that should be selected carefully.
+
+```{code-cell}
+---
+executionInfo:
+ elapsed: 6576
+ status: ok
+ timestamp: 1725919886762
+ user:
+ displayName: Justin Fu
+ userId: '17543197034567316452'
+ user_tz: 420
+id: CkzjqnekpZbx
+outputId: 1ae9031e-705a-4d05-f8b9-d09623918300
+---
+# Benchmark Sparse Pallas kernel vs reference JAX implementation
+
+def benchmark(f, ntrials: int = 100):
+ def run(*args, **kwargs):
+ # Compile function first
+ jax.block_until_ready(f(*args, **kwargs))
+ # Time function
+ result = timeit.timeit(lambda: jax.block_until_ready(f(*args, **kwargs)),
+ number=ntrials)
+ time = result / ntrials
+ return time
+ return run
+
+
+n_trials = 100
+
+pallas_impl = lambda *args: kernel(*args)
+time = benchmark(pallas_impl, n_trials)(indices_i, indices_k, X_blocks, Y, zeros)
+print("Sparse Kernel: %.3f ms (avg over %d trials)" % (time * 1000, n_trials))
+
+ref_impl = jax.jit(lambda x, y: x @ y)
+time = benchmark(ref_impl, n_trials)(X_dense, Y)
+print("Reference: %.3f ms (avg over %d trials)" % (time * 1000, n_trials))
+```
+
++++ {"id": "Q1KKd5vTCwnB"}
+
+## Sparse Access Patterns on Dense Data
+
+In our previous example we considered the case when the data itself is sparse. This manifested itself in the kernel structure as a dimension in the kernel grid that was dynamic and looped over the number of nonzero blocks (`num_blocks`).
+
+A second useful programming pattern emerges when the underlying is data is dense, but we wish to perform sparse computation over it. Our kernel grid in this case will be dense, but we wish to skip over some blocks in the grid as indicated by a block-sparse mask. This type of programming pattern is commonly arises when using masks in many machine learning applications, such as causal or local masks in self-attention. In these cases, we can entirely skip over computation in blocks where the mask is zeroed-out. Examples of this programming pattern can be found in the Splash Attention and Grouped Matrix Multiplication kernels located in `jax/experimental/pallas/ops/tpu`, or in PyTorch's [FlexAttention](https://pytorch.org/blog/flexattention/).
+
+The main performance consideration with dealing with a sparse access pattern on dense data is the interaction with pipelining. On any given kernel iteration, the Pallas pipeline emitter will attempt to prefetch the next block of data by calling the `index_map` for each `BlockSpec` on the next iteration of the grid. However, if our computation is sparse we may be skipping the computation for the next block in the grid, so we need some method to tell the pipeline instead begin fetching the *next block that we are not skipping*. In order to do this, we need to construct *prefetch maps* which contains indices to the next non-skipped block of data for each kernel input. The following diagram illustrates how a prefetch map could be constructed for a block-sparse mask that is stored in a COO-like format.
+
+![prefetch_map](../../_static/pallas/sparse/prefetch_map.svg)
+
+*Left: A sparse access pattern, where the color blue denotes blocks with non-zero masks that we need to compute. Right: The prefetch map, where each element of the array contains the index of the next non-zero block data.*
+
+Once the prefetch map has been constructed, we can pass the map as a scalar prefetch argument and query it in the `index_map` function of the BlockSpec.
+
+```python
+def mask_index_map(prefetch_map, i, j, ...):
+ next_nonzero_block = prefetch_map[i, j]
+ return (next_nonzero_block, 0, 0)
+```
+
+We can construct similar index maps for the other inputs to the kernel. For dense inputs you will most likely need to construct prefetch maps which point to the next non-zero block index in the grid. Our next example will provide an example of using these prefetch maps.
+
++++ {"id": "ii7rzL5YIA8-"}
+
+## Example: Dense @ Dense Matrix Multiplication with a Block-Sparse Output Mask
+
++++ {"id": "ecjiqWfA2RlV"}
+
+In our next example we will cover dense matrix multiplication fused with a sparse output mask using a prefetch map to improve pipelining performance. We will use the mask to selectively skip computing output blocks that are zeroed-out, therefore saving on computation costs.
+
+As we will be working with a sparse mask, we will begin by implementing a function that converts an `N x M` mask stored in dense format into a block-sparse format. We additionally need to compute prefetch maps to help the pipeline emitter know which block to fetch next. In total, our `sparsify_mask` function computes:
+- A `block_mask` of shape `(num_N_blocks, num_M_blocks)` indicating if a block is all-zeros (value `0`) or contains non-zero elements (value `1`). If the `block_mask` has a value of 0 we can skip computing the block in the kernel.
+- A `prefetch_mask` array of shape `(num_N_blocks, num_M_blocks)` consisting of indices into `mask_data` for the next non-zero block.
+- A `prefetch_i` array of shape `(num_N_blocks, num_M_blocks)` consisting of the next non-masked `i` index of the mask.
+- A `prefetch_j` array of shape `(num_N_blocks, num_M_blocks)` consisting of the next non-masked `j` index of the mask.
+- A `mask_data` array of shape `(num_blocks, blk_N, blk_M)` containing data for non-zero blocks of the mask.
+
+```{code-cell}
+:id: 19zGcliL2SJy
+
+def sparsify_mask(mask: jax.Array,
+ block_shape: tuple[int, int]):
+ """Preprocesses a mask into a sparse reprentation.
+
+ Args:
+ mask: A boolean array of shape [M, N]
+ block_shape: The size of a single block.
+
+ Returns:
+ block_mask: A block_shape array of booleans indicating whether a block
+ is all-zeros (0) or contains non-zero elements (1).
+ prefetch_mask: A block_shape array of integers indicating the index of the
+ next non-zero block.
+ mask_data: A (num_blocks, block_shape) array containing
+ the data for non-zero blocks of the mask.
+ """
+ M, N = mask.shape
+ bm, bn = block_shape
+
+ block_mask = jnp.zeros((M // bm, N // bn), dtype=mask.dtype)
+ mask_types_finder = []
+ mask_data = []
+ mask_type_idxs = []
+
+ next_mask_type_idx = 0
+ prefetch_mask = jnp.zeros_like(block_mask)
+ next_i = (M // bm) - 1
+ next_j = (N // bn) - 1
+ prefetch_i = jnp.zeros_like(block_mask)
+ prefetch_j = jnp.zeros_like(block_mask)
+ for i in range(M // bm, -1, -1):
+ for j in range(N // bn, -1, -1):
+ mask_block = mask[i * bm :(i + 1) * bm,
+ j * bn :(j + 1) * bn]
+ is_nonzero = jnp.any(mask_block)
+ if is_nonzero:
+ try:
+ type_index = mask_types_finder.index(str(mask_block))
+ except ValueError:
+ type_index = len(mask_types_finder)
+ mask_types_finder.append(str(mask_block))
+ mask_data.append(mask_block)
+ next_mask_type_idx = type_index
+ next_i = i
+ next_j = j
+ else:
+ type_index = -1
+ mask_type_idxs.append(type_index)
+ block_mask = block_mask.at[i, j].set(is_nonzero)
+ prefetch_mask = prefetch_mask.at[i, j].set(next_mask_type_idx)
+ prefetch_i = prefetch_i.at[i, j].set(next_i)
+ prefetch_j = prefetch_j.at[i, j].set(next_j)
+ return block_mask, prefetch_mask, prefetch_i, prefetch_j, jnp.stack(mask_data)
+```
+
++++ {"id": "w4b7ckKq67Xw"}
+
+In terms of the structure of the kernel, we use the same grid pattern as the standard matrix multiplication kernel we covered in previous tutorials with a 3 loops over the `N`, `M`, and `K` dimensions. Within the kernel itself, we first check the `block_mask` to see if the mask for the current output block was all zeros. If the mask is all zeros, we can skip computation and move onto the next block; otherwise we need to compute the matrix multiplication and then mask the result.
+
+```{code-cell}
+---
+executionInfo:
+ elapsed: 5374
+ status: ok
+ timestamp: 1725919713252
+ user:
+ displayName: Justin Fu
+ userId: '17543197034567316452'
+ user_tz: 420
+id: 4YQ9OmbTCSjT
+outputId: 2d752609-34f2-4059-e8ba-4d80afe8cb26
+---
+M = N = K = 16384
+blk_M = blk_N = 512
+blk_K = 1024
+
+def sparse_mask_matmul(
+ block_mask_ref, prefetch_mask, prefetch_i, prefetch_j, # Scalar prefetch inputs.
+ x_ref, y_ref, mask_ref, o_ref, # Kernel inputs.
+ accum_scratch
+ ):
+ del prefetch_mask, prefetch_i, prefetch_j
+ i, j, k = pl.program_id(0), pl.program_id(1), pl.program_id(2)
+ should_compute = block_mask_ref[i, j] != 0
+ @pl.when(k == 0)
+ def _():
+ o_ref[...] = jnp.zeros_like(o_ref)
+ accum_scratch[...] = jnp.zeros_like(accum_scratch[...])
+
+ # We only compute the output for blocks with non-zero masks.
+ # Otherwise we skip the computation entirely.
+ @pl.when(should_compute)
+ def _():
+ result = jnp.dot(x_ref[...], y_ref[...], preferred_element_type=jnp.float32)
+ accum_scratch[...] += result
+ @pl.when(k == pl.num_programs(2) - 1)
+ def _():
+ o_ref[...] = (mask_ref[0, ...] * accum_scratch[...]).astype(o_ref.dtype)
+
+X = jax.random.normal(jax.random.key(0), shape=(M, K), dtype=jnp.bfloat16)
+Y = jax.random.normal(jax.random.key(1), shape=(K, N), dtype=jnp.bfloat16)
+mask = jnp.ones((M, N), dtype=jnp.int32)
+mask = jnp.tril(mask)
+block_mask, prefetch_mask, prefetch_i, prefetch_j, sparse_mask_data = sparsify_mask(mask, (blk_M, blk_N))
+
+def x_map(i, j, k, block_mask, prefetch_mask, prefetch_i, prefetch_j):
+ del prefetch_mask, prefetch_j
+ # Zero-out the k index if the mask is zero, to avoid constantly fetching
+ # new blocks in the inner loop for blocks we are skipping.
+ k_fetch = (block_mask[i, j] != 0) * k
+ return (prefetch_i[i, j], k_fetch)
+
+def y_map(i, j, k, block_mask, prefetch_mask, prefetch_i, prefetch_j):
+ del prefetch_mask, prefetch_i
+ k_fetch = (block_mask[i, j] != 0) * k
+ return (k_fetch, prefetch_j[i, j])
+
+def mask_map(i, j, k, block_mask, prefetch_mask, *_):
+ del k, block_mask
+ return (prefetch_mask[i, j], 0, 0)
+
+def o_map(i, j, k, *_):
+ del k
+ return (i, j)
+
+grid_spec = pltpu.PrefetchScalarGridSpec(
+ num_scalar_prefetch=4,
+ grid=(M // blk_M, N // blk_N, K // blk_K),
+ in_specs=[pl.BlockSpec((blk_M, blk_K), x_map),
+ pl.BlockSpec((blk_K, blk_N), y_map),
+ pl.BlockSpec((1, blk_M, blk_N), mask_map)],
+ out_specs=pl.BlockSpec((blk_M, blk_N), o_map),
+ scratch_shapes=[pltpu.VMEM((blk_M, blk_N), dtype=jnp.float32)]
+)
+kernel = pl.pallas_call(
+ sparse_mask_matmul,
+ grid_spec=grid_spec,
+ out_shape=jax.ShapeDtypeStruct((M, N), jnp.bfloat16),
+)
+args = (block_mask, prefetch_mask, prefetch_i, prefetch_j, X, Y, sparse_mask_data)
+result = kernel(*args)
+
+ref = mask * (X @ Y)
+diff = jnp.abs(ref - result)
+print('mean |result - ref|:', jnp.mean(diff))
+```
+
++++ {"id": "uutNGgjZGGhB"}
+
+Now let's compare performance versus a naive dense implementation. On TPU v5e, we achieve around a ~1.8x speed increase with the sparse kernel, compared to a theoretical best-case of 2x from using a lower triangular mask and only visiting half of the possible outputs.
+
+We would generally expect performance to get closer to the theoretical peak as our inputs get larger, since a few of the main reasons why we don't exactly reach theoretical performance are:
+- We skip slightly less than half of computation since the blocks along the diagonal are mixed 0s and 1s, and for mixed blocks we need to compute the entire block. With larger inputs, our overhead for mixed blocks becomes smaller relative to the overall computation.
+- The pipeline bubble also becomes accounts for a less percentage of the overall runtime as inputs become larger.
+
+```{code-cell}
+---
+executionInfo:
+ elapsed: 8877
+ status: ok
+ timestamp: 1725917397452
+ user:
+ displayName: Justin Fu
+ userId: '17543197034567316452'
+ user_tz: 420
+id: MAT9JjGNvsx8
+outputId: a32d56fb-a71b-4007-c6a5-e5270dcaa6cf
+---
+n_trials = 100
+
+pallas_impl = lambda *args: kernel(*args)
+time = benchmark(pallas_impl, n_trials)(block_mask, prefetch_mask, prefetch_i, prefetch_j, X, Y, sparse_mask_data)
+print("Sparse Kernel: %.3f ms (avg over %d trials)" % (time * 1000, n_trials))
+
+ref_impl = jax.jit(lambda mask, x, y: mask * (x @ y))
+time = benchmark(ref_impl, n_trials)(mask, X, Y)
+print("Reference: %.3f ms (avg over %d trials)" % (time * 1000, n_trials))
+```
diff --git a/docs/persistent_compilation_cache.md b/docs/persistent_compilation_cache.md
index af20aa7bba24..47a7587b620f 100644
--- a/docs/persistent_compilation_cache.md
+++ b/docs/persistent_compilation_cache.md
@@ -1,4 +1,4 @@
-# Persistent Compilation Cache
+# Persistent compilation cache
@@ -29,7 +29,7 @@ f(x)
### Setting cache directory
The compilation cache is enabled when the
-[cache location](https://github.com/google/jax/blob/jax-v0.4.26/jax/_src/config.py#L1206)
+[cache location](https://github.com/jax-ml/jax/blob/jax-v0.4.26/jax/_src/config.py#L1206)
is set. This should be done prior to the first compilation. Set the location as
follows:
@@ -54,7 +54,7 @@ os.environ["JAX_COMPILATION_CACHE_DIR"] = "/tmp/jax_cache"
jax.config.update("jax_compilation_cache_dir", "/tmp/jax_cache")
```
-(3) Using [`set_cache_dir()`](https://github.com/google/jax/blob/jax-v0.4.26/jax/experimental/compilation_cache/compilation_cache.py#L18)
+(3) Using [`set_cache_dir()`](https://github.com/jax-ml/jax/blob/jax-v0.4.26/jax/experimental/compilation_cache/compilation_cache.py#L18)
```python
from jax.experimental.compilation_cache import compilation_cache as cc
diff --git a/docs/profiling.md b/docs/profiling.md
index 6eceec8f54b8..91f4d61b21b6 100644
--- a/docs/profiling.md
+++ b/docs/profiling.md
@@ -1,4 +1,4 @@
-# Profiling JAX programs
+# Profiling computation
diff --git a/docs/quickstart.md b/docs/quickstart.md
index e071a7ce7555..77cbb9d46ab8 100644
--- a/docs/quickstart.md
+++ b/docs/quickstart.md
@@ -5,7 +5,7 @@ jupytext:
extension: .md
format_name: myst
format_version: 0.13
- jupytext_version: 1.16.1
+ jupytext_version: 1.16.4
kernelspec:
display_name: Python 3
language: python
@@ -16,7 +16,7 @@ kernelspec:
-**JAX a library for array-oriented numerical computation (*à la* [NumPy](https://numpy.org/)), with automatic differentiation and JIT compilation to enable high-performance machine learning research**.
+**JAX is a library for array-oriented numerical computation (*à la* [NumPy](https://numpy.org/)), with automatic differentiation and JIT compilation to enable high-performance machine learning research**.
This document provides a quick overview of essential JAX features, so you can get started with JAX quickly:
@@ -88,8 +88,8 @@ _ = selu_jit(x) # compiles on first call
%timeit selu_jit(x).block_until_ready()
```
-The above timing represent execution on CPU, but the same code can be run on GPU or TPU,
-typically for an even greater speedup.
+The above timing represents execution on CPU, but the same code can be run on GPU or
+TPU, typically for an even greater speedup.
For more on JIT compilation in JAX, check out {ref}`jit-compilation`.
@@ -183,7 +183,7 @@ print('Naively batched')
%timeit naively_batched_apply_matrix(batched_x).block_until_ready()
```
-A programmer familiar with the the `jnp.dot` function might recognize that `apply_matrix` can
+A programmer familiar with the `jnp.dot` function might recognize that `apply_matrix` can
be rewritten to avoid explicit looping, using the built-in batching semantics of `jnp.dot`:
```{code-cell}
diff --git a/docs/random-numbers.md b/docs/random-numbers.md
index 85bb5ce01974..2ad1eadb0968 100644
--- a/docs/random-numbers.md
+++ b/docs/random-numbers.md
@@ -5,7 +5,7 @@ jupytext:
extension: .md
format_name: myst
format_version: 0.13
- jupytext_version: 1.16.1
+ jupytext_version: 1.16.4
kernelspec:
display_name: Python 3
language: python
diff --git a/docs/sharded-computation.ipynb b/docs/sharded-computation.ipynb
index 8fa2107795fd..cdfda63c6f13 100644
--- a/docs/sharded-computation.ipynb
+++ b/docs/sharded-computation.ipynb
@@ -5,7 +5,7 @@
"metadata": {},
"source": [
"(sharded-computation)=\n",
- "# Introduction to sharded computation\n",
+ "# Introduction to parallel programming\n",
"\n",
"\n",
"\n",
@@ -60,7 +60,7 @@
"\n",
"Key to all of the distributed computation approaches below is the concept of *data sharding*, which describes how data is laid out on the available devices.\n",
"\n",
- "How can JAX can understand how the data is laid out across devices? JAX's datatype, the {class}`jax.Array` immutable array data structure, represents arrays with physical storage spanning one or multiple devices, and helps make parallelism a core feature of JAX. The {class}`jax.Array` object is designed with distributed data and computation in mind. Every `jax.Array` has an associated {mod}`jax.sharding.Sharding` object, which describes which shard of the global data is required by each global device. When you create a {class}`jax.Array` from scratch, you also need to create its `Sharding`.\n",
+ "How can JAX understand how the data is laid out across devices? JAX's datatype, the {class}`jax.Array` immutable array data structure, represents arrays with physical storage spanning one or multiple devices, and helps make parallelism a core feature of JAX. The {class}`jax.Array` object is designed with distributed data and computation in mind. Every `jax.Array` has an associated {mod}`jax.sharding.Sharding` object, which describes which shard of the global data is required by each global device. When you create a {class}`jax.Array` from scratch, you also need to create its `Sharding`.\n",
"\n",
"In the simplest cases, arrays are sharded on a single device, as demonstrated below:"
]
@@ -188,15 +188,9 @@
}
],
"source": [
- "# Pardon the boilerplate; constructing a sharding will become easier in future!\n",
- "from jax.sharding import Mesh\n",
- "from jax.sharding import PartitionSpec\n",
- "from jax.sharding import NamedSharding\n",
- "from jax.experimental import mesh_utils\n",
+ "from jax.sharding import PartitionSpec as P\n",
"\n",
- "P = jax.sharding.PartitionSpec\n",
- "devices = mesh_utils.create_device_mesh((2, 4))\n",
- "mesh = jax.sharding.Mesh(devices, ('x', 'y'))\n",
+ "mesh = jax.make_mesh((2, 4), ('x', 'y'))\n",
"sharding = jax.sharding.NamedSharding(mesh, P('x', 'y'))\n",
"print(sharding)"
]
@@ -366,7 +360,7 @@
"\n",
"## 2. Semi-automated sharding with constraints\n",
"\n",
- "If you'd like to have some control over the sharding used within a particular computation, JAX offers the {func}`~jax.lax.with_sharding_constraint` function. You can use {func}`jax.lax.with_sharding_constraint` (in place of (func}`jax.device_put()`) together with {func}`jax.jit` for more control over how the compiler constraints how the intermediate values and outputs are distributed.\n",
+ "If you'd like to have some control over the sharding used within a particular computation, JAX offers the {func}`~jax.lax.with_sharding_constraint` function. You can use {func}`jax.lax.with_sharding_constraint` (in place of {func}`jax.device_put()`) together with {func}`jax.jit` for more control over how the compiler constraints how the intermediate values and outputs are distributed.\n",
"\n",
"For example, suppose that within `f_contract` above, you'd prefer the output not to be partially-replicated, but rather to be fully sharded across the eight devices:"
]
@@ -405,9 +399,7 @@
"@jax.jit\n",
"def f_contract_2(x):\n",
" out = x.sum(axis=0)\n",
- " # mesh = jax.create_mesh((8,), 'x')\n",
- " devices = mesh_utils.create_device_mesh(8)\n",
- " mesh = jax.sharding.Mesh(devices, 'x')\n",
+ " mesh = jax.make_mesh((8,), ('x',))\n",
" sharding = jax.sharding.NamedSharding(mesh, P('x'))\n",
" return jax.lax.with_sharding_constraint(out, sharding)\n",
"\n",
@@ -460,8 +452,7 @@
],
"source": [
"from jax.experimental.shard_map import shard_map\n",
- "P = jax.sharding.PartitionSpec\n",
- "mesh = jax.sharding.Mesh(jax.devices(), 'x')\n",
+ "mesh = jax.make_mesh((8,), ('x',))\n",
"\n",
"f_elementwise_sharded = shard_map(\n",
" f_elementwise,\n",
@@ -659,8 +650,7 @@
}
],
"source": [
- "P = jax.sharding.PartitionSpec\n",
- "mesh = jax.sharding.Mesh(jax.devices(), 'x')\n",
+ "mesh = jax.make_mesh((8,), ('x',))\n",
"sharding = jax.sharding.NamedSharding(mesh, P('x'))\n",
"\n",
"x_sharded = jax.device_put(x, sharding)\n",
diff --git a/docs/sharded-computation.md b/docs/sharded-computation.md
index 345ca7987b41..84516a557166 100644
--- a/docs/sharded-computation.md
+++ b/docs/sharded-computation.md
@@ -5,14 +5,14 @@ jupytext:
extension: .md
format_name: myst
format_version: 0.13
- jupytext_version: 1.16.1
+ jupytext_version: 1.16.4
kernelspec:
display_name: Python 3
name: python3
---
(sharded-computation)=
-# Introduction to sharded computation
+# Introduction to parallel programming
@@ -39,7 +39,7 @@ jax.devices()
Key to all of the distributed computation approaches below is the concept of *data sharding*, which describes how data is laid out on the available devices.
-How can JAX can understand how the data is laid out across devices? JAX's datatype, the {class}`jax.Array` immutable array data structure, represents arrays with physical storage spanning one or multiple devices, and helps make parallelism a core feature of JAX. The {class}`jax.Array` object is designed with distributed data and computation in mind. Every `jax.Array` has an associated {mod}`jax.sharding.Sharding` object, which describes which shard of the global data is required by each global device. When you create a {class}`jax.Array` from scratch, you also need to create its `Sharding`.
+How can JAX understand how the data is laid out across devices? JAX's datatype, the {class}`jax.Array` immutable array data structure, represents arrays with physical storage spanning one or multiple devices, and helps make parallelism a core feature of JAX. The {class}`jax.Array` object is designed with distributed data and computation in mind. Every `jax.Array` has an associated {mod}`jax.sharding.Sharding` object, which describes which shard of the global data is required by each global device. When you create a {class}`jax.Array` from scratch, you also need to create its `Sharding`.
In the simplest cases, arrays are sharded on a single device, as demonstrated below:
@@ -72,15 +72,9 @@ Here, define a {class}`~jax.sharding.NamedSharding`, which specifies an N-dimens
```{code-cell}
:outputId: 0b397dba-3ddc-4aca-f002-2beab7e6b8a5
-# Pardon the boilerplate; constructing a sharding will become easier in future!
-from jax.sharding import Mesh
-from jax.sharding import PartitionSpec
-from jax.sharding import NamedSharding
-from jax.experimental import mesh_utils
+from jax.sharding import PartitionSpec as P
-P = jax.sharding.PartitionSpec
-devices = mesh_utils.create_device_mesh((2, 4))
-mesh = jax.sharding.Mesh(devices, ('x', 'y'))
+mesh = jax.make_mesh((2, 4), ('x', 'y'))
sharding = jax.sharding.NamedSharding(mesh, P('x', 'y'))
print(sharding)
```
@@ -139,7 +133,7 @@ The result is partially replicated: that is, the first two elements of the array
## 2. Semi-automated sharding with constraints
-If you'd like to have some control over the sharding used within a particular computation, JAX offers the {func}`~jax.lax.with_sharding_constraint` function. You can use {func}`jax.lax.with_sharding_constraint` (in place of (func}`jax.device_put()`) together with {func}`jax.jit` for more control over how the compiler constraints how the intermediate values and outputs are distributed.
+If you'd like to have some control over the sharding used within a particular computation, JAX offers the {func}`~jax.lax.with_sharding_constraint` function. You can use {func}`jax.lax.with_sharding_constraint` (in place of {func}`jax.device_put()`) together with {func}`jax.jit` for more control over how the compiler constraints how the intermediate values and outputs are distributed.
For example, suppose that within `f_contract` above, you'd prefer the output not to be partially-replicated, but rather to be fully sharded across the eight devices:
@@ -149,9 +143,7 @@ For example, suppose that within `f_contract` above, you'd prefer the output not
@jax.jit
def f_contract_2(x):
out = x.sum(axis=0)
- # mesh = jax.create_mesh((8,), 'x')
- devices = mesh_utils.create_device_mesh(8)
- mesh = jax.sharding.Mesh(devices, 'x')
+ mesh = jax.make_mesh((8,), ('x',))
sharding = jax.sharding.NamedSharding(mesh, P('x'))
return jax.lax.with_sharding_constraint(out, sharding)
@@ -177,8 +169,7 @@ In the automatic parallelism methods explored above, you can write a function as
:outputId: 435c32f3-557a-4676-c11b-17e6bab8c1e2
from jax.experimental.shard_map import shard_map
-P = jax.sharding.PartitionSpec
-mesh = jax.sharding.Mesh(jax.devices(), 'x')
+mesh = jax.make_mesh((8,), ('x',))
f_elementwise_sharded = shard_map(
f_elementwise,
@@ -268,8 +259,7 @@ If you shard the leading axis of both `x` and `weights` in the same way, then th
```{code-cell}
:outputId: 80be899e-8dbc-4bfc-acd2-0f3d554a0aa5
-P = jax.sharding.PartitionSpec
-mesh = jax.sharding.Mesh(jax.devices(), 'x')
+mesh = jax.make_mesh((8,), ('x',))
sharding = jax.sharding.NamedSharding(mesh, P('x'))
x_sharded = jax.device_put(x, sharding)
diff --git a/docs/sphinxext/jax_extensions.py b/docs/sphinxext/jax_extensions.py
index 3a78557632a7..7cce8b88254d 100644
--- a/docs/sphinxext/jax_extensions.py
+++ b/docs/sphinxext/jax_extensions.py
@@ -26,14 +26,14 @@ def jax_issue_role(name, rawtext, text, lineno, inliner, options=None,
:jax-issue:`1234`
This will output a hyperlink of the form
- `#1234 `_. These links work even
+ `#1234 `_. These links work even
for PR numbers.
"""
text = text.lstrip('#')
if not text.isdigit():
raise RuntimeError(f"Invalid content in {rawtext}: expected an issue or PR number.")
options = {} if options is None else options
- url = f"https://github.com/google/jax/issues/{text}"
+ url = f"https://github.com/jax-ml/jax/issues/{text}"
node = nodes.reference(rawtext, '#' + text, refuri=url, **options)
return [node], []
diff --git a/docs/stateful-computations.md b/docs/stateful-computations.md
index 5a8af2b74142..2ff82e0431e2 100644
--- a/docs/stateful-computations.md
+++ b/docs/stateful-computations.md
@@ -5,14 +5,14 @@ jupytext:
extension: .md
format_name: myst
format_version: 0.13
- jupytext_version: 1.16.1
+ jupytext_version: 1.16.4
kernelspec:
display_name: Python 3
language: python
name: python3
---
-# Stateful Computations
+# Stateful computations
@@ -144,7 +144,7 @@ This is because, like the strategy we just applied, object-oriented programming
In our case, the `CounterV2` class is nothing more than a namespace bringing all the functions that use `CounterState` into one location. Exercise for the reader: do you think it makes sense to keep it as a class?
-Incidentally, you've already seen an example of this strategy in the JAX pseudo-randomness API, {mod}`jax.random`, shown in the :ref:`pseudorandom-numbers` section.
+Incidentally, you've already seen an example of this strategy in the JAX pseudo-randomness API, {mod}`jax.random`, shown in the {ref}`pseudorandom-numbers` section.
Unlike Numpy, which manages random state using implicitly updated stateful classes, JAX requires the programmer to work directly with the random generator state -- the PRNG key.
@@ -234,4 +234,4 @@ Handling parameters manually seems fine if you're dealing with two parameters, b
2) Are we supposed to pipe all these things around manually?
-The details can be tricky to handle, but there are examples of libraries that take care of this for you. See [JAX Neural Network Libraries](https://github.com/google/jax#neural-network-libraries) for some examples.
+The details can be tricky to handle, but there are examples of libraries that take care of this for you. See [JAX Neural Network Libraries](https://github.com/jax-ml/jax#neural-network-libraries) for some examples.
diff --git a/docs/tutorials.rst b/docs/tutorials.rst
index 2f90e4226e50..a31517155e1a 100644
--- a/docs/tutorials.rst
+++ b/docs/tutorials.rst
@@ -1,7 +1,7 @@
.. _jax-tutorials:
-JAX tutorials
-=============
+Tutorials
+=========
.. toctree::
:maxdepth: 1
@@ -16,3 +16,13 @@ JAX tutorials
working-with-pytrees
sharded-computation
stateful-computations
+
+.. toctree::
+ :maxdepth: 1
+ :caption: Advanced tutorials
+
+ advanced-autodiff
+ external-callbacks
+ gradient-checkpointing
+ jax-primitives
+ jaxpr
diff --git a/docs/type_promotion.rst b/docs/type_promotion.rst
index 103a8331df2b..d3724745fe08 100644
--- a/docs/type_promotion.rst
+++ b/docs/type_promotion.rst
@@ -218,7 +218,7 @@ Strict dtype promotion
----------------------
In some contexts it can be useful to disable implicit type promotion behavior, and
instead require all promotions to be explicit. This can be done in JAX by setting the
-``jax_numpy_dtype_promtion`` flag to ``'strict'``. Locally, it can be done with a\
+``jax_numpy_dtype_promotion`` flag to ``'strict'``. Locally, it can be done with a\
context manager:
.. code-block:: python
diff --git a/docs/user_guides.rst b/docs/user_guides.rst
index 57913bf6d4c8..6481da7a31dd 100644
--- a/docs/user_guides.rst
+++ b/docs/user_guides.rst
@@ -1,6 +1,6 @@
.. _user-guides:
-User Guides
+User guides
===========
User guides are deeper dives into particular topics within JAX
@@ -9,7 +9,7 @@ or deployed codebases.
.. toctree::
:maxdepth: 1
- :caption: Debugging and Performance
+ :caption: Debugging and performance
notebooks/thinking_in_jax
profiling
@@ -20,25 +20,26 @@ or deployed codebases.
.. toctree::
:maxdepth: 1
- :caption: Development
+ :caption: Interfaces
- jaxpr
- notebooks/external_callbacks
- type_promotion
pytrees
-
-.. toctree::
- :maxdepth: 1
- :caption: Run Time
-
+ errors
aot
export/index
- errors
+ type_promotion
transfer_guard
.. toctree::
:maxdepth: 1
- :caption: Custom Operations
+ :caption: Custom operations
pallas/index
ffi
+
+.. toctree::
+ :caption: Example applications
+ :maxdepth: 1
+
+ notebooks/neural_network_with_tfds_data
+ notebooks/Neural_Network_and_Data_Loading
+ notebooks/vmapped_log_probs
diff --git a/docs/working-with-pytrees.md b/docs/working-with-pytrees.md
index 2bd1cc08ecdf..e41179996bc4 100644
--- a/docs/working-with-pytrees.md
+++ b/docs/working-with-pytrees.md
@@ -5,7 +5,7 @@ jupytext:
extension: .md
format_name: myst
format_version: 0.13
- jupytext_version: 1.16.1
+ jupytext_version: 1.16.4
kernelspec:
display_name: Python 3
language: python
diff --git a/docs/xla_flags.md b/docs/xla_flags.md
new file mode 100644
index 000000000000..b332940ccb9d
--- /dev/null
+++ b/docs/xla_flags.md
@@ -0,0 +1,89 @@
+# List of XLA compiler flags
+
+
+
+## Introduction
+This guide gives a brief overview of XLA and how XLA relates to Jax.
+For in-depth details please refer to [XLA documentation](https://openxla.org/xla). Then it lists commonly-used XLA compiler flags designed to optimize performance of Jax programs.
+
+## XLA: The Powerhouse Behind Jax
+XLA (Accelerated Linear Algebra) is a domain-specific compiler for linear algebra that plays a pivotal role in Jax's performance and flexibility. It enables Jax to generate optimized code for various hardware backends (CPUs, GPUs, TPUs) by transforming and compiling your Python/NumPy-like code into efficient machine instructions.
+
+Jax uses XLA's JIT compilation capabilities to transform your Python functions into optimized XLA computations at runtime.
+
+## Configuring XLA in Jax:
+You can influence XLA's behavior in Jax by setting XLA_FLAGS environment variables before running your Python script or colab notebook.
+
+For the colab notebooks:
+
+Provide flags using `os.environ['XLA_FLAGS']`:
+
+
+```python
+import os
+
+# Set multiple flags separated by spaces
+os.environ['XLA_FLAGS'] = '--flag1=value1 --flag2=value2'
+```
+
+For the python scripts:
+
+Specify `XLA_FLAGS` as a part of cli command:
+
+```bash
+XLA_FLAGS='--flag1=value1 --flag2=value2' python3 source.py
+```
+
+**Important Notes:**
+
+* Set `XLA_FLAGS` before importing Jax or other relevant libraries. Changing `XLA_FLAGS` after backend initialization will have no effect and given backend initialization time is not clearly defined it is usually safer to set `XLA_FLAGS` before executing any Jax code.
+* Experiment with different flags to optimize performance for your specific use case.
+
+
+**For further information:**
+* Complete and up to date documentation about XLA can be found in the official [XLA documentation](https://openxla.org/xla).
+
+* For backends supported by open-source version of XLA (CPU, GPU), XLA flags are defined with their default values in [xla/debug_options_flags.cc](https://github.com/openxla/xla/blob/main/xla/debug_options_flags.cc), and a complete list of flags could be found [here](https://github.com/openxla/xla/blob/main/xla/xla.proto).
+* TPU compiler flags are not part of [OpenXLA](https://github.com/openxla/xla), but commonly-used options are listed below.
+
+* Please note that this list of flags is not exhaustive and is subject to change. These flags are implementation details, and there is no guarantee that they will remain available or maintain their current behavior.
+### Common XLA flags
+| Flag | Type | Notes |
+| ---- | ---- | ----- |
+| `xla_dump_to` | String (filepath) | The folder where pre-optimization HLO files and other artifacts will be placed (see [XLA Tools](https://openxla.org/xla/tools)). |
+| `xla_enable_async_collective_permute` | TristateFlag (true/false/auto) | Rewrites all collective-permute operations to their asynchronous variants. When set to `auto`, XLA can turn on async collective based on other configurations or conditions automatically. |
+| `xla_enable_async_all_gather` | TristateFlag (true/false/auto) | If set to true, enables async all gather. If `auto`, enables only for platforms that implement async all-gather. The implementation (such as BC-offload or continuation fusion) is chosen based on other flag values. |
+| `xla_disable_hlo_passes` | String (comma-separated list of pass names) | Comma-separated list of HLO passes to be disabled. These names must exactly match the pass name (no whitespace around commas). |
+
+### TPU XLA flags
+| Flag | Type | Notes |
+| ---- | ---- | ----- |
+| `xla_tpu_enable_data_parallel_all_reduce_opt` | Boolean (true/false) | Optimization to increase overlap opportunities for DCN (data center networking) all-reduces used for data parallel sharding. |
+| `xla_tpu_data_parallel_opt_different_sized_ops` | Boolean (true/false) | Enables pipelining of data parallel ops across multiple iterations even if their output sizes doesn't match what can Be saved in place in the stacked variables. Can increase memory pressure. |
+| `xla_tpu_enable_async_collective_fusion` | Boolean (true/false) | Enables the pass which fuses async collective communications with compute ops (output/loop-fusion or convolution) that are scheduled between their -start and -done instructions. |
+| `xla_tpu_enable_async_collective_fusion_fuse_all_gather` | TristateFlag (true/false/auto) | Enables fusing all-gathers within the AsyncCollectiveFusion pass.
If set to `auto`, it will be enabled based on the target. |
+| `xla_tpu_enable_async_collective_fusion_multiple_steps` | Boolean (true/false) | Enables continuing the same async collective in multiple steps (fusions) in the AsyncCollectiveFusion pass. |
+| `xla_tpu_overlap_compute_collective_tc` | Boolean (true/false) | Enables the overlap of compute and communication on a single TensorCore, i.e., one core equivalent of MegaCore fusion. |
+| `xla_tpu_spmd_rng_bit_generator_unsafe` | Boolean (true/false) | Whether to run RngBitGenerator HLO in a partitioned way, which is unsafe if deterministic results are expected with different shardings on different parts of the computation. |
+| `xla_tpu_megacore_fusion_allow_ags` | Boolean (true/false) | Allows fusing all-gathers with convolutions/all-reduces. |
+| `xla_tpu_enable_ag_backward_pipelining` | Boolean (true/false) | Pipelines all-gathers (currently megascale all-gathers) backwards through scan loops. |
+
+### GPU XLA flags
+| Flag | Type | Notes |
+| ---- | ---- | ----- |
+| `xla_gpu_enable_latency_hiding_scheduler` | Boolean (true/false) |This flag enables latency hiding schedulers to overlap asynchronous communication with computation efficiently. The default value is False. |
+| `xla_gpu_enable_triton_gemm` | Boolean (true/false) | Use Triton-based matrix multiplication. |
+| `xla_gpu_graph_level` | Flag (0-3) | The legacy flag for setting GPU graph level. Use xla_gpu_enable_command_buffer in new use cases. 0 = off; 1 = capture fusions and memcpys; 2 = capture gemms; 3 = capture convolutions. |
+| `xla_gpu_all_reduce_combine_threshold_bytes` | Integer (bytes) | These flags tune when to combine multiple small AllGather / ReduceScatter / AllReduce into one big AllGather / ReduceScatter / AllReduce to reduce time spent on cross-device communication. For example, for the AllGather / ReduceScatter thresholds on a Transformer-based workload, consider tuning them high enough so as to combine at least a Transformer Layer’s weight AllGather / ReduceScatter. By default, the combine_threshold_bytes is set to 256. |
+| `xla_gpu_all_gather_combine_threshold_bytes` | Integer (bytes) | See xla_gpu_all_reduce_combine_threshold_bytes above. |
+| `xla_gpu_reduce_scatter_combine_threshold_bytes` | Integer (bytes) | See xla_gpu_all_reduce_combine_threshold_bytes above. |
+| `xla_gpu_enable_pipelined_all_gather` | Boolean (true/false) | Enable pipelinling of all-gather instructions. |
+| `xla_gpu_enable_pipelined_reduce_scatter` | Boolean (true/false) | Enable pipelinling of reduce-scatter instructions. |
+| `xla_gpu_enable_pipelined_all_reduce` | Boolean (true/false) | Enable pipelinling of all-reduce instructions. |
+| `xla_gpu_enable_while_loop_double_buffering` | Boolean (true/false) | Enable double-buffering for while loop. |
+| `xla_gpu_enable_triton_softmax_fusion` | Boolean (true/false) | Use Triton-based Softmax fusion. |
+| `xla_gpu_enable_all_gather_combine_by_dim` | Boolean (true/false) | Combine all-gather ops with the same gather dimension or irrespective of their dimension. |
+| `xla_gpu_enable_reduce_scatter_combine_by_dim` | Boolean (true/false) | Combine reduce-scatter ops with the same dimension or irrespective of their dimension. |
+
+**Additional reading:**
+* [GPU performance tips](https://jax.readthedocs.io/en/latest/gpu_performance_tips.html#xla-performance-flags)
diff --git a/examples/ffi/CMakeLists.txt b/examples/ffi/CMakeLists.txt
new file mode 100644
index 000000000000..8d9b811374d1
--- /dev/null
+++ b/examples/ffi/CMakeLists.txt
@@ -0,0 +1,15 @@
+cmake_minimum_required(VERSION 3.15...3.30)
+project(${SKBUILD_PROJECT_NAME} LANGUAGES CXX)
+
+find_package(Python 3.10 REQUIRED COMPONENTS Interpreter Development.Module)
+execute_process(
+ COMMAND "${Python_EXECUTABLE}"
+ "-c" "from jax.extend import ffi; print(ffi.include_dir())"
+ OUTPUT_STRIP_TRAILING_WHITESPACE OUTPUT_VARIABLE XLA_DIR)
+message(STATUS "XLA include directory: ${XLA_DIR}")
+
+find_package(nanobind CONFIG REQUIRED)
+
+nanobind_add_module(_rms_norm NB_STATIC "src/jax_ffi_example/rms_norm.cc")
+target_include_directories(_rms_norm PUBLIC ${XLA_DIR})
+install(TARGETS _rms_norm LIBRARY DESTINATION ${SKBUILD_PROJECT_NAME})
diff --git a/examples/ffi/README.md b/examples/ffi/README.md
new file mode 100644
index 000000000000..cc7018782a25
--- /dev/null
+++ b/examples/ffi/README.md
@@ -0,0 +1,9 @@
+# End-to-end example usage for JAX's foreign function interface
+
+This directory includes an example project demonstrating the use of JAX's
+foreign function interface (FFI). The JAX docs provide more information about
+this interface in [the FFI tutorial](https://jax.readthedocs.io/en/latest/ffi.html),
+but the example in this directory explicitly demonstrates:
+
+1. One way to package and distribute FFI targets, and
+2. Some more advanced use cases.
diff --git a/examples/ffi/pyproject.toml b/examples/ffi/pyproject.toml
new file mode 100644
index 000000000000..130dd91bbc70
--- /dev/null
+++ b/examples/ffi/pyproject.toml
@@ -0,0 +1,12 @@
+[build-system]
+requires = ["scikit-build-core", "nanobind", "jax>=0.4.31"]
+build-backend = "scikit_build_core.build"
+
+[project]
+name = "jax_ffi_example"
+version = "0.0.1"
+requires-python = ">=3.10"
+dependencies = ["jax"]
+
+[project.optional-dependencies]
+test = ["pytest", "absl-py"]
diff --git a/examples/ffi/src/jax_ffi_example/__init__.py b/examples/ffi/src/jax_ffi_example/__init__.py
new file mode 100644
index 000000000000..862a661e24b9
--- /dev/null
+++ b/examples/ffi/src/jax_ffi_example/__init__.py
@@ -0,0 +1,13 @@
+# Copyright 2024 The JAX Authors.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
diff --git a/examples/ffi/src/jax_ffi_example/rms_norm.cc b/examples/ffi/src/jax_ffi_example/rms_norm.cc
new file mode 100644
index 000000000000..2fb8d96c8461
--- /dev/null
+++ b/examples/ffi/src/jax_ffi_example/rms_norm.cc
@@ -0,0 +1,157 @@
+/* Copyright 2024 The JAX Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include
+#include
+#include
+#include