diff --git a/.bazelrc b/.bazelrc index cb223ef02f78..5b7bc653373b 100644 --- a/.bazelrc +++ b/.bazelrc @@ -57,27 +57,26 @@ build:native_arch_posix --host_copt=-march=native build:mkl_open_source_only --define=tensorflow_mkldnn_contraction_kernel=1 +# Disable clang extention that rejects type definitions within offsetof. +# This was added in clang-16 by https://reviews.llvm.org/D133574. +# Can be removed once upb is updated, since a type definition is used within +# offset of in the current version of ubp. +# See https://github.com/protocolbuffers/upb/blob/9effcbcb27f0a665f9f345030188c0b291e32482/upb/upb.c#L183. +build:clang --copt=-Wno-gnu-offsetof-extensions +# Disable clang extention that rejects unknown arguments. +build:clang --copt=-Qunused-arguments + build:cuda --repo_env TF_NEED_CUDA=1 build:cuda --repo_env TF_NCCL_USE_STUB=1 # "sm" means we emit only cubin, which is forward compatible within a GPU generation. # "compute" means we emit both cubin and PTX, which is larger but also forward compatible to future GPU generations. -build:cuda --action_env TF_CUDA_COMPUTE_CAPABILITIES="sm_50,sm_60,sm_70,sm_80,compute_90" +build:cuda --repo_env HERMETIC_CUDA_COMPUTE_CAPABILITIES="sm_50,sm_60,sm_70,sm_80,compute_90" build:cuda --crosstool_top=@local_config_cuda//crosstool:toolchain build:cuda --@local_config_cuda//:enable_cuda build:cuda --@xla//xla/python:jax_cuda_pip_rpaths=true - -# Build with nvcc for CUDA and clang for host -build:nvcc_clang --config=cuda -# Unfortunately, cuda_configure.bzl demands this for using nvcc + clang -build:nvcc_clang --action_env=TF_CUDA_CLANG="1" -build:nvcc_clang --action_env=TF_NVCC_CLANG="1" -build:nvcc_clang --@local_config_cuda//:cuda_compiler=nvcc - -# Requires MSVC and LLVM to be installed -build:win_clang --extra_toolchains=@local_config_cc//:cc-toolchain-x64_windows-clang-cl -build:win_clang --extra_execution_platforms=//jax/tools/toolchains:x64_windows-clang-cl -build:win_clang --compiler=clang-cl - +# Default hermetic CUDA and CUDNN versions. +build:cuda --repo_env=HERMETIC_CUDA_VERSION="12.3.2" +build:cuda --repo_env=HERMETIC_CUDNN_VERSION="9.1.1" # Force the linker to set RPATH, not RUNPATH. When resolving dynamic libraries, # ld.so prefers in order: RPATH, LD_LIBRARY_PATH, RUNPATH. JAX sets RPATH to # point to the $ORIGIN-relative location of the pip-installed NVIDIA CUDA @@ -92,17 +91,19 @@ build:win_clang --compiler=clang-cl # The list of CUDA pip packages that JAX depends on are present in setup.py. build:cuda --linkopt=-Wl,--disable-new-dtags +# This flag is needed to include CUDA libraries for bazel tests. +test:cuda --@local_config_cuda//cuda:include_cuda_libs=true + +build:cuda_clang --config=clang build:cuda_clang --@local_config_cuda//:cuda_compiler=clang build:cuda_clang --action_env=CLANG_CUDA_COMPILER_PATH="/usr/lib/llvm-18/bin/clang" -build:cuda_clang --action_env=TF_CUDA_CLANG="1" -# Disable clang extention that rejects type definitions within offsetof. -# This was added in clang-16 by https://reviews.llvm.org/D133574. -# Can be removed once upb is updated, since a type definition is used within -# offset of in the current version of ubp. -# See https://github.com/protocolbuffers/upb/blob/9effcbcb27f0a665f9f345030188c0b291e32482/upb/upb.c#L183. -build:cuda_clang --copt=-Wno-gnu-offsetof-extensions -# Disable clang extention that rejects unknown arguments. -build:cuda_clang --copt=-Qunused-arguments + +# Build with NVCC for CUDA +build:cuda_nvcc --config=cuda +build:cuda_nvcc --config=clang +build:cuda_nvcc --@local_config_cuda//:cuda_compiler=nvcc +build:cuda_nvcc --action_env=TF_NVCC_CLANG="1" +build:cuda_nvcc --action_env=CLANG_CUDA_COMPILER_PATH="/usr/lib/llvm-18/bin/clang" build:rocm --crosstool_top=@local_config_rocm//crosstool:toolchain build:rocm --define=using_rocm=true --define=using_rocm_hipcc=true @@ -111,6 +112,11 @@ build:rocm --action_env TF_ROCM_AMDGPU_TARGETS="gfx900,gfx906,gfx908,gfx90a,gfx1 build:nonccl --define=no_nccl_support=true +# Requires MSVC and LLVM to be installed +build:win_clang --extra_toolchains=@local_config_cc//:cc-toolchain-x64_windows-clang-cl +build:win_clang --extra_execution_platforms=//jax/tools/toolchains:x64_windows-clang-cl +build:win_clang --compiler=clang-cl + # Windows has a relatively short command line limit, which JAX has begun to hit. # See https://docs.bazel.build/versions/main/windows.html build:windows --features=compiler_param_file @@ -198,48 +204,46 @@ build:rbe_linux --host_linkopt=-lm # https://github.com/bazelbuild/bazel/issues/13623 build:rbe_cpu_linux_base --config=rbe_linux build:rbe_cpu_linux_base --config=cuda_clang -build:rbe_cpu_linux_base --action_env=TF_NVCC_CLANG="1" -build:rbe_cpu_linux_base --host_crosstool_top="@ubuntu20.04-clang_manylinux2014-cuda12.3-cudnn9.1_config_cuda//crosstool:toolchain" -build:rbe_cpu_linux_base --crosstool_top="@ubuntu20.04-clang_manylinux2014-cuda12.3-cudnn9.1_config_cuda//crosstool:toolchain" -build:rbe_cpu_linux_base --extra_toolchains="@ubuntu20.04-clang_manylinux2014-cuda12.3-cudnn9.1_config_cuda//crosstool:toolchain-linux-x86_64" +build:rbe_cpu_linux_base --host_crosstool_top="@local_config_cuda//crosstool:toolchain" +build:rbe_cpu_linux_base --crosstool_top="@local_config_cuda//crosstool:toolchain" +build:rbe_cpu_linux_base --extra_toolchains="@local_config_cuda//crosstool:toolchain-linux-x86_64" +build:rbe_cpu_linux_base --repo_env=TF_SYSROOT="/dt9" build:rbe_cpu_linux_base --extra_execution_platforms="@ubuntu20.04-clang_manylinux2014-cuda12.3-cudnn9.1_config_platform//:platform" build:rbe_cpu_linux_base --host_platform="@ubuntu20.04-clang_manylinux2014-cuda12.3-cudnn9.1_config_platform//:platform" build:rbe_cpu_linux_base --platforms="@ubuntu20.04-clang_manylinux2014-cuda12.3-cudnn9.1_config_platform//:platform" -build:rbe_cpu_linux_py3.10 --config=rbe_cpu_linux_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu20.04-clang_manylinux2014-cuda12.3-cudnn9.1_config_python3.10" +build:rbe_cpu_linux_py3.10 --config=rbe_cpu_linux_base build:rbe_cpu_linux_py3.10 --repo_env HERMETIC_PYTHON_VERSION="3.10" -build:rbe_cpu_linux_py3.11 --config=rbe_cpu_linux_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu20.04-clang_manylinux2014-cuda12.3-cudnn9.1_config_python3.11" +build:rbe_cpu_linux_py3.11 --config=rbe_cpu_linux_base build:rbe_cpu_linux_py3.11 --repo_env HERMETIC_PYTHON_VERSION="3.11" -build:rbe_cpu_linux_py3.12 --config=rbe_cpu_linux_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu20.04-clang_manylinux2014-cuda12.3-cudnn9.1_config_python3.12" +build:rbe_cpu_linux_py3.12 --config=rbe_cpu_linux_base build:rbe_cpu_linux_py3.12 --repo_env HERMETIC_PYTHON_VERSION="3.12" +build:rbe_cpu_linux_py3.13 --config=rbe_cpu_linux_base +build:rbe_cpu_linux_py3.13 --repo_env HERMETIC_PYTHON_VERSION="3.13" build:rbe_linux_cuda_base --config=rbe_linux build:rbe_linux_cuda_base --config=cuda build:rbe_linux_cuda_base --repo_env=REMOTE_GPU_TESTING=1 build:rbe_linux_cuda12.3_nvcc_base --config=rbe_linux_cuda_base -build:rbe_linux_cuda12.3_nvcc_base --config=cuda_clang -build:rbe_linux_cuda12.3_nvcc_base --action_env=TF_NVCC_CLANG="1" -build:rbe_linux_cuda12.3_nvcc_base --action_env=TF_CUDA_VERSION=12 -build:rbe_linux_cuda12.3_nvcc_base --action_env=TF_CUDNN_VERSION=9 -build:rbe_linux_cuda12.3_nvcc_base --action_env=CUDA_TOOLKIT_PATH="/usr/local/cuda-12" -build:rbe_linux_cuda12.3_nvcc_base --action_env=LD_LIBRARY_PATH="/usr/local/cuda:/usr/local/cuda/lib64:/usr/local/cuda/extras/CUPTI/lib64:/usr/local/tensorrt/lib" -build:rbe_linux_cuda12.3_nvcc_base --host_crosstool_top="@ubuntu20.04-clang_manylinux2014-cuda12.3-cudnn9.1_config_cuda//crosstool:toolchain" -build:rbe_linux_cuda12.3_nvcc_base --crosstool_top="@ubuntu20.04-clang_manylinux2014-cuda12.3-cudnn9.1_config_cuda//crosstool:toolchain" -build:rbe_linux_cuda12.3_nvcc_base --extra_toolchains="@ubuntu20.04-clang_manylinux2014-cuda12.3-cudnn9.1_config_cuda//crosstool:toolchain-linux-x86_64" +build:rbe_linux_cuda12.3_nvcc_base --config=cuda_nvcc +build:rbe_linux_cuda12.3_nvcc_base --repo_env=HERMETIC_CUDA_VERSION="12.3.2" +build:rbe_linux_cuda12.3_nvcc_base --repo_env=HERMETIC_CUDNN_VERSION="9.1.1" +build:rbe_linux_cuda12.3_nvcc_base --host_crosstool_top="@local_config_cuda//crosstool:toolchain" +build:rbe_linux_cuda12.3_nvcc_base --crosstool_top="@local_config_cuda//crosstool:toolchain" +build:rbe_linux_cuda12.3_nvcc_base --extra_toolchains="@local_config_cuda//crosstool:toolchain-linux-x86_64" +build:rbe_linux_cuda12.3_nvcc_base --repo_env=TF_SYSROOT="/dt9" build:rbe_linux_cuda12.3_nvcc_base --extra_execution_platforms="@ubuntu20.04-clang_manylinux2014-cuda12.3-cudnn9.1_config_platform//:platform" build:rbe_linux_cuda12.3_nvcc_base --host_platform="@ubuntu20.04-clang_manylinux2014-cuda12.3-cudnn9.1_config_platform//:platform" build:rbe_linux_cuda12.3_nvcc_base --platforms="@ubuntu20.04-clang_manylinux2014-cuda12.3-cudnn9.1_config_platform//:platform" -build:rbe_linux_cuda12.3_nvcc_base --repo_env=TF_CUDA_CONFIG_REPO="@ubuntu20.04-clang_manylinux2014-cuda12.3-cudnn9.1_config_cuda" -build:rbe_linux_cuda12.3_nvcc_base --repo_env=TF_NCCL_CONFIG_REPO="@ubuntu20.04-clang_manylinux2014-cuda12.3-cudnn9.1_config_nccl" -# RBE machines have an older CUDA driver version, so we have to enable driver forward compatibility -build:rbe_linux_cuda12.3_nvcc_base --test_env=LD_LIBRARY_PATH=/usr/local/cuda/compat -build:rbe_linux_cuda12.3_nvcc_py3.10 --config=rbe_linux_cuda12.3_nvcc_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu20.04-clang_manylinux2014-cuda12.3-cudnn9.1_config_python3.10" +build:rbe_linux_cuda12.3_nvcc_py3.10 --config=rbe_linux_cuda12.3_nvcc_base build:rbe_linux_cuda12.3_nvcc_py3.10 --repo_env HERMETIC_PYTHON_VERSION="3.10" -build:rbe_linux_cuda12.3_nvcc_py3.11 --config=rbe_linux_cuda12.3_nvcc_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu20.04-clang_manylinux2014-cuda12.3-cudnn9.1_config_python3.11" +build:rbe_linux_cuda12.3_nvcc_py3.11 --config=rbe_linux_cuda12.3_nvcc_base build:rbe_linux_cuda12.3_nvcc_py3.11 --repo_env HERMETIC_PYTHON_VERSION="3.11" -build:rbe_linux_cuda12.3_nvcc_py3.12 --config=rbe_linux_cuda12.3_nvcc_base --repo_env=TF_PYTHON_CONFIG_REPO="@ubuntu20.04-clang_manylinux2014-cuda12.3-cudnn9.1_config_python3.12" +build:rbe_linux_cuda12.3_nvcc_py3.12 --config=rbe_linux_cuda12.3_nvcc_base build:rbe_linux_cuda12.3_nvcc_py3.12 --repo_env HERMETIC_PYTHON_VERSION="3.12" +build:rbe_linux_cuda12.3_nvcc_py3.13 --config=rbe_linux_cuda12.3_nvcc_base +build:rbe_linux_cuda12.3_nvcc_py3.13 --repo_env HERMETIC_PYTHON_VERSION="3.13" # These you may need to change for your own GCP project. build:tensorflow_testing_rbe --project_id=tensorflow-testing diff --git a/.github/ISSUE_TEMPLATE/bug-report.yml b/.github/ISSUE_TEMPLATE/bug-report.yml index c19832e63163..628310519b66 100644 --- a/.github/ISSUE_TEMPLATE/bug-report.yml +++ b/.github/ISSUE_TEMPLATE/bug-report.yml @@ -20,11 +20,11 @@ body: * If you prefer a non-templated issue report, click [here][Raw report]. - [Discussions]: https://github.com/google/jax/discussions + [Discussions]: https://github.com/jax-ml/jax/discussions - [issue search]: https://github.com/google/jax/search?q=is%3Aissue&type=issues + [issue search]: https://github.com/jax-ml/jax/search?q=is%3Aissue&type=issues - [Raw report]: http://github.com/google/jax/issues/new + [Raw report]: http://github.com/jax-ml/jax/issues/new - type: textarea attributes: label: Description diff --git a/.github/ISSUE_TEMPLATE/config.yml b/.github/ISSUE_TEMPLATE/config.yml index cabbed58967a..f078e8e94182 100644 --- a/.github/ISSUE_TEMPLATE/config.yml +++ b/.github/ISSUE_TEMPLATE/config.yml @@ -1,5 +1,5 @@ blank_issues_enabled: false contact_links: - name: Have questions or need support? - url: https://github.com/google/jax/discussions + url: https://github.com/jax-ml/jax/discussions about: Please ask questions on the Discussions tab diff --git a/.github/workflows/ci-build.yaml b/.github/workflows/ci-build.yaml index c01d44af9cf6..a2e45c1a8fc6 100644 --- a/.github/workflows/ci-build.yaml +++ b/.github/workflows/ci-build.yaml @@ -20,20 +20,19 @@ permissions: contents: read # to fetch code actions: write # to cancel previous workflows +concurrency: + group: ${{ github.workflow }}-${{ github.head_ref || github.ref }} + cancel-in-progress: true + jobs: lint_and_typecheck: if: false runs-on: ubuntu-latest timeout-minutes: 5 steps: - - name: Cancel previous - uses: styfle/cancel-workflow-action@85880fa0301c86cca9da44039ee3bb12d3bedbfa # ratchet: styfle/cancel-workflow-action@0.12.1 - with: - access_token: ${{ github.token }} - if: ${{github.ref != 'refs/heads/main'}} - uses: actions/checkout@b4ffde65f46336ab88eb53be808477a3936bae11 # ratchet:actions/checkout@v4 - name: Set up Python 3.11 - uses: actions/setup-python@39cd14951b08e74b54015e9e001cdefcf80e669f # ratchet:actions/setup-python@v5 + uses: actions/setup-python@f677139bbe7f9c59b41e40162b753c062f5d49a3 # ratchet:actions/setup-python@v5 with: python-version: 3.11 - uses: pre-commit/action@2c7b3805fd2a0fd8c1884dcaebf91fc102a13ecd # ratchet: pre-commit/action@v3.0.1 @@ -60,14 +59,9 @@ jobs: prng-upgrade: 0 num_generated_cases: 1 steps: - - name: Cancel previous - uses: styfle/cancel-workflow-action@85880fa0301c86cca9da44039ee3bb12d3bedbfa # ratchet: styfle/cancel-workflow-action@0.12.1 - with: - access_token: ${{ github.token }} - if: ${{github.ref != 'refs/heads/main'}} - uses: actions/checkout@b4ffde65f46336ab88eb53be808477a3936bae11 # ratchet:actions/checkout@v4 - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@39cd14951b08e74b54015e9e001cdefcf80e669f # ratchet:actions/setup-python@v5 + uses: actions/setup-python@f677139bbe7f9c59b41e40162b753c062f5d49a3 # ratchet:actions/setup-python@v5 with: python-version: ${{ matrix.python-version }} - name: Get pip cache dir @@ -113,14 +107,9 @@ jobs: matrix: python-version: ['3.10'] steps: - - name: Cancel previous - uses: styfle/cancel-workflow-action@85880fa0301c86cca9da44039ee3bb12d3bedbfa # ratchet: styfle/cancel-workflow-action@0.12.1 - with: - access_token: ${{ github.token }} - if: ${{github.ref != 'refs/heads/main'}} - uses: actions/checkout@b4ffde65f46336ab88eb53be808477a3936bae11 # ratchet:actions/checkout@v4 - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@39cd14951b08e74b54015e9e001cdefcf80e669f # ratchet:actions/setup-python@v5 + uses: actions/setup-python@f677139bbe7f9c59b41e40162b753c062f5d49a3 # ratchet:actions/setup-python@v5 with: python-version: ${{ matrix.python-version }} - name: Get pip cache dir @@ -156,14 +145,9 @@ jobs: matrix: python-version: ['3.10'] steps: - - name: Cancel previous - uses: styfle/cancel-workflow-action@85880fa0301c86cca9da44039ee3bb12d3bedbfa # ratchet: styfle/cancel-workflow-action@0.12.1 - with: - access_token: ${{ github.token }} - if: ${{github.ref != 'refs/heads/main'}} - uses: actions/checkout@b4ffde65f46336ab88eb53be808477a3936bae11 # ratchet:actions/checkout@v4 - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@39cd14951b08e74b54015e9e001cdefcf80e669f # ratchet:actions/setup-python@v5 + uses: actions/setup-python@f677139bbe7f9c59b41e40162b753c062f5d49a3 # ratchet:actions/setup-python@v5 with: python-version: ${{ matrix.python-version }} - name: Get pip cache dir @@ -198,14 +182,9 @@ jobs: enable-x64: 0 num_generated_cases: 10 steps: - - name: Cancel previous - uses: styfle/cancel-workflow-action@85880fa0301c86cca9da44039ee3bb12d3bedbfa # ratchet: styfle/cancel-workflow-action@0.12.1 - with: - access_token: ${{ github.token }} - if: ${{github.ref != 'refs/heads/main'}} - uses: actions/checkout@b4ffde65f46336ab88eb53be808477a3936bae11 # ratchet:actions/checkout@v4 - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@39cd14951b08e74b54015e9e001cdefcf80e669f # ratchet:actions/setup-python@v5 + uses: actions/setup-python@f677139bbe7f9c59b41e40162b753c062f5d49a3 # ratchet:actions/setup-python@v5 with: python-version: ${{ matrix.python-version }} - name: Get pip cache dir @@ -236,4 +215,37 @@ jobs: echo "JAX_ENABLE_CHECKS=$JAX_ENABLE_CHECKS" echo "JAX_SKIP_SLOW_TESTS=$JAX_SKIP_SLOW_TESTS" pytest -n auto --tb=short --maxfail=20 jax/experimental/jax2tf/tests/jax2tf_test.py - \ No newline at end of file + + ffi: + name: FFI example + runs-on: ubuntu-latest + timeout-minutes: 5 + steps: + - uses: actions/checkout@b4ffde65f46336ab88eb53be808477a3936bae11 # ratchet:actions/checkout@v4 + - name: Set up Python 3.11 + uses: actions/setup-python@f677139bbe7f9c59b41e40162b753c062f5d49a3 # ratchet:actions/setup-python@v5 + with: + python-version: 3.11 + - name: Get pip cache dir + id: pip-cache + run: | + python -m pip install --upgrade pip wheel + echo "dir=$(pip cache dir)" >> $GITHUB_OUTPUT + - name: pip cache + uses: actions/cache@0c45773b623bea8c8e75f6c82b208c3cf94ea4f9 # ratchet: actions/cache@v4 + with: + path: ${{ steps.pip-cache.outputs.dir }} + key: ${{ runner.os }}-pip-ffi-examples-${{ hashFiles('**/setup.py', '**/requirements.txt', '**/test-requirements.txt', 'examples/**/pyproject.toml') }} + - name: Install JAX + run: pip install . + - name: Build and install example project + run: python -m pip install -v ./examples/ffi[test] + env: + # We test building using GCC instead of clang. All other JAX builds use + # clang, but it is useful to make sure that FFI users can compile using + # a different toolchain. GCC is the default compiler on the + # 'ubuntu-latest' runner, but we still set this explicitly just to be + # clear. + CMAKE_ARGS: -DCMAKE_C_COMPILER=gcc -DCMAKE_CXX_COMPILER=g++ + - name: Run tests + run: python -m pytest examples/ffi/tests diff --git a/.github/workflows/jax-array-api.yml b/.github/workflows/jax-array-api.yml index 78cddb411feb..f1dc8eee8a75 100644 --- a/.github/workflows/jax-array-api.yml +++ b/.github/workflows/jax-array-api.yml @@ -1,13 +1,16 @@ name: JAX Array API on: - workflow_dispatch: # allows triggering the workflow run manually - pull_request: # Automatically trigger on pull requests affecting particular files + push: branches: - main - paths: - - '**workflows/jax-array-api.yml' - - '**experimental/array_api/**' + pull_request: + branches: + - main + +concurrency: + group: ${{ github.workflow }}-${{ github.head_ref || github.ref }} + cancel-in-progress: true jobs: build: @@ -25,11 +28,11 @@ jobs: with: repository: data-apis/array-api-tests # TODO(jakevdp) update this to a stable release/tag when available. - ref: '33f2d2ea2f3dd2b3ceeeb4519d55e08096184149' # Latest commit as of 2024-05-28 + ref: 'b4c0823469c02d6ce6e512ad4c2bd8ba42b1b4b2' # Latest commit as of 2024-09-09 submodules: 'true' path: 'array-api-tests' - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@39cd14951b08e74b54015e9e001cdefcf80e669f # ratchet:actions/setup-python@v5 + uses: actions/setup-python@f677139bbe7f9c59b41e40162b753c062f5d49a3 # ratchet:actions/setup-python@v5 with: python-version: ${{ matrix.python-version }} - name: Install dependencies diff --git a/.github/workflows/metal_plugin_ci.yml b/.github/workflows/metal_plugin_ci.yml index 0c739619df1a..75f4bba1a367 100644 --- a/.github/workflows/metal_plugin_ci.yml +++ b/.github/workflows/metal_plugin_ci.yml @@ -11,6 +11,10 @@ on: paths: - '**workflows/metal_plugin_ci.yml' +concurrency: + group: ${{ github.workflow }}-${{ github.head_ref || github.ref }} + cancel-in-progress: true + jobs: jax-metal-plugin-test: diff --git a/.github/workflows/self_hosted_runner_utils/setup_runner.sh b/.github/workflows/self_hosted_runner_utils/setup_runner.sh index 79c1224c13cc..ef501784b45e 100755 --- a/.github/workflows/self_hosted_runner_utils/setup_runner.sh +++ b/.github/workflows/self_hosted_runner_utils/setup_runner.sh @@ -31,7 +31,7 @@ runner_token="$3" # - sets empty string as default to avoid unbound variable error from set -u jax_repo_url="${4-}" if [ -z "${jax_repo_url}" ]; then - jax_repo_url="https://github.com/google/jax" + jax_repo_url="https://github.com/jax-ml/jax" fi # Create `runner` user. This user won't have sudo access unless you ssh into the @@ -67,7 +67,7 @@ cd ~/ git clone ${jax_repo_url} -# Based on https://github.com/google/jax/settings/actions/runners/new +# Based on https://github.com/jax-ml/jax/settings/actions/runners/new # (will be 404 for github users with insufficient repo permissions) mkdir actions-runner && cd actions-runner curl -o actions-runner-linux-x64.tar.gz -L ${actions_runner_download} diff --git a/.github/workflows/upstream-nightly.yml b/.github/workflows/upstream-nightly.yml index 9e0386f75053..2bdd8ba5192e 100644 --- a/.github/workflows/upstream-nightly.yml +++ b/.github/workflows/upstream-nightly.yml @@ -38,7 +38,7 @@ jobs: steps: - uses: actions/checkout@b4ffde65f46336ab88eb53be808477a3936bae11 # ratchet:actions/checkout@v4 - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@39cd14951b08e74b54015e9e001cdefcf80e669f # ratchet:actions/setup-python@v5 + uses: actions/setup-python@f677139bbe7f9c59b41e40162b753c062f5d49a3 # ratchet:actions/setup-python@v5 with: python-version: ${{ matrix.python-version }} - name: Install JAX test requirements @@ -84,8 +84,8 @@ jobs: failure() && steps.status.outcome == 'failure' && github.event_name == 'schedule' - && github.repository == 'google/jax' - uses: actions/upload-artifact@89ef406dd8d7e03cfd12d9e0a4a378f454709029 # ratchet: actions/upload-artifact@v4 + && github.repository == 'jax-ml/jax' + uses: actions/upload-artifact@50769540e7f4bd5e21e526ee35c689e35e0d6874 # ratchet: actions/upload-artifact@v4 with: name: output-${{ matrix.python-version }}-log.jsonl path: output-${{ matrix.python-version }}-log.jsonl @@ -107,7 +107,7 @@ jobs: shell: bash steps: - uses: actions/checkout@b4ffde65f46336ab88eb53be808477a3936bae11 # ratchet:actions/checkout@v4 - - uses: actions/setup-python@39cd14951b08e74b54015e9e001cdefcf80e669f # ratchet:actions/setup-python@v5 + - uses: actions/setup-python@f677139bbe7f9c59b41e40162b753c062f5d49a3 # ratchet:actions/setup-python@v5 with: python-version: "3.x" - uses: actions/download-artifact@fa0a91b85d4f404e444e00e005971372dc801d16 # ratchet:actions/download-artifact@v4 diff --git a/.github/workflows/wheel_win_x64.yml b/.github/workflows/wheel_win_x64.yml index 8dad541b81d0..f4fb7727da6b 100644 --- a/.github/workflows/wheel_win_x64.yml +++ b/.github/workflows/wheel_win_x64.yml @@ -2,6 +2,10 @@ name: Wheel build - Windows CPU x86_64 on: workflow_dispatch: # allows triggering the workflow run manually +concurrency: + group: ${{ github.workflow }}-${{ github.head_ref || github.ref }} + cancel-in-progress: true + env: DISTUTILS_USE_SDK: 1 MSSdk: 1 @@ -13,22 +17,17 @@ jobs: matrix: os: [windows-2019-32core] arch: [AMD64] - pyver: ['3.10', '3.11', '3.12'] + pyver: ['3.10', '3.11', '3.12', '3.13.0-rc.2'] name: ${{ matrix.os }} ${{ matrix.pyver }} jaxlib wheel build runs-on: ${{ matrix.os }} steps: - - name: Cancel Previous Runs - uses: styfle/cancel-workflow-action@85880fa0301c86cca9da44039ee3bb12d3bedbfa # ratchet: styfle/cancel-workflow-action@0.12.1 - with: - access_token: ${{ github.token }} - - name: Install LLVM/Clang run: choco install llvm --version=18.1.4 --yes --no-progress --allow-downgrade - uses: actions/checkout@b4ffde65f46336ab88eb53be808477a3936bae11 # ratchet:actions/checkout@v4 - - uses: actions/setup-python@39cd14951b08e74b54015e9e001cdefcf80e669f # ratchet:actions/setup-python@v5 + - uses: actions/setup-python@f677139bbe7f9c59b41e40162b753c062f5d49a3 # ratchet:actions/setup-python@v5 with: python-version: ${{ matrix.pyver }} cache: 'pip' @@ -46,7 +45,7 @@ jobs: --bazel_options=--config=win_clang ` --verbose - - uses: actions/upload-artifact@89ef406dd8d7e03cfd12d9e0a4a378f454709029 # ratchet: actions/upload-artifact@v4 + - uses: actions/upload-artifact@50769540e7f4bd5e21e526ee35c689e35e0d6874 # ratchet: actions/upload-artifact@v4 with: name: wheels-${{ matrix.os }}-${{ matrix.pyver }} path: ${{ github.workspace }}\dist\*.whl @@ -58,7 +57,7 @@ jobs: JAX_SKIP_SLOW_TESTS: true PY_COLORS: 1 run: | + python -m pip install --find-links ${{ github.workspace }}\dist jaxlib python -m pip install -e ${{ github.workspace }} - python -m pip install --no-index --find-links ${{ github.workspace }}\dist jaxlib echo "JAX_ENABLE_CHECKS=$JAX_ENABLE_CHECKS" pytest -n auto --tb=short tests examples diff --git a/.github/workflows/windows_ci.yml b/.github/workflows/windows_ci.yml index 60bebd32ea76..03a6876cdbb1 100644 --- a/.github/workflows/windows_ci.yml +++ b/.github/workflows/windows_ci.yml @@ -6,6 +6,10 @@ on: pull_request: types: [ labeled ] # allow force-windows-run label +concurrency: + group: ${{ github.workflow }}-${{ github.head_ref || github.ref }} + cancel-in-progress: true + env: DISTUTILS_USE_SDK: 1 MSSdk: 1 @@ -23,10 +27,6 @@ jobs: runs-on: ${{ matrix.os }} steps: - - name: Cancel Previous Runs - uses: styfle/cancel-workflow-action@85880fa0301c86cca9da44039ee3bb12d3bedbfa # ratchet: styfle/cancel-workflow-action@0.12.1 - with: - access_token: ${{ github.token }} - name: Install LLVM/Clang run: choco install llvm --version=18.1.4 --yes --no-progress --allow-downgrade @@ -35,7 +35,7 @@ jobs: with: path: jax - - uses: actions/setup-python@39cd14951b08e74b54015e9e001cdefcf80e669f # ratchet:actions/setup-python@v5 + - uses: actions/setup-python@f677139bbe7f9c59b41e40162b753c062f5d49a3 # ratchet:actions/setup-python@v5 with: python-version: ${{ matrix.pyver }} cache: 'pip' @@ -53,7 +53,7 @@ jobs: --bazel_options=--color=yes ` --bazel_options=--config=win_clang - - uses: actions/upload-artifact@89ef406dd8d7e03cfd12d9e0a4a378f454709029 # ratchet: actions/upload-artifact@v4 + - uses: actions/upload-artifact@50769540e7f4bd5e21e526ee35c689e35e0d6874 # ratchet: actions/upload-artifact@v4 with: name: wheels path: ${{ github.workspace }}\jax\dist\*.whl diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 355f134f0551..c89aa934d95d 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -9,7 +9,7 @@ repos: - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v4.6.0 + rev: 2c9f875913ee60ca25ce70243dc24d5b6415598c # frozen: v4.6.0 hooks: - id: check-ast - id: check-merge-conflict @@ -26,12 +26,12 @@ repos: files: \.py$ - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.4.4 + rev: 8b5112a3b2ad121439a2092f8ff548c0d80f2514 # frozen: v0.6.1 hooks: - id: ruff - repo: https://github.com/pre-commit/mirrors-mypy - rev: 'v1.11.0' + rev: 'd4911cfb7f1010759fde68da196036feeb25b99d' # frozen: v1.11.2 hooks: - id: mypy files: (jax/|tests/typing_test\.py) @@ -40,7 +40,7 @@ repos: args: [--config=pyproject.toml] - repo: https://github.com/mwouts/jupytext - rev: v1.16.1 + rev: 8ed836db64ad5d304f2315e6bfd9049c9142e190 # frozen: v1.16.4 hooks: - id: jupytext files: docs/ diff --git a/CHANGELOG.md b/CHANGELOG.md index 7fbe947fa7de..079e055aa994 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,15 +10,103 @@ Remember to align the itemized text with the first line of an item within a list When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.md. --> -## jax 0.4.32 +## jax 0.4.34 + +* New Functionality + * This release includes wheels for Python 3.13. Free-threading mode is not yet + supported. + * `jax.errors.JaxRuntimeError` has been added as a public alias for the + formerly private `XlaRuntimeError` type. + +* Breaking changes + * `jax_pmap_no_rank_reduction` flag is set to `True` by default. + * array[0] on a pmap result now introduces a reshape (use array[0:1] + instead). + * The per-shard shape (accessable via jax_array.addressable_shards or + jax_array.addressable_data(0)) now has a leading (1, ...). Update code + that directly accesses shards accordingly. The rank of the per-shard-shape + now matches that of the global shape which is the same behavior as jit. + This avoids costly reshapes when passing results from pmap into jit. + +* Deprecations + * In {func}`jax.numpy.trim_zeros`, non-arraylike arguments or arraylike + arguments with `ndim != 1` are now deprecated, and in the future will result + in an error. + * Internal pretty-printing tools `jax.core.pp_*` have been removed, after + being deprecated in JAX v0.4.30. + * `jax.lib.xla_client.Device` is deprecated; use `jax.Device` instead. + * `jax.lib.xla_client.XlaRuntimeError` has been deprecated. Use + `jax.errors.JaxRuntimeError` instead. + +* Deletion: + * `jax.xla_computation` is deleted. It's been 3 months since it's deprecation + in 0.4.30 JAX release. + Please use the AOT APIs to get the same functionality as `jax.xla_computation`. + * `jax.xla_computation(fn)(*args, **kwargs)` can be replaced with + `jax.jit(fn).lower(*args, **kwargs).compiler_ir('hlo')`. + * You can also use `.out_info` property of `jax.stages.Lowered` to get the + output information (like tree structure, shape and dtype). + * For cross-backend lowering, you can replace + `jax.xla_computation(fn, backend='tpu')(*args, **kwargs)` with + `jax.jit(fn).trace(*args, **kwargs).lower(lowering_platforms=('tpu',)).compiler_ir('hlo')`. + * {class}`jax.ShapeDtypeStruct` no longer accepts the `named_shape` argument. + The argument was only used by `xmap` which was removed in 0.4.31. + * `jax.tree.map(f, None, non-None)`, which previously emitted a + `DeprecationWarning`, now raises an error in a future version of jax. `None` + is only a tree-prefix of itself. To preserve the current behavior, you can + ask `jax.tree.map` to treat `None` as a leaf value by writing: + `jax.tree.map(lambda x, y: None if x is None else f(x, y), a, b, is_leaf=lambda x: x is None)`. + +* Bug fixes + * Fixed a bug where {func}`jax.numpy.cumsum` would produce incorrect outputs + if a non-boolean input was provided and `dtype=bool` was specified. + +## jax 0.4.33 (September 16, 2024) + +This is a patch release on top of jax 0.4.32, that fixes two bugs found in that +release. + +A TPU-only data corruption bug was found in the version of libtpu pinned by +JAX 0.4.32, which manifested only if multiple TPU slices were present in the +same job, for example, if training on multiple v5e slices. +This release fixes that issue by pinning a fixed version of `libtpu`. + +This release fixes an inaccurate result for F64 tanh on CPU (#23590). + +## jax 0.4.32 (September 11, 2024) + +Note: This release was yanked from PyPi because of a data corruption bug on TPU. +See the 0.4.33 release notes for more details. + +* New Functionality + * Added {func}`jax.extend.ffi.ffi_call` and {func}`jax.extend.ffi.ffi_lowering` + to support the use of the new {ref}`ffi-tutorial` to interface with custom + C++ and CUDA code from JAX. * Changes + * `jax_enable_memories` flag is set to `True` by default. * {mod}`jax.numpy` now supports v2023.12 of the Python Array API Standard. See {ref}`python-array-api` for more information. * Computations on the CPU backend may now be dispatched asynchronously in more cases. Previously non-parallel computations were always dispatched synchronously. You can recover the old behavior by setting `jax.config.update('jax_cpu_enable_async_dispatch', False)`. + * Added new {func}`jax.process_indices` function to replace the + `jax.host_ids()` function that was deprecated in JAX v0.2.13. + * To align with the behavior of `numpy.fabs`, `jax.numpy.fabs` has been + modified to no longer support `complex dtypes`. + * ``jax.tree_util.register_dataclass`` now checks that ``data_fields`` + and ``meta_fields`` includes all dataclass fields with ``init=True`` + and only them, if ``nodetype`` is a dataclass. + * Several {mod}`jax.numpy` functions now have full {class}`~jax.numpy.ufunc` + interfaces, including {obj}`~jax.numpy.add`, {obj}`~jax.numpy.multiply`, + {obj}`~jax.numpy.bitwise_and`, {obj}`~jax.numpy.bitwise_or`, + {obj}`~jax.numpy.bitwise_xor`, {obj}`~jax.numpy.logical_and`, + {obj}`~jax.numpy.logical_and`, and {obj}`~jax.numpy.logical_and`. + In future releases we plan to expand these to other ufuncs. + * Added {func}`jax.lax.optimization_barrier`, which allows users to prevent + compiler optimizations such as common-subexpression elimination and to + control scheduling. * Breaking changes * The MHLO MLIR dialect (`jax.extend.mlir.mhlo`) has been removed. Use the @@ -34,8 +122,37 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG. * The `jax.experimental.array_api` module is deprecated, and importing it is no longer required to use the Array API. `jax.numpy` supports the array API directly; see {ref}`python-array-api` for more information. + * The internal utilities `jax.core.check_eqn`, `jax.core.check_type`, and + `jax.core.check_valid_jaxtype` are now deprecated, and will be removed in + the future. + * `jax.numpy.round_` has been deprecated, following removal of the corresponding + API in NumPy 2.0. Use {func}`jax.numpy.round` instead. + * Passing a DLPack capsule to {func}`jax.dlpack.from_dlpack` is deprecated. + The argument to {func}`jax.dlpack.from_dlpack` should be an array from + another framework that implements the ``__dlpack__`` protocol. + +## jaxlib 0.4.32 (September 11, 2024) + +Note: This release was yanked from PyPi because of a data corruption bug on TPU. +See the 0.4.33 release notes for more details. -## jaxlib 0.4.32 +* Breaking changes + * This release of jaxlib switched to a new version of the CPU backend, which + should compile faster and leverage parallelism better. If you experience + any problems due to this change, you can temporarily enable the old CPU + backend by setting the environment variable + `XLA_FLAGS=--xla_cpu_use_thunk_runtime=false`. If you need to do this, + please file a JAX bug with instructions to reproduce. + * Hermetic CUDA support is added. + Hermetic CUDA uses a specific downloadable version of CUDA instead of the + user’s locally installed CUDA. Bazel will download CUDA, CUDNN and NCCL + distributions, and then use CUDA libraries and tools as dependencies in + various Bazel targets. This enables more reproducible builds for JAX and its + supported CUDA versions. + +* Changes + * SparseCore profiling is added. + * JAX now supports profiling [SparseCore](https://cloud.google.com/tpu/docs/system-architecture-tpu-vm#sparsecore) on TPUv5p chips. These traces will be viewable in Tensorboard Profiler's [TraceViewer](https://www.tensorflow.org/guide/profiler#trace_viewer). ## jax 0.4.31 (July 29, 2024) @@ -179,7 +296,7 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG. which manifested as an incorrect output for cumulative reductions (#21403). * Fixed a bug where XLA:CPU miscompiled certain matmul fusions (https://github.com/openxla/xla/pull/13301). - * Fixes a compiler crash on GPU (https://github.com/google/jax/issues/21396). + * Fixes a compiler crash on GPU (https://github.com/jax-ml/jax/issues/21396). * Deprecations * `jax.tree.map(f, None, non-None)` now emits a `DeprecationWarning`, and will @@ -301,7 +418,7 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG. branch consistent with that of NumPy 2.0. * The behavior of `lax.rng_bit_generator`, and in turn the `'rbg'` and `'unsafe_rbg'` PRNG implementations, under `jax.vmap` [has - changed](https://github.com/google/jax/issues/19085) so that + changed](https://github.com/jax-ml/jax/issues/19085) so that mapping over keys results in random generation only from the first key in the batch. * Docs now use `jax.random.key` for construction of PRNG key arrays @@ -333,7 +450,7 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG. * JAX export does not support older serialization versions anymore. Version 9 has been supported since October 27th, 2023 and has become the default since February 1, 2024. - See [a description of the versions](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#native-serialization-versions). + See [a description of the versions](https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#native-serialization-versions). This change could break clients that set a specific JAX serialization version lower than 9. @@ -406,7 +523,7 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG. * added the ability to specify symbolic constraints on the dimension variables. This makes shape polymorphism more expressive, and gives a way to workaround limitations in the reasoning about inequalities. - See https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#user-specified-symbolic-constraints. + See https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#user-specified-symbolic-constraints. * with the addition of symbolic constraints ({jax-issue}`#19235`) we now consider dimension variables from different scopes to be different, even if they have the same name. Symbolic expressions from different scopes @@ -416,7 +533,7 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG. The scope of a symbolic expression `e` can be read with `e.scope` and passed into the above functions to direct them to construct symbolic expressions in a given scope. - See https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#user-specified-symbolic-constraints. + See https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#user-specified-symbolic-constraints. * simplified and faster equality comparisons, where we consider two symbolic dimensions to be equal if the normalized form of their difference reduces to 0 ({jax-issue}`#19231`; note that this may result in user-visible behavior @@ -435,7 +552,7 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG. strings for polymorphic shapes specifications ({jax-issue}`#19284`). * JAX default native serialization version is now 9. This is relevant for {mod}`jax.experimental.jax2tf` and {mod}`jax.experimental.export`. - See [description of version numbers](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#native-serialization-versions). + See [description of version numbers](https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#native-serialization-versions). * Refactored the API for `jax.experimental.export`. Instead of `from jax.experimental.export import export` you should use now `from jax.experimental import export`. The old way of importing will @@ -681,19 +798,19 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG. * When not running under IPython: when an exception is raised, JAX now filters out the entirety of its internal frames from tracebacks. (Without the "unfiltered stack trace" that previously appeared.) This should produce much friendlier-looking tracebacks. See - [here](https://github.com/google/jax/pull/16949) for an example. + [here](https://github.com/jax-ml/jax/pull/16949) for an example. This behavior can be changed by setting `JAX_TRACEBACK_FILTERING=remove_frames` (for two separate unfiltered/filtered tracebacks, which was the old behavior) or `JAX_TRACEBACK_FILTERING=off` (for one unfiltered traceback). * jax2tf default serialization version is now 7, which introduces new shape - [safety assertions](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#errors-in-presence-of-shape-polymorphism). + [safety assertions](https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#errors-in-presence-of-shape-polymorphism). * Devices passed to `jax.sharding.Mesh` should be hashable. This specifically applies to mock devices or user created devices. `jax.devices()` are already hashable. * Breaking changes: * jax2tf now uses native serialization by default. See - the [jax2tf documentation](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md) + the [jax2tf documentation](https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md) for details and for mechanisms to override the default. * The option `--jax_coordination_service` has been removed. It is now always `True`. @@ -822,7 +939,7 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG. arguments will always resolve to the "common operands" `cond` behavior (as documented) if the second and third arguments are callable, even if other operands are callable as well. See - [#16413](https://github.com/google/jax/issues/16413). + [#16413](https://github.com/jax-ml/jax/issues/16413). * The deprecated config options `jax_array` and `jax_jit_pjit_api_merge`, which did nothing, have been removed. These options have been true by default for many releases. @@ -833,7 +950,7 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG. serialization version ({jax-issue}`#16746`). * jax2tf in presence of shape polymorphism now generates code that checks certain shape constraints, if the serialization version is at least 7. - See https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#errors-in-presence-of-shape-polymorphism. + See https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#errors-in-presence-of-shape-polymorphism. ## jaxlib 0.4.14 (July 27, 2023) @@ -995,14 +1112,14 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG. {func}`jax.experimental.host_callback` is no longer supported on Cloud TPU with the new runtime component. Please file an issue on the [JAX issue - tracker](https://github.com/google/jax/issues) if the new `jax.debug` APIs + tracker](https://github.com/jax-ml/jax/issues) if the new `jax.debug` APIs are insufficient for your use case. The old runtime component will be available for at least the next three months by setting the environment variable `JAX_USE_PJRT_C_API_ON_TPU=false`. If you find you need to disable the new runtime for any reason, please let us know on the [JAX issue - tracker](https://github.com/google/jax/issues). + tracker](https://github.com/jax-ml/jax/issues). * Changes * The minimum jaxlib version has been bumped from 0.4.6 to 0.4.7. @@ -1026,7 +1143,7 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG. StableHLO module for the entire JAX function instead of lowering each JAX primitive to a TensorFlow op. This simplifies the internals and increases the confidence that what you serialize matches the JAX native semantics. - See [documentation](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md). + See [documentation](https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md). As part of this change the config flag `--jax2tf_default_experimental_native_lowering` has been renamed to `--jax2tf_native_serialization`. * JAX now depends on `ml_dtypes`, which contains definitions of NumPy types @@ -1303,7 +1420,7 @@ Changes: ## jaxlib 0.3.22 (Oct 11, 2022) ## jax 0.3.21 (Sep 30, 2022) -* [GitHub commits](https://github.com/google/jax/compare/jax-v0.3.20...jax-v0.3.21). +* [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.3.20...jax-v0.3.21). * Changes * The persistent compilation cache will now warn instead of raising an exception on error ({jax-issue}`#12582`), so program execution can continue @@ -1317,18 +1434,18 @@ Changes: * Fix incorrect `pip` url in `setup.py` comment ({jax-issue}`#12528`). ## jaxlib 0.3.20 (Sep 28, 2022) -* [GitHub commits](https://github.com/google/jax/compare/jaxlib-v0.3.15...jaxlib-v0.3.20). +* [GitHub commits](https://github.com/jax-ml/jax/compare/jaxlib-v0.3.15...jaxlib-v0.3.20). * Bug fixes * Fixes support for limiting the visible CUDA devices via `jax_cuda_visible_devices` in distributed jobs. This functionality is needed for the JAX/SLURM integration on GPU ({jax-issue}`#12533`). ## jax 0.3.19 (Sep 27, 2022) -* [GitHub commits](https://github.com/google/jax/compare/jax-v0.3.18...jax-v0.3.19). +* [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.3.18...jax-v0.3.19). * Fixes required jaxlib version. ## jax 0.3.18 (Sep 26, 2022) -* [GitHub commits](https://github.com/google/jax/compare/jax-v0.3.17...jax-v0.3.18). +* [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.3.17...jax-v0.3.18). * Changes * Ahead-of-time lowering and compilation functionality (tracked in {jax-issue}`#7733`) is stable and public. See [the @@ -1346,7 +1463,7 @@ Changes: would have been provided. ## jax 0.3.17 (Aug 31, 2022) -* [GitHub commits](https://github.com/google/jax/compare/jax-v0.3.16...jax-v0.3.17). +* [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.3.16...jax-v0.3.17). * Bugs * Fix corner case issue in gradient of `lax.pow` with an exponent of zero ({jax-issue}`12041`) @@ -1362,7 +1479,7 @@ Changes: * `DeviceArray.to_py()` has been deprecated. Use `np.asarray(x)` instead. ## jax 0.3.16 -* [GitHub commits](https://github.com/google/jax/compare/jax-v0.3.15...main). +* [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.3.15...main). * Breaking changes * Support for NumPy 1.19 has been dropped, per the [deprecation policy](https://jax.readthedocs.io/en/latest/deprecation.html). @@ -1386,7 +1503,7 @@ Changes: deprecated; see [JEP 11830](https://jax.readthedocs.io/en/latest/jep/11830-new-remat-checkpoint.html). ## jax 0.3.15 (July 22, 2022) -* [GitHub commits](https://github.com/google/jax/compare/jax-v0.3.14...jax-v0.3.15). +* [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.3.14...jax-v0.3.15). * Changes * `JaxTestCase` and `JaxTestLoader` have been removed from `jax.test_util`. These classes have been deprecated since v0.3.1 ({jax-issue}`#11248`). @@ -1407,10 +1524,10 @@ Changes: following a similar deprecation in {func}`scipy.linalg.solve`. ## jaxlib 0.3.15 (July 22, 2022) -* [GitHub commits](https://github.com/google/jax/compare/jaxlib-v0.3.14...jaxlib-v0.3.15). +* [GitHub commits](https://github.com/jax-ml/jax/compare/jaxlib-v0.3.14...jaxlib-v0.3.15). ## jax 0.3.14 (June 27, 2022) -* [GitHub commits](https://github.com/google/jax/compare/jax-v0.3.13...jax-v0.3.14). +* [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.3.13...jax-v0.3.14). * Breaking changes * {func}`jax.experimental.compilation_cache.initialize_cache` does not support `max_cache_size_ bytes` anymore and will not get that as an input. @@ -1463,22 +1580,22 @@ Changes: coefficients have leading zeros ({jax-issue}`#11215`). ## jaxlib 0.3.14 (June 27, 2022) -* [GitHub commits](https://github.com/google/jax/compare/jaxlib-v0.3.10...jaxlib-v0.3.14). +* [GitHub commits](https://github.com/jax-ml/jax/compare/jaxlib-v0.3.10...jaxlib-v0.3.14). * x86-64 Mac wheels now require Mac OS 10.14 (Mojave) or newer. Mac OS 10.14 was released in 2018, so this should not be a very onerous requirement. * The bundled version of NCCL was updated to 2.12.12, fixing some deadlocks. * The Python flatbuffers package is no longer a dependency of jaxlib. ## jax 0.3.13 (May 16, 2022) -* [GitHub commits](https://github.com/google/jax/compare/jax-v0.3.12...jax-v0.3.13). +* [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.3.12...jax-v0.3.13). ## jax 0.3.12 (May 15, 2022) -* [GitHub commits](https://github.com/google/jax/compare/jax-v0.3.11...jax-v0.3.12). +* [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.3.11...jax-v0.3.12). * Changes - * Fixes [#10717](https://github.com/google/jax/issues/10717). + * Fixes [#10717](https://github.com/jax-ml/jax/issues/10717). ## jax 0.3.11 (May 15, 2022) -* [GitHub commits](https://github.com/google/jax/compare/jax-v0.3.10...jax-v0.3.11). +* [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.3.10...jax-v0.3.11). * Changes * {func}`jax.lax.eigh` now accepts an optional `sort_eigenvalues` argument that allows users to opt out of eigenvalue sorting on TPU. @@ -1492,22 +1609,22 @@ Changes: scipy API, is deprecated. Use {func}`jax.scipy.linalg.polar` instead. ## jax 0.3.10 (May 3, 2022) -* [GitHub commits](https://github.com/google/jax/compare/jax-v0.3.9...jax-v0.3.10). +* [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.3.9...jax-v0.3.10). ## jaxlib 0.3.10 (May 3, 2022) -* [GitHub commits](https://github.com/google/jax/compare/jaxlib-v0.3.7...jaxlib-v0.3.10). +* [GitHub commits](https://github.com/jax-ml/jax/compare/jaxlib-v0.3.7...jaxlib-v0.3.10). * Changes * [TF commit](https://github.com/tensorflow/tensorflow/commit/207d50d253e11c3a3430a700af478a1d524a779a) fixes an issue in the MHLO canonicalizer that caused constant folding to take a long time or crash for certain programs. ## jax 0.3.9 (May 2, 2022) -* [GitHub commits](https://github.com/google/jax/compare/jax-v0.3.8...jax-v0.3.9). +* [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.3.8...jax-v0.3.9). * Changes * Added support for fully asynchronous checkpointing for GlobalDeviceArray. ## jax 0.3.8 (April 29 2022) -* [GitHub commits](https://github.com/google/jax/compare/jax-v0.3.7...jax-v0.3.8). +* [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.3.7...jax-v0.3.8). * Changes * {func}`jax.numpy.linalg.svd` on TPUs uses a qdwh-svd solver. * {func}`jax.numpy.linalg.cond` on TPUs now accepts complex input. @@ -1566,7 +1683,7 @@ Changes: ## jax 0.3.7 (April 15, 2022) * [GitHub - commits](https://github.com/google/jax/compare/jax-v0.3.6...jax-v0.3.7). + commits](https://github.com/jax-ml/jax/compare/jax-v0.3.6...jax-v0.3.7). * Changes: * Fixed a performance problem if the indices passed to {func}`jax.numpy.take_along_axis` were broadcasted ({jax-issue}`#10281`). @@ -1584,17 +1701,17 @@ Changes: ## jax 0.3.6 (April 12, 2022) * [GitHub - commits](https://github.com/google/jax/compare/jax-v0.3.5...jax-v0.3.6). + commits](https://github.com/jax-ml/jax/compare/jax-v0.3.5...jax-v0.3.6). * Changes: * Upgraded libtpu wheel to a version that fixes a hang when initializing a TPU - pod. Fixes [#10218](https://github.com/google/jax/issues/10218). + pod. Fixes [#10218](https://github.com/jax-ml/jax/issues/10218). * Deprecations: * {mod}`jax.experimental.loops` is being deprecated. See {jax-issue}`#10278` for an alternative API. ## jax 0.3.5 (April 7, 2022) * [GitHub - commits](https://github.com/google/jax/compare/jax-v0.3.4...jax-v0.3.5). + commits](https://github.com/jax-ml/jax/compare/jax-v0.3.4...jax-v0.3.5). * Changes: * added {func}`jax.random.loggamma` & improved behavior of {func}`jax.random.beta` and {func}`jax.random.dirichlet` for small parameter values ({jax-issue}`#9906`). @@ -1617,17 +1734,17 @@ Changes: ## jax 0.3.4 (March 18, 2022) * [GitHub - commits](https://github.com/google/jax/compare/jax-v0.3.3...jax-v0.3.4). + commits](https://github.com/jax-ml/jax/compare/jax-v0.3.3...jax-v0.3.4). ## jax 0.3.3 (March 17, 2022) * [GitHub - commits](https://github.com/google/jax/compare/jax-v0.3.2...jax-v0.3.3). + commits](https://github.com/jax-ml/jax/compare/jax-v0.3.2...jax-v0.3.3). ## jax 0.3.2 (March 16, 2022) * [GitHub - commits](https://github.com/google/jax/compare/jax-v0.3.1...jax-v0.3.2). + commits](https://github.com/jax-ml/jax/compare/jax-v0.3.1...jax-v0.3.2). * Changes: * The functions `jax.ops.index_update`, `jax.ops.index_add`, which were deprecated in 0.2.22, have been removed. Please use @@ -1651,7 +1768,7 @@ Changes: ## jax 0.3.1 (Feb 18, 2022) * [GitHub - commits](https://github.com/google/jax/compare/jax-v0.3.0...jax-v0.3.1). + commits](https://github.com/jax-ml/jax/compare/jax-v0.3.0...jax-v0.3.1). * Changes: * `jax.test_util.JaxTestCase` and `jax.test_util.JaxTestLoader` are now deprecated. @@ -1674,7 +1791,7 @@ Changes: ## jax 0.3.0 (Feb 10, 2022) * [GitHub - commits](https://github.com/google/jax/compare/jax-v0.2.28...jax-v0.3.0). + commits](https://github.com/jax-ml/jax/compare/jax-v0.2.28...jax-v0.3.0). * Changes * jax version has been bumped to 0.3.0. Please see the [design doc](https://jax.readthedocs.io/en/latest/design_notes/jax_versioning.html) @@ -1688,7 +1805,7 @@ Changes: ## jax 0.2.28 (Feb 1, 2022) * [GitHub - commits](https://github.com/google/jax/compare/jax-v0.2.27...jax-v0.2.28). + commits](https://github.com/jax-ml/jax/compare/jax-v0.2.27...jax-v0.2.28). * `jax.jit(f).lower(...).compiler_ir()` now defaults to the MHLO dialect if no `dialect=` is passed. * The `jax.jit(f).lower(...).compiler_ir(dialect='mhlo')` now returns an MLIR @@ -1713,7 +1830,7 @@ Changes: * The JAX jit cache requires two static arguments to have identical types for a cache hit (#9311). ## jax 0.2.27 (Jan 18 2022) -* [GitHub commits](https://github.com/google/jax/compare/jax-v0.2.26...jax-v0.2.27). +* [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.2.26...jax-v0.2.27). * Breaking changes: * Support for NumPy 1.18 has been dropped, per the @@ -1758,7 +1875,7 @@ Changes: ## jax 0.2.26 (Dec 8, 2021) * [GitHub - commits](https://github.com/google/jax/compare/jax-v0.2.25...jax-v0.2.26). + commits](https://github.com/jax-ml/jax/compare/jax-v0.2.25...jax-v0.2.26). * Bug fixes: * Out-of-bounds indices to `jax.ops.segment_sum` will now be handled with @@ -1775,7 +1892,7 @@ Changes: ## jax 0.2.25 (Nov 10, 2021) * [GitHub - commits](https://github.com/google/jax/compare/jax-v0.2.24...jax-v0.2.25). + commits](https://github.com/jax-ml/jax/compare/jax-v0.2.24...jax-v0.2.25). * New features: * (Experimental) `jax.distributed.initialize` exposes multi-host GPU backend. @@ -1789,7 +1906,7 @@ Changes: ## jax 0.2.24 (Oct 19, 2021) * [GitHub - commits](https://github.com/google/jax/compare/jax-v0.2.22...jax-v0.2.24). + commits](https://github.com/jax-ml/jax/compare/jax-v0.2.22...jax-v0.2.24). * New features: * `jax.random.choice` and `jax.random.permutation` now support @@ -1823,7 +1940,7 @@ Changes: ## jax 0.2.22 (Oct 12, 2021) * [GitHub - commits](https://github.com/google/jax/compare/jax-v0.2.21...jax-v0.2.22). + commits](https://github.com/jax-ml/jax/compare/jax-v0.2.21...jax-v0.2.22). * Breaking Changes * Static arguments to `jax.pmap` must now be hashable. @@ -1858,13 +1975,13 @@ Changes: * Support for CUDA 10.2 and CUDA 10.1 has been dropped. Jaxlib now supports CUDA 11.1+. * Bug fixes: - * Fixes https://github.com/google/jax/issues/7461, which caused wrong + * Fixes https://github.com/jax-ml/jax/issues/7461, which caused wrong outputs on all platforms due to incorrect buffer aliasing inside the XLA compiler. ## jax 0.2.21 (Sept 23, 2021) * [GitHub - commits](https://github.com/google/jax/compare/jax-v0.2.20...jax-v0.2.21). + commits](https://github.com/jax-ml/jax/compare/jax-v0.2.20...jax-v0.2.21). * Breaking Changes * `jax.api` has been removed. Functions that were available as `jax.api.*` were aliases for functions in `jax.*`; please use the functions in @@ -1892,7 +2009,7 @@ Changes: ## jax 0.2.20 (Sept 2, 2021) * [GitHub - commits](https://github.com/google/jax/compare/jax-v0.2.19...jax-v0.2.20). + commits](https://github.com/jax-ml/jax/compare/jax-v0.2.19...jax-v0.2.20). * Breaking Changes * `jnp.poly*` functions now require array-like inputs ({jax-issue}`#7732`) * `jnp.unique` and other set-like operations now require array-like inputs @@ -1905,7 +2022,7 @@ Changes: ## jax 0.2.19 (Aug 12, 2021) * [GitHub - commits](https://github.com/google/jax/compare/jax-v0.2.18...jax-v0.2.19). + commits](https://github.com/jax-ml/jax/compare/jax-v0.2.18...jax-v0.2.19). * Breaking changes: * Support for NumPy 1.17 has been dropped, per the [deprecation policy](https://jax.readthedocs.io/en/latest/deprecation.html). @@ -1942,7 +2059,7 @@ Changes: called in sequence. ## jax 0.2.18 (July 21 2021) -* [GitHub commits](https://github.com/google/jax/compare/jax-v0.2.17...jax-v0.2.18). +* [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.2.17...jax-v0.2.18). * Breaking changes: * Support for Python 3.6 has been dropped, per the @@ -1965,7 +2082,7 @@ Changes: * Fix bugs in TFRT CPU backend that results in incorrect results. ## jax 0.2.17 (July 9 2021) -* [GitHub commits](https://github.com/google/jax/compare/jax-v0.2.16...jax-v0.2.17). +* [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.2.16...jax-v0.2.17). * Bug fixes: * Default to the older "stream_executor" CPU runtime for jaxlib <= 0.1.68 to work around #7229, which caused wrong outputs on CPU due to a concurrency @@ -1982,12 +2099,12 @@ Changes: ## jax 0.2.16 (June 23 2021) -* [GitHub commits](https://github.com/google/jax/compare/jax-v0.2.15...jax-v0.2.16). +* [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.2.15...jax-v0.2.16). ## jax 0.2.15 (June 23 2021) -* [GitHub commits](https://github.com/google/jax/compare/jax-v0.2.14...jax-v0.2.15). +* [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.2.14...jax-v0.2.15). * New features: - * [#7042](https://github.com/google/jax/pull/7042) Turned on TFRT CPU backend + * [#7042](https://github.com/jax-ml/jax/pull/7042) Turned on TFRT CPU backend with significant dispatch performance improvements on CPU. * The {func}`jax2tf.convert` supports inequalities and min/max for booleans ({jax-issue}`#6956`). @@ -2007,7 +2124,7 @@ Changes: CPU. ## jax 0.2.14 (June 10 2021) -* [GitHub commits](https://github.com/google/jax/compare/jax-v0.2.13...jax-v0.2.14). +* [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.2.13...jax-v0.2.14). * New features: * The {func}`jax2tf.convert` now has support for `pjit` and `sharded_jit`. * A new configuration option JAX_TRACEBACK_FILTERING controls how JAX filters @@ -2065,7 +2182,7 @@ Changes: {func}`jit` transformed functions. ## jax 0.2.13 (May 3 2021) -* [GitHub commits](https://github.com/google/jax/compare/jax-v0.2.12...jax-v0.2.13). +* [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.2.12...jax-v0.2.13). * New features: * When combined with jaxlib 0.1.66, {func}`jax.jit` now supports static keyword arguments. A new `static_argnames` option has been added to specify @@ -2109,7 +2226,7 @@ Changes: ## jaxlib 0.1.65 (April 7 2021) ## jax 0.2.12 (April 1 2021) -* [GitHub commits](https://github.com/google/jax/compare/jax-v0.2.11...v0.2.12). +* [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.2.11...v0.2.12). * New features * New profiling APIs: {func}`jax.profiler.start_trace`, {func}`jax.profiler.stop_trace`, and {func}`jax.profiler.trace` @@ -2122,7 +2239,7 @@ Changes: * `TraceContext` --> {func}`~jax.profiler.TraceAnnotation` * `StepTraceContext` --> {func}`~jax.profiler.StepTraceAnnotation` * `trace_function` --> {func}`~jax.profiler.annotate_function` - * Omnistaging can no longer be disabled. See [omnistaging](https://github.com/google/jax/blob/main/docs/design_notes/omnistaging.md) + * Omnistaging can no longer be disabled. See [omnistaging](https://github.com/jax-ml/jax/blob/main/docs/design_notes/omnistaging.md) for more information. * Python integers larger than the maximum `int64` value will now lead to an overflow in all cases, rather than being silently converted to `uint64` in some cases ({jax-issue}`#6047`). @@ -2136,23 +2253,23 @@ Changes: ## jax 0.2.11 (March 23 2021) * [GitHub - commits](https://github.com/google/jax/compare/jax-v0.2.10...jax-v0.2.11). + commits](https://github.com/jax-ml/jax/compare/jax-v0.2.10...jax-v0.2.11). * New features: - * [#6112](https://github.com/google/jax/pull/6112) added context managers: + * [#6112](https://github.com/jax-ml/jax/pull/6112) added context managers: `jax.enable_checks`, `jax.check_tracer_leaks`, `jax.debug_nans`, `jax.debug_infs`, `jax.log_compiles`. - * [#6085](https://github.com/google/jax/pull/6085) added `jnp.delete` + * [#6085](https://github.com/jax-ml/jax/pull/6085) added `jnp.delete` * Bug fixes: - * [#6136](https://github.com/google/jax/pull/6136) generalized + * [#6136](https://github.com/jax-ml/jax/pull/6136) generalized `jax.flatten_util.ravel_pytree` to handle integer dtypes. - * [#6129](https://github.com/google/jax/issues/6129) fixed a bug with handling + * [#6129](https://github.com/jax-ml/jax/issues/6129) fixed a bug with handling some constants like `enum.IntEnums` - * [#6145](https://github.com/google/jax/pull/6145) fixed batching issues with + * [#6145](https://github.com/jax-ml/jax/pull/6145) fixed batching issues with incomplete beta functions - * [#6014](https://github.com/google/jax/pull/6014) fixed H2D transfers during + * [#6014](https://github.com/jax-ml/jax/pull/6014) fixed H2D transfers during tracing - * [#6165](https://github.com/google/jax/pull/6165) avoids OverflowErrors when + * [#6165](https://github.com/jax-ml/jax/pull/6165) avoids OverflowErrors when converting some large Python integers to floats * Breaking changes: * The minimum jaxlib version is now 0.1.62. @@ -2164,13 +2281,13 @@ Changes: ## jax 0.2.10 (March 5 2021) -* [GitHub commits](https://github.com/google/jax/compare/jax-v0.2.9...jax-v0.2.10). +* [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.2.9...jax-v0.2.10). * New features: * {func}`jax.scipy.stats.chi2` is now available as a distribution with logpdf and pdf methods. * {func}`jax.scipy.stats.betabinom` is now available as a distribution with logpmf and pmf methods. * Added {func}`jax.experimental.jax2tf.call_tf` to call TensorFlow functions from JAX ({jax-issue}`#5627`) - and [README](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#calling-tensorflow-functions-from-jax)). + and [README](https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#calling-tensorflow-functions-from-jax)). * Extended the batching rule for `lax.pad` to support batching of the padding values. * Bug fixes: * {func}`jax.numpy.take` properly handles negative indices ({jax-issue}`#5768`) @@ -2214,7 +2331,7 @@ Changes: ## jax 0.2.9 (January 26 2021) -* [GitHub commits](https://github.com/google/jax/compare/jax-v0.2.8...jax-v0.2.9). +* [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.2.8...jax-v0.2.9). * New features: * Extend the {mod}`jax.experimental.loops` module with support for pytrees. Improved error checking and error messages. @@ -2230,7 +2347,7 @@ Changes: ## jax 0.2.8 (January 12 2021) -* [GitHub commits](https://github.com/google/jax/compare/jax-v0.2.7...jax-v0.2.8). +* [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.2.7...jax-v0.2.8). * New features: * Add {func}`jax.closure_convert` for use with higher-order custom derivative functions. ({jax-issue}`#5244`) @@ -2262,7 +2379,7 @@ Changes: ## jax 0.2.7 (Dec 4 2020) -* [GitHub commits](https://github.com/google/jax/compare/jax-v0.2.6...jax-v0.2.7). +* [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.2.6...jax-v0.2.7). * New features: * Add `jax.device_put_replicated` * Add multi-host support to `jax.experimental.sharded_jit` @@ -2282,14 +2399,14 @@ Changes: ## jax 0.2.6 (Nov 18 2020) -* [GitHub commits](https://github.com/google/jax/compare/jax-v0.2.5...jax-v0.2.6). +* [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.2.5...jax-v0.2.6). * New Features: * Add support for shape-polymorphic tracing for the jax.experimental.jax2tf converter. - See [README.md](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md). + See [README.md](https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md). * Breaking change cleanup * Raise an error on non-hashable static arguments for jax.jit and - xla_computation. See [cb48f42](https://github.com/google/jax/commit/cb48f42). + xla_computation. See [cb48f42](https://github.com/jax-ml/jax/commit/cb48f42). * Improve consistency of type promotion behavior ({jax-issue}`#4744`): * Adding a complex Python scalar to a JAX floating point number respects the precision of the JAX float. For example, `jnp.float32(1) + 1j` now returns `complex64`, where previously @@ -2341,15 +2458,15 @@ Changes: ## jax 0.2.5 (October 27 2020) -* [GitHub commits](https://github.com/google/jax/compare/jax-v0.2.4...jax-v0.2.5). +* [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.2.4...jax-v0.2.5). * Improvements: * Ensure that `check_jaxpr` does not perform FLOPS. See {jax-issue}`#4650`. * Expanded the set of JAX primitives converted by jax2tf. - See [primitives_with_limited_support.md](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/primitives_with_limited_support.md). + See [primitives_with_limited_support.md](https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/primitives_with_limited_support.md). ## jax 0.2.4 (October 19 2020) -* [GitHub commits](https://github.com/google/jax/compare/jax-v0.2.3...jax-v0.2.4). +* [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.2.3...jax-v0.2.4). * Improvements: * Add support for `remat` to jax.experimental.host_callback. See {jax-issue}`#4608`. * Deprecations @@ -2361,17 +2478,17 @@ Changes: ## jax 0.2.3 (October 14 2020) -* [GitHub commits](https://github.com/google/jax/compare/jax-v0.2.2...jax-v0.2.3). +* [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.2.2...jax-v0.2.3). * The reason for another release so soon is we need to temporarily roll back a new jit fastpath while we look into a performance degradation ## jax 0.2.2 (October 13 2020) -* [GitHub commits](https://github.com/google/jax/compare/jax-v0.2.1...jax-v0.2.2). +* [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.2.1...jax-v0.2.2). ## jax 0.2.1 (October 6 2020) -* [GitHub commits](https://github.com/google/jax/compare/jax-v0.2.0...jax-v0.2.1). +* [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.2.0...jax-v0.2.1). * Improvements: * As a benefit of omnistaging, the host_callback functions are executed (in program order) even if the result of the {py:func}`jax.experimental.host_callback.id_print`/ @@ -2379,10 +2496,10 @@ Changes: ## jax (0.2.0) (September 23 2020) -* [GitHub commits](https://github.com/google/jax/compare/jax-v0.1.77...jax-v0.2.0). +* [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.1.77...jax-v0.2.0). * Improvements: * Omnistaging on by default. See {jax-issue}`#3370` and - [omnistaging](https://github.com/google/jax/blob/main/docs/design_notes/omnistaging.md) + [omnistaging](https://github.com/jax-ml/jax/blob/main/docs/design_notes/omnistaging.md) ## jax (0.1.77) (September 15 2020) @@ -2396,11 +2513,11 @@ Changes: ## jax 0.1.76 (September 8, 2020) -* [GitHub commits](https://github.com/google/jax/compare/jax-v0.1.75...jax-v0.1.76). +* [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.1.75...jax-v0.1.76). ## jax 0.1.75 (July 30, 2020) -* [GitHub commits](https://github.com/google/jax/compare/jax-v0.1.74...jax-v0.1.75). +* [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.1.74...jax-v0.1.75). * Bug Fixes: * make jnp.abs() work for unsigned inputs (#3914) * Improvements: @@ -2408,7 +2525,7 @@ Changes: ## jax 0.1.74 (July 29, 2020) -* [GitHub commits](https://github.com/google/jax/compare/jax-v0.1.73...jax-v0.1.74). +* [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.1.73...jax-v0.1.74). * New Features: * BFGS (#3101) * TPU support for half-precision arithmetic (#3878) @@ -2425,7 +2542,7 @@ Changes: ## jax 0.1.73 (July 22, 2020) -* [GitHub commits](https://github.com/google/jax/compare/jax-v0.1.72...jax-v0.1.73). +* [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.1.72...jax-v0.1.73). * The minimum jaxlib version is now 0.1.51. * New Features: * jax.image.resize. (#3703) @@ -2463,14 +2580,14 @@ Changes: ## jax 0.1.72 (June 28, 2020) -* [GitHub commits](https://github.com/google/jax/compare/jax-v0.1.71...jax-v0.1.72). +* [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.1.71...jax-v0.1.72). * Bug fixes: * Fix an odeint bug introduced in the previous release, see {jax-issue}`#3587`. ## jax 0.1.71 (June 25, 2020) -* [GitHub commits](https://github.com/google/jax/compare/jax-v0.1.70...jax-v0.1.71). +* [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.1.70...jax-v0.1.71). * The minimum jaxlib version is now 0.1.48. * Bug fixes: * Allow `jax.experimental.ode.odeint` dynamics functions to close over @@ -2506,7 +2623,7 @@ Changes: ## jax 0.1.70 (June 8, 2020) -* [GitHub commits](https://github.com/google/jax/compare/jax-v0.1.69...jax-v0.1.70). +* [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.1.69...jax-v0.1.70). * New features: * `lax.switch` introduces indexed conditionals with multiple branches, together with a generalization of the `cond` @@ -2515,11 +2632,11 @@ Changes: ## jax 0.1.69 (June 3, 2020) -* [GitHub commits](https://github.com/google/jax/compare/jax-v0.1.68...jax-v0.1.69). +* [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.1.68...jax-v0.1.69). ## jax 0.1.68 (May 21, 2020) -* [GitHub commits](https://github.com/google/jax/compare/jax-v0.1.67...jax-v0.1.68). +* [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.1.67...jax-v0.1.68). * New features: * {func}`lax.cond` supports a single-operand form, taken as the argument to both branches @@ -2530,7 +2647,7 @@ Changes: ## jax 0.1.67 (May 12, 2020) -* [GitHub commits](https://github.com/google/jax/compare/jax-v0.1.66...jax-v0.1.67). +* [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.1.66...jax-v0.1.67). * New features: * Support for reduction over subsets of a pmapped axis using `axis_index_groups` {jax-issue}`#2382`. @@ -2548,7 +2665,7 @@ Changes: ## jax 0.1.66 (May 5, 2020) -* [GitHub commits](https://github.com/google/jax/compare/jax-v0.1.65...jax-v0.1.66). +* [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.1.65...jax-v0.1.66). * New features: * Support for `in_axes=None` on {func}`pmap` {jax-issue}`#2896`. @@ -2561,7 +2678,7 @@ Changes: ## jax 0.1.65 (April 30, 2020) -* [GitHub commits](https://github.com/google/jax/compare/jax-v0.1.64...jax-v0.1.65). +* [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.1.64...jax-v0.1.65). * New features: * Differentiation of determinants of singular matrices {jax-issue}`#2809`. @@ -2579,7 +2696,7 @@ Changes: ## jax 0.1.64 (April 21, 2020) -* [GitHub commits](https://github.com/google/jax/compare/jax-v0.1.63...jax-v0.1.64). +* [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.1.63...jax-v0.1.64). * New features: * Add syntactic sugar for functional indexed updates {jax-issue}`#2684`. @@ -2606,7 +2723,7 @@ Changes: ## jax 0.1.63 (April 12, 2020) -* [GitHub commits](https://github.com/google/jax/compare/jax-v0.1.62...jax-v0.1.63). +* [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.1.62...jax-v0.1.63). * Added `jax.custom_jvp` and `jax.custom_vjp` from {jax-issue}`#2026`, see the [tutorial notebook](https://jax.readthedocs.io/en/latest/notebooks/Custom_derivative_rules_for_Python_code.html). Deprecated `jax.custom_transforms` and removed it from the docs (though it still works). * Add `scipy.sparse.linalg.cg` {jax-issue}`#2566`. * Changed how Tracers are printed to show more useful information for debugging {jax-issue}`#2591`. @@ -2627,7 +2744,7 @@ Changes: ## jax 0.1.62 (March 21, 2020) -* [GitHub commits](https://github.com/google/jax/compare/jax-v0.1.61...jax-v0.1.62). +* [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.1.61...jax-v0.1.62). * JAX has dropped support for Python 3.5. Please upgrade to Python 3.6 or newer. * Removed the internal function `lax._safe_mul`, which implemented the convention `0. * nan == 0.`. This change means some programs when @@ -2645,13 +2762,13 @@ Changes: ## jax 0.1.61 (March 17, 2020) -* [GitHub commits](https://github.com/google/jax/compare/jax-v0.1.60...jax-v0.1.61). +* [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.1.60...jax-v0.1.61). * Fixes Python 3.5 support. This will be the last JAX or jaxlib release that supports Python 3.5. ## jax 0.1.60 (March 17, 2020) -* [GitHub commits](https://github.com/google/jax/compare/jax-v0.1.59...jax-v0.1.60). +* [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.1.59...jax-v0.1.60). * New features: * {py:func}`jax.pmap` has `static_broadcast_argnums` argument which allows the user to specify arguments that should be treated as compile-time @@ -2677,7 +2794,7 @@ Changes: ## jax 0.1.59 (February 11, 2020) -* [GitHub commits](https://github.com/google/jax/compare/jax-v0.1.58...jax-v0.1.59). +* [GitHub commits](https://github.com/jax-ml/jax/compare/jax-v0.1.58...jax-v0.1.59). * Breaking changes * The minimum jaxlib version is now 0.1.38. @@ -2709,7 +2826,7 @@ Changes: ## jax 0.1.58 (January 28, 2020) -* [GitHub commits](https://github.com/google/jax/compare/46014da21...jax-v0.1.58). +* [GitHub commits](https://github.com/jax-ml/jax/compare/46014da21...jax-v0.1.58). * Breaking changes * JAX has dropped Python 2 support, because Python 2 reached its end of life on diff --git a/CITATION.bib b/CITATION.bib index 88049a1469d1..777058b5aaa9 100644 --- a/CITATION.bib +++ b/CITATION.bib @@ -1,7 +1,7 @@ @software{jax2018github, author = {James Bradbury and Roy Frostig and Peter Hawkins and Matthew James Johnson and Chris Leary and Dougal Maclaurin and George Necula and Adam Paszke and Jake Vander{P}las and Skye Wanderman-{M}ilne and Qiao Zhang}, title = {{JAX}: composable transformations of {P}ython+{N}um{P}y programs}, - url = {http://github.com/google/jax}, + url = {http://github.com/jax-ml/jax}, version = {0.3.13}, year = {2018}, } diff --git a/README.md b/README.md index b19d7b9ff128..d67bdac82414 100644 --- a/README.md +++ b/README.md @@ -1,10 +1,10 @@
-logo +logo
-# JAX: Autograd and XLA +# Transformable numerical computing at scale -![Continuous integration](https://github.com/google/jax/actions/workflows/ci-build.yaml/badge.svg) +![Continuous integration](https://github.com/jax-ml/jax/actions/workflows/ci-build.yaml/badge.svg) ![PyPI version](https://img.shields.io/pypi/v/jax) [**Quickstart**](#quickstart-colab-in-the-cloud) @@ -50,7 +50,7 @@ parallel programming of multiple accelerators, with more to come. This is a research project, not an official Google product. Expect bugs and [sharp edges](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html). Please help by trying it out, [reporting -bugs](https://github.com/google/jax/issues), and letting us know what you +bugs](https://github.com/jax-ml/jax/issues), and letting us know what you think! ```python @@ -84,16 +84,16 @@ perex_grads = jit(vmap(grad_loss, in_axes=(None, 0, 0))) # fast per-example gra Jump right in using a notebook in your browser, connected to a Google Cloud GPU. Here are some starter notebooks: - [The basics: NumPy on accelerators, `grad` for differentiation, `jit` for compilation, and `vmap` for vectorization](https://jax.readthedocs.io/en/latest/quickstart.html) -- [Training a Simple Neural Network, with TensorFlow Dataset Data Loading](https://colab.research.google.com/github/google/jax/blob/main/docs/notebooks/neural_network_with_tfds_data.ipynb) +- [Training a Simple Neural Network, with TensorFlow Dataset Data Loading](https://colab.research.google.com/github/jax-ml/jax/blob/main/docs/notebooks/neural_network_with_tfds_data.ipynb) **JAX now runs on Cloud TPUs.** To try out the preview, see the [Cloud TPU -Colabs](https://github.com/google/jax/tree/main/cloud_tpu_colabs). +Colabs](https://github.com/jax-ml/jax/tree/main/cloud_tpu_colabs). For a deeper dive into JAX: - [The Autodiff Cookbook, Part 1: easy and powerful automatic differentiation in JAX](https://jax.readthedocs.io/en/latest/notebooks/autodiff_cookbook.html) - [Common gotchas and sharp edges](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html) - See the [full list of -notebooks](https://github.com/google/jax/tree/main/docs/notebooks). +notebooks](https://github.com/jax-ml/jax/tree/main/docs/notebooks). ## Transformations @@ -273,7 +273,7 @@ from jax import random, pmap import jax.numpy as jnp # Create 8 random 5000 x 6000 matrices, one per GPU -keys = random.split(random.PRNGKey(0), 8) +keys = random.split(random.key(0), 8) mats = pmap(lambda key: random.normal(key, (5000, 6000)))(keys) # Run a local matmul on each device in parallel (no data transfer) @@ -300,7 +300,7 @@ print(normalize(jnp.arange(4.))) # prints [0. 0.16666667 0.33333334 0.5 ] ``` -You can even [nest `pmap` functions](https://colab.research.google.com/github/google/jax/blob/main/cloud_tpu_colabs/Pmap_Cookbook.ipynb#scrollTo=MdRscR5MONuN) for more +You can even [nest `pmap` functions](https://colab.research.google.com/github/jax-ml/jax/blob/main/cloud_tpu_colabs/Pmap_Cookbook.ipynb#scrollTo=MdRscR5MONuN) for more sophisticated communication patterns. It all composes, so you're free to differentiate through parallel computations: @@ -333,9 +333,9 @@ When reverse-mode differentiating a `pmap` function (e.g. with `grad`), the backward pass of the computation is parallelized just like the forward pass. See the [SPMD -Cookbook](https://colab.research.google.com/github/google/jax/blob/main/cloud_tpu_colabs/Pmap_Cookbook.ipynb) +Cookbook](https://colab.research.google.com/github/jax-ml/jax/blob/main/cloud_tpu_colabs/Pmap_Cookbook.ipynb) and the [SPMD MNIST classifier from scratch -example](https://github.com/google/jax/blob/main/examples/spmd_mnist_classifier_fromscratch.py) +example](https://github.com/jax-ml/jax/blob/main/examples/spmd_mnist_classifier_fromscratch.py) for more. ## Current gotchas @@ -349,7 +349,7 @@ Some standouts: 1. [In-place mutating updates of arrays](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#in-place-updates), like `x[i] += y`, aren't supported, but [there are functional alternatives](https://jax.readthedocs.io/en/latest/jax.ops.html). Under a `jit`, those functional alternatives will reuse buffers in-place automatically. 1. [Random numbers are - different](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#random-numbers), but for [good reasons](https://github.com/google/jax/blob/main/docs/jep/263-prng.md). + different](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#random-numbers), but for [good reasons](https://github.com/jax-ml/jax/blob/main/docs/jep/263-prng.md). 1. If you're looking for [convolution operators](https://jax.readthedocs.io/en/latest/notebooks/convolutions.html), they're in the `jax.lax` package. @@ -437,7 +437,7 @@ To cite this repository: @software{jax2018github, author = {James Bradbury and Roy Frostig and Peter Hawkins and Matthew James Johnson and Chris Leary and Dougal Maclaurin and George Necula and Adam Paszke and Jake Vander{P}las and Skye Wanderman-{M}ilne and Qiao Zhang}, title = {{JAX}: composable transformations of {P}ython+{N}um{P}y programs}, - url = {http://github.com/google/jax}, + url = {http://github.com/jax-ml/jax}, version = {0.3.13}, year = {2018}, } diff --git a/WORKSPACE b/WORKSPACE index 57e84b12ddf1..ed284acadf81 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -37,9 +37,10 @@ install_deps() load("@xla//third_party/py:python_repo.bzl", "custom_python_interpreter") custom_python_interpreter( name = "python_dev", - urls = ["https://www.python.org/ftp/python/3.13.0/Python-{version}.tgz"], - strip_prefix = "Python-{version}", - version = "3.13.0a6", + urls = ["https://www.python.org/ftp/python/{version}/Python-{version_variant}.tgz"], + strip_prefix = "Python-{version_variant}", + version = "3.13.0", + version_variant = "3.13.0rc2", ) load("@xla//:workspace4.bzl", "xla_workspace4") @@ -59,3 +60,50 @@ xla_workspace0() load("//third_party/flatbuffers:workspace.bzl", flatbuffers = "repo") flatbuffers() + +load( + "@tsl//third_party/gpus/cuda/hermetic:cuda_json_init_repository.bzl", + "cuda_json_init_repository", +) + +cuda_json_init_repository() + +load( + "@cuda_redist_json//:distributions.bzl", + "CUDA_REDISTRIBUTIONS", + "CUDNN_REDISTRIBUTIONS", +) +load( + "@tsl//third_party/gpus/cuda/hermetic:cuda_redist_init_repositories.bzl", + "cuda_redist_init_repositories", + "cudnn_redist_init_repository", +) + +cuda_redist_init_repositories( + cuda_redistributions = CUDA_REDISTRIBUTIONS, +) + +cudnn_redist_init_repository( + cudnn_redistributions = CUDNN_REDISTRIBUTIONS, +) + +load( + "@tsl//third_party/gpus/cuda/hermetic:cuda_configure.bzl", + "cuda_configure", +) + +cuda_configure(name = "local_config_cuda") + +load( + "@tsl//third_party/nccl/hermetic:nccl_redist_init_repository.bzl", + "nccl_redist_init_repository", +) + +nccl_redist_init_repository() + +load( + "@tsl//third_party/nccl/hermetic:nccl_configure.bzl", + "nccl_configure", +) + +nccl_configure(name = "local_config_nccl") diff --git a/benchmarks/api_benchmark.py b/benchmarks/api_benchmark.py index c68dab85dc8e..df9528ada9ff 100644 --- a/benchmarks/api_benchmark.py +++ b/benchmarks/api_benchmark.py @@ -566,7 +566,7 @@ def bench_repeated_static_slicing(state): while state: jax.block_until_ready([x[i:i + 2] for i in range(0, 1000, 2)]) -def pjit_simple_benchmark(state, num_devices, num_args, cpp_jit, use_aot=False): +def pjit_simple_benchmark(state, num_devices, num_args, use_aot=False): spec = jax.sharding.PartitionSpec('x') mesh = create_mesh((num_devices,), ('x',), state) if mesh is None: @@ -601,8 +601,7 @@ def pjit_simple_benchmark(state, num_devices, num_args, cpp_jit, use_aot=False): @google_benchmark.option.args([10]) @google_benchmark.option.args([100]) def pjit_simple_1_device(state): - pjit_simple_benchmark( - state, num_devices=1, num_args=state.range(0), cpp_jit=state.range(1)) + pjit_simple_benchmark(state, num_devices=1, num_args=state.range(0)) @google_benchmark.register @google_benchmark.option.arg_names(['num_args']) @@ -610,8 +609,7 @@ def pjit_simple_1_device(state): @google_benchmark.option.args([10]) @google_benchmark.option.args([100]) def pjit_simple_4_device(state): - pjit_simple_benchmark( - state, num_devices=4, num_args=state.range(0), cpp_jit=state.range(1)) + pjit_simple_benchmark(state, num_devices=4, num_args=state.range(0)) @google_benchmark.register @google_benchmark.option.arg_names(['num_args']) @@ -619,8 +617,7 @@ def pjit_simple_4_device(state): @google_benchmark.option.args([10]) @google_benchmark.option.args([100]) def pjit_simple_4000_device(state): - pjit_simple_benchmark( - state, num_devices=4000, num_args=state.range(0), cpp_jit=state.range(1)) + pjit_simple_benchmark(state, num_devices=4000, num_args=state.range(0)) @google_benchmark.register @@ -633,7 +630,6 @@ def pjit_aot_1_device(state): state, num_devices=1, num_args=state.range(0), - cpp_jit=state.range(1), use_aot=True) @@ -647,7 +643,6 @@ def pjit_aot_4_device(state): state, num_devices=4, num_args=state.range(0), - cpp_jit=state.range(1), use_aot=True) @@ -661,7 +656,6 @@ def pjit_aot_4000_device(state): state, num_devices=4000, num_args=state.range(0), - cpp_jit=state.range(1), use_aot=True) @@ -697,6 +691,8 @@ def device_put_from_numpy_array(state): @google_benchmark.option.args([100]) @google_benchmark.option.args([1000]) def device_put_from_jax_array(state): + if len(jax.devices()) < 2: + state.skip_with_error('requires 2 devices') x = [np.array(1, np.int32)] * state.range(0) x = jax.block_until_ready(jax.device_put(x, device=jax.devices()[0])) d = jax.devices()[1] @@ -839,7 +835,7 @@ def f(x): out = out + y * x[0] return out - x = jax.random.normal(jax.random.PRNGKey(0), (2, 2)) + x = jax.random.normal(jax.random.key(0), (2, 2)) f(x).block_until_ready() # compile while state: f(x).block_until_ready() @@ -929,7 +925,7 @@ def jit_add_chain(state): def g(x, y): return lax.add(x, y) - x = jax.random.normal(jax.random.PRNGKey(0), (2, 2)) + x = jax.random.normal(jax.random.key(0), (2, 2)) while state: @jax.jit def f(x): diff --git a/benchmarks/mosaic/BUILD b/benchmarks/mosaic/BUILD index 027da12ce6d3..39c7aa5f3395 100644 --- a/benchmarks/mosaic/BUILD +++ b/benchmarks/mosaic/BUILD @@ -15,7 +15,7 @@ load( "//jaxlib:jax.bzl", "jax_generate_backend_suites", - "jax_test", + "jax_multiplatform_test", "py_deps", ) @@ -28,29 +28,15 @@ package( jax_generate_backend_suites() -DISABLED_BACKENDS = [ - "cpu", - "tpu", -] - -DISABLED_CONFIGS = [ - "gpu", - "gpu_a100", - "gpu_p100", - "gpu_p100_x32", - "gpu_x32", - "gpu_pjrt_c_api", -] - -jax_test( +jax_multiplatform_test( name = "matmul_bench", srcs = ["matmul_bench.py"], - disable_backends = DISABLED_BACKENDS, - disable_configs = DISABLED_CONFIGS, + enable_backends = [], + enable_configs = ["gpu_h100"], tags = ["notap"], deps = [ + "//jax:mosaic_gpu", + "//jax/experimental/mosaic/gpu/examples:matmul", "//third_party/py/google_benchmark", - "//third_party/py/jax:mosaic_gpu", - "//third_party/py/jax/experimental/mosaic/gpu/examples:matmul", ] + py_deps("absl/testing") + py_deps("numpy"), ) diff --git a/benchmarks/sparse_benchmark.py b/benchmarks/sparse_benchmark.py index 65550b9cfee0..d6328881d5c6 100644 --- a/benchmarks/sparse_benchmark.py +++ b/benchmarks/sparse_benchmark.py @@ -109,7 +109,7 @@ def sparse_bcoo_todense_compile(state): def _sparse_bcoo_matvec(state, jit: bool = False, compile: bool = False): shape = (2000, 2000) nse = 10000 - key = jax.random.PRNGKey(1701) + key = jax.random.key(1701) mat = sparse.random_bcoo( key, nse=nse, diff --git a/build/build.py b/build/build.py index 7a418c9b4d78..de0d5a9817fb 100755 --- a/build/build.py +++ b/build/build.py @@ -218,7 +218,7 @@ def get_clang_path_or_exit(): return str(pathlib.Path(which_clang_output).resolve()) else: print( - "--use_clang set, but --clang_path is unset and clang cannot be found" + "--clang_path is unset and clang cannot be found" " on the PATH. Please pass --clang_path directly." ) sys.exit(-1) @@ -236,16 +236,14 @@ def get_clang_major_version(clang_path): return major_version - def write_bazelrc(*, remote_build, - cuda_toolkit_path, cudnn_install_path, cuda_version, cudnn_version, rocm_toolkit_path, cpu, cuda_compute_capabilities, rocm_amdgpu_targets, target_cpu_features, wheel_cpu, enable_mkl_dnn, use_clang, clang_path, - clang_major_version, enable_cuda, enable_nccl, enable_rocm, - python_version): - tf_cuda_paths = [] + clang_major_version, python_version, + enable_cuda, enable_nccl, enable_rocm, + use_cuda_nvcc): with open("../.jax_configure.bazelrc", "w") as f: if not remote_build: @@ -263,28 +261,6 @@ def write_bazelrc(*, remote_build, # https://github.com/openxla/xla/blob/c4277a076e249f5b97c8e45c8cb9d1f554089d76/.bazelrc#L505 f.write("build --copt=-Wno-gnu-offsetof-extensions\n") - if cuda_toolkit_path: - tf_cuda_paths.append(cuda_toolkit_path) - f.write("build --action_env CUDA_TOOLKIT_PATH=\"{cuda_toolkit_path}\"\n" - .format(cuda_toolkit_path=cuda_toolkit_path)) - if cudnn_install_path: - # see https://github.com/tensorflow/tensorflow/issues/51040 - if cudnn_install_path not in tf_cuda_paths: - tf_cuda_paths.append(cudnn_install_path) - f.write("build --action_env CUDNN_INSTALL_PATH=\"{cudnn_install_path}\"\n" - .format(cudnn_install_path=cudnn_install_path)) - if len(tf_cuda_paths): - f.write("build --action_env TF_CUDA_PATHS=\"{tf_cuda_paths}\"\n" - .format(tf_cuda_paths=",".join(tf_cuda_paths))) - if cuda_version: - f.write("build --action_env TF_CUDA_VERSION=\"{cuda_version}\"\n" - .format(cuda_version=cuda_version)) - if cudnn_version: - f.write("build --action_env TF_CUDNN_VERSION=\"{cudnn_version}\"\n" - .format(cudnn_version=cudnn_version)) - if cuda_compute_capabilities: - f.write( - f'build:cuda --action_env TF_CUDA_COMPUTE_CAPABILITIES="{cuda_compute_capabilities}"\n') if rocm_toolkit_path: f.write("build --action_env ROCM_PATH=\"{rocm_toolkit_path}\"\n" .format(rocm_toolkit_path=rocm_toolkit_path)) @@ -308,11 +284,22 @@ def write_bazelrc(*, remote_build, f.write("build --config=mkl_open_source_only\n") if enable_cuda: f.write("build --config=cuda\n") + if use_cuda_nvcc: + f.write("build --config=cuda_nvcc\n") + else: + f.write("build --config=cuda_clang\n") + f.write(f"build --action_env=CLANG_CUDA_COMPILER_PATH={clang_path}\n") if not enable_nccl: f.write("build --config=nonccl\n") - if use_clang: - f.write("build --config=nvcc_clang\n") - f.write(f"build --action_env=CLANG_CUDA_COMPILER_PATH={clang_path}\n") + if cuda_version: + f.write("build --repo_env HERMETIC_CUDA_VERSION=\"{cuda_version}\"\n" + .format(cuda_version=cuda_version)) + if cudnn_version: + f.write("build --repo_env HERMETIC_CUDNN_VERSION=\"{cudnn_version}\"\n" + .format(cudnn_version=cudnn_version)) + if cuda_compute_capabilities: + f.write( + f'build:cuda --repo_env HERMETIC_CUDA_COMPUTE_CAPABILITIES="{cuda_compute_capabilities}"\n') if enable_rocm: f.write("build --config=rocm\n") if not enable_nccl: @@ -364,6 +351,15 @@ def add_boolean_argument(parser, name, default=False, help_str=None): group.add_argument("--no" + name, dest=name, action="store_false") +def _get_editable_output_paths(output_path): + """Returns the paths to the editable wheels.""" + return ( + os.path.join(output_path, "jaxlib"), + os.path.join(output_path, "jax_gpu_pjrt"), + os.path.join(output_path, "jax_gpu_plugin"), + ) + + def main(): cwd = os.getcwd() parser = argparse.ArgumentParser( @@ -397,16 +393,16 @@ def main(): add_boolean_argument( parser, "use_clang", + default = "true", help_str=( - "Should we build using clang as the host compiler? Requires " - "clang to be findable via the PATH, or a path to be given via " - "--clang_path." + "DEPRECATED: This flag is redundant because clang is " + "always used as default compiler." ), ) parser.add_argument( "--clang_path", help=( - "Path to clang binary to use if --use_clang is set. The default is " + "Path to clang binary to use. The default is " "to find clang via the PATH." ), ) @@ -419,7 +415,18 @@ def main(): add_boolean_argument( parser, "enable_cuda", - help_str="Should we build with CUDA enabled? Requires CUDA and CuDNN.") + help_str="Should we build with CUDA enabled? Requires CUDA and CuDNN." + ) + add_boolean_argument( + parser, + "use_cuda_nvcc", + default=True, + help_str=( + "Should we build CUDA code using NVCC compiler driver? The default value " + "is true. If --nouse_cuda_nvcc flag is used then CUDA code is built " + "by clang compiler." + ), + ) add_boolean_argument( parser, "build_gpu_plugin", @@ -449,7 +456,7 @@ def main(): ) parser.add_argument( "--gpu_plugin_cuda_version", - choices=["11", "12"], + choices=["12"], default="12", help="Which CUDA major version the gpu plugin is for.") parser.add_argument( @@ -472,22 +479,14 @@ def main(): "remote_build", default=False, help_str="Should we build with RBE (Remote Build Environment)?") - parser.add_argument( - "--cuda_path", - default=None, - help="Path to the CUDA toolkit.") - parser.add_argument( - "--cudnn_path", - default=None, - help="Path to CUDNN libraries.") parser.add_argument( "--cuda_version", default=None, - help="CUDA toolkit version, e.g., 11.1") + help="CUDA toolkit version, e.g., 12.3.2") parser.add_argument( "--cudnn_version", default=None, - help="CUDNN version, e.g., 8") + help="CUDNN version, e.g., 8.9.7.29") # Caution: if changing the default list of CUDA capabilities, you should also # update the list in .bazelrc, which is used for wheel builds. parser.add_argument( @@ -553,12 +552,6 @@ def main(): if args.verbose: logger.setLevel(logging.DEBUG) - if is_windows() and args.enable_cuda: - if args.cuda_version is None: - parser.error("--cuda_version is needed for Windows CUDA build.") - if args.cudnn_version is None: - parser.error("--cudnn_version is needed for Windows CUDA build.") - if args.enable_cuda and args.enable_rocm: parser.error("--enable_cuda and --enable_rocm cannot be enabled at the same time.") @@ -606,15 +599,9 @@ def main(): print(f"Target CPU: {wheel_cpu}") print(f"Target CPU features: {args.target_cpu_features}") - cuda_toolkit_path = args.cuda_path - cudnn_install_path = args.cudnn_path rocm_toolkit_path = args.rocm_path print("CUDA enabled: {}".format("yes" if args.enable_cuda else "no")) if args.enable_cuda: - if cuda_toolkit_path: - print(f"CUDA toolkit path: {cuda_toolkit_path}") - if cudnn_install_path: - print(f"CUDNN library path: {cudnn_install_path}") if args.cuda_compute_capabilities is not None: print(f"CUDA compute capabilities: {args.cuda_compute_capabilities}") if args.cuda_version: @@ -631,8 +618,6 @@ def main(): write_bazelrc( remote_build=args.remote_build, - cuda_toolkit_path=cuda_toolkit_path, - cudnn_install_path=cudnn_install_path, cuda_version=args.cuda_version, cudnn_version=args.cudnn_version, rocm_toolkit_path=rocm_toolkit_path, @@ -645,26 +630,24 @@ def main(): use_clang=args.use_clang, clang_path=clang_path, clang_major_version=clang_major_version, + python_version=python_version, enable_cuda=args.enable_cuda, enable_nccl=args.enable_nccl, enable_rocm=args.enable_rocm, - python_version=python_version, + use_cuda_nvcc=args.use_cuda_nvcc, ) - if args.requirements_update: + if args.requirements_update or args.requirements_nightly_update: + if args.requirements_update: + task = "//build:requirements.update" + else: # args.requirements_nightly_update + task = "//build:requirements_nightly.update" update_command = ([bazel_path] + args.bazel_startup_options + - ["run", "--verbose_failures=true", "//build:requirements.update"]) + ["run", "--verbose_failures=true", task, *args.bazel_options]) print(" ".join(update_command)) shell(update_command) return - if args.requirements_nightly_update: - update_nightly_command = ([bazel_path] + args.bazel_startup_options + - ["run", "--verbose_failures=true", "//build:requirements_nightly.update"]) - print(" ".join(update_nightly_command)) - shell(update_nightly_command) - return - if args.configure_only: return @@ -678,11 +661,20 @@ def main(): *args.bazel_options, ) + if args.build_gpu_plugin and args.editable: + output_path_jaxlib, output_path_jax_pjrt, output_path_jax_kernel = ( + _get_editable_output_paths(output_path) + ) + else: + output_path_jaxlib = output_path + output_path_jax_pjrt = output_path + output_path_jax_kernel = output_path + if args.build_gpu_kernel_plugin == "" and not args.build_gpu_pjrt_plugin: build_cpu_wheel_command = [ *command_base, "//jaxlib/tools:build_wheel", "--", - f"--output_path={output_path}", + f"--output_path={output_path_jaxlib}", f"--jaxlib_git_hash={get_githash()}", f"--cpu={wheel_cpu}" ] @@ -698,7 +690,7 @@ def main(): build_gpu_kernels_command = [ *command_base, "//jaxlib/tools:build_gpu_kernels_wheel", "--", - f"--output_path={output_path}", + f"--output_path={output_path_jax_kernel}", f"--jaxlib_git_hash={get_githash()}", f"--cpu={wheel_cpu}", ] @@ -719,7 +711,7 @@ def main(): build_pjrt_plugin_command = [ *command_base, "//jaxlib/tools:build_gpu_plugin_wheel", "--", - f"--output_path={output_path}", + f"--output_path={output_path_jax_pjrt}", f"--jaxlib_git_hash={get_githash()}", f"--cpu={wheel_cpu}", ] diff --git a/build/requirements.in b/build/requirements.in index add6b8577350..a8d81fa5c670 100644 --- a/build/requirements.in +++ b/build/requirements.in @@ -11,14 +11,17 @@ matplotlib; python_version>="3.11" # # build deps # -numpy~=2.0.0 +numpy~=2.0.0; python_version<="3.12" +numpy~=2.1.0; python_version>="3.13" # # runtime deps # -scipy~=1.13.1 +scipy>=1.13.1 ml_dtypes>=0.4.0 opt_einsum zstandard etils[epath] +# TODO(ybaturina): remove setuptools version +setuptools<71.0.0 diff --git a/build/requirements_lock_3_13.txt b/build/requirements_lock_3_13.txt index 62b5e14e65b4..e2369a8001bb 100644 --- a/build/requirements_lock_3_13.txt +++ b/build/requirements_lock_3_13.txt @@ -1,52 +1,423 @@ # -# This file is autogenerated by pip-compile with Python 3.12 +# This file is autogenerated by pip-compile with Python 3.13 # by the following command: # -# bazel run //build:requirements_dev.update +# bazel run //build:requirements.update # -absl-py==2.1.0 +absl-py==2.1.0 \ + --hash=sha256:526a04eadab8b4ee719ce68f204172ead1027549089702d99b9059f129ff1308 \ + --hash=sha256:7820790efbb316739cde8b4e19357243fc3608a152024288513dd968d7d959ff # via -r build/test-requirements.txt -attrs==23.2.0 +attrs==24.2.0 \ + --hash=sha256:5cfb1b9148b5b086569baec03f20d7b6bf3bcacc9a42bebf87ffaaca362f6346 \ + --hash=sha256:81921eb96de3191c8258c199618104dd27ac608d9366f5e35d011eae1867ede2 # via hypothesis -build==1.2.1 +build==1.2.2 \ + --hash=sha256:119b2fb462adef986483438377a13b2f42064a2a3a4161f24a0cca698a07ac8c \ + --hash=sha256:277ccc71619d98afdd841a0e96ac9fe1593b823af481d3b0cea748e8894e0613 # via -r build/test-requirements.txt -cloudpickle==3.0.0 +cloudpickle==3.0.0 \ + --hash=sha256:246ee7d0c295602a036e86369c77fecda4ab17b506496730f2f576d9016fd9c7 \ + --hash=sha256:996d9a482c6fb4f33c1a35335cf8afd065d2a56e973270364840712d9131a882 # via -r build/test-requirements.txt -colorama==0.4.6 +colorama==0.4.6 \ + --hash=sha256:08695f5cb7ed6e0531a20572697297273c47b8cae5a63ffc6d6ed5c201be6e44 \ + --hash=sha256:4f1d9991f5acc0ca119f9d443620b77f9d6b33703e51011c16baf57afb285fc6 # via -r build/test-requirements.txt -contourpy==1.2.1 +contourpy==1.3.0 \ + --hash=sha256:00ccd0dbaad6d804ab259820fa7cb0b8036bda0686ef844d24125d8287178ce0 \ + --hash=sha256:0be4d8425bfa755e0fd76ee1e019636ccc7c29f77a7c86b4328a9eb6a26d0639 \ + --hash=sha256:0dce35502151b6bd35027ac39ba6e5a44be13a68f55735c3612c568cac3805fd \ + --hash=sha256:0fa4c02abe6c446ba70d96ece336e621efa4aecae43eaa9b030ae5fb92b309ad \ + --hash=sha256:14e262f67bd7e6eb6880bc564dcda30b15e351a594657e55b7eec94b6ef72843 \ + --hash=sha256:167d6c890815e1dac9536dca00828b445d5d0df4d6a8c6adb4a7ec3166812fa8 \ + --hash=sha256:1ec4dc6bf570f5b22ed0d7efba0dfa9c5b9e0431aeea7581aa217542d9e809a4 \ + --hash=sha256:303c252947ab4b14c08afeb52375b26781ccd6a5ccd81abcdfc1fafd14cf93c1 \ + --hash=sha256:31cd3a85dbdf1fc002280c65caa7e2b5f65e4a973fcdf70dd2fdcb9868069294 \ + --hash=sha256:32b238b3b3b649e09ce9aaf51f0c261d38644bdfa35cbaf7b263457850957a84 \ + --hash=sha256:33c92cdae89ec5135d036e7218e69b0bb2851206077251f04a6c4e0e21f03927 \ + --hash=sha256:345af746d7766821d05d72cb8f3845dfd08dd137101a2cb9b24de277d716def8 \ + --hash=sha256:3634b5385c6716c258d0419c46d05c8aa7dc8cb70326c9a4fb66b69ad2b52e09 \ + --hash=sha256:364174c2a76057feef647c802652f00953b575723062560498dc7930fc9b1cb7 \ + --hash=sha256:36e0cff201bcb17a0a8ecc7f454fe078437fa6bda730e695a92f2d9932bd507f \ + --hash=sha256:36f965570cff02b874773c49bfe85562b47030805d7d8360748f3eca570f4cab \ + --hash=sha256:3bb3808858a9dc68f6f03d319acd5f1b8a337e6cdda197f02f4b8ff67ad2057b \ + --hash=sha256:3e1c7fa44aaae40a2247e2e8e0627f4bea3dd257014764aa644f319a5f8600e3 \ + --hash=sha256:3faeb2998e4fcb256542e8a926d08da08977f7f5e62cf733f3c211c2a5586223 \ + --hash=sha256:420d39daa61aab1221567b42eecb01112908b2cab7f1b4106a52caaec8d36973 \ + --hash=sha256:4553c421929ec95fb07b3aaca0fae668b2eb5a5203d1217ca7c34c063c53d087 \ + --hash=sha256:4865cd1d419e0c7a7bf6de1777b185eebdc51470800a9f42b9e9decf17762081 \ + --hash=sha256:4cfb5c62ce023dfc410d6059c936dcf96442ba40814aefbfa575425a3a7f19dc \ + --hash=sha256:4d63ee447261e963af02642ffcb864e5a2ee4cbfd78080657a9880b8b1868e18 \ + --hash=sha256:570ef7cf892f0afbe5b2ee410c507ce12e15a5fa91017a0009f79f7d93a1268f \ + --hash=sha256:637f674226be46f6ba372fd29d9523dd977a291f66ab2a74fbeb5530bb3f445d \ + --hash=sha256:68a32389b06b82c2fdd68276148d7b9275b5f5cf13e5417e4252f6d1a34f72a2 \ + --hash=sha256:69375194457ad0fad3a839b9e29aa0b0ed53bb54db1bfb6c3ae43d111c31ce41 \ + --hash=sha256:6cb6cc968059db9c62cb35fbf70248f40994dfcd7aa10444bbf8b3faeb7c2d67 \ + --hash=sha256:710a26b3dc80c0e4febf04555de66f5fd17e9cf7170a7b08000601a10570bda6 \ + --hash=sha256:732896af21716b29ab3e988d4ce14bc5133733b85956316fb0c56355f398099b \ + --hash=sha256:75ee7cb1a14c617f34a51d11fa7524173e56551646828353c4af859c56b766e2 \ + --hash=sha256:76a896b2f195b57db25d6b44e7e03f221d32fe318d03ede41f8b4d9ba1bff53c \ + --hash=sha256:76c905ef940a4474a6289c71d53122a4f77766eef23c03cd57016ce19d0f7b42 \ + --hash=sha256:7a52040312b1a858b5e31ef28c2e865376a386c60c0e248370bbea2d3f3b760d \ + --hash=sha256:7ffa0db17717a8ffb127efd0c95a4362d996b892c2904db72428d5b52e1938a4 \ + --hash=sha256:81cb5ed4952aae6014bc9d0421dec7c5835c9c8c31cdf51910b708f548cf58e5 \ + --hash=sha256:834e0cfe17ba12f79963861e0f908556b2cedd52e1f75e6578801febcc6a9f49 \ + --hash=sha256:87ddffef1dbe5e669b5c2440b643d3fdd8622a348fe1983fad7a0f0ccb1cd67b \ + --hash=sha256:880ea32e5c774634f9fcd46504bf9f080a41ad855f4fef54f5380f5133d343c7 \ + --hash=sha256:8ca947601224119117f7c19c9cdf6b3ab54c5726ef1d906aa4a69dfb6dd58102 \ + --hash=sha256:90f73a5116ad1ba7174341ef3ea5c3150ddf20b024b98fb0c3b29034752c8aeb \ + --hash=sha256:92f8557cbb07415a4d6fa191f20fd9d2d9eb9c0b61d1b2f52a8926e43c6e9af7 \ + --hash=sha256:94e848a6b83da10898cbf1311a815f770acc9b6a3f2d646f330d57eb4e87592e \ + --hash=sha256:9c0da700bf58f6e0b65312d0a5e695179a71d0163957fa381bb3c1f72972537c \ + --hash=sha256:a11077e395f67ffc2c44ec2418cfebed032cd6da3022a94fc227b6faf8e2acb8 \ + --hash=sha256:aea348f053c645100612b333adc5983d87be69acdc6d77d3169c090d3b01dc35 \ + --hash=sha256:b11b39aea6be6764f84360fce6c82211a9db32a7c7de8fa6dd5397cf1d079c3b \ + --hash=sha256:c6c7c2408b7048082932cf4e641fa3b8ca848259212f51c8c59c45aa7ac18f14 \ + --hash=sha256:c6ec93afeb848a0845a18989da3beca3eec2c0f852322efe21af1931147d12cb \ + --hash=sha256:cacd81e2d4b6f89c9f8a5b69b86490152ff39afc58a95af002a398273e5ce589 \ + --hash=sha256:d402880b84df3bec6eab53cd0cf802cae6a2ef9537e70cf75e91618a3801c20c \ + --hash=sha256:d51fca85f9f7ad0b65b4b9fe800406d0d77017d7270d31ec3fb1cc07358fdea0 \ + --hash=sha256:d73f659398a0904e125280836ae6f88ba9b178b2fed6884f3b1f95b989d2c8da \ + --hash=sha256:d78ab28a03c854a873787a0a42254a0ccb3cb133c672f645c9f9c8f3ae9d0800 \ + --hash=sha256:da84c537cb8b97d153e9fb208c221c45605f73147bd4cadd23bdae915042aad6 \ + --hash=sha256:dbc4c3217eee163fa3984fd1567632b48d6dfd29216da3ded3d7b844a8014a66 \ + --hash=sha256:e12968fdfd5bb45ffdf6192a590bd8ddd3ba9e58360b29683c6bb71a7b41edca \ + --hash=sha256:e1fd23e9d01591bab45546c089ae89d926917a66dceb3abcf01f6105d927e2cb \ + --hash=sha256:e8134301d7e204c88ed7ab50028ba06c683000040ede1d617298611f9dc6240c \ + --hash=sha256:eb8b141bb00fa977d9122636b16aa67d37fd40a3d8b52dd837e536d64b9a4d06 \ + --hash=sha256:eca7e17a65f72a5133bdbec9ecf22401c62bcf4821361ef7811faee695799779 \ + --hash=sha256:f317576606de89da6b7e0861cf6061f6146ead3528acabff9236458a6ba467f8 \ + --hash=sha256:fd2a0fc506eccaaa7595b7e1418951f213cf8255be2600f1ea1b61e46a60c55f \ + --hash=sha256:fe41b41505a5a33aeaed2a613dccaeaa74e0e3ead6dd6fd3a118fb471644fd6c # via matplotlib -cycler==0.12.1 +cycler==0.12.1 \ + --hash=sha256:85cef7cff222d8644161529808465972e51340599459b8ac3ccbac5a854e0d30 \ + --hash=sha256:88bb128f02ba341da8ef447245a9e138fae777f6a23943da4540077d3601eb1c # via matplotlib -etils[epath,epy]==1.8.0 +etils[epath,epy]==1.9.4 \ + --hash=sha256:4387e7a4911a3b5cc4b92b99a9211386d176b43bae1dac8e2fe345fc2cb95e4b \ + --hash=sha256:fad950414f0a1ca58c70c70915b0014f9953dd9bcf8aa951a0f75ff9becbeb24 # via -r build/requirements.in -execnet==2.1.1 +execnet==2.1.1 \ + --hash=sha256:26dee51f1b80cebd6d0ca8e74dd8745419761d3bef34163928cbebbdc4749fdc \ + --hash=sha256:5189b52c6121c24feae288166ab41b32549c7e2348652736540b9e6e7d4e72e3 # via pytest-xdist -flatbuffers==24.3.25 +filelock==3.16.0 \ + --hash=sha256:81de9eb8453c769b63369f87f11131a7ab04e367f8d97ad39dc230daa07e3bec \ + --hash=sha256:f6ed4c963184f4c84dd5557ce8fece759a3724b37b80c6c4f20a2f63a4dc6609 # via -r build/test-requirements.txt -fonttools==4.51.0 +flatbuffers==24.3.25 \ + --hash=sha256:8dbdec58f935f3765e4f7f3cf635ac3a77f83568138d6a2311f524ec96364812 \ + --hash=sha256:de2ec5b203f21441716617f38443e0a8ebf3d25bf0d9c0bb0ce68fa00ad546a4 + # via -r build/test-requirements.txt +fonttools==4.53.1 \ + --hash=sha256:02569e9a810f9d11f4ae82c391ebc6fb5730d95a0657d24d754ed7763fb2d122 \ + --hash=sha256:0679a30b59d74b6242909945429dbddb08496935b82f91ea9bf6ad240ec23397 \ + --hash=sha256:10f5e6c3510b79ea27bb1ebfcc67048cde9ec67afa87c7dd7efa5c700491ac7f \ + --hash=sha256:2af40ae9cdcb204fc1d8f26b190aa16534fcd4f0df756268df674a270eab575d \ + --hash=sha256:32f029c095ad66c425b0ee85553d0dc326d45d7059dbc227330fc29b43e8ba60 \ + --hash=sha256:35250099b0cfb32d799fb5d6c651220a642fe2e3c7d2560490e6f1d3f9ae9169 \ + --hash=sha256:3b3c8ebafbee8d9002bd8f1195d09ed2bd9ff134ddec37ee8f6a6375e6a4f0e8 \ + --hash=sha256:4824c198f714ab5559c5be10fd1adf876712aa7989882a4ec887bf1ef3e00e31 \ + --hash=sha256:5ff7e5e9bad94e3a70c5cd2fa27f20b9bb9385e10cddab567b85ce5d306ea923 \ + --hash=sha256:651390c3b26b0c7d1f4407cad281ee7a5a85a31a110cbac5269de72a51551ba2 \ + --hash=sha256:6e08f572625a1ee682115223eabebc4c6a2035a6917eac6f60350aba297ccadb \ + --hash=sha256:6ed170b5e17da0264b9f6fae86073be3db15fa1bd74061c8331022bca6d09bab \ + --hash=sha256:73379d3ffdeecb376640cd8ed03e9d2d0e568c9d1a4e9b16504a834ebadc2dfb \ + --hash=sha256:75a157d8d26c06e64ace9df037ee93a4938a4606a38cb7ffaf6635e60e253b7a \ + --hash=sha256:791b31ebbc05197d7aa096bbc7bd76d591f05905d2fd908bf103af4488e60670 \ + --hash=sha256:7b6b35e52ddc8fb0db562133894e6ef5b4e54e1283dff606fda3eed938c36fc8 \ + --hash=sha256:84ec3fb43befb54be490147b4a922b5314e16372a643004f182babee9f9c3407 \ + --hash=sha256:8959a59de5af6d2bec27489e98ef25a397cfa1774b375d5787509c06659b3671 \ + --hash=sha256:9dfdae43b7996af46ff9da520998a32b105c7f098aeea06b2226b30e74fbba88 \ + --hash=sha256:9e6ceba2a01b448e36754983d376064730690401da1dd104ddb543519470a15f \ + --hash=sha256:9efd176f874cb6402e607e4cc9b4a9cd584d82fc34a4b0c811970b32ba62501f \ + --hash=sha256:a1c7c5aa18dd3b17995898b4a9b5929d69ef6ae2af5b96d585ff4005033d82f0 \ + --hash=sha256:aae7bd54187e8bf7fd69f8ab87b2885253d3575163ad4d669a262fe97f0136cb \ + --hash=sha256:b21952c092ffd827504de7e66b62aba26fdb5f9d1e435c52477e6486e9d128b2 \ + --hash=sha256:b96cd370a61f4d083c9c0053bf634279b094308d52fdc2dd9a22d8372fdd590d \ + --hash=sha256:becc5d7cb89c7b7afa8321b6bb3dbee0eec2b57855c90b3e9bf5fb816671fa7c \ + --hash=sha256:bee32ea8765e859670c4447b0817514ca79054463b6b79784b08a8df3a4d78e3 \ + --hash=sha256:c6e7170d675d12eac12ad1a981d90f118c06cf680b42a2d74c6c931e54b50719 \ + --hash=sha256:c818c058404eb2bba05e728d38049438afd649e3c409796723dfc17cd3f08749 \ + --hash=sha256:c8696544c964500aa9439efb6761947393b70b17ef4e82d73277413f291260a4 \ + --hash=sha256:c9cd19cf4fe0595ebdd1d4915882b9440c3a6d30b008f3cc7587c1da7b95be5f \ + --hash=sha256:d4d0096cb1ac7a77b3b41cd78c9b6bc4a400550e21dc7a92f2b5ab53ed74eb02 \ + --hash=sha256:d92d3c2a1b39631a6131c2fa25b5406855f97969b068e7e08413325bc0afba58 \ + --hash=sha256:da33440b1413bad53a8674393c5d29ce64d8c1a15ef8a77c642ffd900d07bfe1 \ + --hash=sha256:e013aae589c1c12505da64a7d8d023e584987e51e62006e1bb30d72f26522c41 \ + --hash=sha256:e128778a8e9bc11159ce5447f76766cefbd876f44bd79aff030287254e4752c4 \ + --hash=sha256:e54f1bba2f655924c1138bbc7fa91abd61f45c68bd65ab5ed985942712864bbb \ + --hash=sha256:e5b708073ea3d684235648786f5f6153a48dc8762cdfe5563c57e80787c29fbb \ + --hash=sha256:e8bf06b94694251861ba7fdeea15c8ec0967f84c3d4143ae9daf42bbc7717fe3 \ + --hash=sha256:f08df60fbd8d289152079a65da4e66a447efc1d5d5a4d3f299cdd39e3b2e4a7d \ + --hash=sha256:f1f8758a2ad110bd6432203a344269f445a2907dc24ef6bccfd0ac4e14e0d71d \ + --hash=sha256:f677ce218976496a587ab17140da141557beb91d2a5c1a14212c994093f2eae2 # via matplotlib -fsspec==2024.3.1 +fsspec==2024.9.0 \ + --hash=sha256:4b0afb90c2f21832df142f292649035d80b421f60a9e1c027802e5a0da2b04e8 \ + --hash=sha256:a0947d552d8a6efa72cc2c730b12c41d043509156966cca4fb157b0f2a0c574b # via etils -hypothesis==6.100.1 +hypothesis==6.112.1 \ + --hash=sha256:93631b1498b20d2c205ed304cbd41d50e9c069d78a9c773c1324ca094c5e30ce \ + --hash=sha256:b070d7a1bb9bd84706c31885c9aeddc138e2b36a9c112a91984f49501c567856 # via -r build/test-requirements.txt -importlib-resources==6.4.0 +importlib-resources==6.4.5 \ + --hash=sha256:980862a1d16c9e147a59603677fa2aa5fd82b87f223b6cb870695bcfce830065 \ + --hash=sha256:ac29d5f956f01d5e4bb63102a5a19957f1b9175e45649977264a1416783bb717 # via etils -iniconfig==2.0.0 +iniconfig==2.0.0 \ + --hash=sha256:2d91e135bf72d31a410b17c16da610a82cb55f6b0477d1a902134b24a455b8b3 \ + --hash=sha256:b6a85871a79d2e3b22d2d1b94ac2824226a63c6b741c88f7ae975f18b6778374 # via pytest -kiwisolver==1.4.5 +kiwisolver==1.4.7 \ + --hash=sha256:073a36c8273647592ea332e816e75ef8da5c303236ec0167196793eb1e34657a \ + --hash=sha256:08471d4d86cbaec61f86b217dd938a83d85e03785f51121e791a6e6689a3be95 \ + --hash=sha256:0c18ec74c0472de033e1bebb2911c3c310eef5649133dd0bedf2a169a1b269e5 \ + --hash=sha256:0c6c43471bc764fad4bc99c5c2d6d16a676b1abf844ca7c8702bdae92df01ee0 \ + --hash=sha256:10849fb2c1ecbfae45a693c070e0320a91b35dd4bcf58172c023b994283a124d \ + --hash=sha256:18077b53dc3bb490e330669a99920c5e6a496889ae8c63b58fbc57c3d7f33a18 \ + --hash=sha256:18e0cca3e008e17fe9b164b55735a325140a5a35faad8de92dd80265cd5eb80b \ + --hash=sha256:22f499f6157236c19f4bbbd472fa55b063db77a16cd74d49afe28992dff8c258 \ + --hash=sha256:2a8781ac3edc42ea4b90bc23e7d37b665d89423818e26eb6df90698aa2287c95 \ + --hash=sha256:2e6039dcbe79a8e0f044f1c39db1986a1b8071051efba3ee4d74f5b365f5226e \ + --hash=sha256:34ea1de54beef1c104422d210c47c7d2a4999bdecf42c7b5718fbe59a4cac383 \ + --hash=sha256:3ab58c12a2cd0fc769089e6d38466c46d7f76aced0a1f54c77652446733d2d02 \ + --hash=sha256:3abc5b19d24af4b77d1598a585b8a719beb8569a71568b66f4ebe1fb0449460b \ + --hash=sha256:3bf1ed55088f214ba6427484c59553123fdd9b218a42bbc8c6496d6754b1e523 \ + --hash=sha256:3ce6b2b0231bda412463e152fc18335ba32faf4e8c23a754ad50ffa70e4091ee \ + --hash=sha256:3da53da805b71e41053dc670f9a820d1157aae77b6b944e08024d17bcd51ef88 \ + --hash=sha256:3f9362ecfca44c863569d3d3c033dbe8ba452ff8eed6f6b5806382741a1334bd \ + --hash=sha256:409afdfe1e2e90e6ee7fc896f3df9a7fec8e793e58bfa0d052c8a82f99c37abb \ + --hash=sha256:40fa14dbd66b8b8f470d5fc79c089a66185619d31645f9b0773b88b19f7223c4 \ + --hash=sha256:4322872d5772cae7369f8351da1edf255a604ea7087fe295411397d0cfd9655e \ + --hash=sha256:44756f9fd339de0fb6ee4f8c1696cfd19b2422e0d70b4cefc1cc7f1f64045a8c \ + --hash=sha256:46707a10836894b559e04b0fd143e343945c97fd170d69a2d26d640b4e297935 \ + --hash=sha256:48b571ecd8bae15702e4f22d3ff6a0f13e54d3d00cd25216d5e7f658242065ee \ + --hash=sha256:48be928f59a1f5c8207154f935334d374e79f2b5d212826307d072595ad76a2e \ + --hash=sha256:4bfa75a048c056a411f9705856abfc872558e33c055d80af6a380e3658766038 \ + --hash=sha256:4c00336b9dd5ad96d0a558fd18a8b6f711b7449acce4c157e7343ba92dd0cf3d \ + --hash=sha256:4c26ed10c4f6fa6ddb329a5120ba3b6db349ca192ae211e882970bfc9d91420b \ + --hash=sha256:4d05d81ecb47d11e7f8932bd8b61b720bf0b41199358f3f5e36d38e28f0532c5 \ + --hash=sha256:4e77f2126c3e0b0d055f44513ed349038ac180371ed9b52fe96a32aa071a5107 \ + --hash=sha256:5337ec7809bcd0f424c6b705ecf97941c46279cf5ed92311782c7c9c2026f07f \ + --hash=sha256:5360cc32706dab3931f738d3079652d20982511f7c0ac5711483e6eab08efff2 \ + --hash=sha256:58370b1ffbd35407444d57057b57da5d6549d2d854fa30249771775c63b5fe17 \ + --hash=sha256:58cb20602b18f86f83a5c87d3ee1c766a79c0d452f8def86d925e6c60fbf7bfb \ + --hash=sha256:599b5c873c63a1f6ed7eead644a8a380cfbdf5db91dcb6f85707aaab213b1674 \ + --hash=sha256:5b7dfa3b546da08a9f622bb6becdb14b3e24aaa30adba66749d38f3cc7ea9706 \ + --hash=sha256:5b9c3f4ee0b9a439d2415012bd1b1cc2df59e4d6a9939f4d669241d30b414327 \ + --hash=sha256:5d34eb8494bea691a1a450141ebb5385e4b69d38bb8403b5146ad279f4b30fa3 \ + --hash=sha256:5d5abf8f8ec1f4e22882273c423e16cae834c36856cac348cfbfa68e01c40f3a \ + --hash=sha256:5e3bc157fed2a4c02ec468de4ecd12a6e22818d4f09cde2c31ee3226ffbefab2 \ + --hash=sha256:612a10bdae23404a72941a0fc8fa2660c6ea1217c4ce0dbcab8a8f6543ea9e7f \ + --hash=sha256:657a05857bda581c3656bfc3b20e353c232e9193eb167766ad2dc58b56504948 \ + --hash=sha256:65e720d2ab2b53f1f72fb5da5fb477455905ce2c88aaa671ff0a447c2c80e8e3 \ + --hash=sha256:693902d433cf585133699972b6d7c42a8b9f8f826ebcaf0132ff55200afc599e \ + --hash=sha256:6af936f79086a89b3680a280c47ea90b4df7047b5bdf3aa5c524bbedddb9e545 \ + --hash=sha256:71bb308552200fb2c195e35ef05de12f0c878c07fc91c270eb3d6e41698c3bcc \ + --hash=sha256:764202cc7e70f767dab49e8df52c7455e8de0df5d858fa801a11aa0d882ccf3f \ + --hash=sha256:76c8094ac20ec259471ac53e774623eb62e6e1f56cd8690c67ce6ce4fcb05650 \ + --hash=sha256:78a42513018c41c2ffd262eb676442315cbfe3c44eed82385c2ed043bc63210a \ + --hash=sha256:79849239c39b5e1fd906556c474d9b0439ea6792b637511f3fe3a41158d89ca8 \ + --hash=sha256:7ab9ccab2b5bd5702ab0803676a580fffa2aa178c2badc5557a84cc943fcf750 \ + --hash=sha256:7bbfcb7165ce3d54a3dfbe731e470f65739c4c1f85bb1018ee912bae139e263b \ + --hash=sha256:7c06a4c7cf15ec739ce0e5971b26c93638730090add60e183530d70848ebdd34 \ + --hash=sha256:801fa7802e5cfabe3ab0c81a34c323a319b097dfb5004be950482d882f3d7225 \ + --hash=sha256:803b8e1459341c1bb56d1c5c010406d5edec8a0713a0945851290a7930679b51 \ + --hash=sha256:82a5c2f4b87c26bb1a0ef3d16b5c4753434633b83d365cc0ddf2770c93829e3c \ + --hash=sha256:84ec80df401cfee1457063732d90022f93951944b5b58975d34ab56bb150dfb3 \ + --hash=sha256:8705f17dfeb43139a692298cb6637ee2e59c0194538153e83e9ee0c75c2eddde \ + --hash=sha256:88a9ca9c710d598fd75ee5de59d5bda2684d9db36a9f50b6125eaea3969c2599 \ + --hash=sha256:88f17c5ffa8e9462fb79f62746428dd57b46eb931698e42e990ad63103f35e6c \ + --hash=sha256:8a3ec5aa8e38fc4c8af308917ce12c536f1c88452ce554027e55b22cbbfbff76 \ + --hash=sha256:8a9c83f75223d5e48b0bc9cb1bf2776cf01563e00ade8775ffe13b0b6e1af3a6 \ + --hash=sha256:8b01aac285f91ca889c800042c35ad3b239e704b150cfd3382adfc9dcc780e39 \ + --hash=sha256:8d53103597a252fb3ab8b5845af04c7a26d5e7ea8122303dd7a021176a87e8b9 \ + --hash=sha256:8e045731a5416357638d1700927529e2b8ab304811671f665b225f8bf8d8f933 \ + --hash=sha256:8f0ea6da6d393d8b2e187e6a5e3fb81f5862010a40c3945e2c6d12ae45cfb2ad \ + --hash=sha256:90da3b5f694b85231cf93586dad5e90e2d71b9428f9aad96952c99055582f520 \ + --hash=sha256:913983ad2deb14e66d83c28b632fd35ba2b825031f2fa4ca29675e665dfecbe1 \ + --hash=sha256:9242795d174daa40105c1d86aba618e8eab7bf96ba8c3ee614da8302a9f95503 \ + --hash=sha256:929e294c1ac1e9f615c62a4e4313ca1823ba37326c164ec720a803287c4c499b \ + --hash=sha256:933d4de052939d90afbe6e9d5273ae05fb836cc86c15b686edd4b3560cc0ee36 \ + --hash=sha256:942216596dc64ddb25adb215c3c783215b23626f8d84e8eff8d6d45c3f29f75a \ + --hash=sha256:94252291e3fe68001b1dd747b4c0b3be12582839b95ad4d1b641924d68fd4643 \ + --hash=sha256:9893ff81bd7107f7b685d3017cc6583daadb4fc26e4a888350df530e41980a60 \ + --hash=sha256:9e838bba3a3bac0fe06d849d29772eb1afb9745a59710762e4ba3f4cb8424483 \ + --hash=sha256:a0f64a48bb81af7450e641e3fe0b0394d7381e342805479178b3d335d60ca7cf \ + --hash=sha256:a17f6a29cf8935e587cc8a4dbfc8368c55edc645283db0ce9801016f83526c2d \ + --hash=sha256:a1ecf0ac1c518487d9d23b1cd7139a6a65bc460cd101ab01f1be82ecf09794b6 \ + --hash=sha256:a79ae34384df2b615eefca647a2873842ac3b596418032bef9a7283675962644 \ + --hash=sha256:a91b5f9f1205845d488c928e8570dcb62b893372f63b8b6e98b863ebd2368ff2 \ + --hash=sha256:aa0abdf853e09aff551db11fce173e2177d00786c688203f52c87ad7fcd91ef9 \ + --hash=sha256:ac542bf38a8a4be2dc6b15248d36315ccc65f0743f7b1a76688ffb6b5129a5c2 \ + --hash=sha256:ad42ba922c67c5f219097b28fae965e10045ddf145d2928bfac2eb2e17673640 \ + --hash=sha256:aeb3531b196ef6f11776c21674dba836aeea9d5bd1cf630f869e3d90b16cfade \ + --hash=sha256:b38ac83d5f04b15e515fd86f312479d950d05ce2368d5413d46c088dda7de90a \ + --hash=sha256:b7d755065e4e866a8086c9bdada157133ff466476a2ad7861828e17b6026e22c \ + --hash=sha256:bd3de6481f4ed8b734da5df134cd5a6a64fe32124fe83dde1e5b5f29fe30b1e6 \ + --hash=sha256:bfa1acfa0c54932d5607e19a2c24646fb4c1ae2694437789129cf099789a3b00 \ + --hash=sha256:c619b101e6de2222c1fcb0531e1b17bbffbe54294bfba43ea0d411d428618c27 \ + --hash=sha256:ce8be0466f4c0d585cdb6c1e2ed07232221df101a4c6f28821d2aa754ca2d9e2 \ + --hash=sha256:cf0438b42121a66a3a667de17e779330fc0f20b0d97d59d2f2121e182b0505e4 \ + --hash=sha256:cf8bcc23ceb5a1b624572a1623b9f79d2c3b337c8c455405ef231933a10da379 \ + --hash=sha256:d2b0e12a42fb4e72d509fc994713d099cbb15ebf1103545e8a45f14da2dfca54 \ + --hash=sha256:d83db7cde68459fc803052a55ace60bea2bae361fc3b7a6d5da07e11954e4b09 \ + --hash=sha256:dda56c24d869b1193fcc763f1284b9126550eaf84b88bbc7256e15028f19188a \ + --hash=sha256:dea0bf229319828467d7fca8c7c189780aa9ff679c94539eed7532ebe33ed37c \ + --hash=sha256:e1631290ee9271dffe3062d2634c3ecac02c83890ada077d225e081aca8aab89 \ + --hash=sha256:e28c7fea2196bf4c2f8d46a0415c77a1c480cc0724722f23d7410ffe9842c407 \ + --hash=sha256:e2e6c39bd7b9372b0be21456caab138e8e69cc0fc1190a9dfa92bd45a1e6e904 \ + --hash=sha256:e33e8fbd440c917106b237ef1a2f1449dfbb9b6f6e1ce17c94cd6a1e0d438376 \ + --hash=sha256:e8df2eb9b2bac43ef8b082e06f750350fbbaf2887534a5be97f6cf07b19d9583 \ + --hash=sha256:e968b84db54f9d42046cf154e02911e39c0435c9801681e3fc9ce8a3c4130278 \ + --hash=sha256:eb542fe7933aa09d8d8f9d9097ef37532a7df6497819d16efe4359890a2f417a \ + --hash=sha256:edcfc407e4eb17e037bca59be0e85a2031a2ac87e4fed26d3e9df88b4165f92d \ + --hash=sha256:eee3ea935c3d227d49b4eb85660ff631556841f6e567f0f7bda972df6c2c9935 \ + --hash=sha256:ef97b8df011141c9b0f6caf23b29379f87dd13183c978a30a3c546d2c47314cb \ + --hash=sha256:f106407dda69ae456dd1227966bf445b157ccc80ba0dff3802bb63f30b74e895 \ + --hash=sha256:f3160309af4396e0ed04db259c3ccbfdc3621b5559b5453075e5de555e1f3a1b \ + --hash=sha256:f32d6edbc638cde7652bd690c3e728b25332acbadd7cad670cc4a02558d9c417 \ + --hash=sha256:f37cfe618a117e50d8c240555331160d73d0411422b59b5ee217843d7b693608 \ + --hash=sha256:f4c9aee212bc89d4e13f58be11a56cc8036cabad119259d12ace14b34476fd07 \ + --hash=sha256:f4d742cb7af1c28303a51b7a27aaee540e71bb8e24f68c736f6f2ffc82f2bf05 \ + --hash=sha256:f5a8b53bdc0b3961f8b6125e198617c40aeed638b387913bf1ce78afb1b0be2a \ + --hash=sha256:f816dd2277f8d63d79f9c8473a79fe54047bc0467754962840782c575522224d \ + --hash=sha256:f9a9e8a507420fe35992ee9ecb302dab68550dedc0da9e2880dd88071c5fb052 # via matplotlib -markdown-it-py==3.0.0 +markdown-it-py==3.0.0 \ + --hash=sha256:355216845c60bd96232cd8d8c40e8f9765cc86f46880e43a8fd22dc1a1a8cab1 \ + --hash=sha256:e3f60a94fa066dc52ec76661e37c851cb232d92f9886b15cb560aaada2df8feb # via rich -matplotlib==3.8.3 +matplotlib==3.9.2 ; python_version >= "3.11" \ + --hash=sha256:039082812cacd6c6bec8e17a9c1e6baca230d4116d522e81e1f63a74d01d2e21 \ + --hash=sha256:03ba9c1299c920964e8d3857ba27173b4dbb51ca4bab47ffc2c2ba0eb5e2cbc5 \ + --hash=sha256:050598c2b29e0b9832cde72bcf97627bf00262adbc4a54e2b856426bb2ef0697 \ + --hash=sha256:18128cc08f0d3cfff10b76baa2f296fc28c4607368a8402de61bb3f2eb33c7d9 \ + --hash=sha256:1cd93b91ab47a3616b4d3c42b52f8363b88ca021e340804c6ab2536344fad9ca \ + --hash=sha256:1d94ff717eb2bd0b58fe66380bd8b14ac35f48a98e7c6765117fe67fb7684e64 \ + --hash=sha256:306c8dfc73239f0e72ac50e5a9cf19cc4e8e331dd0c54f5e69ca8758550f1e1e \ + --hash=sha256:37e51dd1c2db16ede9cfd7b5cabdfc818b2c6397c83f8b10e0e797501c963a03 \ + --hash=sha256:3fd595f34aa8a55b7fc8bf9ebea8aa665a84c82d275190a61118d33fbc82ccae \ + --hash=sha256:4876d7d40219e8ae8bb70f9263bcbe5714415acfdf781086601211335e24f8aa \ + --hash=sha256:5413401594cfaff0052f9d8b1aafc6d305b4bd7c4331dccd18f561ff7e1d3bd3 \ + --hash=sha256:5816b1e1fe8c192cbc013f8f3e3368ac56fbecf02fb41b8f8559303f24c5015e \ + --hash=sha256:65aacf95b62272d568044531e41de26285d54aec8cb859031f511f84bd8b495a \ + --hash=sha256:6758baae2ed64f2331d4fd19be38b7b4eae3ecec210049a26b6a4f3ae1c85dcc \ + --hash=sha256:6d1ce5ed2aefcdce11904fc5bbea7d9c21fff3d5f543841edf3dea84451a09ea \ + --hash=sha256:6d9f07a80deab4bb0b82858a9e9ad53d1382fd122be8cde11080f4e7dfedb38b \ + --hash=sha256:7741f26a58a240f43bee74965c4882b6c93df3e7eb3de160126d8c8f53a6ae6e \ + --hash=sha256:8912ef7c2362f7193b5819d17dae8629b34a95c58603d781329712ada83f9447 \ + --hash=sha256:909645cce2dc28b735674ce0931a4ac94e12f5b13f6bb0b5a5e65e7cea2c192b \ + --hash=sha256:96ab43906269ca64a6366934106fa01534454a69e471b7bf3d79083981aaab92 \ + --hash=sha256:9d78bbc0cbc891ad55b4f39a48c22182e9bdaea7fc0e5dbd364f49f729ca1bbb \ + --hash=sha256:ab68d50c06938ef28681073327795c5db99bb4666214d2d5f880ed11aeaded66 \ + --hash=sha256:ac43031375a65c3196bee99f6001e7fa5bdfb00ddf43379d3c0609bdca042df9 \ + --hash=sha256:ae82a14dab96fbfad7965403c643cafe6515e386de723e498cf3eeb1e0b70cc7 \ + --hash=sha256:b2696efdc08648536efd4e1601b5fd491fd47f4db97a5fbfd175549a7365c1b2 \ + --hash=sha256:b82c5045cebcecd8496a4d694d43f9cc84aeeb49fe2133e036b207abe73f4d30 \ + --hash=sha256:be0fc24a5e4531ae4d8e858a1a548c1fe33b176bb13eff7f9d0d38ce5112a27d \ + --hash=sha256:bf81de2926c2db243c9b2cbc3917619a0fc85796c6ba4e58f541df814bbf83c7 \ + --hash=sha256:c375cc72229614632c87355366bdf2570c2dac01ac66b8ad048d2dabadf2d0d4 \ + --hash=sha256:c797dac8bb9c7a3fd3382b16fe8f215b4cf0f22adccea36f1545a6d7be310b41 \ + --hash=sha256:cef2a73d06601437be399908cf13aee74e86932a5ccc6ccdf173408ebc5f6bb2 \ + --hash=sha256:d52a3b618cb1cbb769ce2ee1dcdb333c3ab6e823944e9a2d36e37253815f9556 \ + --hash=sha256:d719465db13267bcef19ea8954a971db03b9f48b4647e3860e4bc8e6ed86610f \ + --hash=sha256:d8dd059447824eec055e829258ab092b56bb0579fc3164fa09c64f3acd478772 \ + --hash=sha256:dbe196377a8248972f5cede786d4c5508ed5f5ca4a1e09b44bda889958b33f8c \ + --hash=sha256:e0830e188029c14e891fadd99702fd90d317df294c3298aad682739c5533721a \ + --hash=sha256:f053c40f94bc51bc03832a41b4f153d83f2062d88c72b5e79997072594e97e51 \ + --hash=sha256:f32c7410c7f246838a77d6d1eff0c0f87f3cb0e7c4247aebea71a6d5a68cab49 \ + --hash=sha256:f6ee45bc4245533111ced13f1f2cace1e7f89d1c793390392a80c139d6cf0e6c \ + --hash=sha256:f7c0410f181a531ec4e93bbc27692f2c71a15c2da16766f5ba9761e7ae518413 # via -r build/requirements.in -mdurl==0.1.2 +mdurl==0.1.2 \ + --hash=sha256:84008a41e51615a49fc9966191ff91509e3c40b939176e643fd50a5c2196b8f8 \ + --hash=sha256:bb413d29f5eea38f31dd4754dd7377d4465116fb207585f97bf925588687c1ba # via markdown-it-py -ml-dtypes==0.4.0 +ml-dtypes==0.5.0 \ + --hash=sha256:099e09edd54e676903b4538f3815b5ab96f5b119690514602d96bfdb67172cbe \ + --hash=sha256:2e7534392682c3098bc7341648c650864207169c654aed83143d7a19c67ae06f \ + --hash=sha256:3e7d3a380fe73a63c884f06136f8baa7a5249cc8e9fdec677997dd78549f8128 \ + --hash=sha256:54415257f00eb44fbcc807454efac3356f75644f1cbfc2d4e5522a72ae1dacab \ + --hash=sha256:5f2b59233a0dbb6a560b3137ed6125433289ccba2f8d9c3695a52423a369ed15 \ + --hash=sha256:60275f2b51b56834e840c4809fca840565f9bf8e9a73f6d8c94f5b5935701215 \ + --hash=sha256:76942f6aeb5c40766d5ea62386daa4148e6a54322aaf5b53eae9e7553240222f \ + --hash=sha256:7ee9c320bb0f9ffdf9f6fa6a696ef2e005d1f66438d6f1c1457338e00a02e8cf \ + --hash=sha256:8c32138975797e681eb175996d64356bcfa124bdbb6a70460b9768c2b35a6fa4 \ + --hash=sha256:968fede07d1f9b926a63df97d25ac656cac1a57ebd33701734eaf704bc55d8d8 \ + --hash=sha256:a03fc861b86cc586728e3d093ba37f0cc05e65330c3ebd7688e7bae8290f8859 \ + --hash=sha256:a38df8df61194aeaae1ab7579075779b4ad32cd1cffd012c28be227fa7f2a70a \ + --hash=sha256:a988bac6572630e1e9c2edd9b1277b4eefd1c86209e52b0d061b775ac33902ff \ + --hash=sha256:ab046f2ff789b1f11b2491909682c5d089934835f9a760fafc180e47dcb676b8 \ + --hash=sha256:afa08343069874a30812871d639f9c02b4158ace065601406a493a8511180c02 \ + --hash=sha256:c7a9152f5876fef565516aa5dd1dccd6fc298a5891b2467973905103eb5c7856 \ + --hash=sha256:cb5cc7b25acabd384f75bbd78892d0c724943f3e2e1986254665a1aa10982e07 \ + --hash=sha256:d3b3db9990c3840986a0e70524e122cfa32b91139c3653df76121ba7776e015f \ + --hash=sha256:d4b1a70a3e5219790d6b55b9507606fc4e02911d1497d16c18dd721eb7efe7d0 \ + --hash=sha256:dc74fd9995513d33eac63d64e436240f5494ec74d522a9f0920194942fc3d2d7 \ + --hash=sha256:e04fde367b2fe901b1d47234426fe8819909bd1dd862a5adb630f27789c20599 # via -r build/requirements.in -mpmath==1.3.0 +mpmath==1.4.0a1 \ + --hash=sha256:78884400f439f500fa76be0121a8f9598313d87664863a192e1185ddbd7ae97f \ + --hash=sha256:f8b7b5a3a1726ab6e8c898eb2157426b82c482ab1ab8ffed9f88bb9e07c6e9c1 # via -r build/test-requirements.txt -numpy==1.26.4 +numpy==2.1.1 ; python_version >= "3.13" \ + --hash=sha256:046356b19d7ad1890c751b99acad5e82dc4a02232013bd9a9a712fddf8eb60f5 \ + --hash=sha256:0b8cc2715a84b7c3b161f9ebbd942740aaed913584cae9cdc7f8ad5ad41943d0 \ + --hash=sha256:0d07841fd284718feffe7dd17a63a2e6c78679b2d386d3e82f44f0108c905550 \ + --hash=sha256:13cc11c00000848702322af4de0147ced365c81d66053a67c2e962a485b3717c \ + --hash=sha256:13ce49a34c44b6de5241f0b38b07e44c1b2dcacd9e36c30f9c2fcb1bb5135db7 \ + --hash=sha256:24c2ad697bd8593887b019817ddd9974a7f429c14a5469d7fad413f28340a6d2 \ + --hash=sha256:251105b7c42abe40e3a689881e1793370cc9724ad50d64b30b358bbb3a97553b \ + --hash=sha256:2ca4b53e1e0b279142113b8c5eb7d7a877e967c306edc34f3b58e9be12fda8df \ + --hash=sha256:3269c9eb8745e8d975980b3a7411a98976824e1fdef11f0aacf76147f662b15f \ + --hash=sha256:397bc5ce62d3fb73f304bec332171535c187e0643e176a6e9421a6e3eacef06d \ + --hash=sha256:3fc5eabfc720db95d68e6646e88f8b399bfedd235994016351b1d9e062c4b270 \ + --hash=sha256:50a95ca3560a6058d6ea91d4629a83a897ee27c00630aed9d933dff191f170cd \ + --hash=sha256:52ac2e48f5ad847cd43c4755520a2317f3380213493b9d8a4c5e37f3b87df504 \ + --hash=sha256:53e27293b3a2b661c03f79aa51c3987492bd4641ef933e366e0f9f6c9bf257ec \ + --hash=sha256:57eb525e7c2a8fdee02d731f647146ff54ea8c973364f3b850069ffb42799647 \ + --hash=sha256:5889dd24f03ca5a5b1e8a90a33b5a0846d8977565e4ae003a63d22ecddf6782f \ + --hash=sha256:59ca673ad11d4b84ceb385290ed0ebe60266e356641428c845b39cd9df6713ab \ + --hash=sha256:6435c48250c12f001920f0751fe50c0348f5f240852cfddc5e2f97e007544cbe \ + --hash=sha256:6e5a9cb2be39350ae6c8f79410744e80154df658d5bea06e06e0ac5bb75480d5 \ + --hash=sha256:7be6a07520b88214ea85d8ac8b7d6d8a1839b0b5cb87412ac9f49fa934eb15d5 \ + --hash=sha256:7c803b7934a7f59563db459292e6aa078bb38b7ab1446ca38dd138646a38203e \ + --hash=sha256:7dd86dfaf7c900c0bbdcb8b16e2f6ddf1eb1fe39c6c8cca6e94844ed3152a8fd \ + --hash=sha256:8661c94e3aad18e1ea17a11f60f843a4933ccaf1a25a7c6a9182af70610b2313 \ + --hash=sha256:8ae0fd135e0b157365ac7cc31fff27f07a5572bdfc38f9c2d43b2aff416cc8b0 \ + --hash=sha256:910b47a6d0635ec1bd53b88f86120a52bf56dcc27b51f18c7b4a2e2224c29f0f \ + --hash=sha256:913cc1d311060b1d409e609947fa1b9753701dac96e6581b58afc36b7ee35af6 \ + --hash=sha256:920b0911bb2e4414c50e55bd658baeb78281a47feeb064ab40c2b66ecba85553 \ + --hash=sha256:950802d17a33c07cba7fd7c3dcfa7d64705509206be1606f196d179e539111ed \ + --hash=sha256:981707f6b31b59c0c24bcda52e5605f9701cb46da4b86c2e8023656ad3e833cb \ + --hash=sha256:98ce7fb5b8063cfdd86596b9c762bf2b5e35a2cdd7e967494ab78a1fa7f8b86e \ + --hash=sha256:99f4a9ee60eed1385a86e82288971a51e71df052ed0b2900ed30bc840c0f2e39 \ + --hash=sha256:9a8e06c7a980869ea67bbf551283bbed2856915f0a792dc32dd0f9dd2fb56728 \ + --hash=sha256:ae8ce252404cdd4de56dcfce8b11eac3c594a9c16c231d081fb705cf23bd4d9e \ + --hash=sha256:afd9c680df4de71cd58582b51e88a61feed4abcc7530bcd3d48483f20fc76f2a \ + --hash=sha256:b49742cdb85f1f81e4dc1b39dcf328244f4d8d1ded95dea725b316bd2cf18c95 \ + --hash=sha256:b5613cfeb1adfe791e8e681128f5f49f22f3fcaa942255a6124d58ca59d9528f \ + --hash=sha256:bab7c09454460a487e631ffc0c42057e3d8f2a9ddccd1e60c7bb8ed774992480 \ + --hash=sha256:c8a0e34993b510fc19b9a2ce7f31cb8e94ecf6e924a40c0c9dd4f62d0aac47d9 \ + --hash=sha256:caf5d284ddea7462c32b8d4a6b8af030b6c9fd5332afb70e7414d7fdded4bfd0 \ + --hash=sha256:cea427d1350f3fd0d2818ce7350095c1a2ee33e30961d2f0fef48576ddbbe90f \ + --hash=sha256:d0cf7d55b1051387807405b3898efafa862997b4cba8aa5dbe657be794afeafd \ + --hash=sha256:d10c39947a2d351d6d466b4ae83dad4c37cd6c3cdd6d5d0fa797da56f710a6ae \ + --hash=sha256:d2b9cd92c8f8e7b313b80e93cedc12c0112088541dcedd9197b5dee3738c1201 \ + --hash=sha256:d4c57b68c8ef5e1ebf47238e99bf27657511ec3f071c465f6b1bccbef12d4136 \ + --hash=sha256:d51fc141ddbe3f919e91a096ec739f49d686df8af254b2053ba21a910ae518bf \ + --hash=sha256:e097507396c0be4e547ff15b13dc3866f45f3680f789c1a1301b07dadd3fbc78 \ + --hash=sha256:e30356d530528a42eeba51420ae8bf6c6c09559051887196599d96ee5f536468 \ + --hash=sha256:e8d5f8a8e3bc87334f025194c6193e408903d21ebaeb10952264943a985066ca \ + --hash=sha256:e8dfa9e94fc127c40979c3eacbae1e61fda4fe71d84869cc129e2721973231ef \ + --hash=sha256:f212d4f46b67ff604d11fff7cc62d36b3e8714edf68e44e9760e19be38c03eb0 \ + --hash=sha256:f7506387e191fe8cdb267f912469a3cccc538ab108471291636a96a54e599556 \ + --hash=sha256:fac6e277a41163d27dfab5f4ec1f7a83fac94e170665a4a50191b545721c6521 \ + --hash=sha256:fcd8f556cdc8cfe35e70efb92463082b7f43dd7e547eb071ffc36abc0ca4699b # via # -r build/requirements.in # -r build/test-requirements.txt @@ -55,52 +426,315 @@ numpy==1.26.4 # ml-dtypes # opt-einsum # scipy -opt-einsum==3.3.0 +opt-einsum==3.3.0 \ + --hash=sha256:2455e59e3947d3c275477df7f5205b30635e266fe6dc300e3d9f9646bfcea147 \ + --hash=sha256:59f6475f77bbc37dcf7cd748519c0ec60722e91e63ca114e68821c0c54a46549 # via -r build/requirements.in -packaging==24.0 +packaging==24.1 \ + --hash=sha256:026ed72c8ed3fcce5bf8950572258698927fd1dbda10a5e981cdf0ac37f4f002 \ + --hash=sha256:5b8f2217dbdbd2f7f384c41c628544e6d52f2d0f53c6d0c3ea61aa5d1d7ff124 # via # build # matplotlib # pytest -pillow==10.3.0 +pillow==10.4.0 \ + --hash=sha256:02a2be69f9c9b8c1e97cf2713e789d4e398c751ecfd9967c18d0ce304efbf885 \ + --hash=sha256:030abdbe43ee02e0de642aee345efa443740aa4d828bfe8e2eb11922ea6a21ea \ + --hash=sha256:06b2f7898047ae93fad74467ec3d28fe84f7831370e3c258afa533f81ef7f3df \ + --hash=sha256:0755ffd4a0c6f267cccbae2e9903d95477ca2f77c4fcf3a3a09570001856c8a5 \ + --hash=sha256:0a9ec697746f268507404647e531e92889890a087e03681a3606d9b920fbee3c \ + --hash=sha256:0ae24a547e8b711ccaaf99c9ae3cd975470e1a30caa80a6aaee9a2f19c05701d \ + --hash=sha256:134ace6dc392116566980ee7436477d844520a26a4b1bd4053f6f47d096997fd \ + --hash=sha256:166c1cd4d24309b30d61f79f4a9114b7b2313d7450912277855ff5dfd7cd4a06 \ + --hash=sha256:1b5dea9831a90e9d0721ec417a80d4cbd7022093ac38a568db2dd78363b00908 \ + --hash=sha256:1d846aea995ad352d4bdcc847535bd56e0fd88d36829d2c90be880ef1ee4668a \ + --hash=sha256:1ef61f5dd14c300786318482456481463b9d6b91ebe5ef12f405afbba77ed0be \ + --hash=sha256:297e388da6e248c98bc4a02e018966af0c5f92dfacf5a5ca22fa01cb3179bca0 \ + --hash=sha256:298478fe4f77a4408895605f3482b6cc6222c018b2ce565c2b6b9c354ac3229b \ + --hash=sha256:29dbdc4207642ea6aad70fbde1a9338753d33fb23ed6956e706936706f52dd80 \ + --hash=sha256:2db98790afc70118bd0255c2eeb465e9767ecf1f3c25f9a1abb8ffc8cfd1fe0a \ + --hash=sha256:32cda9e3d601a52baccb2856b8ea1fc213c90b340c542dcef77140dfa3278a9e \ + --hash=sha256:37fb69d905be665f68f28a8bba3c6d3223c8efe1edf14cc4cfa06c241f8c81d9 \ + --hash=sha256:416d3a5d0e8cfe4f27f574362435bc9bae57f679a7158e0096ad2beb427b8696 \ + --hash=sha256:43efea75eb06b95d1631cb784aa40156177bf9dd5b4b03ff38979e048258bc6b \ + --hash=sha256:4b35b21b819ac1dbd1233317adeecd63495f6babf21b7b2512d244ff6c6ce309 \ + --hash=sha256:4d9667937cfa347525b319ae34375c37b9ee6b525440f3ef48542fcf66f2731e \ + --hash=sha256:5161eef006d335e46895297f642341111945e2c1c899eb406882a6c61a4357ab \ + --hash=sha256:543f3dc61c18dafb755773efc89aae60d06b6596a63914107f75459cf984164d \ + --hash=sha256:551d3fd6e9dc15e4c1eb6fc4ba2b39c0c7933fa113b220057a34f4bb3268a060 \ + --hash=sha256:59291fb29317122398786c2d44427bbd1a6d7ff54017075b22be9d21aa59bd8d \ + --hash=sha256:5b001114dd152cfd6b23befeb28d7aee43553e2402c9f159807bf55f33af8a8d \ + --hash=sha256:5b4815f2e65b30f5fbae9dfffa8636d992d49705723fe86a3661806e069352d4 \ + --hash=sha256:5dc6761a6efc781e6a1544206f22c80c3af4c8cf461206d46a1e6006e4429ff3 \ + --hash=sha256:5e84b6cc6a4a3d76c153a6b19270b3526a5a8ed6b09501d3af891daa2a9de7d6 \ + --hash=sha256:6209bb41dc692ddfee4942517c19ee81b86c864b626dbfca272ec0f7cff5d9fb \ + --hash=sha256:673655af3eadf4df6b5457033f086e90299fdd7a47983a13827acf7459c15d94 \ + --hash=sha256:6c762a5b0997f5659a5ef2266abc1d8851ad7749ad9a6a5506eb23d314e4f46b \ + --hash=sha256:7086cc1d5eebb91ad24ded9f58bec6c688e9f0ed7eb3dbbf1e4800280a896496 \ + --hash=sha256:73664fe514b34c8f02452ffb73b7a92c6774e39a647087f83d67f010eb9a0cf0 \ + --hash=sha256:76a911dfe51a36041f2e756b00f96ed84677cdeb75d25c767f296c1c1eda1319 \ + --hash=sha256:780c072c2e11c9b2c7ca37f9a2ee8ba66f44367ac3e5c7832afcfe5104fd6d1b \ + --hash=sha256:7928ecbf1ece13956b95d9cbcfc77137652b02763ba384d9ab508099a2eca856 \ + --hash=sha256:7970285ab628a3779aecc35823296a7869f889b8329c16ad5a71e4901a3dc4ef \ + --hash=sha256:7a8d4bade9952ea9a77d0c3e49cbd8b2890a399422258a77f357b9cc9be8d680 \ + --hash=sha256:7c1ee6f42250df403c5f103cbd2768a28fe1a0ea1f0f03fe151c8741e1469c8b \ + --hash=sha256:7dfecdbad5c301d7b5bde160150b4db4c659cee2b69589705b6f8a0c509d9f42 \ + --hash=sha256:812f7342b0eee081eaec84d91423d1b4650bb9828eb53d8511bcef8ce5aecf1e \ + --hash=sha256:866b6942a92f56300012f5fbac71f2d610312ee65e22f1aa2609e491284e5597 \ + --hash=sha256:86dcb5a1eb778d8b25659d5e4341269e8590ad6b4e8b44d9f4b07f8d136c414a \ + --hash=sha256:87dd88ded2e6d74d31e1e0a99a726a6765cda32d00ba72dc37f0651f306daaa8 \ + --hash=sha256:8bc1a764ed8c957a2e9cacf97c8b2b053b70307cf2996aafd70e91a082e70df3 \ + --hash=sha256:8d4d5063501b6dd4024b8ac2f04962d661222d120381272deea52e3fc52d3736 \ + --hash=sha256:8f0aef4ef59694b12cadee839e2ba6afeab89c0f39a3adc02ed51d109117b8da \ + --hash=sha256:930044bb7679ab003b14023138b50181899da3f25de50e9dbee23b61b4de2126 \ + --hash=sha256:950be4d8ba92aca4b2bb0741285a46bfae3ca699ef913ec8416c1b78eadd64cd \ + --hash=sha256:961a7293b2457b405967af9c77dcaa43cc1a8cd50d23c532e62d48ab6cdd56f5 \ + --hash=sha256:9b885f89040bb8c4a1573566bbb2f44f5c505ef6e74cec7ab9068c900047f04b \ + --hash=sha256:9f4727572e2918acaa9077c919cbbeb73bd2b3ebcfe033b72f858fc9fbef0026 \ + --hash=sha256:a02364621fe369e06200d4a16558e056fe2805d3468350df3aef21e00d26214b \ + --hash=sha256:a985e028fc183bf12a77a8bbf36318db4238a3ded7fa9df1b9a133f1cb79f8fc \ + --hash=sha256:ac1452d2fbe4978c2eec89fb5a23b8387aba707ac72810d9490118817d9c0b46 \ + --hash=sha256:b15e02e9bb4c21e39876698abf233c8c579127986f8207200bc8a8f6bb27acf2 \ + --hash=sha256:b2724fdb354a868ddf9a880cb84d102da914e99119211ef7ecbdc613b8c96b3c \ + --hash=sha256:bbc527b519bd3aa9d7f429d152fea69f9ad37c95f0b02aebddff592688998abe \ + --hash=sha256:bcd5e41a859bf2e84fdc42f4edb7d9aba0a13d29a2abadccafad99de3feff984 \ + --hash=sha256:bd2880a07482090a3bcb01f4265f1936a903d70bc740bfcb1fd4e8a2ffe5cf5a \ + --hash=sha256:bee197b30783295d2eb680b311af15a20a8b24024a19c3a26431ff83eb8d1f70 \ + --hash=sha256:bf2342ac639c4cf38799a44950bbc2dfcb685f052b9e262f446482afaf4bffca \ + --hash=sha256:c76e5786951e72ed3686e122d14c5d7012f16c8303a674d18cdcd6d89557fc5b \ + --hash=sha256:cbed61494057c0f83b83eb3a310f0bf774b09513307c434d4366ed64f4128a91 \ + --hash=sha256:cfdd747216947628af7b259d274771d84db2268ca062dd5faf373639d00113a3 \ + --hash=sha256:d7480af14364494365e89d6fddc510a13e5a2c3584cb19ef65415ca57252fb84 \ + --hash=sha256:dbc6ae66518ab3c5847659e9988c3b60dc94ffb48ef9168656e0019a93dbf8a1 \ + --hash=sha256:dc3e2db6ba09ffd7d02ae9141cfa0ae23393ee7687248d46a7507b75d610f4f5 \ + --hash=sha256:dfe91cb65544a1321e631e696759491ae04a2ea11d36715eca01ce07284738be \ + --hash=sha256:e4d49b85c4348ea0b31ea63bc75a9f3857869174e2bf17e7aba02945cd218e6f \ + --hash=sha256:e4db64794ccdf6cb83a59d73405f63adbe2a1887012e308828596100a0b2f6cc \ + --hash=sha256:e553cad5179a66ba15bb18b353a19020e73a7921296a7979c4a2b7f6a5cd57f9 \ + --hash=sha256:e88d5e6ad0d026fba7bdab8c3f225a69f063f116462c49892b0149e21b6c0a0e \ + --hash=sha256:ecd85a8d3e79cd7158dec1c9e5808e821feea088e2f69a974db5edf84dc53141 \ + --hash=sha256:f5b92f4d70791b4a67157321c4e8225d60b119c5cc9aee8ecf153aace4aad4ef \ + --hash=sha256:f5f0c3e969c8f12dd2bb7e0b15d5c468b51e5017e01e2e867335c81903046a22 \ + --hash=sha256:f7baece4ce06bade126fb84b8af1c33439a76d8a6fd818970215e0560ca28c27 \ + --hash=sha256:ff25afb18123cea58a591ea0244b92eb1e61a1fd497bf6d6384f09bc3262ec3e \ + --hash=sha256:ff337c552345e95702c5fde3158acb0625111017d0e5f24bf3acdb9cc16b90d1 # via # -r build/test-requirements.txt # matplotlib -pluggy==1.4.0 +pluggy==1.5.0 \ + --hash=sha256:2cffa88e94fdc978c4c574f15f9e59b7f4201d439195c3715ca9e2486f1d0cf1 \ + --hash=sha256:44e1ad92c8ca002de6377e165f3e0f1be63266ab4d554740532335b9d75ea669 # via pytest -portpicker==1.6.0 +portpicker==1.6.0 \ + --hash=sha256:b2787a41404cf7edbe29b07b9e0ed863b09f2665dcc01c1eb0c2261c1e7d0755 \ + --hash=sha256:bd507fd6f96f65ee02781f2e674e9dc6c99bbfa6e3c39992e3916204c9d431fa # via -r build/test-requirements.txt -psutil==5.9.8 +psutil==6.0.0 \ + --hash=sha256:02b69001f44cc73c1c5279d02b30a817e339ceb258ad75997325e0e6169d8b35 \ + --hash=sha256:1287c2b95f1c0a364d23bc6f2ea2365a8d4d9b726a3be7294296ff7ba97c17f0 \ + --hash=sha256:1e7c870afcb7d91fdea2b37c24aeb08f98b6d67257a5cb0a8bc3ac68d0f1a68c \ + --hash=sha256:21f1fb635deccd510f69f485b87433460a603919b45e2a324ad65b0cc74f8fb1 \ + --hash=sha256:33ea5e1c975250a720b3a6609c490db40dae5d83a4eb315170c4fe0d8b1f34b3 \ + --hash=sha256:34859b8d8f423b86e4385ff3665d3f4d94be3cdf48221fbe476e883514fdb71c \ + --hash=sha256:5fd9a97c8e94059b0ef54a7d4baf13b405011176c3b6ff257c247cae0d560ecd \ + --hash=sha256:6ec7588fb3ddaec7344a825afe298db83fe01bfaaab39155fa84cf1c0d6b13c3 \ + --hash=sha256:6ed2440ada7ef7d0d608f20ad89a04ec47d2d3ab7190896cd62ca5fc4fe08bf0 \ + --hash=sha256:8faae4f310b6d969fa26ca0545338b21f73c6b15db7c4a8d934a5482faa818f2 \ + --hash=sha256:a021da3e881cd935e64a3d0a20983bda0bb4cf80e4f74fa9bfcb1bc5785360c6 \ + --hash=sha256:a495580d6bae27291324fe60cea0b5a7c23fa36a7cd35035a16d93bdcf076b9d \ + --hash=sha256:a9a3dbfb4de4f18174528d87cc352d1f788b7496991cca33c6996f40c9e3c92c \ + --hash=sha256:c588a7e9b1173b6e866756dde596fd4cad94f9399daf99ad8c3258b3cb2b47a0 \ + --hash=sha256:e2e8d0054fc88153ca0544f5c4d554d42e33df2e009c4ff42284ac9ebdef4132 \ + --hash=sha256:fc8c9510cde0146432bbdb433322861ee8c3efbf8589865c8bf8d21cb30c4d14 \ + --hash=sha256:ffe7fc9b6b36beadc8c322f84e1caff51e8703b88eee1da46d1e3a6ae11b4fd0 # via portpicker -pygments==2.17.2 +pygments==2.18.0 \ + --hash=sha256:786ff802f32e91311bff3889f6e9a86e81505fe99f2735bb6d60ae0c5004f199 \ + --hash=sha256:b8e6aca0523f3ab76fee51799c488e38782ac06eafcf95e7ba832985c8e7b13a # via rich -pyparsing==3.1.2 +pyparsing==3.2.0b1 \ + --hash=sha256:51e00c907f7b2ac2d2c35c4d431e944c525ddcfd58b09517f308f40d70e0ddca \ + --hash=sha256:ecf0805530839936196a802cd6d6d65ffa9328eebdc8ee5b8f4b358be5f16666 # via matplotlib -pyproject-hooks==1.0.0 +pyproject-hooks==1.1.0 \ + --hash=sha256:4b37730834edbd6bd37f26ece6b44802fb1c1ee2ece0e54ddff8bfc06db86965 \ + --hash=sha256:7ceeefe9aec63a1064c18d939bdc3adf2d8aa1988a510afec15151578b232aa2 # via build -pytest==8.1.1 +pytest==8.3.3 \ + --hash=sha256:70b98107bd648308a7952b06e6ca9a50bc660be218d53c257cc1fc94fda10181 \ + --hash=sha256:a6853c7375b2663155079443d2e45de913a911a11d669df02a50814944db57b2 # via pytest-xdist -pytest-xdist==3.5.0 +pytest-xdist==3.6.1 \ + --hash=sha256:9ed4adfb68a016610848639bb7e02c9352d5d9f03d04809919e2dafc3be4cca7 \ + --hash=sha256:ead156a4db231eec769737f57668ef58a2084a34b2e55c4a8fa20d861107300d # via -r build/test-requirements.txt -python-dateutil==2.9.0.post0 +python-dateutil==2.9.0.post0 \ + --hash=sha256:37dd54208da7e1cd875388217d5e00ebd4179249f90fb72437e91a35459a0ad3 \ + --hash=sha256:a8b2bc7bffae282281c8140a97d3aa9c14da0b136dfe83f850eea9a5f7470427 # via matplotlib -rich==13.7.1 +rich==13.8.1 \ + --hash=sha256:1760a3c0848469b97b558fc61c85233e3dafb69c7a071b4d60c38099d3cd4c06 \ + --hash=sha256:8260cda28e3db6bf04d2d1ef4dbc03ba80a824c88b0e7668a0f23126a424844a # via -r build/test-requirements.txt -scipy==1.13.1 +scipy==1.14.1 \ + --hash=sha256:0c2f95de3b04e26f5f3ad5bb05e74ba7f68b837133a4492414b3afd79dfe540e \ + --hash=sha256:1729560c906963fc8389f6aac023739ff3983e727b1a4d87696b7bf108316a79 \ + --hash=sha256:278266012eb69f4a720827bdd2dc54b2271c97d84255b2faaa8f161a158c3b37 \ + --hash=sha256:2843f2d527d9eebec9a43e6b406fb7266f3af25a751aa91d62ff416f54170bc5 \ + --hash=sha256:2da0469a4ef0ecd3693761acbdc20f2fdeafb69e6819cc081308cc978153c675 \ + --hash=sha256:2ff0a7e01e422c15739ecd64432743cf7aae2b03f3084288f399affcefe5222d \ + --hash=sha256:2ff38e22128e6c03ff73b6bb0f85f897d2362f8c052e3b8ad00532198fbdae3f \ + --hash=sha256:30ac8812c1d2aab7131a79ba62933a2a76f582d5dbbc695192453dae67ad6310 \ + --hash=sha256:3a1b111fac6baec1c1d92f27e76511c9e7218f1695d61b59e05e0fe04dc59617 \ + --hash=sha256:4079b90df244709e675cdc8b93bfd8a395d59af40b72e339c2287c91860deb8e \ + --hash=sha256:5149e3fd2d686e42144a093b206aef01932a0059c2a33ddfa67f5f035bdfe13e \ + --hash=sha256:5a275584e726026a5699459aa72f828a610821006228e841b94275c4a7c08417 \ + --hash=sha256:631f07b3734d34aced009aaf6fedfd0eb3498a97e581c3b1e5f14a04164a456d \ + --hash=sha256:716e389b694c4bb564b4fc0c51bc84d381735e0d39d3f26ec1af2556ec6aad94 \ + --hash=sha256:8426251ad1e4ad903a4514712d2fa8fdd5382c978010d1c6f5f37ef286a713ad \ + --hash=sha256:8475230e55549ab3f207bff11ebfc91c805dc3463ef62eda3ccf593254524ce8 \ + --hash=sha256:8bddf15838ba768bb5f5083c1ea012d64c9a444e16192762bd858f1e126196d0 \ + --hash=sha256:8e32dced201274bf96899e6491d9ba3e9a5f6b336708656466ad0522d8528f69 \ + --hash=sha256:8f9ea80f2e65bdaa0b7627fb00cbeb2daf163caa015e59b7516395fe3bd1e066 \ + --hash=sha256:97c5dddd5932bd2a1a31c927ba5e1463a53b87ca96b5c9bdf5dfd6096e27efc3 \ + --hash=sha256:a49f6ed96f83966f576b33a44257d869756df6cf1ef4934f59dd58b25e0327e5 \ + --hash=sha256:af29a935803cc707ab2ed7791c44288a682f9c8107bc00f0eccc4f92c08d6e07 \ + --hash=sha256:b05d43735bb2f07d689f56f7b474788a13ed8adc484a85aa65c0fd931cf9ccd2 \ + --hash=sha256:b28d2ca4add7ac16ae8bb6632a3c86e4b9e4d52d3e34267f6e1b0c1f8d87e389 \ + --hash=sha256:b99722ea48b7ea25e8e015e8341ae74624f72e5f21fc2abd45f3a93266de4c5d \ + --hash=sha256:baff393942b550823bfce952bb62270ee17504d02a1801d7fd0719534dfb9c84 \ + --hash=sha256:c0ee987efa6737242745f347835da2cc5bb9f1b42996a4d97d5c7ff7928cb6f2 \ + --hash=sha256:d0d2821003174de06b69e58cef2316a6622b60ee613121199cb2852a873f8cf3 \ + --hash=sha256:e0cf28db0f24a38b2a0ca33a85a54852586e43cf6fd876365c86e0657cfe7d73 \ + --hash=sha256:e4f5a7c49323533f9103d4dacf4e4f07078f360743dec7f7596949149efeec06 \ + --hash=sha256:eb58ca0abd96911932f688528977858681a59d61a7ce908ffd355957f7025cfc \ + --hash=sha256:edaf02b82cd7639db00dbff629995ef185c8df4c3ffa71a5562a595765a06ce1 \ + --hash=sha256:fef8c87f8abfb884dac04e97824b61299880c43f4ce675dd2cbeadd3c9b466d2 # via -r build/requirements.in -six==1.16.0 +six==1.16.0 \ + --hash=sha256:1e61c37477a1626458e36f7b1d82aa5c9b094fa4802892072e49de9c60c4c926 \ + --hash=sha256:8abb2f1d86890a2dfb989f9a77cfcfd3e47c2a354b01111771326f8aa26e0254 # via python-dateutil -sortedcontainers==2.4.0 +sortedcontainers==2.4.0 \ + --hash=sha256:25caa5a06cc30b6b83d11423433f65d1f9d76c4c6a0c90e3379eaa43b9bfdb88 \ + --hash=sha256:a163dcaede0f1c021485e957a39245190e74249897e2ae4b2aa38595db237ee0 # via hypothesis -typing-extensions==4.11.0 +typing-extensions==4.12.2 \ + --hash=sha256:04e5ca0351e0f3f85c6853954072df659d0d13fac324d0072316b67d7794700d \ + --hash=sha256:1a7ead55c7e559dd4dee8856e3a88b41225abfe1ce8df57b7c13915fe121ffb8 # via etils -wheel==0.43.0 +wheel==0.44.0 \ + --hash=sha256:2376a90c98cc337d18623527a97c31797bd02bad0033d41547043a1cbfbe448f \ + --hash=sha256:a29c3f2817e95ab89aa4660681ad547c0e9547f20e75b0562fe7723c9a2a9d49 # via -r build/test-requirements.txt -zipp==3.18.1 +zipp==3.20.2 \ + --hash=sha256:a817ac80d6cf4b23bf7f2828b7cabf326f15a001bea8b1f9b49631780ba28350 \ + --hash=sha256:bc9eb26f4506fda01b81bcde0ca78103b6e62f991b381fec825435c836edbc29 # via etils -zstandard==0.22.0 +zstandard==0.23.0 \ + --hash=sha256:034b88913ecc1b097f528e42b539453fa82c3557e414b3de9d5632c80439a473 \ + --hash=sha256:0a7f0804bb3799414af278e9ad51be25edf67f78f916e08afdb983e74161b916 \ + --hash=sha256:11e3bf3c924853a2d5835b24f03eeba7fc9b07d8ca499e247e06ff5676461a15 \ + --hash=sha256:12a289832e520c6bd4dcaad68e944b86da3bad0d339ef7989fb7e88f92e96072 \ + --hash=sha256:1516c8c37d3a053b01c1c15b182f3b5f5eef19ced9b930b684a73bad121addf4 \ + --hash=sha256:157e89ceb4054029a289fb504c98c6a9fe8010f1680de0201b3eb5dc20aa6d9e \ + --hash=sha256:1bfe8de1da6d104f15a60d4a8a768288f66aa953bbe00d027398b93fb9680b26 \ + --hash=sha256:1e172f57cd78c20f13a3415cc8dfe24bf388614324d25539146594c16d78fcc8 \ + --hash=sha256:1fd7e0f1cfb70eb2f95a19b472ee7ad6d9a0a992ec0ae53286870c104ca939e5 \ + --hash=sha256:203d236f4c94cd8379d1ea61db2fce20730b4c38d7f1c34506a31b34edc87bdd \ + --hash=sha256:27d3ef2252d2e62476389ca8f9b0cf2bbafb082a3b6bfe9d90cbcbb5529ecf7c \ + --hash=sha256:29a2bc7c1b09b0af938b7a8343174b987ae021705acabcbae560166567f5a8db \ + --hash=sha256:2ef230a8fd217a2015bc91b74f6b3b7d6522ba48be29ad4ea0ca3a3775bf7dd5 \ + --hash=sha256:2ef3775758346d9ac6214123887d25c7061c92afe1f2b354f9388e9e4d48acfc \ + --hash=sha256:2f146f50723defec2975fb7e388ae3a024eb7151542d1599527ec2aa9cacb152 \ + --hash=sha256:2fb4535137de7e244c230e24f9d1ec194f61721c86ebea04e1581d9d06ea1269 \ + --hash=sha256:32ba3b5ccde2d581b1e6aa952c836a6291e8435d788f656fe5976445865ae045 \ + --hash=sha256:34895a41273ad33347b2fc70e1bff4240556de3c46c6ea430a7ed91f9042aa4e \ + --hash=sha256:379b378ae694ba78cef921581ebd420c938936a153ded602c4fea612b7eaa90d \ + --hash=sha256:38302b78a850ff82656beaddeb0bb989a0322a8bbb1bf1ab10c17506681d772a \ + --hash=sha256:3aa014d55c3af933c1315eb4bb06dd0459661cc0b15cd61077afa6489bec63bb \ + --hash=sha256:4051e406288b8cdbb993798b9a45c59a4896b6ecee2f875424ec10276a895740 \ + --hash=sha256:40b33d93c6eddf02d2c19f5773196068d875c41ca25730e8288e9b672897c105 \ + --hash=sha256:43da0f0092281bf501f9c5f6f3b4c975a8a0ea82de49ba3f7100e64d422a1274 \ + --hash=sha256:445e4cb5048b04e90ce96a79b4b63140e3f4ab5f662321975679b5f6360b90e2 \ + --hash=sha256:48ef6a43b1846f6025dde6ed9fee0c24e1149c1c25f7fb0a0585572b2f3adc58 \ + --hash=sha256:50a80baba0285386f97ea36239855f6020ce452456605f262b2d33ac35c7770b \ + --hash=sha256:519fbf169dfac1222a76ba8861ef4ac7f0530c35dd79ba5727014613f91613d4 \ + --hash=sha256:53dd9d5e3d29f95acd5de6802e909ada8d8d8cfa37a3ac64836f3bc4bc5512db \ + --hash=sha256:53ea7cdc96c6eb56e76bb06894bcfb5dfa93b7adcf59d61c6b92674e24e2dd5e \ + --hash=sha256:576856e8594e6649aee06ddbfc738fec6a834f7c85bf7cadd1c53d4a58186ef9 \ + --hash=sha256:59556bf80a7094d0cfb9f5e50bb2db27fefb75d5138bb16fb052b61b0e0eeeb0 \ + --hash=sha256:5d41d5e025f1e0bccae4928981e71b2334c60f580bdc8345f824e7c0a4c2a813 \ + --hash=sha256:61062387ad820c654b6a6b5f0b94484fa19515e0c5116faf29f41a6bc91ded6e \ + --hash=sha256:61f89436cbfede4bc4e91b4397eaa3e2108ebe96d05e93d6ccc95ab5714be512 \ + --hash=sha256:62136da96a973bd2557f06ddd4e8e807f9e13cbb0bfb9cc06cfe6d98ea90dfe0 \ + --hash=sha256:64585e1dba664dc67c7cdabd56c1e5685233fbb1fc1966cfba2a340ec0dfff7b \ + --hash=sha256:65308f4b4890aa12d9b6ad9f2844b7ee42c7f7a4fd3390425b242ffc57498f48 \ + --hash=sha256:66b689c107857eceabf2cf3d3fc699c3c0fe8ccd18df2219d978c0283e4c508a \ + --hash=sha256:6a41c120c3dbc0d81a8e8adc73312d668cd34acd7725f036992b1b72d22c1772 \ + --hash=sha256:6f77fa49079891a4aab203d0b1744acc85577ed16d767b52fc089d83faf8d8ed \ + --hash=sha256:72c68dda124a1a138340fb62fa21b9bf4848437d9ca60bd35db36f2d3345f373 \ + --hash=sha256:752bf8a74412b9892f4e5b58f2f890a039f57037f52c89a740757ebd807f33ea \ + --hash=sha256:76e79bc28a65f467e0409098fa2c4376931fd3207fbeb6b956c7c476d53746dd \ + --hash=sha256:774d45b1fac1461f48698a9d4b5fa19a69d47ece02fa469825b442263f04021f \ + --hash=sha256:77da4c6bfa20dd5ea25cbf12c76f181a8e8cd7ea231c673828d0386b1740b8dc \ + --hash=sha256:77ea385f7dd5b5676d7fd943292ffa18fbf5c72ba98f7d09fc1fb9e819b34c23 \ + --hash=sha256:80080816b4f52a9d886e67f1f96912891074903238fe54f2de8b786f86baded2 \ + --hash=sha256:80a539906390591dd39ebb8d773771dc4db82ace6372c4d41e2d293f8e32b8db \ + --hash=sha256:82d17e94d735c99621bf8ebf9995f870a6b3e6d14543b99e201ae046dfe7de70 \ + --hash=sha256:837bb6764be6919963ef41235fd56a6486b132ea64afe5fafb4cb279ac44f259 \ + --hash=sha256:84433dddea68571a6d6bd4fbf8ff398236031149116a7fff6f777ff95cad3df9 \ + --hash=sha256:8c24f21fa2af4bb9f2c492a86fe0c34e6d2c63812a839590edaf177b7398f700 \ + --hash=sha256:8ed7d27cb56b3e058d3cf684d7200703bcae623e1dcc06ed1e18ecda39fee003 \ + --hash=sha256:9206649ec587e6b02bd124fb7799b86cddec350f6f6c14bc82a2b70183e708ba \ + --hash=sha256:983b6efd649723474f29ed42e1467f90a35a74793437d0bc64a5bf482bedfa0a \ + --hash=sha256:98da17ce9cbf3bfe4617e836d561e433f871129e3a7ac16d6ef4c680f13a839c \ + --hash=sha256:9c236e635582742fee16603042553d276cca506e824fa2e6489db04039521e90 \ + --hash=sha256:9da6bc32faac9a293ddfdcb9108d4b20416219461e4ec64dfea8383cac186690 \ + --hash=sha256:a05e6d6218461eb1b4771d973728f0133b2a4613a6779995df557f70794fd60f \ + --hash=sha256:a0817825b900fcd43ac5d05b8b3079937073d2b1ff9cf89427590718b70dd840 \ + --hash=sha256:a4ae99c57668ca1e78597d8b06d5af837f377f340f4cce993b551b2d7731778d \ + --hash=sha256:a8c86881813a78a6f4508ef9daf9d4995b8ac2d147dcb1a450448941398091c9 \ + --hash=sha256:a8fffdbd9d1408006baaf02f1068d7dd1f016c6bcb7538682622c556e7b68e35 \ + --hash=sha256:a9b07268d0c3ca5c170a385a0ab9fb7fdd9f5fd866be004c4ea39e44edce47dd \ + --hash=sha256:ab19a2d91963ed9e42b4e8d77cd847ae8381576585bad79dbd0a8837a9f6620a \ + --hash=sha256:ac184f87ff521f4840e6ea0b10c0ec90c6b1dcd0bad2f1e4a9a1b4fa177982ea \ + --hash=sha256:b0e166f698c5a3e914947388c162be2583e0c638a4703fc6a543e23a88dea3c1 \ + --hash=sha256:b2170c7e0367dde86a2647ed5b6f57394ea7f53545746104c6b09fc1f4223573 \ + --hash=sha256:b2d8c62d08e7255f68f7a740bae85b3c9b8e5466baa9cbf7f57f1cde0ac6bc09 \ + --hash=sha256:b4567955a6bc1b20e9c31612e615af6b53733491aeaa19a6b3b37f3b65477094 \ + --hash=sha256:b69bb4f51daf461b15e7b3db033160937d3ff88303a7bc808c67bbc1eaf98c78 \ + --hash=sha256:b8c0bd73aeac689beacd4e7667d48c299f61b959475cdbb91e7d3d88d27c56b9 \ + --hash=sha256:be9b5b8659dff1f913039c2feee1aca499cfbc19e98fa12bc85e037c17ec6ca5 \ + --hash=sha256:bf0a05b6059c0528477fba9054d09179beb63744355cab9f38059548fedd46a9 \ + --hash=sha256:c16842b846a8d2a145223f520b7e18b57c8f476924bda92aeee3a88d11cfc391 \ + --hash=sha256:c363b53e257246a954ebc7c488304b5592b9c53fbe74d03bc1c64dda153fb847 \ + --hash=sha256:c7c517d74bea1a6afd39aa612fa025e6b8011982a0897768a2f7c8ab4ebb78a2 \ + --hash=sha256:d20fd853fbb5807c8e84c136c278827b6167ded66c72ec6f9a14b863d809211c \ + --hash=sha256:d2240ddc86b74966c34554c49d00eaafa8200a18d3a5b6ffbf7da63b11d74ee2 \ + --hash=sha256:d477ed829077cd945b01fc3115edd132c47e6540ddcd96ca169facff28173057 \ + --hash=sha256:d50d31bfedd53a928fed6707b15a8dbeef011bb6366297cc435accc888b27c20 \ + --hash=sha256:dc1d33abb8a0d754ea4763bad944fd965d3d95b5baef6b121c0c9013eaf1907d \ + --hash=sha256:dc5d1a49d3f8262be192589a4b72f0d03b72dcf46c51ad5852a4fdc67be7b9e4 \ + --hash=sha256:e2d1a054f8f0a191004675755448d12be47fa9bebbcffa3cdf01db19f2d30a54 \ + --hash=sha256:e7792606d606c8df5277c32ccb58f29b9b8603bf83b48639b7aedf6df4fe8171 \ + --hash=sha256:ed1708dbf4d2e3a1c5c69110ba2b4eb6678262028afd6c6fbcc5a8dac9cda68e \ + --hash=sha256:f2d4380bf5f62daabd7b751ea2339c1a21d1c9463f1feb7fc2bdcea2c29c3160 \ + --hash=sha256:f3513916e8c645d0610815c257cbfd3242adfd5c4cfa78be514e5a3ebb42a41b \ + --hash=sha256:f8346bfa098532bc1fb6c7ef06783e969d87a99dd1d2a5a18a892c1d7a643c58 \ + --hash=sha256:f83fa6cae3fff8e98691248c9320356971b59678a17f20656a9e59cd32cee6d8 \ + --hash=sha256:fa6ce8b52c5987b3e34d5674b0ab529a4602b632ebab0a93b07bfb4dfc8f8a33 \ + --hash=sha256:fb2b1ecfef1e67897d336de3a0e3f52478182d6a47eda86cbd42504c5cbd009a \ + --hash=sha256:fc9ca1c9718cb3b06634c7c8dec57d24e9438b2aa9a0f02b8bb36bf478538880 \ + --hash=sha256:fd30d9c67d13d891f2360b2a120186729c111238ac63b43dbd37a5a40670b8ca \ + --hash=sha256:fd7699e8fd9969f455ef2926221e0233f81a2542921471382e77a9e2f2b57f4b \ + --hash=sha256:fe3b385d996ee0822fd46528d9f0443b880d4d05528fd26a9119a54ec3f91c69 # via -r build/requirements.in # The following packages are considered to be unsafe in a requirements file: -setuptools==69.2.0 - # via -r build/test-requirements.txt +setuptools==70.3.0 \ + --hash=sha256:f171bab1dfbc86b132997f26a119f6056a57950d058587841a0082e8830f9dc5 \ + --hash=sha256:fe384da74336c398e0d956d1cae0669bc02eed936cdb1d49b57de1990dc11ffc + # via + # -r build/requirements.in + # -r build/test-requirements.txt diff --git a/build/rocm/Dockerfile.ms b/build/rocm/Dockerfile.ms index 5fc0afa326af..0bcc89f493ce 100644 --- a/build/rocm/Dockerfile.ms +++ b/build/rocm/Dockerfile.ms @@ -1,20 +1,23 @@ ################################################################################ -ARG BASE_DOCKER=ubuntu:20.04 -FROM $BASE_DOCKER as rt_build +FROM ubuntu:20.04 AS rocm_base ################################################################################ +RUN --mount=type=cache,target=/var/cache/apt \ + apt-get update && apt-get install -y python3 python-is-python3 + # Add target file to help determine which device(s) to build for -ARG GPU_DEVICE_TARGETS="gfx900 gfx906 gfx908 gfx90a gfx940 gfx941 gfx942 gfx1030 gfx1100" +ARG GPU_DEVICE_TARGETS="gfx908 gfx90a gfx940 gfx941 gfx942 gfx1030 gfx1100" ENV GPU_DEVICE_TARGETS=${GPU_DEVICE_TARGETS} # Install ROCM ARG ROCM_VERSION=6.0.0 -ARG CUSTOM_INSTALL ARG ROCM_PATH=/opt/rocm-${ROCM_VERSION} ENV ROCM_PATH=${ROCM_PATH} -COPY ${CUSTOM_INSTALL} /${CUSTOM_INSTALL} -COPY setup.rocm.sh /setup.rocm.sh -RUN /setup.rocm.sh $ROCM_VERSION +ARG ROCM_BUILD_JOB +ARG ROCM_BUILD_NUM +RUN --mount=type=bind,source=build/rocm/tools/get_rocm.py,target=get_rocm.py \ + --mount=type=cache,target=/var/cache/apt \ + python3 get_rocm.py --rocm-version=$ROCM_VERSION --job-name=$ROCM_BUILD_JOB --build-num=$ROCM_BUILD_NUM # Set up paths ENV HCC_HOME=$ROCM_PATH/hcc @@ -25,12 +28,61 @@ ENV PATH="$ROCM_PATH/bin:${PATH}" ENV PATH="$OPENCL_ROOT/bin:${PATH}" ENV PATH="/root/bin:/root/.local/bin:$PATH" +# install pyenv and python build dependencies +RUN --mount=type=cache,target=/var/cache/apt \ + apt-get update && apt-get install -y \ + git \ + libssl-dev \ + libffi-dev \ + libreadline-dev \ + liblzma-dev # Install pyenv with different python versions -ARG PYTHON_VERSION=3.10.0 +ARG PYTHON_VERSION=3.10.14 RUN git clone https://github.com/pyenv/pyenv.git /pyenv ENV PYENV_ROOT /pyenv ENV PATH $PYENV_ROOT/shims:$PYENV_ROOT/bin:$PATH RUN pyenv install $PYTHON_VERSION -RUN eval "$(pyenv init -)" && pyenv local ${PYTHON_VERSION} && pip3 install --upgrade --force-reinstall setuptools pip && pip install numpy setuptools build wheel six auditwheel scipy pytest pytest-html pytest_html_merger pytest-reportlog pytest-json-report pytest-csv pytest-rerunfailures cloudpickle portpicker matplotlib absl-py flatbuffers hypothesis +RUN --mount=type=cache,mode=0755,target=/root/.cache/pip \ + eval "$(pyenv init -)" && \ + pyenv local ${PYTHON_VERSION} && \ + pip3 install --upgrade --force-reinstall setuptools pip && \ + pip3 install \ + "numpy<2" \ + build \ + wheel \ + six \ + auditwheel \ + scipy \ + pytest \ + pytest-html \ + pytest_html_merger \ + pytest-reportlog \ + pytest-rerunfailures \ + pytest-json-report \ + pytest-csv \ + cloudpickle \ + portpicker \ + matplotlib \ + absl-py \ + flatbuffers \ + hypothesis + +################################################################################ +FROM rocm_base AS rt_build +################################################################################ + +ARG JAX_VERSION +ARG JAX_COMMIT +ARG XLA_COMMIT + +LABEL com.amdgpu.rocm_version="$ROCM_VERSION" \ + com.amdgpu.python_version="$PYTHON_VERSION" \ + com.amdgpu.jax_version="$JAX_VERSION" \ + com.amdgpu.jax_commit="$JAX_COMMIT" \ + com.amdgpu.xla_commit="$XLA_COMMIT" + +RUN --mount=type=cache,mode=0755,target=/root/.cache/pip \ + --mount=type=bind,source=wheelhouse,target=/wheelhouse \ + pip install --find-links /wheelhouse jax jaxlib jax_rocm60_plugin jax_rocm60_pjrt diff --git a/build/rocm/build_rocm.sh b/build/rocm/build_rocm.sh index 6374a2a18929..111998d35608 100755 --- a/build/rocm/build_rocm.sh +++ b/build/rocm/build_rocm.sh @@ -1,4 +1,5 @@ #!/usr/bin/env bash + # Copyright 2022 The JAX Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -13,57 +14,12 @@ # See the License for the specific language governing permissions and # limitations under the License. -# Environment Var Notes -# XLA_CLONE_DIR - -# Specifies filepath to where XLA repo is cloned. -# NOTE:, if this is set then XLA repo is not cloned. Must clone repo before running this script. -# Also, if this is set then setting XLA_REPO and XLA_BRANCH have no effect. -# XLA_REPO -# XLA repo to clone from. Default is https://github.com/ROCmSoftwarePlatform/tensorflow-upstream -# XLA_BRANCH -# XLA branch in the XLA repo. Default is develop-upstream-jax -# +# NOTE(mrodden): ROCm JAX build and installs have moved to wheel based builds and installs, +# but some CI scripts still try to run this script. Nothing needs to be done here, +# but we print some debugging information for logs. set -eux python -V -#If XLA_REPO is not set, then use default -if [ ! -v XLA_REPO ]; then - XLA_REPO="https://github.com/openxla/xla.git" - XLA_BRANCH="main" -elif [ -z "$XLA_REPO" ]; then - XLA_REPO="https://github.com/openxla/xla.git" - XLA_BRANCH="main" -fi - -#If XLA_CLONE_PATH is not set, then use default path. -#Note, setting XLA_CLONE_PATH makes setting XLA_REPO and XLA_BRANCH a no-op -#Set this when XLA repository has been already clone. This is useful in CI -#environments and when doing local development -if [ ! -v XLA_CLONE_DIR ]; then - XLA_CLONE_DIR=/tmp/xla - rm -rf /tmp/xla || true - git clone -b ${XLA_BRANCH} ${XLA_REPO} /tmp/xla -elif [ -z "$XLA_CLONE_DIR" ]; then - XLA_CLONE_DIR=/tmp/xla - rm -rf /tmp/xla || true - git clone -b ${XLA_BRANCH} ${XLA_REPO} /tmp/xla -fi - - -#Export JAX_ROCM_VERSION so that it is appened in the wheel name -export JAXLIB_RELEASE=1 -rocm_version=$(cat /opt/rocm/.info/version | cut -d "-" -f 1) -export JAX_ROCM_VERSION=${rocm_version//./} - -#Build and install wheel -python3 ./build/build.py --enable_rocm --build_gpu_plugin --gpu_plugin_rocm_version=60 --rocm_path=${ROCM_PATH} --bazel_options=--override_repository=xla=${XLA_CLONE_DIR} - -JAX_RELEASE=1 python -m build -pip3 install --force-reinstall dist/*.whl # installs jaxlib (includes XLA) - -#This is for CI to read without having to start the container again -if [ -v CI_RUN ]; then - pip3 list | grep jaxlib | tr -s ' ' | cut -d " " -f 2 | cut -d "+" -f 1 > jax_version_installed - cat /opt/rocm/.info/version | cut -d "-" -f 1 > jax_rocm_version -fi +printf "Detected jaxlib version: %s\n" $(pip3 list | grep jaxlib | tr -s ' ' | cut -d " " -f 2 | cut -d "+" -f 1) +printf "Detected ROCm version: %s\n" $(cat /opt/rocm/.info/version | cut -d "-" -f 1) diff --git a/build/rocm/build_wheels/Dockerfile.manylinux_2_28_x86_64.rocm b/build/rocm/build_wheels/Dockerfile.manylinux_2_28_x86_64.rocm new file mode 100644 index 000000000000..caf303d45ff3 --- /dev/null +++ b/build/rocm/build_wheels/Dockerfile.manylinux_2_28_x86_64.rocm @@ -0,0 +1,9 @@ +FROM quay.io/pypa/manylinux_2_28_x86_64 + +ARG ROCM_VERSION=6.1.1 +ARG ROCM_BUILD_JOB +ARG ROCM_BUILD_NUM + +RUN --mount=type=cache,target=/var/cache/dnf \ + --mount=type=bind,source=build/rocm/tools/get_rocm.py,target=get_rocm.py \ + python3 get_rocm.py --rocm-version=$ROCM_VERSION --job-name=$ROCM_BUILD_JOB --build-num=$ROCM_BUILD_NUM diff --git a/build/rocm/ci_build b/build/rocm/ci_build new file mode 100755 index 000000000000..aeb0201e27ed --- /dev/null +++ b/build/rocm/ci_build @@ -0,0 +1,318 @@ +#!/usr/bin/env python3 + +# Copyright 2024 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +# NOTE(mrodden): This file is part of the ROCm build scripts, and +# needs be compatible with Python 3.6. Please do not include these +# in any "upgrade" scripts + + +import argparse +import os +import subprocess +import sys + + +def image_by_name(name): + cmd = ["docker", "images", "-q", "-f", "reference=%s" % name] + out = subprocess.check_output(cmd) + image_id = out.decode("utf8").strip().split("\n")[0] or None + return image_id + + +def dist_wheels( + rocm_version, python_versions, xla_path, rocm_build_job="", rocm_build_num="", + compiler="gcc" +): + if xla_path: + xla_path = os.path.abspath(xla_path) + + # create manylinux image with requested ROCm installed + image = "jax-manylinux_2_28_x86_64_rocm%s" % rocm_version.replace(".", "") + + cmd = [ + "docker", + "build", + "-f", + "build/rocm/build_wheels/Dockerfile.manylinux_2_28_x86_64.rocm", + "--build-arg=ROCM_VERSION=%s" % rocm_version, + "--build-arg=ROCM_BUILD_JOB=%s" % rocm_build_job, + "--build-arg=ROCM_BUILD_NUM=%s" % rocm_build_num, + "--tag=%s" % image, + ".", + ] + + if not image_by_name(image): + _ = subprocess.run(cmd, check=True) + + # use image to build JAX/jaxlib wheels + os.makedirs("wheelhouse", exist_ok=True) + + pyver_string = ",".join(python_versions) + + container_xla_path = "/xla" + + bw_cmd = [ + "python3", + "/jax/build/rocm/tools/build_wheels.py", + "--rocm-version", + rocm_version, + "--python-versions", + pyver_string, + "--compiler", + compiler, + ] + + if xla_path: + bw_cmd.extend(["--xla-path", container_xla_path]) + + bw_cmd.append("/jax") + + cmd = ["docker", "run"] + + mounts = [ + "-v", + "./:/jax", + "-v", + "./wheelhouse:/wheelhouse", + ] + + if xla_path: + mounts.extend(["-v", "%s:%s" % (xla_path, container_xla_path)]) + + cmd.extend(mounts) + + if os.isatty(sys.stdout.fileno()): + cmd.append("-it") + + # NOTE(mrodden): bazel times out without --init, probably blocking on a zombie PID + cmd.extend( + [ + "--init", + "--rm", + image, + "bash", + "-c", + " ".join(bw_cmd), + ] + ) + + _ = subprocess.run(cmd, check=True) + + +def _fetch_jax_metadata(xla_path): + cmd = ["git", "rev-parse", "HEAD"] + jax_commit = subprocess.check_output(cmd) + xla_commit = "" + + if xla_path: + try: + xla_commit = subprocess.check_output(cmd, cwd=xla_path) + except Exception as ex: + LOG.warning("Exception while retrieving xla_commit: %s" % ex) + + cmd = ["python3", "setup.py", "-V"] + env = dict(os.environ) + env["JAX_RELEASE"] = "1" + + jax_version = subprocess.check_output(cmd, env=env) + + return { + "jax_version": jax_version.decode("utf8").strip(), + "jax_commit": jax_commit.decode("utf8").strip(), + "xla_commit": xla_commit.decode("utf8").strip(), + } + + +def dist_docker( + rocm_version, + python_versions, + xla_path, + rocm_build_job="", + rocm_build_num="", + tag="rocm/jax-dev", + dockerfile=None, + keep_image=True, +): + if not dockerfile: + dockerfile = "build/rocm/Dockerfile.ms" + + python_version = python_versions[0] + + md = _fetch_jax_metadata(xla_path) + + cmd = [ + "docker", + "build", + "-f", + dockerfile, + "--target", + "rt_build", + "--build-arg=ROCM_VERSION=%s" % rocm_version, + "--build-arg=ROCM_BUILD_JOB=%s" % rocm_build_job, + "--build-arg=ROCM_BUILD_NUM=%s" % rocm_build_num, + "--build-arg=PYTHON_VERSION=%s" % python_version, + "--build-arg=JAX_VERSION=%(jax_version)s" % md, + "--build-arg=JAX_COMMIT=%(jax_commit)s" % md, + "--build-arg=XLA_COMMIT=%(xla_commit)s" % md, + "--tag=%s" % tag, + ] + + if not keep_image: + cmd.append("--rm") + + # context dir + cmd.append(".") + + subprocess.check_call(cmd) + + +def test(image_name): + """Run unit tests like CI would inside a JAX image.""" + + gpu_args = [ + "--device=/dev/kfd", + "--device=/dev/dri", + "--group-add", + "video", + "--cap-add=SYS_PTRACE", + "--security-opt", + "seccomp=unconfined", + "--shm-size", + "16G", + ] + + cmd = [ + "docker", + "run", + "-it", + "--rm", + ] + + # NOTE(mrodden): we need jax source dir for the unit test code only, + # JAX and jaxlib are already installed from wheels + mounts = [ + "-v", + "./:/jax", + ] + + cmd.extend(mounts) + cmd.extend(gpu_args) + + container_cmd = "cd /jax && ./build/rocm/build_rocm.sh && ./build/rocm/run_single_gpu.py -c && ./build/rocm/run_multi_gpu.sh" + cmd.append(image_name) + cmd.extend( + [ + "bash", + "-c", + container_cmd, + ] + ) + + subprocess.check_call(cmd) + + +def parse_args(): + p = argparse.ArgumentParser() + p.add_argument( + "--python-versions", + type=lambda x: x.split(","), + default="3.12", + help="Comma separated list of CPython versions to build wheels for", + ) + + p.add_argument( + "--rocm-version", + default="6.1.1", + help="ROCm version used for building wheels, testing, and installing into Docker image", + ) + + p.add_argument( + "--rocm-build-job", + default="", + help="ROCm build job for development ROCm builds", + ) + + p.add_argument( + "--rocm-build-num", + default="", + help="ROCm build number for development ROCm builds", + ) + + p.add_argument( + "--xla-source-dir", + help="Path to XLA source to use during jaxlib build, instead of builtin XLA", + ) + + p.add_argument( + "--compiler", + choices=["gcc", "clang"], + help="Compiler backend to use when compiling jax/jaxlib" + ) + + subp = p.add_subparsers(dest="action", required=True) + + dwp = subp.add_parser("dist_wheels") + + testp = subp.add_parser("test") + testp.add_argument("image_name") + + ddp = subp.add_parser("dist_docker") + ddp.add_argument("--dockerfile", default="build/rocm/Dockerfile.ms") + ddp.add_argument("--keep-image", action="store_true") + ddp.add_argument("--image-tag", default="rocm/jax-dev") + + return p.parse_args() + + +def main(): + args = parse_args() + + if args.action == "dist_wheels": + dist_wheels( + args.rocm_version, + args.python_versions, + args.xla_source_dir, + args.rocm_build_job, + args.rocm_build_num, + ) + + elif args.action == "test": + test(args.image_name) + + elif args.action == "dist_docker": + dist_wheels( + args.rocm_version, + args.python_versions, + args.xla_source_dir, + args.rocm_build_job, + args.rocm_build_num, + args.compiler, + ) + dist_docker( + args.rocm_version, + args.python_versions, + args.xla_source_dir, + rocm_build_job=args.rocm_build_job, + rocm_build_num=args.rocm_build_num, + tag=args.image_tag, + dockerfile=args.dockerfile, + keep_image=args.keep_image, + ) + + +if __name__ == "__main__": + main() diff --git a/build/rocm/ci_build.sh b/build/rocm/ci_build.sh index 9084651bed4c..302a0449b19e 100755 --- a/build/rocm/ci_build.sh +++ b/build/rocm/ci_build.sh @@ -1,4 +1,5 @@ #!/usr/bin/env bash + # Copyright 2022 The JAX Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -28,12 +29,8 @@ # # ROCM_VERSION: ROCm repo version # -# ROCM_PATH: ROCM path in the docker container -# # Environment variables read by this script # WORKSPACE -# XLA_REPO -# XLA_BRANCH # XLA_CLONE_DIR # BUILD_TAG # @@ -44,75 +41,78 @@ SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" source "${SCRIPT_DIR}/build_common.sh" CONTAINER_TYPE="rocm" -DOCKERFILE_PATH="${SCRIPT_DIR}/Dockerfile.ms" +DOCKERFILE_PATH="${SCRIPT_DIR}/Dockerfile.ms" DOCKER_CONTEXT_PATH="${SCRIPT_DIR}" KEEP_IMAGE="--rm" -KEEP_CONTAINER="--rm" -PYTHON_VERSION="3.10.0" -ROCM_VERSION="6.0.0" #Point to latest release +PYTHON_VERSION="3.10" +ROCM_VERSION="6.1.3" +ROCM_BUILD_JOB="" +ROCM_BUILD_NUM="" BASE_DOCKER="ubuntu:20.04" CUSTOM_INSTALL="" -#BASE_DOCKER="compute-artifactory.amd.com:5000/rocm-plus-docker/compute-rocm-rel-6.0:91-ubuntu-20.04-stg2" -#CUSTOM_INSTALL="custom_install_dummy.sh" -#ROCM_PATH="/opt/rocm-5.6.0" +JAX_USE_CLANG="" POSITIONAL_ARGS=() RUNTIME_FLAG=1 while [[ $# -gt 0 ]]; do - case $1 in - --py_version) - PYTHON_VERSION="$2" - shift 2 - ;; - --dockerfile) - DOCKERFILE_PATH="$2" - DOCKER_CONTEXT_PATH=$(dirname "${DOCKERFILE_PATH}") - shift 2 - ;; - --keep_image) - KEEP_IMAGE="" - shift 1 - ;; - --runtime) - RUNTIME_FLAG=1 - shift 1 - ;; - --keep_container) - KEEP_CONTAINER="" - shift 1 - ;; - --rocm_version) - ROCM_VERSION="$2" - shift 2 - ;; - #--rocm_path) - # ROCM_PATH="$2" - # shift 2 - # ;; - - *) - POSITIONAL_ARGS+=("$1") - shift - ;; - esac + case $1 in + --py_version) + PYTHON_VERSION="$2" + shift 2 + ;; + --dockerfile) + DOCKERFILE_PATH="$2" + DOCKER_CONTEXT_PATH=$(dirname "${DOCKERFILE_PATH}") + shift 2 + ;; + --keep_image) + KEEP_IMAGE="" + shift 1 + ;; + --runtime) + RUNTIME_FLAG=1 + shift 1 + ;; + --keep_container) + KEEP_CONTAINER="" + shift 1 + ;; + --rocm_version) + ROCM_VERSION="$2" + shift 2 + ;; + --rocm_job) + ROCM_BUILD_JOB="$2" + shift 2 + ;; + --rocm_build) + ROCM_BUILD_NUM="$2" + shift 2 + ;; + --use_clang) + JAX_USE_CLANG="$2" + shift 2 + ;; + *) + POSITIONAL_ARGS+=("$1") + shift + ;; + esac done if [[ ! -f "${DOCKERFILE_PATH}" ]]; then - die "Invalid Dockerfile path: \"${DOCKERFILE_PATH}\"" + die "Invalid Dockerfile path: \"${DOCKERFILE_PATH}\"" fi -ROCM_EXTRA_PARAMS="--device=/dev/kfd --device=/dev/dri --group-add video \ - --cap-add=SYS_PTRACE --security-opt seccomp=unconfined --shm-size 16G" - # Helper function to traverse directories up until given file is found. function upsearch (){ - test / == "$PWD" && return || \ - test -e "$1" && echo "$PWD" && return || \ - cd .. && upsearch "$1" + test / == "$PWD" && return || \ + test -e "$1" && echo "$PWD" && return || \ + cd .. && upsearch "$1" } -# Set up WORKSPACE. +# Set up WORKSPACE. WORKSPACE="${WORKSPACE:-$(upsearch WORKSPACE)}" BUILD_TAG="${BUILD_TAG:-jax}" @@ -126,6 +126,7 @@ DOCKER_IMG_NAME=$(echo "${DOCKER_IMG_NAME}" | sed -e 's/=/_/g' -e 's/,/-/g') # Convert to all lower-case, as per requirement of Docker image names DOCKER_IMG_NAME=$(echo "${DOCKER_IMG_NAME}" | tr '[:upper:]' '[:lower:]') + # Print arguments. echo "WORKSPACE: ${WORKSPACE}" echo "COMMAND: ${POSITIONAL_ARGS[*]}" @@ -135,55 +136,34 @@ echo "" echo "Building container (${DOCKER_IMG_NAME})..." echo "Python Version (${PYTHON_VERSION})" -if [[ "${RUNTIME_FLAG}" -eq 1 ]]; then - echo "Building (runtime) container (${DOCKER_IMG_NAME}) with Dockerfile($DOCKERFILE_PATH)..." - docker build --target rt_build --tag ${DOCKER_IMG_NAME} \ - --build-arg PYTHON_VERSION=$PYTHON_VERSION --build-arg ROCM_VERSION=$ROCM_VERSION \ - --build-arg CUSTOM_INSTALL=$CUSTOM_INSTALL \ - --build-arg BASE_DOCKER=$BASE_DOCKER \ - -f "${DOCKERFILE_PATH}" "${DOCKER_CONTEXT_PATH}" -else - echo "Building (CI) container (${DOCKER_IMG_NAME}) with Dockerfile($DOCKERFILE_PATH)..." - docker build --target ci_build --tag ${DOCKER_IMG_NAME} \ - --build-arg PYTHON_VERSION=$PYTHON_VERSION \ - --build-arg BASE_DOCKER=$BASE_DOCKER \ - -f "${DOCKERFILE_PATH}" "${DOCKER_CONTEXT_PATH}" -fi +echo "Building (runtime) container (${DOCKER_IMG_NAME}) with Dockerfile($DOCKERFILE_PATH)..." -# Check docker build status -if [[ $? != "0" ]]; then - die "ERROR: docker build failed. Dockerfile is at ${DOCKERFILE_PATH}" -fi - -# Run the command inside the container. -echo "Running '${POSITIONAL_ARGS[*]}' inside ${DOCKER_IMG_NAME}..." - -export XLA_REPO="${XLA_REPO:-}" -export XLA_BRANCH="${XLA_BRANCH:-}" export XLA_CLONE_DIR="${XLA_CLONE_DIR:-}" -export JAX_RENAME_WHL="${XLA_CLONE_DIR:-}" -if [ ! -z ${XLA_CLONE_DIR} ]; then - ROCM_EXTRA_PARAMS=${ROCM_EXTRA_PARAMS}" -v ${XLA_CLONE_DIR}:${XLA_CLONE_DIR}" +# default to gcc +JAX_COMPILER="gcc" +if [ -n "$JAX_USE_CLANG" ]; then + JAX_COMPILER="clang" fi -docker run ${KEEP_IMAGE} --name ${DOCKER_IMG_NAME} --pid=host \ - -v ${WORKSPACE}:/workspace \ - -w /workspace \ - -e XLA_REPO=${XLA_REPO} \ - -e XLA_BRANCH=${XLA_BRANCH} \ - -e XLA_CLONE_DIR=${XLA_CLONE_DIR} \ - -e PYTHON_VERSION=$PYTHON_VERSION \ - -e CI_RUN=1 \ - ${ROCM_EXTRA_PARAMS} \ - "${DOCKER_IMG_NAME}" \ - ${POSITIONAL_ARGS[@]} - -if [[ "${KEEP_IMAGE}" != "--rm" ]] && [[ $? == "0" ]]; then - echo "Committing the docker container as ${DOCKER_IMG_NAME}" - docker stop ${DOCKER_IMG_NAME} - docker commit ${DOCKER_IMG_NAME} ${DOCKER_IMG_NAME} - docker rm ${DOCKER_IMG_NAME} # remove this temp container +# ci_build.sh is mostly a compatibility wrapper for ci_build + +# 'dist_docker' will run 'dist_wheels' followed by a Docker build to create the "JAX image", +# which is the ROCm image that is shipped for users to use (i.e. distributable). +./build/rocm/ci_build \ + --rocm-version $ROCM_VERSION \ + --python-versions $PYTHON_VERSION \ + --xla-source-dir=$XLA_CLONE_DIR \ + --rocm-build-job=$ROCM_BUILD_JOB \ + --rocm-build-num=$ROCM_BUILD_NUM \ + --compiler=$JAX_COMPILER \ + dist_docker \ + --dockerfile $DOCKERFILE_PATH \ + --image-tag $DOCKER_IMG_NAME + +# Check build status +if [[ $? != "0" ]]; then + die "ERROR: docker build failed. Dockerfile is at ${DOCKERFILE_PATH}" fi echo "Jax-ROCm build was successful!" diff --git a/build/rocm/dev_build_rocm.py b/build/rocm/dev_build_rocm.py new file mode 100755 index 000000000000..2be64152f667 --- /dev/null +++ b/build/rocm/dev_build_rocm.py @@ -0,0 +1,165 @@ +# !/usr/bin/env python3 +# +# Copyright 2024 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# NOTE(ruturaj4): This script automates the build process for JAX and XLA on ROCm, +# allowing for optional uninstallation of existing packages, and custom paths for ROCm and XLA repositories. + +import argparse +import os +import shutil +import subprocess +import sys + + +def get_rocm_version(): + try: + version = subprocess.check_output( + "cat /opt/rocm/.info/version | cut -d '-' -f 1", shell=True + ) + return version.decode("utf-8").strip() + except subprocess.CalledProcessError as e: + print(f"Error fetching ROCm version: {e}") + return None + + +def get_rocm_target(): + try: + target_info = subprocess.check_output( + "rocminfo | grep gfx | head -n 1", shell=True + ) + target = target_info.decode("utf-8").split()[1] + return target + except subprocess.CalledProcessError as e: + print(f"Error fetching ROCm target: {e}") + return None + + +def uninstall_existing_packages(packages): + cmd = ["python3", "-m", "pip", "uninstall", "-y"] + cmd.extend(packages) + + try: + subprocess.run(cmd, check=True) + print(f"Successfully uninstalled {packages}") + except subprocess.CalledProcessError as e: + print(f"Failed to uninstall {packages}: {e}") + + +def clean_dist_directory(): + try: + shutil.rmtree("dist") + print("Cleaned dist directory.") + except FileNotFoundError: + print("dist directory not found, skipping cleanup.") + except Exception as e: + print(f"Failed to clean dist directory: {e}") + sys.exit(1) + + +def build_jax_xla(xla_path, rocm_version, rocm_target, use_clang, clang_path): + bazel_options = ( + f"--bazel_options=--override_repository=xla={xla_path}" if xla_path else "" + ) + clang_option = f"--clang_path={clang_path}" if clang_path else "" + build_command = [ + "python3", + "./build/build.py", + "--enable_rocm", + "--build_gpu_plugin", + "--gpu_plugin_rocm_version=60", + f"--use_clang={str(use_clang).lower()}", + f"--rocm_amdgpu_targets={rocm_target}", + f"--rocm_path=/opt/rocm-{rocm_version}/", + bazel_options, + ] + + if clang_option: + build_command.append(clang_option) + + print("Executing build command:") + print(" ".join(build_command)) + + try: + subprocess.run(build_command, check=True) + print("Build completed successfully.") + except subprocess.CalledProcessError as e: + print(f"Build failed: {e}") + sys.exit(1) + + +def install_wheel(): + try: + subprocess.run( + ["python3", "-m", "pip", "install", "dist/*.whl"], check=True, shell=True + ) + print("Packages installed successfully.") + except subprocess.CalledProcessError as e: + print(f"Failed to install packages: {e}") + sys.exit(1) + + +def main(): + parser = argparse.ArgumentParser(description="Script to build JAX and XLA on ROCm.") + parser.add_argument( + "--clang-path", type=str, default="", help="Specify the Clang compiler path" + ) + parser.add_argument( + "--skip-uninstall", + action="store_true", + help="Skip uninstall of old versions during package install", + ) + parser.add_argument( + "--use-clang", default="false", help="Use Clang compiler if set" + ) + parser.add_argument( + "--xla-path", type=str, default="", help="Specify the XLA repository path" + ) + + args = parser.parse_args() + + if args.xla_path: + args.xla_path = os.path.abspath(args.xla_path) + print(f"Converted XLA path to absolute: {args.xla_path}") + + rocm_version = get_rocm_version() + if not rocm_version: + print("Could not determine ROCm version. Exiting.") + sys.exit(1) + + rocm_target = get_rocm_target() + if not rocm_target: + print("Could not determine ROCm target. Exiting.") + sys.exit(1) + + if not args.skip_uninstall: + print("Uninstalling existing packages...") + packages = ["jax", "jaxlib", "jax-rocm60-pjrt", "jax-rocm60-plugin"] + uninstall_existing_packages(packages) + + clean_dist_directory() + + print( + f"Building JAX and XLA with ROCm version: {rocm_version}, Target: {rocm_target}" + ) + build_jax_xla( + args.xla_path, rocm_version, rocm_target, args.use_clang, args.clang_path + ) + + install_wheel() + + +if __name__ == "__main__": + main() diff --git a/build/rocm/docker/Dockerfile.jax-ubu22 b/build/rocm/docker/Dockerfile.jax-ubu22 new file mode 100644 index 000000000000..ba64efbbc682 --- /dev/null +++ b/build/rocm/docker/Dockerfile.jax-ubu22 @@ -0,0 +1,64 @@ +FROM ubuntu:22.04 + +RUN --mount=type=cache,target=/var/cache/apt \ + apt-get update && apt-get install -y python3 python-is-python3 + +# Add target file to help determine which device(s) to build for +ARG GPU_DEVICE_TARGETS="gfx908 gfx90a gfx940 gfx941 gfx942 gfx1030 gfx1100" +ENV GPU_DEVICE_TARGETS=${GPU_DEVICE_TARGETS} + +# Install ROCM +ARG ROCM_VERSION=6.0.0 +ARG ROCM_PATH=/opt/rocm-${ROCM_VERSION} +ENV ROCM_PATH=${ROCM_PATH} +ARG ROCM_BUILD_JOB +ARG ROCM_BUILD_NUM +RUN --mount=type=bind,source=build/rocm/tools/get_rocm.py,target=get_rocm.py \ + --mount=type=cache,target=/var/cache/apt \ + python3 get_rocm.py --rocm-version=$ROCM_VERSION --job-name=$ROCM_BUILD_JOB --build-num=$ROCM_BUILD_NUM + +# Set up paths +ENV HCC_HOME=$ROCM_PATH/hcc +ENV HIP_PATH=$ROCM_PATH/ +ENV OPENCL_ROOT=$ROCM_PATH/opencl +ENV PATH="$HCC_HOME/bin:$HIP_PATH/bin:${PATH}" +ENV PATH="$ROCM_PATH/bin:${PATH}" +ENV PATH="$OPENCL_ROOT/bin:${PATH}" +ENV PATH="/root/bin:/root/.local/bin:$PATH" + +RUN --mount=type=cache,mode=0755,target=/root/.cache/pip \ + pip3 install --upgrade --force-reinstall setuptools pip && \ + pip3 install \ + "numpy<2" \ + build \ + wheel \ + six \ + auditwheel \ + scipy \ + pytest \ + pytest-html \ + pytest_html_merger \ + pytest-reportlog \ + pytest-rerunfailures \ + pytest-json-report \ + pytest-csv \ + cloudpickle \ + portpicker \ + matplotlib \ + absl-py \ + flatbuffers \ + hypothesis + +ARG JAX_VERSION +ARG JAX_COMMIT +ARG XLA_COMMIT + +LABEL com.amdgpu.rocm_version="$ROCM_VERSION" \ + com.amdgpu.python_version="3.10" \ + com.amdgpu.jax_version="$JAX_VERSION" \ + com.amdgpu.jax_commit="$JAX_COMMIT" \ + com.amdgpu.xla_commit="$XLA_COMMIT" + +RUN --mount=type=cache,mode=0755,target=/root/.cache/pip \ + --mount=type=bind,source=wheelhouse,target=/wheelhouse \ + pip3 install --find-links /wheelhouse jax jaxlib jax_rocm60_plugin jax_rocm60_pjrt diff --git a/build/rocm/docker/Dockerfile.jax-ubu24 b/build/rocm/docker/Dockerfile.jax-ubu24 new file mode 100644 index 000000000000..44c59b1b7e6b --- /dev/null +++ b/build/rocm/docker/Dockerfile.jax-ubu24 @@ -0,0 +1,63 @@ +FROM ubuntu:24.04 + +RUN --mount=type=cache,target=/var/cache/apt \ + apt-get update && apt-get install -y python3 python-is-python3 python3-pip + +# Add target file to help determine which device(s) to build for +ARG GPU_DEVICE_TARGETS="gfx908 gfx90a gfx940 gfx941 gfx942 gfx1030 gfx1100" +ENV GPU_DEVICE_TARGETS=${GPU_DEVICE_TARGETS} + +# Install ROCM +ARG ROCM_VERSION=6.2.0 +ARG ROCM_PATH=/opt/rocm-${ROCM_VERSION} +ENV ROCM_PATH=${ROCM_PATH} +ARG ROCM_BUILD_JOB +ARG ROCM_BUILD_NUM +RUN --mount=type=bind,source=build/rocm/tools/get_rocm.py,target=get_rocm.py \ + --mount=type=cache,target=/var/cache/apt \ + python3 get_rocm.py --rocm-version=$ROCM_VERSION --job-name=$ROCM_BUILD_JOB --build-num=$ROCM_BUILD_NUM + +# Set up paths +ENV HCC_HOME=$ROCM_PATH/hcc +ENV HIP_PATH=$ROCM_PATH/ +ENV OPENCL_ROOT=$ROCM_PATH/opencl +ENV PATH="$HCC_HOME/bin:$HIP_PATH/bin:${PATH}" +ENV PATH="$ROCM_PATH/bin:${PATH}" +ENV PATH="$OPENCL_ROOT/bin:${PATH}" +ENV PATH="/root/bin:/root/.local/bin:$PATH" + +RUN --mount=type=cache,mode=0755,target=/root/.cache/pip \ + pip3 install --break-system-packages \ + "numpy<2" \ + build \ + wheel \ + six \ + auditwheel \ + scipy \ + pytest \ + pytest-html \ + pytest_html_merger \ + pytest-reportlog \ + pytest-rerunfailures \ + pytest-json-report \ + pytest-csv \ + cloudpickle \ + portpicker \ + matplotlib \ + absl-py \ + flatbuffers \ + hypothesis + +ARG JAX_VERSION +ARG JAX_COMMIT +ARG XLA_COMMIT + +LABEL com.amdgpu.rocm_version="$ROCM_VERSION" \ + com.amdgpu.python_version="3.12" \ + com.amdgpu.jax_version="$JAX_VERSION" \ + com.amdgpu.jax_commit="$JAX_COMMIT" \ + com.amdgpu.xla_commit="$XLA_COMMIT" + +RUN --mount=type=cache,mode=0755,target=/root/.cache/pip \ + --mount=type=bind,source=wheelhouse,target=/wheelhouse \ + pip3 install --break-system-packages --find-links /wheelhouse jax jaxlib jax_rocm60_plugin jax_rocm60_pjrt diff --git a/build/rocm/docker/Makefile b/build/rocm/docker/Makefile new file mode 100644 index 000000000000..7fb38a936a64 --- /dev/null +++ b/build/rocm/docker/Makefile @@ -0,0 +1,20 @@ +.PHONY: all clean + +all: .docker-jax-ubu22 .docker-jax-ubu24 + +clean: clean-jax-ubu22 clean-jax-ubu24 + +ROCM_VERSION = 6.2.0 + +.docker-% : build/rocm/docker/Dockerfile.% + docker build -f $< --tag $(*F) --progress plain \ + --build-arg=ROCM_VERSION=${ROCM_VERSION} \ + --build-arg=JAX_VERSION=$(shell python setup.py -V) \ + --build-arg=JAX_COMMIT=$(shell git rev-parse HEAD) \ + . + @touch $@ + + +clean-%: + -docker rmi $(*F) + @rm -f .docker-$(*F) diff --git a/build/rocm/run_single_gpu.py b/build/rocm/run_single_gpu.py index 4e7660ca1f15..e1fa26c72872 100755 --- a/build/rocm/run_single_gpu.py +++ b/build/rocm/run_single_gpu.py @@ -169,7 +169,7 @@ def run_parallel(all_testmodules, p, c): def find_num_gpus(): - cmd = ["lspci|grep 'controller\|accel'|grep 'AMD/ATI'|wc -l"] + cmd = [r"lspci|grep 'controller\|accel'|grep 'AMD/ATI'|wc -l"] _, _, stdout = run_shell_command(cmd, shell=True) return int(stdout) diff --git a/build/rocm/setup.rocm.sh b/build/rocm/setup.rocm.sh index 1ade67b17f6e..35c8f4c5166c 100755 --- a/build/rocm/setup.rocm.sh +++ b/build/rocm/setup.rocm.sh @@ -25,7 +25,7 @@ ROCM_DEB_REPO=${ROCM_DEB_REPO_HOME}${ROCM_VERS}/ if [ ! -f "/${CUSTOM_INSTALL}" ]; then # Add rocm repository chmod 1777 /tmp - DEBIAN_FRONTEND=noninteractive apt-get --allow-unauthenticated update + DEBIAN_FRONTEND=noninteractive apt-get --allow-unauthenticated update DEBIAN_FRONTEND=noninteractive apt install -y wget software-properties-common DEBIAN_FRONTEND=noninteractive apt-get clean all wget -qO - https://repo.radeon.com/rocm/rocm.gpg.key | apt-key add -; diff --git a/build/rocm/tools/blacken.sh b/build/rocm/tools/blacken.sh new file mode 100644 index 000000000000..7b61cbdb9e10 --- /dev/null +++ b/build/rocm/tools/blacken.sh @@ -0,0 +1,3 @@ +#!/bin/bash + +black -t py36 build/rocm/ci_build build/rocm/tools/*.py diff --git a/build/rocm/tools/build_wheels.py b/build/rocm/tools/build_wheels.py new file mode 100644 index 000000000000..b6dd1256e2f5 --- /dev/null +++ b/build/rocm/tools/build_wheels.py @@ -0,0 +1,261 @@ +#!/usr/bin/env python3 + +# Copyright 2024 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +# NOTE(mrodden): This file is part of the ROCm build scripts, and +# needs be compatible with Python 3.6. Please do not include these +# in any "upgrade" scripts + + +import argparse +from collections import deque +import fcntl +import logging +import os +import re +import select +import subprocess +import shutil +import sys + + +LOG = logging.getLogger(__name__) + + +GPU_DEVICE_TARGETS = "gfx900 gfx906 gfx908 gfx90a gfx940 gfx941 gfx942 gfx1030 gfx1100" + + +def build_rocm_path(rocm_version_str): + path = "/opt/rocm-%s" % rocm_version_str + if os.path.exists(path): + return path + else: + return os.path.realpath("/opt/rocm") + + +def update_rocm_targets(rocm_path, targets): + target_fp = os.path.join(rocm_path, "bin/target.lst") + version_fp = os.path.join(rocm_path, ".info/version") + with open(target_fp, "w") as fd: + fd.write("%s\n" % targets) + + # mimic touch + open(version_fp, "a").close() + + +def build_jaxlib_wheel(jax_path, rocm_path, python_version, xla_path=None, compiler="gcc"): + use_clang = "true" if compiler == "clang" else "false" + cmd = [ + "python", + "build/build.py", + "--enable_rocm", + "--build_gpu_plugin", + "--gpu_plugin_rocm_version=60", + "--rocm_path=%s" % rocm_path, + "--use_clang=%s" % use_clang, + ] + + if xla_path: + cmd.append("--bazel_options=--override_repository=xla=%s" % xla_path) + + cpy = to_cpy_ver(python_version) + py_bin = "/opt/python/%s-%s/bin" % (cpy, cpy) + + env = dict(os.environ) + env["JAX_RELEASE"] = str(1) + env["JAXLIB_RELEASE"] = str(1) + env["PATH"] = "%s:%s" % (py_bin, env["PATH"]) + + LOG.info("Running %r from cwd=%r" % (cmd, jax_path)) + pattern = re.compile("Output wheel: (.+)\n") + + _run_scan_for_output(cmd, pattern, env=env, cwd=jax_path, capture="stderr") + + +def build_jax_wheel(jax_path, python_version): + cmd = [ + "python", + "-m", + "build", + ] + + cpy = to_cpy_ver(python_version) + py_bin = "/opt/python/%s-%s/bin" % (cpy, cpy) + + env = dict(os.environ) + env["JAX_RELEASE"] = str(1) + env["JAXLIB_RELEASE"] = str(1) + env["PATH"] = "%s:%s" % (py_bin, env["PATH"]) + + LOG.info("Running %r from cwd=%r" % (cmd, jax_path)) + pattern = re.compile(r"Successfully built jax-.+ and (jax-.+\.whl)\n") + + _run_scan_for_output(cmd, pattern, env=env, cwd=jax_path, capture="stdout") + + +def _run_scan_for_output(cmd, pattern, env=None, cwd=None, capture=None): + + buf = deque(maxlen=20000) + + if capture == "stderr": + p = subprocess.Popen(cmd, env=env, cwd=cwd, stderr=subprocess.PIPE) + redir = sys.stderr + cap_fd = p.stderr + else: + p = subprocess.Popen(cmd, env=env, cwd=cwd, stdout=subprocess.PIPE) + redir = sys.stdout + cap_fd = p.stdout + + flags = fcntl.fcntl(cap_fd, fcntl.F_GETFL) + fcntl.fcntl(cap_fd, fcntl.F_SETFL, flags | os.O_NONBLOCK) + + eof = False + while not eof: + r, _, _ = select.select([cap_fd], [], []) + for fd in r: + dat = fd.read(512) + if dat is None: + continue + elif dat: + t = dat.decode("utf8") + redir.write(t) + buf.extend(t) + else: + eof = True + + # wait and drain pipes + _, _ = p.communicate() + + if p.returncode != 0: + raise Exception( + "Child process exited with nonzero result: rc=%d" % p.returncode + ) + + text = "".join(buf) + + matches = pattern.findall(text) + + if not matches: + LOG.error("No wheel name found in output: %r" % text) + raise Exception("No wheel name found in output") + + wheels = [] + for match in matches: + LOG.info("Found built wheel: %r" % match) + wheels.append(match) + + return wheels + + +def to_cpy_ver(python_version): + tup = python_version.split(".") + return "cp%d%d" % (int(tup[0]), int(tup[1])) + + +def fix_wheel(path, jax_path): + # NOTE(mrodden): fixwheel needs auditwheel 6.0.0, which has a min python of 3.8 + # so use one of the CPythons in /opt to run + env = dict(os.environ) + py_bin = "/opt/python/cp310-cp310/bin" + env["PATH"] = "%s:%s" % (py_bin, env["PATH"]) + + cmd = ["pip", "install", "auditwheel>=6"] + subprocess.run(cmd, check=True, env=env) + + fixwheel_path = os.path.join(jax_path, "build/rocm/tools/fixwheel.py") + cmd = ["python", fixwheel_path, path] + subprocess.run(cmd, check=True, env=env) + + +def parse_args(): + p = argparse.ArgumentParser() + p.add_argument( + "--rocm-version", default="6.1.1", help="ROCM Version to build JAX against" + ) + p.add_argument( + "--python-versions", + default=["3.10.19,3.12"], + help="Comma separated CPython versions that wheels will be built and output for", + ) + p.add_argument( + "--xla-path", + type=str, + default=None, + help="Optional directory where XLA source is located to use instead of JAX builtin XLA", + ) + p.add_argument( + "--compiler", + type=str, + default="gcc", + help="Compiler backend to use when compiling jax/jaxlib", + ) + + p.add_argument("jax_path", help="Directory where JAX source directory is located") + + return p.parse_args() + + +def find_wheels(path): + wheels = [] + + for f in os.listdir(path): + if f.endswith(".whl"): + wheels.append(os.path.join(path, f)) + + LOG.info("Found wheels: %r" % wheels) + return wheels + + +def main(): + args = parse_args() + python_versions = args.python_versions.split(",") + + print("ROCM_VERSION=%s" % args.rocm_version) + print("PYTHON_VERSIONS=%r" % python_versions) + print("JAX_PATH=%s" % args.jax_path) + print("XLA_PATH=%s" % args.xla_path) + + rocm_path = build_rocm_path(args.rocm_version) + + update_rocm_targets(rocm_path, GPU_DEVICE_TARGETS) + + for py in python_versions: + build_jaxlib_wheel(args.jax_path, rocm_path, py, args.xla_path, args.compiler) + wheel_paths = find_wheels(os.path.join(args.jax_path, "dist")) + for wheel_path in wheel_paths: + # skip jax wheel since it is non-platform + if not os.path.basename(wheel_path).startswith("jax-"): + fix_wheel(wheel_path, args.jax_path) + + # build JAX wheel for completeness + build_jax_wheel(args.jax_path, python_versions[-1]) + wheels = find_wheels(os.path.join(args.jax_path, "dist")) + + # NOTE(mrodden): the jax wheel is a "non-platform wheel", so auditwheel will + # do nothing, and in fact will throw an Exception. we just need to copy it + # along with the jaxlib and plugin ones + + # copy jax wheel(s) to wheelhouse + wheelhouse_dir = "/wheelhouse/" + for whl in wheels: + if os.path.basename(whl).startswith("jax-"): + LOG.info("Copying %s into %s" % (whl, wheelhouse_dir)) + shutil.copy(whl, wheelhouse_dir) + + +if __name__ == "__main__": + logging.basicConfig(level=logging.INFO) + main() diff --git a/build/rocm/tools/fixwheel.py b/build/rocm/tools/fixwheel.py new file mode 100644 index 000000000000..ea77162728d5 --- /dev/null +++ b/build/rocm/tools/fixwheel.py @@ -0,0 +1,111 @@ +#!/usr/bin/env python3 + +# Copyright 2024 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +# NOTE(mrodden): This file is part of the ROCm build scripts, and +# needs be compatible with Python 3.6. Please do not include these +# in any "upgrade" scripts + + +import argparse +import logging +import os +from pprint import pprint +import subprocess + +from auditwheel.lddtree import lddtree +from auditwheel.wheeltools import InWheelCtx +from auditwheel.elfutils import elf_file_filter +from auditwheel.policy import WheelPolicies +from auditwheel.wheel_abi import analyze_wheel_abi + + +LOG = logging.getLogger(__name__) + + +def tree(path): + + with InWheelCtx(path) as ctx: + for sofile, fd in elf_file_filter(ctx.iter_files()): + + LOG.info("found SO file: %s" % sofile) + elftree = lddtree(sofile) + + print(elftree) + + +def parse_args(): + p = argparse.ArgumentParser() + p.add_argument("wheel_path") + return p.parse_args() + + +def parse_wheel_name(path): + wheel_name = os.path.basename(path) + return wheel_name[:-4].split("-") + + +def fix_wheel(path): + tup = parse_wheel_name(path) + plat_tag = tup[4] + if "manylinux2014" in plat_tag: + # strip any manylinux tags from the current wheel first + from wheel.cli import tags + + plat_mod_str = "linux_x86_64" + new_wheel = tags.tags( + path, + python_tags=None, + abi_tags=None, + platform_tags=plat_mod_str, + build_tag=None, + ) + new_path = os.path.join(os.path.dirname(path), new_wheel) + LOG.info("Stripped broken tags and created new wheel at %r" % new_path) + path = new_path + + # build excludes, using auditwheels lddtree to find them + wheel_pol = WheelPolicies() + exclude = frozenset() + abi = analyze_wheel_abi(wheel_pol, path, exclude) + + plat = "manylinux_2_28_x86_64" + ext_libs = abi.external_refs.get(plat, {}).get("libs") + exclude = list(ext_libs.keys()) + + # call auditwheel repair with excludes + cmd = ["auditwheel", "repair", "--plat", plat, "--only-plat"] + + for ex in exclude: + cmd.append("--exclude") + cmd.append(ex) + + cmd.append(path) + + LOG.info("running %r" % cmd) + + rc = subprocess.run(cmd, check=True) + + +def main(): + args = parse_args() + path = args.wheel_path + fix_wheel(path) + + +if __name__ == "__main__": + logging.basicConfig(level=logging.INFO) + main() diff --git a/build/rocm/tools/get_rocm.py b/build/rocm/tools/get_rocm.py new file mode 100644 index 000000000000..5334bf40ece7 --- /dev/null +++ b/build/rocm/tools/get_rocm.py @@ -0,0 +1,359 @@ +#!/usr/bin/env python3 + +# Copyright 2024 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +# NOTE(mrodden): This file is part of the ROCm build scripts, and +# needs be compatible with Python 3.6. Please do not include these +# in any "upgrade" scripts + + +import argparse +import json +import logging +import os +import sys +import subprocess +import urllib.request + + +LOG = logging.getLogger(__name__) + + +def latest_rocm(): + dat = urllib.request.urlopen( + "https://api.github.com/repos/rocm/rocm/releases/latest" + ).read() + rd = json.loads(dat) + _, ver_str = rd["tag_name"].split("-") + return ver_str + + +def os_release_meta(): + try: + os_rel = open("/etc/os-release").read() + + kvs = {} + for line in os_rel.split("\n"): + if line.strip(): + k, v = line.strip().split("=", 1) + v = v.strip('"') + kvs[k] = v + + return kvs + except OSError: + pass + + +class System(object): + + def __init__(self, pkgbin, rocm_package_list): + self.pkgbin = pkgbin + self.rocm_package_list = rocm_package_list + + def install_packages(self, package_specs): + cmd = [ + self.pkgbin, + "install", + "-y", + ] + cmd.extend(package_specs) + + env = dict(os.environ) + if self.pkgbin == "apt": + env["DEBIAN_FRONTEND"] = "noninteractive" + + LOG.info("Running %r" % cmd) + subprocess.check_call(cmd, env=env) + + def install_rocm(self): + self.install_packages(self.rocm_package_list) + + +UBUNTU = System( + pkgbin="apt", + rocm_package_list=[ + "rocm-dev", + "rocm-libs", + ], +) + + +RHEL8 = System( + pkgbin="dnf", + rocm_package_list=[ + "libdrm-amdgpu", + "rocm-dev", + "rocm-ml-sdk", + "miopen-hip ", + "miopen-hip-devel", + "rocblas", + "rocblas-devel", + "rocsolver-devel", + "rocrand-devel", + "rocfft-devel", + "hipfft-devel", + "hipblas-devel", + "rocprim-devel", + "hipcub-devel", + "rccl-devel", + "hipsparse-devel", + "hipsolver-devel", + ], +) + + +def parse_version(version_str): + if isinstance(version_str, str): + parts = version_str.split(".") + rv = type("Version", (), {})() + rv.major = int(parts[0].strip()) + rv.minor = int(parts[1].strip()) + rv.rev = None + + if len(parts) > 2: + rv.rev = int(parts[2].strip()) + + else: + rv = version_str + + return rv + + +def get_system(): + md = os_release_meta() + + if md["ID"] == "ubuntu": + return UBUNTU + + if md["ID"] in ["almalinux", "rhel", "fedora", "centos"]: + if md["PLATFORM_ID"] == "platform:el8": + return RHEL8 + + raise Exception("No system for %r" % md) + + +def _setup_internal_repo(system, rocm_version, job_name, build_num): + # wget is required by amdgpu-repo + system.install_packages(["wget"]) + + install_amdgpu_installer_internal(rocm_version) + + amdgpu_build = ( + urllib.request.urlopen( + "http://rocm-ci.amd.com/job/%s/%s/artifact/amdgpu_kernel_info.txt" + % (job_name, build_num) + ) + .read() + .decode("utf8") + .strip() + ) + + cmd = [ + "amdgpu-repo", + "--amdgpu-build=%s" % amdgpu_build, + "--rocm-build=%s/%s" % (job_name, build_num), + ] + LOG.info("Running %r" % cmd) + subprocess.check_call(cmd) + + cmd = [ + "amdgpu-install", + "--no-dkms", + "--usecase=rocm", + "-y", + ] + + env = dict(os.environ) + if system.pkgbin == "apt": + env["DEBIAN_FRONTEND"] = "noninteractive" + + LOG.info("Running %r" % cmd) + subprocess.check_call(cmd, env=env) + + +def install_rocm(rocm_version, job_name=None, build_num=None): + s = get_system() + + if job_name and build_num: + _setup_internal_repo(s, rocm_version, job_name, build_num) + else: + if s == RHEL8: + setup_repos_el8(rocm_version) + elif s == UBUNTU: + setup_repos_ubuntu(rocm_version) + else: + raise Exception("Platform not supported") + + s.install_rocm() + + +def install_amdgpu_installer_internal(rocm_version): + """ + Download and install the "amdgpu-installer" package from internal builds + on the current system. + """ + md = os_release_meta() + url, fn = _build_installer_url(rocm_version, md) + + try: + # download installer + LOG.info("Downloading from %s", url) + urllib.request.urlretrieve(url, filename=fn) + + system = get_system() + + cmd = [system.pkgbin, "install", "-y", "./%s" % fn] + subprocess.check_call(cmd) + finally: + try: + os.remove(fn) + except FileNotFoundError: + pass + + +def _build_installer_url(rocm_version, metadata): + md = metadata + + rv = parse_version(rocm_version) + + base_url = "http://artifactory-cdn.amd.com/artifactory/list" + + if md["ID"] == "ubuntu": + fmt = "amdgpu-install-internal_%(rocm_major)s.%(rocm_minor)s-%(os_version)s-1_all.deb" + package_name = fmt % { + "rocm_major": rv.major, + "rocm_minor": rv.minor, + "os_version": md["VERSION_ID"], + } + + url = "%s/amdgpu-deb/%s" % (base_url, package_name) + elif md.get("PLATFORM_ID") == "platform:el8": + fmt = "amdgpu-install-internal-%(rocm_major)s.%(rocm_minor)s_%(os_version)s-1.noarch.rpm" + package_name = fmt % { + "rocm_major": rv.major, + "rocm_minor": rv.minor, + "os_version": "8", + } + + url = "%s/amdgpu-rpm/rhel/%s" % (base_url, package_name) + else: + raise Exception("Platform not supported: %r" % md) + + return url, package_name + + +APT_RADEON_PIN_CONTENT = """ +Package: * +Pin: release o=repo.radeon.com +Pin-Priority: 600 +""" + + +def setup_repos_ubuntu(rocm_version_str): + + rv = parse_version(rocm_version_str) + + # if X.Y.0 -> repo url version should be X.Y + if rv.rev == 0: + rocm_version_str = "%d.%d" % (rv.major, rv.minor) + + s = get_system() + s.install_packages(["wget", "sudo", "gnupg"]) + + md = os_release_meta() + codename = md["VERSION_CODENAME"] + + keyadd = "wget -qO - https://repo.radeon.com/rocm/rocm.gpg.key | sudo apt-key add -" + subprocess.check_call(keyadd, shell=True) + + with open("/etc/apt/sources.list.d/amdgpu.list", "w") as fd: + fd.write( + ("deb [arch=amd64] " "https://repo.radeon.com/amdgpu/%s/ubuntu %s main\n") + % (rocm_version_str, codename) + ) + + with open("/etc/apt/sources.list.d/rocm.list", "w") as fd: + fd.write( + ("deb [arch=amd64] " "https://repo.radeon.com/rocm/apt/%s %s main\n") + % (rocm_version_str, codename) + ) + + # on ubuntu 22 or greater, debian community rocm packages + # conflict with repo.radeon.com packages + with open("/etc/apt/preferences.d/rocm-pin-600", "w") as fd: + fd.write(APT_RADEON_PIN_CONTENT) + + # update indexes + subprocess.check_call(["apt-get", "update"]) + + +def setup_repos_el8(rocm_version_str): + + with open("/etc/yum.repos.d/rocm.repo", "w") as rfd: + rfd.write( + """ +[ROCm] +name=ROCm +baseurl=http://repo.radeon.com/rocm/rhel8/%s/main +enabled=1 +gpgcheck=1 +gpgkey=https://repo.radeon.com/rocm/rocm.gpg.key +""" + % rocm_version_str + ) + + with open("/etc/yum.repos.d/amdgpu.repo", "w") as afd: + afd.write( + """ +[amdgpu] +name=amdgpu +baseurl=https://repo.radeon.com/amdgpu/latest/rhel/8.8/main/x86_64/ +enabled=1 +gpgcheck=1 +gpgkey=https://repo.radeon.com/rocm/rocm.gpg.key +""" + ) + + +def parse_args(): + p = argparse.ArgumentParser() + p.add_argument("--rocm-version", help="ROCm version to install", default="latest") + p.add_argument("--job-name", default=None) + p.add_argument("--build-num", default=None) + return p.parse_args() + + +def main(): + args = parse_args() + if args.rocm_version == "latest": + try: + rocm_version = latest_rocm() + print("Latest ROCm release: %s" % rocm_version) + except Exception: + print( + "Latest ROCm lookup failed. Please use '--rocm-version' to specify a version instead.", + file=sys.stderr, + ) + sys.exit(-1) + else: + rocm_version = args.rocm_version + + install_rocm(rocm_version, job_name=args.job_name, build_num=args.build_num) + + +if __name__ == "__main__": + logging.basicConfig(level=logging.INFO) + main() diff --git a/build/rocm/tools/libc.py b/build/rocm/tools/libc.py new file mode 100644 index 000000000000..1cd16b04cd14 --- /dev/null +++ b/build/rocm/tools/libc.py @@ -0,0 +1,62 @@ +#!/usr/bin/env python3 + +# Copyright 2024 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +# NOTE(mrodden): This file is part of the ROCm build scripts, and +# needs be compatible with Python 3.6. Please do not include these +# in any "upgrade" scripts + + +import os +import sys + + +def get_libc_version(): + """ + Detect and return glibc version that the current Python is linked against. + + This mimics the detection behavior of the 'wheel' and 'auditwheel' projects, + but without any PyPy or libmusl support. + """ + + try: + version_str = os.confstr("CS_GNU_LIBC_VERSION") + return version_str + except Exception: + print("WARN: lookup by confstr failed", file=sys.stderr) + pass + + try: + import ctypes + except ImportError: + return None + + pn = ctypes.CDLL(None) + print(dir(pn)) + + try: + gnu_get_libc_version = pn.gnu_get_libc_version + except AttributeError: + return None + + gnu_get_libc_version.restype = ctypes.c_char_p + version_str = gnu_get_libc_version() + + return version_str + + +if __name__ == "__main__": + print(get_libc_version()) diff --git a/build/rocm/tools/symbols.py b/build/rocm/tools/symbols.py new file mode 100644 index 000000000000..2982bb187c9e --- /dev/null +++ b/build/rocm/tools/symbols.py @@ -0,0 +1,67 @@ +#!/usr/bin/env python3 + +# Copyright 2024 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +# NOTE(mrodden): This file is part of the ROCm build scripts, and +# needs be compatible with Python 3.6. Please do not include these +# in any "upgrade" scripts + + +import pprint +import re +import sys +import subprocess + +""" +Utility for examining GLIBC versioned symbols +for an object file (shared object or ELF binary) +""" + + +def main(): + sofile = sys.argv[1] + + s = highest_for_file(sofile) + + print("%s: %r" % (sofile, s)) + + +def highest_for_file(sofile): + output = subprocess.check_output(["objdump", "-T", sofile]) + + r = re.compile(r"\(GLIBC_(.*)\)") + versions = {} + + for line in output.decode("utf-8").split("\n"): + line = line.strip() + match = r.search(line) + if match: + version_str = match.group(1) + count = versions.get(version_str, 0) + versions[version_str] = count + 1 + + vtups = list(map(lambda x: parse(x), versions.keys())) + s = sorted(vtups) + + return s[-1] + + +def parse(version_str): + return tuple(map(int, version_str.split("."))) + + +if __name__ == "__main__": + main() diff --git a/build/test-requirements.txt b/build/test-requirements.txt index 4f9d19e76ba2..bec6afce1853 100644 --- a/build/test-requirements.txt +++ b/build/test-requirements.txt @@ -7,9 +7,10 @@ flatbuffers hypothesis mpmath>=1.3 numpy>=1.22 -pillow>=9.1.0 +pillow>=10.4.0 portpicker pytest-xdist wheel rich -setuptools +# TODO(ybaturina): remove setuptools version +setuptools<71.0.0 diff --git a/cloud_tpu_colabs/JAX_NeurIPS_2020_demo.ipynb b/cloud_tpu_colabs/JAX_NeurIPS_2020_demo.ipynb index 1278bd01c91f..edaa71b93e85 100644 --- a/cloud_tpu_colabs/JAX_NeurIPS_2020_demo.ipynb +++ b/cloud_tpu_colabs/JAX_NeurIPS_2020_demo.ipynb @@ -38,7 +38,7 @@ }, "outputs": [], "source": [ - "key = random.PRNGKey(0)\n", + "key = random.key(0)\n", "key, subkey = random.split(key)\n", "x = random.normal(key, (5000, 5000))\n", "\n", @@ -189,7 +189,7 @@ }, "outputs": [], "source": [ - "key = random.PRNGKey(0)\n", + "key = random.key(0)\n", "x = random.normal(key, ())\n", "\n", "print(grad(f)(x))\n", @@ -261,7 +261,7 @@ }, "outputs": [], "source": [ - "key = random.PRNGKey(0)\n", + "key = random.key(0)\n", "x = random.normal(key, (5000, 5000))" ] }, @@ -451,7 +451,7 @@ "id": "jC-KIMQ1q-lK" }, "source": [ - "For more, see the [`pmap` cookbook](https://colab.research.google.com/github/google/jax/blob/main/cloud_tpu_colabs/Pmap_Cookbook.ipynb)." + "For more, see the [`pmap` cookbook](https://colab.research.google.com/github/jax-ml/jax/blob/main/cloud_tpu_colabs/Pmap_Cookbook.ipynb)." ] }, { @@ -510,8 +510,8 @@ "outputs": [], "source": [ "image_partitions = P(1, 1, 4, 2)\n", - "sharded_conv = sharded_jit(conv, \n", - " in_parts=(image_partitions, None), \n", + "sharded_conv = sharded_jit(conv,\n", + " in_parts=(image_partitions, None),\n", " out_parts=image_partitions)\n", "\n", "sharded_conv(image, kernel)" diff --git a/cloud_tpu_colabs/JAX_demo.ipynb b/cloud_tpu_colabs/JAX_demo.ipynb index 6a6993f44ed2..d7ba5ed334f4 100644 --- a/cloud_tpu_colabs/JAX_demo.ipynb +++ b/cloud_tpu_colabs/JAX_demo.ipynb @@ -27,7 +27,7 @@ "import jax.numpy as jnp\n", "from jax import random\n", "\n", - "key = random.PRNGKey(0)" + "key = random.key(0)" ] }, { @@ -194,7 +194,7 @@ }, "outputs": [], "source": [ - "key = random.PRNGKey(0)\n", + "key = random.key(0)\n", "x = random.normal(key, ())\n", "\n", "print(grad(f)(x))\n", @@ -246,7 +246,7 @@ "\n", "layer_sizes = [5, 2, 3]\n", "\n", - "key = random.PRNGKey(0)\n", + "key = random.key(0)\n", "key, *keys = random.split(key, len(layer_sizes))\n", "params = list(map(init_layer, keys, layer_sizes[:-1], layer_sizes[1:]))\n", "\n", @@ -351,7 +351,7 @@ }, "outputs": [], "source": [ - "key = random.PRNGKey(0)\n", + "key = random.key(0)\n", "x = random.normal(key, (5000, 5000))" ] }, @@ -754,7 +754,7 @@ }, "outputs": [], "source": [ - "keys = random.split(random.PRNGKey(0), 8)\n", + "keys = random.split(random.key(0), 8)\n", "mats = pmap(lambda key: random.normal(key, (5000, 5000)))(keys)\n", "result = pmap(jnp.dot)(mats, mats)\n", "print(pmap(jnp.mean)(result))" @@ -837,7 +837,7 @@ "id": "f-FBsWeo1AXE" }, "source": [ - "" + "" ] }, { @@ -847,7 +847,7 @@ "id": "jC-KIMQ1q-lK" }, "source": [ - "For more, see the [`pmap` cookbook](https://colab.research.google.com/github/google/jax/blob/main/cloud_tpu_colabs/Pmap_Cookbook.ipynb)." + "For more, see the [`pmap` cookbook](https://colab.research.google.com/github/jax-ml/jax/blob/main/cloud_tpu_colabs/Pmap_Cookbook.ipynb)." ] }, { @@ -877,7 +877,7 @@ " def g(z):\n", " return jnp.cos(z) * jnp.tan(y.sum()) * jnp.tanh(x).sum()\n", " return grad(lambda w: jnp.sum(g(w)))(x)\n", - " \n", + "\n", "f(x)" ] }, @@ -950,17 +950,6 @@ "per_example_hess = pmap(input_hess) # pmap!\n", "per_example_hess(inputs)" ] - }, - { - "cell_type": "code", - "execution_count": 0, - "metadata": { - "colab": {}, - "colab_type": "code", - "id": "u3ggM_WYZ8QC" - }, - "outputs": [], - "source": [] } ], "metadata": { diff --git a/cloud_tpu_colabs/Lorentz_ODE_Solver.ipynb b/cloud_tpu_colabs/Lorentz_ODE_Solver.ipynb index 1777d3d1ef79..84abf865851a 100644 --- a/cloud_tpu_colabs/Lorentz_ODE_Solver.ipynb +++ b/cloud_tpu_colabs/Lorentz_ODE_Solver.ipynb @@ -366,7 +366,7 @@ "\n", "# set some initial conditions for each replicate\n", "ys = jnp.zeros((N_dev, N, 3))\n", - "state0 = jr.uniform(jr.PRNGKey(1), \n", + "state0 = jr.uniform(jr.key(1), \n", " minval=-1., maxval=1.,\n", " shape=(N_dev, 3))\n", "state0 = state0 * jnp.array([18,18,1]) + jnp.array((0.,0.,10.))\n", diff --git a/cloud_tpu_colabs/Pmap_Cookbook.ipynb b/cloud_tpu_colabs/Pmap_Cookbook.ipynb index 4f4ba8c165a3..ea126ac4f1e7 100644 --- a/cloud_tpu_colabs/Pmap_Cookbook.ipynb +++ b/cloud_tpu_colabs/Pmap_Cookbook.ipynb @@ -15,13 +15,13 @@ "id": "sk-3cPGIBTq8" }, "source": [ - "[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/jax/blob/main/cloud_tpu_colabs/Pmap_Cookbook.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/google/jax/blob/main/cloud_tpu_colabs/Pmap_Cookbook.ipynb)\n", + "[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jax-ml/jax/blob/main/cloud_tpu_colabs/Pmap_Cookbook.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/jax-ml/jax/blob/main/cloud_tpu_colabs/Pmap_Cookbook.ipynb)\n", "\n", "This notebook is an introduction to writing single-program multiple-data (SPMD) programs in JAX, and executing them synchronously in parallel on multiple devices, such as multiple GPUs or multiple TPU cores. The SPMD model is useful for computations like training neural networks with synchronous gradient descent algorithms, and can be used for data-parallel as well as model-parallel computations.\n", "\n", "**Note:** To run this notebook with any parallelism, you'll need multiple XLA devices available, e.g. with a multi-GPU machine, a Colab TPU, a Google Cloud TPU or a Kaggle TPU VM.\n", "\n", - "The code in this notebook is simple. For an example of how to use these tools to do data-parallel neural network training, check out [the SPMD MNIST example](https://github.com/google/jax/blob/main/examples/spmd_mnist_classifier_fromscratch.py) or the much more capable [Trax library](https://github.com/google/trax/)." + "The code in this notebook is simple. For an example of how to use these tools to do data-parallel neural network training, check out [the SPMD MNIST example](https://github.com/jax-ml/jax/blob/main/examples/spmd_mnist_classifier_fromscratch.py) or the much more capable [Trax library](https://github.com/google/trax/)." ] }, { @@ -263,7 +263,7 @@ "from jax import random\n", "\n", "# create 8 random keys\n", - "keys = random.split(random.PRNGKey(0), 8)\n", + "keys = random.split(random.key(0), 8)\n", "# create a 5000 x 6000 matrix on each device by mapping over keys\n", "mats = pmap(lambda key: random.normal(key, (5000, 6000)))(keys)\n", "# the stack of matrices is represented logically as a single array\n", diff --git a/cloud_tpu_colabs/README.md b/cloud_tpu_colabs/README.md index 4a795f718c84..db3dc5f30814 100644 --- a/cloud_tpu_colabs/README.md +++ b/cloud_tpu_colabs/README.md @@ -13,25 +13,25 @@ VM](https://cloud.google.com/tpu/docs/jax-quickstart-tpu-vm). The following notebooks showcase how to use and what you can do with Cloud TPUs on Colab: -### [Pmap Cookbook](https://colab.research.google.com/github/google/jax/blob/main/cloud_tpu_colabs/Pmap_Cookbook.ipynb) +### [Pmap Cookbook](https://colab.research.google.com/github/jax-ml/jax/blob/main/cloud_tpu_colabs/Pmap_Cookbook.ipynb) A guide to getting started with `pmap`, a transform for easily distributing SPMD computations across devices. -### [Lorentz ODE Solver](https://colab.research.google.com/github/google/jax/blob/main/cloud_tpu_colabs/Lorentz_ODE_Solver.ipynb) +### [Lorentz ODE Solver](https://colab.research.google.com/github/jax-ml/jax/blob/main/cloud_tpu_colabs/Lorentz_ODE_Solver.ipynb) Contributed by Alex Alemi (alexalemi@) Solve and plot parallel ODE solutions with `pmap`. - + -### [Wave Equation](https://colab.research.google.com/github/google/jax/blob/main/cloud_tpu_colabs/Wave_Equation.ipynb) +### [Wave Equation](https://colab.research.google.com/github/jax-ml/jax/blob/main/cloud_tpu_colabs/Wave_Equation.ipynb) Contributed by Stephan Hoyer (shoyer@) Solve the wave equation with `pmap`, and make cool movies! The spatial domain is partitioned across the 8 cores of a Cloud TPU. -![](https://raw.githubusercontent.com/google/jax/main/cloud_tpu_colabs/images/wave_movie.gif) +![](https://raw.githubusercontent.com/jax-ml/jax/main/cloud_tpu_colabs/images/wave_movie.gif) -### [JAX Demo](https://colab.research.google.com/github/google/jax/blob/main/cloud_tpu_colabs/JAX_demo.ipynb) +### [JAX Demo](https://colab.research.google.com/github/jax-ml/jax/blob/main/cloud_tpu_colabs/JAX_demo.ipynb) An overview of JAX presented at the [Program Transformations for ML workshop at NeurIPS 2019](https://program-transformations.github.io/) and the [Compilers for ML workshop at CGO 2020](https://www.c4ml.org/). Covers basic numpy usage, `grad`, `jit`, `vmap`, and `pmap`. ## Performance notes @@ -53,7 +53,7 @@ By default\*, matrix multiplication in JAX on TPUs [uses bfloat16](https://cloud JAX also adds the `bfloat16` dtype, which you can use to explicitly cast arrays to bfloat16, e.g., `jax.numpy.array(x, dtype=jax.numpy.bfloat16)`. -\* We might change the default precision in the future, since it is arguably surprising. Please comment/vote on [this issue](https://github.com/google/jax/issues/2161) if it affects you! +\* We might change the default precision in the future, since it is arguably surprising. Please comment/vote on [this issue](https://github.com/jax-ml/jax/issues/2161) if it affects you! ## Running JAX on a Cloud TPU VM @@ -65,8 +65,8 @@ documentation](https://cloud.google.com/tpu/docs/jax-quickstart-tpu-vm). If you run into Cloud TPU-specific issues (e.g. trouble creating a Cloud TPU VM), please email , or if you are a [TRC](https://sites.research.google/trc/) member. You can also [file a -JAX issue](https://github.com/google/jax/issues) or [ask a discussion -question](https://github.com/google/jax/discussions) for any issues with these +JAX issue](https://github.com/jax-ml/jax/issues) or [ask a discussion +question](https://github.com/jax-ml/jax/discussions) for any issues with these notebooks or using JAX in general. If you have any other questions or comments regarding JAX on Cloud TPUs, please diff --git a/cloud_tpu_colabs/Wave_Equation.ipynb b/cloud_tpu_colabs/Wave_Equation.ipynb index 0591739191e0..16f675a76140 100644 --- a/cloud_tpu_colabs/Wave_Equation.ipynb +++ b/cloud_tpu_colabs/Wave_Equation.ipynb @@ -67,7 +67,6 @@ "source": [ "from functools import partial\n", "import jax\n", - "from jax import jit, pmap\n", "from jax import lax\n", "from jax import tree_util\n", "import jax.numpy as jnp\n", diff --git a/docs/Custom_Operation_for_GPUs.md b/docs/Custom_Operation_for_GPUs.md index 28a81715428e..8490bd489608 100644 --- a/docs/Custom_Operation_for_GPUs.md +++ b/docs/Custom_Operation_for_GPUs.md @@ -306,7 +306,7 @@ per_core_batch_size=4 seq_len=512 emb_dim=512 x = jax.random.normal( - jax.random.PRNGKey(0), + jax.random.key(0), shape=(jax.local_device_count() * per_core_batch_size, seq_len, emb_dim), dtype=jnp.bfloat16, ) diff --git a/docs/Custom_Operation_for_GPUs.py b/docs/Custom_Operation_for_GPUs.py index 4c0b4b6f7b38..31a00c49071e 100644 --- a/docs/Custom_Operation_for_GPUs.py +++ b/docs/Custom_Operation_for_GPUs.py @@ -479,7 +479,7 @@ def custom_p_rms_norm_bwd(eps, res, g): emb_dim = 512 assert jax.local_device_count() > 1, "Only 1 GPU, the example work, but it is this really what you want?" x = jax.random.normal( - jax.random.PRNGKey(0), + jax.random.key(0), shape=(jax.local_device_count() * per_core_batch_size, seq_len, emb_dim), dtype=jnp.float16, ) diff --git a/docs/_static/pallas/distributed/all_gather.svg b/docs/_static/pallas/distributed/all_gather.svg new file mode 100644 index 000000000000..5bbf6f70cf8f --- /dev/null +++ b/docs/_static/pallas/distributed/all_gather.svg @@ -0,0 +1 @@ + \ No newline at end of file diff --git a/docs/_static/pallas/distributed/race_condition.svg b/docs/_static/pallas/distributed/race_condition.svg new file mode 100644 index 000000000000..e4f981186dab --- /dev/null +++ b/docs/_static/pallas/distributed/race_condition.svg @@ -0,0 +1 @@ + \ No newline at end of file diff --git a/docs/_static/pallas/distributed/rdma_recv.svg b/docs/_static/pallas/distributed/rdma_recv.svg new file mode 100644 index 000000000000..d49ba5eb8541 --- /dev/null +++ b/docs/_static/pallas/distributed/rdma_recv.svg @@ -0,0 +1 @@ + \ No newline at end of file diff --git a/docs/_static/pallas/distributed/rdma_send.svg b/docs/_static/pallas/distributed/rdma_send.svg new file mode 100644 index 000000000000..579ba1323667 --- /dev/null +++ b/docs/_static/pallas/distributed/rdma_send.svg @@ -0,0 +1 @@ + \ No newline at end of file diff --git a/docs/_static/pallas/distributed/rdma_start.svg b/docs/_static/pallas/distributed/rdma_start.svg new file mode 100644 index 000000000000..f37bde6e83e1 --- /dev/null +++ b/docs/_static/pallas/distributed/rdma_start.svg @@ -0,0 +1 @@ + \ No newline at end of file diff --git a/docs/_static/pallas/distributed/reduce_scatter_1.svg b/docs/_static/pallas/distributed/reduce_scatter_1.svg new file mode 100644 index 000000000000..c66df4acf8a5 --- /dev/null +++ b/docs/_static/pallas/distributed/reduce_scatter_1.svg @@ -0,0 +1 @@ + \ No newline at end of file diff --git a/docs/_static/pallas/distributed/reduce_scatter_2.svg b/docs/_static/pallas/distributed/reduce_scatter_2.svg new file mode 100644 index 000000000000..bb4ae3496297 --- /dev/null +++ b/docs/_static/pallas/distributed/reduce_scatter_2.svg @@ -0,0 +1 @@ + \ No newline at end of file diff --git a/docs/_static/pallas/distributed/reduce_sum_1.svg b/docs/_static/pallas/distributed/reduce_sum_1.svg new file mode 100644 index 000000000000..9a527aff6a2e --- /dev/null +++ b/docs/_static/pallas/distributed/reduce_sum_1.svg @@ -0,0 +1 @@ + \ No newline at end of file diff --git a/docs/_static/pallas/distributed/reduce_sum_2.svg b/docs/_static/pallas/distributed/reduce_sum_2.svg new file mode 100644 index 000000000000..61685cf41863 --- /dev/null +++ b/docs/_static/pallas/distributed/reduce_sum_2.svg @@ -0,0 +1 @@ + \ No newline at end of file diff --git a/docs/_static/pallas/sparse/block_coo.svg b/docs/_static/pallas/sparse/block_coo.svg new file mode 100644 index 000000000000..474dfcb64d7a --- /dev/null +++ b/docs/_static/pallas/sparse/block_coo.svg @@ -0,0 +1 @@ + \ No newline at end of file diff --git a/docs/_static/pallas/sparse/prefetch_map.svg b/docs/_static/pallas/sparse/prefetch_map.svg new file mode 100644 index 000000000000..08fdd2c1cf39 --- /dev/null +++ b/docs/_static/pallas/sparse/prefetch_map.svg @@ -0,0 +1 @@ + \ No newline at end of file diff --git a/docs/_static/pallas/sparse/sparse_matmul.svg b/docs/_static/pallas/sparse/sparse_matmul.svg new file mode 100644 index 000000000000..06a24317cfe1 --- /dev/null +++ b/docs/_static/pallas/sparse/sparse_matmul.svg @@ -0,0 +1 @@ + \ No newline at end of file diff --git a/docs/_static/style.css b/docs/_static/style.css index 7a5c647052f0..296912ace2c8 100644 --- a/docs/_static/style.css +++ b/docs/_static/style.css @@ -20,6 +20,15 @@ background-color: rgba(171, 0, 182, var(--block-bg-opacity)); } +.ecosystem-grid { + font-size: smaller; +} + +.ecosystem-grid ul { + list-style-type: none; + padding-inline-start: 0.5em; +} + div.red-background pre { background-color: rgba(244, 204, 204, var(--block-bg-opacity)); } diff --git a/docs/_tutorials/advanced-debugging.md b/docs/_tutorials/advanced-debugging.md index 56188e0958fa..d4462feaf829 100644 --- a/docs/_tutorials/advanced-debugging.md +++ b/docs/_tutorials/advanced-debugging.md @@ -5,7 +5,7 @@ jupytext: extension: .md format_name: myst format_version: 0.13 - jupytext_version: 1.16.1 + jupytext_version: 1.16.4 kernelspec: display_name: Python 3 language: python diff --git a/docs/_tutorials/index.rst b/docs/_tutorials/index.rst index 5b3d690d5e96..0e5a6a16dcfc 100644 --- a/docs/_tutorials/index.rst +++ b/docs/_tutorials/index.rst @@ -38,10 +38,7 @@ JAX 201 :maxdepth: 1 parallelism - advanced-autodiff - gradient-checkpointing advanced-debugging - external-callbacks profiling-and-performance JAX 301 @@ -50,6 +47,4 @@ JAX 301 .. toctree:: :maxdepth: 1 - jax-primitives - jaxpr advanced-compilation diff --git a/docs/_tutorials/advanced-autodiff.md b/docs/advanced-autodiff.md similarity index 99% rename from docs/_tutorials/advanced-autodiff.md rename to docs/advanced-autodiff.md index da95f96d8b25..023dc8040954 100644 --- a/docs/_tutorials/advanced-autodiff.md +++ b/docs/advanced-autodiff.md @@ -5,7 +5,7 @@ jupytext: extension: .md format_name: myst format_version: 0.13 - jupytext_version: 1.16.1 + jupytext_version: 1.16.4 kernelspec: display_name: Python 3 language: python @@ -77,7 +77,7 @@ def meta_loss_fn(params, data): meta_grads = jax.grad(meta_loss_fn)(params, data) ``` - +(stopping-gradients)= ### Stopping gradients Autodiff enables automatic computation of the gradient of a function with respect to its inputs. Sometimes, however, you might want some additional control: for instance, you might want to avoid backpropagating gradients through some subset of the computational graph. @@ -315,7 +315,7 @@ print("jacrev result, with shape", J.shape) print(J) ``` -These two functions compute the same values (up to machine numerics), but differ in their implementation: {func}`jax.jacfwd` uses forward-mode automatic differentiation, which is more efficient for "tall" Jacobian matrices, while {func}`jax.jacrev` uses reverse-mode, which is more efficient for "wide" Jacobian matrices. For matrices that are near-square, {func}`jax.jacfwd` probably has an edge over {func}`jax.jacrev`. +These two functions compute the same values (up to machine numerics), but differ in their implementation: {func}`jax.jacfwd` uses forward-mode automatic differentiation, which is more efficient for "tall" Jacobian matrices (more outputs than inputs), while {func}`jax.jacrev` uses reverse-mode, which is more efficient for "wide" Jacobian matrices (more inputs than outputs). For matrices that are near-square, {func}`jax.jacfwd` probably has an edge over {func}`jax.jacrev`. You can also use {func}`jax.jacfwd` and {func}`jax.jacrev` with container types: @@ -571,7 +571,7 @@ print("Naive full Hessian materialization") ### Jacobian-Matrix and Matrix-Jacobian products -Now that you have {func}`jax.jvp` and {func}`jax.vjp` transformations that give you functions to push-forward or pull-back single vectors at a time, you can use JAX's {func}`jax.vmap` [transformation](https://github.com/google/jax#auto-vectorization-with-vmap) to push and pull entire bases at once. In particular, you can use that to write fast matrix-Jacobian and Jacobian-matrix products: +Now that you have {func}`jax.jvp` and {func}`jax.vjp` transformations that give you functions to push-forward or pull-back single vectors at a time, you can use JAX's {func}`jax.vmap` [transformation](https://github.com/jax-ml/jax#auto-vectorization-with-vmap) to push and pull entire bases at once. In particular, you can use that to write fast matrix-Jacobian and Jacobian-matrix products: ```{code-cell} # Isolate the function from the weight matrix to the predictions @@ -590,7 +590,7 @@ def vmap_mjp(f, x, M): outs, = vmap(vjp_fun)(M) return outs -key = random.PRNGKey(0) +key = random.key(0) num_covecs = 128 U = random.normal(key, (num_covecs,) + y.shape) @@ -640,7 +640,7 @@ def our_jacrev(f): y, vjp_fun = vjp(f, x) # Use vmap to do a matrix-Jacobian product. # Here, the matrix is the Euclidean basis, so we get all - # entries in the Jacobian at once. + # entries in the Jacobian at once. J, = vmap(vjp_fun, in_axes=0)(jnp.eye(len(y))) return J return jacfun @@ -654,7 +654,7 @@ from jax import jacfwd as builtin_jacfwd def our_jacfwd(f): def jacfun(x): _jvp = lambda s: jvp(f, (x,), (s,))[1] - Jt =vmap(_jvp, in_axes=1)(jnp.eye(len(x))) + Jt = vmap(_jvp, in_axes=1)(jnp.eye(len(x))) return jnp.transpose(Jt) return jacfun @@ -714,7 +714,7 @@ Here's a check: ```{code-cell} def check(seed): - key = random.PRNGKey(seed) + key = random.key(seed) # random coeffs for u and v key, subkey = random.split(key) @@ -768,7 +768,7 @@ Here's a check of the VJP rules: ```{code-cell} def check(seed): - key = random.PRNGKey(seed) + key = random.key(seed) # random coeffs for u and v key, subkey = random.split(key) diff --git a/docs/advanced_guide.rst b/docs/advanced_guide.rst index 5fe6c03ee059..85ed315c98e5 100644 --- a/docs/advanced_guide.rst +++ b/docs/advanced_guide.rst @@ -1,28 +1,23 @@ .. _advanced_guide: -Advanced Tutorials -================== -This section contains examples and tutorials on more advanced topics, such as Multi Core computation, Custom operations, and more in depth applications +Advanced guides +=============== -.. toctree:: - :caption: Examples - :maxdepth: 1 - - notebooks/neural_network_with_tfds_data - notebooks/Neural_Network_and_Data_Loading - notebooks/vmapped_log_probs +This section contains examples and tutorials on more advanced topics, +such as multi-core computation, automatic differentiation, and custom +operations. .. toctree:: - :caption: Parallel Computation + :caption: Parallel computation :maxdepth: 1 - multi_process notebooks/Distributed_arrays_and_automatic_parallelization notebooks/shard_map + multi_process distributed_data_loading .. toctree:: - :caption: Automatic Differentiation + :caption: Automatic differentiation :maxdepth: 1 notebooks/autodiff_cookbook @@ -30,15 +25,8 @@ This section contains examples and tutorials on more advanced topics, such as Mu notebooks/autodiff_remat .. toctree:: - :caption: JAX Internals - :maxdepth: 1 - - notebooks/How_JAX_primitives_work - notebooks/Writing_custom_interpreters_in_Jax - Custom_Operation_for_GPUs - -.. toctree:: - :caption: Deep Dives + :caption: Deep dives :maxdepth: 1 notebooks/convolutions + xla_flags diff --git a/docs/autodidax.ipynb b/docs/autodidax.ipynb index ed242ecc5710..9a956670ceea 100644 --- a/docs/autodidax.ipynb +++ b/docs/autodidax.ipynb @@ -27,7 +27,7 @@ "metadata": {}, "source": [ "[![Open in\n", - "Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/jax/blob/main/docs/autodidax.ipynb)" + "Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jax-ml/jax/blob/main/docs/autodidax.ipynb)" ] }, { @@ -1781,7 +1781,7 @@ "metadata": {}, "source": [ "This is precisely the issue that\n", - "[omnistaging](https://github.com/google/jax/pull/3370) fixed.\n", + "[omnistaging](https://github.com/jax-ml/jax/pull/3370) fixed.\n", "We want to ensure that the `JaxprTrace` started by `make_jaxpr` is always\n", "applied, regardless of whether any inputs to `bind` are boxed in corresponding\n", "`JaxprTracer` instances. We can achieve this by employing the `dynamic_trace`\n", diff --git a/docs/autodidax.md b/docs/autodidax.md index 0551b9905db3..937e1012a230 100644 --- a/docs/autodidax.md +++ b/docs/autodidax.md @@ -6,7 +6,7 @@ jupytext: extension: .md format_name: myst format_version: 0.13 - jupytext_version: 1.16.1 + jupytext_version: 1.16.4 kernelspec: display_name: Python 3 name: python3 @@ -33,7 +33,7 @@ limitations under the License. ``` [![Open in -Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/jax/blob/main/docs/autodidax.ipynb) +Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jax-ml/jax/blob/main/docs/autodidax.ipynb) +++ @@ -1399,7 +1399,7 @@ print(jaxpr) ``` This is precisely the issue that -[omnistaging](https://github.com/google/jax/pull/3370) fixed. +[omnistaging](https://github.com/jax-ml/jax/pull/3370) fixed. We want to ensure that the `JaxprTrace` started by `make_jaxpr` is always applied, regardless of whether any inputs to `bind` are boxed in corresponding `JaxprTracer` instances. We can achieve this by employing the `dynamic_trace` diff --git a/docs/autodidax.py b/docs/autodidax.py index b09534381c69..c10e6365e62d 100644 --- a/docs/autodidax.py +++ b/docs/autodidax.py @@ -20,14 +20,14 @@ # extension: .py # format_name: light # format_version: '1.5' -# jupytext_version: 1.16.1 +# jupytext_version: 1.16.4 # kernelspec: # display_name: Python 3 # name: python3 # --- # [![Open in -# Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/jax/blob/main/docs/autodidax.ipynb) +# Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jax-ml/jax/blob/main/docs/autodidax.ipynb) # # Autodidax: JAX core from scratch # @@ -1396,7 +1396,7 @@ def pp_params(params: dict[str, Any]) -> PPrint: # This is precisely the issue that -# [omnistaging](https://github.com/google/jax/pull/3370) fixed. +# [omnistaging](https://github.com/jax-ml/jax/pull/3370) fixed. # We want to ensure that the `JaxprTrace` started by `make_jaxpr` is always # applied, regardless of whether any inputs to `bind` are boxed in corresponding # `JaxprTracer` instances. We can achieve this by employing the `dynamic_trace` diff --git a/docs/automatic-differentiation.md b/docs/automatic-differentiation.md index cc4a19aaba64..07af05e3d973 100644 --- a/docs/automatic-differentiation.md +++ b/docs/automatic-differentiation.md @@ -5,7 +5,7 @@ jupytext: extension: .md format_name: myst format_version: 0.13 - jupytext_version: 1.16.1 + jupytext_version: 1.16.4 kernelspec: display_name: Python 3 language: python diff --git a/docs/automatic-vectorization.md b/docs/automatic-vectorization.md index 7559155e2e9e..032d1c56f27a 100644 --- a/docs/automatic-vectorization.md +++ b/docs/automatic-vectorization.md @@ -5,7 +5,7 @@ jupytext: extension: .md format_name: myst format_version: 0.13 - jupytext_version: 1.16.1 + jupytext_version: 1.16.4 kernelspec: display_name: Python 3 language: python diff --git a/docs/beginner_guide.rst b/docs/beginner_guide.rst index 204659ec2cb9..783d3b49ae52 100644 --- a/docs/beginner_guide.rst +++ b/docs/beginner_guide.rst @@ -52,4 +52,4 @@ questions answered are: .. _Flax: https://flax.readthedocs.io/ .. _Haiku: https://dm-haiku.readthedocs.io/ .. _JAX on StackOverflow: https://stackoverflow.com/questions/tagged/jax -.. _JAX GitHub discussions: https://github.com/google/jax/discussions \ No newline at end of file +.. _JAX GitHub discussions: https://github.com/jax-ml/jax/discussions \ No newline at end of file diff --git a/docs/building_on_jax.md b/docs/building_on_jax.md index e0a4404911a7..9416b16cde10 100644 --- a/docs/building_on_jax.md +++ b/docs/building_on_jax.md @@ -11,7 +11,7 @@ and how it's used for computational speedup in other libraries. Below are examples of how JAX's features can be used to define accelerated computation across numerous domains and software packages. -## Gradient Computation +## Gradient computation Easy gradient calculation is a key feature of JAX. In the [JaxOpt library](https://github.com/google/jaxopt) value and grad is directly utilized for users in multiple optimization algorithms in [its source code](https://github.com/google/jaxopt/blob/main/jaxopt/_src/base.py#LL87C30-L87C44). @@ -19,7 +19,7 @@ Similarly the same Dynamax Optax pairing mentioned above is an example of gradients enabling estimation methods that were challenging historically [Maximum Likelihood Expectation using Optax](https://probml.github.io/dynamax/notebooks/linear_gaussian_ssm/lgssm_learning.html). -## Computational Speedup on a Single Core across Multiple Devices +## Computational speedup on a single core across multiple devices Models defined in JAX can then be compiled to enable single computation speedup through JIT compiling. The same compiled code can then be sent to a CPU device, to a GPU or TPU device for additional speedup, @@ -28,7 +28,7 @@ This allows for a smooth workflow from development into production. In Dynamax the computationally expensive portion of a Linear State Space Model solver has been [jitted](https://github.com/probml/dynamax/blob/main/dynamax/linear_gaussian_ssm/models.py#L579). A more complex example comes from PyTensor which compiles a JAX function dynamically and then [jits the constructed function](https://github.com/pymc-devs/pytensor/blob/main/pytensor/link/jax/linker.py#L64). -## Single and Multi Computer Speedup Using Parallelization +## Single and multi computer speedup using parallelization Another benefit of JAX is the simplicity of parallelizing computation using `pmap` and `vmap` function calls or decorators. In Dynamax state space models are parallelized with a [VMAP decorator](https://github.com/probml/dynamax/blob/main/dynamax/linear_gaussian_ssm/parallel_inference.py#L89) @@ -43,7 +43,7 @@ such as Neural Networks or State Space models or others, or provide specific functionality such as optimization. Here are more specific examples of each pattern. -### Direct Usage +### Direct usage Jax can be directly imported and utilized to build models “from scratch” as shown across this website, for example in [JAX Tutorials](https://jax.readthedocs.io/en/latest/tutorials.html) or [Neural Network with JAX](https://jax.readthedocs.io/en/latest/notebooks/neural_network_with_tfds_data.html). @@ -51,7 +51,7 @@ This may be the best option if you are unable to find prebuilt code for your particular challenge, or if you're looking to reduce the number of dependencies in your codebase. -### Composable Domain Specific Libraries with JAX exposed +### Composable domain specific libraries with JAX exposed Another common approach are packages that provide prebuilt functionality, whether it be model definition, or computation of some type. Combinations of these packages can then be mixed and matched for a full @@ -68,7 +68,7 @@ With Dynamax parameters can be estimated using [Maximum Likelihood using Optax](https://probml.github.io/dynamax/notebooks/linear_gaussian_ssm/lgssm_learning.html) or full Bayesian Posterior can be estimating using [MCMC from Blackjax](https://probml.github.io/dynamax/notebooks/linear_gaussian_ssm/lgssm_hmc.html) -### JAX Totally Hidden from Users +### JAX totally hidden from users Other libraries opt to completely wrap JAX in their model specific API. An example is PyMC and [Pytensor](https://github.com/pymc-devs/pytensor), in which a user may never “see” JAX directly diff --git a/docs/conf.py b/docs/conf.py index 15941e2faa5b..d57420dec881 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -85,6 +85,7 @@ def _do_not_evaluate_in_jax( ] intersphinx_mapping = { + 'array_api': ('https://data-apis.org/array-api/2023.12/', None), 'python': ('https://docs.python.org/3/', None), 'numpy': ('https://numpy.org/doc/stable/', None), 'scipy': ('https://docs.scipy.org/doc/scipy/reference/', None), @@ -132,6 +133,8 @@ def _do_not_evaluate_in_jax( 'notebooks/*.md', 'pallas/quickstart.md', 'pallas/tpu/pipelining.md', + 'pallas/tpu/distributed.md', + 'pallas/tpu/sparse.md', 'pallas/tpu/matmul.md', 'jep/9407-type-promotion.md', 'autodidax.md', @@ -165,7 +168,7 @@ def _do_not_evaluate_in_jax( # documentation. html_theme_options = { 'show_toc_level': 2, - 'repository_url': 'https://github.com/google/jax', + 'repository_url': 'https://github.com/jax-ml/jax', 'use_repository_button': True, # add a "link to repository" button 'navigation_with_keys': False, } @@ -221,6 +224,8 @@ def _do_not_evaluate_in_jax( # Requires accelerators 'pallas/quickstart.*', 'pallas/tpu/pipelining.*', + 'pallas/tpu/distributed.*', + 'pallas/tpu/sparse.*', 'pallas/tpu/matmul.*', 'sharded-computation.*', 'distributed_data_loading.*' @@ -340,7 +345,7 @@ def linkcode_resolve(domain, info): return None filename = os.path.relpath(filename, start=os.path.dirname(jax.__file__)) lines = f"#L{linenum}-L{linenum + len(source)}" if linenum else "" - return f"https://github.com/google/jax/blob/main/jax/{filename}{lines}" + return f"https://github.com/jax-ml/jax/blob/main/jax/{filename}{lines}" # Generate redirects from deleted files to new sources rediraffe_redirects = { @@ -355,4 +360,6 @@ def linkcode_resolve(domain, info): 'jax-101/07-state.md': 'stateful-computations.md', 'jax-101/08-pjit.rst': 'sharded-computation.md', 'jax-101/index.rst': 'tutorials.rst', + 'notebooks/external_callbacks.md': 'external-callbacks.md', + 'notebooks/How_JAX_primitives_work.md': 'jax-primitives.md', } diff --git a/docs/contributing.md b/docs/contributing.md index 4aecf7153a03..99d78453c436 100644 --- a/docs/contributing.md +++ b/docs/contributing.md @@ -5,22 +5,22 @@ Everyone can contribute to JAX, and we value everyone's contributions. There are several ways to contribute, including: -- Answering questions on JAX's [discussions page](https://github.com/google/jax/discussions) +- Answering questions on JAX's [discussions page](https://github.com/jax-ml/jax/discussions) - Improving or expanding JAX's [documentation](http://jax.readthedocs.io/) -- Contributing to JAX's [code-base](http://github.com/google/jax/) -- Contributing in any of the above ways to the broader ecosystem of [libraries built on JAX](https://github.com/google/jax#neural-network-libraries) +- Contributing to JAX's [code-base](http://github.com/jax-ml/jax/) +- Contributing in any of the above ways to the broader ecosystem of [libraries built on JAX](https://github.com/jax-ml/jax#neural-network-libraries) The JAX project follows [Google's Open Source Community Guidelines](https://opensource.google/conduct/). ## Ways to contribute We welcome pull requests, in particular for those issues marked with -[contributions welcome](https://github.com/google/jax/issues?q=is%3Aopen+is%3Aissue+label%3A%22contributions+welcome%22) or -[good first issue](https://github.com/google/jax/issues?q=is%3Aopen+is%3Aissue+label%3A%22good+first+issue%22). +[contributions welcome](https://github.com/jax-ml/jax/issues?q=is%3Aopen+is%3Aissue+label%3A%22contributions+welcome%22) or +[good first issue](https://github.com/jax-ml/jax/issues?q=is%3Aopen+is%3Aissue+label%3A%22good+first+issue%22). For other proposals, we ask that you first open a GitHub -[Issue](https://github.com/google/jax/issues/new/choose) or -[Discussion](https://github.com/google/jax/discussions) +[Issue](https://github.com/jax-ml/jax/issues/new/choose) or +[Discussion](https://github.com/jax-ml/jax/discussions) to seek feedback on your planned contribution. ## Contributing code using pull requests @@ -33,7 +33,7 @@ Follow these steps to contribute code: For more information, see the Pull Request Checklist below. 2. Fork the JAX repository by clicking the **Fork** button on the - [repository page](http://www.github.com/google/jax). This creates + [repository page](http://www.github.com/jax-ml/jax). This creates a copy of the JAX repository in your own account. 3. Install Python >= 3.10 locally in order to run tests. @@ -52,7 +52,7 @@ Follow these steps to contribute code: changes. ```bash - git remote add upstream https://www.github.com/google/jax + git remote add upstream https://www.github.com/jax-ml/jax ``` 6. Create a branch where you will develop from: @@ -162,7 +162,7 @@ possible. The `git rebase -i` command might be useful to this end. (linting-and-type-checking)= -### Linting and Type-checking +### Linting and type-checking JAX uses [mypy](https://mypy.readthedocs.io/) and [ruff](https://docs.astral.sh/ruff/) to statically test code quality; the @@ -186,7 +186,7 @@ fix the issues you can push new commits to your branch. ### Restricted test suite -Once your PR has been reviewed, a JAX maintainer will mark it as `Pull Ready`. This +Once your PR has been reviewed, a JAX maintainer will mark it as `pull ready`. This will trigger a larger set of tests, including tests on GPU and TPU backends that are not available via standard GitHub CI. Detailed results of these tests are not publicly viewable, but the JAX maintainer assigned to your PR will communicate with you regarding diff --git a/docs/contributor_guide.rst b/docs/contributor_guide.rst index cb0c034be850..55094fc88958 100644 --- a/docs/contributor_guide.rst +++ b/docs/contributor_guide.rst @@ -1,18 +1,28 @@ .. _contributor-guide: -Developer Documentation -======================= +Developer notes +=============== JAX welcomes contributions from the community. -See below for various install guides to get setup as a developer -as well as developer-focused resources such as Jax Enhancement Proposals. +These are guides to get set up as a developer, as well as +developer-focused resources, such as JAX Enhancement Proposals. + +See also the :doc:`extension guides<../extensions>`, which document +some of JAX's (extensible) internals. + .. toctree:: :maxdepth: 1 + :caption: Contribution guides contributing developer - jax_internal_api + investigating_a_regression + +.. toctree:: + :maxdepth: 1 + :caption: Design and internals + autodidax jep/index - investigating_a_regression + jax_internal_api diff --git a/docs/cuda_custom_call/BUILD b/docs/cuda_custom_call/BUILD index 93715bdac171..4954ce3db4fa 100644 --- a/docs/cuda_custom_call/BUILD +++ b/docs/cuda_custom_call/BUILD @@ -16,7 +16,7 @@ load( "//jaxlib:jax.bzl", "cuda_library", "jax_generate_backend_suites", - "jax_test", + "jax_multiplatform_test", ) licenses(["notice"]) @@ -28,14 +28,11 @@ package( jax_generate_backend_suites() -jax_test( +jax_multiplatform_test( name = "cuda_custom_call_test", srcs = ["cuda_custom_call_test.py"], data = [":foo"], - disable_backends = [ - "cpu", - "tpu", - ], + enable_backends = ["gpu"], tags = ["notap"], deps = [ "//jax:extend", @@ -56,8 +53,8 @@ cuda_library( name = "foo_", srcs = ["foo.cu.cc"], deps = [ + "@local_config_cuda//cuda:cuda_headers", "@xla//xla/ffi/api:c_api", "@xla//xla/ffi/api:ffi", - "@local_config_cuda//cuda:cuda_headers", ], ) diff --git a/docs/debugging.md b/docs/debugging.md index 7ee36f19f5bf..d07f42da5c85 100644 --- a/docs/debugging.md +++ b/docs/debugging.md @@ -5,7 +5,7 @@ jupytext: extension: .md format_name: myst format_version: 0.13 - jupytext_version: 1.16.1 + jupytext_version: 1.16.4 kernelspec: display_name: Python 3 language: python @@ -21,9 +21,9 @@ This section introduces you to a set of built-in JAX debugging methods — {func Let's begin with {func}`jax.debug.print`. -## JAX `debug.print` for high-level +## `jax.debug.print` for simple inspection -**TL;DR** Here is a rule of thumb: +Here is a rule of thumb: - Use {func}`jax.debug.print` for traced (dynamic) array values with {func}`jax.jit`, {func}`jax.vmap` and others. - Use Python {func}`print` for static values, such as dtypes and array shapes. @@ -111,9 +111,9 @@ f(1, 2) To learn more about {func}`jax.debug.print` and its Sharp Bits, refer to {ref}`advanced-debugging`. -## JAX `debug.breakpoint` for `pdb`-like debugging +## `jax.debug.breakpoint` for `pdb`-like debugging -**TL;DR** Use {func}`jax.debug.breakpoint` to pause the execution of your JAX program to inspect values. +**Summary:** Use {func}`jax.debug.breakpoint` to pause the execution of your JAX program to inspect values. To pause your compiled JAX program during certain points during debugging, you can use {func}`jax.debug.breakpoint`. The prompt is similar to Python `pdb`, and it allows you to inspect the values in the call stack. In fact, {func}`jax.debug.breakpoint` is an application of {func}`jax.debug.callback` that captures information about the call stack. @@ -160,7 +160,7 @@ f(2., 1.) # ==> No breakpoint f(2., 0.) # ==> Pauses during execution ``` -## JAX `debug.callback` for more control during debugging +## `jax.debug.callback` for more control during debugging Both {func}`jax.debug.print` and {func}`jax.debug.breakpoint` are implemented using the more flexible {func}`jax.debug.callback`, which gives greater control over the diff --git a/docs/debugging/checkify_guide.md b/docs/debugging/checkify_guide.md index 2dad9b863b06..8b012e97ef28 100644 --- a/docs/debugging/checkify_guide.md +++ b/docs/debugging/checkify_guide.md @@ -2,7 +2,7 @@ -**TL;DR** Checkify lets you add `jit`-able runtime error checking (e.g. out of bounds indexing) to your JAX code. Use the `checkify.checkify` transformation together with the assert-like `checkify.check` function to add runtime checks to JAX code: +**Summary:** Checkify lets you add `jit`-able runtime error checking (e.g. out of bounds indexing) to your JAX code. Use the `checkify.checkify` transformation together with the assert-like `checkify.check` function to add runtime checks to JAX code: ```python from jax.experimental import checkify diff --git a/docs/debugging/flags.md b/docs/debugging/flags.md index 1cf1829e5152..13e34a6c3ac4 100644 --- a/docs/debugging/flags.md +++ b/docs/debugging/flags.md @@ -6,7 +6,7 @@ JAX offers flags and context managers that enable catching errors more easily. ## `jax_debug_nans` configuration option and context manager -**TL;DR** Enable the `jax_debug_nans` flag to automatically detect when NaNs are produced in `jax.jit`-compiled code (but not in `jax.pmap` or `jax.pjit`-compiled code). +**Summary:** Enable the `jax_debug_nans` flag to automatically detect when NaNs are produced in `jax.jit`-compiled code (but not in `jax.pmap` or `jax.pjit`-compiled code). `jax_debug_nans` is a JAX flag that when enabled, automatically raises an error when a NaN is detected. It has special handling for JIT-compiled -- when a NaN output is detected from a JIT-ted function, the function is re-run eagerly (i.e. without compilation) and will throw an error at the specific primitive that produced the NaN. @@ -41,7 +41,7 @@ jax.jit(f)(0., 0.) # ==> raises FloatingPointError exception! ## `jax_disable_jit` configuration option and context manager -**TL;DR** Enable the `jax_disable_jit` flag to disable JIT-compilation, enabling use of traditional Python debugging tools like `print` and `pdb` +**Summary:** Enable the `jax_disable_jit` flag to disable JIT-compilation, enabling use of traditional Python debugging tools like `print` and `pdb` `jax_disable_jit` is a JAX flag that when enabled, disables JIT-compilation throughout JAX (including in control flow functions like `jax.lax.cond` and `jax.lax.scan`). diff --git a/docs/debugging/index.md b/docs/debugging/index.md index b00fcc13d0a0..bcf561d06807 100644 --- a/docs/debugging/index.md +++ b/docs/debugging/index.md @@ -1,8 +1,8 @@ -# Runtime value debugging in JAX +# Debugging runtime values -Do you have exploding gradients? Are NaNs making you gnash your teeth? Just want to poke around the intermediate values in your computation? Check out the following JAX debugging tools! This page has TL;DR summaries and you can click the "Read more" links at the bottom to learn more. +Do you have exploding gradients? Are NaNs making you gnash your teeth? Just want to poke around the intermediate values in your computation? Check out the following JAX debugging tools! This page has summaries and you can click the "Read more" links at the bottom to learn more. Table of contents: @@ -10,9 +10,11 @@ Table of contents: * [Functional error checks with jax.experimental.checkify](checkify_guide) * [Throwing Python errors with JAX’s debug flags](flags) -## [Interactive inspection with `jax.debug`](print_breakpoint) +## Interactive inspection with `jax.debug` - **TL;DR** Use {func}`jax.debug.print` to print values to stdout in `jax.jit`-,`jax.pmap`-, and `pjit`-decorated functions, +Complete guide [here](print_breakpoint) + + **Summary:** Use {func}`jax.debug.print` to print values to stdout in `jax.jit`-,`jax.pmap`-, and `pjit`-decorated functions, and {func}`jax.debug.breakpoint` to pause execution of your compiled function to inspect values in the call stack: ```python @@ -34,11 +36,13 @@ Table of contents: # 🤯 0.9092974662780762 🤯 ``` -Click [here](print_breakpoint) to learn more! +[Read more](print_breakpoint). + +## Functional error checks with `jax.experimental.checkify` -## [Functional error checks with `jax.experimental.checkify`](checkify_guide) +Complete guide [here](checkify_guide) - **TL;DR** Checkify lets you add `jit`-able runtime error checking (e.g. out of bounds indexing) to your JAX code. Use the `checkify.checkify` transformation together with the assert-like `checkify.check` function to add runtime checks to JAX code: + **Summary:** Checkify lets you add `jit`-able runtime error checking (e.g. out of bounds indexing) to your JAX code. Use the `checkify.checkify` transformation together with the assert-like `checkify.check` function to add runtime checks to JAX code: ```python from jax.experimental import checkify @@ -77,11 +81,13 @@ Click [here](print_breakpoint) to learn more! # ValueError: nan generated by primitive sin at <...>:8 (f) ``` -Click [here](checkify_guide) to learn more! +[Read more](checkify_guide). + +## Throwing Python errors with JAX's debug flags -## [Throwing Python errors with JAX's debug flags](flags) +Complete guide [here](flags) -**TL;DR** Enable the `jax_debug_nans` flag to automatically detect when NaNs are produced in `jax.jit`-compiled code (but not in `jax.pmap` or `jax.pjit`-compiled code) and enable the `jax_disable_jit` flag to disable JIT-compilation, enabling use of traditional Python debugging tools like `print` and `pdb`. +**Summary:** Enable the `jax_debug_nans` flag to automatically detect when NaNs are produced in `jax.jit`-compiled code (but not in `jax.pmap` or `jax.pjit`-compiled code) and enable the `jax_disable_jit` flag to disable JIT-compilation, enabling use of traditional Python debugging tools like `print` and `pdb`. ```python import jax @@ -92,7 +98,7 @@ def f(x, y): jax.jit(f)(0., 0.) # ==> raises FloatingPointError exception! ``` -Click [here](flags) to learn more! +[Read more](flags). ```{toctree} :caption: Read more diff --git a/docs/debugging/print_breakpoint.md b/docs/debugging/print_breakpoint.md index d7cb68bd1b0b..73ac0262851d 100644 --- a/docs/debugging/print_breakpoint.md +++ b/docs/debugging/print_breakpoint.md @@ -1,13 +1,14 @@ -# `jax.debug.print` and `jax.debug.breakpoint` +# Compiled prints and breakpoints The {mod}`jax.debug` package offers some useful tools for inspecting values -inside of JIT-ted functions. +inside of compiled functions. ## Debugging with `jax.debug.print` and other debugging callbacks -**TL;DR** Use {func}`jax.debug.print` to print traced array values to stdout in `jit`- and `pmap`-decorated functions: +**Summary:** Use {func}`jax.debug.print` to print traced array values to +stdout in compiled (e.g. `jax.jit` or `jax.pmap`-decorated) functions: ```python import jax @@ -26,7 +27,6 @@ f(2.) # 🤯 0.9092974662780762 🤯 ``` - With some transformations, like `jax.grad` and `jax.vmap`, you can use Python's builtin `print` function to print out numerical values. But `print` won't work with `jax.jit` or `jax.pmap` because those transformations delay numerical evaluation. So use `jax.debug.print` instead! Semantically, `jax.debug.print` is roughly equivalent to the following Python function @@ -236,7 +236,7 @@ Furthermore, when using `jax.debug.print` with `jax.pjit`, a global synchronizat ## Interactive inspection with `jax.debug.breakpoint()` -**TL;DR** Use `jax.debug.breakpoint()` to pause the execution of your JAX program to inspect values: +**Summary:** Use `jax.debug.breakpoint()` to pause the execution of your JAX program to inspect values: ```python @jax.jit diff --git a/docs/deprecation.md b/docs/deprecation.md index 7a8b867b6f2e..385d31271421 100644 --- a/docs/deprecation.md +++ b/docs/deprecation.md @@ -13,24 +13,25 @@ nine months longer than SPEC-0 recommends. This means we support at least: -* All minor Python releases in the 45 months prior to each JAX release. For example: +* All Python feature releases in the 45 months prior to each JAX release. For example: - * **Python 3.9** was released October 2020, and will be supported in new JAX releases at least until **July 2024**. * **Python 3.10** was released October 2021, and will be supported in new JAX releases at least until **July 2025**. * **Python 3.11** was released October 2022, and will be supported in new JAX releases at least until **July 2026**. + * **Python 3.12** was released October 2023, and will be supported in new JAX releases at least until **July 2027**. -* All minor NumPy releases in the 24 months prior to each JAX release. For example: +* All NumPy feature releases in the 24 months prior to each JAX release. For example: - * **NumPy 1.22** was released December 2021, and will be supported in new JAX releases at least until **December 2023**. - * **NumPy 1.23** was released June 2022, and will be supported in new JAX releases at least until **June 2024**. * **NumPy 1.24** was released December 2022, and will be supported in new JAX releases at least until **December 2024**. + * **NumPy 1.25** was released June 2023, and will be supported in new JAX releases at least until **June 2025** + * **NumPy 1.26** was released September 2023, and will be supported in new JAX releases at least until **September 2025** + * **NumPy 2.0** was released June 2024, and will be supported in new JAX releases at least until **June 2026** -* All minor SciPy releases in the 24 months prior to each JAX release, starting - with SciPy version 1.9. For example: +* All SciPy feature releases in the 24 months prior to each JAX release. For example: - * **Scipy 1.9** was released July 2022, and will be supported in new JAX releases at least until **July 2024**. * **Scipy 1.10** was released January 2023, and will be supported in new JAX releases at least until **January 2025**. * **Scipy 1.11** was released June 2023, and will be supported in new JAX releases at least until **June 2025**. + * **Scipy 1.12** was released January 2024, and will be supported in new JAX releases at least until **January 2026**. + * **Scipy 1.13** was released April 2024, and will be supported in new JAX releases at least until **April 2026**. JAX releases may support older versions of Python, NumPy, and SciPy than strictly required by this policy, but support for older versions may be dropped at any time beyond the listed diff --git a/docs/developer.md b/docs/developer.md index e2850d2a94e7..5f57b2499860 100644 --- a/docs/developer.md +++ b/docs/developer.md @@ -6,7 +6,7 @@ First, obtain the JAX source code: ``` -git clone https://github.com/google/jax +git clone https://github.com/jax-ml/jax cd jax ``` @@ -26,28 +26,38 @@ If you're only modifying Python portions of JAX, we recommend installing pip install jaxlib ``` -See the [JAX readme](https://github.com/google/jax#installation) for full +See the [JAX readme](https://github.com/jax-ml/jax#installation) for full guidance on pip installation (e.g., for GPU and TPU support). ### Building `jaxlib` from source +```{warning} +While it should typically be possible to compile `jaxlib` from source using +most modern compilers, the builds are only tested using clang. Pull requests +are welcomed to improve support for different toolchains, but other compilers +are not actively supported. +``` + To build `jaxlib` from source, you must also install some prerequisites: -- a C++ compiler (g++, clang, or MSVC) +- A C++ compiler: - On Ubuntu or Debian you can install the necessary prerequisites with: + As mentioned in the box above, it is best to use a recent version of clang + (at the time of writing, the version we test is 18), but other compilers (e.g. + g++ or MSVC) may work. - ``` - sudo apt install g++ python python3-dev - ``` + On Ubuntu or Debian you can follow the instructions from the + [LLVM](https://apt.llvm.org/) documentation to install the latest stable + version of clang. If you are building on a Mac, make sure XCode and the XCode command line tools are installed. See below for Windows build instructions. -- there is no need to install Python dependencies locally, as your system - Python will be ignored during the build; please check +- Python: for running the build helper script. Note that there is no need to + install Python dependencies locally, as your system Python will be ignored + during the build; please check [Managing hermetic Python](#managing-hermetic-python) for details. To build `jaxlib` for CPU or TPU, you can run: @@ -75,11 +85,11 @@ There are two ways to build `jaxlib` with CUDA support: (1) use `python build/build.py --enable_cuda` to generate a jaxlib wheel with cuda support, or (2) use `python build/build.py --enable_cuda --build_gpu_plugin --gpu_plugin_cuda_version=12` -to generate three wheels (jaxlib without cuda, jax-cuda-plugin, -and jax-cuda-pjrt). You can set `gpu_plugin_cuda_version` to 11 or 12. +to generate three wheels (jaxlib without cuda, jax-cuda-plugin, and +jax-cuda-pjrt). By default all CUDA compilation steps performed by NVCC and +clang, but it can be restricted to clang via the `--nouse_cuda_nvcc` flag. -See `python build/build.py --help` for configuration options, including ways to -specify the paths to CUDA and CUDNN, which you must have installed. Here +See `python build/build.py --help` for configuration options. Here `python` should be the name of your Python 3 interpreter; on some systems, you may need to use `python3` instead. Despite calling the script with `python`, Bazel will always use its own hermetic Python interpreter and dependencies, only @@ -87,6 +97,33 @@ the `build/build.py` script itself will be processed by your system Python interpreter. By default, the wheel is written to the `dist/` subdirectory of the current directory. +* JAX versions starting from v.0.4.32: you can provide custom CUDA and CUDNN + versions in the configuration options. Bazel will download them and use as + target dependencies. + + To download the specific versions of CUDA/CUDNN redistributions, you can use + the following command: + + ```bash + python build/build.py --enable_cuda \ + --cuda_version=12.3.2 --cudnn_version=9.1.1 + ``` + + To point to CUDA/CUDNN/NCCL redistributions on local file system, you can use + the following command: + + ```bash + python build/build.py --enable_cuda \ + --bazel_options=--repo_env=LOCAL_CUDA_PATH="/foo/bar/nvidia/cuda" \ + --bazel_options=--repo_env=LOCAL_CUDNN_PATH="/foo/bar/nvidia/cudnn" \ + --bazel_options=--repo_env=LOCAL_NCCL_PATH="/foo/bar/nvidia/nccl" + ``` + + Please see the full list of instructions in [XLA documentation](https://github.com/openxla/xla/blob/main/docs/hermetic_cuda.md). + +* JAX versions prior v.0.4.32: you must have CUDA and CUDNN installed and + provide paths to them using configuration options. + ### Building jaxlib from source with a modified XLA repository. JAX depends on XLA, whose source code is in the @@ -112,6 +149,8 @@ particular before each `jaxlib` release. ### Additional Notes for Building `jaxlib` from source on Windows +Note: JAX does not support CUDA on Windows; use WSL2 for CUDA support. + On Windows, follow [Install Visual Studio](https://docs.microsoft.com/en-us/visualstudio/install/install-visual-studio?view=vs-2019) to set up a C++ toolchain. Visual Studio 2019 version 16.5 or newer is required. @@ -231,8 +270,8 @@ together with their corresponding hashes are specified in `build/requirements_lock_.txt` files ( e.g. `build/requirements_lock_3_12.txt` for `Python 3.12`). -To update the lock files, make sure `build/requirements.in` contains the desired -direct dependencies list and then execute the following command (which will call +To update the lock files, make sure `build/requirements.in` contains the desired +direct dependencies list and then execute the following command (which will call [pip-compile](https://pypi.org/project/pip-tools/) under the hood): ``` @@ -327,17 +366,24 @@ sudo apt-get install libopenblas-dev -y has `custom_python_interpreter()` entry there, pointing to the version of Python you want to build. -3) Run `bazel build @python_dev//:python_dev` to build Python interpreter. By default it will - be built with GCC compiler. If you wish to build with clang, you need to set - corresponding env variables to do so ( +3) Run `bazel build @python_dev//:python_dev -repo_env=HERMETIC_PYTHON_VERSION=3.12` + to build Python interpreter. Note, it is easy to confuse Python version used + to conduct the build (which is needed for technical reasons and is defined by + `HERMETIC_PYTHON_VERSION=3.12`) and the version of Python you are building + (defined by whichever version you specified in `custom_python_interpreter()` + on step 2). For build to succeed, please make sure that hermetic Python you + choose to conduct the build already exists in your configuraiton (the actual + version does not matter, as long as it is a working one). By default, Python + binary will be built with GCC compiler. If you wish to build it with clang, + you need to set corresponding env variables to do so ( e.g. `--repo_env=CC=/usr/lib/llvm-17/bin/clang --repo_env=CXX=/usr/lib/llvm-17/bin/clang++`). 4) Check the output of the previous command. At the very end of it you will find a code snippet for `python_register_toolchains()` entry with your newly built Python in it. Copy that code snippet in your `WORKSPACE` file either right after `python_init_toolchains()` entry (to add the new version of Python) or - instead of it (to replace an existing version, like replacing 3.12 with - custom built variant of 3.12). The code snippet is generated to match your + instead of it (to replace an existing version, like replacing `3.12` with + custom built variant of `3.12`). The code snippet is generated to match your actual setup, so it should work as is, but you can customize it if you choose so (for example to change location of Python's `.tgz` file so it could be downloaded remotely instead of being on local machine). @@ -345,7 +391,11 @@ sudo apt-get install libopenblas-dev -y 5) Make sure there is an entry for your Python's version in `requirements` parameter for `python_init_repositories()` in your WORKSPACE file. For example for `Python 3.13` it should have something - like `"3.13": "//build:requirements_lock_3_13.txt"`. + like `"3.13": "//build:requirements_lock_3_13.txt"`. Note, the key in the + `requirements` parameter must always be in `"major.minor"` version format, so + even if you are building Python version `3.13.0rc1` the corresponding + `requirements` entry must still be `"3.13": "//build:requirements_lock_3_13.txt"`, + **not** `"3.13.0rc1": "//build:requirements_lock_3_13_0rc1.txt"`. 6) For unstable versions of Python, optionally (but highly recommended) run `bazel build //build:all_py_deps --repo_env=HERMETIC_PYTHON_VERSION="3.13"`, @@ -572,7 +622,7 @@ pytest --doctest-modules jax/_src/numpy/lax_numpy.py Keep in mind that there are several files that are marked to be skipped when the doctest command is run on the full package; you can see the details in -[`ci-build.yaml`](https://github.com/google/jax/blob/main/.github/workflows/ci-build.yaml) +[`ci-build.yaml`](https://github.com/jax-ml/jax/blob/main/.github/workflows/ci-build.yaml) ## Type checking @@ -658,12 +708,12 @@ using [jupytext](https://jupytext.readthedocs.io/) by running `jupytext --sync` notebooks; for example: ``` -pip install jupytext==1.16.0 +pip install jupytext==1.16.4 jupytext --sync docs/notebooks/thinking_in_jax.ipynb ``` The jupytext version should match that specified in -[.pre-commit-config.yaml](https://github.com/google/jax/blob/main/.pre-commit-config.yaml). +[.pre-commit-config.yaml](https://github.com/jax-ml/jax/blob/main/.pre-commit-config.yaml). To check that the markdown and ipynb files are properly synced, you may use the [pre-commit](https://pre-commit.com/) framework to perform the same check used @@ -691,12 +741,12 @@ desired formats, and which the `jupytext --sync` command recognizes when invoked Some of the notebooks are built automatically as part of the pre-submit checks and as part of the [Read the docs](https://jax.readthedocs.io/en/latest) build. The build will fail if cells raise errors. If the errors are intentional, you can either catch them, -or tag the cell with `raises-exceptions` metadata ([example PR](https://github.com/google/jax/pull/2402/files)). +or tag the cell with `raises-exceptions` metadata ([example PR](https://github.com/jax-ml/jax/pull/2402/files)). You have to add this metadata by hand in the `.ipynb` file. It will be preserved when somebody else re-saves the notebook. We exclude some notebooks from the build, e.g., because they contain long computations. -See `exclude_patterns` in [conf.py](https://github.com/google/jax/blob/main/docs/conf.py). +See `exclude_patterns` in [conf.py](https://github.com/jax-ml/jax/blob/main/docs/conf.py). ### Documentation building on `readthedocs.io` @@ -723,7 +773,7 @@ I saw in the Readthedocs logs: mkvirtualenv jax-docs # A new virtualenv mkdir jax-docs # A new directory cd jax-docs -git clone --no-single-branch --depth 50 https://github.com/google/jax +git clone --no-single-branch --depth 50 https://github.com/jax-ml/jax cd jax git checkout --force origin/test-docs git clean -d -f -f diff --git a/docs/device_memory_profiling.md b/docs/device_memory_profiling.md index e4d871b780f3..a2fd3f68780c 100644 --- a/docs/device_memory_profiling.md +++ b/docs/device_memory_profiling.md @@ -1,4 +1,4 @@ -# Device Memory Profiling +# Profiling device memory @@ -9,7 +9,7 @@ profile, open the `memory_viewer` tab of the Tensorboard profiler for more detailed and understandable device memory usage. ``` -The JAX Device Memory Profiler allows us to explore how and why JAX programs are +The JAX device memory profiler allows us to explore how and why JAX programs are using GPU or TPU memory. For example, it can be used to: * Figure out which arrays and executables are in GPU memory at a given time, or diff --git a/docs/distributed_data_loading.md b/docs/distributed_data_loading.md index 70cbd26baa5c..4f4dd7839c37 100644 --- a/docs/distributed_data_loading.md +++ b/docs/distributed_data_loading.md @@ -5,14 +5,14 @@ jupytext: extension: .md format_name: myst format_version: 0.13 - jupytext_version: 1.16.1 + jupytext_version: 1.16.4 kernelspec: display_name: Python 3 language: python name: python3 --- -# Distributed data loading in a multi-host/multi-process environment +# Distributed data loading @@ -243,35 +243,10 @@ ds = ds.shard(num_shards=jax.process_count(), index=jax.process_index()) # Grab just the first batch from the Dataset for this example per_process_batch = ds.as_numpy_iterator().next() -per_process_batch_size = per_process_batch.shape[0] # adjust if your batch dim - # isn't 0 - -per_replica_batch_size = per_process_batch_size // jax.local_device_count() -assert per_process_batch_size % per_replica_batch_size == 0, \ - "This example doesn't implement padding." -per_replica_batches = np.split(per_process_batch, jax.local_device_count()) - -# Thanks to the very important trick about data parallelism, no need to care what -# order the devices appear in the sharding. -sharding = jax.sharding.PositionalSharding(jax.devices()) -# PositionalSharding must have same rank as data being sharded. -sharding = sharding.reshape((jax.device_count(),) + - (1,) * (per_process_batch.ndim - 1)) - -global_batch_size = per_replica_batch_size * jax.device_count() -global_batch_shape = ((global_batch_size,) + per_process_batch.shape[1:]) - -global_batch_array = jax.make_array_from_single_device_arrays( - global_batch_shape, sharding, - # Thanks again to the very important trick, no need to care which device gets - # which per-replica batch. - arrays=[jax.device_put(batch, device) - for batch, device - in zip(per_replica_batches, sharding.addressable_devices)]) - -assert global_batch_array.shape == global_batch_shape -assert (global_batch_array.addressable_shards[0].data.shape == - per_replica_batches[0].shape) +mesh = jax.make_mesh((jax.device_count(),), ('batch',)) +sharding = jax.NamedSharding(mesh, jax.sharding.PartitionSpec('batch')) +global_batch_array = jax.make_array_from_process_local_data( + sharding, per_process_batch) ``` ## Data + model parallelism @@ -366,16 +341,6 @@ per_process_batch = ds.as_numpy_iterator().next() num_model_replicas_per_process = 2 # set according to your parallelism strategy num_model_replicas_total = num_model_replicas_per_process * jax.process_count() -per_process_batch_size = per_process_batch.shape[0] # adjust if your batch dim - # isn't 0 - -per_replica_batch_size = (per_process_batch_size // - num_model_replicas_per_process) -assert per_process_batch_size % per_replica_batch_size == 0, \ - "This example doesn't implement padding." -per_replica_batches = np.split(per_process_batch, - num_model_replicas_per_process) - # Create an example `Mesh` for per-process data parallelism. Make sure all devices # are grouped by process, and then resize so each row is a model replica. mesh_devices = np.array([jax.local_devices(process_idx) @@ -393,35 +358,8 @@ mesh = jax.sharding.Mesh(mesh_devices, ["model_replicas", "data_parallelism"]) sharding = jax.sharding.NamedSharding( mesh, jax.sharding.PartitionSpec("model_replicas")) -global_batch_size = per_replica_batch_size * num_model_replicas_total -global_batch_shape = ((global_batch_size,) + per_process_batch.shape[1:]) - -# Create the final jax.Array using jax.make_array_from_callback. The callback -# will be called for each local device, and passed the N-D numpy-style index -# that describes what shard of the global data that device should receive. -# -# You don't need care exactly which index is passed in due to the very important data -# parallelism, but you do use the index argument to make sure you replicate each -# per-replica batch correctly -- the `index` argument will be the same for -# devices in the same model replica, and different for devices in different -# model replicas. - -index_to_batch = {} -def callback(index: tuple[slice, ...]) -> np.ndarray: - # Python `slice` objects aren't hashable, so manually create dict key. - index_key = tuple((slice_.start, slice_.stop) for slice_ in index) - if index_key not in index_to_batch: - # You don't care which per-replica batch goes to which replica, just take the - # next unused one. - index_to_batch[index_key] = per_replica_batches[len(index_to_batch)] - return index_to_batch[index_key] - -global_batch_array = jax.make_array_from_callback( - global_batch_shape, sharding, callback) - -assert global_batch_array.shape == global_batch_shape -assert (global_batch_array.addressable_shards[0].data.shape == - per_replica_batches[0].shape) +global_batch_array = jax.make_array_from_process_local_data( + sharding, per_process_batch) ``` ### Model parallelism across processes diff --git a/docs/errors.rst b/docs/errors.rst index 23dbaf29c46f..9965d6698bd4 100644 --- a/docs/errors.rst +++ b/docs/errors.rst @@ -1,13 +1,15 @@ .. _jax-errors: -JAX Errors -========== +Errors +====== + This page lists a few of the errors you might encounter when using JAX, along with representative examples of how one might fix them. .. currentmodule:: jax.errors .. autoclass:: ConcretizationTypeError .. autoclass:: KeyReuseError +.. autoclass:: JaxRuntimeError .. autoclass:: NonConcreteBooleanIndexError .. autoclass:: TracerArrayConversionError .. autoclass:: TracerBoolConversionError diff --git a/docs/export/export.md b/docs/export/export.md index 9e6597cef49b..4e4d50556d8e 100644 --- a/docs/export/export.md +++ b/docs/export/export.md @@ -153,7 +153,7 @@ JAX runtime system that are: an inference system that is already deployed when the exporting is done. (The particular compatibility window lengths are the same that JAX -[promised for jax2tf](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#usage-saved-model), +[promised for jax2tf](https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#usage-saved-model), and are based on [TensorFlow Compatibility](https://www.tensorflow.org/guide/versions#graph_and_checkpoint_compatibility_when_extending_tensorflow). The terminology “backward compatibility” is from the perspective of the consumer, e.g., the inference system.) @@ -626,7 +626,7 @@ We list here a history of the calling convention version numbers: June 13th, 2023 (JAX 0.4.13). * Version 7 adds support for `stablehlo.shape_assertion` operations and for `shape_assertions` specified in `disabled_checks`. - See [Errors in presence of shape polymorphism](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#errors-in-presence-of-shape-polymorphism). Supported by XlaCallModule + See [Errors in presence of shape polymorphism](https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#errors-in-presence-of-shape-polymorphism). Supported by XlaCallModule since July 12th, 2023 (cl/547482522), available in JAX serialization since July 20th, 2023 (JAX 0.4.14), and the default since August 12th, 2023 (JAX 0.4.15). @@ -721,7 +721,7 @@ that live in jaxlib): 2. Day “D”, we add the new custom call target `T_NEW`. We should create a new custom call target, and clean up the old target roughly after 6 months, rather than updating `T` in place: - * See the example [PR #20997](https://github.com/google/jax/pull/20997) + * See the example [PR #20997](https://github.com/jax-ml/jax/pull/20997) implementing the steps below. * We add the custom call target `T_NEW`. * We change the JAX lowering rules that were previous using `T`, @@ -732,10 +732,7 @@ that live in jaxlib): from jax._src.lib import version as jaxlib_version def my_lowering_rule(ctx: LoweringRuleContext, ...): - lowering_parameters = ctx.module_context.lowering_parameters - forward_compat_mode = (lowering_parameters.for_export and - not lowering_parameters.export_ignore_forward_compatibility) - if forward_compat_mode or jaxlib_version < (0, 4, 31): + if ctx.is_forward_compat() or jaxlib_version < (0, 4, 31): # this is the old lowering, using target T, while we # are in forward compatibility mode for T, or we # are in OSS and are using an old jaxlib. diff --git a/docs/export/jax2tf.md b/docs/export/jax2tf.md index 498a0418f232..9c0ee90a0d93 100644 --- a/docs/export/jax2tf.md +++ b/docs/export/jax2tf.md @@ -2,4 +2,4 @@ ## Interoperation with TensorFlow -See the [JAX2TF documentation](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md). +See the [JAX2TF documentation](https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md). diff --git a/docs/export/shape_poly.md b/docs/export/shape_poly.md index 695ca6cd21d9..b1ce80638706 100644 --- a/docs/export/shape_poly.md +++ b/docs/export/shape_poly.md @@ -353,7 +353,7 @@ symbolic constraints: E.g., `floordiv(a, b) == c` works by replacing all occurences of `floordiv(a, b)` with `c`. Equality constraints must not contain addition or - subtraction at the top-leve on the left-hand-side. Examples of + subtraction at the top-level on the left-hand-side. Examples of valid left-hand-sides are `a * b`, or `4 * a`, or `floordiv(a + c, b)`. @@ -530,7 +530,7 @@ Array([[ 9, 8, 7], >>> k, = export.symbolic_shape("k", constraints=["k <= 10"]) >>> export.export(jax.jit(my_top_k, static_argnums=0))(k, x) # doctest: +IGNORE_EXCEPTION_DETAIL Traceback (most recent call last): -KeyError: "Encountered dimension variable 'k' that is not appearing in the shapes of the function arguments +UnexpectedDimVar: "Encountered dimension variable 'k' that is not appearing in the shapes of the function arguments ``` @@ -619,7 +619,7 @@ compilation. ### Division of symbolic dimensions is partially supported JAX will attempt to simplify division and modulo operations, -e.g., `(a * b + a) // (b + 1) == a` and `6*a + 4 % 3 == 1`. +e.g., `(a * b + a) // (b + 1) == a` and `(6 * a + 4) % 3 == 1`. In particular, JAX will handle the cases when either (a) there is no remainder, or (b) the divisor is a constant in which case there may be a constant remainder. diff --git a/docs/extensions.rst b/docs/extensions.rst new file mode 100644 index 000000000000..856153cd8723 --- /dev/null +++ b/docs/extensions.rst @@ -0,0 +1,21 @@ +.. _extensions: + +Extension guides +================ + +Guides for extending JAX's capabilities, and for building libraries +that use or interface with JAX. + +.. toctree:: + :caption: Extensible JAX internals + :maxdepth: 1 + + notebooks/Writing_custom_interpreters_in_Jax + Custom_Operation_for_GPUs + jax.extend + +.. toctree:: + :caption: Libraries and extensions + :maxdepth: 1 + + building_on_jax diff --git a/docs/_tutorials/external-callbacks.md b/docs/external-callbacks.md similarity index 99% rename from docs/_tutorials/external-callbacks.md rename to docs/external-callbacks.md index a46927e6a8b4..c404f320fca7 100644 --- a/docs/_tutorials/external-callbacks.md +++ b/docs/external-callbacks.md @@ -5,7 +5,7 @@ jupytext: extension: .md format_name: myst format_version: 0.13 - jupytext_version: 1.16.1 + jupytext_version: 1.16.4 kernelspec: display_name: Python 3 language: python diff --git a/docs/faq.rst b/docs/faq.rst index 3b63128d2c28..af14f382b1d7 100644 --- a/docs/faq.rst +++ b/docs/faq.rst @@ -1,5 +1,5 @@ -JAX Frequently Asked Questions (FAQ) -==================================== +Frequently asked questions (FAQ) +================================ .. comment RST primer for Sphinx: https://thomas-cokelaer.info/tutorials/sphinx/rest_syntax.html .. comment Some links referenced here. Use `JAX - The Sharp Bits`_ (underscore at the end) to reference @@ -372,7 +372,7 @@ device. Jitted functions behave like any other primitive operations—they will follow the data and will show errors if invoked on data committed on more than one device. -(Before `PR #6002 `_ in March 2021 +(Before `PR #6002 `_ in March 2021 there was some laziness in creation of array constants, so that ``jax.device_put(jnp.zeros(...), jax.devices()[1])`` or similar would actually create the array of zeros on ``jax.devices()[1]``, instead of creating the @@ -385,7 +385,7 @@ and its use is not recommended.) For a worked-out example, we recommend reading through ``test_computation_follows_data`` in -`multi_device_test.py `_. +`multi_device_test.py `_. .. _faq-benchmark: @@ -691,7 +691,7 @@ The inner ``jnp.where`` may be needed in addition to the original one, e.g.:: Additional reading: - * `Issue: gradients through jnp.where when one of branches is nan `_. + * `Issue: gradients through jnp.where when one of branches is nan `_. * `How to avoid NaN gradients when using where `_. diff --git a/docs/ffi.ipynb b/docs/ffi.ipynb index 9dc49a74ec36..a8cd5219d4b5 100644 --- a/docs/ffi.ipynb +++ b/docs/ffi.ipynb @@ -4,7 +4,9 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "# JAX's foreign function interface\n", + "(ffi-tutorial)=\n", + "\n", + "# Foreign function interface (FFI)\n", "\n", "_This tutorial requires JAX v0.4.31 or newer._\n", "\n", @@ -362,7 +364,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "We can inspect the [jaxpr](understanding-jaxprs) of the {func}`~jax.vmap` of `rms_norm` to confirm that it isn't being rewritten using {func}`~jax.lax.scan`:" + "We can inspect the [jaxpr](jax-internals-jaxpr) of the {func}`~jax.vmap` of `rms_norm` to confirm that it isn't being rewritten using {func}`~jax.lax.scan`:" ] }, { @@ -404,7 +406,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "If your foreign function provides an efficient batching rule that isn't supported by this simple `vectorized` parameter, it might also be possible to define more flexible custom `vmap` rules using the experimental `custom_vmap` interface, but it's worth also opening an issue describing your use case on [the JAX issue tracker](https://github.com/google/jax/issues)." + "If your foreign function provides an efficient batching rule that isn't supported by this simple `vectorized` parameter, it might also be possible to define more flexible custom `vmap` rules using the experimental `custom_vmap` interface, but it's worth also opening an issue describing your use case on [the JAX issue tracker](https://github.com/jax-ml/jax/issues)." ] }, { @@ -490,7 +492,7 @@ "source": [ "At this point, we can use our new `rms_norm` function transparently for many JAX applications, and it will transform appropriately under the standard JAX function transformations like {func}`~jax.vmap` and {func}`~jax.grad`.\n", "One thing that this example doesn't support is forward-mode AD ({func}`jax.jvp`, for example) since {func}`~jax.custom_vjp` is restricted to reverse-mode.\n", - "JAX doesn't currently expose a public API for simultaneously customizing both forward-mode and reverse-mode AD, but such an API is on the roadmap, so please [open an issue](https://github.com/google/jax/issues) describing you use case if you hit this limitation in practice.\n", + "JAX doesn't currently expose a public API for simultaneously customizing both forward-mode and reverse-mode AD, but such an API is on the roadmap, so please [open an issue](https://github.com/jax-ml/jax/issues) describing you use case if you hit this limitation in practice.\n", "\n", "One other JAX feature that this example doesn't support is higher-order AD.\n", "It would be possible to work around this by wrapping the `res_norm_bwd` function above in a {func}`jax.custom_jvp` or {func}`jax.custom_vjp` decorator, but we won't go into the details of that advanced use case here.\n", diff --git a/docs/ffi.md b/docs/ffi.md index aa861d9a094f..cc3863ed99b2 100644 --- a/docs/ffi.md +++ b/docs/ffi.md @@ -5,14 +5,16 @@ jupytext: extension: .md format_name: myst format_version: 0.13 - jupytext_version: 1.16.1 + jupytext_version: 1.16.4 kernelspec: display_name: Python 3 (ipykernel) language: python name: python3 --- -# JAX's foreign function interface +(ffi-tutorial)= + +# Foreign function interface (FFI) _This tutorial requires JAX v0.4.31 or newer._ @@ -309,7 +311,7 @@ Our implementation of `rms_norm` has the appropriate semantics, and it supports np.testing.assert_allclose(jax.vmap(rms_norm)(x), jax.vmap(rms_norm_ref)(x), rtol=1e-5) ``` -We can inspect the [jaxpr](understanding-jaxprs) of the {func}`~jax.vmap` of `rms_norm` to confirm that it isn't being rewritten using {func}`~jax.lax.scan`: +We can inspect the [jaxpr](jax-internals-jaxpr) of the {func}`~jax.vmap` of `rms_norm` to confirm that it isn't being rewritten using {func}`~jax.lax.scan`: ```{code-cell} ipython3 jax.make_jaxpr(jax.vmap(rms_norm))(x) @@ -331,7 +333,7 @@ def rms_norm_not_vectorized(x, eps=1e-5): jax.make_jaxpr(jax.vmap(rms_norm_not_vectorized))(x) ``` -If your foreign function provides an efficient batching rule that isn't supported by this simple `vectorized` parameter, it might also be possible to define more flexible custom `vmap` rules using the experimental `custom_vmap` interface, but it's worth also opening an issue describing your use case on [the JAX issue tracker](https://github.com/google/jax/issues). +If your foreign function provides an efficient batching rule that isn't supported by this simple `vectorized` parameter, it might also be possible to define more flexible custom `vmap` rules using the experimental `custom_vmap` interface, but it's worth also opening an issue describing your use case on [the JAX issue tracker](https://github.com/jax-ml/jax/issues). +++ @@ -404,7 +406,7 @@ np.testing.assert_allclose( At this point, we can use our new `rms_norm` function transparently for many JAX applications, and it will transform appropriately under the standard JAX function transformations like {func}`~jax.vmap` and {func}`~jax.grad`. One thing that this example doesn't support is forward-mode AD ({func}`jax.jvp`, for example) since {func}`~jax.custom_vjp` is restricted to reverse-mode. -JAX doesn't currently expose a public API for simultaneously customizing both forward-mode and reverse-mode AD, but such an API is on the roadmap, so please [open an issue](https://github.com/google/jax/issues) describing you use case if you hit this limitation in practice. +JAX doesn't currently expose a public API for simultaneously customizing both forward-mode and reverse-mode AD, but such an API is on the roadmap, so please [open an issue](https://github.com/jax-ml/jax/issues) describing you use case if you hit this limitation in practice. One other JAX feature that this example doesn't support is higher-order AD. It would be possible to work around this by wrapping the `res_norm_bwd` function above in a {func}`jax.custom_jvp` or {func}`jax.custom_vjp` decorator, but we won't go into the details of that advanced use case here. diff --git a/docs/glossary.rst b/docs/glossary.rst index 78b7fcd246f3..286b07e21a66 100644 --- a/docs/glossary.rst +++ b/docs/glossary.rst @@ -1,5 +1,5 @@ -JAX Glossary of Terms -===================== +Glossary of terms +================= .. glossary:: @@ -28,9 +28,9 @@ JAX Glossary of Terms able to target GPUs for fast operations on arrays (see also :term:`CPU` and :term:`TPU`). jaxpr - Short for *JAX Expression*, a jaxpr is an intermediate representation of a computation that + Short for *JAX expression*, a jaxpr is an intermediate representation of a computation that is generated by JAX, and is forwarded to :term:`XLA` for compilation and execution. - See :ref:`understanding-jaxprs` for more discussion and examples. + See :ref:`jax-internals-jaxpr` for more discussion and examples. JIT Short for *Just In Time* compilation, JIT in JAX generally refers to the compilation of diff --git a/docs/_tutorials/gradient-checkpointing.md b/docs/gradient-checkpointing.md similarity index 99% rename from docs/_tutorials/gradient-checkpointing.md rename to docs/gradient-checkpointing.md index b768514e4bb0..14a532b54dd1 100644 --- a/docs/_tutorials/gradient-checkpointing.md +++ b/docs/gradient-checkpointing.md @@ -5,7 +5,7 @@ jupytext: extension: .md format_name: myst format_version: 0.13 - jupytext_version: 1.16.1 + jupytext_version: 1.16.4 kernelspec: display_name: Python 3 language: python diff --git a/docs/index.rst b/docs/index.rst index 2e13c109dbbe..2dd856ab88ef 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -1,13 +1,9 @@ -JAX: High-Performance Array Computing +JAX: High performance array computing ===================================== JAX is a Python library for accelerator-oriented array computation and program transformation, designed for high-performance numerical computing and large-scale machine learning. -If you're looking to train neural networks, use Flax_ and start with its documentation. -Some associated tools are Optax_ and Orbax_. -For an end-to-end transformer library built on JAX, see MaxText_. - .. grid:: 3 :margin: 0 :padding: 0 @@ -27,7 +23,7 @@ For an end-to-end transformer library built on JAX, see MaxText_. JAX includes composable function transformations for compilation, batching, automatic differentiation, and parallelization. - .. grid-item-card:: Run Anywhere + .. grid-item-card:: Run anywhere :columns: 12 6 6 4 :class-card: sd-border-0 :shadow: None @@ -36,34 +32,95 @@ For an end-to-end transformer library built on JAX, see MaxText_. .. grid:: 3 - .. grid-item-card:: :material-regular:`rocket_launch;2em` Getting Started + .. grid-item-card:: :material-regular:`rocket_launch;2em` Getting started :columns: 12 6 6 4 :link: beginner-guide :link-type: ref :class-card: getting-started - .. grid-item-card:: :material-regular:`library_books;2em` User Guides + .. grid-item-card:: :material-regular:`library_books;2em` User guides :columns: 12 6 6 4 :link: user-guides :link-type: ref :class-card: user-guides - .. grid-item-card:: :material-regular:`laptop_chromebook;2em` Developer Docs + .. grid-item-card:: :material-regular:`laptop_chromebook;2em` Developer notes :columns: 12 6 6 4 :link: contributor-guide :link-type: ref :class-card: developer-docs +If you're looking to train neural networks, use Flax_ and start with its tutorials. +For an end-to-end transformer library built on JAX, see MaxText_. + +Ecosystem +--------- +JAX itself is narrowly-scoped and focuses on efficient array operations & program +transformations. Built around JAX is an evolving ecosystem of machine learning and +numerical computing tools; the following is just a small sample of what is out there: + +.. grid:: 4 + :class-container: ecosystem-grid + + .. grid-item:: :material-outlined:`hub;2em` **Neural networks** + + - Flax_ + - NNX_ + - Equinox_ + - Keras_ + + .. grid-item:: :material-regular:`show_chart;2em` **Optimizers & solvers** + + - Optax_ + - Optimistix_ + - Lineax_ + - Diffrax_ + + .. grid-item:: :material-outlined:`storage;2em` **Data loading** + + - Grain_ + - `Tensorflow datasets`_ + - `Hugging Face datasets`_ + + .. grid-item:: :material-regular:`construction;2em` **Miscellaneous tools** + + - Orbax_ + - Chex_ + + .. grid-item:: :material-regular:`lan;2em` **Probabilistic programming** + + - Blackjax_ + - Numpyro_ + - PyMC_ + + .. grid-item:: :material-regular:`bar_chart;2em` **Probabilistic modeling** + + - `Tensorflow probabilty`_ + - Distrax_ + + .. grid-item:: :material-outlined:`animation;2em` **Physics & simulation** + + - `JAX MD`_ + - Brax_ + + .. grid-item:: :material-regular:`language;2em` **LLMs** + + - MaxText_ + - AXLearn_ + - Levanter_ + - EasyLM_ + + +Many more JAX-based libraries have been developed; the community-run `Awesome JAX`_ page +maintains an up-to-date list. .. toctree:: :hidden: :maxdepth: 1 - :caption: Getting Started + :caption: Getting started installation quickstart - notebooks/Common_Gotchas_in_JAX - faq .. toctree:: :hidden: @@ -71,16 +128,19 @@ For an end-to-end transformer library built on JAX, see MaxText_. tutorials + notebooks/Common_Gotchas_in_JAX + + faq .. toctree:: :hidden: :maxdepth: 2 - :caption: Further Resources + :caption: More guides/resources user_guides advanced_guide contributor_guide - building_on_jax + extensions notes jax @@ -93,7 +153,28 @@ For an end-to-end transformer library built on JAX, see MaxText_. glossary +.. _Awesome JAX: https://github.com/n2cholas/awesome-jax +.. _AXLearn: https://github.com/apple/axlearn +.. _Blackjax: https://blackjax-devs.github.io/blackjax/ +.. _Brax: https://github.com/google/brax/ +.. _Chex: https://chex.readthedocs.io/ +.. _Diffrax: https://docs.kidger.site/diffrax/ +.. _Distrax: https://github.com/google-deepmind/distrax +.. _EasyLM: https://github.com/young-geng/EasyLM +.. _Equinox: https://docs.kidger.site/equinox/ .. _Flax: https://flax.readthedocs.io/ -.. _Orbax: https://orbax.readthedocs.io/ -.. _Optax: https://optax.readthedocs.io/ +.. _Grain: https://github.com/google/grain +.. _Hugging Face datasets: https://huggingface.co/docs/datasets/ +.. _JAX MD: https://jax-md.readthedocs.io/ +.. _Keras: https://keras.io/ +.. _Levanter: https://github.com/stanford-crfm/levanter +.. _Lineax: https://github.com/patrick-kidger/lineax .. _MaxText: https://github.com/google/maxtext/ +.. _NNX: https://flax.readthedocs.io/en/latest/nnx/ +.. _Numpyro: https://num.pyro.ai/en/latest/index.html +.. _Optax: https://optax.readthedocs.io/ +.. _Optimistix: https://github.com/patrick-kidger/optimistix +.. _Orbax: https://orbax.readthedocs.io/ +.. _PyMC: https://www.pymc.io/ +.. _Tensorflow datasets: https://www.tensorflow.org/datasets +.. _Tensorflow probabilty: https://www.tensorflow.org/probability diff --git a/docs/installation.md b/docs/installation.md index 20ffe436ff8a..acb802ea939c 100644 --- a/docs/installation.md +++ b/docs/installation.md @@ -1,5 +1,5 @@ (installation)= -# Installing JAX +# Installation @@ -7,7 +7,7 @@ Using JAX requires installing two packages: `jax`, which is pure Python and cross-platform, and `jaxlib` which contains compiled binaries, and requires different builds for different operating systems and accelerators. -**TL;DR** For most users, a typical JAX installation may look something like this: +**Summary:** For most users, a typical JAX installation may look something like this: * **CPU-only (Linux/macOS/Windows)** ``` @@ -176,7 +176,7 @@ installation. JAX requires libdevice10.bc, which typically comes from the cuda-nvvm package. Make sure that it is present in your CUDA installation. -Please let the JAX team know on [the GitHub issue tracker](https://github.com/google/jax/issues) +Please let the JAX team know on [the GitHub issue tracker](https://github.com/jax-ml/jax/issues) if you run into any errors or problems with the pre-built wheels. (docker-containers-nvidia-gpu)= @@ -216,7 +216,7 @@ refer to **Note:** There are several caveats with the Metal plugin: * The Metal plugin is new and experimental and has a number of - [known issues](https://github.com/google/jax/issues?q=is%3Aissue+is%3Aopen+label%3A%22Apple+GPU+%28Metal%29+plugin%22). + [known issues](https://github.com/jax-ml/jax/issues?q=is%3Aissue+is%3Aopen+label%3A%22Apple+GPU+%28Metal%29+plugin%22). Please report any issues on the JAX issue tracker. * The Metal plugin currently requires very specific versions of `jax` and `jaxlib`. This restriction will be relaxed over time as the plugin API @@ -269,22 +269,26 @@ for more details. Nightly releases reflect the state of the main JAX repository at the time they are built, and may not pass the full test suite. +Unlike the instructions for installing a JAX release, here we name all of JAX's +packages explicitly on the command line, so `pip` will upgrade them if a newer +version is available. + - CPU only: ```bash -pip install -U --pre jax -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html +pip install -U --pre jax jaxlib -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html ``` - Google Cloud TPU: ```bash -pip install -U --pre jax[tpu] -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/libtpu_releases.html +pip install -U --pre jax jaxlib libtpu-nightly requests -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/libtpu_releases.html ``` - NVIDIA GPU (CUDA 12): ```bash -pip install -U --pre jax[cuda12] -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html +pip install -U --pre jax jaxlib jax-cuda12-plugin[with_cuda] jax-cuda12-pjrt -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html ``` - NVIDIA GPU (CUDA 12) legacy: diff --git a/docs/investigating_a_regression.md b/docs/investigating_a_regression.md index 4affae3a65d8..61d219d1bae1 100644 --- a/docs/investigating_a_regression.md +++ b/docs/investigating_a_regression.md @@ -9,7 +9,7 @@ Let's first make a JAX issue. But if you can pinpoint the commit that triggered the regression, it will really help us. This document explains how we identified the commit that caused a -[15% performance regression](https://github.com/google/jax/issues/17686). +[15% performance regression](https://github.com/jax-ml/jax/issues/17686). ## Steps @@ -23,9 +23,9 @@ Here is a suggested investigation strategy: 2. Hourly recompilation while keeping XLA and JAX in sync. 3. Final verification: maybe a manual check of a few commits (or a git bisect). -## Nightly investigation. +## Nightly investigation -This can be done by using [JAX-Toolbox nightly +This can be done by using the [NVIDIA JAX-Toolbox nightly containers](https://github.com/NVIDIA/JAX-Toolbox). - Some days, bugs prevent the container from being built, or there are temporary regressions. Just discard those days. @@ -34,7 +34,7 @@ containers](https://github.com/NVIDIA/JAX-Toolbox). - test_runner.sh: will start the containers and the test. - test.sh: will install missing dependencies and run the test -Here are real example scripts used for the issue: https://github.com/google/jax/issues/17686 +Here are real example scripts used for the issue: https://github.com/jax-ml/jax/issues/17686 - test_runner.sh: ``` for m in 7 8 9; do @@ -128,7 +128,7 @@ investigate hourly between 8-24 and 8-26. There was a smaller slowdown earlier, lets ignore it for this example. It would be only another hourly investigation between those dates. -## Hourly investigation. +## Hourly investigation This does a checkout of JAX and XLA at each hour between the 2 dates, rebuilds everything and runs the test. The scripts are structured diff --git a/docs/_tutorials/jax-primitives.md b/docs/jax-primitives.md similarity index 99% rename from docs/_tutorials/jax-primitives.md rename to docs/jax-primitives.md index 51abe2916693..abdc8be6d0a8 100644 --- a/docs/_tutorials/jax-primitives.md +++ b/docs/jax-primitives.md @@ -5,7 +5,7 @@ jupytext: extension: .md format_name: myst format_version: 0.13 - jupytext_version: 1.16.1 + jupytext_version: 1.16.4 kernelspec: display_name: Python 3 language: python @@ -306,7 +306,7 @@ from jax.interpreters import mlir mlir.register_lowering(multiply_add_p, multiply_add_lowering, platform='cpu') ``` -You will now succeed to apply `jax.jit`. Notice below that JAX first evaluates the function abstractly, which triggers the `multiply_add_abstract_eval` function, and then compiles the set of primitives it has encountered, including `multiply_add`. At this point JAX invokes `multiply_add_xla_translation`. +You will now succeed to apply `jax.jit`. Notice below that JAX first evaluates the function abstractly, which triggers the `multiply_add_abstract_eval` function, and then compiles the set of primitives it has encountered, including `multiply_add`. At this point JAX invokes `multiply_add_lowering`. ```{code-cell} assert api.jit(lambda x, y: square_add_prim(x, y))(2., 10.) == 14. diff --git a/docs/jax.lax.rst b/docs/jax.lax.rst index 7b19955d3d78..3a03665b3217 100644 --- a/docs/jax.lax.rst +++ b/docs/jax.lax.rst @@ -119,6 +119,7 @@ Operators ne neg nextafter + optimization_barrier pad platform_dependent polygamma @@ -248,6 +249,7 @@ Argument classes .. autoclass:: ConvDimensionNumbers .. autoclass:: ConvGeneralDilatedDimensionNumbers +.. autoclass:: DotAlgorithm .. autoclass:: GatherDimensionNumbers .. autoclass:: GatherScatterMode .. autoclass:: Precision diff --git a/docs/jax.random.rst b/docs/jax.random.rst index 9d6369d2d2b1..6c5427c05e66 100644 --- a/docs/jax.random.rst +++ b/docs/jax.random.rst @@ -12,13 +12,13 @@ Key Creation & Manipulation .. autosummary:: :toctree: _autosummary - PRNGKey key key_data wrap_key_data fold_in split clone + PRNGKey Random Samplers ~~~~~~~~~~~~~~~ diff --git a/docs/jax.rst b/docs/jax.rst index b112490a0912..ecfeaf29e3c0 100644 --- a/docs/jax.rst +++ b/docs/jax.rst @@ -1,7 +1,7 @@ .. currentmodule:: jax -Public API: jax package -======================= +Public API: ``jax`` package +=========================== Subpackages ----------- @@ -69,7 +69,6 @@ Just-in-time compilation (:code:`jit`) jit disable_jit ensure_compile_time_eval - xla_computation make_jaxpr eval_shape ShapeDtypeStruct @@ -92,6 +91,7 @@ Automatic differentiation grad value_and_grad + jacobian jacfwd jacrev hessian @@ -99,12 +99,29 @@ Automatic differentiation linearize linear_transpose vjp - custom_jvp - custom_vjp custom_gradient closure_convert checkpoint +``custom_jvp`` +~~~~~~~~~~~~~~ + +.. autosummary:: + :toctree: _autosummary + + custom_jvp + custom_jvp.defjvp + custom_jvp.defjvps + +``custom_vjp`` +~~~~~~~~~~~~~~ + +.. autosummary:: + :toctree: _autosummary + + custom_vjp + custom_vjp.defvjp + jax.Array (:code:`jax.Array`) ----------------------------- @@ -116,6 +133,73 @@ jax.Array (:code:`jax.Array`) make_array_from_single_device_arrays make_array_from_process_local_data +Array properties and methods +~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autosummary:: + :toctree: _autosummary + + Array.addressable_shards + Array.all + Array.any + Array.argmax + Array.argmin + Array.argpartition + Array.argsort + Array.astype + Array.at + Array.choose + Array.clip + Array.compress + Array.conj + Array.conjugate + Array.copy + Array.copy_to_host_async + Array.cumprod + Array.cumsum + Array.device + Array.diagonal + Array.dot + Array.dtype + Array.flat + Array.flatten + Array.global_shards + Array.imag + Array.is_fully_addressable + Array.is_fully_replicated + Array.item + Array.itemsize + Array.max + Array.mean + Array.min + Array.nbytes + Array.ndim + Array.nonzero + Array.prod + Array.ptp + Array.ravel + Array.real + Array.repeat + Array.reshape + Array.round + Array.searchsorted + Array.shape + Array.sharding + Array.size + Array.sort + Array.squeeze + Array.std + Array.sum + Array.swapaxes + Array.take + Array.to_device + Array.trace + Array.transpose + Array.var + Array.view + Array.T + Array.mT + Vectorization (:code:`vmap`) ---------------------------- @@ -138,6 +222,7 @@ Parallelization (:code:`pmap`) device_count local_device_count process_count + process_indices Callbacks --------- diff --git a/docs/jax.scipy.rst b/docs/jax.scipy.rst index f6d8a151440b..abdf5069ee08 100644 --- a/docs/jax.scipy.rst +++ b/docs/jax.scipy.rst @@ -164,6 +164,7 @@ jax.scipy.special expit expn factorial + fresnel gamma gammainc gammaincc diff --git a/docs/jax_internal_api.rst b/docs/jax_internal_api.rst index fe65054d22c1..1ece596d88ef 100644 --- a/docs/jax_internal_api.rst +++ b/docs/jax_internal_api.rst @@ -1,5 +1,5 @@ -Internal APIs -============= +Internal API reference +====================== core ---- diff --git a/docs/_tutorials/jaxpr.md b/docs/jaxpr.md similarity index 99% rename from docs/_tutorials/jaxpr.md rename to docs/jaxpr.md index 9fe990c0a8ba..974ed39c1663 100644 --- a/docs/_tutorials/jaxpr.md +++ b/docs/jaxpr.md @@ -5,7 +5,7 @@ jupytext: extension: .md format_name: myst format_version: 0.13 - jupytext_version: 1.16.1 + jupytext_version: 1.16.4 kernelspec: display_name: Python 3 language: python diff --git a/docs/jaxpr.rst b/docs/jaxpr.rst deleted file mode 100644 index 56be62162a9e..000000000000 --- a/docs/jaxpr.rst +++ /dev/null @@ -1,472 +0,0 @@ -.. _understanding-jaxprs: - -Understanding Jaxprs -==================== - -Updated: May 3, 2020 (for commit f1a46fe). - -Conceptually, one can think of JAX transformations as first trace-specializing -the Python function to be transformed into a small and well-behaved -intermediate form that is then interpreted with transformation-specific -interpretation rules. One of the reasons JAX can pack so much power into such a -small software package is that it starts with a familiar and flexible -programming interface (Python with NumPy) and it uses the actual Python -interpreter to do most of the heavy lifting to distill the essence of the -computation into a simple statically-typed expression language with limited -higher-order features. That language is the jaxpr language. - -Not all Python programs can be processed this way, but it turns out that many -scientific computing and machine learning programs can. - -Before we proceed, it is important to point out that not all JAX -transformations literally materialize a jaxpr as described above; some, e.g., -differentiation or batching, will apply transformations incrementally during -tracing. Nevertheless, if one wants to understand how JAX works internally, or -to make use of the result of JAX tracing, it is useful to understand jaxprs. - -A jaxpr instance represents a function with one or more typed parameters (input -variables) and one or more typed results. The results depend only on the input -variables; there are no free variables captured from enclosing scopes. The -inputs and outputs have types, which in JAX are represented as abstract values. -There are two related representations in the code for jaxprs, -:py:class:`jax.core.Jaxpr` and :py:class:`jax.core.ClosedJaxpr`. A -:py:class:`jax.core.ClosedJaxpr` represents a partially-applied -:py:class:`jax.core.Jaxpr`, and is what you obtain when you use -:py:func:`jax.make_jaxpr` to inspect jaxprs. It has the following fields: - - * ``jaxpr`` is a :py:class:`jax.core.Jaxpr` representing the actual - computation content of the function (described below). - * ``consts`` is a list of constants. - -The most interesting part of the ClosedJaxpr is the actual execution content, -represented as a :py:class:`jax.core.Jaxpr` as printed using the following -grammar:: - - Jaxpr ::= { lambda Var* ; Var+. let - Eqn* - in [Expr+] } - -where: - * The parameters of the jaxpr are shown as two lists of variables separated by - ``;``. The first set of variables are the ones that have been introduced - to stand for constants that have been hoisted out. These are called the - ``constvars``, and in a :py:class:`jax.core.ClosedJaxpr` the ``consts`` - field holds corresponding values. The second list of variables, called - ``invars``, correspond to the inputs of the traced Python function. - * ``Eqn*`` is a list of equations, defining intermediate variables referring to - intermediate expressions. Each equation defines one or more variables as the - result of applying a primitive on some atomic expressions. Each equation uses only - input variables and intermediate variables defined by previous equations. - * ``Expr+``: is a list of output atomic expressions (literals or variables) - for the jaxpr. - -Equations are printed as follows:: - - Eqn ::= Var+ = Primitive [ Param* ] Expr+ - -where: - * ``Var+`` are one or more intermediate variables to be defined as the output - of a primitive invocation (some primitives can return multiple values). - * ``Expr+`` are one or more atomic expressions, each either a variable or a - literal constant. A special variable ``unitvar`` or literal ``unit``, - printed as ``*``, represents a value that is not needed - in the rest of the computation and has been elided. That is, units are just - placeholders. - * ``Param*`` are zero or more named parameters to the primitive, printed in - square brackets. Each parameter is shown as ``Name = Value``. - - -Most jaxpr primitives are first-order (they take just one or more ``Expr`` as arguments):: - - Primitive := add | sub | sin | mul | ... - - -The jaxpr primitives are documented in the :py:mod:`jax.lax` module. - -For example, here is the jaxpr produced for the function ``func1`` below - ->>> from jax import make_jaxpr ->>> import jax.numpy as jnp ->>> def func1(first, second): -... temp = first + jnp.sin(second) * 3. -... return jnp.sum(temp) -... ->>> print(make_jaxpr(func1)(jnp.zeros(8), jnp.ones(8))) -{ lambda ; a:f32[8] b:f32[8]. let - c:f32[8] = sin b - d:f32[8] = mul c 3.0 - e:f32[8] = add a d - f:f32[] = reduce_sum[axes=(0,)] e - in (f,) } - -Here there are no constvars, ``a`` and ``b`` are the input variables -and they correspond respectively to -``first`` and ``second`` function parameters. The scalar literal ``3.0`` is kept -inline. -The ``reduce_sum`` primitive has named parameter ``axes``, in addition to the -operand ``e``. - -Note that even though execution of a program that calls into JAX builds a jaxpr, -Python-level control-flow and Python-level functions execute normally. -This means that just because a Python program contains functions and control-flow, -the resulting jaxpr does not have to contain control-flow or higher-order features. - -For example, when tracing the function ``func3`` JAX will inline the call to -``inner`` and the conditional ``if second.shape[0] > 4``, and will produce the same -jaxpr as before - ->>> def func2(inner, first, second): -... temp = first + inner(second) * 3. -... return jnp.sum(temp) -... ->>> def inner(second): -... if second.shape[0] > 4: -... return jnp.sin(second) -... else: -... assert False -... ->>> def func3(first, second): -... return func2(inner, first, second) -... ->>> print(make_jaxpr(func3)(jnp.zeros(8), jnp.ones(8))) -{ lambda ; a:f32[8] b:f32[8]. let - c:f32[8] = sin b - d:f32[8] = mul c 3.0 - e:f32[8] = add a d - f:f32[] = reduce_sum[axes=(0,)] e - in (f,) } - - -Handling PyTrees ----------------- - -In jaxpr there are no tuple types; instead primitives take multiple inputs -and produce multiple outputs. When processing a function that has structured -inputs or outputs, JAX will flatten those and in jaxpr they will appear as lists -of inputs and outputs. For more details, please see the documentation for -PyTrees (:ref:`pytrees`). - -For example, the following code produces an identical jaxpr to what we saw -before (with two input vars, one for each element of the input tuple) - - ->>> def func4(arg): # Arg is a pair -... temp = arg[0] + jnp.sin(arg[1]) * 3. -... return jnp.sum(temp) -... ->>> print(make_jaxpr(func4)((jnp.zeros(8), jnp.ones(8)))) -{ lambda ; a:f32[8] b:f32[8]. let - c:f32[8] = sin b - d:f32[8] = mul c 3.0 - e:f32[8] = add a d - f:f32[] = reduce_sum[axes=(0,)] e - in (f,) } - - - -Constant Vars -------------- - -Some values in jaxprs are constants, in that their value does not depend on the -jaxpr's arguments. When these values are scalars they are represented directly -in the jaxpr equations; non-scalar array constants are instead hoisted out to -the top-level jaxpr, where they correspond to constant variables ("constvars"). -These constvars differ from the other jaxpr parameters ("invars") only as a -bookkeeping convention. - - -Higher-order primitives ------------------------ - -jaxpr includes several higher-order primitives. They are more complicated because -they include sub-jaxprs. - -Conditionals -^^^^^^^^^^^^ - -JAX traces through normal Python conditionals. To capture a -conditional expression for dynamic execution, one must use the -:py:func:`jax.lax.switch` and :py:func:`jax.lax.cond` constructors, -which have the signatures:: - - lax.switch(index: int, branches: Sequence[A -> B], operand: A) -> B - - lax.cond(pred: bool, true_body: A -> B, false_body: A -> B, operand: A) -> B - -Both of these will bind a primitive called ``cond`` internally. The -``cond`` primitive in jaxprs reflects the more general signature of -:py:func:`lax.switch`: it takes an integer denoting the index of the branch -to execute (clamped into valid indexing range). - -For example: - ->>> from jax import lax ->>> ->>> def one_of_three(index, arg): -... return lax.switch(index, [lambda x: x + 1., -... lambda x: x - 2., -... lambda x: x + 3.], -... arg) -... ->>> print(make_jaxpr(one_of_three)(1, 5.)) -{ lambda ; a:i32[] b:f32[]. let - c:i32[] = convert_element_type[new_dtype=int32 weak_type=False] a - d:i32[] = clamp 0 c 2 - e:f32[] = cond[ - branches=( - { lambda ; f:f32[]. let g:f32[] = add f 1.0 in (g,) } - { lambda ; h:f32[]. let i:f32[] = sub h 2.0 in (i,) } - { lambda ; j:f32[]. let k:f32[] = add j 3.0 in (k,) } - ) - ] d b - in (e,) } - -The `branches` parameter to the cond primitive corresponds to the branch -functionals. In this example, those functionals each take one input variable, -corresponding to ``x``. - -The above instance of the cond primitive takes two operands. The first -one (``d``) is the branch index, then ``b`` is the operand (``arg``) to -be passed to whichever jaxpr in ``branches`` is selected by the branch -index. - -Another example, using :py:func:`lax.cond`: - ->>> from jax import lax ->>> ->>> def func7(arg): -... return lax.cond(arg >= 0., -... lambda xtrue: xtrue + 3., -... lambda xfalse: xfalse - 3., -... arg) -... ->>> print(make_jaxpr(func7)(5.)) -{ lambda ; a:f32[]. let - b:bool[] = ge a 0.0 - c:i32[] = convert_element_type[new_dtype=int32 weak_type=False] b - d:f32[] = cond[ - branches=( - { lambda ; e:f32[]. let f:f32[] = sub e 3.0 in (f,) } - { lambda ; g:f32[]. let h:f32[] = add g 3.0 in (h,) } - ) - ] c a - in (d,) } - -In this case, the boolean predicate is converted to an integer index -(0 or 1), and ``branches`` are jaxprs that correspond to the false and -true branch functionals, in that order. Again, each functional takes -one input variable, corresponding to ``xfalse`` and ``xtrue`` -respectively. - -The following example shows a more complicated situation when the input -to the branch functionals is a tuple, and the `false` branch functional -contains a constant ``jnp.ones(1)`` that is hoisted as a `constvar` - ->>> def func8(arg1, arg2): # arg2 is a pair -... return lax.cond(arg1 >= 0., -... lambda xtrue: xtrue[0], -... lambda xfalse: jnp.array([1]) + xfalse[1], -... arg2) -... ->>> print(make_jaxpr(func8)(5., (jnp.zeros(1), 2.))) -{ lambda a:i32[1]; b:f32[] c:f32[1] d:f32[]. let - e:bool[] = ge b 0.0 - f:i32[] = convert_element_type[new_dtype=int32 weak_type=False] e - g:f32[1] = cond[ - branches=( - { lambda ; h:i32[1] i:f32[1] j:f32[]. let - k:f32[1] = convert_element_type[new_dtype=float32 weak_type=True] h - l:f32[1] = add k j - in (l,) } - { lambda ; m_:i32[1] n:f32[1] o:f32[]. let in (n,) } - ) - ] f a c d - in (g,) } - - - -While -^^^^^ - -Just like for conditionals, Python loops are inlined during tracing. -If you want to capture a loop for dynamic execution, you must use one of several -special operations, :py:func:`jax.lax.while_loop` (a primitive) -and :py:func:`jax.lax.fori_loop` -(a helper that generates a while_loop primitive):: - - lax.while_loop(cond_fun: (C -> bool), body_fun: (C -> C), init: C) -> C - lax.fori_loop(start: int, end: int, body: (int -> C -> C), init: C) -> C - - -In the above signature, “C” stands for the type of the loop “carry” value. -For example, here is an example fori loop - ->>> import numpy as np ->>> ->>> def func10(arg, n): -... ones = jnp.ones(arg.shape) # A constant -... return lax.fori_loop(0, n, -... lambda i, carry: carry + ones * 3. + arg, -... arg + ones) -... ->>> print(make_jaxpr(func10)(np.ones(16), 5)) -{ lambda ; a:f32[16] b:i32[]. let - c:f32[16] = broadcast_in_dim[broadcast_dimensions=() shape=(16,)] 1.0 - d:f32[16] = add a c - _:i32[] _:i32[] e:f32[16] = while[ - body_jaxpr={ lambda ; f:f32[16] g:f32[16] h:i32[] i:i32[] j:f32[16]. let - k:i32[] = add h 1 - l:f32[16] = mul f 3.0 - m:f32[16] = add j l - n:f32[16] = add m g - in (k, i, n) } - body_nconsts=2 - cond_jaxpr={ lambda ; o:i32[] p:i32[] q:f32[16]. let - r:bool[] = lt o p - in (r,) } - cond_nconsts=0 - ] c a 0 b d - in (e,) } - -The while primitive takes 5 arguments: ``c a 0 b d``, as follows: - - * 0 constants for ``cond_jaxpr`` (since ``cond_nconsts`` is 0) - * 2 constants for ``body_jaxpr`` (``c``, and ``a``) - * 3 parameters for the initial value of carry - -Scan -^^^^ - -JAX supports a special form of loop over the elements of an array (with -statically known shape). The fact that there are a fixed number of iterations -makes this form of looping easily reverse-differentiable. Such loops are -constructed with the :py:func:`jax.lax.scan` function:: - - lax.scan(body_fun: (C -> A -> (C, B)), init_carry: C, in_arr: Array[A]) -> (C, Array[B]) - -This is written in terms of a `Haskell Type Signature`_: -``C`` is the type of the scan carry, ``A`` is the element type of the -input array(s), and ``B`` is the element type of the output array(s). - -For the example consider the function ``func11`` below - ->>> def func11(arr, extra): -... ones = jnp.ones(arr.shape) # A constant -... def body(carry, aelems): -... # carry: running dot-product of the two arrays -... # aelems: a pair with corresponding elements from the two arrays -... ae1, ae2 = aelems -... return (carry + ae1 * ae2 + extra, carry) -... return lax.scan(body, 0., (arr, ones)) -... ->>> print(make_jaxpr(func11)(np.ones(16), 5.)) -{ lambda ; a:f32[16] b:f32[]. let - c:f32[16] = broadcast_in_dim[broadcast_dimensions=() shape=(16,)] 1.0 - d:f32[] e:f32[16] = scan[ - _split_transpose=False - jaxpr={ lambda ; f:f32[] g:f32[] h:f32[] i:f32[]. let - j:f32[] = mul h i - k:f32[] = convert_element_type[new_dtype=float32 weak_type=False] g - l:f32[] = add k j - m:f32[] = convert_element_type[new_dtype=float32 weak_type=False] f - n:f32[] = add l m - in (n, g) } - length=16 - linear=(False, False, False, False) - num_carry=1 - num_consts=1 - reverse=False - unroll=1 - ] b 0.0 a c - in (d, e) } - -The ``linear`` parameter describes for each of the input variables whether they -are guaranteed to be used linearly in the body. Once the scan goes through -linearization, more arguments will be linear. - -The scan primitive takes 4 arguments: ``b 0.0 a c``, of which: - - * one is the free variable for the body - * one is the initial value of the carry - * The next 2 are the arrays over which the scan operates. - -XLA_call -^^^^^^^^ - -The call primitive arises from JIT compilation, and it encapsulates -a sub-jaxpr along with parameters that specify the backend and the device on -which the computation should run. For example - ->>> from jax import jit ->>> ->>> def func12(arg): -... @jit -... def inner(x): -... return x + arg * jnp.ones(1) # Include a constant in the inner function -... return arg + inner(arg - 2.) -... ->>> print(make_jaxpr(func12)(1.)) # doctest:+ELLIPSIS -{ lambda ; a:f32[]. let - b:f32[] = sub a 2.0 - c:f32[1] = pjit[ - name=inner - jaxpr={ lambda ; d:f32[] e:f32[]. let - f:f32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] 1.0 - g:f32[] = convert_element_type[new_dtype=float32 weak_type=False] d - h:f32[1] = mul g f - i:f32[] = convert_element_type[new_dtype=float32 weak_type=False] e - j:f32[1] = add i h - in (j,) } - ] a b - k:f32[] = convert_element_type[new_dtype=float32 weak_type=False] a - l:f32[1] = add k c - in (l,) } - - -XLA_pmap -^^^^^^^^ - -If you use the :py:func:`jax.pmap` transformation, the function to be mapped is -captured using the ``xla_pmap`` primitive. Consider this example - ->>> from jax import pmap ->>> ->>> def func13(arr, extra): -... def inner(x): -... # use a free variable "extra" and a constant jnp.ones(1) -... return (x + extra + jnp.ones(1)) / lax.psum(x, axis_name='rows') -... return pmap(inner, axis_name='rows')(arr) -... ->>> print(make_jaxpr(func13)(jnp.ones((1, 3)), 5.)) -{ lambda ; a:f32[1,3] b:f32[]. let - c:f32[1,3] = xla_pmap[ - axis_name=rows - axis_size=1 - backend=None - call_jaxpr={ lambda ; d:f32[] e:f32[3]. let - f:f32[] = convert_element_type[new_dtype=float32 weak_type=False] d - g:f32[3] = add e f - h:f32[1] = broadcast_in_dim[broadcast_dimensions=() shape=(1,)] 1.0 - i:f32[3] = add g h - j:f32[3] = psum[axes=('rows',) axis_index_groups=None] e - k:f32[3] = div i j - in (k,) } - devices=None - donated_invars=(False, False) - global_axis_size=1 - in_axes=(None, 0) - is_explicit_global_axis_size=False - name=inner - out_axes=(0,) - ] b a - in (c,) } - -The ``xla_pmap`` primitive specifies the name of the axis (parameter -``axis_name``) and the body of the function to be mapped as the ``call_jaxpr`` -parameter. The value of this parameter is a Jaxpr with 2 input variables. - -The parameter ``in_axes`` specifies which of the input variables should be -mapped and which should be broadcast. In our example, the value of ``extra`` -is broadcast and the value of ``arr`` is mapped. - -.. _Haskell Type Signature: https://wiki.haskell.org/Type_signature diff --git a/docs/jep/11830-new-remat-checkpoint.md b/docs/jep/11830-new-remat-checkpoint.md index da0adaf18060..019188349257 100644 --- a/docs/jep/11830-new-remat-checkpoint.md +++ b/docs/jep/11830-new-remat-checkpoint.md @@ -14,7 +14,7 @@ ## What’s going on? -As of [#11830](https://github.com/google/jax/pull/11830) we're switching on a new implementation of {func}`jax.checkpoint`, aka {func}`jax.remat` (the two names are aliases of one another). **For most code, there will be no changes.** But there may be some observable differences in edge cases; see [What are the possible issues after the upgrade?](#what-are-the-possible-issues-after-the-upgrade) +As of [#11830](https://github.com/jax-ml/jax/pull/11830) we're switching on a new implementation of {func}`jax.checkpoint`, aka {func}`jax.remat` (the two names are aliases of one another). **For most code, there will be no changes.** But there may be some observable differences in edge cases; see [What are the possible issues after the upgrade?](#what-are-the-possible-issues-after-the-upgrade) ## How can I disable the change, and go back to the old behavior for now? @@ -29,7 +29,7 @@ If you need to revert to the old implementation, **please reach out** on a GitHu As of `jax==0.3.17` the `jax_new_checkpoint` config option is no longer available. If you have an issue, please reach out on [the issue -tracker](https://github.com/google/jax/issues) so we can help fix it! +tracker](https://github.com/jax-ml/jax/issues) so we can help fix it! ## Why are we doing this? @@ -82,7 +82,7 @@ The old `jax.checkpoint` implementation was forced to save the value of `a`, whi ### Significantly less Python overhead in some cases -The new `jax.checkpoint` incurs significantly less Python overhead in some cases. [Simple overhead benchmarks](https://github.com/google/jax/blob/88636d2b649bfa31fa58a30ea15c925f35637397/benchmarks/api_benchmark.py#L511-L539) got 10x faster. These overheads only arise in eager op-by-op execution, so in the common case of using a `jax.checkpoint` under a `jax.jit` or similar the speedups aren't relevant. But still, nice! +The new `jax.checkpoint` incurs significantly less Python overhead in some cases. [Simple overhead benchmarks](https://github.com/jax-ml/jax/blob/88636d2b649bfa31fa58a30ea15c925f35637397/benchmarks/api_benchmark.py#L511-L539) got 10x faster. These overheads only arise in eager op-by-op execution, so in the common case of using a `jax.checkpoint` under a `jax.jit` or similar the speedups aren't relevant. But still, nice! ### Enabling new JAX features by simplifying internals diff --git a/docs/jep/12049-type-annotations.md b/docs/jep/12049-type-annotations.md index 9137e3e71232..7a20958c5cab 100644 --- a/docs/jep/12049-type-annotations.md +++ b/docs/jep/12049-type-annotations.md @@ -12,7 +12,7 @@ The current state of type annotations in JAX is a bit patchwork, and efforts to This doc attempts to summarize those issues and generate a roadmap for the goals and non-goals of type annotations in JAX. Why do we need such a roadmap? Better/more comprehensive type annotations are a frequent request from users, both internally and externally. -In addition, we frequently receive pull requests from external users (for example, [PR #9917](https://github.com/google/jax/pull/9917), [PR #10322](https://github.com/google/jax/pull/10322)) seeking to improve JAX's type annotations: it's not always clear to the JAX team member reviewing the code whether such contributions are beneficial, particularly when they introduce complex Protocols to address the challenges inherent to full-fledged annotation of JAX's use of Python. +In addition, we frequently receive pull requests from external users (for example, [PR #9917](https://github.com/jax-ml/jax/pull/9917), [PR #10322](https://github.com/jax-ml/jax/pull/10322)) seeking to improve JAX's type annotations: it's not always clear to the JAX team member reviewing the code whether such contributions are beneficial, particularly when they introduce complex Protocols to address the challenges inherent to full-fledged annotation of JAX's use of Python. This document details JAX's goals and recommendations for type annotations within the package. ## Why type annotations? @@ -21,7 +21,7 @@ There are a number of reasons that a Python project might wish to annotate their ### Level 1: Annotations as documentation -When originally introduced in [PEP 3107](https://peps.python.org/pep-3107/), type annotations were motivated partly by the ability to use them as concise, inline documentation of function parameter types and return types. JAX has long utilized annotations in this manner; an example is the common pattern of creating type names aliased to `Any`. An example can be found in `lax/slicing.py` [[source](https://github.com/google/jax/blob/2bc3e39cd9104071ee39dacac22abd51b94eb27e/jax/_src/lax/slicing.py#L47-L58)]: +When originally introduced in [PEP 3107](https://peps.python.org/pep-3107/), type annotations were motivated partly by the ability to use them as concise, inline documentation of function parameter types and return types. JAX has long utilized annotations in this manner; an example is the common pattern of creating type names aliased to `Any`. An example can be found in `lax/slicing.py` [[source](https://github.com/jax-ml/jax/blob/2bc3e39cd9104071ee39dacac22abd51b94eb27e/jax/_src/lax/slicing.py#L47-L58)]: ```python Array = Any @@ -44,14 +44,14 @@ Many modern IDEs take advantage of type annotations as inputs to [intelligent co This use of type checking requires going further than the simple aliases used above; for example, knowing that the `slice` function returns an alias of `Any` named `Array` does not add any useful information to the code completion engine. However, were we to annotate the function with a `DeviceArray` return type, the autocomplete would know how to populate the namespace of the result, and thus be able to suggest more relevant autocompletions during the course of development. -JAX has begun to add this level of type annotation in a few places; one example is the `jnp.ndarray` return type within the `jax.random` package [[source](https://github.com/google/jax/blob/2bc3e39cd9104071ee39dacac22abd51b94eb27e/jax/_src/random.py#L359)]: +JAX has begun to add this level of type annotation in a few places; one example is the `jnp.ndarray` return type within the `jax.random` package [[source](https://github.com/jax-ml/jax/blob/2bc3e39cd9104071ee39dacac22abd51b94eb27e/jax/_src/random.py#L359)]: ```python def shuffle(key: KeyArray, x: Array, axis: int = 0) -> jnp.ndarray: ... ``` -In this case `jnp.ndarray` is an abstract base class that forward-declares the attributes and methods of JAX arrays ([see source](https://github.com/google/jax/blob/2bc3e39cd9104071ee39dacac22abd51b94eb27e/jax/_src/numpy/ndarray.py#L41)), and so Pylance in VSCode can offer the full set of autocompletions on results from this function. Here is a screenshot showing the result: +In this case `jnp.ndarray` is an abstract base class that forward-declares the attributes and methods of JAX arrays ([see source](https://github.com/jax-ml/jax/blob/2bc3e39cd9104071ee39dacac22abd51b94eb27e/jax/_src/numpy/ndarray.py#L41)), and so Pylance in VSCode can offer the full set of autocompletions on results from this function. Here is a screenshot showing the result: ![VSCode Intellisense Screenshot](../_static/vscode-completion.png) @@ -232,7 +232,7 @@ assert jit(f)(x) # x will be a tracer ``` Again, there are a couple mechanisms that could be used for this: -- override `type(ArrayInstance).__instancecheck__` to return `True` for both `Array` and `Tracer` objects; this is how `jnp.ndarray` is currently implemented ([source](https://github.com/google/jax/blob/jax-v0.3.17/jax/_src/numpy/ndarray.py#L24-L49)). +- override `type(ArrayInstance).__instancecheck__` to return `True` for both `Array` and `Tracer` objects; this is how `jnp.ndarray` is currently implemented ([source](https://github.com/jax-ml/jax/blob/jax-v0.3.17/jax/_src/numpy/ndarray.py#L24-L49)). - define `ArrayInstance` as an abstract base class and dynamically register it to `Array` and `Tracer` - restructure `Array` and `Tracer` so that `ArrayInstance` is a true base class of both `Array` and `Tracer` diff --git a/docs/jep/15856-jex.md b/docs/jep/15856-jex.md index bec06000194e..a5625abf8930 100644 --- a/docs/jep/15856-jex.md +++ b/docs/jep/15856-jex.md @@ -170,7 +170,7 @@ print(jax.jit(mul_add_p.bind)(2, 3, 4)) # -> Array(10, dtype=int32) This module could expose our mechanism for defining new RNG implementations, and functions for working with PRNG key internals -(see issue [#9263](https://github.com/google/jax/issues/9263)), +(see issue [#9263](https://github.com/jax-ml/jax/issues/9263)), such as the current `jax._src.prng.random_wrap` and `random_unwrap`. diff --git a/docs/jep/18137-numpy-scipy-scope.md b/docs/jep/18137-numpy-scipy-scope.md index 2371e11ee07e..eaebe8fb8997 100644 --- a/docs/jep/18137-numpy-scipy-scope.md +++ b/docs/jep/18137-numpy-scipy-scope.md @@ -78,8 +78,8 @@ to JAX which have relatively complex implementations which are difficult to vali and introduce outsized maintenance burdens; an example is {func}`jax.scipy.special.bessel_jn`: as of the writing of this JEP, its current implementation is a non-straightforward iterative approximation that has -[convergence issues in some domains](https://github.com/google/jax/issues/12402#issuecomment-1384828637), -and [proposed fixes](https://github.com/google/jax/pull/17038/files) introduce further +[convergence issues in some domains](https://github.com/jax-ml/jax/issues/12402#issuecomment-1384828637), +and [proposed fixes](https://github.com/jax-ml/jax/pull/17038/files) introduce further complexity. Had we more carefully weighed the complexity and robustness of the implementation when accepting the contribution, we may have chosen not to accept this contribution to the package. diff --git a/docs/jep/2026-custom-derivatives.md b/docs/jep/2026-custom-derivatives.md index aa568adc0d9a..ce149fa6fb35 100644 --- a/docs/jep/2026-custom-derivatives.md +++ b/docs/jep/2026-custom-derivatives.md @@ -35,9 +35,9 @@ behavior of their code. This customization Python control flow and workflows for NaN debugging. As **JAX developers** we want to write library functions, like -[`logit`](https://github.com/google/jax/blob/01039299304b148b405ef9b9fa5e82bbb527471d/jax/scipy/special.py#L83) +[`logit`](https://github.com/jax-ml/jax/blob/01039299304b148b405ef9b9fa5e82bbb527471d/jax/scipy/special.py#L83) and -[`expit`](https://github.com/google/jax/blob/01039299304b148b405ef9b9fa5e82bbb527471d/jax/scipy/special.py#L91), +[`expit`](https://github.com/jax-ml/jax/blob/01039299304b148b405ef9b9fa5e82bbb527471d/jax/scipy/special.py#L91), that are defined in terms of other primitives, but for the purposes of differentiation have primitive-like behavior in the sense that we want to define custom differentiation rules for them, which may be more numerically stable or @@ -50,9 +50,9 @@ looking to add custom differentiation rules for higher-order functions like want to be confident we’re not going to preclude good solutions to that problem. That is, our primary goals are -1. solve the vmap-removes-custom-jvp semantics problem ([#1249](https://github.com/google/jax/issues/1249)), and +1. solve the vmap-removes-custom-jvp semantics problem ([#1249](https://github.com/jax-ml/jax/issues/1249)), and 2. allow Python in custom VJPs, e.g. to debug NaNs - ([#1275](https://github.com/google/jax/issues/1275)). + ([#1275](https://github.com/jax-ml/jax/issues/1275)). Secondary goals are 3. clean up and simplify user experience (symbolic zeros, kwargs, etc) @@ -60,18 +60,18 @@ Secondary goals are `odeint`, `root`, etc. Overall, we want to close -[#116](https://github.com/google/jax/issues/116), -[#1097](https://github.com/google/jax/issues/1097), -[#1249](https://github.com/google/jax/issues/1249), -[#1275](https://github.com/google/jax/issues/1275), -[#1366](https://github.com/google/jax/issues/1366), -[#1723](https://github.com/google/jax/issues/1723), -[#1670](https://github.com/google/jax/issues/1670), -[#1875](https://github.com/google/jax/issues/1875), -[#1938](https://github.com/google/jax/issues/1938), +[#116](https://github.com/jax-ml/jax/issues/116), +[#1097](https://github.com/jax-ml/jax/issues/1097), +[#1249](https://github.com/jax-ml/jax/issues/1249), +[#1275](https://github.com/jax-ml/jax/issues/1275), +[#1366](https://github.com/jax-ml/jax/issues/1366), +[#1723](https://github.com/jax-ml/jax/issues/1723), +[#1670](https://github.com/jax-ml/jax/issues/1670), +[#1875](https://github.com/jax-ml/jax/issues/1875), +[#1938](https://github.com/jax-ml/jax/issues/1938), and replace the custom_transforms machinery (from -[#636](https://github.com/google/jax/issues/636), -[#818](https://github.com/google/jax/issues/818), +[#636](https://github.com/jax-ml/jax/issues/636), +[#818](https://github.com/jax-ml/jax/issues/818), and others). ## Non-goals @@ -400,7 +400,7 @@ There are some other bells and whistles to the API: resolved to positions using the `inspect` module. This is a bit of an experiment with Python 3’s improved ability to programmatically inspect argument signatures. I believe it is sound but not complete, which is a fine place to be. - (See also [#2069](https://github.com/google/jax/issues/2069).) + (See also [#2069](https://github.com/jax-ml/jax/issues/2069).) * Arguments can be marked non-differentiable using `nondiff_argnums`, and as with `jit`’s `static_argnums` these arguments don’t have to be JAX types. We need to set a convention for how these arguments are passed to the rules. For a primal @@ -433,5 +433,5 @@ There are some other bells and whistles to the API: `custom_lin` to the tangent values; `custom_lin` carries with it the user’s custom backward-pass function, and as a primitive it only has a transpose rule. - * This mechanism is described more in [#636](https://github.com/google/jax/issues/636). + * This mechanism is described more in [#636](https://github.com/jax-ml/jax/issues/636). * To prevent diff --git a/docs/jep/4008-custom-vjp-update.md b/docs/jep/4008-custom-vjp-update.md index 65235dc64337..1e2270e052a6 100644 --- a/docs/jep/4008-custom-vjp-update.md +++ b/docs/jep/4008-custom-vjp-update.md @@ -9,7 +9,7 @@ notebook. ## What to update -After JAX [PR #4008](https://github.com/google/jax/pull/4008), the arguments +After JAX [PR #4008](https://github.com/jax-ml/jax/pull/4008), the arguments passed into a `custom_vjp` function's `nondiff_argnums` can't be `Tracer`s (or containers of `Tracer`s), which basically means to allow for arbitrarily-transformable code `nondiff_argnums` shouldn't be used for @@ -95,7 +95,7 @@ acted very much like lexical closure. But lexical closure over `Tracer`s wasn't at the time intended to work with `custom_jvp`/`custom_vjp`. Implementing `nondiff_argnums` that way was a mistake! -**[PR #4008](https://github.com/google/jax/pull/4008) fixes all lexical closure +**[PR #4008](https://github.com/jax-ml/jax/pull/4008) fixes all lexical closure issues with `custom_jvp` and `custom_vjp`.** Woohoo! That is, now `custom_jvp` and `custom_vjp` functions and rules can close over `Tracer`s to our hearts' content. For all non-autodiff transformations, things will Just Work. For @@ -120,9 +120,9 @@ manageable, until you think through how we have to handle arbitrary pytrees! Moreover, that complexity isn't necessary: if user code treats array-like non-differentiable arguments just like regular arguments and residuals, everything already works. (Before -[#4039](https://github.com/google/jax/pull/4039) JAX might've complained about +[#4039](https://github.com/jax-ml/jax/pull/4039) JAX might've complained about involving integer-valued inputs and outputs in autodiff, but after -[#4039](https://github.com/google/jax/pull/4039) those will just work!) +[#4039](https://github.com/jax-ml/jax/pull/4039) those will just work!) Unlike `custom_vjp`, it was easy to make `custom_jvp` work with `nondiff_argnums` arguments that were `Tracer`s. So these updates only need to diff --git a/docs/jep/4410-omnistaging.md b/docs/jep/4410-omnistaging.md index eb68ee5f0e0a..f95c15f404b6 100644 --- a/docs/jep/4410-omnistaging.md +++ b/docs/jep/4410-omnistaging.md @@ -20,7 +20,7 @@ This is more of an upgrade guide than a design doc. ### What's going on? A change to JAX's tracing infrastructure called “omnistaging” -([google/jax#3370](https://github.com/google/jax/pull/3370)) was switched on in +([jax-ml/jax#3370](https://github.com/jax-ml/jax/pull/3370)) was switched on in jax==0.2.0. This change improves memory performance, trace execution time, and simplifies jax internals, but may cause some existing code to break. Breakage is usually a result of buggy code, so long-term it’s best to fix the bugs, but @@ -191,7 +191,7 @@ and potentially even fragmenting memory. (The `broadcast` that corresponds to the construction of the zeros array for `jnp.zeros_like(x)` is staged out because JAX is lazy about very simple -expressions from [google/jax#1668](https://github.com/google/jax/pull/1668). After +expressions from [jax-ml/jax#1668](https://github.com/jax-ml/jax/pull/1668). After omnistaging, we can remove that lazy sublanguage and simplify JAX internals.) The reason the creation of `mask` is not staged out is that, before omnistaging, diff --git a/docs/jep/9263-typed-keys.md b/docs/jep/9263-typed-keys.md index 828b95e8ce00..d520f6f63df9 100644 --- a/docs/jep/9263-typed-keys.md +++ b/docs/jep/9263-typed-keys.md @@ -321,7 +321,7 @@ Why introduce extended dtypes in generality, beyond PRNGs? We reuse this same extended dtype mechanism elsewhere internally. For example, the `jax._src.core.bint` object, a bounded integer type used for experimental work on dynamic shapes, is another extended dtype. In recent JAX versions it satisfies -the properties above (See [jax/_src/core.py#L1789-L1802](https://github.com/google/jax/blob/jax-v0.4.14/jax/_src/core.py#L1789-L1802)). +the properties above (See [jax/_src/core.py#L1789-L1802](https://github.com/jax-ml/jax/blob/jax-v0.4.14/jax/_src/core.py#L1789-L1802)). ### PRNG dtypes PRNG dtypes are defined as a particular case of extended dtypes. Specifically, diff --git a/docs/jep/9407-type-promotion.ipynb b/docs/jep/9407-type-promotion.ipynb index 2aef1768112f..a1ede3177a3a 100644 --- a/docs/jep/9407-type-promotion.ipynb +++ b/docs/jep/9407-type-promotion.ipynb @@ -8,7 +8,7 @@ "source": [ "# Design of Type Promotion Semantics for JAX\n", "\n", - "[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/jax/blob/main/docs/jep/9407-type-promotion.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/google/jax/blob/main/docs/jep/9407-type-promotion.ipynb)\n", + "[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jax-ml/jax/blob/main/docs/jep/9407-type-promotion.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/jax-ml/jax/blob/main/docs/jep/9407-type-promotion.ipynb)\n", "\n", "*Jake VanderPlas, December 2021*\n", "\n", @@ -3317,7 +3317,6 @@ ], "source": [ "# @title\n", - "from jax import dtypes\n", "import jax\n", "import jax.numpy as jnp\n", "import pandas as pd\n", @@ -3802,7 +3801,6 @@ ], "source": [ "# @title\n", - "from jax import dtypes\n", "import jax\n", "import jax.numpy as jnp\n", "import pandas as pd\n", diff --git a/docs/jep/9407-type-promotion.md b/docs/jep/9407-type-promotion.md index 107bcd8c968b..ff67a8c21399 100644 --- a/docs/jep/9407-type-promotion.md +++ b/docs/jep/9407-type-promotion.md @@ -5,7 +5,7 @@ jupytext: extension: .md format_name: myst format_version: 0.13 - jupytext_version: 1.16.1 + jupytext_version: 1.16.4 kernelspec: display_name: Python 3 language: python @@ -16,7 +16,7 @@ kernelspec: # Design of Type Promotion Semantics for JAX -[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/jax/blob/main/docs/jep/9407-type-promotion.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/google/jax/blob/main/docs/jep/9407-type-promotion.ipynb) +[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jax-ml/jax/blob/main/docs/jep/9407-type-promotion.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/jax-ml/jax/blob/main/docs/jep/9407-type-promotion.ipynb) *Jake VanderPlas, December 2021* @@ -908,7 +908,6 @@ display.HTML(table.to_html()) :tags: [hide-input] # @title -from jax import dtypes import jax import jax.numpy as jnp import pandas as pd @@ -963,7 +962,6 @@ display.HTML(table.to_html()) :tags: [hide-input] # @title -from jax import dtypes import jax import jax.numpy as jnp import pandas as pd diff --git a/docs/jep/9419-jax-versioning.md b/docs/jep/9419-jax-versioning.md index 759a9be86713..b964aa2af45d 100644 --- a/docs/jep/9419-jax-versioning.md +++ b/docs/jep/9419-jax-versioning.md @@ -58,11 +58,11 @@ These constraints imply the following rules for releases: * If a new `jaxlib` is released, a `jax` release must be made at the same time. These -[version constraints](https://github.com/google/jax/blob/main/jax/version.py) +[version constraints](https://github.com/jax-ml/jax/blob/main/jax/version.py) are currently checked by `jax` at import time, instead of being expressed as Python package version constraints. `jax` checks the `jaxlib` version at runtime rather than using a `pip` package version constraint because we -[provide separate `jaxlib` wheels](https://github.com/google/jax#installation) +[provide separate `jaxlib` wheels](https://github.com/jax-ml/jax#installation) for a variety of hardware and software versions (e.g, GPU, TPU, etc.). Since we do not know which is the right choice for any given user, we do not want `pip` to install a `jaxlib` package for us automatically. @@ -119,7 +119,7 @@ no released `jax` version uses that API. ## How is the source to `jaxlib` laid out? `jaxlib` is split across two main repositories, namely the -[`jaxlib/` subdirectory in the main JAX repository](https://github.com/google/jax/tree/main/jaxlib) +[`jaxlib/` subdirectory in the main JAX repository](https://github.com/jax-ml/jax/tree/main/jaxlib) and in the [XLA source tree, which lives inside the XLA repository](https://github.com/openxla/xla). The JAX-specific pieces inside XLA are primarily in the @@ -146,7 +146,7 @@ level. `jaxlib` is built using Bazel out of the `jax` repository. The pieces of `jaxlib` from the XLA repository are incorporated into the build -[as a Bazel submodule](https://github.com/google/jax/blob/main/WORKSPACE). +[as a Bazel submodule](https://github.com/jax-ml/jax/blob/main/WORKSPACE). To update the version of XLA used during the build, one must update the pinned version in the Bazel `WORKSPACE`. This is done manually on an as-needed basis, but can be overridden on a build-by-build basis. diff --git a/docs/jep/index.rst b/docs/jep/index.rst index 194eb0cb9d69..f9dda2657ced 100644 --- a/docs/jep/index.rst +++ b/docs/jep/index.rst @@ -32,7 +32,7 @@ should be linked to this issue. Then create a pull request that adds a file named `%d-{short-title}.md` - with the number being the issue number. -.. _JEP label: https://github.com/google/jax/issues?q=label%3AJEP +.. _JEP label: https://github.com/jax-ml/jax/issues?q=label%3AJEP .. toctree:: :maxdepth: 1 diff --git a/docs/jit-compilation.md b/docs/jit-compilation.md index 2d442c8411aa..59c7bbd8fb90 100644 --- a/docs/jit-compilation.md +++ b/docs/jit-compilation.md @@ -5,7 +5,7 @@ jupytext: extension: .md format_name: myst format_version: 0.13 - jupytext_version: 1.16.1 + jupytext_version: 1.16.4 kernelspec: display_name: Python 3 language: python @@ -51,7 +51,7 @@ def log2(x): print(jax.make_jaxpr(log2)(3.0)) ``` -The {ref}`understanding-jaxprs` section of the documentation provides more information on the meaning of the above output. +The {ref}`jax-internals-jaxpr` section of the documentation provides more information on the meaning of the above output. Importantly, notice that the jaxpr does not capture the side-effect present in the function: there is nothing in it corresponding to `global_list.append(x)`. This is a feature, not a bug: JAX transformations are designed to understand side-effect-free (a.k.a. functionally pure) code. diff --git a/docs/key-concepts.md b/docs/key-concepts.md index 4b114c857460..daab2c9fdde4 100644 --- a/docs/key-concepts.md +++ b/docs/key-concepts.md @@ -5,7 +5,7 @@ jupytext: extension: .md format_name: myst format_version: 0.13 - jupytext_version: 1.16.1 + jupytext_version: 1.16.4 kernelspec: display_name: Python 3 language: python @@ -13,7 +13,7 @@ kernelspec: --- (key-concepts)= -# Key Concepts +# Key concepts @@ -23,13 +23,13 @@ This section briefly introduces some key concepts of the JAX package. ## JAX arrays ({class}`jax.Array`) The default array implementation in JAX is {class}`jax.Array`. In many ways it is similar to -the {class}`numpy.ndarray` type that you may be familar with from the NumPy package, but it +the {class}`numpy.ndarray` type that you may be familiar with from the NumPy package, but it has some important differences. ### Array creation We typically don't call the {class}`jax.Array` constructor directly, but rather create arrays via JAX API functions. -For example, {mod}`jax.numpy` provides familar NumPy-style array construction functionality +For example, {mod}`jax.numpy` provides familiar NumPy-style array construction functionality such as {func}`jax.numpy.zeros`, {func}`jax.numpy.linspace`, {func}`jax.numpy.arange`, etc. ```{code-cell} @@ -147,10 +147,10 @@ jaxprs later in {ref}`jax-internals-jaxpr`. ## Pytrees JAX functions and transformations fundamentally operate on arrays, but in practice it is -convenient to write code that work with collections of arrays: for example, a neural +convenient to write code that works with collection of arrays: for example, a neural network might organize its parameters in a dictionary of arrays with meaningful keys. Rather than handle such structures on a case-by-case basis, JAX relies on the {term}`pytree` -abstraction to treat such collections in a uniform matter. +abstraction to treat such collections in a uniform manner. Here are some examples of objects that can be treated as pytrees: diff --git a/docs/multi_process.md b/docs/multi_process.md index 7d7083bde10f..32cfae126784 100644 --- a/docs/multi_process.md +++ b/docs/multi_process.md @@ -1,4 +1,4 @@ -# Using JAX in multi-host and multi-process environments +# Multi-host and multi-process environments diff --git a/docs/notebooks/Common_Gotchas_in_JAX.ipynb b/docs/notebooks/Common_Gotchas_in_JAX.ipynb index 7ba192437a32..71bd4527644a 100644 --- a/docs/notebooks/Common_Gotchas_in_JAX.ipynb +++ b/docs/notebooks/Common_Gotchas_in_JAX.ipynb @@ -10,7 +10,7 @@ "\n", "\n", "\n", - "[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/jax/blob/main/docs/notebooks/Common_Gotchas_in_JAX.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/google/jax/blob/main/docs/notebooks/Common_Gotchas_in_JAX.ipynb)" + "[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jax-ml/jax/blob/main/docs/notebooks/Common_Gotchas_in_JAX.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/jax-ml/jax/blob/main/docs/notebooks/Common_Gotchas_in_JAX.ipynb)" ] }, { @@ -19,8 +19,6 @@ "id": "4k5PVzEo2uJO" }, "source": [ - "*levskaya@ mattjj@*\n", - "\n", "When walking about the countryside of Italy, the people will not hesitate to tell you that __JAX__ has [_\"una anima di pura programmazione funzionale\"_](https://www.sscardapane.it/iaml-backup/jax-intro/).\n", "\n", "__JAX__ is a language for __expressing__ and __composing__ __transformations__ of numerical programs. __JAX__ is also able to __compile__ numerical programs for CPU or accelerators (GPU/TPU).\n", @@ -226,7 +224,6 @@ ], "source": [ "import jax.numpy as jnp\n", - "import jax.lax as lax\n", "from jax import make_jaxpr\n", "\n", "# lax.fori_loop\n", @@ -258,7 +255,7 @@ "id": "oBdKtkVW8Lha" }, "source": [ - "## 🔪 In-Place Updates" + "## 🔪 In-place updates" ] }, { @@ -533,7 +530,7 @@ "id": "oZ_jE2WAypdL" }, "source": [ - "## 🔪 Out-of-Bounds Indexing" + "## 🔪 Out-of-bounds indexing" ] }, { @@ -664,7 +661,7 @@ "source": [ "Note that due to this behavior for index retrieval, functions like `jnp.nanargmin` and `jnp.nanargmax` return -1 for slices consisting of NaNs whereas Numpy would throw an error.\n", "\n", - "Note also that, as the two behaviors described above are not inverses of each other, reverse-mode automatic differentiation (which turns index updates into index retrievals and vice versa) [will not preserve the semantics of out of bounds indexing](https://github.com/google/jax/issues/5760). Thus it may be a good idea to think of out-of-bounds indexing in JAX as a case of [undefined behavior](https://en.wikipedia.org/wiki/Undefined_behavior)." + "Note also that, as the two behaviors described above are not inverses of each other, reverse-mode automatic differentiation (which turns index updates into index retrievals and vice versa) [will not preserve the semantics of out of bounds indexing](https://github.com/jax-ml/jax/issues/5760). Thus it may be a good idea to think of out-of-bounds indexing in JAX as a case of [undefined behavior](https://en.wikipedia.org/wiki/Undefined_behavior)." ] }, { @@ -868,7 +865,7 @@ "id": "MUycRNh6e50W" }, "source": [ - "## 🔪 Random Numbers" + "## 🔪 Random numbers" ] }, { @@ -888,7 +885,7 @@ "id": "Qikt9pPW9L5K" }, "source": [ - "### RNGs and State\n", + "### RNGs and state\n", "You're used to _stateful_ pseudorandom number generators (PRNGs) from numpy and other libraries, which helpfully hide a lot of details under the hood to give you a ready fountain of pseudorandomness:" ] }, @@ -1006,7 +1003,7 @@ "id": "COjzGBpO4tzL" }, "source": [ - "JAX instead implements an _explicit_ PRNG where entropy production and consumption are handled by explicitly passing and iterating PRNG state. JAX uses a modern [Threefry counter-based PRNG](https://github.com/google/jax/blob/main/docs/jep/263-prng.md) that's __splittable__. That is, its design allows us to __fork__ the PRNG state into new PRNGs for use with parallel stochastic generation.\n", + "JAX instead implements an _explicit_ PRNG where entropy production and consumption are handled by explicitly passing and iterating PRNG state. JAX uses a modern [Threefry counter-based PRNG](https://github.com/jax-ml/jax/blob/main/docs/jep/263-prng.md) that's __splittable__. That is, its design allows us to __fork__ the PRNG state into new PRNGs for use with parallel stochastic generation.\n", "\n", "The random state is described by a special array element that we call a __key__:" ] @@ -1031,7 +1028,6 @@ } ], "source": [ - "from jax import random\n", "key = random.key(0)\n", "key" ] @@ -1105,8 +1101,8 @@ "print(\"old key\", key)\n", "key, subkey = random.split(key)\n", "normal_pseudorandom = random.normal(subkey, shape=(1,))\n", - "print(\" \\---SPLIT --> new key \", key)\n", - "print(\" \\--> new subkey\", subkey, \"--> normal\", normal_pseudorandom)" + "print(r\" \\---SPLIT --> new key \", key)\n", + "print(r\" \\--> new subkey\", subkey, \"--> normal\", normal_pseudorandom)" ] }, { @@ -1140,8 +1136,8 @@ "print(\"old key\", key)\n", "key, subkey = random.split(key)\n", "normal_pseudorandom = random.normal(subkey, shape=(1,))\n", - "print(\" \\---SPLIT --> new key \", key)\n", - "print(\" \\--> new subkey\", subkey, \"--> normal\", normal_pseudorandom)" + "print(r\" \\---SPLIT --> new key \", key)\n", + "print(r\" \\--> new subkey\", subkey, \"--> normal\", normal_pseudorandom)" ] }, { @@ -1183,7 +1179,7 @@ "id": "rg4CpMZ8c3ri" }, "source": [ - "## 🔪 Control Flow" + "## 🔪 Control flow" ] }, { @@ -1192,7 +1188,7 @@ "id": "izLTvT24dAq0" }, "source": [ - "### ✔ python control_flow + autodiff ✔\n", + "### ✔ Python control_flow + autodiff ✔\n", "\n", "If you just want to apply `grad` to your python functions, you can use regular python control-flow constructs with no problems, as if you were using [Autograd](https://github.com/hips/autograd) (or Pytorch or TF Eager)." ] @@ -1231,7 +1227,7 @@ "id": "hIfPT7WMmZ2H" }, "source": [ - "### python control flow + JIT\n", + "### Python control flow + JIT\n", "\n", "Using control flow with `jit` is more complicated, and by default it has more constraints.\n", "\n", @@ -1353,7 +1349,7 @@ "\n", "For example, if we evaluate an `@jit` function on the array `jnp.array([1., 2., 3.], jnp.float32)`, we might want to compile code that we can reuse to evaluate the function on `jnp.array([4., 5., 6.], jnp.float32)` to save on compile time.\n", "\n", - "To get a view of your Python code that is valid for many different argument values, JAX traces it on _abstract values_ that represent sets of possible inputs. There are [multiple different levels of abstraction](https://github.com/google/jax/blob/main/jax/_src/abstract_arrays.py), and different transformations use different abstraction levels.\n", + "To get a view of your Python code that is valid for many different argument values, JAX traces it on _abstract values_ that represent sets of possible inputs. There are [multiple different levels of abstraction](https://github.com/jax-ml/jax/blob/main/jax/_src/abstract_arrays.py), and different transformations use different abstraction levels.\n", "\n", "By default, `jit` traces your code on the `ShapedArray` abstraction level, where each abstract value represents the set of all array values with a fixed shape and dtype. For example, if we trace using the abstract value `ShapedArray((3,), jnp.float32)`, we get a view of the function that can be reused for any concrete value in the corresponding set of arrays. That means we can save on compile time.\n", "\n", @@ -1701,7 +1697,7 @@ ], "source": [ "init_val = 0\n", - "cond_fun = lambda x: x<10\n", + "cond_fun = lambda x: x < 10\n", "body_fun = lambda x: x+1\n", "lax.while_loop(cond_fun, body_fun, init_val)\n", "# --> array(10, dtype=int32)" @@ -1791,7 +1787,7 @@ "id": "OxLsZUyRt_kF" }, "source": [ - "## 🔪 Dynamic Shapes" + "## 🔪 Dynamic shapes" ] }, { @@ -2194,7 +2190,7 @@ "id": "WAHjmL0E2XwO" }, "source": [ - "## 🔪 Miscellaneous Divergences from NumPy\n", + "## 🔪 Miscellaneous divergences from NumPy\n", "\n", "While `jax.numpy` makes every attempt to replicate the behavior of numpy's API, there do exist corner cases where the behaviors differ.\n", "Many such cases are discussed in detail in the sections above; here we list several other known places where the APIs diverge.\n", diff --git a/docs/notebooks/Common_Gotchas_in_JAX.md b/docs/notebooks/Common_Gotchas_in_JAX.md index 98c4b391c7ce..741fa3af063c 100644 --- a/docs/notebooks/Common_Gotchas_in_JAX.md +++ b/docs/notebooks/Common_Gotchas_in_JAX.md @@ -5,7 +5,7 @@ jupytext: extension: .md format_name: myst format_version: 0.13 - jupytext_version: 1.16.1 + jupytext_version: 1.16.4 kernelspec: display_name: Python 3 language: python @@ -18,12 +18,10 @@ kernelspec: -[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/jax/blob/main/docs/notebooks/Common_Gotchas_in_JAX.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/google/jax/blob/main/docs/notebooks/Common_Gotchas_in_JAX.ipynb) +[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jax-ml/jax/blob/main/docs/notebooks/Common_Gotchas_in_JAX.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/jax-ml/jax/blob/main/docs/notebooks/Common_Gotchas_in_JAX.ipynb) +++ {"id": "4k5PVzEo2uJO"} -*levskaya@ mattjj@* - When walking about the countryside of Italy, the people will not hesitate to tell you that __JAX__ has [_"una anima di pura programmazione funzionale"_](https://www.sscardapane.it/iaml-backup/jax-intro/). __JAX__ is a language for __expressing__ and __composing__ __transformations__ of numerical programs. __JAX__ is also able to __compile__ numerical programs for CPU or accelerators (GPU/TPU). @@ -130,7 +128,6 @@ It is not recommended to use iterators in any JAX function you want to `jit` or :outputId: 52d885fd-0239-4a08-f5ce-0c38cc008903 import jax.numpy as jnp -import jax.lax as lax from jax import make_jaxpr # lax.fori_loop @@ -158,7 +155,7 @@ iter_operand = iter(range(10)) +++ {"id": "oBdKtkVW8Lha"} -## 🔪 In-Place Updates +## 🔪 In-place updates +++ {"id": "JffAqnEW4JEb"} @@ -268,7 +265,7 @@ For more details on indexed array updates, see the [documentation for the `.at` +++ {"id": "oZ_jE2WAypdL"} -## 🔪 Out-of-Bounds Indexing +## 🔪 Out-of-bounds indexing +++ {"id": "btRFwEVzypdN"} @@ -315,7 +312,7 @@ jnp.arange(10.0).at[11].get(mode='fill', fill_value=jnp.nan) Note that due to this behavior for index retrieval, functions like `jnp.nanargmin` and `jnp.nanargmax` return -1 for slices consisting of NaNs whereas Numpy would throw an error. -Note also that, as the two behaviors described above are not inverses of each other, reverse-mode automatic differentiation (which turns index updates into index retrievals and vice versa) [will not preserve the semantics of out of bounds indexing](https://github.com/google/jax/issues/5760). Thus it may be a good idea to think of out-of-bounds indexing in JAX as a case of [undefined behavior](https://en.wikipedia.org/wiki/Undefined_behavior). +Note also that, as the two behaviors described above are not inverses of each other, reverse-mode automatic differentiation (which turns index updates into index retrievals and vice versa) [will not preserve the semantics of out of bounds indexing](https://github.com/jax-ml/jax/issues/5760). Thus it may be a good idea to think of out-of-bounds indexing in JAX as a case of [undefined behavior](https://en.wikipedia.org/wiki/Undefined_behavior). +++ {"id": "LwB07Kx5sgHu"} @@ -385,7 +382,7 @@ jnp.sum(jnp.array(x)) +++ {"id": "MUycRNh6e50W"} -## 🔪 Random Numbers +## 🔪 Random numbers +++ {"id": "O8vvaVt3MRG2"} @@ -395,7 +392,7 @@ jnp.sum(jnp.array(x)) +++ {"id": "Qikt9pPW9L5K"} -### RNGs and State +### RNGs and state You're used to _stateful_ pseudorandom number generators (PRNGs) from numpy and other libraries, which helpfully hide a lot of details under the hood to give you a ready fountain of pseudorandomness: ```{code-cell} ipython3 @@ -463,7 +460,7 @@ The Mersenne Twister PRNG is also known to have a [number](https://cs.stackexcha +++ {"id": "COjzGBpO4tzL"} -JAX instead implements an _explicit_ PRNG where entropy production and consumption are handled by explicitly passing and iterating PRNG state. JAX uses a modern [Threefry counter-based PRNG](https://github.com/google/jax/blob/main/docs/jep/263-prng.md) that's __splittable__. That is, its design allows us to __fork__ the PRNG state into new PRNGs for use with parallel stochastic generation. +JAX instead implements an _explicit_ PRNG where entropy production and consumption are handled by explicitly passing and iterating PRNG state. JAX uses a modern [Threefry counter-based PRNG](https://github.com/jax-ml/jax/blob/main/docs/jep/263-prng.md) that's __splittable__. That is, its design allows us to __fork__ the PRNG state into new PRNGs for use with parallel stochastic generation. The random state is described by a special array element that we call a __key__: @@ -471,7 +468,6 @@ The random state is described by a special array element that we call a __key__: :id: yPHE7KTWgAWs :outputId: ae8af0ee-f19e-474e-81b6-45e894eb2fc3 -from jax import random key = random.key(0) key ``` @@ -504,8 +500,8 @@ Instead, we __split__ the PRNG to get usable __subkeys__ every time we need a ne print("old key", key) key, subkey = random.split(key) normal_pseudorandom = random.normal(subkey, shape=(1,)) -print(" \---SPLIT --> new key ", key) -print(" \--> new subkey", subkey, "--> normal", normal_pseudorandom) +print(r" \---SPLIT --> new key ", key) +print(r" \--> new subkey", subkey, "--> normal", normal_pseudorandom) ``` +++ {"id": "tqtFVE4MthO3"} @@ -519,8 +515,8 @@ We propagate the __key__ and make new __subkeys__ whenever we need a new random print("old key", key) key, subkey = random.split(key) normal_pseudorandom = random.normal(subkey, shape=(1,)) -print(" \---SPLIT --> new key ", key) -print(" \--> new subkey", subkey, "--> normal", normal_pseudorandom) +print(r" \---SPLIT --> new key ", key) +print(r" \--> new subkey", subkey, "--> normal", normal_pseudorandom) ``` +++ {"id": "0KLYUluz3lN3"} @@ -538,11 +534,11 @@ for subkey in subkeys: +++ {"id": "rg4CpMZ8c3ri"} -## 🔪 Control Flow +## 🔪 Control flow +++ {"id": "izLTvT24dAq0"} -### ✔ python control_flow + autodiff ✔ +### ✔ Python control_flow + autodiff ✔ If you just want to apply `grad` to your python functions, you can use regular python control-flow constructs with no problems, as if you were using [Autograd](https://github.com/hips/autograd) (or Pytorch or TF Eager). @@ -562,7 +558,7 @@ print(grad(f)(4.)) # ok! +++ {"id": "hIfPT7WMmZ2H"} -### python control flow + JIT +### Python control flow + JIT Using control flow with `jit` is more complicated, and by default it has more constraints. @@ -627,7 +623,7 @@ When we `jit`-compile a function, we usually want to compile a version of the fu For example, if we evaluate an `@jit` function on the array `jnp.array([1., 2., 3.], jnp.float32)`, we might want to compile code that we can reuse to evaluate the function on `jnp.array([4., 5., 6.], jnp.float32)` to save on compile time. -To get a view of your Python code that is valid for many different argument values, JAX traces it on _abstract values_ that represent sets of possible inputs. There are [multiple different levels of abstraction](https://github.com/google/jax/blob/main/jax/_src/abstract_arrays.py), and different transformations use different abstraction levels. +To get a view of your Python code that is valid for many different argument values, JAX traces it on _abstract values_ that represent sets of possible inputs. There are [multiple different levels of abstraction](https://github.com/jax-ml/jax/blob/main/jax/_src/abstract_arrays.py), and different transformations use different abstraction levels. By default, `jit` traces your code on the `ShapedArray` abstraction level, where each abstract value represents the set of all array values with a fixed shape and dtype. For example, if we trace using the abstract value `ShapedArray((3,), jnp.float32)`, we get a view of the function that can be reused for any concrete value in the corresponding set of arrays. That means we can save on compile time. @@ -805,7 +801,7 @@ def while_loop(cond_fun, body_fun, init_val): :outputId: 552fe42f-4d32-4e25-c8c2-b951160a3f4e init_val = 0 -cond_fun = lambda x: x<10 +cond_fun = lambda x: x < 10 body_fun = lambda x: x+1 lax.while_loop(cond_fun, body_fun, init_val) # --> array(10, dtype=int32) @@ -865,7 +861,7 @@ $\ast$ = argument-value-independent loop condition - unrolls the loop +++ {"id": "OxLsZUyRt_kF"} -## 🔪 Dynamic Shapes +## 🔪 Dynamic shapes +++ {"id": "1tKXcAMduDR1"} @@ -1130,7 +1126,7 @@ x.dtype # --> dtype('float64') +++ {"id": "WAHjmL0E2XwO"} -## 🔪 Miscellaneous Divergences from NumPy +## 🔪 Miscellaneous divergences from NumPy While `jax.numpy` makes every attempt to replicate the behavior of numpy's API, there do exist corner cases where the behaviors differ. Many such cases are discussed in detail in the sections above; here we list several other known places where the APIs diverge. diff --git a/docs/notebooks/Custom_derivative_rules_for_Python_code.ipynb b/docs/notebooks/Custom_derivative_rules_for_Python_code.ipynb index 3abb6d9cbaec..5c09a0a4f732 100644 --- a/docs/notebooks/Custom_derivative_rules_for_Python_code.ipynb +++ b/docs/notebooks/Custom_derivative_rules_for_Python_code.ipynb @@ -6,13 +6,11 @@ "id": "LqiaKasFjH82" }, "source": [ - "# Custom derivative rules for JAX-transformable Python functions\n", + "# Custom derivative rules\n", "\n", "\n", "\n", - "[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/jax/blob/main/docs/notebooks/Custom_derivative_rules_for_Python_code.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/google/jax/blob/main/docs/notebooks/Custom_derivative_rules_for_Python_code.ipynb)\n", - "\n", - "*mattjj@ Mar 19 2020, last updated Oct 14 2020*\n", + "[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jax-ml/jax/blob/main/docs/notebooks/Custom_derivative_rules_for_Python_code.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/jax-ml/jax/blob/main/docs/notebooks/Custom_derivative_rules_for_Python_code.ipynb)\n", "\n", "There are two ways to define differentiation rules in JAX:\n", "\n", @@ -30,7 +28,7 @@ "id": "9Fg3NFNY-2RY" }, "source": [ - "## TL;DR" + "## Summary" ] }, { @@ -247,7 +245,6 @@ } ], "source": [ - "import jax.numpy as jnp\n", "\n", "def log1pexp(x):\n", " return jnp.log(1. + jnp.exp(x))\n", @@ -984,7 +981,7 @@ " (a, x_star, x_star_bar),\n", " x_star_bar))\n", " return a_bar, jnp.zeros_like(x_star)\n", - " \n", + "\n", "def rev_iter(f, packed, u):\n", " a, x_star, x_star_bar = packed\n", " _, vjp_x = vjp(lambda x: f(a, x), x_star)\n", @@ -1884,7 +1881,6 @@ } ], "source": [ - "from jax import vjp\n", "\n", "y, f_vjp = vjp(f, 3.)\n", "print(y)" @@ -1983,7 +1979,7 @@ " return x, x\n", "\n", "def debug_bwd(x, g):\n", - " import pdb; pdb.set_trace()\n", + " pdb.set_trace()\n", " return g\n", "\n", "debug.defvjp(debug_fwd, debug_bwd)" diff --git a/docs/notebooks/Custom_derivative_rules_for_Python_code.md b/docs/notebooks/Custom_derivative_rules_for_Python_code.md index ad577d55cd0d..8a9b931552d9 100644 --- a/docs/notebooks/Custom_derivative_rules_for_Python_code.md +++ b/docs/notebooks/Custom_derivative_rules_for_Python_code.md @@ -5,7 +5,7 @@ jupytext: extension: .md format_name: myst format_version: 0.13 - jupytext_version: 1.16.1 + jupytext_version: 1.16.4 kernelspec: display_name: Python 3 name: python3 @@ -13,13 +13,11 @@ kernelspec: +++ {"id": "LqiaKasFjH82"} -# Custom derivative rules for JAX-transformable Python functions +# Custom derivative rules -[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/jax/blob/main/docs/notebooks/Custom_derivative_rules_for_Python_code.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/google/jax/blob/main/docs/notebooks/Custom_derivative_rules_for_Python_code.ipynb) - -*mattjj@ Mar 19 2020, last updated Oct 14 2020* +[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jax-ml/jax/blob/main/docs/notebooks/Custom_derivative_rules_for_Python_code.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/jax-ml/jax/blob/main/docs/notebooks/Custom_derivative_rules_for_Python_code.ipynb) There are two ways to define differentiation rules in JAX: @@ -32,7 +30,7 @@ For an introduction to JAX's automatic differentiation API, see [The Autodiff Co +++ {"id": "9Fg3NFNY-2RY"} -## TL;DR +## Summary +++ {"id": "ZgMNRtXyWIW8"} @@ -145,7 +143,6 @@ Say we want to write a function called `log1pexp`, which computes $x \mapsto \lo :id: 6lWbTvs40ET- :outputId: 8caff99e-add1-4c70-ace3-212c0c5c6f4e -import jax.numpy as jnp def log1pexp(x): return jnp.log(1. + jnp.exp(x)) @@ -524,7 +521,7 @@ def fixed_point_rev(f, res, x_star_bar): (a, x_star, x_star_bar), x_star_bar)) return a_bar, jnp.zeros_like(x_star) - + def rev_iter(f, packed, u): a, x_star, x_star_bar = packed _, vjp_x = vjp(lambda x: f(a, x), x_star) @@ -965,7 +962,6 @@ print(grad(f)(3.)) :id: s1Pn_qCIODcF :outputId: 423d34e0-35b8-4b57-e89d-f70f20e28ea9 -from jax import vjp y, f_vjp = vjp(f, 3.) print(y) @@ -1015,7 +1011,7 @@ def debug_fwd(x): return x, x def debug_bwd(x, g): - import pdb; pdb.set_trace() + pdb.set_trace() return g debug.defvjp(debug_fwd, debug_bwd) diff --git a/docs/notebooks/Distributed_arrays_and_automatic_parallelization.ipynb b/docs/notebooks/Distributed_arrays_and_automatic_parallelization.ipynb index 2face1d4a0b2..32d332d9ac7e 100644 --- a/docs/notebooks/Distributed_arrays_and_automatic_parallelization.ipynb +++ b/docs/notebooks/Distributed_arrays_and_automatic_parallelization.ipynb @@ -17,22 +17,20 @@ "id": "pFtQjv4SzHRj" }, "source": [ - "[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/jax/blob/main/docs/notebooks/Distributed_arrays_and_automatic_parallelization.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/google/jax/blob/main/docs/notebooks/Distributed_arrays_and_automatic_parallelization.ipynb)\n", + "[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jax-ml/jax/blob/main/docs/notebooks/Distributed_arrays_and_automatic_parallelization.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/jax-ml/jax/blob/main/docs/notebooks/Distributed_arrays_and_automatic_parallelization.ipynb)\n", "\n", "This tutorial discusses parallelism via `jax.Array`, the unified array object model available in JAX v0.4.1 and newer." ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 1, "metadata": { "id": "FNxScTfq3vGF" }, "outputs": [], "source": [ - "import os\n", "\n", - "import functools\n", "from typing import Optional\n", "\n", "import numpy as np\n", @@ -52,7 +50,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 3, "metadata": { "id": "IZMLqOUV3vGG" }, @@ -70,7 +68,7 @@ "source": [ "## Intro and a quick example\n", "\n", - "By reading this tutorial notebook, you'll learn about `jax.Array`, a unified\n", + "By reading this tutorial notebook, you'll learn about `jax.Array`, a unified \n", "datatype for representing arrays, even with physical storage spanning multiple\n", "devices. You'll also learn about how using `jax.Array`s together with `jax.jit`\n", "can provide automatic compiler-based parallelization.\n", @@ -81,57 +79,76 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 4, "metadata": { "id": "Gf2lO4ii3vGG" }, "outputs": [], "source": [ "from jax.experimental import mesh_utils\n", - "from jax.sharding import PositionalSharding" + "from jax.sharding import Mesh, PartitionSpec as P, NamedSharding" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 5, "metadata": { "id": "q-XBTEoy3vGG" }, "outputs": [], "source": [ "# Create a Sharding object to distribute a value across devices:\n", - "sharding = PositionalSharding(mesh_utils.create_device_mesh((8,)))" + "mesh = Mesh(devices=mesh_utils.create_device_mesh((4, 2)),\n", + " axis_names=('x', 'y'))" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 6, "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 166 + }, "id": "vI39znW93vGH", - "outputId": "3b518df8-5c29-4848-acc3-e41df939f30b" + "outputId": "4f702753-8add-4b65-a4af-0f18f098cc46" }, "outputs": [ { - "name": "stdout", - "output_type": "stream", - "text": [ - "┌──────────┬──────────┐\n", - "│ TPU 0 │ TPU 1 │\n", - "├──────────┼──────────┤\n", - "│ TPU 2 │ TPU 3 │\n", - "├──────────┼──────────┤\n", - "│ TPU 6 │ TPU 7 │\n", - "├──────────┼──────────┤\n", - "│ TPU 4 │ TPU 5 │\n", - "└──────────┴──────────┘\n" - ] + "data": { + "text/html": [ + "
┌──────────┬──────────┐\n",
+       "│  TPU 0   │  TPU 1   │\n",
+       "├──────────┼──────────┤\n",
+       "│  TPU 2   │  TPU 3   │\n",
+       "├──────────┼──────────┤\n",
+       "│  TPU 6   │  TPU 7   │\n",
+       "├──────────┼──────────┤\n",
+       "│  TPU 4   │  TPU 5   │\n",
+       "└──────────┴──────────┘\n",
+       "
\n" + ], + "text/plain": [ + "┌──────────┬──────────┐\n", + "│ TPU \u001b[1;36m0\u001b[0m │ TPU \u001b[1;36m1\u001b[0m │\n", + "├──────────┼──────────┤\n", + "│ TPU \u001b[1;36m2\u001b[0m │ TPU \u001b[1;36m3\u001b[0m │\n", + "├──────────┼──────────┤\n", + "│ TPU \u001b[1;36m6\u001b[0m │ TPU \u001b[1;36m7\u001b[0m │\n", + "├──────────┼──────────┤\n", + "│ TPU \u001b[1;36m4\u001b[0m │ TPU \u001b[1;36m5\u001b[0m │\n", + "└──────────┴──────────┘\n" + ] + }, + "metadata": {}, + "output_type": "display_data" } ], "source": [ "# Create an array of random values:\n", "x = jax.random.normal(jax.random.key(0), (8192, 8192))\n", "# and use jax.device_put to distribute it across devices:\n", - "y = jax.device_put(x, sharding.reshape(4, 2))\n", + "y = jax.device_put(x, NamedSharding(mesh, P('x', 'y')))\n", "jax.debug.visualize_array_sharding(y)" ] }, @@ -147,26 +164,44 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 7, "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 166 + }, "id": "-qCnHZl83vGI", - "outputId": "9da9c29e-ce88-4425-e1ec-e93e5bcf3106" + "outputId": "0e131c23-5765-43ae-f232-6417ae1acbb2" }, "outputs": [ { - "name": "stdout", - "output_type": "stream", - "text": [ - "┌──────────┬──────────┐\n", - "│ TPU 0 │ TPU 1 │\n", - "├──────────┼──────────┤\n", - "│ TPU 2 │ TPU 3 │\n", - "├──────────┼──────────┤\n", - "│ TPU 6 │ TPU 7 │\n", - "├──────────┼──────────┤\n", - "│ TPU 4 │ TPU 5 │\n", - "└──────────┴──────────┘\n" - ] + "data": { + "text/html": [ + "
┌──────────┬──────────┐\n",
+       "│  TPU 0   │  TPU 1   │\n",
+       "├──────────┼──────────┤\n",
+       "│  TPU 2   │  TPU 3   │\n",
+       "├──────────┼──────────┤\n",
+       "│  TPU 6   │  TPU 7   │\n",
+       "├──────────┼──────────┤\n",
+       "│  TPU 4   │  TPU 5   │\n",
+       "└──────────┴──────────┘\n",
+       "
\n" + ], + "text/plain": [ + "┌──────────┬──────────┐\n", + "│ TPU \u001b[1;36m0\u001b[0m │ TPU \u001b[1;36m1\u001b[0m │\n", + "├──────────┼──────────┤\n", + "│ TPU \u001b[1;36m2\u001b[0m │ TPU \u001b[1;36m3\u001b[0m │\n", + "├──────────┼──────────┤\n", + "│ TPU \u001b[1;36m6\u001b[0m │ TPU \u001b[1;36m7\u001b[0m │\n", + "├──────────┼──────────┤\n", + "│ TPU \u001b[1;36m4\u001b[0m │ TPU \u001b[1;36m5\u001b[0m │\n", + "└──────────┴──────────┘\n" + ] + }, + "metadata": {}, + "output_type": "display_data" } ], "source": [ @@ -186,18 +221,21 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 8, "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, "id": "_VTzN0r03vGI", - "outputId": "c9208010-984b-442b-d105-c8c6a3a010e6" + "outputId": "c03eecab-4c86-4dac-d776-5fc72cbb5273" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "The slowest run took 13.32 times longer than the fastest. This could mean that an intermediate result is being cached \n", - "5 loops, best of 5: 9.69 ms per loop\n" + "The slowest run took 8.96 times longer than the fastest. This could mean that an intermediate result is being cached.\n", + "25.2 ms ± 30.9 ms per loop (mean ± std. dev. of 5 runs, 5 loops each)\n" ] } ], @@ -208,17 +246,20 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 9, "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, "id": "QuzhU1g63vGI", - "outputId": "d48fc76e-79a7-47b9-d392-b18a1c33c798" + "outputId": "8135cca0-871b-4b6a-a7e5-02e78c2028c7" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "5 loops, best of 5: 1.86 ms per loop\n" + "2.4 ms ± 61.4 µs per loop (mean ± std. dev. of 5 runs, 5 loops each)\n" ] } ], @@ -245,7 +286,7 @@ "id": "W6HsXauGxL6w" }, "source": [ - "### Sharding basics, and the `PositionalSharding` subclass" + "### Sharding basics, and the `NamedSharding` subclass" ] }, { @@ -263,7 +304,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 10, "metadata": { "id": "VmoX4SUp3vGJ" }, @@ -275,511 +316,109 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 11, "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 199 + }, "id": "vNRabO2J3vGJ", - "outputId": "73db7b6e-c2e7-467d-a0ef-c35e29e582dd" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "┌───────────────────────┐\n", - "│ │\n", - "│ │\n", - "│ │\n", - "│ │\n", - "│ TPU 0 │\n", - "│ │\n", - "│ │\n", - "│ │\n", - "│ │\n", - "└───────────────────────┘\n" - ] - } - ], - "source": [ - "jax.debug.visualize_array_sharding(x)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "HhCjhK0zXIqX" - }, - "source": [ - "Here, we're using the `jax.debug.visualize_array_sharding` function to show where the value `x` is stored in memory. All of `x` is stored on a single device, so the visualization is pretty boring!\n", - "\n", - "But we can shard `x` across multiple devices by using `jax.device_put` and a `Sharding` object. First, we make a `numpy.ndarray` of `Devices` using `mesh_utils.create_device_mesh`, which takes hardware topology into account for the `Device` order:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "VUIEIzRp3vGK" - }, - "outputs": [], - "source": [ - "from jax.experimental import mesh_utils\n", - "devices = mesh_utils.create_device_mesh((8,))" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "lbOKFWmBX1iv" - }, - "source": [ - "Then, we create a `PositionalSharding` and use it with `device_put`:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "jwrWfZeB3vGK", - "outputId": "e6f126bd-f6bd-48c7-c130-6f02757e3342" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "┌───────────────────────┐\n", - "│ TPU 0 │\n", - "├───────────────────────┤\n", - "│ TPU 1 │\n", - "├───────────────────────┤\n", - "│ TPU 2 │\n", - "├───────────────────────┤\n", - "│ TPU 3 │\n", - "├───────────────────────┤\n", - "│ TPU 6 │\n", - "├───────────────────────┤\n", - "│ TPU 7 │\n", - "├───────────────────────┤\n", - "│ TPU 4 │\n", - "├───────────────────────┤\n", - "│ TPU 5 │\n", - "└───────────────────────┘\n" - ] - } - ], - "source": [ - "from jax.sharding import PositionalSharding\n", - "\n", - "sharding = PositionalSharding(devices)\n", - "\n", - "x = jax.device_put(x, sharding.reshape(8, 1))\n", - "jax.debug.visualize_array_sharding(x)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "TUu69IWXZdTm" - }, - "source": [ - "Here `sharding` is a `PositionalSharding` which acts like an array with sets of devices as elements:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "zxWB82Kz3vGK", - "outputId": "11384a6b-fabc-4c4c-bcad-a3be51eb0465" + "outputId": "40fd7172-a16c-4dd8-e2e1-17bb3afe5409" }, "outputs": [ { "data": { + "text/html": [ + "
┌───────────────────────┐\n",
+       "│                       │\n",
+       "│                       │\n",
+       "│                       │\n",
+       "│                       │\n",
+       "│         TPU 0         │\n",
+       "│                       │\n",
+       "│                       │\n",
+       "│                       │\n",
+       "│                       │\n",
+       "└───────────────────────┘\n",
+       "
\n" + ], "text/plain": [ - "PositionalSharding([{TPU 0} {TPU 1} {TPU 2} {TPU 3} {TPU 6} {TPU 7} {TPU 4} {TPU 5}])" + "┌───────────────────────┐\n", + "│ │\n", + "│ │\n", + "│ │\n", + "│ │\n", + "│ TPU \u001b[1;36m0\u001b[0m │\n", + "│ │\n", + "│ │\n", + "│ │\n", + "│ │\n", + "└───────────────────────┘\n" ] }, - "execution_count": 13, "metadata": {}, - "output_type": "execute_result" + "output_type": "display_data" } ], "source": [ - "sharding" + "jax.debug.visualize_array_sharding(x)" ] }, { "cell_type": "markdown", "metadata": { - "id": "uRLpOcmNj_Vt" + "id": "HhCjhK0zXIqX" }, "source": [ - "The device numbers here are not in numerical order, because the mesh reflects the underlying toroidal topology of the device.\n", + "Here, we're using the `jax.debug.visualize_array_sharding` function to show where the value `x` is stored in memory. All of `x` is stored on a single device, so the visualization is pretty boring!\n", "\n", - "By writing `PositionalSharding(ndarray_of_devices)`, we fix the device order and the initial shape. Then we can reshape it:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "PLsnpSzc3vGL", - "outputId": "9f4db733-cafe-46ae-c057-dc31046a6f66" - }, - "outputs": [ - { - "data": { - "text/plain": [ - "PositionalSharding([[{TPU 0}]\n", - " [{TPU 1}]\n", - " [{TPU 2}]\n", - " [{TPU 3}]\n", - " [{TPU 6}]\n", - " [{TPU 7}]\n", - " [{TPU 4}]\n", - " [{TPU 5}]])" - ] - }, - "execution_count": 14, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "sharding.reshape(8, 1)" + "But we can shard `x` across multiple devices by using `jax.device_put` and a `Sharding` object. First, we make a `numpy.ndarray` of `Devices` using `mesh_utils.create_device_mesh`, which takes hardware topology into account for the `Device` order:" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 16, "metadata": { - "id": "iqKdI4LO3vGL", - "outputId": "6aa10fc2-cec4-4401-a0df-343e71646e0a" + "colab": { + "base_uri": "https://localhost:8080/", + "height": 166 + }, + "id": "zpB1JxyK3vGN", + "outputId": "8e385462-1c2c-4256-c38a-84299d3bd02c" }, "outputs": [ { "data": { + "text/html": [ + "
┌──────────┬──────────┐\n",
+       "│  TPU 0   │  TPU 1   │\n",
+       "├──────────┼──────────┤\n",
+       "│  TPU 2   │  TPU 3   │\n",
+       "├──────────┼──────────┤\n",
+       "│  TPU 6   │  TPU 7   │\n",
+       "├──────────┼──────────┤\n",
+       "│  TPU 4   │  TPU 5   │\n",
+       "└──────────┴──────────┘\n",
+       "
\n" + ], "text/plain": [ - "PositionalSharding([[{TPU 0} {TPU 1}]\n", - " [{TPU 2} {TPU 3}]\n", - " [{TPU 6} {TPU 7}]\n", - " [{TPU 4} {TPU 5}]])" + "┌──────────┬──────────┐\n", + "│ TPU \u001b[1;36m0\u001b[0m │ TPU \u001b[1;36m1\u001b[0m │\n", + "├──────────┼──────────┤\n", + "│ TPU \u001b[1;36m2\u001b[0m │ TPU \u001b[1;36m3\u001b[0m │\n", + "├──────────┼──────────┤\n", + "│ TPU \u001b[1;36m6\u001b[0m │ TPU \u001b[1;36m7\u001b[0m │\n", + "├──────────┼──────────┤\n", + "│ TPU \u001b[1;36m4\u001b[0m │ TPU \u001b[1;36m5\u001b[0m │\n", + "└──────────┴──────────┘\n" ] }, - "execution_count": 15, "metadata": {}, - "output_type": "execute_result" + "output_type": "display_data" } ], "source": [ - "sharding.reshape(4, 2)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "KBu6WLfhm7ra" - }, - "source": [ - "To use `device_put` with a data array `x`, we can reshape the `sharding` into a shape that is _congruent_ with `x.shape`, meaning a shape with the same length as `x.shape` and where each element evenly divides the corresponding element of `x.shape`:\n", - "```python\n", - "def is_congruent(x_shape: Sequence[int], sharding_shape: Sequence[int]) -> bool:\n", - " return (len(x_shape) == len(sharding_shape) and\n", - " all(d1 % d2 == 0 for d1, d2 in zip(x_shape, sharding_shape)))\n", - "```\n", - "\n", - "For example, we can reshape `sharding` to have shape `(4, 2)`, then use it in a `device_put`:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "SELr4xNi3vGL", - "outputId": "b2f4acec-0cd3-4829-ca16-cae2e0e8ca60" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "PositionalSharding([[{TPU 0} {TPU 1}]\n", - " [{TPU 2} {TPU 3}]\n", - " [{TPU 6} {TPU 7}]\n", - " [{TPU 4} {TPU 5}]])\n" - ] - } - ], - "source": [ - "sharding = sharding.reshape(4, 2)\n", - "print(sharding)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "8IVIsqfX3vGL", - "outputId": "033d0e02-a643-4f4c-9d24-9cd8465bc69a" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "┌──────────┬──────────┐\n", - "│ TPU 0 │ TPU 1 │\n", - "├──────────┼──────────┤\n", - "│ TPU 2 │ TPU 3 │\n", - "├──────────┼──────────┤\n", - "│ TPU 6 │ TPU 7 │\n", - "├──────────┼──────────┤\n", - "│ TPU 4 │ TPU 5 │\n", - "└──────────┴──────────┘\n" - ] - } - ], - "source": [ - "y = jax.device_put(x, sharding)\n", - "jax.debug.visualize_array_sharding(y)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "tyg9F-UIsU__" - }, - "source": [ - "Here `y` represents the same _value_ as `x`, but its shards (i.e. slices) are stored in different devices' memories.\n", - "\n", - "Different `PositionalSharding` shapes result in different distributed layouts (i.e. shardings) of the result:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "cCjt6QCz3vGM", - "outputId": "4ad8a611-596d-424f-b6c5-fc00f1adc306" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "PositionalSharding([[{TPU 0} {TPU 1} {TPU 2} {TPU 3} {TPU 6} {TPU 7} {TPU 4} {TPU 5}]])\n" - ] - } - ], - "source": [ - "sharding = sharding.reshape(1, 8)\n", - "print(sharding)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "yTK4Nz3u3vGM", - "outputId": "e445c6bc-4fe3-4e9d-cc9e-d82858f58312" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "┌───────┬───────┬───────┬───────┬───────┬───────┬───────┬───────┐\n", - "│ │ │ │ │ │ │ │ │\n", - "│ │ │ │ │ │ │ │ │\n", - "│ │ │ │ │ │ │ │ │\n", - "│ │ │ │ │ │ │ │ │\n", - "│ TPU 0 │ TPU 1 │ TPU 2 │ TPU 3 │ TPU 6 │ TPU 7 │ TPU 4 │ TPU 5 │\n", - "│ │ │ │ │ │ │ │ │\n", - "│ │ │ │ │ │ │ │ │\n", - "│ │ │ │ │ │ │ │ │\n", - "│ │ │ │ │ │ │ │ │\n", - "└───────┴───────┴───────┴───────┴───────┴───────┴───────┴───────┘\n" - ] - } - ], - "source": [ - "y = jax.device_put(x, sharding)\n", - "jax.debug.visualize_array_sharding(y)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "0PuamOvXubcf" - }, - "source": [ - "In some cases, we don't just want to store each slice of `x` in a single device's memory; we might want to _replicate_ some slices, meaning storing copies of a slice's values in multiple devices' memories.\n", - "\n", - "With `PositionalSharding`, we can express replication by calling the reducer method `replicate`:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "_jr6XYKx3vGM", - "outputId": "59c8b9a4-b8af-493a-ba8d-da5931e88f93" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "PositionalSharding([[{TPU 0, 2, 4, 6} {TPU 1, 3, 5, 7}]])\n" - ] - } - ], - "source": [ - "sharding = sharding.reshape(4, 2)\n", - "print(sharding.replicate(axis=0, keepdims=True))" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "S5vzjFuH3vGN", - "outputId": "b6ce2675-7261-4e57-fa8c-b4e87abf7e52" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "┌───────────┬───────────┐\n", - "│ │ │\n", - "│ │ │\n", - "│ │ │\n", - "│ │ │\n", - "│TPU 0,2,4,6│TPU 1,3,5,7│\n", - "│ │ │\n", - "│ │ │\n", - "│ │ │\n", - "│ │ │\n", - "└───────────┴───────────┘\n" - ] - } - ], - "source": [ - "y = jax.device_put(x, sharding.replicate(axis=0, keepdims=True))\n", - "jax.debug.visualize_array_sharding(y)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "FzeP0kpTvJv-" - }, - "source": [ - "Here the visualization shows that `x` is sharded two ways along its second dimension (and not sharded along the first dimension), and each of those shards is replicated four ways (i.e. stored in four device memories).\n", - "\n", - "The `replicate` method is analogous to the familiar NumPy array reduction methods like `.sum()` and `.prod()`. It operates along an axis performing a set union. So if `sharding` has shape `(4, 2)`, then `sharding.replicate(0, keepdims=True)` has shape `(1, 2)`, and `sharding.replicate(1, keepdims=True)` has shape `(4, 1)`. Unlike analogous NumPy methods, `keepdims=True` is actually the default, so reduced-over axes aren't squeezed:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "DR7VV-6e3vGN", - "outputId": "f879fc2c-5723-4199-b306-295bc1b3681e" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "(1, 2)\n", - "(4, 1)\n" - ] - } - ], - "source": [ - "print(sharding.replicate(0).shape)\n", - "print(sharding.replicate(1).shape)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "agUtVUVx3vGN", - "outputId": "0e9789ef-ce52-4ed6-8bd5-c876b95f66e6" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "┌───────────────────────┐\n", - "│ TPU 0,1 │\n", - "├───────────────────────┤\n", - "│ TPU 2,3 │\n", - "├───────────────────────┤\n", - "│ TPU 6,7 │\n", - "├───────────────────────┤\n", - "│ TPU 4,5 │\n", - "└───────────────────────┘\n" - ] - } - ], - "source": [ - "y = jax.device_put(x, sharding.replicate(1))\n", - "jax.debug.visualize_array_sharding(y)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "D31t5POXxHHJ" - }, - "source": [ - "### `NamedSharding` gives a way to express shardings with names" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "ayMKWeTmxl-X" - }, - "source": [ - "So far we've worked with `PositionalSharding`, but there are alternative ways to express shardings. In fact, `Sharding` is an interface, and any class that implements that interface can be used with functions like `device_put`.\n", - "\n", - "Another convenient way to express sharding is with the `NamedSharding`:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "zpB1JxyK3vGN", - "outputId": "46d5da37-840c-49d8-8380-a162811bae8a" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "┌──────────┬──────────┐\n", - "│ TPU 0 │ TPU 1 │\n", - "├──────────┼──────────┤\n", - "│ TPU 2 │ TPU 3 │\n", - "├──────────┼──────────┤\n", - "│ TPU 6 │ TPU 7 │\n", - "├──────────┼──────────┤\n", - "│ TPU 4 │ TPU 5 │\n", - "└──────────┴──────────┘\n" - ] - } - ], - "source": [ - "from jax.sharding import Mesh\n", - "from jax.sharding import PartitionSpec\n", - "from jax.sharding import NamedSharding\n", + "from jax.sharding import Mesh, PartitionSpec, NamedSharding\n", "from jax.experimental import mesh_utils\n", "\n", "P = PartitionSpec\n", @@ -801,7 +440,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 17, "metadata": { "id": "8g0Md2Gd3vGO" }, @@ -820,26 +459,44 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 18, "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 166 + }, "id": "zp3MfS4Y3vGO", - "outputId": "2c2f7201-c2c1-49e5-f8a5-0730c124d89a" + "outputId": "032fdd7e-19a1-45da-e1ad-b3227fa43ee6" }, "outputs": [ { - "name": "stdout", - "output_type": "stream", - "text": [ - "┌──────────┬──────────┐\n", - "│ TPU 0 │ TPU 1 │\n", - "├──────────┼──────────┤\n", - "│ TPU 2 │ TPU 3 │\n", - "├──────────┼──────────┤\n", - "│ TPU 6 │ TPU 7 │\n", - "├──────────┼──────────┤\n", - "│ TPU 4 │ TPU 5 │\n", - "└──────────┴──────────┘\n" - ] + "data": { + "text/html": [ + "
┌──────────┬──────────┐\n",
+       "│  TPU 0   │  TPU 1   │\n",
+       "├──────────┼──────────┤\n",
+       "│  TPU 2   │  TPU 3   │\n",
+       "├──────────┼──────────┤\n",
+       "│  TPU 6   │  TPU 7   │\n",
+       "├──────────┼──────────┤\n",
+       "│  TPU 4   │  TPU 5   │\n",
+       "└──────────┴──────────┘\n",
+       "
\n" + ], + "text/plain": [ + "┌──────────┬──────────┐\n", + "│ TPU \u001b[1;36m0\u001b[0m │ TPU \u001b[1;36m1\u001b[0m │\n", + "├──────────┼──────────┤\n", + "│ TPU \u001b[1;36m2\u001b[0m │ TPU \u001b[1;36m3\u001b[0m │\n", + "├──────────┼──────────┤\n", + "│ TPU \u001b[1;36m6\u001b[0m │ TPU \u001b[1;36m7\u001b[0m │\n", + "├──────────┼──────────┤\n", + "│ TPU \u001b[1;36m4\u001b[0m │ TPU \u001b[1;36m5\u001b[0m │\n", + "└──────────┴──────────┘\n" + ] + }, + "metadata": {}, + "output_type": "display_data" } ], "source": [ @@ -858,28 +515,48 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 19, "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 199 + }, "id": "FigK5Zsa3vGO", - "outputId": "eca784e8-33fe-4e9b-a41d-21e9ee781a35" + "outputId": "e488d073-9d02-4376-a6af-19d6d5509c7d" }, "outputs": [ { - "name": "stdout", - "output_type": "stream", - "text": [ - "┌───────┬───────┬───────┬───────┐\n", - "│ │ │ │ │\n", - "│ TPU 0 │ TPU 2 │ TPU 6 │ TPU 4 │\n", - "│ │ │ │ │\n", - "│ │ │ │ │\n", - "├───────┼───────┼───────┼───────┤\n", - "│ │ │ │ │\n", - "│ TPU 1 │ TPU 3 │ TPU 7 │ TPU 5 │\n", - "│ │ │ │ │\n", - "│ │ │ │ │\n", - "└───────┴───────┴───────┴───────┘\n" - ] + "data": { + "text/html": [ + "
┌───────┬───────┬───────┬───────┐\n",
+       "│       │       │       │       │\n",
+       "│ TPU 0 │ TPU 2 │ TPU 6 │ TPU 4 │\n",
+       "│       │       │       │       │\n",
+       "│       │       │       │       │\n",
+       "├───────┼───────┼───────┼───────┤\n",
+       "│       │       │       │       │\n",
+       "│ TPU 1 │ TPU 3 │ TPU 7 │ TPU 5 │\n",
+       "│       │       │       │       │\n",
+       "│       │       │       │       │\n",
+       "└───────┴───────┴───────┴───────┘\n",
+       "
\n" + ], + "text/plain": [ + "┌───────┬───────┬───────┬───────┐\n", + "│ │ │ │ │\n", + "│ TPU \u001b[1;36m0\u001b[0m │ TPU \u001b[1;36m2\u001b[0m │ TPU \u001b[1;36m6\u001b[0m │ TPU \u001b[1;36m4\u001b[0m │\n", + "│ │ │ │ │\n", + "│ │ │ │ │\n", + "├───────┼───────┼───────┼───────┤\n", + "│ │ │ │ │\n", + "│ TPU \u001b[1;36m1\u001b[0m │ TPU \u001b[1;36m3\u001b[0m │ TPU \u001b[1;36m7\u001b[0m │ TPU \u001b[1;36m5\u001b[0m │\n", + "│ │ │ │ │\n", + "│ │ │ │ │\n", + "└───────┴───────┴───────┴───────┘\n" + ] + }, + "metadata": {}, + "output_type": "display_data" } ], "source": [ @@ -889,26 +566,44 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 20, "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 166 + }, "id": "hI-HD0xN3vGO", - "outputId": "c3e7dc3c-4048-448a-ef0b-50683532fcdc" + "outputId": "b0c2e863-3aee-4417-b45f-21b2187f6ef7" }, "outputs": [ { - "name": "stdout", - "output_type": "stream", - "text": [ - "┌───────────────────────┐\n", - "│ TPU 0,1 │\n", - "├───────────────────────┤\n", - "│ TPU 2,3 │\n", - "├───────────────────────┤\n", - "│ TPU 6,7 │\n", - "├───────────────────────┤\n", - "│ TPU 4,5 │\n", - "└───────────────────────┘\n" - ] + "data": { + "text/html": [ + "
┌───────────────────────┐\n",
+       "│        TPU 0,1        │\n",
+       "├───────────────────────┤\n",
+       "│        TPU 2,3        │\n",
+       "├───────────────────────┤\n",
+       "│        TPU 6,7        │\n",
+       "├───────────────────────┤\n",
+       "│        TPU 4,5        │\n",
+       "└───────────────────────┘\n",
+       "
\n" + ], + "text/plain": [ + "┌───────────────────────┐\n", + "│ TPU \u001b[1;36m0\u001b[0m,\u001b[1;36m1\u001b[0m │\n", + "├───────────────────────┤\n", + "│ TPU \u001b[1;36m2\u001b[0m,\u001b[1;36m3\u001b[0m │\n", + "├───────────────────────┤\n", + "│ TPU \u001b[1;36m6\u001b[0m,\u001b[1;36m7\u001b[0m │\n", + "├───────────────────────┤\n", + "│ TPU \u001b[1;36m4\u001b[0m,\u001b[1;36m5\u001b[0m │\n", + "└───────────────────────┘\n" + ] + }, + "metadata": {}, + "output_type": "display_data" } ], "source": [ @@ -932,28 +627,48 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 21, "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 199 + }, "id": "EXBExMQC3vGP", - "outputId": "fe1c8d7e-3345-4438-b9d2-780e7854b4eb" + "outputId": "c80e6177-12a6-40ef-b4e4-934dad22da3d" }, "outputs": [ { - "name": "stdout", - "output_type": "stream", - "text": [ - "┌───────────┬───────────┐\n", - "│ │ │\n", - "│ │ │\n", - "│ │ │\n", - "│ │ │\n", - "│TPU 0,2,4,6│TPU 1,3,5,7│\n", - "│ │ │\n", - "│ │ │\n", - "│ │ │\n", - "│ │ │\n", - "└───────────┴───────────┘\n" - ] + "data": { + "text/html": [ + "
┌───────────┬───────────┐\n",
+       "│           │           │\n",
+       "│           │           │\n",
+       "│           │           │\n",
+       "│           │           │\n",
+       "│TPU 0,2,4,6│TPU 1,3,5,7│\n",
+       "│           │           │\n",
+       "│           │           │\n",
+       "│           │           │\n",
+       "│           │           │\n",
+       "└───────────┴───────────┘\n",
+       "
\n" + ], + "text/plain": [ + "┌───────────┬───────────┐\n", + "│ │ │\n", + "│ │ │\n", + "│ │ │\n", + "│ │ │\n", + "│TPU \u001b[1;36m0\u001b[0m,\u001b[1;36m2\u001b[0m,\u001b[1;36m4\u001b[0m,\u001b[1;36m6\u001b[0m│TPU \u001b[1;36m1\u001b[0m,\u001b[1;36m3\u001b[0m,\u001b[1;36m5\u001b[0m,\u001b[1;36m7\u001b[0m│\n", + "│ │ │\n", + "│ │ │\n", + "│ │ │\n", + "│ │ │\n", + "└───────────┴───────────┘\n" + ] + }, + "metadata": {}, + "output_type": "display_data" } ], "source": [ @@ -963,28 +678,48 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 22, "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 199 + }, "id": "PjUpG8uz3vGP", - "outputId": "64d8224d-15d9-4ad4-d613-f7f85b1dc1af" + "outputId": "a0f59dc5-b509-4b8b-bd22-bcd69f696763" }, "outputs": [ { - "name": "stdout", - "output_type": "stream", - "text": [ - "┌───────┬───────┬───────┬───────┐\n", - "│ │ │ │ │\n", - "│ │ │ │ │\n", - "│ │ │ │ │\n", - "│ │ │ │ │\n", - "│TPU 0,1│TPU 2,3│TPU 6,7│TPU 4,5│\n", - "│ │ │ │ │\n", - "│ │ │ │ │\n", - "│ │ │ │ │\n", - "│ │ │ │ │\n", - "└───────┴───────┴───────┴───────┘\n" - ] + "data": { + "text/html": [ + "
┌───────┬───────┬───────┬───────┐\n",
+       "│       │       │       │       │\n",
+       "│       │       │       │       │\n",
+       "│       │       │       │       │\n",
+       "│       │       │       │       │\n",
+       "│TPU 0,1│TPU 2,3│TPU 6,7│TPU 4,5│\n",
+       "│       │       │       │       │\n",
+       "│       │       │       │       │\n",
+       "│       │       │       │       │\n",
+       "│       │       │       │       │\n",
+       "└───────┴───────┴───────┴───────┘\n",
+       "
\n" + ], + "text/plain": [ + "┌───────┬───────┬───────┬───────┐\n", + "│ │ │ │ │\n", + "│ │ │ │ │\n", + "│ │ │ │ │\n", + "│ │ │ │ │\n", + "│TPU \u001b[1;36m0\u001b[0m,\u001b[1;36m1\u001b[0m│TPU \u001b[1;36m2\u001b[0m,\u001b[1;36m3\u001b[0m│TPU \u001b[1;36m6\u001b[0m,\u001b[1;36m7\u001b[0m│TPU \u001b[1;36m4\u001b[0m,\u001b[1;36m5\u001b[0m│\n", + "│ │ │ │ │\n", + "│ │ │ │ │\n", + "│ │ │ │ │\n", + "│ │ │ │ │\n", + "└───────┴───────┴───────┴───────┘\n" + ] + }, + "metadata": {}, + "output_type": "display_data" } ], "source": [ @@ -1003,34 +738,60 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 23, "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 298 + }, "id": "fVcPbDUA3vGP", - "outputId": "7f524ba5-a6d8-4490-cda9-685ad11416f9" + "outputId": "da3f435d-dfc1-4a41-ec90-691cd7c748a0" }, "outputs": [ { - "name": "stdout", - "output_type": "stream", - "text": [ - "┌───────────────────────┐\n", - "│ TPU 0 │\n", - "├───────────────────────┤\n", - "│ TPU 1 │\n", - "├───────────────────────┤\n", - "│ TPU 2 │\n", - "├───────────────────────┤\n", - "│ TPU 3 │\n", - "├───────────────────────┤\n", - "│ TPU 6 │\n", - "├───────────────────────┤\n", - "│ TPU 7 │\n", - "├───────────────────────┤\n", - "│ TPU 4 │\n", - "├───────────────────────┤\n", - "│ TPU 5 │\n", - "└───────────────────────┘\n" - ] + "data": { + "text/html": [ + "
┌───────────────────────┐\n",
+       "│         TPU 0         │\n",
+       "├───────────────────────┤\n",
+       "│         TPU 1         │\n",
+       "├───────────────────────┤\n",
+       "│         TPU 2         │\n",
+       "├───────────────────────┤\n",
+       "│         TPU 3         │\n",
+       "├───────────────────────┤\n",
+       "│         TPU 6         │\n",
+       "├───────────────────────┤\n",
+       "│         TPU 7         │\n",
+       "├───────────────────────┤\n",
+       "│         TPU 4         │\n",
+       "├───────────────────────┤\n",
+       "│         TPU 5         │\n",
+       "└───────────────────────┘\n",
+       "
\n" + ], + "text/plain": [ + "┌───────────────────────┐\n", + "│ TPU \u001b[1;36m0\u001b[0m │\n", + "├───────────────────────┤\n", + "│ TPU \u001b[1;36m1\u001b[0m │\n", + "├───────────────────────┤\n", + "│ TPU \u001b[1;36m2\u001b[0m │\n", + "├───────────────────────┤\n", + "│ TPU \u001b[1;36m3\u001b[0m │\n", + "├───────────────────────┤\n", + "│ TPU \u001b[1;36m6\u001b[0m │\n", + "├───────────────────────┤\n", + "│ TPU \u001b[1;36m7\u001b[0m │\n", + "├───────────────────────┤\n", + "│ TPU \u001b[1;36m4\u001b[0m │\n", + "├───────────────────────┤\n", + "│ TPU \u001b[1;36m5\u001b[0m │\n", + "└───────────────────────┘\n" + ] + }, + "metadata": {}, + "output_type": "display_data" } ], "source": [ @@ -1069,54 +830,103 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 24, "metadata": { "id": "_EmQwggc3vGQ" }, "outputs": [], "source": [ - "from jax.experimental import mesh_utils\n", - "from jax.sharding import PositionalSharding\n", - "sharding = PositionalSharding(mesh_utils.create_device_mesh((8,)))" + "devices = mesh_utils.create_device_mesh((4, 2))\n", + "mesh = Mesh(devices, axis_names=('a', 'b'))" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 25, "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 349 + }, "id": "LnT0vWjc3vGQ", - "outputId": "8089effc-aa4c-49e3-dd19-7064881dbad0" + "outputId": "8e642049-61eb-458d-af79-ac449b58d11b" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "input sharding:\n", - "┌──────────┬──────────┐\n", - "│ TPU 0 │ TPU 1 │\n", - "├──────────┼──────────┤\n", - "│ TPU 2 │ TPU 3 │\n", - "├──────────┼──────────┤\n", - "│ TPU 6 │ TPU 7 │\n", - "├──────────┼──────────┤\n", - "│ TPU 4 │ TPU 5 │\n", - "└──────────┴──────────┘\n", - "output sharding:\n", - "┌──────────┬──────────┐\n", - "│ TPU 0 │ TPU 1 │\n", - "├──────────┼──────────┤\n", - "│ TPU 2 │ TPU 3 │\n", - "├──────────┼──────────┤\n", - "│ TPU 6 │ TPU 7 │\n", - "├──────────┼──────────┤\n", - "│ TPU 4 │ TPU 5 │\n", - "└──────────┴──────────┘\n" + "input sharding:\n" + ] + }, + { + "data": { + "text/html": [ + "
┌──────────┬──────────┐\n",
+       "│  TPU 0   │  TPU 1   │\n",
+       "├──────────┼──────────┤\n",
+       "│  TPU 2   │  TPU 3   │\n",
+       "├──────────┼──────────┤\n",
+       "│  TPU 6   │  TPU 7   │\n",
+       "├──────────┼──────────┤\n",
+       "│  TPU 4   │  TPU 5   │\n",
+       "└──────────┴──────────┘\n",
+       "
\n" + ], + "text/plain": [ + "┌──────────┬──────────┐\n", + "│ TPU \u001b[1;36m0\u001b[0m │ TPU \u001b[1;36m1\u001b[0m │\n", + "├──────────┼──────────┤\n", + "│ TPU \u001b[1;36m2\u001b[0m │ TPU \u001b[1;36m3\u001b[0m │\n", + "├──────────┼──────────┤\n", + "│ TPU \u001b[1;36m6\u001b[0m │ TPU \u001b[1;36m7\u001b[0m │\n", + "├──────────┼──────────┤\n", + "│ TPU \u001b[1;36m4\u001b[0m │ TPU \u001b[1;36m5\u001b[0m │\n", + "└──────────┴──────────┘\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "output sharding:\n" ] + }, + { + "data": { + "text/html": [ + "
┌──────────┬──────────┐\n",
+       "│  TPU 0   │  TPU 1   │\n",
+       "├──────────┼──────────┤\n",
+       "│  TPU 2   │  TPU 3   │\n",
+       "├──────────┼──────────┤\n",
+       "│  TPU 6   │  TPU 7   │\n",
+       "├──────────┼──────────┤\n",
+       "│  TPU 4   │  TPU 5   │\n",
+       "└──────────┴──────────┘\n",
+       "
\n" + ], + "text/plain": [ + "┌──────────┬──────────┐\n", + "│ TPU \u001b[1;36m0\u001b[0m │ TPU \u001b[1;36m1\u001b[0m │\n", + "├──────────┼──────────┤\n", + "│ TPU \u001b[1;36m2\u001b[0m │ TPU \u001b[1;36m3\u001b[0m │\n", + "├──────────┼──────────┤\n", + "│ TPU \u001b[1;36m6\u001b[0m │ TPU \u001b[1;36m7\u001b[0m │\n", + "├──────────┼──────────┤\n", + "│ TPU \u001b[1;36m4\u001b[0m │ TPU \u001b[1;36m5\u001b[0m │\n", + "└──────────┴──────────┘\n" + ] + }, + "metadata": {}, + "output_type": "display_data" } ], "source": [ - "x = jax.device_put(x, sharding.reshape(4, 2))\n", + "x = jax.device_put(x, NamedSharding(mesh, P('a', 'b')))\n", "print('input sharding:')\n", "jax.debug.visualize_array_sharding(x)\n", "\n", @@ -1140,54 +950,132 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 26, "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 548 + }, "id": "Dq043GkP3vGQ", - "outputId": "350219a8-1e4a-4404-fe14-50f97ea3e7ba" + "outputId": "3eff7b67-d7f0-4212-c9d3-2cc271ac1f98" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "lhs sharding:\n", - "┌───────────────────────┐\n", - "│ TPU 0,1 │\n", - "├───────────────────────┤\n", - "│ TPU 2,3 │\n", - "├───────────────────────┤\n", - "│ TPU 6,7 │\n", - "├───────────────────────┤\n", - "│ TPU 4,5 │\n", - "└───────────────────────┘\n", - "rhs sharding:\n", - "┌───────────┬───────────┐\n", - "│ │ │\n", - "│ │ │\n", - "│ │ │\n", - "│ │ │\n", - "│TPU 0,2,4,6│TPU 1,3,5,7│\n", - "│ │ │\n", - "│ │ │\n", - "│ │ │\n", - "│ │ │\n", - "└───────────┴───────────┘\n", - "out sharding:\n", - "┌──────────┬──────────┐\n", - "│ TPU 0 │ TPU 1 │\n", - "├──────────┼──────────┤\n", - "│ TPU 2 │ TPU 3 │\n", - "├──────────┼──────────┤\n", - "│ TPU 6 │ TPU 7 │\n", - "├──────────┼──────────┤\n", - "│ TPU 4 │ TPU 5 │\n", - "└──────────┴──────────┘\n" + "lhs sharding:\n" + ] + }, + { + "data": { + "text/html": [ + "
┌───────────────────────┐\n",
+       "│        TPU 0,1        │\n",
+       "├───────────────────────┤\n",
+       "│        TPU 2,3        │\n",
+       "├───────────────────────┤\n",
+       "│        TPU 6,7        │\n",
+       "├───────────────────────┤\n",
+       "│        TPU 4,5        │\n",
+       "└───────────────────────┘\n",
+       "
\n" + ], + "text/plain": [ + "┌───────────────────────┐\n", + "│ TPU \u001b[1;36m0\u001b[0m,\u001b[1;36m1\u001b[0m │\n", + "├───────────────────────┤\n", + "│ TPU \u001b[1;36m2\u001b[0m,\u001b[1;36m3\u001b[0m │\n", + "├───────────────────────┤\n", + "│ TPU \u001b[1;36m6\u001b[0m,\u001b[1;36m7\u001b[0m │\n", + "├───────────────────────┤\n", + "│ TPU \u001b[1;36m4\u001b[0m,\u001b[1;36m5\u001b[0m │\n", + "└───────────────────────┘\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "rhs sharding:\n" + ] + }, + { + "data": { + "text/html": [ + "
┌───────────┬───────────┐\n",
+       "│           │           │\n",
+       "│           │           │\n",
+       "│           │           │\n",
+       "│           │           │\n",
+       "│TPU 0,2,4,6│TPU 1,3,5,7│\n",
+       "│           │           │\n",
+       "│           │           │\n",
+       "│           │           │\n",
+       "│           │           │\n",
+       "└───────────┴───────────┘\n",
+       "
\n" + ], + "text/plain": [ + "┌───────────┬───────────┐\n", + "│ │ │\n", + "│ │ │\n", + "│ │ │\n", + "│ │ │\n", + "│TPU \u001b[1;36m0\u001b[0m,\u001b[1;36m2\u001b[0m,\u001b[1;36m4\u001b[0m,\u001b[1;36m6\u001b[0m│TPU \u001b[1;36m1\u001b[0m,\u001b[1;36m3\u001b[0m,\u001b[1;36m5\u001b[0m,\u001b[1;36m7\u001b[0m│\n", + "│ │ │\n", + "│ │ │\n", + "│ │ │\n", + "│ │ │\n", + "└───────────┴───────────┘\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "out sharding:\n" ] + }, + { + "data": { + "text/html": [ + "
┌──────────┬──────────┐\n",
+       "│  TPU 0   │  TPU 1   │\n",
+       "├──────────┼──────────┤\n",
+       "│  TPU 2   │  TPU 3   │\n",
+       "├──────────┼──────────┤\n",
+       "│  TPU 6   │  TPU 7   │\n",
+       "├──────────┼──────────┤\n",
+       "│  TPU 4   │  TPU 5   │\n",
+       "└──────────┴──────────┘\n",
+       "
\n" + ], + "text/plain": [ + "┌──────────┬──────────┐\n", + "│ TPU \u001b[1;36m0\u001b[0m │ TPU \u001b[1;36m1\u001b[0m │\n", + "├──────────┼──────────┤\n", + "│ TPU \u001b[1;36m2\u001b[0m │ TPU \u001b[1;36m3\u001b[0m │\n", + "├──────────┼──────────┤\n", + "│ TPU \u001b[1;36m6\u001b[0m │ TPU \u001b[1;36m7\u001b[0m │\n", + "├──────────┼──────────┤\n", + "│ TPU \u001b[1;36m4\u001b[0m │ TPU \u001b[1;36m5\u001b[0m │\n", + "└──────────┴──────────┘\n" + ] + }, + "metadata": {}, + "output_type": "display_data" } ], "source": [ - "y = jax.device_put(x, sharding.reshape(4, 2).replicate(1))\n", - "z = jax.device_put(x, sharding.reshape(4, 2).replicate(0))\n", + "y = jax.device_put(x, NamedSharding(mesh, P('a', None)))\n", + "z = jax.device_put(x, NamedSharding(mesh, P(None, 'b')))\n", "print('lhs sharding:')\n", "jax.debug.visualize_array_sharding(y)\n", "print('rhs sharding:')\n", @@ -1211,28 +1099,48 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 27, "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 199 + }, "id": "QjQ5u8qh3vGQ", - "outputId": "bd29edcd-b87c-486e-c568-906f06ae16be" + "outputId": "0aefc170-833c-4a6a-e003-5990d3db31d9" }, "outputs": [ { - "name": "stdout", - "output_type": "stream", - "text": [ - "┌───────────────────────┐\n", - "│ │\n", - "│ │\n", - "│ │\n", - "│ │\n", - "│ TPU 0 │\n", - "│ │\n", - "│ │\n", - "│ │\n", - "│ │\n", - "└───────────────────────┘\n" - ] + "data": { + "text/html": [ + "
┌───────────────────────┐\n",
+       "│                       │\n",
+       "│                       │\n",
+       "│                       │\n",
+       "│                       │\n",
+       "│         TPU 0         │\n",
+       "│                       │\n",
+       "│                       │\n",
+       "│                       │\n",
+       "│                       │\n",
+       "└───────────────────────┘\n",
+       "
\n" + ], + "text/plain": [ + "┌───────────────────────┐\n", + "│ │\n", + "│ │\n", + "│ │\n", + "│ │\n", + "│ TPU \u001b[1;36m0\u001b[0m │\n", + "│ │\n", + "│ │\n", + "│ │\n", + "│ │\n", + "└───────────────────────┘\n" + ] + }, + "metadata": {}, + "output_type": "display_data" } ], "source": [ @@ -1242,10 +1150,13 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 28, "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, "id": "8tn8lOj73vGR", - "outputId": "5809b3c8-7333-4cd3-db97-a7aede943dce" + "outputId": "d9898c93-7afc-416b-8c40-4d9551613cd0" }, "outputs": [ { @@ -1254,7 +1165,7 @@ "True" ] }, - "execution_count": 36, + "execution_count": 28, "metadata": {}, "output_type": "execute_result" } @@ -1266,17 +1177,20 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 29, "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, "id": "D7PpZwhR3vGR", - "outputId": "4f0bd43d-0b32-4089-d3da-c8f1449e3526" + "outputId": "4901a11b-2354-4d26-a897-b88def07a716" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "5 loops, best of 5: 19.3 ms per loop\n" + "49.7 ms ± 349 µs per loop (mean ± std. dev. of 5 runs, 5 loops each)\n" ] } ], @@ -1286,17 +1200,20 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 30, "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, "id": "rgo_yVHF3vGR", - "outputId": "97f19052-f1c9-4d30-f453-07b3a7208aa9" + "outputId": "e51216cf-b073-4250-d422-67f9fd72f6aa" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "5 loops, best of 5: 3.25 ms per loop\n" + "7.47 ms ± 44.8 µs per loop (mean ± std. dev. of 5 runs, 5 loops each)\n" ] } ], @@ -1315,26 +1232,44 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 31, "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 166 + }, "id": "f1Zw-2lH3vGR", - "outputId": "a796bed4-07b0-497d-8fd8-31a22ab9762e" + "outputId": "43d7a642-fde4-47a6-901f-dfdc64d6a613" }, "outputs": [ { - "name": "stdout", - "output_type": "stream", - "text": [ - "┌──────────┬──────────┐\n", - "│ TPU 0 │ TPU 1 │\n", - "├──────────┼──────────┤\n", - "│ TPU 2 │ TPU 3 │\n", - "├──────────┼──────────┤\n", - "│ TPU 6 │ TPU 7 │\n", - "├──────────┼──────────┤\n", - "│ TPU 4 │ TPU 5 │\n", - "└──────────┴──────────┘\n" - ] + "data": { + "text/html": [ + "
┌──────────┬──────────┐\n",
+       "│  TPU 0   │  TPU 1   │\n",
+       "├──────────┼──────────┤\n",
+       "│  TPU 2   │  TPU 3   │\n",
+       "├──────────┼──────────┤\n",
+       "│  TPU 6   │  TPU 7   │\n",
+       "├──────────┼──────────┤\n",
+       "│  TPU 4   │  TPU 5   │\n",
+       "└──────────┴──────────┘\n",
+       "
\n" + ], + "text/plain": [ + "┌──────────┬──────────┐\n", + "│ TPU \u001b[1;36m0\u001b[0m │ TPU \u001b[1;36m1\u001b[0m │\n", + "├──────────┼──────────┤\n", + "│ TPU \u001b[1;36m2\u001b[0m │ TPU \u001b[1;36m3\u001b[0m │\n", + "├──────────┼──────────┤\n", + "│ TPU \u001b[1;36m6\u001b[0m │ TPU \u001b[1;36m7\u001b[0m │\n", + "├──────────┼──────────┤\n", + "│ TPU \u001b[1;36m4\u001b[0m │ TPU \u001b[1;36m5\u001b[0m │\n", + "└──────────┴──────────┘\n" + ] + }, + "metadata": {}, + "output_type": "display_data" } ], "source": [ @@ -1365,7 +1300,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 94, "metadata": { "id": "1vAkZAOY3vGR" }, @@ -1375,54 +1310,63 @@ "from termcolor import colored\n", "\n", "def print_exception(e):\n", - " name = colored(f'{type(e).__name__}', 'red')\n", + " name = colored(f'{type(e).__name__}', 'red', force_color=True)\n", " print(textwrap.fill(f'{name}: {str(e)}'))" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 95, "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, "id": "DHh0N3vn3vGS", - "outputId": "e7741882-0ebf-4237-e5d1-e48c9b9c178f" + "outputId": "8c4652f7-c484-423b-ad78-182134280187" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "\u001b[31mValueError\u001b[0m: Devices of all `Array` inputs and outputs should\n", - "be the same. Got array device ids [0, 1, 2, 3] on platform TPU and\n", - "another array's device ids [4, 5, 6, 7] on platform TPU\n" + "\u001b[31mValueError\u001b[0m: Received incompatible devices for jitted\n", + "computation. Got argument x1 of jax.numpy.add with shape int32[24] and\n", + "device ids [0, 1, 2, 3] on platform TPU and argument x2 of\n", + "jax.numpy.add with shape int32[24] and device ids [4, 5, 6, 7] on\n", + "platform TPU\n" ] } ], "source": [ - "sharding1 = PositionalSharding(jax.devices()[:4])\n", - "sharding2 = PositionalSharding(jax.devices()[4:])\n", + "sharding1 = NamedSharding(Mesh(jax.devices()[:4], 'x'), P('x'))\n", + "sharding2 = NamedSharding(Mesh(jax.devices()[4:], 'x'), P('x'))\n", "\n", - "y = jax.device_put(x, sharding1.reshape(2, 2))\n", - "z = jax.device_put(x, sharding2.reshape(2, 2))\n", + "y = jax.device_put(x, sharding1)\n", + "z = jax.device_put(x, sharding2)\n", "try: y + z\n", "except ValueError as e: print_exception(e)" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 96, "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, "id": "Im7DkoOl3vGS", - "outputId": "3adfe1cb-db52-4a9d-e98e-62c6455c3100" + "outputId": "1b6fcd7a-762b-4366-a96d-aea63bad7fe0" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "\u001b[31mValueError\u001b[0m: Devices of all `Array` inputs and outputs should\n", - "be the same. Got array device ids [0, 1, 2, 3, 4, 5, 6, 7] on platform\n", - "TPU and another array's device ids [0, 1, 2, 3, 6, 7, 4, 5] on\n", - "platform TPU\n" + "\u001b[31mValueError\u001b[0m: Received incompatible devices for jitted\n", + "computation. Got argument x1 of jax.numpy.add with shape int32[24] and\n", + "device ids [0, 1, 2, 3, 4, 5, 6, 7] on platform TPU and argument x2 of\n", + "jax.numpy.add with shape int32[24] and device ids [0, 1, 2, 3, 6, 7,\n", + "4, 5] on platform TPU\n" ] } ], @@ -1430,11 +1374,11 @@ "devices = jax.devices()\n", "permuted_devices = [devices[i] for i in [0, 1, 2, 3, 6, 7, 4, 5]]\n", "\n", - "sharding1 = PositionalSharding(devices)\n", - "sharding2 = PositionalSharding(permuted_devices)\n", + "sharding1 = NamedSharding(Mesh(devices, 'x'), P('x'))\n", + "sharding2 = NamedSharding(Mesh(permuted_devices, 'x'), P('x'))\n", "\n", - "y = jax.device_put(x, sharding1.reshape(4, 2))\n", - "z = jax.device_put(x, sharding2.reshape(4, 2))\n", + "y = jax.device_put(x, sharding1)\n", + "z = jax.device_put(x, sharding2)\n", "try: y + z\n", "except ValueError as e: print_exception(e)" ] @@ -1455,10 +1399,13 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 40, "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, "id": "_QvtKL8r3vGS", - "outputId": "e0078805-bdfd-436e-f94f-7cd256d2574f" + "outputId": "761b1208-fe4b-4c09-a7d2-f62152183ef0" }, "outputs": [ { @@ -1470,7 +1417,7 @@ } ], "source": [ - "y = jax.device_put(x, sharding1.reshape(4, 2))\n", + "y = jax.device_put(x, sharding1)\n", "y + jnp.ones_like(y)\n", "y + jnp.arange(y.size).reshape(y.shape)\n", "print('no error!')" @@ -1496,30 +1443,30 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 41, "metadata": { "id": "jniSFm5V3vGT" }, "outputs": [], "source": [ - "sharding = PositionalSharding(mesh_utils.create_device_mesh((8,)))" + "mesh = Mesh(mesh_utils.create_device_mesh((4, 2)), ('x', 'y'))" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 42, "metadata": { "id": "Q1wuDp-L3vGT" }, "outputs": [], "source": [ "x = jax.random.normal(jax.random.key(0), (8192, 8192))\n", - "x = jax.device_put(x, sharding.reshape(4, 2))" + "x = jax.device_put(x, NamedSharding(mesh, P('x', 'y')))" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 44, "metadata": { "id": "rqEDj0wB3vGT" }, @@ -1528,43 +1475,83 @@ "@jax.jit\n", "def f(x):\n", " x = x + 1\n", - " y = jax.lax.with_sharding_constraint(x, sharding.reshape(2, 4))\n", + " y = jax.lax.with_sharding_constraint(x, NamedSharding(mesh, P('y', 'x')))\n", " return y" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 45, "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 347 + }, "id": "zYFS-n4r3vGT", - "outputId": "d23a7938-cb7d-44b4-b9c7-83edf1d1145e" + "outputId": "0ac96b8f-ed23-4413-aed9-edd00a841c37" }, "outputs": [ { - "name": "stdout", - "output_type": "stream", - "text": [ - "┌──────────┬──────────┐\n", - "│ TPU 0 │ TPU 1 │\n", - "├──────────┼──────────┤\n", - "│ TPU 2 │ TPU 3 │\n", - "├──────────┼──────────┤\n", - "│ TPU 6 │ TPU 7 │\n", - "├──────────┼──────────┤\n", - "│ TPU 4 │ TPU 5 │\n", - "└──────────┴──────────┘\n", - "┌───────┬───────┬───────┬───────┐\n", - "│ │ │ │ │\n", - "│ TPU 0 │ TPU 1 │ TPU 2 │ TPU 3 │\n", - "│ │ │ │ │\n", - "│ │ │ │ │\n", - "├───────┼───────┼───────┼───────┤\n", - "│ │ │ │ │\n", - "│ TPU 6 │ TPU 7 │ TPU 4 │ TPU 5 │\n", - "│ │ │ │ │\n", - "│ │ │ │ │\n", - "└───────┴───────┴───────┴───────┘\n" - ] + "data": { + "text/html": [ + "
┌──────────┬──────────┐\n",
+       "│  TPU 0   │  TPU 1   │\n",
+       "├──────────┼──────────┤\n",
+       "│  TPU 2   │  TPU 3   │\n",
+       "├──────────┼──────────┤\n",
+       "│  TPU 6   │  TPU 7   │\n",
+       "├──────────┼──────────┤\n",
+       "│  TPU 4   │  TPU 5   │\n",
+       "└──────────┴──────────┘\n",
+       "
\n" + ], + "text/plain": [ + "┌──────────┬──────────┐\n", + "│ TPU \u001b[1;36m0\u001b[0m │ TPU \u001b[1;36m1\u001b[0m │\n", + "├──────────┼──────────┤\n", + "│ TPU \u001b[1;36m2\u001b[0m │ TPU \u001b[1;36m3\u001b[0m │\n", + "├──────────┼──────────┤\n", + "│ TPU \u001b[1;36m6\u001b[0m │ TPU \u001b[1;36m7\u001b[0m │\n", + "├──────────┼──────────┤\n", + "│ TPU \u001b[1;36m4\u001b[0m │ TPU \u001b[1;36m5\u001b[0m │\n", + "└──────────┴──────────┘\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
┌───────┬───────┬───────┬───────┐\n",
+       "│       │       │       │       │\n",
+       "│ TPU 0 │ TPU 2 │ TPU 6 │ TPU 4 │\n",
+       "│       │       │       │       │\n",
+       "│       │       │       │       │\n",
+       "├───────┼───────┼───────┼───────┤\n",
+       "│       │       │       │       │\n",
+       "│ TPU 1 │ TPU 3 │ TPU 7 │ TPU 5 │\n",
+       "│       │       │       │       │\n",
+       "│       │       │       │       │\n",
+       "└───────┴───────┴───────┴───────┘\n",
+       "
\n" + ], + "text/plain": [ + "┌───────┬───────┬───────┬───────┐\n", + "│ │ │ │ │\n", + "│ TPU \u001b[1;36m0\u001b[0m │ TPU \u001b[1;36m2\u001b[0m │ TPU \u001b[1;36m6\u001b[0m │ TPU \u001b[1;36m4\u001b[0m │\n", + "│ │ │ │ │\n", + "│ │ │ │ │\n", + "├───────┼───────┼───────┼───────┤\n", + "│ │ │ │ │\n", + "│ TPU \u001b[1;36m1\u001b[0m │ TPU \u001b[1;36m3\u001b[0m │ TPU \u001b[1;36m7\u001b[0m │ TPU \u001b[1;36m5\u001b[0m │\n", + "│ │ │ │ │\n", + "│ │ │ │ │\n", + "└───────┴───────┴───────┴───────┘\n" + ] + }, + "metadata": {}, + "output_type": "display_data" } ], "source": [ @@ -1575,7 +1562,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 46, "metadata": { "id": "8g_2Y8wp3vGT" }, @@ -1584,43 +1571,83 @@ "@jax.jit\n", "def f(x):\n", " x = x + 1\n", - " y = jax.lax.with_sharding_constraint(x, sharding.replicate())\n", + " y = jax.lax.with_sharding_constraint(x, NamedSharding(mesh, P()))\n", " return y" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 47, "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 347 + }, "id": "AiRFtVsR3vGT", - "outputId": "f3e28a70-46cf-46fb-c801-82f0ddb447e4" + "outputId": "2edacc2c-ac80-4519-c9d1-bee364a22b31" }, "outputs": [ { - "name": "stdout", - "output_type": "stream", - "text": [ - "┌──────────┬──────────┐\n", - "│ TPU 0 │ TPU 1 │\n", - "├──────────┼──────────┤\n", - "│ TPU 2 │ TPU 3 │\n", - "├──────────┼──────────┤\n", - "│ TPU 6 │ TPU 7 │\n", - "├──────────┼──────────┤\n", - "│ TPU 4 │ TPU 5 │\n", - "└──────────┴──────────┘\n", - "┌───────────────────────┐\n", - "│ │\n", - "│ │\n", - "│ │\n", - "│ │\n", - "│ TPU 0,1,2,3,4,5,6,7 │\n", - "│ │\n", - "│ │\n", - "│ │\n", - "│ │\n", - "└───────────────────────┘\n" - ] + "data": { + "text/html": [ + "
┌──────────┬──────────┐\n",
+       "│  TPU 0   │  TPU 1   │\n",
+       "├──────────┼──────────┤\n",
+       "│  TPU 2   │  TPU 3   │\n",
+       "├──────────┼──────────┤\n",
+       "│  TPU 6   │  TPU 7   │\n",
+       "├──────────┼──────────┤\n",
+       "│  TPU 4   │  TPU 5   │\n",
+       "└──────────┴──────────┘\n",
+       "
\n" + ], + "text/plain": [ + "┌──────────┬──────────┐\n", + "│ TPU \u001b[1;36m0\u001b[0m │ TPU \u001b[1;36m1\u001b[0m │\n", + "├──────────┼──────────┤\n", + "│ TPU \u001b[1;36m2\u001b[0m │ TPU \u001b[1;36m3\u001b[0m │\n", + "├──────────┼──────────┤\n", + "│ TPU \u001b[1;36m6\u001b[0m │ TPU \u001b[1;36m7\u001b[0m │\n", + "├──────────┼──────────┤\n", + "│ TPU \u001b[1;36m4\u001b[0m │ TPU \u001b[1;36m5\u001b[0m │\n", + "└──────────┴──────────┘\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
┌───────────────────────┐\n",
+       "│                       │\n",
+       "│                       │\n",
+       "│                       │\n",
+       "│                       │\n",
+       "│  TPU 0,1,2,3,4,5,6,7  │\n",
+       "│                       │\n",
+       "│                       │\n",
+       "│                       │\n",
+       "│                       │\n",
+       "└───────────────────────┘\n",
+       "
\n" + ], + "text/plain": [ + "┌───────────────────────┐\n", + "│ │\n", + "│ │\n", + "│ │\n", + "│ │\n", + "│ TPU \u001b[1;36m0\u001b[0m,\u001b[1;36m1\u001b[0m,\u001b[1;36m2\u001b[0m,\u001b[1;36m3\u001b[0m,\u001b[1;36m4\u001b[0m,\u001b[1;36m5\u001b[0m,\u001b[1;36m6\u001b[0m,\u001b[1;36m7\u001b[0m │\n", + "│ │\n", + "│ │\n", + "│ │\n", + "│ │\n", + "└───────────────────────┘\n" + ] + }, + "metadata": {}, + "output_type": "display_data" } ], "source": [ @@ -1669,7 +1696,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 48, "metadata": { "id": "mEKF3zIF3vGU" }, @@ -1681,7 +1708,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 49, "metadata": { "id": "Mocs3oGe3vGU" }, @@ -1701,7 +1728,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 50, "metadata": { "id": "glBB8tzW3vGU" }, @@ -1713,27 +1740,27 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 51, "metadata": { "id": "R0x62AIa3vGU" }, "outputs": [], "source": [ "def init_layer(key, n_in, n_out):\n", - " k1, k2 = jax.random.split(key)\n", - " W = jax.random.normal(k1, (n_in, n_out)) / jnp.sqrt(n_in)\n", - " b = jax.random.normal(k2, (n_out,))\n", - " return W, b\n", + " k1, k2 = jax.random.split(key)\n", + " W = jax.random.normal(k1, (n_in, n_out)) / jnp.sqrt(n_in)\n", + " b = jax.random.normal(k2, (n_out,))\n", + " return W, b\n", "\n", "def init_model(key, layer_sizes, batch_size):\n", - " key, *keys = jax.random.split(key, len(layer_sizes))\n", - " params = list(map(init_layer, keys, layer_sizes[:-1], layer_sizes[1:]))\n", + " key, *keys = jax.random.split(key, len(layer_sizes))\n", + " params = list(map(init_layer, keys, layer_sizes[:-1], layer_sizes[1:]))\n", "\n", - " key, *keys = jax.random.split(key, 3)\n", - " inputs = jax.random.normal(keys[0], (batch_size, layer_sizes[0]))\n", - " targets = jax.random.normal(keys[1], (batch_size, layer_sizes[-1]))\n", + " key, *keys = jax.random.split(key, 3)\n", + " inputs = jax.random.normal(keys[0], (batch_size, layer_sizes[0]))\n", + " targets = jax.random.normal(keys[1], (batch_size, layer_sizes[-1]))\n", "\n", - " return params, (inputs, targets)\n", + " return params, (inputs, targets)\n", "\n", "layer_sizes = [784, 8192, 8192, 8192, 10]\n", "batch_size = 8192\n", @@ -1752,33 +1779,48 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 52, + "metadata": { + "id": "mJLqRPpSDX0i" + }, + "outputs": [], + "source": [ + "mesh = Mesh(mesh_utils.create_device_mesh((8,)), 'batch')" + ] + }, + { + "cell_type": "code", + "execution_count": 54, "metadata": { "id": "_Q5NbdOn3vGV" }, "outputs": [], "source": [ - "sharding = PositionalSharding(jax.devices()).reshape(8, 1)" + "sharding = NamedSharding(mesh, P('batch'))\n", + "replicated_sharding = NamedSharding(mesh, P())" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 55, "metadata": { "id": "3KC6ieEe3vGV" }, "outputs": [], "source": [ "batch = jax.device_put(batch, sharding)\n", - "params = jax.device_put(params, sharding.replicate())" + "params = jax.device_put(params, replicated_sharding)" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 56, "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, "id": "MUb-QE2b3vGV", - "outputId": "1f831ea5-5a30-49ad-8195-977ff7ed476a" + "outputId": "5a27f007-c572-44f8-9f49-6e745ee739e8" }, "outputs": [ { @@ -1787,7 +1829,7 @@ "Array(23.469475, dtype=float32)" ] }, - "execution_count": 57, + "execution_count": 56, "metadata": {}, "output_type": "execute_result" } @@ -1798,17 +1840,20 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 57, "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, "id": "HUkw0u413vGV", - "outputId": "dfa2599c-9440-4657-9035-0dc3bbf625e1" + "outputId": "07e481a1-97fb-4bd0-d754-cb6d8317bff6" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "10.760101\n" + "10.760109\n" ] } ], @@ -1825,17 +1870,20 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 58, "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, "id": "paCw6Zaj3vGV", - "outputId": "8ab1c32c-f2b1-465c-df71-f5a599e7f19e" + "outputId": "ad4cce34-3a6a-4d44-9a86-477a7fee4841" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "5 loops, best of 5: 26.3 ms per loop\n" + "53.8 ms ± 1.14 ms per loop (mean ± std. dev. of 5 runs, 5 loops each)\n" ] } ], @@ -1845,7 +1893,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 59, "metadata": { "id": "BF86UWpg3vGV" }, @@ -1857,17 +1905,20 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 60, "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, "id": "Z1wgUKXk3vGV", - "outputId": "74df8892-c349-41dc-cb1b-e0843ec5c994" + "outputId": "d66767b7-3f17-482f-b811-919bb1793277" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "5 loops, best of 5: 122 ms per loop\n" + "351 ms ± 81.2 ms per loop (mean ± std. dev. of 5 runs, 5 loops each)\n" ] } ], @@ -1886,50 +1937,88 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 61, "metadata": { - "id": "N5-zzgW03vGW" + "id": "k1hxOfgRDwo0" }, "outputs": [], "source": [ - "sharding = sharding.reshape(4, 2)" + "mesh = Mesh(mesh_utils.create_device_mesh((4, 2)), ('batch', 'model'))" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 62, "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 314 + }, "id": "sgIWCjJK3vGW", - "outputId": "b2fdc556-05cc-4e68-fa04-48643d194dee" + "outputId": "8cb0f19f-3942-415c-c57a-31bb81784f46" }, "outputs": [ { - "name": "stdout", - "output_type": "stream", - "text": [ - "┌───────┐\n", - "│TPU 0,1│\n", - "├───────┤\n", - "│TPU 2,3│\n", - "├───────┤\n", - "│TPU 4,5│\n", - "├───────┤\n", - "│TPU 6,7│\n", - "└───────┘\n", - "┌───────┐\n", - "│TPU 0,1│\n", - "├───────┤\n", - "│TPU 2,3│\n", - "├───────┤\n", - "│TPU 4,5│\n", - "├───────┤\n", - "│TPU 6,7│\n", - "└───────┘\n" - ] + "data": { + "text/html": [ + "
┌───────┐\n",
+       "│TPU 0,1│\n",
+       "├───────┤\n",
+       "│TPU 2,3│\n",
+       "├───────┤\n",
+       "│TPU 6,7│\n",
+       "├───────┤\n",
+       "│TPU 4,5│\n",
+       "└───────┘\n",
+       "
\n" + ], + "text/plain": [ + "┌───────┐\n", + "│TPU \u001b[1;36m0\u001b[0m,\u001b[1;36m1\u001b[0m│\n", + "├───────┤\n", + "│TPU \u001b[1;36m2\u001b[0m,\u001b[1;36m3\u001b[0m│\n", + "├───────┤\n", + "│TPU \u001b[1;36m6\u001b[0m,\u001b[1;36m7\u001b[0m│\n", + "├───────┤\n", + "│TPU \u001b[1;36m4\u001b[0m,\u001b[1;36m5\u001b[0m│\n", + "└───────┘\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
┌───────┐\n",
+       "│TPU 0,1│\n",
+       "├───────┤\n",
+       "│TPU 2,3│\n",
+       "├───────┤\n",
+       "│TPU 6,7│\n",
+       "├───────┤\n",
+       "│TPU 4,5│\n",
+       "└───────┘\n",
+       "
\n" + ], + "text/plain": [ + "┌───────┐\n", + "│TPU \u001b[1;36m0\u001b[0m,\u001b[1;36m1\u001b[0m│\n", + "├───────┤\n", + "│TPU \u001b[1;36m2\u001b[0m,\u001b[1;36m3\u001b[0m│\n", + "├───────┤\n", + "│TPU \u001b[1;36m6\u001b[0m,\u001b[1;36m7\u001b[0m│\n", + "├───────┤\n", + "│TPU \u001b[1;36m4\u001b[0m,\u001b[1;36m5\u001b[0m│\n", + "└───────┘\n" + ] + }, + "metadata": {}, + "output_type": "display_data" } ], "source": [ - "batch = jax.device_put(batch, sharding.replicate(1))\n", + "batch = jax.device_put(batch, NamedSharding(mesh, P('batch', None)))\n", "jax.debug.visualize_array_sharding(batch[0])\n", "jax.debug.visualize_array_sharding(batch[1])" ] @@ -1937,6 +2026,17 @@ { "cell_type": "code", "execution_count": null, + "metadata": { + "id": "q9PQP-0eEAO6" + }, + "outputs": [], + "source": [ + "replicated_sharding = NamedSharding(mesh, P())" + ] + }, + { + "cell_type": "code", + "execution_count": 67, "metadata": { "id": "BqCjYCgg3vGW" }, @@ -1944,45 +2044,65 @@ "source": [ "(W1, b1), (W2, b2), (W3, b3), (W4, b4) = params\n", "\n", - "W1 = jax.device_put(W1, sharding.replicate())\n", - "b1 = jax.device_put(b1, sharding.replicate())\n", + "W1 = jax.device_put(W1, replicated_sharding)\n", + "b1 = jax.device_put(b1, replicated_sharding)\n", "\n", - "W2 = jax.device_put(W2, sharding.replicate(0))\n", - "b2 = jax.device_put(b2, sharding.replicate(0))\n", + "W2 = jax.device_put(W2, NamedSharding(mesh, P(None, 'model')))\n", + "b2 = jax.device_put(b2, NamedSharding(mesh, P('model')))\n", "\n", - "W3 = jax.device_put(W3, sharding.replicate(0).T)\n", - "b3 = jax.device_put(b3, sharding.replicate())\n", + "W3 = jax.device_put(W3, NamedSharding(mesh, P('model', None)))\n", + "b3 = jax.device_put(b3, replicated_sharding)\n", "\n", - "W4 = jax.device_put(W4, sharding.replicate())\n", - "b4 = jax.device_put(b4, sharding.replicate())\n", + "W4 = jax.device_put(W4, replicated_sharding)\n", + "b4 = jax.device_put(b4, replicated_sharding)\n", "\n", "params = (W1, b1), (W2, b2), (W3, b3), (W4, b4)" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 68, "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 199 + }, "id": "_lSJ63sh3vGW", - "outputId": "5b37aa8b-3226-4805-8282-876e8d06edda" + "outputId": "bcd3e33e-36b5-4787-9cd2-60623fd6e5fa" }, "outputs": [ { - "name": "stdout", - "output_type": "stream", - "text": [ - "┌───────────┬───────────┐\n", - "│ │ │\n", - "│ │ │\n", - "│ │ │\n", - "│ │ │\n", - "│TPU 0,2,4,6│TPU 1,3,5,7│\n", - "│ │ │\n", - "│ │ │\n", - "│ │ │\n", - "│ │ │\n", - "└───────────┴───────────┘\n" - ] + "data": { + "text/html": [ + "
┌───────────┬───────────┐\n",
+       "│           │           │\n",
+       "│           │           │\n",
+       "│           │           │\n",
+       "│           │           │\n",
+       "│TPU 0,2,4,6│TPU 1,3,5,7│\n",
+       "│           │           │\n",
+       "│           │           │\n",
+       "│           │           │\n",
+       "│           │           │\n",
+       "└───────────┴───────────┘\n",
+       "
\n" + ], + "text/plain": [ + "┌───────────┬───────────┐\n", + "│ │ │\n", + "│ │ │\n", + "│ │ │\n", + "│ │ │\n", + "│TPU \u001b[1;36m0\u001b[0m,\u001b[1;36m2\u001b[0m,\u001b[1;36m4\u001b[0m,\u001b[1;36m6\u001b[0m│TPU \u001b[1;36m1\u001b[0m,\u001b[1;36m3\u001b[0m,\u001b[1;36m5\u001b[0m,\u001b[1;36m7\u001b[0m│\n", + "│ │ │\n", + "│ │ │\n", + "│ │ │\n", + "│ │ │\n", + "└───────────┴───────────┘\n" + ] + }, + "metadata": {}, + "output_type": "display_data" } ], "source": [ @@ -1991,28 +2111,48 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 69, "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 199 + }, "id": "fxkfWYkk3vGW", - "outputId": "8a1063c3-540b-47c1-d990-a6845da861f7" + "outputId": "59e60b16-fe37-47d4-8214-96096ffbd79c" }, "outputs": [ { - "name": "stdout", - "output_type": "stream", - "text": [ - "┌───────────────────────┐\n", - "│ │\n", - "│ TPU 0,2,4,6 │\n", - "│ │\n", - "│ │\n", - "├───────────────────────┤\n", - "│ │\n", - "│ TPU 1,3,5,7 │\n", - "│ │\n", - "│ │\n", - "└───────────────────────┘\n" - ] + "data": { + "text/html": [ + "
┌───────────────────────┐\n",
+       "│                       │\n",
+       "│      TPU 0,2,4,6      │\n",
+       "│                       │\n",
+       "│                       │\n",
+       "├───────────────────────┤\n",
+       "│                       │\n",
+       "│      TPU 1,3,5,7      │\n",
+       "│                       │\n",
+       "│                       │\n",
+       "└───────────────────────┘\n",
+       "
\n" + ], + "text/plain": [ + "┌───────────────────────┐\n", + "│ │\n", + "│ TPU \u001b[1;36m0\u001b[0m,\u001b[1;36m2\u001b[0m,\u001b[1;36m4\u001b[0m,\u001b[1;36m6\u001b[0m │\n", + "│ │\n", + "│ │\n", + "├───────────────────────┤\n", + "│ │\n", + "│ TPU \u001b[1;36m1\u001b[0m,\u001b[1;36m3\u001b[0m,\u001b[1;36m5\u001b[0m,\u001b[1;36m7\u001b[0m │\n", + "│ │\n", + "│ │\n", + "└───────────────────────┘\n" + ] + }, + "metadata": {}, + "output_type": "display_data" } ], "source": [ @@ -2021,17 +2161,20 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 70, "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, "id": "uPCVs-_k3vGW", - "outputId": "de01cdfc-36cb-4823-c692-22c692ef4220" + "outputId": "618516e9-9736-4ca0-dd22-09d094ce57a2" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "10.760103\n" + "10.760109\n" ] } ], @@ -2041,7 +2184,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 71, "metadata": { "id": "L9JebLK_3vGW" }, @@ -2057,17 +2200,20 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 72, "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, "id": "c9Sbl69e3vGX", - "outputId": "8272c5fa-e59f-4953-c2d5-658c42a28712" + "outputId": "2ee3d432-7172-46ca-e01a-614e83345808" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "10.752466\n" + "10.752513\n" ] } ], @@ -2077,39 +2223,81 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 73, "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 380 + }, "id": "lkAF0dAb3vGX", - "outputId": "acf0df31-c5e1-4683-b73f-b0cd1b0929f8" + "outputId": "6c1e317e-cded-4af4-8080-0de835fa4c71" }, "outputs": [ { - "name": "stdout", - "output_type": "stream", - "text": [ - "┌───────────┬───────────┐\n", - "│ │ │\n", - "│ │ │\n", - "│ │ │\n", - "│ │ │\n", - "│TPU 0,2,4,6│TPU 1,3,5,7│\n", - "│ │ │\n", - "│ │ │\n", - "│ │ │\n", - "│ │ │\n", - "└───────────┴───────────┘\n", - "┌───────────────────────┐\n", - "│ │\n", - "│ TPU 0,2,4,6 │\n", - "│ │\n", - "│ │\n", - "├───────────────────────┤\n", - "│ │\n", - "│ TPU 1,3,5,7 │\n", - "│ │\n", - "│ │\n", - "└───────────────────────┘\n" - ] + "data": { + "text/html": [ + "
┌───────────┬───────────┐\n",
+       "│           │           │\n",
+       "│           │           │\n",
+       "│           │           │\n",
+       "│           │           │\n",
+       "│TPU 0,2,4,6│TPU 1,3,5,7│\n",
+       "│           │           │\n",
+       "│           │           │\n",
+       "│           │           │\n",
+       "│           │           │\n",
+       "└───────────┴───────────┘\n",
+       "
\n" + ], + "text/plain": [ + "┌───────────┬───────────┐\n", + "│ │ │\n", + "│ │ │\n", + "│ │ │\n", + "│ │ │\n", + "│TPU \u001b[1;36m0\u001b[0m,\u001b[1;36m2\u001b[0m,\u001b[1;36m4\u001b[0m,\u001b[1;36m6\u001b[0m│TPU \u001b[1;36m1\u001b[0m,\u001b[1;36m3\u001b[0m,\u001b[1;36m5\u001b[0m,\u001b[1;36m7\u001b[0m│\n", + "│ │ │\n", + "│ │ │\n", + "│ │ │\n", + "│ │ │\n", + "└───────────┴───────────┘\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
┌───────────────────────┐\n",
+       "│                       │\n",
+       "│      TPU 0,2,4,6      │\n",
+       "│                       │\n",
+       "│                       │\n",
+       "├───────────────────────┤\n",
+       "│                       │\n",
+       "│      TPU 1,3,5,7      │\n",
+       "│                       │\n",
+       "│                       │\n",
+       "└───────────────────────┘\n",
+       "
\n" + ], + "text/plain": [ + "┌───────────────────────┐\n", + "│ │\n", + "│ TPU \u001b[1;36m0\u001b[0m,\u001b[1;36m2\u001b[0m,\u001b[1;36m4\u001b[0m,\u001b[1;36m6\u001b[0m │\n", + "│ │\n", + "│ │\n", + "├───────────────────────┤\n", + "│ │\n", + "│ TPU \u001b[1;36m1\u001b[0m,\u001b[1;36m3\u001b[0m,\u001b[1;36m5\u001b[0m,\u001b[1;36m7\u001b[0m │\n", + "│ │\n", + "│ │\n", + "└───────────────────────┘\n" + ] + }, + "metadata": {}, + "output_type": "display_data" } ], "source": [ @@ -2120,17 +2308,20 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 74, "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, "id": "I1Npor3i3vGX", - "outputId": "4099f6dd-7b46-4123-c1cb-5173c3d3278e" + "outputId": "479c4d81-cb0b-40a5-89ba-394c10dc3297" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "10 loops, best of 10: 30.5 ms per loop\n" + "51.4 ms ± 454 µs per loop (mean ± std. dev. of 10 runs, 10 loops each)\n" ] } ], @@ -2173,7 +2364,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 75, "metadata": { "id": "kwS-aQE_3vGX" }, @@ -2185,7 +2376,8 @@ " return x + numbers\n", "\n", "key = jax.random.key(42)\n", - "x_sharding = jax.sharding.PositionalSharding(jax.devices())\n", + "mesh = Mesh(jax.devices(), 'x')\n", + "x_sharding = NamedSharding(mesh, P('x'))\n", "x = jax.device_put(jnp.arange(24), x_sharding)" ] }, @@ -2200,20 +2392,32 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 76, "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 67 + }, "id": "Oi97rpLz3vGY", - "outputId": "204a7e8d-dc88-4b77-b7e3-0e72f306c5d3" + "outputId": "9dd63254-a483-4847-c0f5-5a4367bf08e9" }, "outputs": [ { - "name": "stdout", - "output_type": "stream", - "text": [ - "┌───────┬───────┬───────┬───────┬───────┬───────┬───────┬───────┐\n", - "│ TPU 0 │ TPU 1 │ TPU 2 │ TPU 3 │ TPU 4 │ TPU 5 │ TPU 6 │ TPU 7 │\n", - "└───────┴───────┴───────┴───────┴───────┴───────┴───────┴───────┘\n" - ] + "data": { + "text/html": [ + "
┌───────┬───────┬───────┬───────┬───────┬───────┬───────┬───────┐\n",
+       "│ TPU 0 │ TPU 1 │ TPU 2 │ TPU 3 │ TPU 4 │ TPU 5 │ TPU 6 │ TPU 7 │\n",
+       "└───────┴───────┴───────┴───────┴───────┴───────┴───────┴───────┘\n",
+       "
\n" + ], + "text/plain": [ + "┌───────┬───────┬───────┬───────┬───────┬───────┬───────┬───────┐\n", + "│ TPU \u001b[1;36m0\u001b[0m │ TPU \u001b[1;36m1\u001b[0m │ TPU \u001b[1;36m2\u001b[0m │ TPU \u001b[1;36m3\u001b[0m │ TPU \u001b[1;36m4\u001b[0m │ TPU \u001b[1;36m5\u001b[0m │ TPU \u001b[1;36m6\u001b[0m │ TPU \u001b[1;36m7\u001b[0m │\n", + "└───────┴───────┴───────┴───────┴───────┴───────┴───────┴───────┘\n" + ] + }, + "metadata": {}, + "output_type": "display_data" } ], "source": [ @@ -2231,10 +2435,13 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 77, "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, "id": "64wIZuSJ3vGY", - "outputId": "1054fe99-0476-44ec-9693-b0d8f98bf6a8" + "outputId": "fa166d45-ca9c-457a-be84-bcc9236d0730" }, "outputs": [ { @@ -2261,10 +2468,13 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 78, "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, "id": "1I7bqxA63vGY", - "outputId": "ec4c579d-f446-4b48-ceda-785c09ba299b" + "outputId": "756e0a36-ff14-438f-bbd4-3ef03f97a47b" }, "outputs": [ { @@ -2292,20 +2502,32 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 79, "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 67 + }, "id": "zHPJzdn23vGY", - "outputId": "a8904d20-4d04-4f59-8eae-281e47d29246" + "outputId": "3332de0f-4827-4f0b-b9ef-69249b7c6bc6" }, "outputs": [ { - "name": "stdout", - "output_type": "stream", - "text": [ - "┌───────┬───────┬───────┬───────┬───────┬───────┬───────┬───────┐\n", - "│ TPU 0 │ TPU 1 │ TPU 2 │ TPU 3 │ TPU 4 │ TPU 5 │ TPU 6 │ TPU 7 │\n", - "└───────┴───────┴───────┴───────┴───────┴───────┴───────┴───────┘\n" - ] + "data": { + "text/html": [ + "
┌───────┬───────┬───────┬───────┬───────┬───────┬───────┬───────┐\n",
+       "│ TPU 0 │ TPU 1 │ TPU 2 │ TPU 3 │ TPU 4 │ TPU 5 │ TPU 6 │ TPU 7 │\n",
+       "└───────┴───────┴───────┴───────┴───────┴───────┴───────┴───────┘\n",
+       "
\n" + ], + "text/plain": [ + "┌───────┬───────┬───────┬───────┬───────┬───────┬───────┬───────┐\n", + "│ TPU \u001b[1;36m0\u001b[0m │ TPU \u001b[1;36m1\u001b[0m │ TPU \u001b[1;36m2\u001b[0m │ TPU \u001b[1;36m3\u001b[0m │ TPU \u001b[1;36m4\u001b[0m │ TPU \u001b[1;36m5\u001b[0m │ TPU \u001b[1;36m6\u001b[0m │ TPU \u001b[1;36m7\u001b[0m │\n", + "└───────┴───────┴───────┴───────┴───────┴───────┴───────┴───────┘\n" + ] + }, + "metadata": {}, + "output_type": "display_data" } ], "source": [ @@ -2323,10 +2545,13 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 80, "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, "id": "nBUHBBal3vGY", - "outputId": "f194c213-0688-4b7a-ffb8-c4453b82b1f1" + "outputId": "4b9be948-ccab-4a31-a06f-37ec9c7b5235" }, "outputs": [ { @@ -2371,10 +2596,10 @@ "metadata": { "accelerator": "TPU", "colab": { + "gpuType": "V28", "provenance": [], "toc_visible": true }, - "gpuClass": "standard", "jupytext": { "formats": "ipynb,md:myst" }, diff --git a/docs/notebooks/Distributed_arrays_and_automatic_parallelization.md b/docs/notebooks/Distributed_arrays_and_automatic_parallelization.md index b9ec9dc694d2..2142db9866ae 100644 --- a/docs/notebooks/Distributed_arrays_and_automatic_parallelization.md +++ b/docs/notebooks/Distributed_arrays_and_automatic_parallelization.md @@ -5,7 +5,7 @@ jupytext: extension: .md format_name: myst format_version: 0.13 - jupytext_version: 1.16.1 + jupytext_version: 1.16.4 kernelspec: display_name: Python 3 name: python3 @@ -19,16 +19,14 @@ kernelspec: +++ {"id": "pFtQjv4SzHRj"} -[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/jax/blob/main/docs/notebooks/Distributed_arrays_and_automatic_parallelization.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/google/jax/blob/main/docs/notebooks/Distributed_arrays_and_automatic_parallelization.ipynb) +[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jax-ml/jax/blob/main/docs/notebooks/Distributed_arrays_and_automatic_parallelization.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/jax-ml/jax/blob/main/docs/notebooks/Distributed_arrays_and_automatic_parallelization.ipynb) This tutorial discusses parallelism via `jax.Array`, the unified array object model available in JAX v0.4.1 and newer. ```{code-cell} :id: FNxScTfq3vGF -import os -import functools from typing import Optional import numpy as np @@ -52,7 +50,7 @@ if len(jax.local_devices()) < 8: ## Intro and a quick example -By reading this tutorial notebook, you'll learn about `jax.Array`, a unified +By reading this tutorial notebook, you'll learn about `jax.Array`, a unified datatype for representing arrays, even with physical storage spanning multiple devices. You'll also learn about how using `jax.Array`s together with `jax.jit` can provide automatic compiler-based parallelization. @@ -64,24 +62,29 @@ First, we'll create a `jax.Array` sharded across multiple devices: :id: Gf2lO4ii3vGG from jax.experimental import mesh_utils -from jax.sharding import PositionalSharding +from jax.sharding import Mesh, PartitionSpec as P, NamedSharding ``` ```{code-cell} :id: q-XBTEoy3vGG # Create a Sharding object to distribute a value across devices: -sharding = PositionalSharding(mesh_utils.create_device_mesh((8,))) +mesh = Mesh(devices=mesh_utils.create_device_mesh((4, 2)), + axis_names=('x', 'y')) ``` ```{code-cell} -:id: vI39znW93vGH -:outputId: 3b518df8-5c29-4848-acc3-e41df939f30b - +--- +colab: + base_uri: https://localhost:8080/ + height: 166 +id: vI39znW93vGH +outputId: 4f702753-8add-4b65-a4af-0f18f098cc46 +--- # Create an array of random values: x = jax.random.normal(jax.random.key(0), (8192, 8192)) # and use jax.device_put to distribute it across devices: -y = jax.device_put(x, sharding.reshape(4, 2)) +y = jax.device_put(x, NamedSharding(mesh, P('x', 'y'))) jax.debug.visualize_array_sharding(y) ``` @@ -91,9 +94,13 @@ Next, we'll apply a computation to it and visualize how the result values are stored across multiple devices too: ```{code-cell} -:id: -qCnHZl83vGI -:outputId: 9da9c29e-ce88-4425-e1ec-e93e5bcf3106 - +--- +colab: + base_uri: https://localhost:8080/ + height: 166 +id: -qCnHZl83vGI +outputId: 0e131c23-5765-43ae-f232-6417ae1acbb2 +--- z = jnp.sin(y) jax.debug.visualize_array_sharding(z) ``` @@ -104,17 +111,23 @@ The evaluation of the `jnp.sin` application was automatically parallelized across the devices on which the input values (and output values) are stored: ```{code-cell} -:id: _VTzN0r03vGI -:outputId: c9208010-984b-442b-d105-c8c6a3a010e6 - +--- +colab: + base_uri: https://localhost:8080/ +id: _VTzN0r03vGI +outputId: c03eecab-4c86-4dac-d776-5fc72cbb5273 +--- # `x` is present on a single device %timeit -n 5 -r 5 jnp.sin(x).block_until_ready() ``` ```{code-cell} -:id: QuzhU1g63vGI -:outputId: d48fc76e-79a7-47b9-d392-b18a1c33c798 - +--- +colab: + base_uri: https://localhost:8080/ +id: QuzhU1g63vGI +outputId: 8135cca0-871b-4b6a-a7e5-02e78c2028c7 +--- # `y` is sharded across 8 devices. %timeit -n 5 -r 5 jnp.sin(y).block_until_ready() ``` @@ -128,7 +141,7 @@ Now let's look at each of these pieces in more detail! +++ {"id": "W6HsXauGxL6w"} -### Sharding basics, and the `PositionalSharding` subclass +### Sharding basics, and the `NamedSharding` subclass +++ {"id": "NWDyp_EjVHkg"} @@ -146,9 +159,13 @@ x = jax.random.normal(jax.random.key(0), (8192, 8192)) ``` ```{code-cell} -:id: vNRabO2J3vGJ -:outputId: 73db7b6e-c2e7-467d-a0ef-c35e29e582dd - +--- +colab: + base_uri: https://localhost:8080/ + height: 199 +id: vNRabO2J3vGJ +outputId: 40fd7172-a16c-4dd8-e2e1-17bb3afe5409 +--- jax.debug.visualize_array_sharding(x) ``` @@ -159,169 +176,14 @@ Here, we're using the `jax.debug.visualize_array_sharding` function to show wher But we can shard `x` across multiple devices by using `jax.device_put` and a `Sharding` object. First, we make a `numpy.ndarray` of `Devices` using `mesh_utils.create_device_mesh`, which takes hardware topology into account for the `Device` order: ```{code-cell} -:id: VUIEIzRp3vGK - -from jax.experimental import mesh_utils -devices = mesh_utils.create_device_mesh((8,)) -``` - -+++ {"id": "lbOKFWmBX1iv"} - -Then, we create a `PositionalSharding` and use it with `device_put`: - -```{code-cell} -:id: jwrWfZeB3vGK -:outputId: e6f126bd-f6bd-48c7-c130-6f02757e3342 - -from jax.sharding import PositionalSharding - -sharding = PositionalSharding(devices) - -x = jax.device_put(x, sharding.reshape(8, 1)) -jax.debug.visualize_array_sharding(x) -``` - -+++ {"id": "TUu69IWXZdTm"} - -Here `sharding` is a `PositionalSharding` which acts like an array with sets of devices as elements: - -```{code-cell} -:id: zxWB82Kz3vGK -:outputId: 11384a6b-fabc-4c4c-bcad-a3be51eb0465 - -sharding -``` - -+++ {"id": "uRLpOcmNj_Vt"} - -The device numbers here are not in numerical order, because the mesh reflects the underlying toroidal topology of the device. - -By writing `PositionalSharding(ndarray_of_devices)`, we fix the device order and the initial shape. Then we can reshape it: - -```{code-cell} -:id: PLsnpSzc3vGL -:outputId: 9f4db733-cafe-46ae-c057-dc31046a6f66 - -sharding.reshape(8, 1) -``` - -```{code-cell} -:id: iqKdI4LO3vGL -:outputId: 6aa10fc2-cec4-4401-a0df-343e71646e0a - -sharding.reshape(4, 2) -``` - -+++ {"id": "KBu6WLfhm7ra"} - -To use `device_put` with a data array `x`, we can reshape the `sharding` into a shape that is _congruent_ with `x.shape`, meaning a shape with the same length as `x.shape` and where each element evenly divides the corresponding element of `x.shape`: -```python -def is_congruent(x_shape: Sequence[int], sharding_shape: Sequence[int]) -> bool: - return (len(x_shape) == len(sharding_shape) and - all(d1 % d2 == 0 for d1, d2 in zip(x_shape, sharding_shape))) -``` - -For example, we can reshape `sharding` to have shape `(4, 2)`, then use it in a `device_put`: - -```{code-cell} -:id: SELr4xNi3vGL -:outputId: b2f4acec-0cd3-4829-ca16-cae2e0e8ca60 - -sharding = sharding.reshape(4, 2) -print(sharding) -``` - -```{code-cell} -:id: 8IVIsqfX3vGL -:outputId: 033d0e02-a643-4f4c-9d24-9cd8465bc69a - -y = jax.device_put(x, sharding) -jax.debug.visualize_array_sharding(y) -``` - -+++ {"id": "tyg9F-UIsU__"} - -Here `y` represents the same _value_ as `x`, but its shards (i.e. slices) are stored in different devices' memories. - -Different `PositionalSharding` shapes result in different distributed layouts (i.e. shardings) of the result: - -```{code-cell} -:id: cCjt6QCz3vGM -:outputId: 4ad8a611-596d-424f-b6c5-fc00f1adc306 - -sharding = sharding.reshape(1, 8) -print(sharding) -``` - -```{code-cell} -:id: yTK4Nz3u3vGM -:outputId: e445c6bc-4fe3-4e9d-cc9e-d82858f58312 - -y = jax.device_put(x, sharding) -jax.debug.visualize_array_sharding(y) -``` - -+++ {"id": "0PuamOvXubcf"} - -In some cases, we don't just want to store each slice of `x` in a single device's memory; we might want to _replicate_ some slices, meaning storing copies of a slice's values in multiple devices' memories. - -With `PositionalSharding`, we can express replication by calling the reducer method `replicate`: - -```{code-cell} -:id: _jr6XYKx3vGM -:outputId: 59c8b9a4-b8af-493a-ba8d-da5931e88f93 - -sharding = sharding.reshape(4, 2) -print(sharding.replicate(axis=0, keepdims=True)) -``` - -```{code-cell} -:id: S5vzjFuH3vGN -:outputId: b6ce2675-7261-4e57-fa8c-b4e87abf7e52 - -y = jax.device_put(x, sharding.replicate(axis=0, keepdims=True)) -jax.debug.visualize_array_sharding(y) -``` - -+++ {"id": "FzeP0kpTvJv-"} - -Here the visualization shows that `x` is sharded two ways along its second dimension (and not sharded along the first dimension), and each of those shards is replicated four ways (i.e. stored in four device memories). - -The `replicate` method is analogous to the familiar NumPy array reduction methods like `.sum()` and `.prod()`. It operates along an axis performing a set union. So if `sharding` has shape `(4, 2)`, then `sharding.replicate(0, keepdims=True)` has shape `(1, 2)`, and `sharding.replicate(1, keepdims=True)` has shape `(4, 1)`. Unlike analogous NumPy methods, `keepdims=True` is actually the default, so reduced-over axes aren't squeezed: - -```{code-cell} -:id: DR7VV-6e3vGN -:outputId: f879fc2c-5723-4199-b306-295bc1b3681e - -print(sharding.replicate(0).shape) -print(sharding.replicate(1).shape) -``` - -```{code-cell} -:id: agUtVUVx3vGN -:outputId: 0e9789ef-ce52-4ed6-8bd5-c876b95f66e6 - -y = jax.device_put(x, sharding.replicate(1)) -jax.debug.visualize_array_sharding(y) -``` - -+++ {"id": "D31t5POXxHHJ"} - -### `NamedSharding` gives a way to express shardings with names - -+++ {"id": "ayMKWeTmxl-X"} - -So far we've worked with `PositionalSharding`, but there are alternative ways to express shardings. In fact, `Sharding` is an interface, and any class that implements that interface can be used with functions like `device_put`. - -Another convenient way to express sharding is with the `NamedSharding`: - -```{code-cell} -:id: zpB1JxyK3vGN -:outputId: 46d5da37-840c-49d8-8380-a162811bae8a - -from jax.sharding import Mesh -from jax.sharding import PartitionSpec -from jax.sharding import NamedSharding +--- +colab: + base_uri: https://localhost:8080/ + height: 166 +id: zpB1JxyK3vGN +outputId: 8e385462-1c2c-4256-c38a-84299d3bd02c +--- +from jax.sharding import Mesh, PartitionSpec, NamedSharding from jax.experimental import mesh_utils P = PartitionSpec @@ -351,9 +213,13 @@ def mesh_sharding( ``` ```{code-cell} -:id: zp3MfS4Y3vGO -:outputId: 2c2f7201-c2c1-49e5-f8a5-0730c124d89a - +--- +colab: + base_uri: https://localhost:8080/ + height: 166 +id: zp3MfS4Y3vGO +outputId: 032fdd7e-19a1-45da-e1ad-b3227fa43ee6 +--- y = jax.device_put(x, mesh_sharding(P('a', 'b'))) jax.debug.visualize_array_sharding(y) ``` @@ -363,17 +229,25 @@ jax.debug.visualize_array_sharding(y) Here, we use `P('a', 'b')` to express that the first and second axes of `x` should be sharded over the device mesh axes `'a'` and `'b'`, respectively. We can easily switch to `P('b', 'a')` to shard the axes of `x` over different devices: ```{code-cell} -:id: FigK5Zsa3vGO -:outputId: eca784e8-33fe-4e9b-a41d-21e9ee781a35 - +--- +colab: + base_uri: https://localhost:8080/ + height: 199 +id: FigK5Zsa3vGO +outputId: e488d073-9d02-4376-a6af-19d6d5509c7d +--- y = jax.device_put(x, mesh_sharding(P('b', 'a'))) jax.debug.visualize_array_sharding(y) ``` ```{code-cell} -:id: hI-HD0xN3vGO -:outputId: c3e7dc3c-4048-448a-ef0b-50683532fcdc - +--- +colab: + base_uri: https://localhost:8080/ + height: 166 +id: hI-HD0xN3vGO +outputId: b0c2e863-3aee-4417-b45f-21b2187f6ef7 +--- # This `None` means that `x` is not sharded on its second dimension, # and since the Mesh axis name 'b' is not mentioned, shards are # replicated across it. @@ -388,17 +262,25 @@ Here, because `P('a', None)` doesn't mention the `Mesh` axis name `'b'`, we get To shard only over the second axis of `x`, we can use a `None` placeholder in the `PartitionSpec`: ```{code-cell} -:id: EXBExMQC3vGP -:outputId: fe1c8d7e-3345-4438-b9d2-780e7854b4eb - +--- +colab: + base_uri: https://localhost:8080/ + height: 199 +id: EXBExMQC3vGP +outputId: c80e6177-12a6-40ef-b4e4-934dad22da3d +--- y = jax.device_put(x, mesh_sharding(P(None, 'b'))) jax.debug.visualize_array_sharding(y) ``` ```{code-cell} -:id: PjUpG8uz3vGP -:outputId: 64d8224d-15d9-4ad4-d613-f7f85b1dc1af - +--- +colab: + base_uri: https://localhost:8080/ + height: 199 +id: PjUpG8uz3vGP +outputId: a0f59dc5-b509-4b8b-bd22-bcd69f696763 +--- y = jax.device_put(x, mesh_sharding(P(None, 'a'))) jax.debug.visualize_array_sharding(y) ``` @@ -408,9 +290,13 @@ jax.debug.visualize_array_sharding(y) For a fixed mesh, we can even partition one logical axis of `x` over multiple device mesh axes: ```{code-cell} -:id: fVcPbDUA3vGP -:outputId: 7f524ba5-a6d8-4490-cda9-685ad11416f9 - +--- +colab: + base_uri: https://localhost:8080/ + height: 298 +id: fVcPbDUA3vGP +outputId: da3f435d-dfc1-4a41-ec90-691cd7c748a0 +--- y = jax.device_put(x, mesh_sharding(P(('a', 'b'), None))) jax.debug.visualize_array_sharding(y) ``` @@ -432,16 +318,19 @@ For example, the simplest computation is an elementwise one: ```{code-cell} :id: _EmQwggc3vGQ -from jax.experimental import mesh_utils -from jax.sharding import PositionalSharding -sharding = PositionalSharding(mesh_utils.create_device_mesh((8,))) +devices = mesh_utils.create_device_mesh((4, 2)) +mesh = Mesh(devices, axis_names=('a', 'b')) ``` ```{code-cell} -:id: LnT0vWjc3vGQ -:outputId: 8089effc-aa4c-49e3-dd19-7064881dbad0 - -x = jax.device_put(x, sharding.reshape(4, 2)) +--- +colab: + base_uri: https://localhost:8080/ + height: 349 +id: LnT0vWjc3vGQ +outputId: 8e642049-61eb-458d-af79-ac449b58d11b +--- +x = jax.device_put(x, NamedSharding(mesh, P('a', 'b'))) print('input sharding:') jax.debug.visualize_array_sharding(x) @@ -459,11 +348,15 @@ In other words, even though we wrote the `jnp.sin` computation as if a single ma We can do the same for more than just elementwise operations too. Consider a matrix multiplication with sharded inputs: ```{code-cell} -:id: Dq043GkP3vGQ -:outputId: 350219a8-1e4a-4404-fe14-50f97ea3e7ba - -y = jax.device_put(x, sharding.reshape(4, 2).replicate(1)) -z = jax.device_put(x, sharding.reshape(4, 2).replicate(0)) +--- +colab: + base_uri: https://localhost:8080/ + height: 548 +id: Dq043GkP3vGQ +outputId: 3eff7b67-d7f0-4212-c9d3-2cc271ac1f98 +--- +y = jax.device_put(x, NamedSharding(mesh, P('a', None))) +z = jax.device_put(x, NamedSharding(mesh, P(None, 'b'))) print('lhs sharding:') jax.debug.visualize_array_sharding(y) print('rhs sharding:') @@ -481,32 +374,45 @@ Here the compiler chose the output sharding so that it could maximally paralleli How can we be sure it's actually running in parallel? We can do a simple timing experiment: ```{code-cell} -:id: QjQ5u8qh3vGQ -:outputId: bd29edcd-b87c-486e-c568-906f06ae16be - +--- +colab: + base_uri: https://localhost:8080/ + height: 199 +id: QjQ5u8qh3vGQ +outputId: 0aefc170-833c-4a6a-e003-5990d3db31d9 +--- x_single = jax.device_put(x, jax.devices()[0]) jax.debug.visualize_array_sharding(x_single) ``` ```{code-cell} -:id: 8tn8lOj73vGR -:outputId: 5809b3c8-7333-4cd3-db97-a7aede943dce - +--- +colab: + base_uri: https://localhost:8080/ +id: 8tn8lOj73vGR +outputId: d9898c93-7afc-416b-8c40-4d9551613cd0 +--- np.allclose(jnp.dot(x_single, x_single), jnp.dot(y, z)) ``` ```{code-cell} -:id: D7PpZwhR3vGR -:outputId: 4f0bd43d-0b32-4089-d3da-c8f1449e3526 - +--- +colab: + base_uri: https://localhost:8080/ +id: D7PpZwhR3vGR +outputId: 4901a11b-2354-4d26-a897-b88def07a716 +--- %timeit -n 5 -r 5 jnp.dot(x_single, x_single).block_until_ready() ``` ```{code-cell} -:id: rgo_yVHF3vGR -:outputId: 97f19052-f1c9-4d30-f453-07b3a7208aa9 - +--- +colab: + base_uri: https://localhost:8080/ +id: rgo_yVHF3vGR +outputId: e51216cf-b073-4250-d422-67f9fd72f6aa +--- %timeit -n 5 -r 5 jnp.dot(y, z).block_until_ready() ``` @@ -515,9 +421,13 @@ np.allclose(jnp.dot(x_single, x_single), Even copying a sharded `Array` produces a result with the sharding of the input: ```{code-cell} -:id: f1Zw-2lH3vGR -:outputId: a796bed4-07b0-497d-8fd8-31a22ab9762e - +--- +colab: + base_uri: https://localhost:8080/ + height: 166 +id: f1Zw-2lH3vGR +outputId: 43d7a642-fde4-47a6-901f-dfdc64d6a613 +--- w_copy = jnp.copy(w) jax.debug.visualize_array_sharding(w_copy) ``` @@ -540,35 +450,41 @@ import textwrap from termcolor import colored def print_exception(e): - name = colored(f'{type(e).__name__}', 'red') + name = colored(f'{type(e).__name__}', 'red', force_color=True) print(textwrap.fill(f'{name}: {str(e)}')) ``` ```{code-cell} -:id: DHh0N3vn3vGS -:outputId: e7741882-0ebf-4237-e5d1-e48c9b9c178f - -sharding1 = PositionalSharding(jax.devices()[:4]) -sharding2 = PositionalSharding(jax.devices()[4:]) +--- +colab: + base_uri: https://localhost:8080/ +id: DHh0N3vn3vGS +outputId: 8c4652f7-c484-423b-ad78-182134280187 +--- +sharding1 = NamedSharding(Mesh(jax.devices()[:4], 'x'), P('x')) +sharding2 = NamedSharding(Mesh(jax.devices()[4:], 'x'), P('x')) -y = jax.device_put(x, sharding1.reshape(2, 2)) -z = jax.device_put(x, sharding2.reshape(2, 2)) +y = jax.device_put(x, sharding1) +z = jax.device_put(x, sharding2) try: y + z except ValueError as e: print_exception(e) ``` ```{code-cell} -:id: Im7DkoOl3vGS -:outputId: 3adfe1cb-db52-4a9d-e98e-62c6455c3100 - +--- +colab: + base_uri: https://localhost:8080/ +id: Im7DkoOl3vGS +outputId: 1b6fcd7a-762b-4366-a96d-aea63bad7fe0 +--- devices = jax.devices() permuted_devices = [devices[i] for i in [0, 1, 2, 3, 6, 7, 4, 5]] -sharding1 = PositionalSharding(devices) -sharding2 = PositionalSharding(permuted_devices) +sharding1 = NamedSharding(Mesh(devices, 'x'), P('x')) +sharding2 = NamedSharding(Mesh(permuted_devices, 'x'), P('x')) -y = jax.device_put(x, sharding1.reshape(4, 2)) -z = jax.device_put(x, sharding2.reshape(4, 2)) +y = jax.device_put(x, sharding1) +z = jax.device_put(x, sharding2) try: y + z except ValueError as e: print_exception(e) ``` @@ -583,10 +499,13 @@ Unlike committed arrays, uncommitted arrays can be moved and resharded automatic For example, the output of `jnp.zeros`, `jnp.arange`, and `jnp.array` are uncommitted: ```{code-cell} -:id: _QvtKL8r3vGS -:outputId: e0078805-bdfd-436e-f94f-7cd256d2574f - -y = jax.device_put(x, sharding1.reshape(4, 2)) +--- +colab: + base_uri: https://localhost:8080/ +id: _QvtKL8r3vGS +outputId: 761b1208-fe4b-4c09-a7d2-f62152183ef0 +--- +y = jax.device_put(x, sharding1) y + jnp.ones_like(y) y + jnp.arange(y.size).reshape(y.shape) print('no error!') @@ -603,14 +522,14 @@ While the compiler will attempt to decide how a function's intermediate values a ```{code-cell} :id: jniSFm5V3vGT -sharding = PositionalSharding(mesh_utils.create_device_mesh((8,))) +mesh = Mesh(mesh_utils.create_device_mesh((4, 2)), ('x', 'y')) ``` ```{code-cell} :id: Q1wuDp-L3vGT x = jax.random.normal(jax.random.key(0), (8192, 8192)) -x = jax.device_put(x, sharding.reshape(4, 2)) +x = jax.device_put(x, NamedSharding(mesh, P('x', 'y'))) ``` ```{code-cell} @@ -619,14 +538,18 @@ x = jax.device_put(x, sharding.reshape(4, 2)) @jax.jit def f(x): x = x + 1 - y = jax.lax.with_sharding_constraint(x, sharding.reshape(2, 4)) + y = jax.lax.with_sharding_constraint(x, NamedSharding(mesh, P('y', 'x'))) return y ``` ```{code-cell} -:id: zYFS-n4r3vGT -:outputId: d23a7938-cb7d-44b4-b9c7-83edf1d1145e - +--- +colab: + base_uri: https://localhost:8080/ + height: 347 +id: zYFS-n4r3vGT +outputId: 0ac96b8f-ed23-4413-aed9-edd00a841c37 +--- jax.debug.visualize_array_sharding(x) y = f(x) jax.debug.visualize_array_sharding(y) @@ -638,14 +561,18 @@ jax.debug.visualize_array_sharding(y) @jax.jit def f(x): x = x + 1 - y = jax.lax.with_sharding_constraint(x, sharding.replicate()) + y = jax.lax.with_sharding_constraint(x, NamedSharding(mesh, P())) return y ``` ```{code-cell} -:id: AiRFtVsR3vGT -:outputId: f3e28a70-46cf-46fb-c801-82f0ddb447e4 - +--- +colab: + base_uri: https://localhost:8080/ + height: 347 +id: AiRFtVsR3vGT +outputId: 2edacc2c-ac80-4519-c9d1-bee364a22b31 +--- jax.debug.visualize_array_sharding(x) y = f(x) jax.debug.visualize_array_sharding(y) @@ -702,20 +629,20 @@ gradfun = jax.jit(jax.grad(loss)) :id: R0x62AIa3vGU def init_layer(key, n_in, n_out): - k1, k2 = jax.random.split(key) - W = jax.random.normal(k1, (n_in, n_out)) / jnp.sqrt(n_in) - b = jax.random.normal(k2, (n_out,)) - return W, b + k1, k2 = jax.random.split(key) + W = jax.random.normal(k1, (n_in, n_out)) / jnp.sqrt(n_in) + b = jax.random.normal(k2, (n_out,)) + return W, b def init_model(key, layer_sizes, batch_size): - key, *keys = jax.random.split(key, len(layer_sizes)) - params = list(map(init_layer, keys, layer_sizes[:-1], layer_sizes[1:])) + key, *keys = jax.random.split(key, len(layer_sizes)) + params = list(map(init_layer, keys, layer_sizes[:-1], layer_sizes[1:])) - key, *keys = jax.random.split(key, 3) - inputs = jax.random.normal(keys[0], (batch_size, layer_sizes[0])) - targets = jax.random.normal(keys[1], (batch_size, layer_sizes[-1])) + key, *keys = jax.random.split(key, 3) + inputs = jax.random.normal(keys[0], (batch_size, layer_sizes[0])) + targets = jax.random.normal(keys[1], (batch_size, layer_sizes[-1])) - return params, (inputs, targets) + return params, (inputs, targets) layer_sizes = [784, 8192, 8192, 8192, 10] batch_size = 8192 @@ -727,30 +654,43 @@ params, batch = init_model(jax.random.key(0), layer_sizes, batch_size) ### 8-way batch data parallelism +```{code-cell} +:id: mJLqRPpSDX0i + +mesh = Mesh(mesh_utils.create_device_mesh((8,)), 'batch') +``` + ```{code-cell} :id: _Q5NbdOn3vGV -sharding = PositionalSharding(jax.devices()).reshape(8, 1) +sharding = NamedSharding(mesh, P('batch')) +replicated_sharding = NamedSharding(mesh, P()) ``` ```{code-cell} :id: 3KC6ieEe3vGV batch = jax.device_put(batch, sharding) -params = jax.device_put(params, sharding.replicate()) +params = jax.device_put(params, replicated_sharding) ``` ```{code-cell} -:id: MUb-QE2b3vGV -:outputId: 1f831ea5-5a30-49ad-8195-977ff7ed476a - +--- +colab: + base_uri: https://localhost:8080/ +id: MUb-QE2b3vGV +outputId: 5a27f007-c572-44f8-9f49-6e745ee739e8 +--- loss_jit(params, batch) ``` ```{code-cell} -:id: HUkw0u413vGV -:outputId: dfa2599c-9440-4657-9035-0dc3bbf625e1 - +--- +colab: + base_uri: https://localhost:8080/ +id: HUkw0u413vGV +outputId: 07e481a1-97fb-4bd0-d754-cb6d8317bff6 +--- step_size = 1e-5 for _ in range(30): @@ -762,9 +702,12 @@ print(loss_jit(params, batch)) ``` ```{code-cell} -:id: paCw6Zaj3vGV -:outputId: 8ab1c32c-f2b1-465c-df71-f5a599e7f19e - +--- +colab: + base_uri: https://localhost:8080/ +id: paCw6Zaj3vGV +outputId: ad4cce34-3a6a-4d44-9a86-477a7fee4841 +--- %timeit -n 5 -r 5 gradfun(params, batch)[0][0].block_until_ready() ``` @@ -776,9 +719,12 @@ params_single = jax.device_put(params, jax.devices()[0]) ``` ```{code-cell} -:id: Z1wgUKXk3vGV -:outputId: 74df8892-c349-41dc-cb1b-e0843ec5c994 - +--- +colab: + base_uri: https://localhost:8080/ +id: Z1wgUKXk3vGV +outputId: d66767b7-3f17-482f-b811-919bb1793277 +--- %timeit -n 5 -r 5 gradfun(params_single, batch_single)[0][0].block_until_ready() ``` @@ -787,58 +733,79 @@ params_single = jax.device_put(params, jax.devices()[0]) ### 4-way batch data parallelism and 2-way model tensor parallelism ```{code-cell} -:id: N5-zzgW03vGW +:id: k1hxOfgRDwo0 -sharding = sharding.reshape(4, 2) +mesh = Mesh(mesh_utils.create_device_mesh((4, 2)), ('batch', 'model')) ``` ```{code-cell} -:id: sgIWCjJK3vGW -:outputId: b2fdc556-05cc-4e68-fa04-48643d194dee - -batch = jax.device_put(batch, sharding.replicate(1)) +--- +colab: + base_uri: https://localhost:8080/ + height: 314 +id: sgIWCjJK3vGW +outputId: 8cb0f19f-3942-415c-c57a-31bb81784f46 +--- +batch = jax.device_put(batch, NamedSharding(mesh, P('batch', None))) jax.debug.visualize_array_sharding(batch[0]) jax.debug.visualize_array_sharding(batch[1]) ``` +```{code-cell} +:id: q9PQP-0eEAO6 + +replicated_sharding = NamedSharding(mesh, P()) +``` + ```{code-cell} :id: BqCjYCgg3vGW (W1, b1), (W2, b2), (W3, b3), (W4, b4) = params -W1 = jax.device_put(W1, sharding.replicate()) -b1 = jax.device_put(b1, sharding.replicate()) +W1 = jax.device_put(W1, replicated_sharding) +b1 = jax.device_put(b1, replicated_sharding) -W2 = jax.device_put(W2, sharding.replicate(0)) -b2 = jax.device_put(b2, sharding.replicate(0)) +W2 = jax.device_put(W2, NamedSharding(mesh, P(None, 'model'))) +b2 = jax.device_put(b2, NamedSharding(mesh, P('model'))) -W3 = jax.device_put(W3, sharding.replicate(0).T) -b3 = jax.device_put(b3, sharding.replicate()) +W3 = jax.device_put(W3, NamedSharding(mesh, P('model', None))) +b3 = jax.device_put(b3, replicated_sharding) -W4 = jax.device_put(W4, sharding.replicate()) -b4 = jax.device_put(b4, sharding.replicate()) +W4 = jax.device_put(W4, replicated_sharding) +b4 = jax.device_put(b4, replicated_sharding) params = (W1, b1), (W2, b2), (W3, b3), (W4, b4) ``` ```{code-cell} -:id: _lSJ63sh3vGW -:outputId: 5b37aa8b-3226-4805-8282-876e8d06edda - +--- +colab: + base_uri: https://localhost:8080/ + height: 199 +id: _lSJ63sh3vGW +outputId: bcd3e33e-36b5-4787-9cd2-60623fd6e5fa +--- jax.debug.visualize_array_sharding(W2) ``` ```{code-cell} -:id: fxkfWYkk3vGW -:outputId: 8a1063c3-540b-47c1-d990-a6845da861f7 - +--- +colab: + base_uri: https://localhost:8080/ + height: 199 +id: fxkfWYkk3vGW +outputId: 59e60b16-fe37-47d4-8214-96096ffbd79c +--- jax.debug.visualize_array_sharding(W3) ``` ```{code-cell} -:id: uPCVs-_k3vGW -:outputId: de01cdfc-36cb-4823-c692-22c692ef4220 - +--- +colab: + base_uri: https://localhost:8080/ +id: uPCVs-_k3vGW +outputId: 618516e9-9736-4ca0-dd22-09d094ce57a2 +--- print(loss_jit(params, batch)) ``` @@ -854,25 +821,35 @@ for _ in range(30): ``` ```{code-cell} -:id: c9Sbl69e3vGX -:outputId: 8272c5fa-e59f-4953-c2d5-658c42a28712 - +--- +colab: + base_uri: https://localhost:8080/ +id: c9Sbl69e3vGX +outputId: 2ee3d432-7172-46ca-e01a-614e83345808 +--- print(loss_jit(params, batch)) ``` ```{code-cell} -:id: lkAF0dAb3vGX -:outputId: acf0df31-c5e1-4683-b73f-b0cd1b0929f8 - +--- +colab: + base_uri: https://localhost:8080/ + height: 380 +id: lkAF0dAb3vGX +outputId: 6c1e317e-cded-4af4-8080-0de835fa4c71 +--- (W1, b1), (W2, b2), (W3, b3), (W4, b4) = params jax.debug.visualize_array_sharding(W2) jax.debug.visualize_array_sharding(W3) ``` ```{code-cell} -:id: I1Npor3i3vGX -:outputId: 4099f6dd-7b46-4123-c1cb-5173c3d3278e - +--- +colab: + base_uri: https://localhost:8080/ +id: I1Npor3i3vGX +outputId: 479c4d81-cb0b-40a5-89ba-394c10dc3297 +--- %timeit -n 10 -r 10 gradfun(params, batch)[0][0].block_until_ready() ``` @@ -903,7 +880,8 @@ def f(key, x): return x + numbers key = jax.random.key(42) -x_sharding = jax.sharding.PositionalSharding(jax.devices()) +mesh = Mesh(jax.devices(), 'x') +x_sharding = NamedSharding(mesh, P('x')) x = jax.device_put(jnp.arange(24), x_sharding) ``` @@ -912,9 +890,13 @@ x = jax.device_put(jnp.arange(24), x_sharding) On a partitioned input, the function `f` produces output that is also partitioned: ```{code-cell} -:id: Oi97rpLz3vGY -:outputId: 204a7e8d-dc88-4b77-b7e3-0e72f306c5d3 - +--- +colab: + base_uri: https://localhost:8080/ + height: 67 +id: Oi97rpLz3vGY +outputId: 9dd63254-a483-4847-c0f5-5a4367bf08e9 +--- jax.debug.visualize_array_sharding(f(key, x)) ``` @@ -923,9 +905,12 @@ jax.debug.visualize_array_sharding(f(key, x)) But if we inspect the compiled computation for `f` on this partitioned input, we see that it does involve some communication: ```{code-cell} -:id: 64wIZuSJ3vGY -:outputId: 1054fe99-0476-44ec-9693-b0d8f98bf6a8 - +--- +colab: + base_uri: https://localhost:8080/ +id: 64wIZuSJ3vGY +outputId: fa166d45-ca9c-457a-be84-bcc9236d0730 +--- f_exe = f.lower(key, x).compile() print('Communicating?', 'collective-permute' in f_exe.as_text()) ``` @@ -935,9 +920,12 @@ print('Communicating?', 'collective-permute' in f_exe.as_text()) One way to work around this is to configure JAX with the experimental upgrade flag `jax_threefry_partitionable`. With the flag on, the "collective permute" operation is now gone from the compiled computation: ```{code-cell} -:id: 1I7bqxA63vGY -:outputId: ec4c579d-f446-4b48-ceda-785c09ba299b - +--- +colab: + base_uri: https://localhost:8080/ +id: 1I7bqxA63vGY +outputId: 756e0a36-ff14-438f-bbd4-3ef03f97a47b +--- jax.config.update('jax_threefry_partitionable', True) f_exe = f.lower(key, x).compile() print('Communicating?', 'collective-permute' in f_exe.as_text()) @@ -948,9 +936,13 @@ print('Communicating?', 'collective-permute' in f_exe.as_text()) The output is still partitioned: ```{code-cell} -:id: zHPJzdn23vGY -:outputId: a8904d20-4d04-4f59-8eae-281e47d29246 - +--- +colab: + base_uri: https://localhost:8080/ + height: 67 +id: zHPJzdn23vGY +outputId: 3332de0f-4827-4f0b-b9ef-69249b7c6bc6 +--- jax.debug.visualize_array_sharding(f(key, x)) ``` @@ -959,9 +951,12 @@ jax.debug.visualize_array_sharding(f(key, x)) One caveat to the `jax_threefry_partitionable` option, however, is that _the random values produced may be different than without the flag set_, even though they were generated by the same random key: ```{code-cell} -:id: nBUHBBal3vGY -:outputId: f194c213-0688-4b7a-ffb8-c4453b82b1f1 - +--- +colab: + base_uri: https://localhost:8080/ +id: nBUHBBal3vGY +outputId: 4b9be948-ccab-4a31-a06f-37ec9c7b5235 +--- jax.config.update('jax_threefry_partitionable', False) print('Stable:') print(f(key, x)) diff --git a/docs/notebooks/How_JAX_primitives_work.ipynb b/docs/notebooks/How_JAX_primitives_work.ipynb deleted file mode 100644 index f42e3f74b4e3..000000000000 --- a/docs/notebooks/How_JAX_primitives_work.ipynb +++ /dev/null @@ -1,1532 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "metadata": { - "id": "vfxqky4PCUnh" - }, - "source": [ - "# How JAX primitives work\n", - "\n", - "\n", - "\n", - "[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/jax/blob/main/docs/notebooks/How_JAX_primitives_work.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/google/jax/blob/main/docs/notebooks/How_JAX_primitives_work.ipynb)\n", - "\n", - "*necula@google.com*, October 2019.\n", - "\n", - "JAX implements certain transformations of Python functions, e.g., `jit`, `grad`,\n", - "`vmap`, or `pmap`. The Python functions to be transformed must be JAX-traceable, \n", - "which means that as the Python function executes\n", - "the only operations it applies to the data are either inspections of data\n", - "attributes such as shape or type, or special operations called JAX primitives.\n", - "In particular, a JAX-traceable function is sometimes invoked by JAX with\n", - "abstract arguments. An example of a JAX abstract value is `ShapedArray(float32[2,2])`, \n", - "which captures the type and the shape of values, but not the concrete data values.\n", - "JAX primitives know how to operate on both concrete data\n", - "values and on the JAX abstract values.\n", - "\n", - "\n", - "The JAX-transformed functions must themselves be JAX-traceable functions,\n", - "to ensure that these transformations\n", - "can be composed, e.g., `jit(jacfwd(grad(f)))`.\n", - "\n", - "There are pre-defined JAX primitives corresponding to most XLA operations, \n", - "e.g., add, matmul, sin, cos, indexing.\n", - "JAX comes with an implementation of numpy functions in terms of JAX primitives, which means that Python programs\n", - "using JAX’s implementation of numpy are JAX-traceable and therefore transformable.\n", - "Other libraries can be made JAX-traceable by implementing them in terms of JAX primitives.\n", - "\n", - "The set of JAX primitives is extensible. Instead of reimplementing a function in terms of pre-defined JAX primitives,\n", - "one can define a new primitive that encapsulates the behavior of the function.\n", - "\n", - "**The goal of this document is to explain the interface that a JAX primitive must support in order to allow JAX to perform all its transformations.**\n", - "\n", - "Consider that we want to add to JAX support for a multiply-add function with three arguments, defined mathematically\n", - "as \"multiply_add(x, y, z) = x * y + z\". \n", - "This function operates on 3 identically-shaped tensors of floating point \n", - "values and performs the operations pointwise." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "HIJYIHNTD1yI" - }, - "source": [ - "## Using existing primitives\n", - "\n", - "The easiest way to define new functions is to write them in terms of JAX primitives, or in terms of other\n", - "functions that are themselves written using JAX primitives, e.g., those \n", - "defined in the `jax.lax` module:" - ] - }, - { - "cell_type": "code", - "execution_count": 1, - "metadata": { - "id": "tbOF0LB0EMne", - "outputId": "3fb1c8a7-7a4c-4a3a-f7ff-37b7dc740528" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "square_add_lax = 14.0\n", - "grad(square_add_lax) = 4.0\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/usr/local/lib/python3.6/dist-packages/jax/lib/xla_bridge.py:115: UserWarning: No GPU/TPU found, falling back to CPU.\n", - " warnings.warn('No GPU/TPU found, falling back to CPU.')\n" - ] - } - ], - "source": [ - "from jax import lax\n", - "from jax._src import api\n", - "\n", - "def multiply_add_lax(x, y, z):\n", - " \"\"\"Implementation of multiply-add using the jax.lax primitives.\"\"\"\n", - " return lax.add(lax.mul(x, y), z)\n", - "\n", - "\n", - "def square_add_lax(a, b):\n", - " \"\"\"A square-add function using the newly defined multiply-add.\"\"\"\n", - " return multiply_add_lax(a, a, b)\n", - "\n", - "print(\"square_add_lax = \", square_add_lax(2., 10.))\n", - "# Differentiate w.r.t. the first argument\n", - "print(\"grad(square_add_lax) = \", api.grad(square_add_lax, argnums=0)(2.0, 10.))" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "Cgv60Wm3E_D5" - }, - "source": [ - "In order to understand how JAX is internally using the primitives,\n", - "we add some helpers for tracing function calls." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "cellView": "form", - "id": "mQRQGEGiE53K" - }, - "outputs": [], - "source": [ - "#@title Helper functions (execute this cell)\n", - "import functools\n", - "import traceback\n", - "\n", - "_indentation = 0\n", - "def _trace(msg=None):\n", - " \"\"\"Print a message at current indentation.\"\"\"\n", - " if msg is not None:\n", - " print(\" \" * _indentation + msg)\n", - "\n", - "def _trace_indent(msg=None):\n", - " \"\"\"Print a message and then indent the rest.\"\"\"\n", - " global _indentation\n", - " _trace(msg)\n", - " _indentation = 1 + _indentation\n", - "\n", - "def _trace_unindent(msg=None):\n", - " \"\"\"Unindent then print a message.\"\"\"\n", - " global _indentation\n", - " _indentation = _indentation - 1\n", - " _trace(msg)\n", - "\n", - "def trace(name):\n", - " \"\"\"A decorator for functions to trace arguments and results.\"\"\"\n", - "\n", - " def trace_func(func): # pylint: disable=missing-docstring\n", - " def pp(v):\n", - " \"\"\"Print certain values more succinctly\"\"\"\n", - " vtype = str(type(v))\n", - " if \"jax._src.xla_bridge._JaxComputationBuilder\" in vtype:\n", - " return \"\"\n", - " elif \"jaxlib.xla_extension.XlaOp\" in vtype:\n", - " return \"\".format(id(v))\n", - " elif (\"partial_eval.JaxprTracer\" in vtype or\n", - " \"batching.BatchTracer\" in vtype or\n", - " \"ad.JVPTracer\" in vtype):\n", - " return \"Traced<{}>\".format(v.aval)\n", - " elif isinstance(v, tuple):\n", - " return \"({})\".format(pp_values(v))\n", - " else:\n", - " return str(v)\n", - " def pp_values(args):\n", - " return \", \".join([pp(arg) for arg in args])\n", - " \n", - " @functools.wraps(func)\n", - " def func_wrapper(*args):\n", - " _trace_indent(\"call {}({})\".format(name, pp_values(args)))\n", - " res = func(*args)\n", - " _trace_unindent(\"|<- {} = {}\".format(name, pp(res)))\n", - " return res\n", - "\n", - " return func_wrapper\n", - "\n", - " return trace_func\n", - "\n", - "class expectNotImplementedError(object):\n", - " \"\"\"Context manager to check for NotImplementedError.\"\"\"\n", - " def __enter__(self): pass\n", - " def __exit__(self, type, value, tb):\n", - " global _indentation\n", - " _indentation = 0\n", - " if type is NotImplementedError:\n", - " print(\"\\nFound expected exception:\")\n", - " traceback.print_exc(limit=3)\n", - " return True\n", - " elif type is None: # No exception\n", - " assert False, \"Expected NotImplementedError\"\n", - " else:\n", - " return False" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "Qf4eLrLCFYDl" - }, - "source": [ - "Instead of using `jax.lax` primitives directly, we can use other functions \n", - "that are already written in terms of those primitives, such as those in `jax.numpy`:" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "metadata": { - "id": "QhKorz6cFRJb", - "outputId": "aba3cef3-6bcc-4eb3-c7b3-34e405f2f82a" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "\n", - "Normal evaluation:\n", - "call square_add_numpy(2.0, 10.0)\n", - " call multiply_add_numpy(2.0, 2.0, 10.0)\n", - " |<- multiply_add_numpy = 14.0\n", - "|<- square_add_numpy = 14.0\n", - "square_add_numpy = 14.0\n", - "\n", - "Gradient evaluation:\n", - "call square_add_numpy(Traced, 10.0)\n", - " call multiply_add_numpy(Traced, Traced, 10.0)\n", - " |<- multiply_add_numpy = Traced\n", - "|<- square_add_numpy = Traced\n", - "grad(square_add_numpy) = 4.0\n" - ] - } - ], - "source": [ - "import jax.numpy as jnp\n", - "import numpy as np\n", - "\n", - "@trace(\"multiply_add_numpy\")\n", - "def multiply_add_numpy(x, y, z):\n", - " return jnp.add(jnp.multiply(x, y), z)\n", - "\n", - "@trace(\"square_add_numpy\")\n", - "def square_add_numpy(a, b):\n", - " return multiply_add_numpy(a, a, b)\n", - "\n", - "print(\"\\nNormal evaluation:\") \n", - "print(\"square_add_numpy = \", square_add_numpy(2., 10.))\n", - "print(\"\\nGradient evaluation:\")\n", - "print(\"grad(square_add_numpy) = \", api.grad(square_add_numpy)(2.0, 10.))" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "Sg-D8EdeFn4a" - }, - "source": [ - "Notice that in the process of computing `grad`, JAX invokes `square_add_numpy` and\n", - "`multiply_add_numpy` with special arguments `ConcreteArray(...)` (described further \n", - "below in this colab). \n", - "It is important to remember that a JAX-traceable function must be able to \n", - "operate not only on concrete arguments but also on special abstract arguments\n", - "that JAX may use to abstract the function execution.\n", - "\n", - "The JAX traceability property is satisfied as long as the function is written \n", - "in terms of JAX primitives." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "WxrQO7-XGLcg" - }, - "source": [ - "## Defining new JAX primitives\n", - "\n", - "The right way to add support for multiply-add is in terms of existing\n", - "JAX primitives, as shown above. However, in order to demonstrate how JAX\n", - "primitives work let us pretend that we want to add a new primitive to \n", - "JAX for the multiply-add functionality." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "cPqAH1XOGTN4" - }, - "outputs": [], - "source": [ - "from jax import core\n", - "multiply_add_p = core.Primitive(\"multiply_add\") # Create the primitive\n", - "\n", - "@trace(\"multiply_add_prim\")\n", - "def multiply_add_prim(x, y, z):\n", - " \"\"\"The JAX-traceable way to use the JAX primitive.\n", - " \n", - " Note that the traced arguments must be passed as positional arguments\n", - " to `bind`. \n", - " \"\"\"\n", - " return multiply_add_p.bind(x, y, z)\n", - "\n", - "@trace(\"square_add_prim\")\n", - "def square_add_prim(a, b):\n", - " \"\"\"A square-add function implemented using the new JAX-primitive.\"\"\"\n", - " return multiply_add_prim(a, a, b)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "LMzs5PAKGr-4" - }, - "source": [ - "If we try to call the newly defined functions we get an error, because\n", - "we have not yet told JAX anything about the semantics of the new primitive." - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "metadata": { - "id": "_X3PAYxhGpWd", - "outputId": "90ea2c6a-9ef3-40ea-e9a3-3ab1cfc59fc8" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "call square_add_prim(2.0, 10.0)\n", - " call multiply_add_prim(2.0, 2.0, 10.0)\n", - "\n", - "Found expected exception:\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "Traceback (most recent call last):\n", - " File \"\", line 2, in \n", - " square_add_prim(2., 10.)\n", - " File \"\", line 47, in func_wrapper\n", - " res = func(*args)\n", - " File \"\", line 16, in square_add_prim\n", - " return multiply_add_prim(a, a, b)\n", - "NotImplementedError: Evaluation rule for 'multiply_add' not implemented\n" - ] - } - ], - "source": [ - "with expectNotImplementedError():\n", - " square_add_prim(2., 10.)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "elha0FdgHSEF" - }, - "source": [ - "### Primal evaluation rules" - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "metadata": { - "id": "FT34FFAGHARU", - "outputId": "4c54f1c2-8a50-4788-90e1-06aee412c43b" - }, - "outputs": [ - { - "data": { - "text/plain": [ - "" - ] - }, - "execution_count": 6, - "metadata": { - "tags": [] - }, - "output_type": "execute_result" - } - ], - "source": [ - "@trace(\"multiply_add_impl\")\n", - "def multiply_add_impl(x, y, z):\n", - " \"\"\"Concrete implementation of the primitive.\n", - "\n", - " This function does not need to be JAX traceable.\n", - " Args:\n", - " x, y, z: the concrete arguments of the primitive. Will only be called with \n", - " concrete values.\n", - " Returns:\n", - " the concrete result of the primitive.\n", - " \"\"\"\n", - " # Note that we can use the original numpy, which is not JAX traceable\n", - " return np.add(np.multiply(x, y), z)\n", - "\n", - "# Now we register the primal implementation with JAX\n", - "multiply_add_p.def_impl(multiply_add_impl)" - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "metadata": { - "id": "G5bstKaeNAVV", - "outputId": "deb94d5b-dfea-4e6f-9ec2-70b416c996c5" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "call square_add_prim(2.0, 10.0)\n", - " call multiply_add_prim(2.0, 2.0, 10.0)\n", - " call multiply_add_impl(2.0, 2.0, 10.0)\n", - " |<- multiply_add_impl = 14.0\n", - " |<- multiply_add_prim = 14.0\n", - "|<- square_add_prim = 14.0\n" - ] - } - ], - "source": [ - "assert square_add_prim(2., 10.) == 14." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "upBf-uAuHhPJ" - }, - "source": [ - "### JIT\n", - "\n", - "If we now try to use `jit` we get a `NotImplementedError`:" - ] - }, - { - "cell_type": "code", - "execution_count": 8, - "metadata": { - "id": "QG-LULjiHk4b", - "outputId": "d4ef4406-8dae-4c96-97ca-b662340474ee" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "call square_add_prim(Traced, Traced)\n", - " call multiply_add_prim(Traced, Traced, Traced)\n", - "\n", - "Found expected exception:\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "Traceback (most recent call last):\n", - " File \"\", line 2, in \n", - " api.jit(square_add_prim)(2., 10.)\n", - " File \"/usr/local/lib/python3.6/dist-packages/jax/api.py\", line 149, in f_jitted\n", - " out = xla.xla_call(flat_fun, *args_flat, device_assignment=device_assignment, backend=backend)\n", - " File \"/usr/local/lib/python3.6/dist-packages/jax/core.py\", line 569, in call_bind\n", - " outs = primitive.impl(f, *args, **params)\n", - "NotImplementedError: Abstract evaluation for 'multiply_add' not implemented\n" - ] - } - ], - "source": [ - "with expectNotImplementedError():\n", - " api.jit(square_add_prim)(2., 10.)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "rHS1bAGHH44E" - }, - "source": [ - "#### Abstract evaluation rules\n", - "In order to JIT the function, and for other transformations as well, \n", - "JAX first evaluates it abstractly using only the \n", - "shape and type of the arguments. This abstract evaluation serves multiple\n", - "purposes:\n", - "\n", - " * Gets the sequence of JAX primitives that are used in the computation. This \n", - " sequence will be compiled. \n", - " * Computes the shape and type of all vectors and operations used in the computation. \n", - "\n", - "\n", - "For example, the abstraction of a vector with 3 elements may be `ShapedArray(float32[3])`, or `ConcreteArray([1., 2., 3.])`. \n", - "In the latter case, JAX uses the actual concrete value wrapped as an abstract value." - ] - }, - { - "cell_type": "code", - "execution_count": 9, - "metadata": { - "id": "ctQmEeckIbdo", - "outputId": "e751d0cc-460e-4ffd-df2e-fdabf9cffdc2" - }, - "outputs": [ - { - "data": { - "text/plain": [ - "" - ] - }, - "execution_count": 9, - "metadata": { - "tags": [] - }, - "output_type": "execute_result" - } - ], - "source": [ - "from jax import core\n", - "@trace(\"multiply_add_abstract_eval\")\n", - "def multiply_add_abstract_eval(xs, ys, zs):\n", - " \"\"\"Abstract evaluation of the primitive.\n", - "\n", - " This function does not need to be JAX traceable. It will be invoked with\n", - " abstractions of the actual arguments. \n", - " Args:\n", - " xs, ys, zs: abstractions of the arguments.\n", - " Result:\n", - " a ShapedArray for the result of the primitive.\n", - " \"\"\"\n", - " assert xs.shape == ys.shape\n", - " assert xs.shape == zs.shape\n", - " return core.ShapedArray(xs.shape, xs.dtype)\n", - "\n", - "# Now we register the abstract evaluation with JAX\n", - "multiply_add_p.def_abstract_eval(multiply_add_abstract_eval)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "RPN88X6YI43A" - }, - "source": [ - "If we re-attempt to JIT, we see how the abstract evaluation proceeds, but\n", - "we get another error, about missing the actual XLA compilation rule:" - ] - }, - { - "cell_type": "code", - "execution_count": 10, - "metadata": { - "id": "eOcNR92SI2h-", - "outputId": "356ef229-3703-4696-cc3d-7c05de405fb0" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "call square_add_prim(Traced, Traced)\n", - " call multiply_add_prim(Traced, Traced, Traced)\n", - " call multiply_add_abstract_eval(ShapedArray(float32[]), ShapedArray(float32[]), ShapedArray(float32[]))\n", - " |<- multiply_add_abstract_eval = ShapedArray(float32[])\n", - " |<- multiply_add_prim = Traced\n", - "|<- square_add_prim = Traced\n", - "\n", - "Found expected exception:\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "Traceback (most recent call last):\n", - " File \"\", line 2, in \n", - " api.jit(square_add_prim)(2., 10.)\n", - " File \"/usr/local/lib/python3.6/dist-packages/jax/api.py\", line 149, in f_jitted\n", - " out = xla.xla_call(flat_fun, *args_flat, device_assignment=device_assignment, backend=backend)\n", - " File \"/usr/local/lib/python3.6/dist-packages/jax/core.py\", line 569, in call_bind\n", - " outs = primitive.impl(f, *args, **params)\n", - "NotImplementedError: XLA translation rule for primitive 'multiply_add' not found\n" - ] - } - ], - "source": [ - "with expectNotImplementedError():\n", - " api.jit(square_add_prim)(2., 10.)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "9IOV1R-fJMHp" - }, - "source": [ - "#### XLA Compilation rules\n", - "\n", - "JAX compilation works by compiling each primitive into a graph of XLA operations.\n", - "\n", - "This is the biggest hurdle to adding new functionality to JAX, because the \n", - "set of XLA operations is limited, and JAX already has pre-defined primitives\n", - "for most of them. However, XLA includes a `CustomCall` operation that can be used to encapsulate arbitrary functionality defined using C++." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "FYQWSSjKJaWP" - }, - "outputs": [], - "source": [ - "from jax._src.lib.mlir.dialects import hlo\n", - "@trace(\"multiply_add_lowering\")\n", - "def multiply_add_lowering(ctx, xc, yc, zc):\n", - " \"\"\"The compilation to XLA of the primitive.\n", - "\n", - " Given an mlir.ir.Value for each argument, return the mlir.ir.Values for\n", - " the results of the function.\n", - "\n", - " Does not need to be a JAX-traceable function.\n", - " \"\"\"\n", - " return [hlo.AddOp(hlo.MulOp(xc, yc), zc).result]\n", - "\n", - "# Now we register the lowering rule with JAX\n", - "# For GPU see the [Custom operations for GPUs](https://jax.readthedocs.io/en/latest/Custom_Operation_for_GPUs.html)\n", - "# TODO: TPU?\n", - "from jax.interpreters import mlir\n", - "mlir.register_lowering(multiply_add_p, multiply_add_lowering, platform='cpu')" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "K98LX-VaJkFu" - }, - "source": [ - "Now we succeed to JIT. Notice below that JAX first evaluates the function\n", - "abstractly, which triggers the `multiply_add_abstract_eval` function, and \n", - "then compiles the set of primitives it has encountered, including `multiply_add`.\n", - "At this point JAX invokes `multiply_add_xla_translation`." - ] - }, - { - "cell_type": "code", - "execution_count": 12, - "metadata": { - "id": "rj3TLsolJgEc", - "outputId": "e384bee4-1e9c-4344-f49c-d3b5ec08eb32" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "call square_add_prim(Traced, Traced)\n", - " call multiply_add_prim(Traced, Traced, Traced)\n", - " call multiply_add_abstract_eval(ShapedArray(float32[]), ShapedArray(float32[]), ShapedArray(float32[]))\n", - " |<- multiply_add_abstract_eval = ShapedArray(float32[])\n", - " |<- multiply_add_prim = Traced\n", - "|<- square_add_prim = Traced\n", - "call multiply_add_xla_translation(, , , )\n", - "|<- multiply_add_xla_translation = \n" - ] - } - ], - "source": [ - "assert api.jit(lambda x, y: square_add_prim(x, y))(2., 10.) == 14." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "Omrez-2_KFfo" - }, - "source": [ - "Below is another use of `jit` where we compile only\n", - "with respect to the first argument. Notice how the second argument to `square_add_prim` is concrete, which leads\n", - "in the third argument to `multiply_add_abstract_eval` being \n", - "`ConcreteArray`. We see that `multiply_add_abstract_eval` may be used with\n", - "both `ShapedArray` and `ConcreteArray`." - ] - }, - { - "cell_type": "code", - "execution_count": 13, - "metadata": { - "id": "mPfTwIBoKOEK", - "outputId": "b293b9b6-a2f9-48f5-f7eb-d4f99c3d905b" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "call square_add_prim(Traced, 10.0)\n", - " call multiply_add_prim(Traced, Traced, 10.0)\n", - " call multiply_add_abstract_eval(ShapedArray(float32[]), ShapedArray(float32[]), ConcreteArray(10.0))\n", - " |<- multiply_add_abstract_eval = ShapedArray(float32[])\n", - " |<- multiply_add_prim = Traced\n", - "|<- square_add_prim = Traced\n", - "call multiply_add_xla_translation(, , , )\n", - "|<- multiply_add_xla_translation = \n" - ] - } - ], - "source": [ - "assert api.jit(lambda x, y: square_add_prim(x, y), \n", - " static_argnums=1)(2., 10.) == 14." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "_Ya3B5l4J1VA" - }, - "source": [ - "### Forward differentiation\n", - "\n", - "JAX implements forward differentiation in the form of\n", - "a Jacobian-vector product (see the [JAX autodiff cookbook](https://jax.readthedocs.io/en/latest/notebooks/autodiff_cookbook.html#Jacobian-Matrix-and-Matrix-Jacobian-products)).\n", - "\n", - "If we attempt now to compute the `jvp` function we get an\n", - "error because we have not yet told JAX how to differentiate\n", - "the `multiply_add` primitive." - ] - }, - { - "cell_type": "code", - "execution_count": 14, - "metadata": { - "id": "OxDx6NQnKwMI", - "outputId": "ce659ef3-c03c-4856-f252-49ec4b6eb964" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "call square_add_prim(Traced, Traced)\n", - " call multiply_add_prim(Traced, Traced, Traced)\n", - "\n", - "Found expected exception:\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "Traceback (most recent call last):\n", - " File \"/usr/local/lib/python3.6/dist-packages/jax/interpreters/ad.py\", line 217, in process_primitive\n", - " jvp = primitive_jvps[primitive]\n", - "KeyError: multiply_add\n", - "\n", - "During handling of the above exception, another exception occurred:\n", - "\n", - "Traceback (most recent call last):\n", - " File \"\", line 2, in \n", - " api.jvp(square_add_prim, (2., 10.), (1., 1.))\n", - " File \"/usr/local/lib/python3.6/dist-packages/jax/api.py\", line 978, in jvp\n", - " out_primals, out_tangents = ad.jvp(flat_fun).call_wrapped(ps_flat, ts_flat)\n", - " File \"/usr/local/lib/python3.6/dist-packages/jax/linear_util.py\", line 165, in call_wrapped\n", - " ans = self.f(*args, **dict(self.params, **kwargs))\n", - "NotImplementedError: Forward-mode differentiation rule for 'multiply_add' not implemented\n" - ] - } - ], - "source": [ - "# The second argument `(2., 10.)` are the argument values\n", - "# where we evaluate the Jacobian, and the third `(1., 1.)`\n", - "# are the values of the tangents for the arguments.\n", - "with expectNotImplementedError():\n", - " api.jvp(square_add_prim, (2., 10.), (1., 1.))" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "zxG24C1JMIMM" - }, - "outputs": [], - "source": [ - "from jax.interpreters import ad\n", - "\n", - "\n", - "@trace(\"multiply_add_value_and_jvp\")\n", - "def multiply_add_value_and_jvp(arg_values, arg_tangents):\n", - " \"\"\"Evaluates the primal output and the tangents (Jacobian-vector product).\n", - "\n", - " Given values of the arguments and perturbation of the arguments (tangents), \n", - " compute the output of the primitive and the perturbation of the output.\n", - "\n", - " This method must be JAX-traceable. JAX may invoke it with abstract values \n", - " for the arguments and tangents.\n", - "\n", - " Args:\n", - " arg_values: a tuple of arguments\n", - " arg_tangents: a tuple with the tangents of the arguments. The tuple has \n", - " the same length as the arg_values. Some of the tangents may also be the \n", - " special value ad.Zero to specify a zero tangent.\n", - " Returns:\n", - " a pair of the primal output and the tangent.\n", - " \"\"\"\n", - " x, y, z = arg_values\n", - " xt, yt, zt = arg_tangents\n", - " _trace(\"Primal evaluation:\")\n", - " # Now we have a JAX-traceable computation of the output. \n", - " # Normally, we can use the ma primitive itself to compute the primal output. \n", - " primal_out = multiply_add_prim(x, y, z)\n", - " \n", - " _trace(\"Tangent evaluation:\")\n", - " # We must use a JAX-traceable way to compute the tangent. It turns out that \n", - " # the output tangent can be computed as (xt * y + x * yt + zt),\n", - " # which we can implement in a JAX-traceable way using the same \"multiply_add_prim\" primitive.\n", - " \n", - " # We do need to deal specially with Zero. Here we just turn it into a \n", - " # proper tensor of 0s (of the same shape as 'x'). \n", - " # An alternative would be to check for Zero and perform algebraic \n", - " # simplification of the output tangent computation.\n", - " def make_zero(tan):\n", - " return lax.zeros_like_array(x) if type(tan) is ad.Zero else tan \n", - " \n", - " output_tangent = multiply_add_prim(make_zero(xt), y, multiply_add_prim(x, make_zero(yt), make_zero(zt)))\n", - " return (primal_out, output_tangent)\n", - "\n", - "# Register the forward differentiation rule with JAX \n", - "ad.primitive_jvps[multiply_add_p] = multiply_add_value_and_jvp" - ] - }, - { - "cell_type": "code", - "execution_count": 16, - "metadata": { - "id": "ma3KBkiAMfW1", - "outputId": "f34cbbc6-20d9-48ca-9a9a-b5d91a972cdd" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "call square_add_prim(Traced, Traced)\n", - " call multiply_add_prim(Traced, Traced, Traced)\n", - " call multiply_add_value_and_jvp((2.0, 2.0, 10.0), (1.0, 1.0, 1.0))\n", - " Primal evaluation:\n", - " call multiply_add_prim(2.0, 2.0, 10.0)\n", - " call multiply_add_impl(2.0, 2.0, 10.0)\n", - " |<- multiply_add_impl = 14.0\n", - " |<- multiply_add_prim = 14.0\n", - " Tangent evaluation:\n", - " call multiply_add_prim(2.0, 1.0, 1.0)\n", - " call multiply_add_impl(2.0, 1.0, 1.0)\n", - " |<- multiply_add_impl = 3.0\n", - " |<- multiply_add_prim = 3.0\n", - " call multiply_add_prim(1.0, 2.0, 3.0)\n", - " call multiply_add_impl(1.0, 2.0, 3.0)\n", - " |<- multiply_add_impl = 5.0\n", - " |<- multiply_add_prim = 5.0\n", - " |<- multiply_add_value_and_jvp = (14.0, 5.0)\n", - " |<- multiply_add_prim = Traced\n", - "|<- square_add_prim = Traced\n" - ] - } - ], - "source": [ - "# Tangent is: xt*y + x*yt + zt = 1.*2. + 2.*1. + 1. = 5.\n", - "assert api.jvp(square_add_prim, (2., 10.), (1., 1.)) == (14., 5.)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "69QsEcu-lP4u" - }, - "source": [ - "TO EXPLAIN: \n", - "\n", - " * Why is JAX using ConcreteArray in square_add_prim? There is no abstract evaluation going on here.\n", - " * Not sure how to explain that multiply_add_prim is invoked with ConcreteValue, yet\n", - " we do not call the multiply_add_abstract_eval.\n", - " * I think it would be useful to show the jaxpr here" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "Sb6e3ZAHOPHv" - }, - "source": [ - "#### JIT of forward differentiation\n", - "\n", - "We can apply JIT to the forward differentiation function:" - ] - }, - { - "cell_type": "code", - "execution_count": 17, - "metadata": { - "id": "hg-hzVu-N-hv", - "outputId": "38d32067-e152-4046-ad80-7f95a31ba628" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "call square_add_prim(Traced, Traced)\n", - " call multiply_add_prim(Traced, Traced, Traced)\n", - " call multiply_add_value_and_jvp((Traced, Traced, Traced), (Traced, Traced, Traced))\n", - " Primal evaluation:\n", - " call multiply_add_prim(Traced, Traced, Traced)\n", - " call multiply_add_abstract_eval(ShapedArray(float32[]), ShapedArray(float32[]), ShapedArray(float32[]))\n", - " |<- multiply_add_abstract_eval = ShapedArray(float32[])\n", - " |<- multiply_add_prim = Traced\n", - " Tangent evaluation:\n", - " call multiply_add_prim(Traced, Traced, Traced)\n", - " call multiply_add_abstract_eval(ShapedArray(float32[]), ShapedArray(float32[]), ShapedArray(float32[]))\n", - " |<- multiply_add_abstract_eval = ShapedArray(float32[])\n", - " |<- multiply_add_prim = Traced\n", - " call multiply_add_prim(Traced, Traced, Traced)\n", - " call multiply_add_abstract_eval(ShapedArray(float32[]), ShapedArray(float32[]), ShapedArray(float32[]))\n", - " |<- multiply_add_abstract_eval = ShapedArray(float32[])\n", - " |<- multiply_add_prim = Traced\n", - " |<- multiply_add_value_and_jvp = (Traced, Traced)\n", - " |<- multiply_add_prim = Traced\n", - "|<- square_add_prim = Traced\n", - "call multiply_add_xla_translation(, , , )\n", - "|<- multiply_add_xla_translation = \n", - "call multiply_add_xla_translation(, , , )\n", - "|<- multiply_add_xla_translation = \n", - "call multiply_add_xla_translation(, , , )\n", - "|<- multiply_add_xla_translation = \n" - ] - } - ], - "source": [ - "assert api.jit(lambda arg_values, arg_tangents: \n", - " api.jvp(square_add_prim, arg_values, arg_tangents))(\n", - " (2., 10.), (1., 1.)) == (14., 5.)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "jlZt1_v2mU88" - }, - "source": [ - "Notice that first we evaluate `multiply_add_value_and_jvp` abstractly, which in turn\n", - "evaluates abstractly both the primal and the tangent evaluation (a total of \n", - "3 invocations of the `ma` primitive). Then we compile the 3 occurrences\n", - "of the primitive." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "555yt6ZIOePB" - }, - "source": [ - "### Reverse differentiation\n", - "\n", - "If we attempt now to use reverse differentiation we\n", - "see that JAX starts by using the `multiply_add_value_and_jvp` to \n", - "compute the forward differentiation for abstract values, but then runs\n", - "into a `NotImplementedError`. \n", - "\n", - "When computing the reverse differentiation JAX first does abstract evaluation\n", - "of the forward differentiation code `multiply_add_value_and_jvp` to obtain a \n", - "trace of primitives that compute the output tangent. \n", - "Observe that JAX performs this abstract evaluation with concrete values\n", - "for the differentiation point, and abstract values for the tangents. \n", - "Observe also that JAX uses the special abstract tangent value `Zero` for\n", - "the tangent corresponding to the 3rd argument of `ma`. This reflects the \n", - "fact that we do not differentiate w.r.t. the 2nd argument to `square_add_prim`,\n", - "which flows to the 3rd argument to `multiply_add_prim`.\n", - "\n", - "Observe also that during the abstract evaluation of the tangent we pass the \n", - "value 0.0 as the tangent for the 3rd argument. This is due to the use\n", - "of the `make_zero` function in the definition of `multiply_add_value_and_jvp`." - ] - }, - { - "cell_type": "code", - "execution_count": 18, - "metadata": { - "id": "8eAVnexaOjBn", - "outputId": "e4ee89cf-ab4a-4505-9817-fa978a2865ab" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "call square_add_prim(Traced, 10.0)\n", - " call multiply_add_prim(Traced, Traced, 10.0)\n", - " call multiply_add_value_and_jvp((Traced, Traced, 10.0), (Traced, Traced, Zero))\n", - " Primal evaluation:\n", - " call multiply_add_prim(Traced, Traced, 10.0)\n", - " call multiply_add_impl(2.0, 2.0, 10.0)\n", - " |<- multiply_add_impl = 14.0\n", - " |<- multiply_add_prim = 14.0\n", - " Tangent evaluation:\n", - " call multiply_add_prim(Traced, Traced, 0.0)\n", - " call multiply_add_abstract_eval(ConcreteArray(2.0), ShapedArray(float32[]), ConcreteArray(0.0))\n", - " |<- multiply_add_abstract_eval = ShapedArray(float32[])\n", - " |<- multiply_add_prim = Traced\n", - " call multiply_add_prim(Traced, Traced, Traced)\n", - " call multiply_add_abstract_eval(ShapedArray(float32[]), ConcreteArray(2.0), ShapedArray(float32[]))\n", - " |<- multiply_add_abstract_eval = ShapedArray(float32[])\n", - " |<- multiply_add_prim = Traced\n", - " |<- multiply_add_value_and_jvp = (14.0, Traced)\n", - " |<- multiply_add_prim = Traced\n", - "|<- square_add_prim = Traced\n", - "\n", - "Found expected exception:\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "Traceback (most recent call last):\n", - " File \"/usr/local/lib/python3.6/dist-packages/jax/interpreters/ad.py\", line 198, in get_primitive_transpose\n", - " return primitive_transposes[p]\n", - "KeyError: multiply_add\n", - "\n", - "During handling of the above exception, another exception occurred:\n", - "\n", - "Traceback (most recent call last):\n", - " File \"\", line 2, in \n", - " api.grad(square_add_prim)(2., 10.)\n", - " File \"/usr/local/lib/python3.6/dist-packages/jax/api.py\", line 340, in grad_f\n", - " _, g = value_and_grad_f(*args, **kwargs)\n", - " File \"/usr/local/lib/python3.6/dist-packages/jax/api.py\", line 398, in value_and_grad_f\n", - " g = vjp_py(np.ones((), dtype=dtype))\n", - "NotImplementedError: Reverse-mode differentiation rule for 'multiply_add' not implemented\n" - ] - } - ], - "source": [ - "# This is reverse differentiation w.r.t. the first argument of square_add_prim\n", - "with expectNotImplementedError():\n", - " api.grad(square_add_prim)(2., 10.)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "fSHLUMDN26AY" - }, - "source": [ - "The above error is because there is a missing piece for JAX to be able\n", - "to use the forward differentiation code to compute reverse differentiation." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "3ibDbGF-PjK9" - }, - "source": [ - "#### Transposition\n", - "\n", - "\n", - "As explained above, when computing reverse differentiation JAX obtains\n", - "a trace of primitives that compute the tangent using forward differentiation.\n", - "Then, **JAX interprets this trace abstractly backwards** and for each \n", - "primitive it applies a **transposition** rule.\n", - "\n", - "To understand what is going on, consider for now a simpler example of the function \"f(x, y) = x * y + y\". Assume we need to differentiate at the point `(2., 4.)`. JAX will produce the following JVP tangent calculation of `ft` from the tangents of the input `xt` and `yt`:\n", - "```\n", - " a = xt * 4.\n", - " b = 2. * yt\n", - " c = a + b\n", - " ft = c + yt\n", - "```\n", - "\n", - "By construction, the tangent calculation is always linear in the input tangents. \n", - "The only non-linear operator that may arise in the tangent calculation is multiplication,\n", - "but then one of the operands is constant.\n", - "\n", - "JAX will produce the reverse differentiation computation by processing the\n", - "JVP computation backwards. For each operation in the tangent computation,\n", - "it accumulates the cotangents\n", - "of the variables used by the operation, using the cotangent of the result\n", - "of the operation:\n", - "```\n", - " # Initialize cotangents of inputs and intermediate vars\n", - " xct = yct = act = bct = cct = 0.\n", - " # Initialize cotangent of the output\n", - " fct = 1.\n", - " # Process \"ft = c + yt\"\n", - " cct += fct\n", - " yct += fct\n", - " # Process \"c = a + b\"\n", - " act += cct\n", - " bct += cct\n", - " # Process \"b = 2. * yt\"\n", - " yct += 2. * bct\n", - " # Process \"a = xt * 4.\"\n", - " xct += act * 4.\n", - "```\n", - "\n", - "One can verify that this computation produces `xct = 4.` and `yct = 3.`, which \n", - "are the partial derivatives of the function `f`. \n", - "\n", - "JAX knows for each primitive that may appear in a JVP calculation how to transpose it. Conceptually, if the primitive `p(x, y, z)` is linear in the arguments `y` and `z` for a constant value of `x`, e.g., `p(x, y, z) = y*cy + z*cz`, then the transposition of the primitive is:\n", - "```\n", - "p_transpose(out_ct, x, _, _) = (None, out_ct*cy, out_ct*cz)\n", - "```\n", - "\n", - "Notice that `p_transpose` takes the cotangent of the output of the primitive and a value corresponding to each argument of the primitive. For the linear arguments, the transposition gets an undefined `_` value, and for the other\n", - "arguments it gets the actual constants. The transposition returns a cotangent value for each argument of the primitive, with the value `None` returned \n", - "for the constant arguments.\n", - "\n", - "In particular, \n", - "```\n", - " add_transpose(out_ct, _, _) = (out_ct, out_ct)\n", - " mult_transpose(out_ct, x, _) = (None, x * out_ct)\n", - " mult_transpose(out_ct, _, y) = (out_ct * y, None)\n", - "```" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "JaHxFdkRO42r" - }, - "outputs": [], - "source": [ - "@trace(\"multiply_add_transpose\")\n", - "def multiply_add_transpose(ct, x, y, z):\n", - " \"\"\"Evaluates the transpose of a linear primitive.\n", - "\n", - " This method is only used when computing the backward gradient following \n", - " value_and_jvp, and is only needed for primitives that are used in the JVP \n", - " calculation for some other primitive. We need transposition for multiply_add_prim, \n", - " because we have used multiply_add_prim in the computation of the output_tangent in \n", - " multiply_add_value_and_jvp.\n", - "\n", - " In our case, multiply_add is not a linear primitive. However, it is used linearly \n", - " w.r.t. tangents in multiply_add_value_and_jvp:\n", - " output_tangent(xt, yt, zt) = multiply_add_prim(xt, y, multiply_add_prim(x, yt, zt))\n", - " \n", - " Always one of the first two multiplicative arguments is a constant.\n", - "\n", - " Args:\n", - " ct: the cotangent of the output of the primitive.\n", - " x, y, z: values of the arguments. The arguments that are used linearly\n", - " get an ad.UndefinedPrimal value. The other arguments get a constant\n", - " value.\n", - " Returns:\n", - " a tuple with the cotangent of the inputs, with the value None\n", - " corresponding to the constant arguments.\n", - " \"\"\"\n", - " if not ad.is_undefined_primal(x):\n", - " # This use of multiply_add is with a constant \"x\"\n", - " assert ad.is_undefined_primal(y)\n", - " ct_y = ad.Zero(y.aval) if type(ct) is ad.Zero else multiply_add_prim(x, ct, lax.zeros_like_array(x))\n", - " res = None, ct_y, ct\n", - " else:\n", - " # This use of multiply_add is with a constant \"y\"\n", - " assert ad.is_undefined_primal(x)\n", - " ct_x = ad.Zero(x.aval) if type(ct) is ad.Zero else multiply_add_prim(ct, y, lax.zeros_like_array(y))\n", - " res = ct_x, None, ct\n", - " return res\n", - "\n", - "\n", - "ad.primitive_transposes[multiply_add_p] = multiply_add_transpose" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "PpChox-Jp7wb" - }, - "source": [ - "Now we can complete the run of the `grad`:" - ] - }, - { - "cell_type": "code", - "execution_count": 20, - "metadata": { - "id": "PogPKS4MPevd", - "outputId": "d33328d4-3e87-45b5-9b31-21ad624b67af" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "call square_add_prim(Traced, 10.0)\n", - " call multiply_add_prim(Traced, Traced, 10.0)\n", - " call multiply_add_value_and_jvp((Traced, Traced, 10.0), (Traced, Traced, Zero))\n", - " Primal evaluation:\n", - " call multiply_add_prim(Traced, Traced, 10.0)\n", - " call multiply_add_impl(2.0, 2.0, 10.0)\n", - " |<- multiply_add_impl = 14.0\n", - " |<- multiply_add_prim = 14.0\n", - " Tangent evaluation:\n", - " call multiply_add_prim(Traced, Traced, 0.0)\n", - " call multiply_add_abstract_eval(ConcreteArray(2.0), ShapedArray(float32[]), ConcreteArray(0.0))\n", - " |<- multiply_add_abstract_eval = ShapedArray(float32[])\n", - " |<- multiply_add_prim = Traced\n", - " call multiply_add_prim(Traced, Traced, Traced)\n", - " call multiply_add_abstract_eval(ShapedArray(float32[]), ConcreteArray(2.0), ShapedArray(float32[]))\n", - " |<- multiply_add_abstract_eval = ShapedArray(float32[])\n", - " |<- multiply_add_prim = Traced\n", - " |<- multiply_add_value_and_jvp = (14.0, Traced)\n", - " |<- multiply_add_prim = Traced\n", - "|<- square_add_prim = Traced\n", - "call multiply_add_transpose(1.0, _, 2.0, _)\n", - " call multiply_add_prim(1.0, 2.0, 0.0)\n", - " call multiply_add_impl(1.0, 2.0, 0.0)\n", - " |<- multiply_add_impl = 2.0\n", - " |<- multiply_add_prim = 2.0\n", - "|<- multiply_add_transpose = (2.0, None, 1.0)\n", - "call multiply_add_transpose(1.0, 2.0, _, 0.0)\n", - " call multiply_add_prim(2.0, 1.0, 0.0)\n", - " call multiply_add_impl(2.0, 1.0, 0.0)\n", - " |<- multiply_add_impl = 2.0\n", - " |<- multiply_add_prim = 2.0\n", - "|<- multiply_add_transpose = (None, 2.0, 1.0)\n" - ] - } - ], - "source": [ - "assert api.grad(square_add_prim)(2., 10.) == 4." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "8M1xLCXW4fK7" - }, - "source": [ - "Notice the two calls to `multiply_add_transpose`. They correspond to the two\n", - "uses of `multiply_add_prim` in the computation of the `output_tangent` in `multiply_add_value_and_jvp`. The first call to transpose corresponds to the \n", - "last use of `multiply_add_prim`: `multiply_add_prim(xt, y, ...)` where `y` is the constant 2.0." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "EIJs6FYmPg6c" - }, - "source": [ - "#### JIT of reverse differentiation \n", - "\n", - "Notice that the abstract evaluation of the `multiply_add_value_and_jvp` is using only\n", - "abstract values, while in the absence of JIT we used `ConcreteArray`." - ] - }, - { - "cell_type": "code", - "execution_count": 21, - "metadata": { - "id": "FZ-JGbWZPq2-", - "outputId": "e42b5222-9c3e-4853-e13a-874f6605d178" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "call square_add_prim(Traced, Traced)\n", - " call multiply_add_prim(Traced, Traced, Traced)\n", - " call multiply_add_value_and_jvp((Traced, Traced, Traced), (Traced, Traced, Zero))\n", - " Primal evaluation:\n", - " call multiply_add_prim(Traced, Traced, Traced)\n", - " call multiply_add_abstract_eval(ShapedArray(float32[]), ShapedArray(float32[]), ShapedArray(float32[]))\n", - " |<- multiply_add_abstract_eval = ShapedArray(float32[])\n", - " |<- multiply_add_prim = Traced\n", - " Tangent evaluation:\n", - " call multiply_add_prim(Traced, Traced, Traced)\n", - " call multiply_add_abstract_eval(ShapedArray(float32[]), ShapedArray(float32[]), ConcreteArray(0.0))\n", - " |<- multiply_add_abstract_eval = ShapedArray(float32[])\n", - " |<- multiply_add_prim = Traced\n", - " call multiply_add_prim(Traced, Traced, Traced)\n", - " call multiply_add_abstract_eval(ShapedArray(float32[]), ShapedArray(float32[]), ShapedArray(float32[]))\n", - " |<- multiply_add_abstract_eval = ShapedArray(float32[])\n", - " |<- multiply_add_prim = Traced\n", - " |<- multiply_add_value_and_jvp = (Traced, Traced)\n", - " |<- multiply_add_prim = Traced\n", - "|<- square_add_prim = Traced\n", - "call multiply_add_transpose(1.0, _, Traced, _)\n", - " call multiply_add_prim(1.0, Traced, Traced)\n", - " call multiply_add_abstract_eval(ConcreteArray(1.0), ShapedArray(float32[]), ConcreteArray(0.0))\n", - " |<- multiply_add_abstract_eval = ShapedArray(float32[])\n", - " |<- multiply_add_prim = Traced\n", - "|<- multiply_add_transpose = (Traced, None, 1.0)\n", - "call multiply_add_transpose(1.0, Traced, _, Traced)\n", - " call multiply_add_prim(Traced, 1.0, Traced)\n", - " call multiply_add_abstract_eval(ShapedArray(float32[]), ConcreteArray(1.0), ConcreteArray(0.0))\n", - " |<- multiply_add_abstract_eval = ShapedArray(float32[])\n", - " |<- multiply_add_prim = Traced\n", - "|<- multiply_add_transpose = (None, Traced, 1.0)\n", - "call multiply_add_xla_translation(, , , )\n", - "|<- multiply_add_xla_translation = \n", - "call multiply_add_xla_translation(, , , )\n", - "|<- multiply_add_xla_translation = \n" - ] - } - ], - "source": [ - "assert api.jit(api.grad(square_add_prim))(2., 10.) == 4." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "-3lqPkdQPvl5" - }, - "source": [ - "### Batching\n", - "\n", - "The batching transformation takes a point-wise computation and turns it\n", - "into a computation on vectors. If we try it right now, we get a `NotImplementedError`:" - ] - }, - { - "cell_type": "code", - "execution_count": 22, - "metadata": { - "id": "hFvBR3I9Pzh3", - "outputId": "434608bc-281f-4d3b-83bd-eaaf3b51b1cd" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "call square_add_prim(Traced, Traced)\n", - " call multiply_add_prim(Traced, Traced, Traced)\n", - "\n", - "Found expected exception:\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "Traceback (most recent call last):\n", - " File \"/usr/local/lib/python3.6/dist-packages/jax/interpreters/batching.py\", line 163, in get_primitive_batcher\n", - " return primitive_batchers[p]\n", - "KeyError: multiply_add\n", - "\n", - "During handling of the above exception, another exception occurred:\n", - "\n", - "Traceback (most recent call last):\n", - " File \"\", line 3, in \n", - " np.array([10., 20.]))\n", - " File \"/usr/local/lib/python3.6/dist-packages/jax/api.py\", line 611, in batched_fun\n", - " lambda: _flatten_axes(out_tree(), out_axes))\n", - " File \"/usr/local/lib/python3.6/dist-packages/jax/interpreters/batching.py\", line 41, in batch\n", - " out_vals, out_dims = batch2(fun, in_vals, in_dims)\n", - "NotImplementedError: Batching rule for 'multiply_add' not implemented\n" - ] - } - ], - "source": [ - "# The arguments are two vectors instead of two scalars\n", - "with expectNotImplementedError():\n", - " api.vmap(square_add_prim, in_axes=0, out_axes=0)(np.array([2., 3.]),\n", - " np.array([10., 20.]))" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "gILasMiP6elR" - }, - "source": [ - "We need to tell JAX how to evaluate the batched version of the primitive. In this particular case, the `multiply_add_prim` already operates pointwise for any dimension of input vectors. So the batched version can use the same `multiply_add_prim` implementation." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "KQfeqRIrP7zg" - }, - "outputs": [], - "source": [ - "from jax.interpreters import batching\n", - "\n", - "\n", - "@trace(\"multiply_add_batch\")\n", - "def multiply_add_batch(vector_arg_values, batch_axes):\n", - " \"\"\"Computes the batched version of the primitive.\n", - " \n", - " This must be a JAX-traceable function.\n", - " \n", - " Since the multiply_add primitive already operates pointwise on arbitrary\n", - " dimension tensors, to batch it we can use the primitive itself. This works as\n", - " long as both the inputs have the same dimensions and are batched along the\n", - " same axes. The result is batched along the axis that the inputs are batched.\n", - " \n", - " Args:\n", - " vector_arg_values: a tuple of two arguments, each being a tensor of matching\n", - " shape.\n", - " batch_axes: the axes that are being batched. See vmap documentation.\n", - " Returns:\n", - " a tuple of the result, and the result axis that was batched. \n", - " \"\"\"\n", - " assert batch_axes[0] == batch_axes[1]\n", - " assert batch_axes[0] == batch_axes[2]\n", - " _trace(\"Using multiply_add to compute the batch:\")\n", - " res = multiply_add_prim(*vector_arg_values)\n", - " return res, batch_axes[0]\n", - "\n", - "\n", - "batching.primitive_batchers[multiply_add_p] = multiply_add_batch" - ] - }, - { - "cell_type": "code", - "execution_count": 24, - "metadata": { - "id": "VwxNk869P_YG", - "outputId": "9d22c921-5803-4d33-9e88-b6e439ba9738" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "call square_add_prim(Traced, Traced)\n", - " call multiply_add_prim(Traced, Traced, Traced)\n", - " call multiply_add_batch(([2. 3.], [2. 3.], [10. 20.]), (0, 0, 0))\n", - " Using multiply_add to compute the batch:\n", - " call multiply_add_prim([2. 3.], [2. 3.], [10. 20.])\n", - " call multiply_add_impl([2. 3.], [2. 3.], [10. 20.])\n", - " |<- multiply_add_impl = [14. 29.]\n", - " |<- multiply_add_prim = [14. 29.]\n", - " |<- multiply_add_batch = ([14. 29.], 0)\n", - " |<- multiply_add_prim = Traced\n", - "|<- square_add_prim = Traced\n" - ] - } - ], - "source": [ - "assert np.allclose(api.vmap(square_add_prim, in_axes=0, out_axes=0)(\n", - " np.array([2., 3.]),\n", - " np.array([10., 20.])),\n", - " [14., 29.])" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "NmqLlV1TQDCC" - }, - "source": [ - "#### JIT of batching" - ] - }, - { - "cell_type": "code", - "execution_count": 25, - "metadata": { - "id": "xqEdXVUgQCTt", - "outputId": "9c22fd9c-919c-491d-bbeb-32c241b808fa" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "call square_add_prim(Traced, Traced)\n", - " call multiply_add_prim(Traced, Traced, Traced)\n", - " call multiply_add_batch((Traced, Traced, Traced), (0, 0, 0))\n", - " Using multiply_add to compute the batch:\n", - " call multiply_add_prim(Traced, Traced, Traced)\n", - " call multiply_add_abstract_eval(ShapedArray(float32[2]), ShapedArray(float32[2]), ShapedArray(float32[2]))\n", - " |<- multiply_add_abstract_eval = ShapedArray(float32[2])\n", - " |<- multiply_add_prim = Traced\n", - " |<- multiply_add_batch = (Traced, 0)\n", - " |<- multiply_add_prim = Traced\n", - "|<- square_add_prim = Traced\n", - "call multiply_add_xla_translation(, , , )\n", - "|<- multiply_add_xla_translation = \n" - ] - } - ], - "source": [ - "assert np.allclose(api.jit(api.vmap(square_add_prim, in_axes=0, out_axes=0))\n", - " (np.array([2., 3.]),\n", - " np.array([10., 20.])),\n", - " [14., 29.])" - ] - } - ], - "metadata": { - "colab": { - "collapsed_sections": [], - "name": "How JAX primitives work.ipynb", - "provenance": [], - "toc_visible": true - }, - "jupytext": { - "formats": "ipynb,md:myst" - }, - "kernelspec": { - "display_name": "Python 3", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.7.6" - } - }, - "nbformat": 4, - "nbformat_minor": 0 -} diff --git a/docs/notebooks/How_JAX_primitives_work.md b/docs/notebooks/How_JAX_primitives_work.md deleted file mode 100644 index 0ebf202f2258..000000000000 --- a/docs/notebooks/How_JAX_primitives_work.md +++ /dev/null @@ -1,771 +0,0 @@ ---- -jupytext: - formats: ipynb,md:myst - text_representation: - extension: .md - format_name: myst - format_version: 0.13 - jupytext_version: 1.16.1 -kernelspec: - display_name: Python 3 - name: python3 ---- - -+++ {"id": "vfxqky4PCUnh"} - -# How JAX primitives work - - - -[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/jax/blob/main/docs/notebooks/How_JAX_primitives_work.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/google/jax/blob/main/docs/notebooks/How_JAX_primitives_work.ipynb) - -*necula@google.com*, October 2019. - -JAX implements certain transformations of Python functions, e.g., `jit`, `grad`, -`vmap`, or `pmap`. The Python functions to be transformed must be JAX-traceable, -which means that as the Python function executes -the only operations it applies to the data are either inspections of data -attributes such as shape or type, or special operations called JAX primitives. -In particular, a JAX-traceable function is sometimes invoked by JAX with -abstract arguments. An example of a JAX abstract value is `ShapedArray(float32[2,2])`, -which captures the type and the shape of values, but not the concrete data values. -JAX primitives know how to operate on both concrete data -values and on the JAX abstract values. - - -The JAX-transformed functions must themselves be JAX-traceable functions, -to ensure that these transformations -can be composed, e.g., `jit(jacfwd(grad(f)))`. - -There are pre-defined JAX primitives corresponding to most XLA operations, -e.g., add, matmul, sin, cos, indexing. -JAX comes with an implementation of numpy functions in terms of JAX primitives, which means that Python programs -using JAX’s implementation of numpy are JAX-traceable and therefore transformable. -Other libraries can be made JAX-traceable by implementing them in terms of JAX primitives. - -The set of JAX primitives is extensible. Instead of reimplementing a function in terms of pre-defined JAX primitives, -one can define a new primitive that encapsulates the behavior of the function. - -**The goal of this document is to explain the interface that a JAX primitive must support in order to allow JAX to perform all its transformations.** - -Consider that we want to add to JAX support for a multiply-add function with three arguments, defined mathematically -as "multiply_add(x, y, z) = x * y + z". -This function operates on 3 identically-shaped tensors of floating point -values and performs the operations pointwise. - -+++ {"id": "HIJYIHNTD1yI"} - -## Using existing primitives - -The easiest way to define new functions is to write them in terms of JAX primitives, or in terms of other -functions that are themselves written using JAX primitives, e.g., those -defined in the `jax.lax` module: - -```{code-cell} ipython3 -:id: tbOF0LB0EMne -:outputId: 3fb1c8a7-7a4c-4a3a-f7ff-37b7dc740528 - -from jax import lax -from jax._src import api - -def multiply_add_lax(x, y, z): - """Implementation of multiply-add using the jax.lax primitives.""" - return lax.add(lax.mul(x, y), z) - - -def square_add_lax(a, b): - """A square-add function using the newly defined multiply-add.""" - return multiply_add_lax(a, a, b) - -print("square_add_lax = ", square_add_lax(2., 10.)) -# Differentiate w.r.t. the first argument -print("grad(square_add_lax) = ", api.grad(square_add_lax, argnums=0)(2.0, 10.)) -``` - -+++ {"id": "Cgv60Wm3E_D5"} - -In order to understand how JAX is internally using the primitives, -we add some helpers for tracing function calls. - -```{code-cell} ipython3 -:cellView: form -:id: mQRQGEGiE53K - -#@title Helper functions (execute this cell) -import functools -import traceback - -_indentation = 0 -def _trace(msg=None): - """Print a message at current indentation.""" - if msg is not None: - print(" " * _indentation + msg) - -def _trace_indent(msg=None): - """Print a message and then indent the rest.""" - global _indentation - _trace(msg) - _indentation = 1 + _indentation - -def _trace_unindent(msg=None): - """Unindent then print a message.""" - global _indentation - _indentation = _indentation - 1 - _trace(msg) - -def trace(name): - """A decorator for functions to trace arguments and results.""" - - def trace_func(func): # pylint: disable=missing-docstring - def pp(v): - """Print certain values more succinctly""" - vtype = str(type(v)) - if "jax._src.xla_bridge._JaxComputationBuilder" in vtype: - return "" - elif "jaxlib.xla_extension.XlaOp" in vtype: - return "".format(id(v)) - elif ("partial_eval.JaxprTracer" in vtype or - "batching.BatchTracer" in vtype or - "ad.JVPTracer" in vtype): - return "Traced<{}>".format(v.aval) - elif isinstance(v, tuple): - return "({})".format(pp_values(v)) - else: - return str(v) - def pp_values(args): - return ", ".join([pp(arg) for arg in args]) - - @functools.wraps(func) - def func_wrapper(*args): - _trace_indent("call {}({})".format(name, pp_values(args))) - res = func(*args) - _trace_unindent("|<- {} = {}".format(name, pp(res))) - return res - - return func_wrapper - - return trace_func - -class expectNotImplementedError(object): - """Context manager to check for NotImplementedError.""" - def __enter__(self): pass - def __exit__(self, type, value, tb): - global _indentation - _indentation = 0 - if type is NotImplementedError: - print("\nFound expected exception:") - traceback.print_exc(limit=3) - return True - elif type is None: # No exception - assert False, "Expected NotImplementedError" - else: - return False -``` - -+++ {"id": "Qf4eLrLCFYDl"} - -Instead of using `jax.lax` primitives directly, we can use other functions -that are already written in terms of those primitives, such as those in `jax.numpy`: - -```{code-cell} ipython3 -:id: QhKorz6cFRJb -:outputId: aba3cef3-6bcc-4eb3-c7b3-34e405f2f82a - -import jax.numpy as jnp -import numpy as np - -@trace("multiply_add_numpy") -def multiply_add_numpy(x, y, z): - return jnp.add(jnp.multiply(x, y), z) - -@trace("square_add_numpy") -def square_add_numpy(a, b): - return multiply_add_numpy(a, a, b) - -print("\nNormal evaluation:") -print("square_add_numpy = ", square_add_numpy(2., 10.)) -print("\nGradient evaluation:") -print("grad(square_add_numpy) = ", api.grad(square_add_numpy)(2.0, 10.)) -``` - -+++ {"id": "Sg-D8EdeFn4a"} - -Notice that in the process of computing `grad`, JAX invokes `square_add_numpy` and -`multiply_add_numpy` with special arguments `ConcreteArray(...)` (described further -below in this colab). -It is important to remember that a JAX-traceable function must be able to -operate not only on concrete arguments but also on special abstract arguments -that JAX may use to abstract the function execution. - -The JAX traceability property is satisfied as long as the function is written -in terms of JAX primitives. - -+++ {"id": "WxrQO7-XGLcg"} - -## Defining new JAX primitives - -The right way to add support for multiply-add is in terms of existing -JAX primitives, as shown above. However, in order to demonstrate how JAX -primitives work let us pretend that we want to add a new primitive to -JAX for the multiply-add functionality. - -```{code-cell} ipython3 -:id: cPqAH1XOGTN4 - -from jax import core -multiply_add_p = core.Primitive("multiply_add") # Create the primitive - -@trace("multiply_add_prim") -def multiply_add_prim(x, y, z): - """The JAX-traceable way to use the JAX primitive. - - Note that the traced arguments must be passed as positional arguments - to `bind`. - """ - return multiply_add_p.bind(x, y, z) - -@trace("square_add_prim") -def square_add_prim(a, b): - """A square-add function implemented using the new JAX-primitive.""" - return multiply_add_prim(a, a, b) -``` - -+++ {"id": "LMzs5PAKGr-4"} - -If we try to call the newly defined functions we get an error, because -we have not yet told JAX anything about the semantics of the new primitive. - -```{code-cell} ipython3 -:id: _X3PAYxhGpWd -:outputId: 90ea2c6a-9ef3-40ea-e9a3-3ab1cfc59fc8 - -with expectNotImplementedError(): - square_add_prim(2., 10.) -``` - -+++ {"id": "elha0FdgHSEF"} - -### Primal evaluation rules - -```{code-cell} ipython3 -:id: FT34FFAGHARU -:outputId: 4c54f1c2-8a50-4788-90e1-06aee412c43b - -@trace("multiply_add_impl") -def multiply_add_impl(x, y, z): - """Concrete implementation of the primitive. - - This function does not need to be JAX traceable. - Args: - x, y, z: the concrete arguments of the primitive. Will only be called with - concrete values. - Returns: - the concrete result of the primitive. - """ - # Note that we can use the original numpy, which is not JAX traceable - return np.add(np.multiply(x, y), z) - -# Now we register the primal implementation with JAX -multiply_add_p.def_impl(multiply_add_impl) -``` - -```{code-cell} ipython3 -:id: G5bstKaeNAVV -:outputId: deb94d5b-dfea-4e6f-9ec2-70b416c996c5 - -assert square_add_prim(2., 10.) == 14. -``` - -+++ {"id": "upBf-uAuHhPJ"} - -### JIT - -If we now try to use `jit` we get a `NotImplementedError`: - -```{code-cell} ipython3 -:id: QG-LULjiHk4b -:outputId: d4ef4406-8dae-4c96-97ca-b662340474ee - -with expectNotImplementedError(): - api.jit(square_add_prim)(2., 10.) -``` - -+++ {"id": "rHS1bAGHH44E"} - -#### Abstract evaluation rules -In order to JIT the function, and for other transformations as well, -JAX first evaluates it abstractly using only the -shape and type of the arguments. This abstract evaluation serves multiple -purposes: - - * Gets the sequence of JAX primitives that are used in the computation. This - sequence will be compiled. - * Computes the shape and type of all vectors and operations used in the computation. - - -For example, the abstraction of a vector with 3 elements may be `ShapedArray(float32[3])`, or `ConcreteArray([1., 2., 3.])`. -In the latter case, JAX uses the actual concrete value wrapped as an abstract value. - -```{code-cell} ipython3 -:id: ctQmEeckIbdo -:outputId: e751d0cc-460e-4ffd-df2e-fdabf9cffdc2 - -from jax import core -@trace("multiply_add_abstract_eval") -def multiply_add_abstract_eval(xs, ys, zs): - """Abstract evaluation of the primitive. - - This function does not need to be JAX traceable. It will be invoked with - abstractions of the actual arguments. - Args: - xs, ys, zs: abstractions of the arguments. - Result: - a ShapedArray for the result of the primitive. - """ - assert xs.shape == ys.shape - assert xs.shape == zs.shape - return core.ShapedArray(xs.shape, xs.dtype) - -# Now we register the abstract evaluation with JAX -multiply_add_p.def_abstract_eval(multiply_add_abstract_eval) -``` - -+++ {"id": "RPN88X6YI43A"} - -If we re-attempt to JIT, we see how the abstract evaluation proceeds, but -we get another error, about missing the actual XLA compilation rule: - -```{code-cell} ipython3 -:id: eOcNR92SI2h- -:outputId: 356ef229-3703-4696-cc3d-7c05de405fb0 - -with expectNotImplementedError(): - api.jit(square_add_prim)(2., 10.) -``` - -+++ {"id": "9IOV1R-fJMHp"} - -#### XLA Compilation rules - -JAX compilation works by compiling each primitive into a graph of XLA operations. - -This is the biggest hurdle to adding new functionality to JAX, because the -set of XLA operations is limited, and JAX already has pre-defined primitives -for most of them. However, XLA includes a `CustomCall` operation that can be used to encapsulate arbitrary functionality defined using C++. - -```{code-cell} ipython3 -:id: FYQWSSjKJaWP - -from jax._src.lib.mlir.dialects import hlo -@trace("multiply_add_lowering") -def multiply_add_lowering(ctx, xc, yc, zc): - """The compilation to XLA of the primitive. - - Given an mlir.ir.Value for each argument, return the mlir.ir.Values for - the results of the function. - - Does not need to be a JAX-traceable function. - """ - return [hlo.AddOp(hlo.MulOp(xc, yc), zc).result] - -# Now we register the lowering rule with JAX -# For GPU see the [Custom operations for GPUs](https://jax.readthedocs.io/en/latest/Custom_Operation_for_GPUs.html) -# TODO: TPU? -from jax.interpreters import mlir -mlir.register_lowering(multiply_add_p, multiply_add_lowering, platform='cpu') -``` - -+++ {"id": "K98LX-VaJkFu"} - -Now we succeed to JIT. Notice below that JAX first evaluates the function -abstractly, which triggers the `multiply_add_abstract_eval` function, and -then compiles the set of primitives it has encountered, including `multiply_add`. -At this point JAX invokes `multiply_add_xla_translation`. - -```{code-cell} ipython3 -:id: rj3TLsolJgEc -:outputId: e384bee4-1e9c-4344-f49c-d3b5ec08eb32 - -assert api.jit(lambda x, y: square_add_prim(x, y))(2., 10.) == 14. -``` - -+++ {"id": "Omrez-2_KFfo"} - -Below is another use of `jit` where we compile only -with respect to the first argument. Notice how the second argument to `square_add_prim` is concrete, which leads -in the third argument to `multiply_add_abstract_eval` being -`ConcreteArray`. We see that `multiply_add_abstract_eval` may be used with -both `ShapedArray` and `ConcreteArray`. - -```{code-cell} ipython3 -:id: mPfTwIBoKOEK -:outputId: b293b9b6-a2f9-48f5-f7eb-d4f99c3d905b - -assert api.jit(lambda x, y: square_add_prim(x, y), - static_argnums=1)(2., 10.) == 14. -``` - -+++ {"id": "_Ya3B5l4J1VA"} - -### Forward differentiation - -JAX implements forward differentiation in the form of -a Jacobian-vector product (see the [JAX autodiff cookbook](https://jax.readthedocs.io/en/latest/notebooks/autodiff_cookbook.html#Jacobian-Matrix-and-Matrix-Jacobian-products)). - -If we attempt now to compute the `jvp` function we get an -error because we have not yet told JAX how to differentiate -the `multiply_add` primitive. - -```{code-cell} ipython3 -:id: OxDx6NQnKwMI -:outputId: ce659ef3-c03c-4856-f252-49ec4b6eb964 - -# The second argument `(2., 10.)` are the argument values -# where we evaluate the Jacobian, and the third `(1., 1.)` -# are the values of the tangents for the arguments. -with expectNotImplementedError(): - api.jvp(square_add_prim, (2., 10.), (1., 1.)) -``` - -```{code-cell} ipython3 -:id: zxG24C1JMIMM - -from jax.interpreters import ad - - -@trace("multiply_add_value_and_jvp") -def multiply_add_value_and_jvp(arg_values, arg_tangents): - """Evaluates the primal output and the tangents (Jacobian-vector product). - - Given values of the arguments and perturbation of the arguments (tangents), - compute the output of the primitive and the perturbation of the output. - - This method must be JAX-traceable. JAX may invoke it with abstract values - for the arguments and tangents. - - Args: - arg_values: a tuple of arguments - arg_tangents: a tuple with the tangents of the arguments. The tuple has - the same length as the arg_values. Some of the tangents may also be the - special value ad.Zero to specify a zero tangent. - Returns: - a pair of the primal output and the tangent. - """ - x, y, z = arg_values - xt, yt, zt = arg_tangents - _trace("Primal evaluation:") - # Now we have a JAX-traceable computation of the output. - # Normally, we can use the ma primitive itself to compute the primal output. - primal_out = multiply_add_prim(x, y, z) - - _trace("Tangent evaluation:") - # We must use a JAX-traceable way to compute the tangent. It turns out that - # the output tangent can be computed as (xt * y + x * yt + zt), - # which we can implement in a JAX-traceable way using the same "multiply_add_prim" primitive. - - # We do need to deal specially with Zero. Here we just turn it into a - # proper tensor of 0s (of the same shape as 'x'). - # An alternative would be to check for Zero and perform algebraic - # simplification of the output tangent computation. - def make_zero(tan): - return lax.zeros_like_array(x) if type(tan) is ad.Zero else tan - - output_tangent = multiply_add_prim(make_zero(xt), y, multiply_add_prim(x, make_zero(yt), make_zero(zt))) - return (primal_out, output_tangent) - -# Register the forward differentiation rule with JAX -ad.primitive_jvps[multiply_add_p] = multiply_add_value_and_jvp -``` - -```{code-cell} ipython3 -:id: ma3KBkiAMfW1 -:outputId: f34cbbc6-20d9-48ca-9a9a-b5d91a972cdd - -# Tangent is: xt*y + x*yt + zt = 1.*2. + 2.*1. + 1. = 5. -assert api.jvp(square_add_prim, (2., 10.), (1., 1.)) == (14., 5.) -``` - -+++ {"id": "69QsEcu-lP4u"} - -TO EXPLAIN: - - * Why is JAX using ConcreteArray in square_add_prim? There is no abstract evaluation going on here. - * Not sure how to explain that multiply_add_prim is invoked with ConcreteValue, yet - we do not call the multiply_add_abstract_eval. - * I think it would be useful to show the jaxpr here - -+++ {"id": "Sb6e3ZAHOPHv"} - -#### JIT of forward differentiation - -We can apply JIT to the forward differentiation function: - -```{code-cell} ipython3 -:id: hg-hzVu-N-hv -:outputId: 38d32067-e152-4046-ad80-7f95a31ba628 - -assert api.jit(lambda arg_values, arg_tangents: - api.jvp(square_add_prim, arg_values, arg_tangents))( - (2., 10.), (1., 1.)) == (14., 5.) -``` - -+++ {"id": "jlZt1_v2mU88"} - -Notice that first we evaluate `multiply_add_value_and_jvp` abstractly, which in turn -evaluates abstractly both the primal and the tangent evaluation (a total of -3 invocations of the `ma` primitive). Then we compile the 3 occurrences -of the primitive. - -+++ {"id": "555yt6ZIOePB"} - -### Reverse differentiation - -If we attempt now to use reverse differentiation we -see that JAX starts by using the `multiply_add_value_and_jvp` to -compute the forward differentiation for abstract values, but then runs -into a `NotImplementedError`. - -When computing the reverse differentiation JAX first does abstract evaluation -of the forward differentiation code `multiply_add_value_and_jvp` to obtain a -trace of primitives that compute the output tangent. -Observe that JAX performs this abstract evaluation with concrete values -for the differentiation point, and abstract values for the tangents. -Observe also that JAX uses the special abstract tangent value `Zero` for -the tangent corresponding to the 3rd argument of `ma`. This reflects the -fact that we do not differentiate w.r.t. the 2nd argument to `square_add_prim`, -which flows to the 3rd argument to `multiply_add_prim`. - -Observe also that during the abstract evaluation of the tangent we pass the -value 0.0 as the tangent for the 3rd argument. This is due to the use -of the `make_zero` function in the definition of `multiply_add_value_and_jvp`. - -```{code-cell} ipython3 -:id: 8eAVnexaOjBn -:outputId: e4ee89cf-ab4a-4505-9817-fa978a2865ab - -# This is reverse differentiation w.r.t. the first argument of square_add_prim -with expectNotImplementedError(): - api.grad(square_add_prim)(2., 10.) -``` - -+++ {"id": "fSHLUMDN26AY"} - -The above error is because there is a missing piece for JAX to be able -to use the forward differentiation code to compute reverse differentiation. - -+++ {"id": "3ibDbGF-PjK9"} - -#### Transposition - - -As explained above, when computing reverse differentiation JAX obtains -a trace of primitives that compute the tangent using forward differentiation. -Then, **JAX interprets this trace abstractly backwards** and for each -primitive it applies a **transposition** rule. - -To understand what is going on, consider for now a simpler example of the function "f(x, y) = x * y + y". Assume we need to differentiate at the point `(2., 4.)`. JAX will produce the following JVP tangent calculation of `ft` from the tangents of the input `xt` and `yt`: -``` - a = xt * 4. - b = 2. * yt - c = a + b - ft = c + yt -``` - -By construction, the tangent calculation is always linear in the input tangents. -The only non-linear operator that may arise in the tangent calculation is multiplication, -but then one of the operands is constant. - -JAX will produce the reverse differentiation computation by processing the -JVP computation backwards. For each operation in the tangent computation, -it accumulates the cotangents -of the variables used by the operation, using the cotangent of the result -of the operation: -``` - # Initialize cotangents of inputs and intermediate vars - xct = yct = act = bct = cct = 0. - # Initialize cotangent of the output - fct = 1. - # Process "ft = c + yt" - cct += fct - yct += fct - # Process "c = a + b" - act += cct - bct += cct - # Process "b = 2. * yt" - yct += 2. * bct - # Process "a = xt * 4." - xct += act * 4. -``` - -One can verify that this computation produces `xct = 4.` and `yct = 3.`, which -are the partial derivatives of the function `f`. - -JAX knows for each primitive that may appear in a JVP calculation how to transpose it. Conceptually, if the primitive `p(x, y, z)` is linear in the arguments `y` and `z` for a constant value of `x`, e.g., `p(x, y, z) = y*cy + z*cz`, then the transposition of the primitive is: -``` -p_transpose(out_ct, x, _, _) = (None, out_ct*cy, out_ct*cz) -``` - -Notice that `p_transpose` takes the cotangent of the output of the primitive and a value corresponding to each argument of the primitive. For the linear arguments, the transposition gets an undefined `_` value, and for the other -arguments it gets the actual constants. The transposition returns a cotangent value for each argument of the primitive, with the value `None` returned -for the constant arguments. - -In particular, -``` - add_transpose(out_ct, _, _) = (out_ct, out_ct) - mult_transpose(out_ct, x, _) = (None, x * out_ct) - mult_transpose(out_ct, _, y) = (out_ct * y, None) -``` - -```{code-cell} ipython3 -:id: JaHxFdkRO42r - -@trace("multiply_add_transpose") -def multiply_add_transpose(ct, x, y, z): - """Evaluates the transpose of a linear primitive. - - This method is only used when computing the backward gradient following - value_and_jvp, and is only needed for primitives that are used in the JVP - calculation for some other primitive. We need transposition for multiply_add_prim, - because we have used multiply_add_prim in the computation of the output_tangent in - multiply_add_value_and_jvp. - - In our case, multiply_add is not a linear primitive. However, it is used linearly - w.r.t. tangents in multiply_add_value_and_jvp: - output_tangent(xt, yt, zt) = multiply_add_prim(xt, y, multiply_add_prim(x, yt, zt)) - - Always one of the first two multiplicative arguments is a constant. - - Args: - ct: the cotangent of the output of the primitive. - x, y, z: values of the arguments. The arguments that are used linearly - get an ad.UndefinedPrimal value. The other arguments get a constant - value. - Returns: - a tuple with the cotangent of the inputs, with the value None - corresponding to the constant arguments. - """ - if not ad.is_undefined_primal(x): - # This use of multiply_add is with a constant "x" - assert ad.is_undefined_primal(y) - ct_y = ad.Zero(y.aval) if type(ct) is ad.Zero else multiply_add_prim(x, ct, lax.zeros_like_array(x)) - res = None, ct_y, ct - else: - # This use of multiply_add is with a constant "y" - assert ad.is_undefined_primal(x) - ct_x = ad.Zero(x.aval) if type(ct) is ad.Zero else multiply_add_prim(ct, y, lax.zeros_like_array(y)) - res = ct_x, None, ct - return res - - -ad.primitive_transposes[multiply_add_p] = multiply_add_transpose -``` - -+++ {"id": "PpChox-Jp7wb"} - -Now we can complete the run of the `grad`: - -```{code-cell} ipython3 -:id: PogPKS4MPevd -:outputId: d33328d4-3e87-45b5-9b31-21ad624b67af - -assert api.grad(square_add_prim)(2., 10.) == 4. -``` - -+++ {"id": "8M1xLCXW4fK7"} - -Notice the two calls to `multiply_add_transpose`. They correspond to the two -uses of `multiply_add_prim` in the computation of the `output_tangent` in `multiply_add_value_and_jvp`. The first call to transpose corresponds to the -last use of `multiply_add_prim`: `multiply_add_prim(xt, y, ...)` where `y` is the constant 2.0. - -+++ {"id": "EIJs6FYmPg6c"} - -#### JIT of reverse differentiation - -Notice that the abstract evaluation of the `multiply_add_value_and_jvp` is using only -abstract values, while in the absence of JIT we used `ConcreteArray`. - -```{code-cell} ipython3 -:id: FZ-JGbWZPq2- -:outputId: e42b5222-9c3e-4853-e13a-874f6605d178 - -assert api.jit(api.grad(square_add_prim))(2., 10.) == 4. -``` - -+++ {"id": "-3lqPkdQPvl5"} - -### Batching - -The batching transformation takes a point-wise computation and turns it -into a computation on vectors. If we try it right now, we get a `NotImplementedError`: - -```{code-cell} ipython3 -:id: hFvBR3I9Pzh3 -:outputId: 434608bc-281f-4d3b-83bd-eaaf3b51b1cd - -# The arguments are two vectors instead of two scalars -with expectNotImplementedError(): - api.vmap(square_add_prim, in_axes=0, out_axes=0)(np.array([2., 3.]), - np.array([10., 20.])) -``` - -+++ {"id": "gILasMiP6elR"} - -We need to tell JAX how to evaluate the batched version of the primitive. In this particular case, the `multiply_add_prim` already operates pointwise for any dimension of input vectors. So the batched version can use the same `multiply_add_prim` implementation. - -```{code-cell} ipython3 -:id: KQfeqRIrP7zg - -from jax.interpreters import batching - - -@trace("multiply_add_batch") -def multiply_add_batch(vector_arg_values, batch_axes): - """Computes the batched version of the primitive. - - This must be a JAX-traceable function. - - Since the multiply_add primitive already operates pointwise on arbitrary - dimension tensors, to batch it we can use the primitive itself. This works as - long as both the inputs have the same dimensions and are batched along the - same axes. The result is batched along the axis that the inputs are batched. - - Args: - vector_arg_values: a tuple of two arguments, each being a tensor of matching - shape. - batch_axes: the axes that are being batched. See vmap documentation. - Returns: - a tuple of the result, and the result axis that was batched. - """ - assert batch_axes[0] == batch_axes[1] - assert batch_axes[0] == batch_axes[2] - _trace("Using multiply_add to compute the batch:") - res = multiply_add_prim(*vector_arg_values) - return res, batch_axes[0] - - -batching.primitive_batchers[multiply_add_p] = multiply_add_batch -``` - -```{code-cell} ipython3 -:id: VwxNk869P_YG -:outputId: 9d22c921-5803-4d33-9e88-b6e439ba9738 - -assert np.allclose(api.vmap(square_add_prim, in_axes=0, out_axes=0)( - np.array([2., 3.]), - np.array([10., 20.])), - [14., 29.]) -``` - -+++ {"id": "NmqLlV1TQDCC"} - -#### JIT of batching - -```{code-cell} ipython3 -:id: xqEdXVUgQCTt -:outputId: 9c22fd9c-919c-491d-bbeb-32c241b808fa - -assert np.allclose(api.jit(api.vmap(square_add_prim, in_axes=0, out_axes=0)) - (np.array([2., 3.]), - np.array([10., 20.])), - [14., 29.]) -``` diff --git a/docs/notebooks/Neural_Network_and_Data_Loading.ipynb b/docs/notebooks/Neural_Network_and_Data_Loading.ipynb index f0c157655790..a7ef2a017048 100644 --- a/docs/notebooks/Neural_Network_and_Data_Loading.ipynb +++ b/docs/notebooks/Neural_Network_and_Data_Loading.ipynb @@ -6,11 +6,11 @@ "id": "18AF5Ab4p6VL" }, "source": [ - "# Training a Simple Neural Network, with PyTorch Data Loading\n", + "# Training a simple neural network, with PyTorch data loading\n", "\n", "\n", "\n", - "[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/jax/blob/main/docs/notebooks/Neural_Network_and_Data_Loading.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/google/jax/blob/main/docs/notebooks/Neural_Network_and_Data_Loading.ipynb)\n", + "[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jax-ml/jax/blob/main/docs/notebooks/Neural_Network_and_Data_Loading.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/jax-ml/jax/blob/main/docs/notebooks/Neural_Network_and_Data_Loading.ipynb)\n", "\n", "**Copyright 2018 The JAX Authors.**\n", "\n", @@ -32,9 +32,9 @@ "id": "B_XlLLpcWjkA" }, "source": [ - "![JAX](https://raw.githubusercontent.com/google/jax/main/images/jax_logo_250px.png)\n", + "![JAX](https://raw.githubusercontent.com/jax-ml/jax/main/images/jax_logo_250px.png)\n", "\n", - "Let's combine everything we showed in the [quickstart](https://colab.research.google.com/github/google/jax/blob/main/docs/quickstart.html) to train a simple neural network. We will first specify and train a simple MLP on MNIST using JAX for the computation. We will use PyTorch's data loading API to load images and labels (because it's pretty great, and the world doesn't need yet another data loading library).\n", + "Let's combine everything we showed in the [quickstart](https://colab.research.google.com/github/jax-ml/jax/blob/main/docs/quickstart.html) to train a simple neural network. We will first specify and train a simple MLP on MNIST using JAX for the computation. We will use PyTorch's data loading API to load images and labels (because it's pretty great, and the world doesn't need yet another data loading library).\n", "\n", "Of course, you can use JAX with any API that is compatible with NumPy to make specifying the model a bit more plug-and-play. Here, just for explanatory purposes, we won't use any neural network libraries or special APIs for building our model." ] @@ -119,7 +119,7 @@ " for w, b in params[:-1]:\n", " outputs = jnp.dot(w, activations) + b\n", " activations = relu(outputs)\n", - " \n", + "\n", " final_w, final_b = params[-1]\n", " logits = jnp.dot(final_w, activations) + final_b\n", " return logits - logsumexp(logits)" @@ -238,7 +238,7 @@ "def one_hot(x, k, dtype=jnp.float32):\n", " \"\"\"Create a one-hot encoding of x of size k.\"\"\"\n", " return jnp.array(x[:, None] == jnp.arange(k), dtype)\n", - " \n", + "\n", "def accuracy(params, images, targets):\n", " target_class = jnp.argmax(targets, axis=1)\n", " predicted_class = jnp.argmax(batched_predict(params, images), axis=1)\n", @@ -261,7 +261,7 @@ "id": "umJJGZCC2oKl" }, "source": [ - "## Data Loading with PyTorch\n", + "## Data loading with PyTorch\n", "\n", "JAX is laser-focused on program transformations and accelerator-backed NumPy, so we don't include data loading or munging in the JAX library. There are already a lot of great data loaders out there, so let's just use them instead of reinventing anything. We'll grab PyTorch's data loader, and make a tiny shim to make it work with NumPy arrays." ] @@ -494,7 +494,7 @@ "id": "xxPd6Qw3Z98v" }, "source": [ - "## Training Loop" + "## Training loop" ] }, { diff --git a/docs/notebooks/Neural_Network_and_Data_Loading.md b/docs/notebooks/Neural_Network_and_Data_Loading.md index 2c53bb1e4ab5..cd98022e7421 100644 --- a/docs/notebooks/Neural_Network_and_Data_Loading.md +++ b/docs/notebooks/Neural_Network_and_Data_Loading.md @@ -5,7 +5,7 @@ jupytext: extension: .md format_name: myst format_version: 0.13 - jupytext_version: 1.16.1 + jupytext_version: 1.16.4 kernelspec: display_name: Python 3 language: python @@ -14,11 +14,11 @@ kernelspec: +++ {"id": "18AF5Ab4p6VL"} -# Training a Simple Neural Network, with PyTorch Data Loading +# Training a simple neural network, with PyTorch data loading -[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/jax/blob/main/docs/notebooks/Neural_Network_and_Data_Loading.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/google/jax/blob/main/docs/notebooks/Neural_Network_and_Data_Loading.ipynb) +[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jax-ml/jax/blob/main/docs/notebooks/Neural_Network_and_Data_Loading.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/jax-ml/jax/blob/main/docs/notebooks/Neural_Network_and_Data_Loading.ipynb) **Copyright 2018 The JAX Authors.** @@ -35,9 +35,9 @@ limitations under the License. +++ {"id": "B_XlLLpcWjkA"} -![JAX](https://raw.githubusercontent.com/google/jax/main/images/jax_logo_250px.png) +![JAX](https://raw.githubusercontent.com/jax-ml/jax/main/images/jax_logo_250px.png) -Let's combine everything we showed in the [quickstart](https://colab.research.google.com/github/google/jax/blob/main/docs/quickstart.html) to train a simple neural network. We will first specify and train a simple MLP on MNIST using JAX for the computation. We will use PyTorch's data loading API to load images and labels (because it's pretty great, and the world doesn't need yet another data loading library). +Let's combine everything we showed in the [quickstart](https://colab.research.google.com/github/jax-ml/jax/blob/main/docs/quickstart.html) to train a simple neural network. We will first specify and train a simple MLP on MNIST using JAX for the computation. We will use PyTorch's data loading API to load images and labels (because it's pretty great, and the world doesn't need yet another data loading library). Of course, you can use JAX with any API that is compatible with NumPy to make specifying the model a bit more plug-and-play. Here, just for explanatory purposes, we won't use any neural network libraries or special APIs for building our model. @@ -96,7 +96,7 @@ def predict(params, image): for w, b in params[:-1]: outputs = jnp.dot(w, activations) + b activations = relu(outputs) - + final_w, final_b = params[-1] logits = jnp.dot(final_w, activations) + final_b return logits - logsumexp(logits) @@ -156,7 +156,7 @@ At this point, we have all the ingredients we need to define our neural network def one_hot(x, k, dtype=jnp.float32): """Create a one-hot encoding of x of size k.""" return jnp.array(x[:, None] == jnp.arange(k), dtype) - + def accuracy(params, images, targets): target_class = jnp.argmax(targets, axis=1) predicted_class = jnp.argmax(batched_predict(params, images), axis=1) @@ -175,7 +175,7 @@ def update(params, x, y): +++ {"id": "umJJGZCC2oKl"} -## Data Loading with PyTorch +## Data loading with PyTorch JAX is laser-focused on program transformations and accelerator-backed NumPy, so we don't include data loading or munging in the JAX library. There are already a lot of great data loaders out there, so let's just use them instead of reinventing anything. We'll grab PyTorch's data loader, and make a tiny shim to make it work with NumPy arrays. @@ -245,7 +245,7 @@ test_labels = one_hot(np.array(mnist_dataset_test.test_labels), n_targets) +++ {"id": "xxPd6Qw3Z98v"} -## Training Loop +## Training loop ```{code-cell} ipython3 :id: X2DnZo3iYj18 diff --git a/docs/notebooks/Writing_custom_interpreters_in_Jax.ipynb b/docs/notebooks/Writing_custom_interpreters_in_Jax.ipynb index 7e65aefe359c..00ba9186eeec 100644 --- a/docs/notebooks/Writing_custom_interpreters_in_Jax.ipynb +++ b/docs/notebooks/Writing_custom_interpreters_in_Jax.ipynb @@ -10,7 +10,7 @@ "\n", "\n", "\n", - "[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/jax/blob/main/docs/notebooks/Writing_custom_interpreters_in_Jax.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/google/jax/blob/main/docs/notebooks/Writing_custom_interpreters_in_Jax.ipynb)" + "[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jax-ml/jax/blob/main/docs/notebooks/Writing_custom_interpreters_in_Jax.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/jax-ml/jax/blob/main/docs/notebooks/Writing_custom_interpreters_in_Jax.ipynb)" ] }, { @@ -35,7 +35,6 @@ }, "outputs": [], "source": [ - "import numpy as np\n", "import jax\n", "import jax.numpy as jnp\n", "from jax import jit, grad, vmap\n", @@ -80,7 +79,7 @@ "id": "gA8V51wZdsjh" }, "source": [ - "When we call `fast_f`, what happens? JAX traces the function and constructs an XLA computation graph. The graph is then JIT-compiled and executed. Other transformations work similarly in that they first trace the function and handle the output trace in some way. To learn more about Jax's tracing machinery, you can refer to the [\"How it works\"](https://github.com/google/jax#how-it-works) section in the README." + "When we call `fast_f`, what happens? JAX traces the function and constructs an XLA computation graph. The graph is then JIT-compiled and executed. Other transformations work similarly in that they first trace the function and handle the output trace in some way. To learn more about Jax's tracing machinery, you can refer to the [\"How it works\"](https://github.com/jax-ml/jax#how-it-works) section in the README." ] }, { @@ -214,7 +213,6 @@ "outputs": [], "source": [ "# Importing Jax functions useful for tracing/interpreting.\n", - "import numpy as np\n", "from functools import wraps\n", "\n", "from jax import core\n", @@ -273,7 +271,7 @@ "def eval_jaxpr(jaxpr, consts, *args):\n", " # Mapping from variable -> value\n", " env = {}\n", - " \n", + "\n", " def read(var):\n", " # Literals are values baked into the Jaxpr\n", " if type(var) is core.Literal:\n", @@ -290,16 +288,16 @@ " # Loop through equations and evaluate primitives using `bind`\n", " for eqn in jaxpr.eqns:\n", " # Read inputs to equation from environment\n", - " invals = safe_map(read, eqn.invars) \n", + " invals = safe_map(read, eqn.invars)\n", " # `bind` is how a primitive is called\n", " outvals = eqn.primitive.bind(*invals, **eqn.params)\n", " # Primitives may return multiple outputs or not\n", - " if not eqn.primitive.multiple_results: \n", + " if not eqn.primitive.multiple_results:\n", " outvals = [outvals]\n", " # Write the results of the primitive into the environment\n", - " safe_map(write, eqn.outvars, outvals) \n", + " safe_map(write, eqn.outvars, outvals)\n", " # Read the final result of the Jaxpr from the environment\n", - " return safe_map(read, jaxpr.outvars) " + " return safe_map(read, jaxpr.outvars)" ] }, { @@ -322,7 +320,7 @@ "source": [ "Notice that `eval_jaxpr` will always return a flat list even if the original function does not.\n", "\n", - "Furthermore, this interpreter does not handle higher-order primitives (like `jit` and `pmap`), which we will not cover in this guide. You can refer to `core.eval_jaxpr` ([link](https://github.com/google/jax/blob/main/jax/core.py)) to see the edge cases that this interpreter does not cover." + "Furthermore, this interpreter does not handle higher-order primitives (like `jit` and `pmap`), which we will not cover in this guide. You can refer to `core.eval_jaxpr` ([link](https://github.com/jax-ml/jax/blob/main/jax/core.py)) to see the edge cases that this interpreter does not cover." ] }, { @@ -335,7 +333,7 @@ "\n", "An `inverse` interpreter doesn't look too different from `eval_jaxpr`. We'll first set up the registry which will map primitives to their inverses. We'll then write a custom interpreter that looks up primitives in the registry.\n", "\n", - "It turns out that this interpreter will also look similar to the \"transpose\" interpreter used in reverse-mode autodifferentiation [found here](https://github.com/google/jax/blob/main/jax/interpreters/ad.py#L164-L234)." + "It turns out that this interpreter will also look similar to the \"transpose\" interpreter used in reverse-mode autodifferentiation [found here](https://github.com/jax-ml/jax/blob/main/jax/interpreters/ad.py#L164-L234)." ] }, { @@ -417,7 +415,7 @@ "source": [ "def inverse_jaxpr(jaxpr, consts, *args):\n", " env = {}\n", - " \n", + "\n", " def read(var):\n", " if type(var) is core.Literal:\n", " return var.val\n", @@ -431,12 +429,12 @@ "\n", " # Looping backward\n", " for eqn in jaxpr.eqns[::-1]:\n", - " # outvars are now invars \n", + " # outvars are now invars\n", " invals = safe_map(read, eqn.outvars)\n", " if eqn.primitive not in inverse_registry:\n", " raise NotImplementedError(\n", " f\"{eqn.primitive} does not have registered inverse.\")\n", - " # Assuming a unary function \n", + " # Assuming a unary function\n", " outval = inverse_registry[eqn.primitive](*invals)\n", " safe_map(write, eqn.invars, [outval])\n", " return safe_map(read, jaxpr.invars)" diff --git a/docs/notebooks/Writing_custom_interpreters_in_Jax.md b/docs/notebooks/Writing_custom_interpreters_in_Jax.md index e52c6a5f8742..10c4e7cb6e3b 100644 --- a/docs/notebooks/Writing_custom_interpreters_in_Jax.md +++ b/docs/notebooks/Writing_custom_interpreters_in_Jax.md @@ -5,7 +5,7 @@ jupytext: extension: .md format_name: myst format_version: 0.13 - jupytext_version: 1.16.1 + jupytext_version: 1.16.4 kernelspec: display_name: Python 3 language: python @@ -18,7 +18,7 @@ kernelspec: -[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/jax/blob/main/docs/notebooks/Writing_custom_interpreters_in_Jax.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/google/jax/blob/main/docs/notebooks/Writing_custom_interpreters_in_Jax.ipynb) +[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jax-ml/jax/blob/main/docs/notebooks/Writing_custom_interpreters_in_Jax.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/jax-ml/jax/blob/main/docs/notebooks/Writing_custom_interpreters_in_Jax.ipynb) +++ {"id": "r-3vMiKRYXPJ"} @@ -32,7 +32,6 @@ Here we show how to add your own function transformations to the system, by writ ```{code-cell} ipython3 :id: s27RDKvKXFL8 -import numpy as np import jax import jax.numpy as jnp from jax import jit, grad, vmap @@ -58,7 +57,7 @@ fast_f = jit(f) +++ {"id": "gA8V51wZdsjh"} -When we call `fast_f`, what happens? JAX traces the function and constructs an XLA computation graph. The graph is then JIT-compiled and executed. Other transformations work similarly in that they first trace the function and handle the output trace in some way. To learn more about Jax's tracing machinery, you can refer to the ["How it works"](https://github.com/google/jax#how-it-works) section in the README. +When we call `fast_f`, what happens? JAX traces the function and constructs an XLA computation graph. The graph is then JIT-compiled and executed. Other transformations work similarly in that they first trace the function and handle the output trace in some way. To learn more about Jax's tracing machinery, you can refer to the ["How it works"](https://github.com/jax-ml/jax#how-it-works) section in the README. +++ {"id": "2Th1vYLVaFBz"} @@ -146,7 +145,6 @@ Let's use `make_jaxpr` to trace a function into a Jaxpr. :id: BHkg_3P1pXJj # Importing Jax functions useful for tracing/interpreting. -import numpy as np from functools import wraps from jax import core @@ -185,7 +183,7 @@ To do this, we first create an environment to store the values for each of the v def eval_jaxpr(jaxpr, consts, *args): # Mapping from variable -> value env = {} - + def read(var): # Literals are values baked into the Jaxpr if type(var) is core.Literal: @@ -202,16 +200,16 @@ def eval_jaxpr(jaxpr, consts, *args): # Loop through equations and evaluate primitives using `bind` for eqn in jaxpr.eqns: # Read inputs to equation from environment - invals = safe_map(read, eqn.invars) + invals = safe_map(read, eqn.invars) # `bind` is how a primitive is called outvals = eqn.primitive.bind(*invals, **eqn.params) # Primitives may return multiple outputs or not - if not eqn.primitive.multiple_results: + if not eqn.primitive.multiple_results: outvals = [outvals] # Write the results of the primitive into the environment - safe_map(write, eqn.outvars, outvals) + safe_map(write, eqn.outvars, outvals) # Read the final result of the Jaxpr from the environment - return safe_map(read, jaxpr.outvars) + return safe_map(read, jaxpr.outvars) ``` ```{code-cell} ipython3 @@ -225,7 +223,7 @@ eval_jaxpr(closed_jaxpr.jaxpr, closed_jaxpr.literals, jnp.ones(5)) Notice that `eval_jaxpr` will always return a flat list even if the original function does not. -Furthermore, this interpreter does not handle higher-order primitives (like `jit` and `pmap`), which we will not cover in this guide. You can refer to `core.eval_jaxpr` ([link](https://github.com/google/jax/blob/main/jax/core.py)) to see the edge cases that this interpreter does not cover. +Furthermore, this interpreter does not handle higher-order primitives (like `jit` and `pmap`), which we will not cover in this guide. You can refer to `core.eval_jaxpr` ([link](https://github.com/jax-ml/jax/blob/main/jax/core.py)) to see the edge cases that this interpreter does not cover. +++ {"id": "0vb2ZoGrCMM4"} @@ -233,7 +231,7 @@ Furthermore, this interpreter does not handle higher-order primitives (like `jit An `inverse` interpreter doesn't look too different from `eval_jaxpr`. We'll first set up the registry which will map primitives to their inverses. We'll then write a custom interpreter that looks up primitives in the registry. -It turns out that this interpreter will also look similar to the "transpose" interpreter used in reverse-mode autodifferentiation [found here](https://github.com/google/jax/blob/main/jax/interpreters/ad.py#L164-L234). +It turns out that this interpreter will also look similar to the "transpose" interpreter used in reverse-mode autodifferentiation [found here](https://github.com/jax-ml/jax/blob/main/jax/interpreters/ad.py#L164-L234). ```{code-cell} ipython3 :id: gSMIT2z1vUpO @@ -279,7 +277,7 @@ Now we just need to define `inverse_jaxpr`, which will walk through the Jaxpr ba def inverse_jaxpr(jaxpr, consts, *args): env = {} - + def read(var): if type(var) is core.Literal: return var.val @@ -293,12 +291,12 @@ def inverse_jaxpr(jaxpr, consts, *args): # Looping backward for eqn in jaxpr.eqns[::-1]: - # outvars are now invars + # outvars are now invars invals = safe_map(read, eqn.outvars) if eqn.primitive not in inverse_registry: raise NotImplementedError( f"{eqn.primitive} does not have registered inverse.") - # Assuming a unary function + # Assuming a unary function outval = inverse_registry[eqn.primitive](*invals) safe_map(write, eqn.invars, [outval]) return safe_map(read, jaxpr.invars) diff --git a/docs/notebooks/autodiff_cookbook.ipynb b/docs/notebooks/autodiff_cookbook.ipynb index edfd0d4535f8..5538b70dac93 100644 --- a/docs/notebooks/autodiff_cookbook.ipynb +++ b/docs/notebooks/autodiff_cookbook.ipynb @@ -10,9 +10,7 @@ "\n", "\n", "\n", - "[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/jax/blob/main/docs/notebooks/autodiff_cookbook.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/google/jax/blob/main/docs/notebooks/autodiff_cookbook.ipynb)\n", - "\n", - "*alexbw@, mattjj@* \n", + "[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jax-ml/jax/blob/main/docs/notebooks/autodiff_cookbook.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/jax-ml/jax/blob/main/docs/notebooks/autodiff_cookbook.ipynb)\n", "\n", "JAX has a pretty general automatic differentiation system. In this notebook, we'll go through a whole bunch of neat autodiff ideas that you can cherry pick for your own work, starting with the basics." ] @@ -257,7 +255,7 @@ "id": "cJ2NxiN58bfI" }, "source": [ - "You can [register your own container types](https://github.com/google/jax/issues/446#issuecomment-467105048) to work with not just `grad` but all the JAX transformations (`jit`, `vmap`, etc.)." + "You can [register your own container types](https://github.com/jax-ml/jax/issues/446#issuecomment-467105048) to work with not just `grad` but all the JAX transformations (`jit`, `vmap`, etc.)." ] }, { @@ -487,7 +485,7 @@ "id": "iZDL-n_AvgBt" }, "source": [ - "These two functions compute the same values (up to machine numerics), but differ in their implementation: `jacfwd` uses forward-mode automatic differentiation, which is more efficient for \"tall\" Jacobian matrices, while `jacrev` uses reverse-mode, which is more efficient for \"wide\" Jacobian matrices. For matrices that are near-square, `jacfwd` probably has an edge over `jacrev`." + "These two functions compute the same values (up to machine numerics), but differ in their implementation: `jacfwd` uses forward-mode automatic differentiation, which is more efficient for \"tall\" Jacobian matrices (more outputs than inputs), while `jacrev` uses reverse-mode, which is more efficient for \"wide\" Jacobian matrices (more inputs than outputs). For matrices that are near-square, `jacfwd` probably has an edge over `jacrev`." ] }, { @@ -1017,7 +1015,7 @@ "source": [ "### Jacobian-Matrix and Matrix-Jacobian products\n", "\n", - "Now that we have `jvp` and `vjp` transformations that give us functions to push-forward or pull-back single vectors at a time, we can use JAX's `vmap` [transformation](https://github.com/google/jax#auto-vectorization-with-vmap) to push and pull entire bases at once. In particular, we can use that to write fast matrix-Jacobian and Jacobian-matrix products." + "Now that we have `jvp` and `vjp` transformations that give us functions to push-forward or pull-back single vectors at a time, we can use JAX's `vmap` [transformation](https://github.com/jax-ml/jax#auto-vectorization-with-vmap) to push and pull entire bases at once. In particular, we can use that to write fast matrix-Jacobian and Jacobian-matrix products." ] }, { @@ -1148,7 +1146,7 @@ " y, vjp_fun = vjp(f, x)\n", " # Use vmap to do a matrix-Jacobian product.\n", " # Here, the matrix is the Euclidean basis, so we get all\n", - " # entries in the Jacobian at once. \n", + " # entries in the Jacobian at once.\n", " J, = vmap(vjp_fun, in_axes=0)(jnp.eye(len(y)))\n", " return J\n", " return jacfun\n", @@ -1169,7 +1167,7 @@ "def our_jacfwd(f):\n", " def jacfun(x):\n", " _jvp = lambda s: jvp(f, (x,), (s,))[1]\n", - " Jt =vmap(_jvp, in_axes=1)(jnp.eye(len(x)))\n", + " Jt = vmap(_jvp, in_axes=1)(jnp.eye(len(x)))\n", " return jnp.transpose(Jt)\n", " return jacfun\n", "\n", diff --git a/docs/notebooks/autodiff_cookbook.md b/docs/notebooks/autodiff_cookbook.md index c24d05c0e7c9..db6fde8051d1 100644 --- a/docs/notebooks/autodiff_cookbook.md +++ b/docs/notebooks/autodiff_cookbook.md @@ -5,7 +5,7 @@ jupytext: extension: .md format_name: myst format_version: 0.13 - jupytext_version: 1.16.1 + jupytext_version: 1.16.4 kernelspec: display_name: Python 3 language: python @@ -18,9 +18,7 @@ kernelspec: -[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/jax/blob/main/docs/notebooks/autodiff_cookbook.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/google/jax/blob/main/docs/notebooks/autodiff_cookbook.ipynb) - -*alexbw@, mattjj@* +[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jax-ml/jax/blob/main/docs/notebooks/autodiff_cookbook.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/jax-ml/jax/blob/main/docs/notebooks/autodiff_cookbook.ipynb) JAX has a pretty general automatic differentiation system. In this notebook, we'll go through a whole bunch of neat autodiff ideas that you can cherry pick for your own work, starting with the basics. @@ -153,7 +151,7 @@ print(grad(loss2)({'W': W, 'b': b})) +++ {"id": "cJ2NxiN58bfI"} -You can [register your own container types](https://github.com/google/jax/issues/446#issuecomment-467105048) to work with not just `grad` but all the JAX transformations (`jit`, `vmap`, etc.). +You can [register your own container types](https://github.com/jax-ml/jax/issues/446#issuecomment-467105048) to work with not just `grad` but all the JAX transformations (`jit`, `vmap`, etc.). +++ {"id": "PaCHzAtGruBz"} @@ -276,7 +274,7 @@ print(J) +++ {"id": "iZDL-n_AvgBt"} -These two functions compute the same values (up to machine numerics), but differ in their implementation: `jacfwd` uses forward-mode automatic differentiation, which is more efficient for "tall" Jacobian matrices, while `jacrev` uses reverse-mode, which is more efficient for "wide" Jacobian matrices. For matrices that are near-square, `jacfwd` probably has an edge over `jacrev`. +These two functions compute the same values (up to machine numerics), but differ in their implementation: `jacfwd` uses forward-mode automatic differentiation, which is more efficient for "tall" Jacobian matrices (more outputs than inputs), while `jacrev` uses reverse-mode, which is more efficient for "wide" Jacobian matrices (more inputs than outputs). For matrices that are near-square, `jacfwd` probably has an edge over `jacrev`. +++ {"id": "zeKlr7Xz8bfm"} @@ -594,7 +592,7 @@ print("Naive full Hessian materialization") ### Jacobian-Matrix and Matrix-Jacobian products -Now that we have `jvp` and `vjp` transformations that give us functions to push-forward or pull-back single vectors at a time, we can use JAX's `vmap` [transformation](https://github.com/google/jax#auto-vectorization-with-vmap) to push and pull entire bases at once. In particular, we can use that to write fast matrix-Jacobian and Jacobian-matrix products. +Now that we have `jvp` and `vjp` transformations that give us functions to push-forward or pull-back single vectors at a time, we can use JAX's `vmap` [transformation](https://github.com/jax-ml/jax#auto-vectorization-with-vmap) to push and pull entire bases at once. In particular, we can use that to write fast matrix-Jacobian and Jacobian-matrix products. ```{code-cell} ipython3 :id: asAWvxVaCmsx @@ -675,7 +673,7 @@ def our_jacrev(f): y, vjp_fun = vjp(f, x) # Use vmap to do a matrix-Jacobian product. # Here, the matrix is the Euclidean basis, so we get all - # entries in the Jacobian at once. + # entries in the Jacobian at once. J, = vmap(vjp_fun, in_axes=0)(jnp.eye(len(y))) return J return jacfun @@ -691,7 +689,7 @@ from jax import jacfwd as builtin_jacfwd def our_jacfwd(f): def jacfun(x): _jvp = lambda s: jvp(f, (x,), (s,))[1] - Jt =vmap(_jvp, in_axes=1)(jnp.eye(len(x))) + Jt = vmap(_jvp, in_axes=1)(jnp.eye(len(x))) return jnp.transpose(Jt) return jacfun diff --git a/docs/notebooks/autodiff_remat.ipynb b/docs/notebooks/autodiff_remat.ipynb index f0552e52688f..82381838a5aa 100644 --- a/docs/notebooks/autodiff_remat.ipynb +++ b/docs/notebooks/autodiff_remat.ipynb @@ -27,7 +27,7 @@ "id": "qaIsQSh1XoKF" }, "source": [ - "### TL;DR\n", + "### Summary\n", "\n", "Use the `jax.checkpoint` decorator (aliased as `jax.remat`) with `jax.grad` to control which intermediates are saved on the forward pass versus recomputed on the backward pass, trading off memory and FLOPs.\n", "\n", @@ -739,8 +739,6 @@ "metadata": {}, "outputs": [], "source": [ - "from jax.ad_checkpoint import checkpoint_name\n", - "\n", "def predict(params, x):\n", " *Ws, Wlast = params\n", " for i, W in enumerate(Ws):\n", diff --git a/docs/notebooks/autodiff_remat.md b/docs/notebooks/autodiff_remat.md index b31e093b6f91..0a6c84b2d88f 100644 --- a/docs/notebooks/autodiff_remat.md +++ b/docs/notebooks/autodiff_remat.md @@ -5,7 +5,7 @@ jupytext: extension: .md format_name: myst format_version: 0.13 - jupytext_version: 1.16.1 + jupytext_version: 1.16.4 kernelspec: display_name: Python 3 name: python3 @@ -24,7 +24,7 @@ import jax.numpy as jnp +++ {"id": "qaIsQSh1XoKF"} -### TL;DR +### Summary Use the `jax.checkpoint` decorator (aliased as `jax.remat`) with `jax.grad` to control which intermediates are saved on the forward pass versus recomputed on the backward pass, trading off memory and FLOPs. @@ -370,8 +370,6 @@ Notice also that by providing a policy, we didn't need to edit the code defining Some policies can refer to values named with `jax.ad_checkpoint.checkpoint_name`: ```{code-cell} -from jax.ad_checkpoint import checkpoint_name - def predict(params, x): *Ws, Wlast = params for i, W in enumerate(Ws): diff --git a/docs/notebooks/convolutions.ipynb b/docs/notebooks/convolutions.ipynb index 0a823353068b..9d91804b6021 100644 --- a/docs/notebooks/convolutions.ipynb +++ b/docs/notebooks/convolutions.ipynb @@ -6,11 +6,11 @@ "id": "TVT_MVvc02AA" }, "source": [ - "# Generalized Convolutions in JAX\n", + "# Generalized convolutions in JAX\n", "\n", "\n", "\n", - "[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/jax/blob/main/docs/notebooks/convolutions.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/google/jax/blob/main/docs/notebooks/convolutions.ipynb)\n", + "[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jax-ml/jax/blob/main/docs/notebooks/convolutions.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/jax-ml/jax/blob/main/docs/notebooks/convolutions.ipynb)\n", "\n", "JAX provides a number of interfaces to compute convolutions across data, including:\n", "\n", @@ -28,7 +28,7 @@ "id": "ewZEn2X12-Ng" }, "source": [ - "## Basic One-dimensional Convolution\n", + "## Basic one-dimensional convolution\n", "\n", "Basic one-dimensional convolution is implemented by {func}`jax.numpy.convolve`, which provides a JAX interface for {func}`numpy.convolve`. Here is a simple example of 1D smoothing implemented via a convolution:" ] @@ -91,7 +91,7 @@ "id": "5ndvLDIH4rv6" }, "source": [ - "## Basic N-dimensional Convolution\n", + "## Basic N-dimensional convolution\n", "\n", "For *N*-dimensional convolution, {func}`jax.scipy.signal.convolve` provides a similar interface to that of {func}`jax.numpy.convolve`, generalized to *N* dimensions.\n", "\n", @@ -160,7 +160,7 @@ "id": "bxuUjFVG-v1h" }, "source": [ - "## General Convolutions" + "## General convolutions" ] }, { @@ -410,7 +410,7 @@ ], "source": [ "dn = lax.conv_dimension_numbers(img.shape, # only ndim matters, not shape\n", - " kernel.shape, # only ndim matters, not shape \n", + " kernel.shape, # only ndim matters, not shape\n", " ('NHWC', 'HWIO', 'NHWC')) # the important bit\n", "print(dn)" ] @@ -806,8 +806,8 @@ ], "source": [ "# 1D kernel - WIO layout\n", - "kernel = jnp.array([[[1, 0, -1], [-1, 0, 1]], \n", - " [[1, 1, 1], [-1, -1, -1]]], \n", + "kernel = jnp.array([[[1, 0, -1], [-1, 0, 1]],\n", + " [[1, 1, 1], [-1, -1, -1]]],\n", " dtype=jnp.float32).transpose([2,1,0])\n", "# 1D data - NWC layout\n", "data = np.zeros((1, 200, 2), dtype=jnp.float32)\n", @@ -895,8 +895,8 @@ "# Random 3D kernel - HWDIO layout\n", "kernel = jnp.array([\n", " [[0, 0, 0], [0, 1, 0], [0, 0, 0]],\n", - " [[0, -1, 0], [-1, 0, -1], [0, -1, 0]], \n", - " [[0, 0, 0], [0, 1, 0], [0, 0, 0]]], \n", + " [[0, -1, 0], [-1, 0, -1], [0, -1, 0]],\n", + " [[0, 0, 0], [0, 1, 0], [0, 0, 0]]],\n", " dtype=jnp.float32)[:, :, :, jnp.newaxis, jnp.newaxis]\n", "\n", "# 3D data - NHWDC layout\n", @@ -919,7 +919,6 @@ "print(\"out shape: \", out.shape)\n", "\n", "# Make some simple 3d density plots:\n", - "from mpl_toolkits.mplot3d import Axes3D\n", "def make_alpha(cmap):\n", " my_cmap = cmap(jnp.arange(cmap.N))\n", " my_cmap[:,-1] = jnp.linspace(0, 1, cmap.N)**3\n", diff --git a/docs/notebooks/convolutions.md b/docs/notebooks/convolutions.md index 3de8f261aa5b..b98099aa9571 100644 --- a/docs/notebooks/convolutions.md +++ b/docs/notebooks/convolutions.md @@ -5,7 +5,7 @@ jupytext: extension: .md format_name: myst format_version: 0.13 - jupytext_version: 1.16.1 + jupytext_version: 1.16.4 kernelspec: display_name: Python 3 language: python @@ -14,11 +14,11 @@ kernelspec: +++ {"id": "TVT_MVvc02AA"} -# Generalized Convolutions in JAX +# Generalized convolutions in JAX -[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/jax/blob/main/docs/notebooks/convolutions.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/google/jax/blob/main/docs/notebooks/convolutions.ipynb) +[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jax-ml/jax/blob/main/docs/notebooks/convolutions.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/jax-ml/jax/blob/main/docs/notebooks/convolutions.ipynb) JAX provides a number of interfaces to compute convolutions across data, including: @@ -31,7 +31,7 @@ For basic convolution operations, the `jax.numpy` and `jax.scipy` operations are +++ {"id": "ewZEn2X12-Ng"} -## Basic One-dimensional Convolution +## Basic one-dimensional convolution Basic one-dimensional convolution is implemented by {func}`jax.numpy.convolve`, which provides a JAX interface for {func}`numpy.convolve`. Here is a simple example of 1D smoothing implemented via a convolution: @@ -65,7 +65,7 @@ For more information, see the {func}`jax.numpy.convolve` documentation, or the d +++ {"id": "5ndvLDIH4rv6"} -## Basic N-dimensional Convolution +## Basic N-dimensional convolution For *N*-dimensional convolution, {func}`jax.scipy.signal.convolve` provides a similar interface to that of {func}`jax.numpy.convolve`, generalized to *N* dimensions. @@ -105,7 +105,7 @@ Like in the one-dimensional case, we use `mode='same'` to specify how we would l +++ {"id": "bxuUjFVG-v1h"} -## General Convolutions +## General convolutions +++ {"id": "0pcn2LeS-03b"} @@ -210,7 +210,7 @@ The important argument is the 3-tuple of axis layout arguments: :outputId: d5a569b3-febc-4832-f725-1d5e8fd31b9b dn = lax.conv_dimension_numbers(img.shape, # only ndim matters, not shape - kernel.shape, # only ndim matters, not shape + kernel.shape, # only ndim matters, not shape ('NHWC', 'HWIO', 'NHWC')) # the important bit print(dn) ``` @@ -363,8 +363,8 @@ You aren't limited to 2D convolutions, a simple 1D demo is below: :outputId: 67c46ace-6adc-4c47-c1c7-1f185be5fd4b # 1D kernel - WIO layout -kernel = jnp.array([[[1, 0, -1], [-1, 0, 1]], - [[1, 1, 1], [-1, -1, -1]]], +kernel = jnp.array([[[1, 0, -1], [-1, 0, 1]], + [[1, 1, 1], [-1, -1, -1]]], dtype=jnp.float32).transpose([2,1,0]) # 1D data - NWC layout data = np.zeros((1, 200, 2), dtype=jnp.float32) @@ -406,8 +406,8 @@ import matplotlib as mpl # Random 3D kernel - HWDIO layout kernel = jnp.array([ [[0, 0, 0], [0, 1, 0], [0, 0, 0]], - [[0, -1, 0], [-1, 0, -1], [0, -1, 0]], - [[0, 0, 0], [0, 1, 0], [0, 0, 0]]], + [[0, -1, 0], [-1, 0, -1], [0, -1, 0]], + [[0, 0, 0], [0, 1, 0], [0, 0, 0]]], dtype=jnp.float32)[:, :, :, jnp.newaxis, jnp.newaxis] # 3D data - NHWDC layout @@ -430,7 +430,6 @@ out = lax.conv_general_dilated(data, # lhs = image tensor print("out shape: ", out.shape) # Make some simple 3d density plots: -from mpl_toolkits.mplot3d import Axes3D def make_alpha(cmap): my_cmap = cmap(jnp.arange(cmap.N)) my_cmap[:,-1] = jnp.linspace(0, 1, cmap.N)**3 diff --git a/docs/notebooks/external_callbacks.ipynb b/docs/notebooks/external_callbacks.ipynb deleted file mode 100644 index bdf71004c01b..000000000000 --- a/docs/notebooks/external_callbacks.ipynb +++ /dev/null @@ -1,1122 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "metadata": { - "id": "7XNMxdTwURqI" - }, - "source": [ - "# External Callbacks in JAX\n", - "\n", - "" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "h6lXo6bSUYGq" - }, - "source": [ - "This guide outlines the uses of various callback functions, which allow JAX runtimes to execute Python code on the host, even while running under `jit`, `vmap`, `grad`, or another transformation." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "Xi_nhfpnlmbm" - }, - "source": [ - "## Why callbacks?\n", - "\n", - "A callback routine is a way to perform **host-side** execution of code at runtime.\n", - "As a simple example, suppose you'd like to print the *value* of some variable during the course of a computation.\n", - "Using a simple Python `print` statement, it looks like this:" - ] - }, - { - "cell_type": "code", - "execution_count": 1, - "metadata": { - "id": "lz8rEL1Amb4r", - "outputId": "bbd37102-19f2-46d2-b794-3d4952c6fe97" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "intermediate value: Tracedwith\n" - ] - } - ], - "source": [ - "import jax\n", - "\n", - "@jax.jit\n", - "def f(x):\n", - " y = x + 1\n", - " print(\"intermediate value: {}\".format(y))\n", - " return y * 2\n", - "\n", - "result = f(2)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "yEy41sFAmxOp" - }, - "source": [ - "What is printed is not the runtime value, but the trace-time abstract value (if you're not famililar with *tracing* in JAX, a good primer can be found in [How To Think In JAX](https://jax.readthedocs.io/en/latest/notebooks/thinking_in_jax.html)).\n", - "\n", - "To print the value at runtime we need a callback, for example `jax.debug.print`:" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "metadata": { - "id": "wFfHmoQxnKDF", - "outputId": "6bea21d9-9bb1-4d4d-f3ec-fcf1c691a46a" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "intermediate value: 3\n" - ] - } - ], - "source": [ - "@jax.jit\n", - "def f(x):\n", - " y = x + 1\n", - " jax.debug.print(\"intermediate value: {}\", y)\n", - " return y * 2\n", - "\n", - "result = f(2)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "CvWv3pudn9X5" - }, - "source": [ - "This works by passing the runtime value represented by `y` back to the host process, where the host can print the value." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "X0vR078znuT-" - }, - "source": [ - "## Flavors of Callback\n", - "\n", - "In earlier versions of JAX, there was only one kind of callback available, implemented in `jax.experimental.host_callback`. The `host_callback` routines had some deficiencies, and are now deprecated in favor of several callbacks designed for different situations:\n", - "\n", - "- {func}`jax.pure_callback`: appropriate for pure functions: i.e. functions with no side effect.\n", - "- {func}`jax.experimental.io_callback`: appropriate for impure functions: e.g. functions which read or write data to disk.\n", - "- {func}`jax.debug.callback`: appropriate for functions that should reflect the execution behavior of the compiler.\n", - "\n", - "(The {func}`jax.debug.print` function we used above is a wrapper around {func}`jax.debug.callback`).\n", - "\n", - "From the user perspective, these three flavors of callback are mainly distinguished by what transformations and compiler optimizations they allow.\n", - "\n", - "|callback function | supports return value | `jit` | `vmap` | `grad` | `scan`/`while_loop` | guaranteed execution |\n", - "|-------------------------------------|----|----|----|----|----|----|\n", - "|`jax.pure_callback` | ✅ | ✅ | ✅ | ❌¹ | ✅ | ❌ |\n", - "|`jax.experimental.io_callback` | ✅ | ✅ | ✅/❌² | ❌ | ✅³ | ✅ |\n", - "|`jax.debug.callback` | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ |\n", - "\n", - "¹ `jax.pure_callback` can be used with `custom_jvp` to make it compatible with autodiff\n", - "\n", - "² `jax.experimental.io_callback` is compatible with `vmap` only if `ordered=False`.\n", - "\n", - "³ Note that `vmap` of `scan`/`while_loop` of `io_callback` has complicated semantics, and its behavior may change in future releases." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "hE_M8DaPvoym" - }, - "source": [ - "### Exploring `jax.pure_callback`\n", - "\n", - "`jax.pure_callback` is generally the callback function you should reach for when you want host-side execution of a pure function: i.e. a function that has no side-effects (such as printing values, reading data from disk, updating a global state, etc.).\n", - "\n", - "The function you pass to `jax.pure_callback` need not actually be pure, but it will be assumed pure by JAX's transformations and higher-order functions, which means that it may be silently elided or called multiple times." - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "metadata": { - "id": "4lQDzXy6t_-k", - "outputId": "279e4daf-0540-4eab-f535-d3bcbac74c44" - }, - "outputs": [ - { - "data": { - "text/plain": [ - "Array([ 0. , 0.841471 , 0.9092974, 0.14112 , -0.7568025], dtype=float32)" - ] - }, - "execution_count": 3, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "import jax\n", - "import jax.numpy as jnp\n", - "import numpy as np\n", - "\n", - "def f_host(x):\n", - " # call a numpy (not jax.numpy) operation:\n", - " return np.sin(x).astype(x.dtype)\n", - "\n", - "def f(x):\n", - " result_shape = jax.ShapeDtypeStruct(x.shape, x.dtype)\n", - " return jax.pure_callback(f_host, result_shape, x)\n", - "\n", - "x = jnp.arange(5.0)\n", - "f(x)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "q7YCIr8qMrDs" - }, - "source": [ - "Because `pure_callback` can be elided or duplicated, it is compatible out-of-the-box with transformations like `jit` and `vmap`, as well as higher-order primitives like `scan` and `while_loop`:\"" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "metadata": { - "id": "bgoZ0fxsuoWV", - "outputId": "901443bd-5cb4-4923-ce53-6f832ac22ca9" - }, - "outputs": [ - { - "data": { - "text/plain": [ - "Array([ 0. , 0.841471 , 0.9092974, 0.14112 , -0.7568025], dtype=float32)" - ] - }, - "execution_count": 4, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "jax.jit(f)(x)" - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "metadata": { - "id": "ajBRGWGfupu2", - "outputId": "b28e31ee-7457-4b92-872b-52d819f53ddf" - }, - "outputs": [ - { - "data": { - "text/plain": [ - "Array([ 0. , 0.841471 , 0.9092974, 0.14112 , -0.7568025], dtype=float32)" - ] - }, - "execution_count": 5, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "jax.vmap(f)(x)" - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "metadata": { - "id": "xe7AOGexvC13", - "outputId": "8fa77977-1f2b-41c5-cc5e-11993ee5aa3e" - }, - "outputs": [ - { - "data": { - "text/plain": [ - "Array([ 0. , 0.841471 , 0.9092974, 0.14112 , -0.7568025], dtype=float32)" - ] - }, - "execution_count": 6, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "def body_fun(_, x):\n", - " return _, f(x)\n", - "jax.lax.scan(body_fun, None, jnp.arange(5.0))[1]" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "tMzAVs2VNj5G" - }, - "source": [ - "However, because there is no way for JAX to introspect the content of the callback, `pure_callback` has undefined autodiff semantics:" - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "metadata": { - "id": "4QAF4VhUu5bb", - "outputId": "f8a06d02-47e9-4240-8077-d7be81e5a480" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Exception reporting mode: Minimal\n" - ] - } - ], - "source": [ - "%xmode minimal" - ] - }, - { - "cell_type": "code", - "execution_count": 8, - "metadata": { - "id": "qUpKPxlOurfY", - "outputId": "11a665e8-40eb-4b0e-dc2e-a544a25fc57e", - "tags": [ - "raises-exception" - ] - }, - "outputs": [ - { - "ename": "ValueError", - "evalue": "ignored", - "output_type": "error", - "traceback": [ - "\u001b[0;31mValueError\u001b[0m\u001b[0;31m:\u001b[0m Pure callbacks do not support JVP. Please use `jax.custom_jvp` to use callbacks while taking gradients.\n" - ] - } - ], - "source": [ - "jax.grad(f)(x)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "y9DAibV4Nwpo" - }, - "source": [ - "For an example of using `pure_callback` with `jax.custom_jvp`, see *Example: `pure_callback` with `custom_jvp`* below." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "LrvdAloMZbIe" - }, - "source": [ - "By design functions passed to `pure_callback` are treated as if they have no side-effects: one consequence of this is that if the output of the function is not used, the compiler may eliminate the callback entirely:" - ] - }, - { - "cell_type": "code", - "execution_count": 9, - "metadata": { - "id": "mmFc_zawZrBq", - "outputId": "a4df7568-3f64-4b2f-9a2c-7adb2e0815e0" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "printing something\n" - ] - } - ], - "source": [ - "def print_something():\n", - " print('printing something')\n", - " return np.int32(0)\n", - "\n", - "@jax.jit\n", - "def f1():\n", - " return jax.pure_callback(print_something, np.int32(0))\n", - "f1();" - ] - }, - { - "cell_type": "code", - "execution_count": 10, - "metadata": { - "id": "tTwE4kpmaNei" - }, - "outputs": [], - "source": [ - "@jax.jit\n", - "def f2():\n", - " jax.pure_callback(print_something, np.int32(0))\n", - " return 1.0\n", - "f2();" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "qfyGYbw4Z5U3" - }, - "source": [ - "In `f1`, the output of the callback is used in the return value of the function, so the callback is executed and we see the printed output.\n", - "In `f2` on the other hand, the output of the callback is unused, and so the compiler notices this and eliminates the function call. These are the correct semantics for a callback to a function with no side-effects." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "JHcJybr7OEBM" - }, - "source": [ - "### Exploring `jax.experimental.io_callback`\n", - "\n", - "In contrast to {func}`jax.pure_callback`, {func}`jax.experimental.io_callback` is explicitly meant to be used with impure functions, i.e. functions that do have side-effects.\n", - "\n", - "As an example, here is a callback to a global host-side numpy random generator. This is an impure operation because a side-effect of generating a random number in numpy is that the random state is updated (Please note that this is meant as a toy example of `io_callback` and not necessarily a recommended way of generating random numbers in JAX!)." - ] - }, - { - "cell_type": "code", - "execution_count": 11, - "metadata": { - "id": "eAg5xIhrOiWV", - "outputId": "e3cfec21-d843-4852-a49d-69a69fba9fc1" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "generating float32[5]\n" - ] - }, - { - "data": { - "text/plain": [ - "Array([0.6369617 , 0.26978672, 0.04097353, 0.01652764, 0.8132702 ], dtype=float32)" - ] - }, - "execution_count": 11, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "from jax.experimental import io_callback\n", - "from functools import partial\n", - "\n", - "global_rng = np.random.default_rng(0)\n", - "\n", - "def host_side_random_like(x):\n", - " \"\"\"Generate a random array like x using the global_rng state\"\"\"\n", - " # We have two side-effects here:\n", - " # - printing the shape and dtype\n", - " # - calling global_rng, thus updating its state\n", - " print(f'generating {x.dtype}{list(x.shape)}')\n", - " return global_rng.uniform(size=x.shape).astype(x.dtype)\n", - "\n", - "@jax.jit\n", - "def numpy_random_like(x):\n", - " return io_callback(host_side_random_like, x, x)\n", - "\n", - "x = jnp.zeros(5)\n", - "numpy_random_like(x)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "mAIF31MlXj33" - }, - "source": [ - "The `io_callback` is compatible with `vmap` by default:" - ] - }, - { - "cell_type": "code", - "execution_count": 12, - "metadata": { - "id": "NY3o5dG6Vg6u", - "outputId": "a67a8a98-214e-40ca-ad98-a930cd3db85e" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "generating float32[]\n", - "generating float32[]\n", - "generating float32[]\n", - "generating float32[]\n", - "generating float32[]\n" - ] - }, - { - "data": { - "text/plain": [ - "Array([0.91275555, 0.60663575, 0.72949654, 0.543625 , 0.9350724 ], dtype=float32)" - ] - }, - "execution_count": 12, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "jax.vmap(numpy_random_like)(x)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "XXvSeeOXXquZ" - }, - "source": [ - "Note, however, that this may execute the mapped callbacks in any order. So, for example, if you ran this on a GPU, the order of the mapped outputs might differ from run to run.\n", - "\n", - "If it is important that the order of callbacks be preserved, you can set `ordered=True`, in which case attempting to `vmap` will raise an error:" - ] - }, - { - "cell_type": "code", - "execution_count": 13, - "metadata": { - "id": "3aNmRsDrX3-2", - "outputId": "a8ff4b77-f4cb-442f-8cfb-ea7251c66274", - "tags": [ - "raises-exception" - ] - }, - "outputs": [ - { - "ename": "ValueError", - "evalue": "ignored", - "output_type": "error", - "traceback": [ - "\u001b[0;31mJaxStackTraceBeforeTransformation\u001b[0m\u001b[0;31m:\u001b[0m ValueError: Cannot `vmap` ordered IO callback.\n\nThe preceding stack trace is the source of the JAX operation that, once transformed by JAX, triggered the following exception.\n\n--------------------\n", - "\nThe above exception was the direct cause of the following exception:\n", - "\u001b[0;31mValueError\u001b[0m\u001b[0;31m:\u001b[0m Cannot `vmap` ordered IO callback.\n" - ] - } - ], - "source": [ - "@jax.jit\n", - "def numpy_random_like_ordered(x):\n", - " return io_callback(host_side_random_like, x, x, ordered=True)\n", - "\n", - "jax.vmap(numpy_random_like_ordered)(x)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "fD2FTHlUYAZH" - }, - "source": [ - "On the other hand, `scan` and `while_loop` work with `io_callback` regardless of whether ordering is enforced:" - ] - }, - { - "cell_type": "code", - "execution_count": 14, - "metadata": { - "id": "lMVzZlIEWL7F", - "outputId": "f9741c18-a30d-4d46-b706-8102849286b5" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "generating float32[]\n", - "generating float32[]\n", - "generating float32[]\n", - "generating float32[]\n", - "generating float32[]\n" - ] - }, - { - "data": { - "text/plain": [ - "Array([0.81585354, 0.0027385 , 0.8574043 , 0.03358557, 0.72965544], dtype=float32)" - ] - }, - "execution_count": 14, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "def body_fun(_, x):\n", - " return _, numpy_random_like_ordered(x)\n", - "jax.lax.scan(body_fun, None, jnp.arange(5.0))[1]" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "w_sf8mCbbo8K" - }, - "source": [ - "Like `pure_callback`, `io_callback` fails under automatic differentiation if it is passed a differentiated variable:" - ] - }, - { - "cell_type": "code", - "execution_count": 15, - "metadata": { - "id": "Cn6_RG4JcKZm", - "outputId": "336ae5d2-e35b-4fe5-cbfb-14a7aef28c07", - "tags": [ - "raises-exception" - ] - }, - "outputs": [ - { - "ename": "ValueError", - "evalue": "ignored", - "output_type": "error", - "traceback": [ - "\u001b[0;31mJaxStackTraceBeforeTransformation\u001b[0m\u001b[0;31m:\u001b[0m ValueError: IO callbacks do not support JVP.\n\nThe preceding stack trace is the source of the JAX operation that, once transformed by JAX, triggered the following exception.\n\n--------------------\n", - "\nThe above exception was the direct cause of the following exception:\n", - "\u001b[0;31mValueError\u001b[0m\u001b[0;31m:\u001b[0m IO callbacks do not support JVP.\n" - ] - } - ], - "source": [ - "jax.grad(numpy_random_like)(x)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "plvfn9lWcKu4" - }, - "source": [ - "However, if the callback is not dependent on a differentiated variable, it will execute:" - ] - }, - { - "cell_type": "code", - "execution_count": 16, - "metadata": { - "id": "wxgfDmDfb5bx", - "outputId": "d8c0285c-cd04-4b4d-d15a-1b07f778882d" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "hello\n" - ] - } - ], - "source": [ - "@jax.jit\n", - "def f(x):\n", - " io_callback(lambda: print('hello'), None)\n", - " return x\n", - "\n", - "jax.grad(f)(1.0);" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "STLI40EZcVIY" - }, - "source": [ - "Unlike `pure_callback`, the compiler will not remove the callback execution in this case, even though the output of the callback is unused in the subsequent computation." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "pkkM1ZmqclV-" - }, - "source": [ - "### Exploring `debug.callback`\n", - "\n", - "Both `pure_callback` and `io_callback` enforce some assumptions about the purity of the function they're calling, and limit in various ways what JAX transforms and compilation machinery may do. `debug.callback` essentially assumes *nothing* about the callback function, such that the action of the callback reflects exactly what JAX is doing during the course of a program. Further, `debug.callback` *cannot* return any value to the program." - ] - }, - { - "cell_type": "code", - "execution_count": 17, - "metadata": { - "id": "74TdWyu9eqBa", - "outputId": "d8551dab-2e61-492e-9ac3-dc3db51b2c18" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "log: 1.0\n" - ] - } - ], - "source": [ - "from jax import debug\n", - "\n", - "def log_value(x):\n", - " # This could be an actual logging call; we'll use\n", - " # print() for demonstration\n", - " print(\"log:\", x)\n", - "\n", - "@jax.jit\n", - "def f(x):\n", - " debug.callback(log_value, x)\n", - " return x\n", - "\n", - "f(1.0);" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "P848STlsfzmW" - }, - "source": [ - "The debug callback is compatible with `vmap`:" - ] - }, - { - "cell_type": "code", - "execution_count": 18, - "metadata": { - "id": "2sSNsPB-fGVI", - "outputId": "fff58575-d94c-48fb-b88a-c1c395595fd0" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "log: 0.0\n", - "log: 1.0\n", - "log: 2.0\n", - "log: 3.0\n", - "log: 4.0\n" - ] - } - ], - "source": [ - "x = jnp.arange(5.0)\n", - "jax.vmap(f)(x);" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "VDMacqpXf3La" - }, - "source": [ - "And is also compatible with `grad` and other autodiff transformations" - ] - }, - { - "cell_type": "code", - "execution_count": 19, - "metadata": { - "id": "wkFRle-tfTDe", - "outputId": "4e8a81d0-5012-4c51-d843-3fbdc498df31" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "log: 1.0\n" - ] - } - ], - "source": [ - "jax.grad(f)(1.0);" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "w8t-SDZ3gRzE" - }, - "source": [ - "This can make `debug.callback` more useful for general-purpose debugging than either `pure_callback` or `io_callback`." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "dF7hoWGQUneJ" - }, - "source": [ - "## Example: `pure_callback` with `custom_jvp`\n", - "\n", - "One powerful way to take advantage of {func}`jax.pure_callback` is to combine it with {class}`jax.custom_jvp` (see [Custom derivative rules](https://jax.readthedocs.io/en/latest/notebooks/Custom_derivative_rules_for_Python_code.html) for more details on `custom_jvp`).\n", - "Suppose we want to create a JAX-compatible wrapper for a scipy or numpy function that is not yet available in the `jax.scipy` or `jax.numpy` wrappers.\n", - "\n", - "Here, we'll consider creating a wrapper for the Bessel function of the first kind, implemented in `scipy.special.jv`.\n", - "We can start by defining a straightforward `pure_callback`:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "Ge4fNPZdVSJY" - }, - "outputs": [], - "source": [ - "import jax\n", - "import jax.numpy as jnp\n", - "import scipy.special\n", - "\n", - "def jv(v, z):\n", - " v, z = jnp.asarray(v), jnp.asarray(z)\n", - "\n", - " # Require the order v to be integer type: this simplifies\n", - " # the JVP rule below.\n", - " assert jnp.issubdtype(v.dtype, jnp.integer)\n", - "\n", - " # Promote the input to inexact (float/complex).\n", - " # Note that jnp.result_type() accounts for the enable_x64 flag.\n", - " z = z.astype(jnp.result_type(float, z.dtype))\n", - "\n", - " # Wrap scipy function to return the expected dtype.\n", - " _scipy_jv = lambda v, z: scipy.special.jv(v, z).astype(z.dtype)\n", - "\n", - " # Define the expected shape & dtype of output.\n", - " result_shape_dtype = jax.ShapeDtypeStruct(\n", - " shape=jnp.broadcast_shapes(v.shape, z.shape),\n", - " dtype=z.dtype)\n", - "\n", - " # We use vectorize=True because scipy.special.jv handles broadcasted inputs.\n", - " return jax.pure_callback(_scipy_jv, result_shape_dtype, v, z, vectorized=True)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "vyjQj-0QVuoN" - }, - "source": [ - "This lets us call into `scipy.special.jv` from transformed JAX code, including when transformed by `jit` and `vmap`:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "f4e46670f4e4" - }, - "outputs": [], - "source": [ - "from functools import partial\n", - "j1 = partial(jv, 1)\n", - "z = jnp.arange(5.0)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "6svImqFHWBwj", - "outputId": "bc8c778a-6c10-443b-9be2-c0f28e2ac1a9" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "[ 0. 0.44005057 0.5767248 0.33905897 -0.06604332]\n" - ] - } - ], - "source": [ - "print(j1(z))" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "d48eb4f2d48e" - }, - "source": [ - "Here is the same result with `jit`:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "txvRqR9DWGdC", - "outputId": "d25f3476-23b1-48e4-dda1-3c06d32c3b87" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "[ 0. 0.44005057 0.5767248 0.33905897 -0.06604332]\n" - ] - } - ], - "source": [ - "print(jax.jit(j1)(z))" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "d861a472d861" - }, - "source": [ - "And here is the same result again with `vmap`:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "BS-Ve5u_WU0C", - "outputId": "08cecd1f-6953-4853-e9db-25a03eb5b000" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "[ 0. 0.44005057 0.5767248 0.33905897 -0.06604332]\n" - ] - } - ], - "source": [ - "print(jax.vmap(j1)(z))" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "SCH2ii_dWXP6" - }, - "source": [ - "However, if we call `jax.grad`, we see an error because there is no autodiff rule defined for this function:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "q3qh_4DrWxdQ", - "outputId": "c46b0bfa-96f3-4629-b9af-a4d4f3ccb870", - "tags": [ - "raises-exception" - ] - }, - "outputs": [ - { - "ename": "ValueError", - "evalue": "ignored", - "output_type": "error", - "traceback": [ - "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mUnfilteredStackTrace\u001b[0m Traceback (most recent call last)", - "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mjax\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mgrad\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mj1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mz\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", - "\u001b[0;32m/usr/local/lib/python3.8/dist-packages/jax/_src/traceback_util.py\u001b[0m in \u001b[0;36mreraise_with_filtered_traceback\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 161\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 162\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mfun\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 163\u001b[0m \u001b[0;32mexcept\u001b[0m \u001b[0mException\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0me\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m/usr/local/lib/python3.8/dist-packages/jax/_src/api.py\u001b[0m in \u001b[0;36mgrad_f\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 1090\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mgrad_f\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1091\u001b[0;31m \u001b[0m_\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mg\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mvalue_and_grad_f\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1092\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mg\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m/usr/local/lib/python3.8/dist-packages/jax/_src/traceback_util.py\u001b[0m in \u001b[0;36mreraise_with_filtered_traceback\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 161\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 162\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mfun\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 163\u001b[0m \u001b[0;32mexcept\u001b[0m \u001b[0mException\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0me\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m/usr/local/lib/python3.8/dist-packages/jax/_src/api.py\u001b[0m in \u001b[0;36mvalue_and_grad_f\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 1166\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0mhas_aux\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1167\u001b[0;31m \u001b[0mans\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mvjp_py\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0m_vjp\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mf_partial\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0mdyn_args\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mreduce_axes\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mreduce_axes\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1168\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m/usr/local/lib/python3.8/dist-packages/jax/_src/api.py\u001b[0m in \u001b[0;36m_vjp\u001b[0;34m(fun, has_aux, reduce_axes, *primals)\u001b[0m\n\u001b[1;32m 2655\u001b[0m \u001b[0mflat_fun\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mout_tree\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mflatten_fun_nokwargs\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfun\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0min_tree\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 2656\u001b[0;31m out_primal, out_vjp = ad.vjp(\n\u001b[0m\u001b[1;32m 2657\u001b[0m flat_fun, primals_flat, reduce_axes=reduce_axes)\n", - "\u001b[0;32m/usr/local/lib/python3.8/dist-packages/jax/interpreters/ad.py\u001b[0m in \u001b[0;36mvjp\u001b[0;34m(traceable, primals, has_aux, reduce_axes)\u001b[0m\n\u001b[1;32m 134\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0mhas_aux\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 135\u001b[0;31m \u001b[0mout_primals\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mpvals\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mjaxpr\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mconsts\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mlinearize\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtraceable\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0mprimals\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 136\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m/usr/local/lib/python3.8/dist-packages/jax/interpreters/ad.py\u001b[0m in \u001b[0;36mlinearize\u001b[0;34m(traceable, *primals, **kwargs)\u001b[0m\n\u001b[1;32m 123\u001b[0m \u001b[0mjvpfun_flat\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mout_tree\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mflatten_fun\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mjvpfun\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0min_tree\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 124\u001b[0;31m \u001b[0mjaxpr\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mout_pvals\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mconsts\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mpe\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtrace_to_jaxpr_nounits\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mjvpfun_flat\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0min_pvals\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 125\u001b[0m \u001b[0mout_primals_pvals\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mout_tangents_pvals\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtree_unflatten\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mout_tree\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mout_pvals\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m/usr/local/lib/python3.8/dist-packages/jax/_src/profiler.py\u001b[0m in \u001b[0;36mwrapper\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 313\u001b[0m \u001b[0;32mwith\u001b[0m \u001b[0mTraceAnnotation\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mname\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mdecorator_kwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 314\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mfunc\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 315\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mwrapper\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m/usr/local/lib/python3.8/dist-packages/jax/interpreters/partial_eval.py\u001b[0m in \u001b[0;36mtrace_to_jaxpr_nounits\u001b[0;34m(fun, pvals, instantiate)\u001b[0m\n\u001b[1;32m 766\u001b[0m \u001b[0mfun\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtrace_to_subjaxpr_nounits\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfun\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmain\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minstantiate\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 767\u001b[0;31m \u001b[0mjaxpr\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mout_pvals\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mconsts\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0menv\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mfun\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcall_wrapped\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mpvals\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 768\u001b[0m \u001b[0;32massert\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0menv\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m/usr/local/lib/python3.8/dist-packages/jax/linear_util.py\u001b[0m in \u001b[0;36mcall_wrapped\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 166\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 167\u001b[0;31m \u001b[0mans\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mf\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mdict\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mparams\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 168\u001b[0m \u001b[0;32mexcept\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m\u001b[0m in \u001b[0;36mjv\u001b[0;34m(v, z)\u001b[0m\n\u001b[1;32m 24\u001b[0m \u001b[0;31m# We use vectorize=True because scipy.special.jv handles broadcasted inputs.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 25\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mjax\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpure_callback\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0m_scipy_jv\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mresult_shape_dtype\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mv\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mz\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mvectorized\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mTrue\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", - "\u001b[0;32m/usr/local/lib/python3.8/dist-packages/jax/_src/api.py\u001b[0m in \u001b[0;36mpure_callback\u001b[0;34m(callback, result_shape_dtypes, *args, **kwargs)\u001b[0m\n\u001b[1;32m 3425\u001b[0m \"\"\"\n\u001b[0;32m-> 3426\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mjcb\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpure_callback\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mcallback\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mresult_shape_dtypes\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 3427\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m/usr/local/lib/python3.8/dist-packages/jax/_src/callback.py\u001b[0m in \u001b[0;36mpure_callback\u001b[0;34m(callback, result_shape_dtypes, vectorized, *args, **kwargs)\u001b[0m\n\u001b[1;32m 130\u001b[0m \u001b[0mflat_result_avals\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mout_tree\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtree_util\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtree_flatten\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mresult_avals\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 131\u001b[0;31m out_flat = pure_callback_p.bind(\n\u001b[0m\u001b[1;32m 132\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0mflat_args\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcallback\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0m_flat_callback\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m/usr/local/lib/python3.8/dist-packages/jax/core.py\u001b[0m in \u001b[0;36mbind\u001b[0;34m(self, *args, **params)\u001b[0m\n\u001b[1;32m 328\u001b[0m all(isinstance(arg, Tracer) or valid_jaxtype(arg) for arg in args)), args\n\u001b[0;32m--> 329\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbind_with_trace\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfind_top_trace\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mparams\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 330\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m/usr/local/lib/python3.8/dist-packages/jax/core.py\u001b[0m in \u001b[0;36mbind_with_trace\u001b[0;34m(self, trace, args, params)\u001b[0m\n\u001b[1;32m 331\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mbind_with_trace\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtrace\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mparams\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 332\u001b[0;31m \u001b[0mout\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtrace\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mprocess_primitive\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmap\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtrace\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfull_raise\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0margs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mparams\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 333\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mmap\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfull_lower\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mout\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmultiple_results\u001b[0m \u001b[0;32melse\u001b[0m \u001b[0mfull_lower\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mout\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m/usr/local/lib/python3.8/dist-packages/jax/interpreters/ad.py\u001b[0m in \u001b[0;36mprocess_primitive\u001b[0;34m(self, primitive, tracers, params)\u001b[0m\n\u001b[1;32m 309\u001b[0m \u001b[0;32mraise\u001b[0m \u001b[0mNotImplementedError\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmsg\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 310\u001b[0;31m \u001b[0mprimal_out\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtangent_out\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mjvp\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mprimals_in\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtangents_in\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mparams\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 311\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mprimitive\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmultiple_results\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m/usr/local/lib/python3.8/dist-packages/jax/_src/callback.py\u001b[0m in \u001b[0;36mpure_callback_jvp_rule\u001b[0;34m(***failed resolving arguments***)\u001b[0m\n\u001b[1;32m 54\u001b[0m \u001b[0;32mdel\u001b[0m \u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mkwargs\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 55\u001b[0;31m raise ValueError(\n\u001b[0m\u001b[1;32m 56\u001b[0m \u001b[0;34m\"Pure callbacks do not support JVP. \"\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;31mUnfilteredStackTrace\u001b[0m: ValueError: Pure callbacks do not support JVP. Please use `jax.custom_jvp` to use callbacks while taking gradients.\n\nThe stack trace below excludes JAX-internal frames.\nThe preceding is the original exception that occurred, unmodified.\n\n--------------------", - "\nThe above exception was the direct cause of the following exception:\n", - "\u001b[0;31mValueError\u001b[0m Traceback (most recent call last)", - "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mjax\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mgrad\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mj1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mz\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", - "\u001b[0;32m\u001b[0m in \u001b[0;36mjv\u001b[0;34m(v, z)\u001b[0m\n\u001b[1;32m 23\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 24\u001b[0m \u001b[0;31m# We use vectorize=True because scipy.special.jv handles broadcasted inputs.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 25\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mjax\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpure_callback\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0m_scipy_jv\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mresult_shape_dtype\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mv\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mz\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mvectorized\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mTrue\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", - "\u001b[0;32m/usr/local/lib/python3.8/dist-packages/jax/_src/callback.py\u001b[0m in \u001b[0;36mpure_callback\u001b[0;34m(callback, result_shape_dtypes, vectorized, *args, **kwargs)\u001b[0m\n\u001b[1;32m 129\u001b[0m lambda x: core.ShapedArray(x.shape, x.dtype), result_shape_dtypes)\n\u001b[1;32m 130\u001b[0m \u001b[0mflat_result_avals\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mout_tree\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtree_util\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtree_flatten\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mresult_avals\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 131\u001b[0;31m out_flat = pure_callback_p.bind(\n\u001b[0m\u001b[1;32m 132\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0mflat_args\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcallback\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0m_flat_callback\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 133\u001b[0m result_avals=tuple(flat_result_avals), vectorized=vectorized)\n", - "\u001b[0;32m/usr/local/lib/python3.8/dist-packages/jax/_src/callback.py\u001b[0m in \u001b[0;36mpure_callback_jvp_rule\u001b[0;34m(***failed resolving arguments***)\u001b[0m\n\u001b[1;32m 53\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mpure_callback_jvp_rule\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 54\u001b[0m \u001b[0;32mdel\u001b[0m \u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mkwargs\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 55\u001b[0;31m raise ValueError(\n\u001b[0m\u001b[1;32m 56\u001b[0m \u001b[0;34m\"Pure callbacks do not support JVP. \"\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 57\u001b[0m \"Please use `jax.custom_jvp` to use callbacks while taking gradients.\")\n", - "\u001b[0;31mValueError\u001b[0m: Pure callbacks do not support JVP. Please use `jax.custom_jvp` to use callbacks while taking gradients." - ] - } - ], - "source": [ - "jax.grad(j1)(z)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "PtYeJ_xUW09v" - }, - "source": [ - "Let's define a custom gradient rule for this. Looking at the definition of the [Bessel Function of the First Kind](https://en.wikipedia.org/?title=Bessel_function_of_the_first_kind), we find that there is a relatively straightforward recurrence relationship for the derivative with respect to the argument `z`:\n", - "\n", - "$$\n", - "d J_\\nu(z) = \\left\\{\n", - "\\begin{eqnarray}\n", - "-J_1(z),\\ &\\nu=0\\\\\n", - "[J_{\\nu - 1}(z) - J_{\\nu + 1}(z)]/2,\\ &\\nu\\ne 0\n", - "\\end{eqnarray}\\right.\n", - "$$\n", - "\n", - "The gradient with respect to $\\nu$ is more complicated, but since we've restricted the `v` argument to integer types we don't need to worry about its gradient for the sake of this example.\n", - "\n", - "We can use `jax.custom_jvp` to define this automatic differentiation rule for our callback function:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "BOVQnt05XvLs" - }, - "outputs": [], - "source": [ - "jv = jax.custom_jvp(jv)\n", - "\n", - "@jv.defjvp\n", - "def _jv_jvp(primals, tangents):\n", - " v, z = primals\n", - " _, z_dot = tangents # Note: v_dot is always 0 because v is integer.\n", - " jv_minus_1, jv_plus_1 = jv(v - 1, z), jv(v + 1, z)\n", - " djv_dz = jnp.where(v == 0, -jv_plus_1, 0.5 * (jv_minus_1 - jv_plus_1))\n", - " return jv(v, z), z_dot * djv_dz" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "W1SxcvQSX44c" - }, - "source": [ - "Now computing the gradient of our function will work correctly:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "sCGceBs-X8nL", - "outputId": "71c5589f-f996-44a0-f09a-ca8bb40c167a" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "-0.06447162\n" - ] - } - ], - "source": [ - "j1 = partial(jv, 1)\n", - "print(jax.grad(j1)(2.0))" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "gWQ4phN5YB26" - }, - "source": [ - "Further, since we've defined our gradient in terms of `jv` itself, JAX's architecture means that we get second-order and higher derivatives for free:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "QTe5mRAvYQBh", - "outputId": "d58ecff3-9419-422a-fd0e-14a7d9cf2cc3" - }, - "outputs": [ - { - "data": { - "text/plain": [ - "Array(-0.4003078, dtype=float32, weak_type=True)" - ] - }, - "execution_count": 8, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "jax.hessian(j1)(2.0)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "QEXGxU4uYZii" - }, - "source": [ - "Keep in mind that although this all works correctly with JAX, each call to our callback-based `jv` function will result in passing the input data from the device to the host, and passing the output of `scipy.special.jv` from the host back to the device.\n", - "When running on accelerators like GPU or TPU, this data movement and host synchronization can lead to significant overhead each time `jv` is called.\n", - "However, if you are running JAX on a single CPU (where the \"host\" and \"device\" are on the same hardware), JAX will generally do this data transfer in a fast, zero-copy fashion, making this pattern is a relatively straightforward way extend JAX's capabilities." - ] - } - ], - "metadata": { - "colab": { - "provenance": [] - }, - "jupytext": { - "formats": "ipynb,md:myst" - }, - "kernelspec": { - "display_name": "Python 3", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "name": "python", - "version": "3.10.0" - } - }, - "nbformat": 4, - "nbformat_minor": 0 -} diff --git a/docs/notebooks/external_callbacks.md b/docs/notebooks/external_callbacks.md deleted file mode 100644 index 857eef42e2b3..000000000000 --- a/docs/notebooks/external_callbacks.md +++ /dev/null @@ -1,516 +0,0 @@ ---- -jupytext: - formats: ipynb,md:myst - text_representation: - extension: .md - format_name: myst - format_version: 0.13 - jupytext_version: 1.16.1 -kernelspec: - display_name: Python 3 - name: python3 ---- - -+++ {"id": "7XNMxdTwURqI"} - -# External Callbacks in JAX - - - -+++ {"id": "h6lXo6bSUYGq"} - -This guide outlines the uses of various callback functions, which allow JAX runtimes to execute Python code on the host, even while running under `jit`, `vmap`, `grad`, or another transformation. - -+++ {"id": "Xi_nhfpnlmbm"} - -## Why callbacks? - -A callback routine is a way to perform **host-side** execution of code at runtime. -As a simple example, suppose you'd like to print the *value* of some variable during the course of a computation. -Using a simple Python `print` statement, it looks like this: - -```{code-cell} -:id: lz8rEL1Amb4r -:outputId: bbd37102-19f2-46d2-b794-3d4952c6fe97 - -import jax - -@jax.jit -def f(x): - y = x + 1 - print("intermediate value: {}".format(y)) - return y * 2 - -result = f(2) -``` - -+++ {"id": "yEy41sFAmxOp"} - -What is printed is not the runtime value, but the trace-time abstract value (if you're not famililar with *tracing* in JAX, a good primer can be found in [How To Think In JAX](https://jax.readthedocs.io/en/latest/notebooks/thinking_in_jax.html)). - -To print the value at runtime we need a callback, for example `jax.debug.print`: - -```{code-cell} -:id: wFfHmoQxnKDF -:outputId: 6bea21d9-9bb1-4d4d-f3ec-fcf1c691a46a - -@jax.jit -def f(x): - y = x + 1 - jax.debug.print("intermediate value: {}", y) - return y * 2 - -result = f(2) -``` - -+++ {"id": "CvWv3pudn9X5"} - -This works by passing the runtime value represented by `y` back to the host process, where the host can print the value. - -+++ {"id": "X0vR078znuT-"} - -## Flavors of Callback - -In earlier versions of JAX, there was only one kind of callback available, implemented in `jax.experimental.host_callback`. The `host_callback` routines had some deficiencies, and are now deprecated in favor of several callbacks designed for different situations: - -- {func}`jax.pure_callback`: appropriate for pure functions: i.e. functions with no side effect. -- {func}`jax.experimental.io_callback`: appropriate for impure functions: e.g. functions which read or write data to disk. -- {func}`jax.debug.callback`: appropriate for functions that should reflect the execution behavior of the compiler. - -(The {func}`jax.debug.print` function we used above is a wrapper around {func}`jax.debug.callback`). - -From the user perspective, these three flavors of callback are mainly distinguished by what transformations and compiler optimizations they allow. - -|callback function | supports return value | `jit` | `vmap` | `grad` | `scan`/`while_loop` | guaranteed execution | -|-------------------------------------|----|----|----|----|----|----| -|`jax.pure_callback` | ✅ | ✅ | ✅ | ❌¹ | ✅ | ❌ | -|`jax.experimental.io_callback` | ✅ | ✅ | ✅/❌² | ❌ | ✅³ | ✅ | -|`jax.debug.callback` | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | - -¹ `jax.pure_callback` can be used with `custom_jvp` to make it compatible with autodiff - -² `jax.experimental.io_callback` is compatible with `vmap` only if `ordered=False`. - -³ Note that `vmap` of `scan`/`while_loop` of `io_callback` has complicated semantics, and its behavior may change in future releases. - -+++ {"id": "hE_M8DaPvoym"} - -### Exploring `jax.pure_callback` - -`jax.pure_callback` is generally the callback function you should reach for when you want host-side execution of a pure function: i.e. a function that has no side-effects (such as printing values, reading data from disk, updating a global state, etc.). - -The function you pass to `jax.pure_callback` need not actually be pure, but it will be assumed pure by JAX's transformations and higher-order functions, which means that it may be silently elided or called multiple times. - -```{code-cell} -:id: 4lQDzXy6t_-k -:outputId: 279e4daf-0540-4eab-f535-d3bcbac74c44 - -import jax -import jax.numpy as jnp -import numpy as np - -def f_host(x): - # call a numpy (not jax.numpy) operation: - return np.sin(x).astype(x.dtype) - -def f(x): - result_shape = jax.ShapeDtypeStruct(x.shape, x.dtype) - return jax.pure_callback(f_host, result_shape, x) - -x = jnp.arange(5.0) -f(x) -``` - -+++ {"id": "q7YCIr8qMrDs"} - -Because `pure_callback` can be elided or duplicated, it is compatible out-of-the-box with transformations like `jit` and `vmap`, as well as higher-order primitives like `scan` and `while_loop`:" - -```{code-cell} -:id: bgoZ0fxsuoWV -:outputId: 901443bd-5cb4-4923-ce53-6f832ac22ca9 - -jax.jit(f)(x) -``` - -```{code-cell} -:id: ajBRGWGfupu2 -:outputId: b28e31ee-7457-4b92-872b-52d819f53ddf - -jax.vmap(f)(x) -``` - -```{code-cell} -:id: xe7AOGexvC13 -:outputId: 8fa77977-1f2b-41c5-cc5e-11993ee5aa3e - -def body_fun(_, x): - return _, f(x) -jax.lax.scan(body_fun, None, jnp.arange(5.0))[1] -``` - -+++ {"id": "tMzAVs2VNj5G"} - -However, because there is no way for JAX to introspect the content of the callback, `pure_callback` has undefined autodiff semantics: - -```{code-cell} -:id: 4QAF4VhUu5bb -:outputId: f8a06d02-47e9-4240-8077-d7be81e5a480 - -%xmode minimal -``` - -```{code-cell} -:id: qUpKPxlOurfY -:outputId: 11a665e8-40eb-4b0e-dc2e-a544a25fc57e -:tags: [raises-exception] - -jax.grad(f)(x) -``` - -+++ {"id": "y9DAibV4Nwpo"} - -For an example of using `pure_callback` with `jax.custom_jvp`, see *Example: `pure_callback` with `custom_jvp`* below. - -+++ {"id": "LrvdAloMZbIe"} - -By design functions passed to `pure_callback` are treated as if they have no side-effects: one consequence of this is that if the output of the function is not used, the compiler may eliminate the callback entirely: - -```{code-cell} -:id: mmFc_zawZrBq -:outputId: a4df7568-3f64-4b2f-9a2c-7adb2e0815e0 - -def print_something(): - print('printing something') - return np.int32(0) - -@jax.jit -def f1(): - return jax.pure_callback(print_something, np.int32(0)) -f1(); -``` - -```{code-cell} -:id: tTwE4kpmaNei - -@jax.jit -def f2(): - jax.pure_callback(print_something, np.int32(0)) - return 1.0 -f2(); -``` - -+++ {"id": "qfyGYbw4Z5U3"} - -In `f1`, the output of the callback is used in the return value of the function, so the callback is executed and we see the printed output. -In `f2` on the other hand, the output of the callback is unused, and so the compiler notices this and eliminates the function call. These are the correct semantics for a callback to a function with no side-effects. - -+++ {"id": "JHcJybr7OEBM"} - -### Exploring `jax.experimental.io_callback` - -In contrast to {func}`jax.pure_callback`, {func}`jax.experimental.io_callback` is explicitly meant to be used with impure functions, i.e. functions that do have side-effects. - -As an example, here is a callback to a global host-side numpy random generator. This is an impure operation because a side-effect of generating a random number in numpy is that the random state is updated (Please note that this is meant as a toy example of `io_callback` and not necessarily a recommended way of generating random numbers in JAX!). - -```{code-cell} -:id: eAg5xIhrOiWV -:outputId: e3cfec21-d843-4852-a49d-69a69fba9fc1 - -from jax.experimental import io_callback -from functools import partial - -global_rng = np.random.default_rng(0) - -def host_side_random_like(x): - """Generate a random array like x using the global_rng state""" - # We have two side-effects here: - # - printing the shape and dtype - # - calling global_rng, thus updating its state - print(f'generating {x.dtype}{list(x.shape)}') - return global_rng.uniform(size=x.shape).astype(x.dtype) - -@jax.jit -def numpy_random_like(x): - return io_callback(host_side_random_like, x, x) - -x = jnp.zeros(5) -numpy_random_like(x) -``` - -+++ {"id": "mAIF31MlXj33"} - -The `io_callback` is compatible with `vmap` by default: - -```{code-cell} -:id: NY3o5dG6Vg6u -:outputId: a67a8a98-214e-40ca-ad98-a930cd3db85e - -jax.vmap(numpy_random_like)(x) -``` - -+++ {"id": "XXvSeeOXXquZ"} - -Note, however, that this may execute the mapped callbacks in any order. So, for example, if you ran this on a GPU, the order of the mapped outputs might differ from run to run. - -If it is important that the order of callbacks be preserved, you can set `ordered=True`, in which case attempting to `vmap` will raise an error: - -```{code-cell} -:id: 3aNmRsDrX3-2 -:outputId: a8ff4b77-f4cb-442f-8cfb-ea7251c66274 -:tags: [raises-exception] - -@jax.jit -def numpy_random_like_ordered(x): - return io_callback(host_side_random_like, x, x, ordered=True) - -jax.vmap(numpy_random_like_ordered)(x) -``` - -+++ {"id": "fD2FTHlUYAZH"} - -On the other hand, `scan` and `while_loop` work with `io_callback` regardless of whether ordering is enforced: - -```{code-cell} -:id: lMVzZlIEWL7F -:outputId: f9741c18-a30d-4d46-b706-8102849286b5 - -def body_fun(_, x): - return _, numpy_random_like_ordered(x) -jax.lax.scan(body_fun, None, jnp.arange(5.0))[1] -``` - -+++ {"id": "w_sf8mCbbo8K"} - -Like `pure_callback`, `io_callback` fails under automatic differentiation if it is passed a differentiated variable: - -```{code-cell} -:id: Cn6_RG4JcKZm -:outputId: 336ae5d2-e35b-4fe5-cbfb-14a7aef28c07 -:tags: [raises-exception] - -jax.grad(numpy_random_like)(x) -``` - -+++ {"id": "plvfn9lWcKu4"} - -However, if the callback is not dependent on a differentiated variable, it will execute: - -```{code-cell} -:id: wxgfDmDfb5bx -:outputId: d8c0285c-cd04-4b4d-d15a-1b07f778882d - -@jax.jit -def f(x): - io_callback(lambda: print('hello'), None) - return x - -jax.grad(f)(1.0); -``` - -+++ {"id": "STLI40EZcVIY"} - -Unlike `pure_callback`, the compiler will not remove the callback execution in this case, even though the output of the callback is unused in the subsequent computation. - -+++ {"id": "pkkM1ZmqclV-"} - -### Exploring `debug.callback` - -Both `pure_callback` and `io_callback` enforce some assumptions about the purity of the function they're calling, and limit in various ways what JAX transforms and compilation machinery may do. `debug.callback` essentially assumes *nothing* about the callback function, such that the action of the callback reflects exactly what JAX is doing during the course of a program. Further, `debug.callback` *cannot* return any value to the program. - -```{code-cell} -:id: 74TdWyu9eqBa -:outputId: d8551dab-2e61-492e-9ac3-dc3db51b2c18 - -from jax import debug - -def log_value(x): - # This could be an actual logging call; we'll use - # print() for demonstration - print("log:", x) - -@jax.jit -def f(x): - debug.callback(log_value, x) - return x - -f(1.0); -``` - -+++ {"id": "P848STlsfzmW"} - -The debug callback is compatible with `vmap`: - -```{code-cell} -:id: 2sSNsPB-fGVI -:outputId: fff58575-d94c-48fb-b88a-c1c395595fd0 - -x = jnp.arange(5.0) -jax.vmap(f)(x); -``` - -+++ {"id": "VDMacqpXf3La"} - -And is also compatible with `grad` and other autodiff transformations - -```{code-cell} -:id: wkFRle-tfTDe -:outputId: 4e8a81d0-5012-4c51-d843-3fbdc498df31 - -jax.grad(f)(1.0); -``` - -+++ {"id": "w8t-SDZ3gRzE"} - -This can make `debug.callback` more useful for general-purpose debugging than either `pure_callback` or `io_callback`. - -+++ {"id": "dF7hoWGQUneJ"} - -## Example: `pure_callback` with `custom_jvp` - -One powerful way to take advantage of {func}`jax.pure_callback` is to combine it with {class}`jax.custom_jvp` (see [Custom derivative rules](https://jax.readthedocs.io/en/latest/notebooks/Custom_derivative_rules_for_Python_code.html) for more details on `custom_jvp`). -Suppose we want to create a JAX-compatible wrapper for a scipy or numpy function that is not yet available in the `jax.scipy` or `jax.numpy` wrappers. - -Here, we'll consider creating a wrapper for the Bessel function of the first kind, implemented in `scipy.special.jv`. -We can start by defining a straightforward `pure_callback`: - -```{code-cell} -:id: Ge4fNPZdVSJY - -import jax -import jax.numpy as jnp -import scipy.special - -def jv(v, z): - v, z = jnp.asarray(v), jnp.asarray(z) - - # Require the order v to be integer type: this simplifies - # the JVP rule below. - assert jnp.issubdtype(v.dtype, jnp.integer) - - # Promote the input to inexact (float/complex). - # Note that jnp.result_type() accounts for the enable_x64 flag. - z = z.astype(jnp.result_type(float, z.dtype)) - - # Wrap scipy function to return the expected dtype. - _scipy_jv = lambda v, z: scipy.special.jv(v, z).astype(z.dtype) - - # Define the expected shape & dtype of output. - result_shape_dtype = jax.ShapeDtypeStruct( - shape=jnp.broadcast_shapes(v.shape, z.shape), - dtype=z.dtype) - - # We use vectorize=True because scipy.special.jv handles broadcasted inputs. - return jax.pure_callback(_scipy_jv, result_shape_dtype, v, z, vectorized=True) -``` - -+++ {"id": "vyjQj-0QVuoN"} - -This lets us call into `scipy.special.jv` from transformed JAX code, including when transformed by `jit` and `vmap`: - -```{code-cell} -:id: f4e46670f4e4 - -from functools import partial -j1 = partial(jv, 1) -z = jnp.arange(5.0) -``` - -```{code-cell} -:id: 6svImqFHWBwj -:outputId: bc8c778a-6c10-443b-9be2-c0f28e2ac1a9 - -print(j1(z)) -``` - -+++ {"id": "d48eb4f2d48e"} - -Here is the same result with `jit`: - -```{code-cell} -:id: txvRqR9DWGdC -:outputId: d25f3476-23b1-48e4-dda1-3c06d32c3b87 - -print(jax.jit(j1)(z)) -``` - -+++ {"id": "d861a472d861"} - -And here is the same result again with `vmap`: - -```{code-cell} -:id: BS-Ve5u_WU0C -:outputId: 08cecd1f-6953-4853-e9db-25a03eb5b000 - -print(jax.vmap(j1)(z)) -``` - -+++ {"id": "SCH2ii_dWXP6"} - -However, if we call `jax.grad`, we see an error because there is no autodiff rule defined for this function: - -```{code-cell} -:id: q3qh_4DrWxdQ -:outputId: c46b0bfa-96f3-4629-b9af-a4d4f3ccb870 -:tags: [raises-exception] - -jax.grad(j1)(z) -``` - -+++ {"id": "PtYeJ_xUW09v"} - -Let's define a custom gradient rule for this. Looking at the definition of the [Bessel Function of the First Kind](https://en.wikipedia.org/?title=Bessel_function_of_the_first_kind), we find that there is a relatively straightforward recurrence relationship for the derivative with respect to the argument `z`: - -$$ -d J_\nu(z) = \left\{ -\begin{eqnarray} --J_1(z),\ &\nu=0\\ -[J_{\nu - 1}(z) - J_{\nu + 1}(z)]/2,\ &\nu\ne 0 -\end{eqnarray}\right. -$$ - -The gradient with respect to $\nu$ is more complicated, but since we've restricted the `v` argument to integer types we don't need to worry about its gradient for the sake of this example. - -We can use `jax.custom_jvp` to define this automatic differentiation rule for our callback function: - -```{code-cell} -:id: BOVQnt05XvLs - -jv = jax.custom_jvp(jv) - -@jv.defjvp -def _jv_jvp(primals, tangents): - v, z = primals - _, z_dot = tangents # Note: v_dot is always 0 because v is integer. - jv_minus_1, jv_plus_1 = jv(v - 1, z), jv(v + 1, z) - djv_dz = jnp.where(v == 0, -jv_plus_1, 0.5 * (jv_minus_1 - jv_plus_1)) - return jv(v, z), z_dot * djv_dz -``` - -+++ {"id": "W1SxcvQSX44c"} - -Now computing the gradient of our function will work correctly: - -```{code-cell} -:id: sCGceBs-X8nL -:outputId: 71c5589f-f996-44a0-f09a-ca8bb40c167a - -j1 = partial(jv, 1) -print(jax.grad(j1)(2.0)) -``` - -+++ {"id": "gWQ4phN5YB26"} - -Further, since we've defined our gradient in terms of `jv` itself, JAX's architecture means that we get second-order and higher derivatives for free: - -```{code-cell} -:id: QTe5mRAvYQBh -:outputId: d58ecff3-9419-422a-fd0e-14a7d9cf2cc3 - -jax.hessian(j1)(2.0) -``` - -+++ {"id": "QEXGxU4uYZii"} - -Keep in mind that although this all works correctly with JAX, each call to our callback-based `jv` function will result in passing the input data from the device to the host, and passing the output of `scipy.special.jv` from the host back to the device. -When running on accelerators like GPU or TPU, this data movement and host synchronization can lead to significant overhead each time `jv` is called. -However, if you are running JAX on a single CPU (where the "host" and "device" are on the same hardware), JAX will generally do this data transfer in a fast, zero-copy fashion, making this pattern is a relatively straightforward way extend JAX's capabilities. diff --git a/docs/notebooks/neural_network_with_tfds_data.ipynb b/docs/notebooks/neural_network_with_tfds_data.ipynb index 95c00bf1e689..c31a99746866 100644 --- a/docs/notebooks/neural_network_with_tfds_data.ipynb +++ b/docs/notebooks/neural_network_with_tfds_data.ipynb @@ -36,15 +36,15 @@ "id": "B_XlLLpcWjkA" }, "source": [ - "# Training a Simple Neural Network, with tensorflow/datasets Data Loading\n", + "# Training a simple neural network, with tensorflow/datasets data loading\n", "\n", "\n", "\n", - "[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/jax/blob/main/docs/notebooks/neural_network_with_tfds_data.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/google/jax/blob/main/docs/notebooks/neural_network_with_tfds_data.ipynb)\n", + "[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jax-ml/jax/blob/main/docs/notebooks/neural_network_with_tfds_data.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/jax-ml/jax/blob/main/docs/notebooks/neural_network_with_tfds_data.ipynb)\n", "\n", "_Forked from_ `neural_network_and_data_loading.ipynb`\n", "\n", - "![JAX](https://raw.githubusercontent.com/google/jax/main/images/jax_logo_250px.png)\n", + "![JAX](https://raw.githubusercontent.com/jax-ml/jax/main/images/jax_logo_250px.png)\n", "\n", "Let's combine everything we showed in the [quickstart](https://jax.readthedocs.io/en/latest/quickstart.html) to train a simple neural network. We will first specify and train a simple MLP on MNIST using JAX for the computation. We will use `tensorflow/datasets` data loading API to load images and labels (because it's pretty great, and the world doesn't need yet another data loading library :P).\n", "\n", @@ -132,7 +132,7 @@ " for w, b in params[:-1]:\n", " outputs = jnp.dot(w, activations) + b\n", " activations = relu(outputs)\n", - " \n", + "\n", " final_w, final_b = params[-1]\n", " logits = jnp.dot(final_w, activations) + final_b\n", " return logits - logsumexp(logits)" @@ -251,7 +251,7 @@ "def one_hot(x, k, dtype=jnp.float32):\n", " \"\"\"Create a one-hot encoding of x of size k.\"\"\"\n", " return jnp.array(x[:, None] == jnp.arange(k), dtype)\n", - " \n", + "\n", "def accuracy(params, images, targets):\n", " target_class = jnp.argmax(targets, axis=1)\n", " predicted_class = jnp.argmax(batched_predict(params, images), axis=1)\n", @@ -274,7 +274,7 @@ "id": "umJJGZCC2oKl" }, "source": [ - "## Data Loading with `tensorflow/datasets`\n", + "## Data loading with `tensorflow/datasets`\n", "\n", "JAX is laser-focused on program transformations and accelerator-backed NumPy, so we don't include data loading or munging in the JAX library. There are already a lot of great data loaders out there, so let's just use them instead of reinventing anything. We'll use the `tensorflow/datasets` data loader." ] @@ -344,7 +344,7 @@ "id": "xxPd6Qw3Z98v" }, "source": [ - "## Training Loop" + "## Training loop" ] }, { diff --git a/docs/notebooks/neural_network_with_tfds_data.md b/docs/notebooks/neural_network_with_tfds_data.md index 8f795484d5b9..53b7d47358c2 100644 --- a/docs/notebooks/neural_network_with_tfds_data.md +++ b/docs/notebooks/neural_network_with_tfds_data.md @@ -5,7 +5,7 @@ jupytext: extension: .md format_name: myst format_version: 0.13 - jupytext_version: 1.16.1 + jupytext_version: 1.16.4 kernelspec: display_name: Python 3 language: python @@ -34,15 +34,15 @@ limitations under the License. +++ {"id": "B_XlLLpcWjkA"} -# Training a Simple Neural Network, with tensorflow/datasets Data Loading +# Training a simple neural network, with tensorflow/datasets data loading -[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/jax/blob/main/docs/notebooks/neural_network_with_tfds_data.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/google/jax/blob/main/docs/notebooks/neural_network_with_tfds_data.ipynb) +[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jax-ml/jax/blob/main/docs/notebooks/neural_network_with_tfds_data.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/jax-ml/jax/blob/main/docs/notebooks/neural_network_with_tfds_data.ipynb) _Forked from_ `neural_network_and_data_loading.ipynb` -![JAX](https://raw.githubusercontent.com/google/jax/main/images/jax_logo_250px.png) +![JAX](https://raw.githubusercontent.com/jax-ml/jax/main/images/jax_logo_250px.png) Let's combine everything we showed in the [quickstart](https://jax.readthedocs.io/en/latest/quickstart.html) to train a simple neural network. We will first specify and train a simple MLP on MNIST using JAX for the computation. We will use `tensorflow/datasets` data loading API to load images and labels (because it's pretty great, and the world doesn't need yet another data loading library :P). @@ -104,7 +104,7 @@ def predict(params, image): for w, b in params[:-1]: outputs = jnp.dot(w, activations) + b activations = relu(outputs) - + final_w, final_b = params[-1] logits = jnp.dot(final_w, activations) + final_b return logits - logsumexp(logits) @@ -164,7 +164,7 @@ At this point, we have all the ingredients we need to define our neural network def one_hot(x, k, dtype=jnp.float32): """Create a one-hot encoding of x of size k.""" return jnp.array(x[:, None] == jnp.arange(k), dtype) - + def accuracy(params, images, targets): target_class = jnp.argmax(targets, axis=1) predicted_class = jnp.argmax(batched_predict(params, images), axis=1) @@ -183,7 +183,7 @@ def update(params, x, y): +++ {"id": "umJJGZCC2oKl"} -## Data Loading with `tensorflow/datasets` +## Data loading with `tensorflow/datasets` JAX is laser-focused on program transformations and accelerator-backed NumPy, so we don't include data loading or munging in the JAX library. There are already a lot of great data loaders out there, so let's just use them instead of reinventing anything. We'll use the `tensorflow/datasets` data loader. @@ -229,7 +229,7 @@ print('Test:', test_images.shape, test_labels.shape) +++ {"id": "xxPd6Qw3Z98v"} -## Training Loop +## Training loop ```{code-cell} ipython3 :id: X2DnZo3iYj18 diff --git a/docs/notebooks/shard_map.ipynb b/docs/notebooks/shard_map.ipynb index 919690d230ab..1315783c340c 100644 --- a/docs/notebooks/shard_map.ipynb +++ b/docs/notebooks/shard_map.ipynb @@ -5,10 +5,12 @@ "id": "41a7e222", "metadata": {}, "source": [ - "# SPMD multi-device parallelism with `shard_map`\n", + "# Manual parallelism with `shard_map`\n", "\n", "\n", "\n", + "## Overview\n", + "\n", "`shard_map` is a single-program multiple-data (SPMD) multi-device parallelism API to map a function over shards of data. Mapped function applications, or _instances_, communicate with each other via explicit collective communication operations.\n", "\n", "`shard_map` is complementary to, and composable with, the automatic compiler-based parallelization built into `jit`. With `jit` you write code as if for a single device, and [the compiler can automatically partition computation over multiple devices](https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html), generating per-device code and communication collectives behind the scenes. With `shard_map` you take control, writing your own partitioned code and explicit collectives. Or you can do a bit of both: take manual control across groups of devices while leaving within-group device partitioning up to the compiler. The two approaches can be mixed, matched, and composed as needed.\n", @@ -36,7 +38,7 @@ "id": "97c57a94", "metadata": {}, "source": [ - "## So, let's see a `shard_map`!\n", + "### So, let's see a `shard_map`!\n", "\n", "Without further ado, here's a toy example:" ] @@ -189,9 +191,9 @@ "id": "532fe5f6", "metadata": {}, "source": [ - "## Slow down, start with the basics!\n", + "### Slow down, start with the basics!\n", "\n", - "### Rank-reducing vs rank-preserving maps\n", + "#### Rank-reducing vs rank-preserving maps\n", "\n", "We can think of `vmap` and `pmap` as unstacking each array input along an axis\n", "(e.g. unpacking a 2D matrix into its 1D rows), applying its body function to\n", @@ -274,7 +276,7 @@ "over 4 devices) then semantically we get 4 logical applications of the\n", "function, corresponding to the 4 devices physically computing them.\n", "\n", - "### Controlling how each input is split (unconcatenated) and tiled with `in_specs`\n", + "#### Controlling how each input is split (unconcatenated) and tiled with `in_specs`\n", "\n", "Each of the `in_specs` identifies some of the corresponding input array's axes\n", "with mesh axes by name using `PartitionSpec`s, representing how to split (or\n", @@ -354,7 +356,7 @@ "Physical data movement is possible on inputs, as each device needs to have a\n", "copy of the appropriate data.\n", "\n", - "### Controlling how each output assembled by concatenation, block transposition, and untiling using `out_specs`\n", + "#### Controlling how each output assembled by concatenation, block transposition, and untiling using `out_specs`\n", "\n", "Analogously to the input side, each of the `out_specs` identifies some of the\n", "corresponding output array's axes with mesh axes by name, representing how the\n", @@ -482,7 +484,7 @@ "`Array`s, or physically how to interpret the buffers across devices as the\n", "physical layout of a single logical `Array`.\n", "\n", - "# API Specification\n", + "## API Specification\n", "\n", "```python\n", "from jax.sharding import Mesh\n", @@ -508,7 +510,7 @@ "the corresponding `PartitionSpec` `spec` as roughly\n", "`tuple(sz // (1 if n is None else mesh.shape[n]) for sz, n in zip(shape, spec))`.\n", "\n", - "# Collectives tutorial\n", + "## Collectives tutorial\n", "\n", "A `shard_map` need not be a pure map: function applications can communicate\n", "with each other via _collectives_, using axis names defined in the `mesh`\n", @@ -529,7 +531,7 @@ "\n", "```python\n", "def f_shmapped_ref(x):\n", - " x_blocks = jnp.array_split(x, mesh.shape[0])\n", + " x_blocks = jnp.array_split(x, mesh.shape['i'])\n", " y_blocks = [f(x_blk) for x_blk in x_blocks]\n", " return jnp.concatenate(y_blocks)\n", "```\n", @@ -572,7 +574,7 @@ "means communication across devices. Exactly what communication happens, and\n", "what values are computed, depend on the collective.\n", "\n", - "## `psum`\n", + "### `psum`\n", "\n", "The simplest collective may be `jax.lax.psum`, which computes an\n", "all-reduce-sum along a device mesh axis (or multiple axes).\n", @@ -714,7 +716,7 @@ "In the sequel, we'll see how `psum` can be implemented in terms of other\n", "primitives, which gives some intuition about its communication cost.\n", "\n", - "## `all_gather`\n", + "### `all_gather`\n", "\n", "Another fundamental operation is gathering array shards along an axis, so that\n", "each function application has a full copy of the data along that axis:\n", @@ -796,7 +798,7 @@ "In deep learning, we might use `all_gather`s on parameters in fully sharded\n", "data parallelism (FSDP).\n", "\n", - "## `psum_scatter`\n", + "### `psum_scatter`\n", "\n", "The `jax.lax.psum_scatter` collective is a bit less intuitive. It's like\n", "`psum` except each function instance gets only one shard of the result:\n", @@ -871,7 +873,7 @@ "multiplies or fully-sharded data parallel gradient accumulation, as shown in\n", "the examples to follow.\n", "\n", - "## `ppermute`\n", + "### `ppermute`\n", "\n", "The `jax.lax.ppermute` collective provides the most direct way for\n", "function instances to send data to one another. Given a mesh axis and a\n", @@ -998,7 +1000,7 @@ "spatial axes and thus devices must communicate \"halos\" to each other. Or it\n", "may be used under-the-hood in tensor-parallel matrix multiplies.\n", "\n", - "## `all_to_all`\n", + "### `all_to_all`\n", "\n", "A final collective is `all_to_all`, which is essentially a block matrix\n", "transpose operating along one positional axis and one cross-device axis:\n", @@ -1059,12 +1061,12 @@ "where we first sort our local batch of examples according to which expert they\n", "should go to, then apply an `all_to_all` to redistribute examples to experts.\n", "\n", - "# Toy examples\n", + "## Toy examples\n", "\n", "How might we use `shard_map` and collective communication in practice? These\n", "examples, while simple, give some idea.\n", "\n", - "## Matrix multiplies\n", + "### Matrix multiplies\n", "\n", "Parallelizing matrix multiplication is central in scaling up deep learning\n", "models, both for training and for inference. When `jax.jit` automatically\n", @@ -1107,7 +1109,7 @@ "id": "2e2b33b9", "metadata": {}, "source": [ - "### Example 1: `all-gather` on one side\n", + "#### Example 1: `all-gather` on one side\n", "\n", "Consider performing a matrix multiplication where we shard the left-hand side\n", "argument (can think: parameters) on its leading (non-contracting) dimension:" @@ -1301,7 +1303,7 @@ "`jax.lax.fori_loop`. We might also have additional axes of parallelism\n", "involved.\n", "\n", - "### Example 2: `psum_scatter` the result\n", + "#### Example 2: `psum_scatter` the result\n", "\n", "Another sharding we might start with has both `lhs` and `rhs` sharded along\n", "their contracting dimensions, with the output sharded like `rhs` again:" @@ -1446,7 +1448,7 @@ "id": "60c2d2bc", "metadata": {}, "source": [ - "## Neural networks\n", + "### Neural networks\n", "\n", "We can use `shard_map` to parallelize computation in neural networks, either by\n", "itself or in combination with the automatic partitioning in `jax.jit`. This\n", @@ -1483,20 +1485,20 @@ "outputs": [], "source": [ "def init_layer(key, n_in, n_out):\n", - " k1, k2 = jax.random.split(key)\n", - " W = jax.random.normal(k1, (n_in, n_out)) / jnp.sqrt(n_in)\n", - " b = jax.random.normal(k2, (n_out,))\n", - " return W, b\n", + " k1, k2 = jax.random.split(key)\n", + " W = jax.random.normal(k1, (n_in, n_out)) / jnp.sqrt(n_in)\n", + " b = jax.random.normal(k2, (n_out,))\n", + " return W, b\n", "\n", "def init(key, layer_sizes, batch_size):\n", - " key, *keys = jax.random.split(key, len(layer_sizes))\n", - " params = list(map(init_layer, keys, layer_sizes[:-1], layer_sizes[1:]))\n", + " key, *keys = jax.random.split(key, len(layer_sizes))\n", + " params = list(map(init_layer, keys, layer_sizes[:-1], layer_sizes[1:]))\n", "\n", - " key, *keys = jax.random.split(key, 3)\n", - " inputs = jax.random.normal(keys[0], (batch_size, layer_sizes[0]))\n", - " targets = jax.random.normal(keys[1], (batch_size, layer_sizes[-1]))\n", + " key, *keys = jax.random.split(key, 3)\n", + " inputs = jax.random.normal(keys[0], (batch_size, layer_sizes[0]))\n", + " targets = jax.random.normal(keys[1], (batch_size, layer_sizes[-1]))\n", "\n", - " return params, (inputs, targets)" + " return params, (inputs, targets)" ] }, { @@ -1509,7 +1511,7 @@ "layer_sizes = [784, 128, 128, 128, 128, 128, 8]\n", "batch_size = 32\n", "\n", - "params, batch = init(jax.random.PRNGKey(0), layer_sizes, batch_size)" + "params, batch = init(jax.random.key(0), layer_sizes, batch_size)" ] }, { @@ -1524,7 +1526,7 @@ "functions to use different parallelization strategies, with `shard_map` we\n", "often do.\n", "\n", - "### 8-way batch data parallelism\n", + "#### 8-way batch data parallelism\n", "\n", "The simplest multi-device parallelism strategy is to shard the batch of inputs\n", "and targets over multiple devices, replicate the parameters over those devices,\n", @@ -1608,7 +1610,7 @@ "end of the forward pass to compute the loss value, and in the backward pass to\n", "compute the total parameter gradients.\n", "\n", - "### 8-way fully sharded data parallelism (FSDP)\n", + "#### 8-way fully sharded data parallelism (FSDP)\n", "\n", "Another strategy is to additionally shard the parameters over the devices,\n", "all-gathering each one when the full value is needed for the `jnp.dot` or bias\n", @@ -1697,7 +1699,7 @@ "id": "f88ddefe", "metadata": {}, "source": [ - "### 8-way tensor parallelism (TP)\n", + "#### 8-way tensor parallelism (TP)\n", "\n", "Usually we don't use tensor model parallelism by itself, but seeing it in\n", "isolation is a good warmup on parallel matrix multiplication. It's also a good\n", @@ -1750,7 +1752,7 @@ "id": "cf59d537", "metadata": {}, "source": [ - "### FSDP + TP, with `shard_map` at the top level\n", + "#### FSDP + TP, with `shard_map` at the top level\n", "\n", "We can compose these strategies together, using multiple axes of parallelism." ] @@ -1821,7 +1823,7 @@ "id": "94a352ca", "metadata": {}, "source": [ - "### SPMD pipeline parallelism (PP)\n", + "#### SPMD pipeline parallelism (PP)\n", "\n", "With pipeline parallelism we aim to parallelize the evaluation of layers at\n", "different depths in our network. For example, one device might compute the\n", diff --git a/docs/notebooks/shard_map.md b/docs/notebooks/shard_map.md index 6f9dfbb659e1..96667e709ac6 100644 --- a/docs/notebooks/shard_map.md +++ b/docs/notebooks/shard_map.md @@ -7,17 +7,19 @@ jupytext: extension: .md format_name: myst format_version: 0.13 - jupytext_version: 1.16.1 + jupytext_version: 1.16.4 kernelspec: display_name: Python 3 language: python name: python3 --- -# SPMD multi-device parallelism with `shard_map` +# Manual parallelism with `shard_map` +## Overview + `shard_map` is a single-program multiple-data (SPMD) multi-device parallelism API to map a function over shards of data. Mapped function applications, or _instances_, communicate with each other via explicit collective communication operations. `shard_map` is complementary to, and composable with, the automatic compiler-based parallelization built into `jit`. With `jit` you write code as if for a single device, and [the compiler can automatically partition computation over multiple devices](https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html), generating per-device code and communication collectives behind the scenes. With `shard_map` you take control, writing your own partitioned code and explicit collectives. Or you can do a bit of both: take manual control across groups of devices while leaving within-group device partitioning up to the compiler. The two approaches can be mixed, matched, and composed as needed. @@ -33,7 +35,7 @@ import os os.environ["XLA_FLAGS"] = '--xla_force_host_platform_device_count=8' # Use 8 CPU devices ``` -## So, let's see a `shard_map`! +### So, let's see a `shard_map`! Without further ado, here's a toy example: @@ -120,9 +122,9 @@ print('b blocks:'); jax.debug.visualize_array_sharding(b) print('c blocks:'); jax.debug.visualize_array_sharding(c) ``` -## Slow down, start with the basics! +### Slow down, start with the basics! -### Rank-reducing vs rank-preserving maps +#### Rank-reducing vs rank-preserving maps We can think of `vmap` and `pmap` as unstacking each array input along an axis (e.g. unpacking a 2D matrix into its 1D rows), applying its body function to @@ -181,7 +183,7 @@ by any input axis size: for example, if we have a mesh of total size 4 (i.e. over 4 devices) then semantically we get 4 logical applications of the function, corresponding to the 4 devices physically computing them. -### Controlling how each input is split (unconcatenated) and tiled with `in_specs` +#### Controlling how each input is split (unconcatenated) and tiled with `in_specs` Each of the `in_specs` identifies some of the corresponding input array's axes with mesh axes by name using `PartitionSpec`s, representing how to split (or @@ -237,7 +239,7 @@ along the first axis, and used the pspec `P(('j', 'i'), None)`. Physical data movement is possible on inputs, as each device needs to have a copy of the appropriate data. -### Controlling how each output assembled by concatenation, block transposition, and untiling using `out_specs` +#### Controlling how each output assembled by concatenation, block transposition, and untiling using `out_specs` Analogously to the input side, each of the `out_specs` identifies some of the corresponding output array's axes with mesh axes by name, representing how the @@ -329,7 +331,7 @@ Instead, `out_specs` just encodes how to assemble the block outputs into `Array`s, or physically how to interpret the buffers across devices as the physical layout of a single logical `Array`. -# API Specification +## API Specification ```python from jax.sharding import Mesh @@ -355,7 +357,7 @@ from the shape `shape` of the corresponding argument to `shard_map`-of-`f` and the corresponding `PartitionSpec` `spec` as roughly `tuple(sz // (1 if n is None else mesh.shape[n]) for sz, n in zip(shape, spec))`. -# Collectives tutorial +## Collectives tutorial A `shard_map` need not be a pure map: function applications can communicate with each other via _collectives_, using axis names defined in the `mesh` @@ -376,7 +378,7 @@ values, as this reference function: ```python def f_shmapped_ref(x): - x_blocks = jnp.array_split(x, mesh.shape[0]) + x_blocks = jnp.array_split(x, mesh.shape['i']) y_blocks = [f(x_blk) for x_blk in x_blocks] return jnp.concatenate(y_blocks) ``` @@ -419,7 +421,7 @@ collective introduces some amount of cross-block dependence. Physically, that means communication across devices. Exactly what communication happens, and what values are computed, depend on the collective. -## `psum` +### `psum` The simplest collective may be `jax.lax.psum`, which computes an all-reduce-sum along a device mesh axis (or multiple axes). @@ -513,7 +515,7 @@ have a `grad` inside the `shard_map`ped function body, total gradients. In the sequel, we'll see how `psum` can be implemented in terms of other primitives, which gives some intuition about its communication cost. -## `all_gather` +### `all_gather` Another fundamental operation is gathering array shards along an axis, so that each function application has a full copy of the data along that axis: @@ -571,7 +573,7 @@ def all_gather_ref(_, x_blocks, *, tiled=False): In deep learning, we might use `all_gather`s on parameters in fully sharded data parallelism (FSDP). -## `psum_scatter` +### `psum_scatter` The `jax.lax.psum_scatter` collective is a bit less intuitive. It's like `psum` except each function instance gets only one shard of the result: @@ -634,7 +636,7 @@ machine learning, `psum_scatter` can be used in tensor-parallel matrix multiplies or fully-sharded data parallel gradient accumulation, as shown in the examples to follow. -## `ppermute` +### `ppermute` The `jax.lax.ppermute` collective provides the most direct way for function instances to send data to one another. Given a mesh axis and a @@ -731,7 +733,7 @@ parallelizing the evaluation of convolutional layers, where we shard over spatial axes and thus devices must communicate "halos" to each other. Or it may be used under-the-hood in tensor-parallel matrix multiplies. -## `all_to_all` +### `all_to_all` A final collective is `all_to_all`, which is essentially a block matrix transpose operating along one positional axis and one cross-device axis: @@ -780,12 +782,12 @@ In deep learning, we might use `all_to_all` in mixture-of-expert routing, where we first sort our local batch of examples according to which expert they should go to, then apply an `all_to_all` to redistribute examples to experts. -# Toy examples +## Toy examples How might we use `shard_map` and collective communication in practice? These examples, while simple, give some idea. -## Matrix multiplies +### Matrix multiplies Parallelizing matrix multiplication is central in scaling up deep learning models, both for training and for inference. When `jax.jit` automatically @@ -810,7 +812,7 @@ def device_put(x, pspec): return jax.device_put(x, NamedSharding(mesh, pspec)) ``` -### Example 1: `all-gather` on one side +#### Example 1: `all-gather` on one side Consider performing a matrix multiplication where we shard the left-hand side argument (can think: parameters) on its leading (non-contracting) dimension: @@ -926,7 +928,7 @@ In practice, to reduce compile times we would probably roll this into a `jax.lax.fori_loop`. We might also have additional axes of parallelism involved. -### Example 2: `psum_scatter` the result +#### Example 2: `psum_scatter` the result Another sharding we might start with has both `lhs` and `rhs` sharded along their contracting dimensions, with the output sharded like `rhs` again: @@ -1011,7 +1013,7 @@ out = matmul_psumscatter_overlapped_bidi(lhs, rhs) print(jnp.allclose(out, lhs @ rhs, atol=1e-3, rtol=1e-3)) ``` -## Neural networks +### Neural networks We can use `shard_map` to parallelize computation in neural networks, either by itself or in combination with the automatic partitioning in `jax.jit`. This @@ -1035,27 +1037,27 @@ def loss(params, batch): ```{code-cell} def init_layer(key, n_in, n_out): - k1, k2 = jax.random.split(key) - W = jax.random.normal(k1, (n_in, n_out)) / jnp.sqrt(n_in) - b = jax.random.normal(k2, (n_out,)) - return W, b + k1, k2 = jax.random.split(key) + W = jax.random.normal(k1, (n_in, n_out)) / jnp.sqrt(n_in) + b = jax.random.normal(k2, (n_out,)) + return W, b def init(key, layer_sizes, batch_size): - key, *keys = jax.random.split(key, len(layer_sizes)) - params = list(map(init_layer, keys, layer_sizes[:-1], layer_sizes[1:])) + key, *keys = jax.random.split(key, len(layer_sizes)) + params = list(map(init_layer, keys, layer_sizes[:-1], layer_sizes[1:])) - key, *keys = jax.random.split(key, 3) - inputs = jax.random.normal(keys[0], (batch_size, layer_sizes[0])) - targets = jax.random.normal(keys[1], (batch_size, layer_sizes[-1])) + key, *keys = jax.random.split(key, 3) + inputs = jax.random.normal(keys[0], (batch_size, layer_sizes[0])) + targets = jax.random.normal(keys[1], (batch_size, layer_sizes[-1])) - return params, (inputs, targets) + return params, (inputs, targets) ``` ```{code-cell} layer_sizes = [784, 128, 128, 128, 128, 128, 8] batch_size = 32 -params, batch = init(jax.random.PRNGKey(0), layer_sizes, batch_size) +params, batch = init(jax.random.key(0), layer_sizes, batch_size) ``` Compare these examples with the purely [automatic partitioning examples in the @@ -1065,7 +1067,7 @@ While in those automatic partitioning examples we don't need to edit the model functions to use different parallelization strategies, with `shard_map` we often do. -### 8-way batch data parallelism +#### 8-way batch data parallelism The simplest multi-device parallelism strategy is to shard the batch of inputs and targets over multiple devices, replicate the parameters over those devices, @@ -1119,7 +1121,7 @@ that the collective all-reduce-sum operations happen where we'd expect: at the end of the forward pass to compute the loss value, and in the backward pass to compute the total parameter gradients. -### 8-way fully sharded data parallelism (FSDP) +#### 8-way fully sharded data parallelism (FSDP) Another strategy is to additionally shard the parameters over the devices, all-gathering each one when the full value is needed for the `jnp.dot` or bias @@ -1184,7 +1186,7 @@ print(allclose(jax.jit(jax.grad(loss))(params, batch), jax.jit(jax.grad(loss_fsdp))(params, batch))) ``` -### 8-way tensor parallelism (TP) +#### 8-way tensor parallelism (TP) Usually we don't use tensor model parallelism by itself, but seeing it in isolation is a good warmup on parallel matrix multiplication. It's also a good @@ -1225,7 +1227,7 @@ def loss_tp(params, batch): return jnp.mean(jnp.sum((predictions - targets) ** 2, axis=-1)) # NOTE psum! ``` -### FSDP + TP, with `shard_map` at the top level +#### FSDP + TP, with `shard_map` at the top level We can compose these strategies together, using multiple axes of parallelism. @@ -1272,7 +1274,7 @@ print(allclose(jax.jit(jax.grad(loss))(params, batch), jax.jit(jax.grad(loss_fsdp_tp))(params, batch))) ``` -### SPMD pipeline parallelism (PP) +#### SPMD pipeline parallelism (PP) With pipeline parallelism we aim to parallelize the evaluation of layers at different depths in our network. For example, one device might compute the diff --git a/docs/notebooks/thinking_in_jax.ipynb b/docs/notebooks/thinking_in_jax.ipynb index 1c1c9729b654..b5f8074c0f3e 100644 --- a/docs/notebooks/thinking_in_jax.ipynb +++ b/docs/notebooks/thinking_in_jax.ipynb @@ -6,11 +6,11 @@ "id": "LQHmwePqryRU" }, "source": [ - "# How to Think in JAX\n", + "# How to think in JAX\n", "\n", "\n", "\n", - "[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/jax/blob/main/docs/notebooks/thinking_in_jax.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/google/jax/blob/main/docs/notebooks/thinking_in_jax.ipynb)\n", + "[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jax-ml/jax/blob/main/docs/notebooks/thinking_in_jax.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/jax-ml/jax/blob/main/docs/notebooks/thinking_in_jax.ipynb)\n", "\n", "JAX provides a simple and powerful API for writing accelerated numerical code, but working effectively in JAX sometimes requires extra consideration. This document is meant to help build a ground-up understanding of how JAX operates, so that you can use it more effectively." ] @@ -23,7 +23,7 @@ "source": [ "## JAX vs. NumPy\n", "\n", - "**Key Concepts:**\n", + "**Key concepts:**\n", "\n", "- JAX provides a NumPy-inspired interface for convenience.\n", "- Through duck-typing, JAX arrays can often be used as drop-in replacements of NumPy arrays.\n", @@ -282,7 +282,7 @@ "source": [ "## NumPy, lax & XLA: JAX API layering\n", "\n", - "**Key Concepts:**\n", + "**Key concepts:**\n", "\n", "- `jax.numpy` is a high-level wrapper that provides a familiar interface.\n", "- `jax.lax` is a lower-level API that is stricter and often more powerful.\n", @@ -475,7 +475,7 @@ "source": [ "## To JIT or not to JIT\n", "\n", - "**Key Concepts:**\n", + "**Key concepts:**\n", "\n", "- By default JAX executes operations one at a time, in sequence.\n", "- Using a just-in-time (JIT) compilation decorator, sequences of operations can be optimized together and run at once.\n", @@ -675,7 +675,7 @@ "source": [ "## JIT mechanics: tracing and static variables\n", "\n", - "**Key Concepts:**\n", + "**Key concepts:**\n", "\n", "- JIT and other JAX transforms work by *tracing* a function to determine its effect on inputs of a specific shape and type.\n", "\n", @@ -932,9 +932,9 @@ "id": "r-RCl_wD5lI7" }, "source": [ - "## Static vs Traced Operations\n", + "## Static vs traced operations\n", "\n", - "**Key Concepts:**\n", + "**Key concepts:**\n", "\n", "- Just as values can be either static or traced, operations can be static or traced.\n", "\n", diff --git a/docs/notebooks/thinking_in_jax.md b/docs/notebooks/thinking_in_jax.md index 14089fa36e32..b3672b90e653 100644 --- a/docs/notebooks/thinking_in_jax.md +++ b/docs/notebooks/thinking_in_jax.md @@ -5,7 +5,7 @@ jupytext: extension: .md format_name: myst format_version: 0.13 - jupytext_version: 1.16.1 + jupytext_version: 1.16.4 kernelspec: display_name: Python 3 name: python3 @@ -13,11 +13,11 @@ kernelspec: +++ {"id": "LQHmwePqryRU"} -# How to Think in JAX +# How to think in JAX -[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/jax/blob/main/docs/notebooks/thinking_in_jax.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/google/jax/blob/main/docs/notebooks/thinking_in_jax.ipynb) +[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jax-ml/jax/blob/main/docs/notebooks/thinking_in_jax.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/jax-ml/jax/blob/main/docs/notebooks/thinking_in_jax.ipynb) JAX provides a simple and powerful API for writing accelerated numerical code, but working effectively in JAX sometimes requires extra consideration. This document is meant to help build a ground-up understanding of how JAX operates, so that you can use it more effectively. @@ -25,7 +25,7 @@ JAX provides a simple and powerful API for writing accelerated numerical code, b ## JAX vs. NumPy -**Key Concepts:** +**Key concepts:** - JAX provides a NumPy-inspired interface for convenience. - Through duck-typing, JAX arrays can often be used as drop-in replacements of NumPy arrays. @@ -132,7 +132,7 @@ print(y) ## NumPy, lax & XLA: JAX API layering -**Key Concepts:** +**Key concepts:** - `jax.numpy` is a high-level wrapper that provides a familiar interface. - `jax.lax` is a lower-level API that is stricter and often more powerful. @@ -215,7 +215,7 @@ Every JAX operation is eventually expressed in terms of these fundamental XLA op ## To JIT or not to JIT -**Key Concepts:** +**Key concepts:** - By default JAX executes operations one at a time, in sequence. - Using a just-in-time (JIT) compilation decorator, sequences of operations can be optimized together and run at once. @@ -308,7 +308,7 @@ This is because the function generates an array whose shape is not known at comp ## JIT mechanics: tracing and static variables -**Key Concepts:** +**Key concepts:** - JIT and other JAX transforms work by *tracing* a function to determine its effect on inputs of a specific shape and type. @@ -417,9 +417,9 @@ Understanding which values and operations will be static and which will be trace +++ {"id": "r-RCl_wD5lI7"} -## Static vs Traced Operations +## Static vs traced operations -**Key Concepts:** +**Key concepts:** - Just as values can be either static or traced, operations can be static or traced. diff --git a/docs/notebooks/vmapped_log_probs.ipynb b/docs/notebooks/vmapped_log_probs.ipynb index 96b334296667..dccc83168ac0 100644 --- a/docs/notebooks/vmapped_log_probs.ipynb +++ b/docs/notebooks/vmapped_log_probs.ipynb @@ -6,11 +6,11 @@ "id": "6umP1IKf4Dg6" }, "source": [ - "# Autobatching for Bayesian Inference\n", + "# Autobatching for Bayesian inference\n", "\n", "\n", "\n", - "[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/jax/blob/main/docs/notebooks/vmapped_log_probs.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/google/jax/blob/main/docs/notebooks/vmapped_log_probs.ipynb)\n", + "[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jax-ml/jax/blob/main/docs/notebooks/vmapped_log_probs.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/jax-ml/jax/blob/main/docs/notebooks/vmapped_log_probs.ipynb)\n", "\n", "This notebook demonstrates a simple Bayesian inference example where autobatching makes user code easier to write, easier to read, and less likely to include bugs.\n", "\n", @@ -25,17 +25,10 @@ }, "outputs": [], "source": [ - "import functools\n", - "import itertools\n", - "import re\n", - "import sys\n", - "import time\n", - "\n", - "from matplotlib.pyplot import *\n", + "import matplotlib.pyplot as plt\n", "\n", "import jax\n", "\n", - "from jax import lax\n", "import jax.numpy as jnp\n", "import jax.scipy as jsp\n", "from jax import random\n", @@ -348,7 +341,7 @@ "def elbo(beta_loc, beta_log_scale, epsilon):\n", " beta_sample = beta_loc + jnp.exp(beta_log_scale) * epsilon\n", " return jnp.mean(batched_log_joint(beta_sample), 0) + jnp.sum(beta_log_scale - 0.5 * np.log(2*np.pi))\n", - " \n", + "\n", "elbo = jax.jit(elbo)\n", "elbo_val_and_grad = jax.jit(jax.value_and_grad(elbo, argnums=(0, 1)))" ] @@ -548,25 +541,16 @@ } ], "source": [ - "figure(figsize=(7, 7))\n", - "plot(true_beta, beta_loc, '.', label='Approximated Posterior Means')\n", - "plot(true_beta, beta_loc + 2*jnp.exp(beta_log_scale), 'r.', label='Approximated Posterior $2\\sigma$ Error Bars')\n", - "plot(true_beta, beta_loc - 2*jnp.exp(beta_log_scale), 'r.')\n", + "plt.figure(figsize=(7, 7))\n", + "plt.plot(true_beta, beta_loc, '.', label='Approximated Posterior Means')\n", + "plt.plot(true_beta, beta_loc + 2*jnp.exp(beta_log_scale), 'r.', label=r'Approximated Posterior $2\\sigma$ Error Bars')\n", + "plt.plot(true_beta, beta_loc - 2*jnp.exp(beta_log_scale), 'r.')\n", "plot_scale = 3\n", - "plot([-plot_scale, plot_scale], [-plot_scale, plot_scale], 'k')\n", - "xlabel('True beta')\n", - "ylabel('Estimated beta')\n", - "legend(loc='best')" + "plt.plot([-plot_scale, plot_scale], [-plot_scale, plot_scale], 'k')\n", + "plt.xlabel('True beta')\n", + "plt.ylabel('Estimated beta')\n", + "plt.legend(loc='best')" ] - }, - { - "cell_type": "code", - "execution_count": 0, - "metadata": { - "id": "_bXdOlvUEJl0" - }, - "outputs": [], - "source": [] } ], "metadata": { diff --git a/docs/notebooks/vmapped_log_probs.md b/docs/notebooks/vmapped_log_probs.md index ea8b4fce2f70..3f836e680e88 100644 --- a/docs/notebooks/vmapped_log_probs.md +++ b/docs/notebooks/vmapped_log_probs.md @@ -5,7 +5,7 @@ jupytext: extension: .md format_name: myst format_version: 0.13 - jupytext_version: 1.16.1 + jupytext_version: 1.16.4 kernelspec: display_name: Python 3 language: python @@ -14,11 +14,11 @@ kernelspec: +++ {"id": "6umP1IKf4Dg6"} -# Autobatching for Bayesian Inference +# Autobatching for Bayesian inference -[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/jax/blob/main/docs/notebooks/vmapped_log_probs.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/google/jax/blob/main/docs/notebooks/vmapped_log_probs.ipynb) +[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jax-ml/jax/blob/main/docs/notebooks/vmapped_log_probs.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/jax-ml/jax/blob/main/docs/notebooks/vmapped_log_probs.ipynb) This notebook demonstrates a simple Bayesian inference example where autobatching makes user code easier to write, easier to read, and less likely to include bugs. @@ -27,17 +27,10 @@ Inspired by a notebook by @davmre. ```{code-cell} ipython3 :id: 8RZDkfbV3zdR -import functools -import itertools -import re -import sys -import time - -from matplotlib.pyplot import * +import matplotlib.pyplot as plt import jax -from jax import lax import jax.numpy as jnp import jax.scipy as jsp from jax import random @@ -192,7 +185,7 @@ batched_log_joint = jax.jit(jax.vmap(log_joint)) def elbo(beta_loc, beta_log_scale, epsilon): beta_sample = beta_loc + jnp.exp(beta_log_scale) * epsilon return jnp.mean(batched_log_joint(beta_sample), 0) + jnp.sum(beta_log_scale - 0.5 * np.log(2*np.pi)) - + elbo = jax.jit(elbo) elbo_val_and_grad = jax.jit(jax.value_and_grad(elbo, argnums=(0, 1))) ``` @@ -240,19 +233,13 @@ Coverage isn't quite as good as we might like, but it's not bad, and nobody said :id: zt1NBLoVHtOG :outputId: fb159795-e6e7-497c-e501-9933ec761af4 -figure(figsize=(7, 7)) -plot(true_beta, beta_loc, '.', label='Approximated Posterior Means') -plot(true_beta, beta_loc + 2*jnp.exp(beta_log_scale), 'r.', label='Approximated Posterior $2\sigma$ Error Bars') -plot(true_beta, beta_loc - 2*jnp.exp(beta_log_scale), 'r.') +plt.figure(figsize=(7, 7)) +plt.plot(true_beta, beta_loc, '.', label='Approximated Posterior Means') +plt.plot(true_beta, beta_loc + 2*jnp.exp(beta_log_scale), 'r.', label=r'Approximated Posterior $2\sigma$ Error Bars') +plt.plot(true_beta, beta_loc - 2*jnp.exp(beta_log_scale), 'r.') plot_scale = 3 -plot([-plot_scale, plot_scale], [-plot_scale, plot_scale], 'k') -xlabel('True beta') -ylabel('Estimated beta') -legend(loc='best') -``` - -```{code-cell} ipython3 -:id: _bXdOlvUEJl0 - - +plt.plot([-plot_scale, plot_scale], [-plot_scale, plot_scale], 'k') +plt.xlabel('True beta') +plt.ylabel('Estimated beta') +plt.legend(loc='best') ``` diff --git a/docs/pallas/CHANGELOG.md b/docs/pallas/CHANGELOG.md index e43a178db50e..43ba3ebd6afb 100644 --- a/docs/pallas/CHANGELOG.md +++ b/docs/pallas/CHANGELOG.md @@ -11,15 +11,31 @@ For the overall JAX change log see [here](https://jax.readthedocs.io/en/latest/c Remember to align the itemized text with the first line of an item within a list. --> -## Released with jax 0.4.32 +## Released with jax 0.4.34 * Changes - * The kernel function is not allowed to close over constants. Instead, all the needed arrays - must be passed as inputs, with proper block specs ({jax-issue}`#22746`). + + * {func}`jax.experimental.pallas.debug_print` no longer requires all arguments + to be scalars. The restrictions on the arguments are backend-specific: + Non-scalar arguments are currently only supported on GPU, when using Triton. * Deprecations -* New functionality: +* New functionality + + * {func}`jax.experimental.pallas.pallas_call` now accepts `scratch_shapes`, + a PyTree specifying backend-specific temporary objects needed by the + kernel, for example, buffers, synchronization primitives etc. + +## Released with jax 0.4.33 (September 16, 2024) + +## Released with jax 0.4.32 (September 11, 2024) + +* Changes + * The kernel function is not allowed to close over constants. Instead, all the needed arrays + must be passed as inputs, with proper block specs ({jax-issue}`#22746`). + +* New functionality * Improved error messages for mistakes in the signature of the index map functions, to include the name and source location of the index map. @@ -36,18 +52,14 @@ Remember to align the itemized text with the first line of an item within a list * The method `compute_index` of {class}`jax.experimental.pallas.GridSpec` has been removed because it is private. Similarly, the `get_grid_mapping` and `unzip_dynamic_bounds` have been removed from `BlockSpec` ({jax-issue}`#22593`). - * Fixed the interpreter mode to work with BlockSpec that involve padding + * Fixed the interpret mode to work with BlockSpec that involve padding ({jax-issue}`#22275`). - Padding in interpreter mode will be with NaN, to help debug out-of-bounds + Padding in interpret mode will be with NaN, to help debug out-of-bounds errors, but this behavior is not present when running in custom kernel mode, and should not be depended on. * Previously it was possible to import many APIs that are meant to be private, as `jax.experimental.pallas.pallas`. This is not possible anymore. - -* Deprecations - - * New Functionality * Added documentation for BlockSpec: {ref}`pallas_grids_and_blockspecs`. * Improved error messages for the {func}`jax.experimental.pallas.pallas_call` @@ -73,7 +85,3 @@ Remember to align the itemized text with the first line of an item within a list * Added checkify support for {func}`jax.experimental.pallas.pallas_call` in interpret mode ({jax-issue}`#21862`). * Improved support for PRNG keys for TPU kernels ({jax-issue}`#21773`). - - - - diff --git a/docs/pallas/async_note.md b/docs/pallas/async_note.md new file mode 100644 index 000000000000..42e32a074fd7 --- /dev/null +++ b/docs/pallas/async_note.md @@ -0,0 +1,675 @@ +# Pallas Async Operations + +## Background \+ Motivation + +We’d like to expose APIs in Pallas to explicitly overlap computation and communication *across multiple kernels*. + +### XLA Async Decomposition + +As motivation, consider the following JAX pseudocode: + +```py +def f(x): + y = ppermute(x) + z = x + 1 + return y, z +``` + +In this function, we could perform the `ppermute` at the same time as the `x + 1`. This is an optimization XLA does automatically by: + +1. decomposing `ppermute` into a `ppermute_start` and `ppermute_done` op, which are connected via a future. +2. scheduling the `x + 1` between the `ppermute_start` and `ppermute_done`, + +resulting in the following program: + +```py +def f(x): + fut = ppermute_start(x) + z = x + 1 # happens at the same time as ppermute + y = ppermute_done(fut) + return y, z +``` + +### Async ops inside kernels + +Now imagine we aren’t using XLA’s `ppermute` but have our own custom Pallas `ppermute`. + +```py +def ppermute_kernel(x_ref, y_ref, send_sem, recv_sem): + right_neighbor = ... + descriptor = pltpu.make_async_remote_copy(x_ref, y_ref, send_sem, recv_sem, device_id=right_neighbor) + descriptor.start() + descriptor.wait_send() + descriptor.wait_recv() + +def ppermute(x): + return pl.pallas_call(ppermute_kernel, out_shape=x, ...)(x) +``` + +Currently, we cannot decompose `ppermute` into a `start/done` pair as XLA does, so instead we explicitly **fuse** the `x + 1` into the kernel. + +```py +def add_one(x_ref, z_ref): + z_ref[...] = x_ref[...] + 1 + +def ppermute_add_one_kernel(x_ref, y_ref, z_ref, send_sem, recv_sem): + right_neighbor = ... + descriptor = pltpu.make_async_remote_copy(x_ref, y_ref, send_sem, recv_sem, device_id=right_neighbor) + descriptor.start() + + # Explicitly schedule inner kernel between start/wait + pltpu.emit_pipeline(add_one)(x_ref, z_ref) + + descriptor.wait_send() + descriptor.wait_recv() + +def ppermute_and_add_one(x): + return pl.pallas_call(ppermute_add_one_kernel, out_shape=(x, x), ...)(x) + +``` + +The goal is to enable writing separate kernels for starting the `ppermute` and waiting on it to complete, so that we can use a regular old `x + 1` in between (or whatever compute we want). This makes the code more readable, maintainable, and less bug-prone. + +## How do we implement decomposed Pallas async operations (on TPU)? + +The main thing to figure out when implementing decomposed async operations in Pallas is what the `future` that is passed between them contains. Specifically, it must contain some important state about the operation happening in the background. + +If we look at the Pallas code, we can see that we need a “descriptor” to both start and wait on a remote copy. Can we plumb this descriptor out of the Pallas kernel, and then pass it into another one? Well kinda. The underlying TPU hardware tracks async op progress via a pair of semaphores: `send_sem` enables us to wait on when a device is done sending data to its neighbor and `recv_sem` tracks the data transfer sent to a device from their neighbor. If we imagine writing a start kernel and a done kernel, all we’d need to pass from the start to the done would be the semaphores and some information about how much to wait on those semaphores. + +We can do this via extending Pallas to support returning semaphores from kernels. + +```py +def ppermute_start_kernel( + in_ref, send_sem, recv_sem, out_ref, *, axis_name, +): + axis_size = jax.lax.psum(1, axis_name) + left_neighbor = jax.lax.rem( + jax.lax.axis_index(axis_name) - 1 + axis_size, axis_size + ) + right_neighbor = jax.lax.rem(jax.lax.axis_index(axis_name) + 1, axis_size) + barrier_sem = pltpu.get_barrier_semaphore() + pltpu.semaphore_signal(barrier_sem, device_id=left_neighbor) + pltpu.semaphore_wait(barrier_sem, 1) + pltpu.make_async_remote_copy( + in_ref, out_ref, send_sem, recv_sem, device_id=right_neighbor + ).start() + +def ppermute_start(x, *, axis_name) -> tuple[Semaphore, Semaphore, Array]: + send_sem, recv_sem, out = pl.pallas_call( + functools.partial(ppermute_start_kernel, axis_name=axis_name), + out_shape=( + pltpu.SemaphoreType.DMA(()), + pltpu.SemaphoreType.DMA(()), + jax.ShapeDtypeStruct( + x.shape, + dtype=x.dtype, + ), + ), + in_specs=[ + pl.BlockSpec(memory_space=pltpu.ANY), + ], + out_specs=( + pl.BlockSpec(memory_space=pltpu.SEMAPHORE), + pl.BlockSpec(memory_space=pltpu.SEMAPHORE), + pl.BlockSpec(memory_space=pltpu.ANY), + ), + )(x) + return send_sem, recv_sem, out +``` + +Note that something subtle is happening here. Pallas is telling XLA that it would like some outputs to be semaphores (a.k.a. sync flags) and XLA will treat them as “reserved” (e.g. while they are alive in the XLA program, those sync flags cannot be allocated by other kernels). They behave similarly to barrier semaphores, which are reserved semaphores managed by XLA. + +Another thing to notice is that we return the output buffer `out` from the start kernel *while it’s being actively copied into*. + +Now we write the `done` kernel that performs the blocking operation. We pass `out` into the kernel to compute the shape needed to block on the semaphore. + +```py +def ppermute_done_kernel(ref, send_sem, recv_sem, _): + pltpu.make_async_copy(ref, ref, send_sem).wait() + pltpu.make_async_copy(ref, ref, recv_sem).wait() + +def ppermute_done(send_sem, recv_sem, out) ->Array: + out = pl.pallas_call( + ppermute_done_kernel, + out_shape=( + jax.ShapeDtypeStruct( + out.shape, + dtype=out.dtype, + ), + ), + in_specs=[ + pl.BlockSpec(memory_space=pltpu.ANY), + pl.BlockSpec(memory_space=pltpu.SEMAPHORE), + pl.BlockSpec(memory_space=pltpu.SEMAPHORE), + ], + out_specs=pl.BlockSpec(memory_space=pltpu.ANY), + input_output_aliases={0:0} + )(out, send_sem, recv_sem) + return out +``` + +Note: we i/o alias the output buffer here to guarantee that the consumers are downstream of the `ppermute_done`. + +We now can implement the decomposed collective permute. + +```py +def f(x): + fut = ppermute_start(x) + z = x + 1 # happens at the same time as ppermute + y = ppermute_done(fut) + return y, z +``` + +***OR CAN WE?*** + +## Why *doesn’t* this work? + +There are three remaining issues with this, each of which exists outside of Pallas to some degree. Here they are at a high level. + +1. Scheduling \- just because we write `ppermute_start`, then `x + 1`, then `ppermute_done` doesn’t guarantee that they will happen in that order. XLA is responsible for scheduling, so when we write JAX programs, we are setting up data dependencies that XLA will respect but XLA will not respect the specific order of operations written in JAX. +2. Lifetimes \- XLA assumes that once a value is out of scope in the dependency graph, its memory can be freed for use by other values. If we have an op that asynchronously copies x \-\> y, we need to ensure that x is alive until the copy is complete, otherwise we will be copying from garbage memory. +3. Defensive copies \- XLA reserves the right to create copies of values. We need to make sure we don’t introduce unnecessary copies to a) avoid unnecessary runtime overhead and b) ensure correctness. + +We will go over these issues one by one and suggest fixes. + +### Scheduling + +How do we explicitly force ops to happen in a particular order in JAX? Note that this is not a Pallas specific problem, and if we had async ops implemented using an alternative method, we’d still run into this. + +One way is to introduce an optimization barrier into the XLA program. The optimization barrier will prevent XLA moving ops around it. + +Here’s our original code: + +```py +def f(x): + fut = ppermute_start(x) + z = x + 1 + y = ppermute_done(fut) + return y, z +``` + +XLA could choose to execute `x + 1` in any of three places: + +```py +def f(x): + z = x + 1 + fut = ppermute_start(x) + y = ppermute_done(fut) + return y, z + +# OR + +def f(x): + fut = ppermute_start(x) + z = x + 1 + y = ppermute_done(fut) + return y, z + +# OR + +def f(x): + fut = ppermute_start(x) + y = ppermute_done(fut) + z = x + 1 + return y, z +``` + +To force the `x + 1` to happen between the `ppermute` ops, we can use `optimization_barrier`, which is semantically the identity function (i.e. `lambda x: x`) but introduces an explicit data dependency between values. Specifically, if we make the `x` that is used in `x + 1` dependent on the `fut` returned by `ppermute_start`, it must happen after `ppermute_start`. + +We also introduce a dependency that forces the output value `y` to depend on `z`. + +```py +def f(x): + fut = ppermute_start(x) + x, fut = optimization_barrier((x, fut)) # x now depends on fut + z = x + 1 + z, fut = optimization_barrier((z, fut)) # fut now depends on z + y = ppermute_done(fut) + return y, z +``` + +`optimization_barrier` is a good enough hammer for us to explicitly write out schedules. + +### Lifetimes + +Let’s look at our original code again and assume the ops are happening in the correct order. + +```py +def f(x): + fut = ppermute_start(x) + z = x + 1 + y = ppermute_done(fut) + return y, z +``` + +Let’s look at which point in the program XLA believes it is okay to free the buffer for `x`. It would be the point after which `x` is no longer used, specifically after `z = x + 1`. + +```py +def f(x): + fut = ppermute_start(x) + z = x + 1 + # XLA can free x here! + y = ppermute_done(fut) + return y, z +``` + +If XLA frees `x` after `z = x + 1` has completed, we run into a very bad problem. The `ppermute` could still be actively copying `x` to the neighbor after `z = x + 1` which means if `x` is freed, the `ppermute` will be reading from garbage memory\! + +How do we extend `x`’s lifetime to the `ppermute_done`? Well we can introduce a data dependency\! We need to modify our kernels a little bit to make this happen. + +First, we rewrite `ppermute_start` to return `x`, aliasing it through the kernel. + +```py +def ppermute_start_kernel( + in_ref, send_sem, recv_sem, out_ref, _, *, axis_name, +): + axis_size = jax.lax.psum(1, axis_name) + left_neighbor = jax.lax.rem( + jax.lax.axis_index(axis_name) - 1 + axis_size, axis_size + ) + right_neighbor = jax.lax.rem(jax.lax.axis_index(axis_name) + 1, axis_size) + barrier_sem = pltpu.get_barrier_semaphore() + pltpu.semaphore_signal(barrier_sem, device_id=left_neighbor) + pltpu.semaphore_wait(barrier_sem, 1) + pltpu.make_async_remote_copy( + in_ref, out_ref, send_sem, recv_sem, device_id=right_neighbor + ).start() + +def ppermute_start(x, *, axis_name) -> tuple[Semaphore, Semaphore, Array, Array]: + send_sem, recv_sem, x, out = pl.pallas_call( + functools.partial(ppermute_start_kernel, axis_name=axis_name), + out_shape=( + pltpu.SemaphoreType.DMA(()), + pltpu.SemaphoreType.DMA(()), + jax.ShapeDtypeStruct( + x.shape, + dtype=x.dtype, + ), + jax.ShapeDtypeStruct( + x.shape, + dtype=x.dtype, + ), + ), + in_specs=[ + pl.BlockSpec(memory_space=pltpu.ANY), + ], + out_specs=( + pl.BlockSpec(memory_space=pltpu.SEMAPHORE), + pl.BlockSpec(memory_space=pltpu.SEMAPHORE), + pl.BlockSpec(memory_space=pltpu.ANY), + pl.BlockSpec(memory_space=pltpu.ANY), + ), + input_output_aliases={0:2} + )(x) + return send_sem, recv_sem, x, out +``` + +We then have `ppermute_done` take in `x` and do nothing with it. + +```py +def ppermute_done_kernel(_, ref, send_sem, recv_sem, _): + pltpu.make_async_copy(ref, ref, send_sem).wait() + pltpu.make_async_copy(ref, ref, recv_sem).wait() + +def ppermute_done(send_sem, recv_sem, x, out) ->Array: + out = pl.pallas_call( + ppermute_done_kernel, + out_shape=( + jax.ShapeDtypeStruct( + out.shape, + dtype=out.dtype, + ), + ), + in_specs=[ + pl.BlockSpec(memory_space=pltpu.ANY), + pl.BlockSpec(memory_space=pltpu.ANY), + pl.BlockSpec(memory_space=pltpu.SEMAPHORE), + pl.BlockSpec(memory_space=pltpu.SEMAPHORE), + ], + out_specs=pl.BlockSpec(memory_space=pltpu.ANY), + input_output_aliases={1:0} + )(x, out, send_sem, recv_sem) + return out + +``` + +Now when we write + +```py +def f(x): + *sems, x ,out = ppermute_start(x) + z = x + 1 + y = ppermute_done(*sems, x, out) + return y, z +``` + +XLA can no longer free `x` because it is an input to `ppermute_done`\! This means that `x`’s lifetime is tied to the `ppermute` and this code is now correct. + +### Defensive copies + +XLA, in its buffer assignment pass, analyzes which buffers are aliased to each other and inserts copies whenever an operation that aliases one of its inputs is not the final consumer of that input. + +#### Background + +Here’s a simple example. Let’s say we have an op `add_one_inplace` which takes in an array and adds one, but promises to do it in-place. + +The following code would be legal. + +```py +def f(): + x = jnp.arange(...) + y = add_one_inplace(x) return y +``` + +However, if `x` had a separate consumer as well, the program may not execute correctly. + +```py +def f(): + x = jnp.arange(...) + y = add_one_inplace(x) + return y, x * 2 # another x consumer! +``` + +This is because `x * 2` operates on the original `x` but `add_one_inplace` clobbers the value in `x`. `x * 2` needs to make sure to read the original values of `x`, not the ones after we’ve incremented it by 1\. XLA notices this and inserts a `copy` op (which is semantically the identity but the input and output buffers will be different). + +```py +def f(x): + x2 = copy(x) + y = add_one_inplace(x2) + return y, x * 2 +``` + +This pass in XLA ensures correctness in the presence of ops that perform in-place updates by forcing them to effectively be out-of-place with `copy` ops. + +#### Copies with downstream ops + +Let’s revisit our example where we add 1 while `ppermute`ing. + +```py +def f(x): + fut = ppermute_start(x) + z = x + 1 + y = ppermute_done(fut) + return y, z +``` + +If we unpack the future into its components, we’ll see the the aliasing patterns: + +```py +def f(x): + *sems, x2, y = ppermute_start(x) + z = x + 1 + y = ppermute_done((*sems, x2, y)) + return y, z +``` + +We know that `x` is left unchanged by `ppermute_start` (that is, `x` is identical to `x2`), but XLA does not. In fact, it looks like our `add_one_inplace` example to XLA, where it conservatively assumes that `ppermute_start` mutated `x` and `x2` is the new aliased result. Therefore, when we do `z = x + 1`, we run into a consumer of the original buffer. XLA therefore introduces a copy\! + +```py +def f(x): + x2 = copy(x) + *sems, x2, y = ppermute_start(x2) + z = x + 1 + y = ppermute_done((*sems, x2, y)) + return y, z +``` + +This copy is unnecessary because we know that `x2` is unchanged from `x`. In order to remove this copy, we’d need some mechanism to inform XLA we are just forwarding a value. However, in the absence of that we can rewrite our program a bit to explicitly use `x2` instead of `x`. + +```py +def f(x): + *sems, x2, y = ppermute_start(x) + z = x2 + 1 + y = ppermute_done((*sems, x2, y)) + return y, z +``` + +Now, XLA doesn’t see a separate consumer of `x` so no more copy is introduced. However, this comes at a major downside in that it forces us to unpack the future coming from `ppermute_start`. It couples the lifetime problem to the copying problem. + +#### Loop aliasing + +Let’s consider a slightly more advanced example. Let’s implement a function that uses a `while_loop` with `ppermute` to send values around a ring. + +```py +def f(x): + def body(i, x): + fut = ppermute_start(x) + y = ppermute_done(fut) + return y + return fori_loop(0, 8, body, x) +``` + +One implementation detail of `fori_loop` is that the inputs and outputs buffers are automatically aliased to each other. Note that we are setting up some additional aliasing in the `ppermute_start` and `ppermute_done` ops. Let’s run our own “buffer assignment” by coloring each of the values in the program to determine how many unique buffers we need. + +First, we’ll unpack the `fut` tuple that has the aliased `x` and `out` buffers. + +```py +def f(x): + def body(i, x): + *sems, x, y = ppermute_start(x) + y = ppermute_done(*sems, x, y) + return y + return fori_loop(0, 8, body, x) +``` + +Let’s now color each of the values according to the unique buffer they are assigned. We have the input/output aliasing coming from `fori_loop`, the `x` aliasing coming from `ppermute_start` and the `y` aliasing coming from `ppermute_done`. + +```py +def f(x): + def body(i, x): + *sems, x, y = ppermute_start(x) + y = ppermute_done((*sems, x, y)) + return y + return fori_loop(0, 8, body, x) +``` + +If you run the alias analysis, you’ll find that all of the buffers have been colored the same\! Intuitively, this is problematic because if we are doing a loop of `ppermute`s, we can’t write into the same buffer we are sending into. We generally need an extra (i.e. a “double”) buffer to receive, and then usually we will switch the send/recv buffers on the next iteration. What XLA will do in practice is that it will observe the buffer re-use and defensively insert a copy. + +```py +def f(x): + def body(i, x): + x = copy(x) + *sems, x, y = ppermute_start(x) + y = ppermute_done((*sems, x, y)) + return y + return fori_loop(0, 8, body, x) +``` + +This copy means `x` and `y` are no longer aliased to each other and the program will be correct. However, do we need this copy? How do we introduce a double buffer to avoid expensive copies each iteration? The answer is unrolling\! + +We’ll manually unroll our code. + +```py +def f(x): + def body(i, x): + *sems, x, x2 = ppermute_start(x) + x2 = ppermute_done((*sems, x, x2)) + + *sems, x2, y = ppermute_start(x2) + y = ppermute_done((*sems, x2, y)) + return y + return fori_loop(0, 4, body, x) +``` + +Now if we were to run the same alias analysis, we’ll find that the buffers all no longer alias to each other and that we won’t need to insert defensive copies to be correct. + +Therefore, the simple solution to removing these copies is to use `fori_loop` with `unroll >= 2`. + +```py +def f(x): + def body(i, x): + fut = ppermute_start(x) + y = ppermute_done(fut) + return y + return fori_loop(0, 8, body, x, unroll=2) +``` + +That’s sufficient to implement this loop without extra copies\! + +#### Passing futures across loop boundaries + +Let’s now look at an even more advanced example. We’ll implement the same program as before but stagger the loop, where we begin the `ppermute` in a prologue before the loop, and wait on the `ppermute` at the beginning of the loop. + +```py +def f(x): + fut = ppermute_start(x) + def body(i, fut): + x = ppermute_done(fut) + fut = ppermute_start(x) + return fut + fut = fori_loop(0, 7, body, fut) + return ppermute_done(fut) +``` + +In this example, rather than passing a value `x` from one loop to another we are passing a future value. + +Let’s unpack the future again to see what’s happening. + +```py +def f(x): + fut = ppermute_start(x) + def body(i, fut): + *sems, x, out = fut + x = ppermute_done((*sems, x, out)) + (*sems, x, out) = ppermute_start(x) + return (*sems, x, out) + (*sems, x, out) = fori_loop(0, 7, body, x) + return ppermute_done((*sems, x, out)) +``` + +So we’re explicitly threading the semaphores, the input buffer, and the target output buffer as a loop carry. What happens if we run alias analysis now? Well, we’ll run into the same aliasing issue as in the previous section where `x` and `out` will be aliased to each other. XLA will introduce a copy. + +```py +def f(x): + fut = ppermute_start(x) + def body(i, fut): + *sems, x, out = fut + out = copy(out) + x = ppermute_done((*sems, x, out)) + (*sems, x, out) = ppermute_start(x) + return (*sems, x, out) + (*sems, x, out) = fori_loop(0, 7, body, x) + return ppermute_done((*sems, x, out)) +``` + +In this case, we inserted a copy on `out`. However, this is a really bad scenario because `out` is being actively copied into\! Even if we insert a copy on `x`, we will also run into issues because then `x`’s lifetime will not extend to the `ppermute_done`. This is very very bad\! We will not only get copies, but we will also get incorrect results\! + +The solution, as we observed before, is to avoid the copies by avoiding aliasing all the buffers via unrolling. So, if we do: + +```py +def f(x): + fut = ppermute_start(x) + def body(i, fut): + x = ppermute_done(fut) + fut = ppermute_start(x) + return fut + fut = fori_loop(0, 7, body, x, unroll=2) + return ppermute_done(fut) +``` + +our program should now be correct. + +### Putting it all together + +So we’ve come up with some rules of thumb: + +1. If we have operations dependent on the input value to the `ppermute`, unpack the future to use the aliased value instead of the original value. +2. Use `unroll >= 2` when doing `ppermute`s in a loop body. + +Let’s combine everything into one function that does `ppermute`s in a loop and accumulates the result. + +```py +def f(x): + out = jnp.zeros_like(x) + fut = (*sems, x, out) = ppermute_start(x) + out = out + x + def body(i, carry): + out, fut = carry + x = ppermute_done(fut) + fut = (*sems, x, out) = ppermute_start(x) + out = out + x + return out, fut + out, fut = fori_loop(0, 7, body, (out, fut), unroll=2) + return out, ppermute_done(fut) +``` + +Note that in this example, we don’t need `optimization_barrier`s because the loop boundary acts as a scheduling barrier, splitting up the `start`s and `done`s. + +That’s it, we are done\! This will be the official API for doing async ops in Pallas. Thank you everyone\! Mission accomplished\! + +***OR IS IT?*** + +## Revenge of the State + +While it seems we have worked around copies and incorrectness issues by using some clever tricks, we are still in an awkward position. This API is powerful, but has many many footguns and caveats. There are likely far many more edge cases we will need to deal with that even require deep knowledge of XLA to predict or understand. Should we release an API like this? Or is there an alternative? + +Well, the answer may have been in front of us this whole time. + +Let’s run through this whole exercise one more time, *except*, let’s write the stateful version. This means each of our custom async ops now operate on `Ref`s instead of values. + +```py +def ppermute_start_stateful(x_ref, y_ref) -> tuple[Semaphore, Semaphore]: + ... + +def ppermute_done_stateful(send_sem, recv_sem, x_ref, y_ref) -> None: + ... +``` + +Let’s assume we can implement these in Pallas and see what our new programs will look like. Let’s start with a basic collective permute: + +```py +def f(x): + x_ref = make_ref(x) + y_ref = make_ref(zeros_like(x)) + fut = ppermute_start_stateful(x_ref, y_ref) + ppermute_done_stateful(*fut, x_ref, y_ref) + return y_ref[...] +``` + +It’s a little bit more verbose than our original value-based version, but it has a few key differences. The first is that we create an “empty” `Ref` to receive the result of the `ppermute`, unlike the value-based version, which creates a value for us. One neat thing is that the lifetime of `x_ref` is clear here: it lives until `ppermute_done_stateful`. We don’t need to “sneak” the `x` value into the op like we did before. + +Another difference becomes more clear when we try adding an op between the `start/done`. + +```py +def f(x): + x_ref = make_ref(x) + y_ref = make_ref(zeros_like(x)) + fut = ppermute_start_stateful(x_ref, y_ref) + x_ref[...] += 1 + ppermute_done_stateful(*fut, x_ref, y_ref) + return y_ref[...] +``` + +Before, we ran into scheduling ambiguity, where XLA could re-order the add w.r.t. the `ppermute`. With stateful semantics, we actually add in an ordering constraint\! `x_ref[...] += 1` mutates `x_ref` so it can’t be moved wrt to `ppermute_done_stateful`. JAX can inject these scheduling constraints as part of the lowering to HLO. + +The final key difference is evident when we try our loop examples. + +```py +def f(x): + x_ref = make_ref(x) + y_ref = make_ref(zeros_like(x)) + def body(i, _): + fut = ppermute_start_stateful(x_ref, y_ref) + ppermute_done_stateful(*fut, x_ref, y_ref) + # Now switch to y_ref -> x_ref + fut = ppermute_start_stateful(y_ref, x_ref) + ppermute_done_stateful(*fut, y_ref, x_ref) + fori_loop(0, 8 // 2, body, None) + return x_ref[...] +``` + +Because of the requirement that we have a separate buffer ready to receive the `ppermute`, we were forced to write our code in such a way that unrolls it\! There is no way to write the version in XLA that requires copying because that would involve a `ppermute` that sends from a `Ref` into itself, which doesn’t really make sense. + +To handle this without the manual unrolling, we’d create a scratch buffer with a leading `2` dimension that acts as the send/recv target across iterations, switching each one. This is the same pattern we use internally in Pallas kernels when writing manually overlapped kernels. + +The realization here is that being stateful forces us to deal with a lot of the issues that pop up with value semantics earlier on. We define them away\! + +1. Scheduling \- stateful ops that have `Ref`s as inputs force an ordering of our program. Note that this will schedule operations on the same `Ref` wrt to each other. We might also need an `opt_barrier_stateful` to enforce more ordering constraints. +2. Lifetimes \- `Ref` lifetimes can be scoped via `run_state` or could be inputs to stateful ops. +3. Defensive copies \- Using `Ref`s forces us to handle buffer assignment “manually” and the lowering can ensure the aliasing works out to avoid any copies. + +Another important fundamental limitation is that we eventually stage out an HLO program where the live buffers and semaphores are represented as array value types. XLA does not provide guarantees about buffer lifetimes or which memory spaces they live in for these intermediate values. *Therefore, it is possible XLA can copy array values even if they are actively being copied into by Pallas kernels.* This is easy to verify in HLO but it is a sharp edge of using custom calls to represent asynchronous operations in HLO. + +## Conclusion + +We’ve gone over some tricky challenges when it comes to async ops in Pallas and JAX. `Ref`s seem like a promising way of representing these ops that circumvents some of the issues that come up with value semantics. However, a downside is that it puts stateful JAX front and center, which we haven’t done yet outside of Pallas. It’s worth thinking whether we should educate users about stateful ops, or provide a more dangerous API. We also don’t know if everything we want to do is expressible via `Ref`s as well. We should also brainstorm alternatives to state to flesh out the design space. For example, what if XLA offered a first-class futures API that respected lifetimes, and it could automatically do things like double buffer loops with futures in them? That might be a viable alternative but the tradeoff would be giving more control to the compiler vs explicit control from the user. diff --git a/docs/pallas/index.rst b/docs/pallas/index.rst index 403a8ce9c620..5969349c962a 100644 --- a/docs/pallas/index.rst +++ b/docs/pallas/index.rst @@ -3,6 +3,9 @@ Pallas: a JAX kernel language ============================= Pallas is an extension to JAX that enables writing custom kernels for GPU and TPU. +It aims to provide fine-grained control over the generated code, combined with +the high-level ergonomics of JAX tracing and the `jax.numpy` API. + This section contains tutorials, guides and examples for using Pallas. See also the :class:`jax.experimental.pallas` module API documentation. @@ -10,6 +13,10 @@ See also the :class:`jax.experimental.pallas` module API documentation. Pallas is experimental and is changing frequently. See the :ref:`pallas-changelog` for the recent changes. + You can expect to encounter errors and unimplemented cases, e.g., when + lowering of high-level JAX concepts that would require emulation, + or simply because Pallas is still under development. + .. toctree:: :caption: Guides :maxdepth: 2 @@ -26,6 +33,13 @@ See also the :class:`jax.experimental.pallas` module API documentation. tpu/index .. toctree:: + :caption: Design Notes + :maxdepth: 1 + + async_note + +.. toctree:: + :caption: Other :maxdepth: 1 CHANGELOG diff --git a/docs/pallas/quickstart.ipynb b/docs/pallas/quickstart.ipynb index 5a8608f494c3..0e759a493a61 100644 --- a/docs/pallas/quickstart.ipynb +++ b/docs/pallas/quickstart.ipynb @@ -282,6 +282,35 @@ "On TPUs, programs are executed in a combination of parallel and sequential\n", "(depending on the architecture) so there are slightly different considerations.\n", "\n", + "To call the above kernel on TPU, run:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "796f928c", + "metadata": {}, + "outputs": [], + "source": [ + "from jax.experimental.pallas import tpu as pltpu\n", + "\n", + "def iota(size: int):\n", + " return pl.pallas_call(iota_kernel,\n", + " out_specs=pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.SMEM),\n", + " out_shape=jax.ShapeDtypeStruct((size,), jnp.int32),\n", + " grid=(size,))()\n", + "iota(8)" + ] + }, + { + "cell_type": "markdown", + "id": "68f97b4e", + "metadata": {}, + "source": [ + "TPUs distinguish between vector and scalar memory spaces and in this case the\n", + "output must be placed in scalar memory (`TPUMemorySpace.SMEM`) since `i` is\n", + "a scalar. For more details read {ref}`pallas_tpu_pipelining`.\n", + "\n", "You can read more details at {ref}`pallas_grid`." ] }, diff --git a/docs/pallas/quickstart.md b/docs/pallas/quickstart.md index 36cc14bf5c34..a8b13ea38eaf 100644 --- a/docs/pallas/quickstart.md +++ b/docs/pallas/quickstart.md @@ -5,7 +5,7 @@ jupytext: extension: .md format_name: myst format_version: 0.13 - jupytext_version: 1.16.1 + jupytext_version: 1.16.4 kernelspec: display_name: Python 3 (ipykernel) language: python @@ -188,6 +188,23 @@ operations like matrix multiplications really quickly. On TPUs, programs are executed in a combination of parallel and sequential (depending on the architecture) so there are slightly different considerations. +To call the above kernel on TPU, run: + +```{code-cell} ipython3 +from jax.experimental.pallas import tpu as pltpu + +def iota(size: int): + return pl.pallas_call(iota_kernel, + out_specs=pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.SMEM), + out_shape=jax.ShapeDtypeStruct((size,), jnp.int32), + grid=(size,))() +iota(8) +``` + +TPUs distinguish between vector and scalar memory spaces and in this case the +output must be placed in scalar memory (`TPUMemorySpace.SMEM`) since `i` is +a scalar. For more details read {ref}`pallas_tpu_pipelining`. + You can read more details at {ref}`pallas_grid`. +++ diff --git a/docs/pallas/tpu/details.rst b/docs/pallas/tpu/details.rst index ae9505c4eb8b..4a2d4daa637f 100644 --- a/docs/pallas/tpu/details.rst +++ b/docs/pallas/tpu/details.rst @@ -21,7 +21,7 @@ software emulation, and can slow down the computation. If you see unexpected outputs, please compare them against a kernel run with ``interpret=True`` passed in to ``pallas_call``. If the results diverge, - please file a `bug report `_. + please file a `bug report `_. What is a TPU? -------------- @@ -148,10 +148,8 @@ grid axes over cores. This is an opt-in procedure. To allow that, .. pallas_call( ..., - compiler_params=dict( - mosaic=dict( - dimension_semantics=["parallel", "parallel", "arbitrary"] - ) + compiler_params=pltpu.TPUCompilerParams( + dimension_semantics=["parallel", "parallel", "arbitrary"] ), ) diff --git a/docs/pallas/tpu/distributed.ipynb b/docs/pallas/tpu/distributed.ipynb new file mode 100644 index 000000000000..8552e10d8552 --- /dev/null +++ b/docs/pallas/tpu/distributed.ipynb @@ -0,0 +1,1743 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "zSNjLhGQJMgq" + }, + "source": [ + "# Distributed Computing in Pallas for TPUs\n", + "\n", + "In this tutorial, we will cover the basics of distributed computing in Pallas on TPUs. We will learn about TPU topologies, communication using the remote DMA primitive, and calling a distributed kernel from JAX using `shard_map`. We will also cover some more advanced kernel writing techniques, such as double-buffering, bi-directional bandwidth optimization, and nested pipelining. As educational examples, we will learn how to implement various collective primitives from JAX, such as `lax.ppermute`, `lax.all_gather`, `lax.psum`, and `lax.psum_scatter`.\n", + "\n", + "Some recommended readings beforehand:\n", + " - [Pallas Pipelining on TPU](https://jax.readthedocs.io/en/latest/pallas/tpu/pipelining.html)\n", + " - [Collectives with `shard_map`](https://jax.readthedocs.io/en/latest/notebooks/shard_map.html#collectives-tutorial)" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": { + "executionInfo": { + "elapsed": 1978, + "status": "ok", + "timestamp": 1722904801801, + "user": { + "displayName": "Justin Fu", + "userId": "17543197034567316452" + }, + "user_tz": 420 + }, + "id": "PyAGnWc9yI8T", + "outputId": "1d8229bd-cab5-495f-93e9-fff2e41db480" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Running with 4 TPU v5 lite devices.\n" + ] + } + ], + "source": [ + "import jax\n", + "from jax import lax\n", + "from jax import numpy as jnp\n", + "from jax.experimental import mesh_utils\n", + "from jax.experimental import pallas as pl\n", + "from jax.experimental import shard_map\n", + "from jax.experimental.pallas import tpu as pltpu\n", + "\n", + "P = jax.sharding.PartitionSpec\n", + "\n", + "num_devices = jax.local_device_count()\n", + "assert num_devices > 1, \"Please run this notebook with more than one device.\"\n", + "assert \"TPU\" in jax.devices()[0].device_kind, \"Please run this notebook with TPU devices.\"\n", + "print(f\"Running with {num_devices} {jax.devices()[0].device_kind} devices.\")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "DySMGNByclMi" + }, + "source": [ + "## TPU Topologies\n", + "\n", + "TPUs are typically deployed in pods of multiple devices connected via a high-bandwidth interchip interconnect (ICI) for communication within the pod that is much faster than a typical network connection. For example, the specifications sheet for a [TPU v5p](https://cloud.google.com/tpu/docs/v5p) states an ICI bandwidth of 4.8Tb/s per chip (for reference, TPU v5p also has 21Tb/s of *local* HBM bandwidth). The ICI allows us to implement fast and performant distributed kernels that require high-bandwidth communication within a pod, and use the datacenter network for parallelization over less bandwidth-intensive operations, such as data-parallelism over a batch dimension.\n", + "\n", + "TPUs pods are typically arranged in an ND torus topology. The following graphic gives several examples of configurations of different sizes.\n", + "\n", + "![tpu_topologies](https://cloud.google.com/static/tpu/docs/images/v4-topologies.png)\n", + "\n", + "Flattened as a graph, the torus can be visualized as follows. Each edge (orange or black) is a bidirectional connection between two devices. You will commonly hear about rings in conjunction with discussion about device toplogies — a key feature of a torus is that when taking a slice along an axis of the pod, such as the nodes `[(0,1), (1, 1), (2, 1), (3, 1)]` or `[(0, 1), (1, 1)]`, we have a ring of devices. This is a feature we can use to simplify communication patterns within the pod.\n", + "\n", + "![tpu_torus](https://cloud.google.com/static/tpu/docs/images/untwisted-tori.png)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "1Oc_WD1hChfN" + }, + "source": [ + "## Remote Direct Memory Access (RDMA) Model\n", + "\n", + "TPUs communicate via a push-only model known as a remote direct memory access (RDMA). A TPU is allowed to issue copy instruction to push from a local buffer to any buffer on another device within the same pod that executes asynchronously from the main program thread. However, a TPU can only read data that is stored locally. This is in contrast to more traditional multi-core programming where it is possible to both read from and write to values to a shared memory.\n", + "\n", + "### Async Remote Copy Operation\n", + "The `pltpu.make_async_remote_copy` function is used to create a remote DMA descriptor object which parameterizes both a \"send\" operation and a \"receive\" operation. Here's its signature:\n", + "\n", + "```python\n", + " def make_async_remote_copy(\n", + " src_ref: Ref,\n", + " dst_ref: Ref,\n", + " send_sem: Ref[SemaphoreType],\n", + " recv_sem: Ref[SemaphoreType],\n", + " device_id: int | tuple[int, ...],\n", + " device_id_type: DeviceIdType\n", + " ) -> AsyncCopyDescriptor:\n", + "```\n", + "\n", + "- `src_ref` is the local `Ref` (in any memory space) containing the data you wish to send to `dst_ref` on another device.\n", + "- `dst_ref` is the remote `Ref` (in any memory space) at which data will be copied to on the target device.\n", + "- `send_sem` is a DMA semaphore used to block until all data has been sent from `src_ref`.\n", + "- `recv_sem` is a DMA semaphore used to block until the expected number of bytes have been received at `dst_ref`. The sender of the DMA will write to the receiver's `recv_sem`.\n", + "- `device_id` is the device ID of the target device to send to.\n", + "- `device_id_type` specifies the format of `device_id`, which can either be in LOGICAL format (integer device ID), or in MESH format (an ND-tuple index into the logical device mesh). The default mode is MESH.\n", + "\n", + "`make_async_remote_copy` returns a descriptor object on which you use the `.start()` method to initiate the DMA, and the `.wait_send()` to block on `send_sem` and `.wait_recv()` to block on `recv_sem` (or `.wait()` to block on both). If a device is only expected to send data, it is sufficient to only call `.start()` and `.wait_send()`, and likewise if a device is only receiving it is sufficient to only call `.wait_recv()`. If using a SPMD pattern where all devices execute the DMA, each device will generally call both `.start()` and `.wait()`.\n", + "```python\n", + "dma_descriptor = make_async_remote_copy(src_ref, dst_ref, send_sem, recv_sem, device_id)\n", + "dma_descriptor.start() # Initiate the DMA (non-blocking).\n", + "# ... do other work\n", + "dma_descriptor.wait_send() # Block until all data has been sent.\n", + "dma_descriptor.wait_recv() # Block until all data has been received.\n", + "```\n", + "\n", + "As an example, let's visualize a DMA where we consider 4 devices (indexed 0, 1, 2, 3). We consider a scheme where device 0 copies to device 1, and device 2 & 3 copy to each other. In practice, we can create such an asymmetric communication pattern by using `@pl.when` to branch on the device ID.\n", + "\n", + "(1) Each device creates the DMA descriptor. Devices 0, 2, and 3 call `.start()` to initiate the DMA from `src_ref`. Device 1 is skips the `.start()` and does nothing, e.g. by using `pl.when`.\n", + "\n", + "![rdma_start](../../_static/pallas/distributed/rdma_start.svg)\n", + "\n", + "(2) As `.start()` is non-blocking, each device is free to do other computation while the DMA is in flight. Devices 0, 2, and 3 call `.wait_send()` to wait on `send_sem` which blocks until all data has been sent.\n", + "\n", + "![rdma_send](../../_static/pallas/distributed/rdma_send.svg)\n", + "\n", + "(3) Finally, devices 1, 2, and 3 will call `.wait_recv()` to wait on `recv_sem` until all data has arrived at `dst_ref`.\n", + "\n", + "![rdma_recv](../../_static/pallas/distributed/rdma_recv.svg)\n", + "\n", + "The above communication pattern can be written as follows:\n", + "```python\n", + "def example_kernel(input_ref, output_ref, send_sem, recv_sem):\n", + " device_id = lax.axis_index('x')\n", + " copy_0_to_1 = pltpu.make_async_remote_copy(\n", + " src_ref=input_ref,\n", + " dst_ref=output_ref,\n", + " send_sem=send_sem,\n", + " recv_sem=recv_sem,\n", + " device_id=1,\n", + " )\n", + " copy_2_to_3 = pltpu.make_async_remote_copy(\n", + " src_ref=input_ref,\n", + " dst_ref=output_ref,\n", + " send_sem=send_sem,\n", + " recv_sem=recv_sem,\n", + " device_id=3,\n", + " )\n", + " copy_3_to_2 = pltpu.make_async_remote_copy(\n", + " src_ref=input_ref,\n", + " dst_ref=output_ref,\n", + " send_sem=send_sem,\n", + " recv_sem=recv_sem,\n", + " device_id=2,\n", + " )\n", + " @pl.when(device_id == 0)\n", + " def _():\n", + " copy_0_to_1.start()\n", + " copy_0_to_1.wait_send()\n", + " @pl.when(device_id == 1)\n", + " def _():\n", + " copy_0_to_1.wait_recv()\n", + " @pl.when(device_id == 2)\n", + " def _():\n", + " copy_2_to_3.start()\n", + " copy_2_to_3.wait_send()\n", + " copy_3_to_2.wait_recv()\n", + " @pl.when(device_id == 3)\n", + " def _():\n", + " copy_3_to_2.start()\n", + " copy_3_to_2.wait_send()\n", + " copy_2_to_3.wait_recv()\n", + "```\n", + "\n", + "### DMA Semaphores\n", + "\n", + "`send_sem` and `recv_sem` are instances of a special type of semaphore reserved exclusively for use with DMAs. They must be allocated with the `tpu.SemaphoreType.DMA` type when specifying input specs to `pallas_call`.\n", + "\n", + "Internally, DMA semaphores can be thought of as integer-valued progress trackers. On DMA start, the local device will begin to increment the value of `send_sem` and the receiver's `recv_sem` asynchronously. Waiting on a semaphore will block until the value of the semaphore reaches the total bytes of data sent/received; when the value is reached, waiting threads are released and the sempahore's value is decremented by the same amount. This means that either all data has been sent (for `send_sem`) or all data has been received (for `dst_sem`). The value of the semaphore can be read with `pl.semaphore_read`, but note that the underlying semantics of the value could change between hardware generations (e.g. the value may not represent exactly the number of bytes sent, although this is a useful mental model to have when reasoning about the behavior of the semaphore).\n", + "\n", + "### Routing\n", + "\n", + "A sender is allowed to send data to any receiver within the same pod, even if they do not share a direct connection (the exception to this rule is for TPU v5e, where devices can only route to a power of 2 offset from themselves). TPUs have an internal routing mechanism which can pass data along to the next device on the path to the destination. However, communicating in this way is not recommended as you have no control over network contention as a kernel writer. The examples we will cover in this tutorial minimize inefficient communication by only transferring data to neighboring devices.\n", + "\n", + "### Failure modes\n", + "\n", + "If using remote DMAs incorrectly, you may encounter several failure modes which can be difficult to debug. The general symptoms of buggy DMA usage are crashes, hanging, or silent data corruption:\n", + "- If semaphores exit the program with an invalid non-zero value, Pallas will crash and exit the program.\n", + "- If semaphores are waited on but an insufficient number of bytes are received (i.e. there is no sender, or if the sent data is less than the size of `dst_ref` on the receiving device), the program may hang indefinitely waiting for bytes that are never sent. In this case the program would need to be restarted.\n", + "- If encountering a race condition, there could be silent data corruption if two simultaneous writes or a simultaneous read and write occur.\n", + "\n", + "Some common causes of the above include:\n", + "- If a device calls `.wait_recv()` but no other device sends to it, the kernel may hang.\n", + "- If a device is sent a more bytes than it expected to receive, it may also crash due to non-zero semaphore states. If sent less, it may hang indefinitely.\n", + "- If DMAs are started but the semaphores are not waited on, the program may crash due to non-zero semaphore states.\n", + "- If two devices copy to the same destination, you may encounter non-deterministic results due to a race condition, or crashing due to non-zero semaphore states." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "vpGSN1Sui0Bu" + }, + "source": [ + "### Example: Right Permute (`lax.ppermute`)\n", + "\n", + "Let's dive into a very basic example. We will implement a kernel that performs a right permutation, where each device sends its slice of the data to its right neighbor.\n", + "\n", + "Suppose we had an array with 512 elements, which we shard into slices of size 128 across 4 devices. Each device will pass its slice to the next device, and the output will consist of the same data, but with the slices rotated by 1. This is identical to the `lax.ppermute` operation where the permutation is set to `(n, (n+1) % 4)`.\n", + "\n", + "In order to call the kernel in distributed mode, we wrap the `pallas_call` in a `shard_map` transformation. From there, we can write the kernel the same way as you would write a normal single-device Pallas kernel, except we now have access to remote DMA instructions. JAX collective primitives such as `lax.axis_index` can be used to obtain a `device_id` that can be used to compute which target devices to copy to, by referencing the same named axes names passed into `shard_map`." + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": { + "executionInfo": { + "elapsed": 1606, + "status": "ok", + "timestamp": 1722904803566, + "user": { + "displayName": "Justin Fu", + "userId": "17543197034567316452" + }, + "user_tz": 420 + }, + "id": "YkyIKN2thZ-V", + "outputId": "9b7ed142-d161-4237-fed8-cbce41adc5f0" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Input = [0.9858954 0.11763906 0.9955574 0.775211 ]\n", + "Pallas Result = [0.775211 0.9858954 0.11763906 0.9955574 ]\n", + "lax.ppermute Result = [0.775211 0.9858954 0.11763906 0.9955574 ]\n", + "Difference |Pallas - lax.ppermute| = 0.0\n" + ] + } + ], + "source": [ + "partition = P(None, 'x')\n", + "devices = mesh_utils.create_device_mesh((1, num_devices))\n", + "mesh = jax.sharding.Mesh(devices, partition)\n", + "sharding = jax.sharding.NamedSharding(mesh, partition)\n", + "\n", + "# Create an input array that shards the last dimension across\n", + "# all devices.\n", + "input_arr = jax.random.uniform(jax.random.key(0), (8, 128 * num_devices))\n", + "input_arr = jax.device_put(input_arr, sharding)\n", + "\n", + "\n", + "def right_permute_kernel(input_ref, output_ref, send_sem, recv_sem):\n", + " my_id = lax.axis_index('x')\n", + " right_neighbor = lax.rem(my_id + 1, num_devices)\n", + " remote_copy_op = pltpu.make_async_remote_copy(\n", + " src_ref=input_ref,\n", + " dst_ref=output_ref,\n", + " send_sem=send_sem,\n", + " recv_sem=recv_sem,\n", + " device_id=(0, right_neighbor),\n", + " device_id_type=pltpu.DeviceIdType.MESH,\n", + " )\n", + " remote_copy_op.start()\n", + " remote_copy_op.wait()\n", + "\n", + "\n", + "out_shape = jax.ShapeDtypeStruct((8, 128), jnp.float32)\n", + "grid_spec = pltpu.PrefetchScalarGridSpec(\n", + " num_scalar_prefetch=0,\n", + " # TPUMemorySpace.ANY will (usually) place the tensor in HBM.\n", + " in_specs=[\n", + " pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY),\n", + " ],\n", + " out_specs=pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY),\n", + " scratch_shapes=(\n", + " # We allocate DMA semaphores in scratch memory.\n", + " [pltpu.SemaphoreType.DMA] * 2\n", + " ),\n", + ")\n", + "right_permute = pl.pallas_call(\n", + " right_permute_kernel,\n", + " out_shape=out_shape,\n", + " grid_spec=grid_spec,\n", + ")\n", + "# Wrap the kernel within a shard_map to call.\n", + "pallas_result = jax.jit(\n", + " shard_map.shard_map(\n", + " right_permute,\n", + " mesh=mesh,\n", + " in_specs=partition,\n", + " out_specs=partition,\n", + " check_rep=False,\n", + " )\n", + ")(input_arr)\n", + "\n", + "# Compare Pallas result to XLA shard_map result.\n", + "perm = tuple((src, (src + 1) % num_devices) for src in range(num_devices))\n", + "\n", + "xla_result = jax.jit(\n", + " shard_map.shard_map(\n", + " lambda x: lax.ppermute(x, 'x', perm),\n", + " mesh=mesh, in_specs=partition, out_specs=partition)\n", + ")(input_arr)\n", + "\n", + "print('Input = ', input_arr[0, ::128])\n", + "print('Pallas Result = ', pallas_result[0, ::128])\n", + "print('lax.ppermute Result = ', xla_result[0, ::128])\n", + "print(\n", + " 'Difference |Pallas - lax.ppermute| = ',\n", + " jnp.mean(jnp.abs(pallas_result - xla_result)),\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "iyfhdGXuUnq2" + }, + "source": [ + "### Example: All-gather (`lax.all_gather`)\n", + "\n", + "In this next example we will implement the all-gather collective operation, which has a JAX equivalent in `lax.all_gather`. In contrast with the right-permute example from above which only involves a pair of source and destination neighbors, an all-gather operation requires communication between all devices and therefore we must think about how data is routed between them. The specifics of how we implement this are dictated by the device topology, for which we assume is a ring.\n", + "\n", + "#### Ring Communication Pattern\n", + "\n", + "We will write our kernel assuming a ring topology. Rings are a natural fit for TPUs as slicing along any dimension of a torus produces a ring. When writing collectives, we often only need to think about 1D slices of our torus at a time because the different dimensions of the torus are reserved for different types of parallelism (data vs. model, for example).\n", + "\n", + "The strategy we will use is to write a looped kernel, where on each iteration a device receives one slice of the sharded array from its left neighbor, and copies the previously received slice to its right neighbor. After `num_devices` iterations, each device will have a copy of the entire array in its local HBM.\n", + "\n", + "![all_gather](../../_static/pallas/distributed/all_gather.svg)\n", + "\n", + "We can re-purpose Pallas's `grid` argument to implement the loop. Rather than iterating over tiles of an array as we have done in previous tutorials, we instead set the grid to `(num_devices,)` to indicate that we want to loop over the number of devices and use `pl.program_id` to obtain the loop iteration inside of the Pallas kernel. The following code snippet demonstrates how to implement this:" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": { + "executionInfo": { + "elapsed": 812, + "status": "ok", + "timestamp": 1722904804531, + "user": { + "displayName": "Justin Fu", + "userId": "17543197034567316452" + }, + "user_tz": 420 + }, + "id": "ojQEZB5mBRqM", + "outputId": "e1648f54-737c-4921-ca3b-b4c639a38d2b" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Input: (32, 128) [0.9858954 0.54248166 0.9547038 0.954962 ]\n", + "Pallas Result: (16, 8, 128) [0.9858954 0.54248166 0.9547038 0.954962 0.9858954 0.54248166\n", + " 0.9547038 0.954962 0.9858954 0.54248166 0.9547038 0.954962\n", + " 0.9858954 0.54248166 0.9547038 0.954962 ]\n", + "lax.all_gather Result: (16, 8, 128) [0.9858954 0.54248166 0.9547038 0.954962 0.9858954 0.54248166\n", + " 0.9547038 0.954962 0.9858954 0.54248166 0.9547038 0.954962\n", + " 0.9858954 0.54248166 0.9547038 0.954962 ]\n", + "Difference |Pallas - lax.all_gather| = 0.0\n" + ] + } + ], + "source": [ + "partition = P('x', None)\n", + "devices = mesh_utils.create_device_mesh((num_devices, 1))\n", + "mesh = jax.sharding.Mesh(devices, partition)\n", + "sharding = jax.sharding.NamedSharding(mesh, partition)\n", + "\n", + "# Create an input array that shards the first dimension across\n", + "# all devices.\n", + "input_arr = jax.random.uniform(jax.random.key(0), (8 * num_devices, 128))\n", + "input_arr = jax.device_put(input_arr, sharding)\n", + "\n", + "\n", + "def all_gather_kernel(input_ref,\n", + " output_ref,\n", + " local_copy_sem,\n", + " send_sem,\n", + " recv_sems):\n", + " outer_step = pl.program_id(0)\n", + " my_id = lax.axis_index('x')\n", + " right_neighbor = lax.rem(my_id + 1, num_devices)\n", + " copy_slot = my_id - outer_step\n", + " copy_slot = lax.rem(copy_slot + num_devices, num_devices)\n", + "\n", + " @pl.when(outer_step == 0)\n", + " def _():\n", + " local_copy_op = pltpu.make_async_copy(\n", + " src_ref=input_ref,\n", + " dst_ref=output_ref.at[my_id],\n", + " sem=local_copy_sem,\n", + " )\n", + " local_copy_op.start()\n", + " local_copy_op.wait()\n", + "\n", + " # Copy to our right neighbor.\n", + " # Note that we will also be receiving data from our left neighbor,\n", + " # but at `copy_slot-1` rather than `copy_slot`! This makes use of the fact\n", + " # that the indices do not need to be symmetric between remote DMAs.\n", + " remote_copy_op = pltpu.make_async_remote_copy(\n", + " src_ref=output_ref.at[copy_slot],\n", + " dst_ref=output_ref.at[copy_slot],\n", + " send_sem=send_sem,\n", + " recv_sem=recv_sems.at[outer_step],\n", + " device_id=(right_neighbor, 0),\n", + " device_id_type=pltpu.DeviceIdType.MESH,\n", + " )\n", + " remote_copy_op.start()\n", + " remote_copy_op.wait()\n", + "\n", + "out_shape = jax.ShapeDtypeStruct((num_devices, 8, 128), jnp.float32)\n", + "grid_spec = pltpu.PrefetchScalarGridSpec(\n", + " num_scalar_prefetch=0,\n", + " in_specs=[\n", + " # TPUMemorySpace.ANY will (usually) place the tensor in HBM.\n", + " pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY),\n", + " ],\n", + " out_specs=pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY),\n", + " scratch_shapes=(\n", + " # DMA semaphores are allocated in scratch memory.\n", + " # We allocated one semaphore for a local HBM-VMEM copy,\n", + " # and one for the remote send semaphore.\n", + " [pltpu.SemaphoreType.DMA] * 2\n", + " # We additionally allocate one receive semaphore per device.\n", + " # This is to avoid situations where we have multiple\n", + " # DMAs in flight, as we do not want to share a receive\n", + " # semaphore between the DMAs.\n", + " + [pltpu.SemaphoreType.DMA((num_devices-1,))]\n", + "\n", + " ),\n", + " grid=(num_devices-1,)\n", + " )\n", + "\n", + "all_gather = pl.pallas_call(\n", + " all_gather_kernel,\n", + " out_shape=out_shape,\n", + " grid_spec=grid_spec,\n", + " )\n", + "\n", + "# Wrap the kernel within a shard_map to call.\n", + "pallas_result = jax.jit(\n", + " shard_map.shard_map(\n", + " all_gather,\n", + " mesh=mesh,\n", + " in_specs=partition,\n", + " out_specs=partition,\n", + " check_rep=False\n", + " )\n", + ")(input_arr)\n", + "\n", + "# Compare Pallas result to XLA shard_map result.\n", + "xla_result = jax.jit(\n", + " shard_map.shard_map(\n", + " lambda x: lax.all_gather(x, 'x'),\n", + " mesh=mesh, in_specs=partition, out_specs=partition\n", + " )\n", + ")(input_arr)\n", + "\n", + "print('Input: ', input_arr.shape, input_arr[::8, 0])\n", + "print('Pallas Result: ', pallas_result.shape, pallas_result[:, 0, 0])\n", + "print('lax.all_gather Result: ', xla_result.shape, xla_result[:, 0, 0])\n", + "print('Difference |Pallas - lax.all_gather| = ',\n", + " jnp.mean(jnp.abs(pallas_result - xla_result)))" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "KgU7HI2pS4om" + }, + "source": [ + "A detail worth mentioning here is the use of multiple receive semaphores. Because we only block on the receiving device, it is still possible for a sender to have sent multiple DMAs in flight before the receiver has finished processing the first one (see the next section and reduce-sum example which discusses race conditions in more detail). In this situation we may hit a situation where the same semaphore is being used for multiple DMAs occurring simultaneously. To avoid this, we allocate `num_devices-1` semaphores so there is no risk of re-use. While this race condition is unlikely to happen on such a small kernel, on larger kernels there is more chance for devices to fall out of sync and potentially cause a silent failure." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "KgU7HI2pS4om" + }, + "source": [ + "## Advanced Techniques\n", + "\n", + "Now that we have seen how to write several basic kernels using remote DMA operations, we will go over more advanced techniques for synchronization and writing efficient kernels." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "8M_kdl0FCtrL" + }, + "source": [ + "### Synchronization: Regular and Barrier Semaphores\n", + "\n", + "The examples we implemented in the basic tutorial do not require special handling of synchronization as all necessary communication writes to disjoint buffers. However, other operations may require more complex communication patterns that need additional synchronization primitives to avoid race conditions. Pallas provides two additional primitives to help with this: regular and barrier semaphores.\n", + "\n", + "#### Regular Semaphores\n", + "\n", + "Regular semaphores are the standard tool used to synchronize across multiple devices. Semaphores are fundamentally counters - they can be incremented by any device after which a device can block until the value of the semaphore reaches a specific value (and then decrement the value).\n", + "\n", + "The three main operations that can be used on regular semaphores are signal, wait, and read:\n", + "```python\n", + "def semaphore_signal(\n", + " sem: Ref[SemaphoreType],\n", + " inc: int,\n", + " device_id: int | tuple[int, ...],\n", + " device_id_type: DeviceIdType\n", + ") -> None:\n", + " ... # Increments the semaphore `sem` on the target device `device_id` by `inc`.\n", + " \n", + "def semaphore_wait(\n", + " semaphore: Ref[SemaphoreType],\n", + " value: int,\n", + ") -> None:\n", + " ... # Blocks until the locally allocated copy of `sem` reaches `value`, then decrement by `value` and proceed.\n", + " \n", + "def semaphore_read(\n", + " sem: Ref[SemaphoreType],\n", + ") -> jax.Array:\n", + " ... # Returns the current value of `sem` as an `int32[]`.\n", + "```\n", + "\n", + "In order to use regular semaphores, they can be allocated in the same way as a DMA semaphore, but by specifying `pltpu.SemaphoreType.REGULAR` rather than `pltpu.SemaphoreType.DMA`.\n", + "\n", + "Semaphores must be zero at the end of a Pallas program to complete succesfully. There are two error cases where this may happen:\n", + " - If a semaphore is over-signaled, the program will end with non-zero (>0) semaphores. In this case, the program will crash upon completion. This is useful for debugging as non-zero semaphores typically means there is a bug somewhere inside of the program.\n", + " - If a semaphore is over-waited, the program will hang on the blocking `semaphore_wait` call while it waits for the sempahore to be incremented. In this case the device or program will need to be restarted.\n", + "\n", + "#### Barrier Semaphores\n", + "\n", + "Barrier semaphores are globally-allocated semaphores used to synchronize devices across an entire program and ensure that all devices have entered the Pallas kernel.\n", + "\n", + "If a Pallas kernel is executed within the context of a larger XLA program, we need to ensure that all devices that communicate have entered the kernel. However, DMA and regular semaphores are both locally scoped - they are only understood by other devices that have entered the kernel. Barrier semaphores serve as a globally understood semaphore that can be used for synchronization no matter where in the XLA program the device is currently executing.\n", + "\n", + "By default, if you do not specify a barrier semaphore, Pallas will automatically insert a barrier semaphore at the beginning of your program. However, it can be more efficient to write your own. Barrier semaphores are similar to regular semaphores in that they are counters that can be incremented via `semaphore_signal` and can be decremented via `semaphore_wait`. They are created by calling `get_barrier_semaphore()` within a kernel. Typically, we use barriers once at the beginning of a kernel to synchronize with all devices we are communicating with.\n", + "\n", + "```python\n", + "from jax.experimental.pallas import tpu as pltpu\n", + "\n", + "def example_kernel(...):\n", + " # Use barrier semaphores at the beginning of a kernel.\n", + " # is_start_of_kernel = ...\n", + " # right_neighbor = ...\n", + " # ...\n", + " @pl.when(is_start_of_kernel)\n", + " def _():\n", + " barrier_sem = pltpu.get_barrier_semaphore()\n", + " # Increment the semaphore of your right neighbor.\n", + " pltpu.semaphore_signal(\n", + " barrier_sem,\n", + " device_id=right_neighbor,\n", + " device_id_type=pltpu.DeviceIdType.LOGICAL,\n", + " )\n", + " # Wait until your left neighbor has incremented your semaphore\n", + " pltpu.semaphore_wait(barrier_sem, 1)\n", + " # ...\n", + "```\n", + "\n", + "When using barrier semaphores, the `collective_id` compiler parameter must be passed to `pallas_call` to specify which barrier semaphore is being used. A TPU has a small, fixed number of barrier semaphores available (typically on the order of 20-30) and therefore they should be used sparingly. In order to ensure correctness, only kernels that share the same communication pattern should use the same `collective_id`. For example, if two kernels synchronize only with neighbors on the same mesh axis, they are allowed to share the same `collective_id`. However, if two kernels synchronize along different axes, they must have different `collective_id`s. Failure to do so may result in race conditions that are difficult to debug.\n", + "\n", + "```python\n", + "kernel = pl.pallas_call(\n", + " example_kernel,\n", + " ...,\n", + " compiler_params=pltpu.TPUCompilerParams(collective_id=0),\n", + ")\n", + "```" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "zy20AxN5TSLA" + }, + "source": [ + "### Double-buffering\n", + "\n", + "In order to avoid reading from a local `Ref` that is also being written into by another device and creating a race condition, a useful technique is the \"double-buffered\" strategy where we allocate a two `Ref`s for each destination value. On each iteration, one `Ref` will be designated as a \"working\" slot, and the other will be designated as a \"receiving\" slot. The device is free to use the working slot for computation, but will only copy data into its neighbor's receiving slot. The working and receiving slots alternate every iteration, so that once a copy is finished, the old receiving slot becomes the new working slot, and vice versa. Using this scheme properly, data is never read from and written to the same buffer.\n", + "\n", + "The following code skeleton demonstrates how double-buffering can be used. We keep a running iteration counter in the variable `iteration`, and the `working_slot` and `receiving_slot` alternate between 0 and 1 every iteration. `dst_ref` is allocated as a double-buffer and has the size `[2, ...]`. On each iteration, we read from the working slot using `dst_ref.at[working_slot, ...]` and use the value to perform computation. Simultaneously, we copy to our neighbor's `dst_ref.at[receiving_slot]` to avoid overwriting their `working_slot` value. By structuring our communication in this fashion it is possible to overlap the communication latency of the remote DMA with local computation while minimizing the risk of race conditions.\n", + "```python\n", + "def kernel(...):\n", + " # ...\n", + " iteration = pl.program_id(0)\n", + " working_slot = lax.rem(iteration, 2)\n", + " receiving_slot = 1 - working_slot\n", + " # ...\n", + "\n", + " local_copy_op = pltpu.make_async_copy(\n", + " src_ref=dst_ref.at[working_slot, ...],\n", + " dst_ref=local_scratch_ref,\n", + " sem=local_copy_sem,\n", + " )\n", + " local_copy_op.start()\n", + " remote_copy_op = pltpu.make_async_remote_copy(\n", + " src_ref=src_ref,\n", + " dst_ref=dst_ref.at[receiving_slot, ...],\n", + " send_sem=send_sem,\n", + " recv_sem=recv_sem,\n", + " device_id=target_device,\n", + " device_id_type=pltpu.DeviceIdType.MESH,\n", + " )\n", + " remote_copy_op.start()\n", + " \n", + " local_copy_op.wait()\n", + " # ... do work on local_scratch while waiting for async_copy_op to finish.\n", + " remote_copy_op.wait()\n", + "\n", + "```\n", + "\n", + "In terms of synchronization, the double-buffered construction works if all devices are executing on the same iteration. If a sender manages to get one iteration ahead of its receiver, it's `working_slot` and `receiving_slot` indices will be flipped compared to the receiver, meaning that it could be writing into the `working_slot` at the same time the receiver is reading from it. In order to avoid this, it may be necessary to use a semaphore to synchronize the sender with the receiver, or add additional buffering slots (\"triple\", \"quadruple\", or N-buffered) to allow additional run-ahead at the cost of more memory. In our previous `all_gather` example, note that the kernel contained a receiving buffer with N slots, which avoids race conditions altogether. In our next kernel, we will instead go through an example which uses a double-buffer with explicit synchronization." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "Or0Itv72No5d" + }, + "source": [ + "### Example: All-Reduce Sum (`lax.psum`)\n", + "\n", + "We will now implement an all-reduce sum kernel using double-buffering and semaphores for synchronization. For those familiar with collective operations in JAX, the equivalent operation is `lax.psum`. All-reduce is a standard collective operation where the objective is to reduce along an axis of an array, but the array is sharded across multiple devices.\n", + "\n", + "![reduce_sum_1](../../_static/pallas/distributed/reduce_sum_1.svg)\n", + "\n", + "In the above example, we have the array [5, 2, 1, 3] sharded across 4 devices. An all-reduce sum operation would sum all values and replicate the result on each device, leading to the result [11, 11, 11, 11] sharded across all 4 devices.\n", + "\n", + "The naive implementation of all-reduce would be to gather all required values onto each device, and then reduce. However, we can improve the performance of this implementation by interleaving communication with computation. An interleaved, single-direction all-reduce can be visualized as follows. On each iteration, we receive an input value from our left neighbor, and concurrently pass input along to our next neighbor while incrementing it with our local accumulator. After N-1 iterations, each device will have a copy of the full sum in it's memory.\n", + "\n", + "![reduce_sum_2](../../_static/pallas/distributed/reduce_sum_2.svg)\n", + "\n", + "#### Putting it all together\n", + "\n", + "The following kernel demonstrates how to combine these principles into a functional kernel.\n", + "\n", + "The prologue (executed when `outer_step==0`) first initiates a barrier with both neighbors to ensure that they have also entered the kernel. It also handles initialization for all `Ref`s and handles the first remote copy to the right neighbor's \"working\" slot.\n", + "\n", + "The main body assumes that a value has already been copied into our local working slot, either from the previous iteration or from the prologue. A complicating factor is that our destination buffers live in HBM, but we need to load values to VMEM before we perform arithmetic. Therefore, we simultaneously copy the working slot value into our VMEM (`receive_scratch`) and pass the value on to our right neighbor's receiving slot. Once the value has been copied into our VMEM, we can accumulate it into our result (contained in `o_ref`).\n", + "\n", + "A subtle race condition can occur if one device runs one loop ahead of it's right neighbor. In this case, it could copy into the receiver's `working_slot` at the same time the receiver is reading from it. In order to avoid this, each device will block on a `REGULAR` semaphore before copying into the right neighbor's `dst_ref` until it has signaled that it is done reading from its `working_slot`. This race condition is rarely triggered for a small kernel such as this example, but can it can be explicitly triggered if for example using a `pltpu.delay` instruction to artifically hang a device.\n", + "\n", + "Note that this is not an optimal or fully general kernel, as the block sizes must entirely fit in VMEM and we could better interleave communication and accumulation. We will discuss these optimizations in later sections." + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": { + "executionInfo": { + "elapsed": 254, + "status": "ok", + "timestamp": 1722904804952, + "user": { + "displayName": "Justin Fu", + "userId": "17543197034567316452" + }, + "user_tz": 420 + }, + "id": "XrY5bMlvBroQ", + "outputId": "77497000-4496-462e-cc3c-73fb640cc14c" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Input = [0.9858954 0.11763906 0.9955574 0.775211 ]\n", + "Pallas result = [2.8743029 2.8743029 2.8743029 2.8743029]\n", + "lax.psum result = [2.8743029 2.8743029 2.8743029 2.8743029]\n", + "Difference |Pallas - lax.psum| = 1.4959369e-08\n" + ] + } + ], + "source": [ + "partition = P(None, 'x')\n", + "devices = mesh_utils.create_device_mesh((1, num_devices))\n", + "mesh = jax.sharding.Mesh(devices, partition)\n", + "sharding = jax.sharding.NamedSharding(mesh, partition)\n", + "\n", + "input_arr = jax.random.uniform(jax.random.key(0), shape=(8, 128 * num_devices))\n", + "input_arr = jax.device_put(input_arr, sharding)\n", + "\n", + "\n", + "def all_reduce_kernel(\n", + " x_ref,\n", + " o_ref,\n", + " hbm_scratch,\n", + " copy_sem,\n", + " remote_recv_sem,\n", + " remote_send_sem,\n", + " capacity_sem,\n", + " receive_scratch,\n", + "):\n", + " outer_step = pl.program_id(0)\n", + " working_slot = lax.rem(outer_step, 2)\n", + " receiving_slot = 1 - working_slot\n", + "\n", + " my_id = lax.axis_index('x')\n", + " right_neighbor = lax.rem(my_id + 1, num_devices)\n", + " left_neighbor = lax.rem(my_id - 1 + num_devices, num_devices)\n", + "\n", + " @pl.when(outer_step == 0)\n", + " def _():\n", + " # Barrier with both neighbors at the start, since we will be\n", + " # communicating with both.\n", + " barrier_sem = pltpu.get_barrier_semaphore()\n", + " pltpu.semaphore_signal(\n", + " barrier_sem,\n", + " inc=1,\n", + " device_id=(0, left_neighbor),\n", + " device_id_type=pltpu.DeviceIdType.MESH,\n", + " )\n", + " pltpu.semaphore_signal(\n", + " barrier_sem,\n", + " inc=1,\n", + " device_id=(0, right_neighbor),\n", + " device_id_type=pltpu.DeviceIdType.MESH,\n", + " )\n", + " pltpu.semaphore_wait(barrier_sem, 2)\n", + "\n", + " # Initialize o_ref, acc_scratch, and hbm_scratch.\n", + " o_ref[...] = jnp.zeros_like(o_ref)\n", + " receive_scratch[...] = jnp.zeros_like(receive_scratch)\n", + " initial_copy = pltpu.make_async_remote_copy(\n", + " src_ref=x_ref,\n", + " dst_ref=hbm_scratch.at[working_slot],\n", + " send_sem=remote_send_sem,\n", + " recv_sem=remote_recv_sem,\n", + " device_id=(0, right_neighbor),\n", + " device_id_type=pltpu.DeviceIdType.MESH,\n", + " )\n", + " initial_copy.start()\n", + " initial_copy.wait()\n", + "\n", + " # Signal to our left neighbor that we are ready to receive.\n", + " # Without this signal, our left neighbor can be >=1 iteration ahead,\n", + " # meaning it could write into our working slot.\n", + " pltpu.semaphore_signal(\n", + " capacity_sem,\n", + " inc=1,\n", + " device_id=(0, left_neighbor),\n", + " device_id_type=pltpu.DeviceIdType.MESH,\n", + " )\n", + "\n", + " # Copy the partial result our left neighbor sent to us into VMEM for\n", + " # computation.\n", + " local_copy = pltpu.make_async_copy(\n", + " src_ref=hbm_scratch.at[working_slot],\n", + " dst_ref=receive_scratch,\n", + " sem=copy_sem,\n", + " )\n", + " local_copy.start()\n", + "\n", + " # Block until our right neighbor is ready to receive.\n", + " pltpu.semaphore_wait(capacity_sem, 1)\n", + " # Pass the value to our right neighbor.\n", + " remote_copy = pltpu.make_async_remote_copy(\n", + " src_ref=hbm_scratch.at[working_slot],\n", + " dst_ref=hbm_scratch.at[receiving_slot],\n", + " send_sem=remote_send_sem,\n", + " recv_sem=remote_recv_sem,\n", + " device_id=(0, right_neighbor),\n", + " device_id_type=pltpu.DeviceIdType.MESH,\n", + " )\n", + " remote_copy.start()\n", + " # Finish local copy and accumulate while remote_copy is happening.\n", + " local_copy.wait()\n", + " o_ref[...] += receive_scratch[...]\n", + " # Block until remote copy finishes.\n", + " remote_copy.wait()\n", + "\n", + "\n", + "out_shape = (\n", + " jax.ShapeDtypeStruct((8, 128), jnp.float32),\n", + " # We allocate the double-buffer as a Pallas output so that it is\n", + " # resident in HBM.\n", + " jax.ShapeDtypeStruct((2, 8, 128), jnp.float32), # hbm_scratch\n", + ")\n", + "\n", + "grid_spec = pltpu.PrefetchScalarGridSpec(\n", + " num_scalar_prefetch=0,\n", + " in_specs=[\n", + " # Our input lives in VMEM\n", + " pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.VMEM),\n", + " ],\n", + " out_specs=[\n", + " # Our output lives in VMEM\n", + " pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.VMEM),\n", + " # Our double-buffer lives in HBM\n", + " pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY),\n", + " ],\n", + " grid=(num_devices,),\n", + " scratch_shapes=(\n", + " [pltpu.SemaphoreType.DMA] * 3\n", + " + [pltpu.SemaphoreType.REGULAR] # capacity_sem\n", + " + [pltpu.VMEM((8, 128), jnp.float32)] # receive_scratch\n", + " ),\n", + ")\n", + "\n", + "kernel = pl.pallas_call(\n", + " all_reduce_kernel,\n", + " out_shape=out_shape,\n", + " grid_spec=grid_spec,\n", + " compiler_params=pltpu.TPUCompilerParams(collective_id=0),\n", + ")\n", + "\n", + "pallas_result = jax.jit(\n", + " shard_map.shard_map(\n", + " kernel,\n", + " mesh=mesh,\n", + " in_specs=partition,\n", + " out_specs=partition,\n", + " check_rep=False,\n", + " )\n", + ")(input_arr)\n", + "pallas_result = jax.block_until_ready(pallas_result)[0]\n", + "\n", + "\n", + "def lax_sum(x):\n", + " return lax.psum(x, 'x')\n", + "\n", + "\n", + "xla_result = jax.jit(\n", + " shard_map.shard_map(\n", + " lax_sum, mesh=mesh, in_specs=P(None, 'x'), out_specs=P(None, 'x')\n", + " )\n", + ")(input_arr)\n", + "\n", + "print('Input = ', input_arr[0, ::128])\n", + "print('Pallas result = ', pallas_result[0, ::128])\n", + "print('lax.psum result = ', xla_result[0, ::128])\n", + "difference = jnp.mean(jnp.abs(pallas_result - xla_result))\n", + "print('Difference |Pallas - lax.psum| = ', difference)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "d8bsZAzQreC_" + }, + "source": [ + "### Run-ahead and Race Conditions\n", + "\n", + "As a general rule of thumb, to maximize performance we want to allow a device to run-ahead of other devices without synchronization as much as possible without sacrificing correctness of the program. While we could enforce a barrier across all devices at the beginning of each iteration, this bottlenecks the performance of the program to the slowest device on each loop. By relaxing synchronization and allowing a moderate amount of run-ahead, we can better accommodate variance in latency between iterations and devices because a device that is slow on one iteration could catch up on the next iteration.\n", + "\n", + "In the all-reduce kernel we wrote previously, we allow devices to run ahead but by less than one iteration compared to its neighbors (however, non-neighboring devices could be more than 1 iteration apart). To see why the semaphore synchronization is necessary, consider the case when one device (say device 2) hangs and falls behind the other devices. An RDMA has no \"handshake\" — only the receiver is blocked while waiting for the data to arrive. Therefore, each device can run up to one iteration ahead before it becomes blocked waiting for the next RDMA to arrive. If we have N devices, this means that the final device can be up to N iterations ahead of the first device.\n", + "\n", + "![race_condition](../../_static/pallas/distributed/race_condition.svg)\n", + "\n", + "Without adding synchronization in the other direction (forcing senders to block), device 1 could potentially run up to `N` iterations (`N = num_devices`) ahead of device 2, sending multiple writes and overwriting values in the process. To solve this in the `all_reduce` kernel we wrote previously we implemented a \"handshake\" protocol where the receiver signals back to the sender that it is ready to receive, and only then does the sender begin issuing the next RDMA." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "UD8lNrqsUeXy" + }, + "source": [ + "### Bi-directional Communication\n", + "\n", + "In our previous kernels, we communicated in a single direction around a ring from left-to-right. However, as ICI connections are bi-directional, we are effectively wasting half of the total bandwidth by not sending values in the opposite direction from right-to-left. In this next kernel we will demonstrate an example which communicates in both directions to maximize ICI bandwidth." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "4KjakLhbBk73" + }, + "source": [ + "### Example: Bi-directional Reduce-Scatter (`lax.psum_scatter`)\n", + "\n", + "A reduce-scatter operation is the combination of an all-reduce followed by a scatter. Or alternatively, an all-reduce is the combination of a reduce-scatter followed by all-gather.\n", + "\n", + "The following graphic depicts the semantics of this operation. We assume that each device starts with a collection of partial sums (denoted by a letter + number, such as `A0`). The goal is to reduce along one axis (numbers), while sharding along the other axis (letters).\n", + "\n", + "![reduce_scatter_1](../../_static/pallas/distributed/reduce_scatter_1.svg)\n", + "\n", + "In order to implement a bi-directional communication strategy, we slice each input block in half, and designate a direction for each half. The top half of each block will be passed from right-to-left, and the bottom half will be passed from left-to-right. A second deviation from the communication patterns of our previous all-reduce and all-gather kernels is that we will also pass around accumulators or partial sums and keep the inputs local to each device. This is in contrast to the previous examples where we passed around inputs but kept the accumulator local to the device. Passing around the accumulator is a more natural fit for this problem as in contrast to all-reduce, most of the data in the inputs are not part of the output that will be stored locally on the device. (e.g. `B0`, `C0`, and `D0` in the above graphic will not be stored on the device holding `A` at the end).\n", + "\n", + "The following diagram illustrates this communication pattern, where the colored boxes represent accumulators (not inputs!). Initially, the accumulator is simply the value that was contained in the input. At each iteration of the algorithm, we will receive a partial sum from our neighbors in each direction. We then compute the correct slice of our input to accumulate into the partial buffer, then pass the new partial sum along to our next neighbor. After N iterations, the accumulator will have passed through each device, meaning that it will hold the full sum in the end.\n", + "\n", + "![reduce_scatter_2](../../_static/pallas/distributed/reduce_scatter_2.svg)\n", + "\n", + "In terms of construction of the kernel, we introduce an additional `phase` dimension to the Pallas grid, which denotes which accumulator (left or right) we are currently computing on. We let `phase=0` denote the accumulator moving to the left, and `phase=1` denote the accumulator moving to the right. We then pipeline the two phases, such that while computing the result for one phase we are transferring our previously computed values in the opposite direction in preparation for the next phase. For example, when we are on `phase=0` (left), we first begin a DMA to transfer results we computed in the previous iteration to our right neighbor (right-DMA). Then, we accumulate into the left-buffer and save the result to HBM. We then wait for the right-DMA to complete so that it is ready for `phase=1` (right)." + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": { + "executionInfo": { + "elapsed": 544, + "status": "ok", + "timestamp": 1722904805699, + "user": { + "displayName": "Justin Fu", + "userId": "17543197034567316452" + }, + "user_tz": 420 + }, + "id": "nRauUAxNHg28" + }, + "outputs": [], + "source": [ + "partition = P(None, 'x')\n", + "devices = mesh_utils.create_device_mesh((1, num_devices))\n", + "mesh = jax.sharding.Mesh(devices, partition)\n", + "sharding = jax.sharding.NamedSharding(mesh, partition)\n", + "\n", + "# We need a block size of (16, 128) to ensure that a half-slice is at least\n", + "# of size (8, 128), which is the size of a VREG. This makes tiling easier\n", + "# for the compiler.\n", + "block_size = (16, 128)\n", + "input_arr = jax.random.uniform(\n", + " jax.random.key(0),\n", + " shape=(block_size[0] * num_devices, block_size[1] * num_devices),\n", + ")\n", + "input_arr = jax.device_put(input_arr, sharding)\n", + "\n", + "LEFT = 0\n", + "RIGHT = 1\n", + "\n", + "\n", + "def mod(x, n):\n", + " return lax.rem(x + n, n)\n", + "\n", + "\n", + "def signal(left_or_right, semaphore):\n", + " my_id = lax.axis_index('x')\n", + " if left_or_right == LEFT:\n", + " neighbor = mod(my_id - 1, num_devices)\n", + " else:\n", + " neighbor = mod(my_id + 1, num_devices)\n", + " pltpu.semaphore_signal(\n", + " semaphore,\n", + " inc=1,\n", + " device_id=(0, neighbor),\n", + " device_id_type=pltpu.DeviceIdType.MESH,\n", + " )\n", + "\n", + "\n", + "def reduce_scatter_kernel(\n", + " x_ref,\n", + " o_ref,\n", + " hbm_scratch,\n", + " local_copy_sem,\n", + " left_recv_sem,\n", + " left_send_sem,\n", + " right_recv_sem,\n", + " right_send_sem,\n", + " left_capacity_sem,\n", + " right_capacity_sem,\n", + " accum_scratch,\n", + "):\n", + " outer_step = pl.program_id(0)\n", + " phase = pl.program_id(1)\n", + " is_start = jnp.logical_and(outer_step == 0, phase == 0)\n", + " last_iteration = outer_step == pl.num_programs(0) - 1\n", + "\n", + " working_slot = lax.rem(outer_step, 2)\n", + " receiving_slot = 1 - working_slot\n", + " my_id = lax.axis_index('x')\n", + " right_neighbor = mod(my_id + 1, num_devices)\n", + " left_neighbor = mod(my_id - 1, num_devices)\n", + "\n", + " left_copy_device = mod(my_id + outer_step + 1, num_devices)\n", + " right_copy_device = mod(my_id - outer_step - 1, num_devices)\n", + " # Slices can be specified using pl.ds(start, size)\n", + " left_copy_slice = pl.ds(0, block_size[0] // 2)\n", + " right_copy_slice = pl.ds(block_size[0] // 2, block_size[0] // 2)\n", + " current_phase_slice = pl.ds(phase * (block_size[0] // 2), block_size[0] // 2)\n", + "\n", + " initial_left_copy = pltpu.make_async_remote_copy(\n", + " src_ref=x_ref.at[my_id, left_copy_slice],\n", + " dst_ref=hbm_scratch.at[working_slot, left_copy_slice],\n", + " send_sem=left_send_sem,\n", + " recv_sem=left_recv_sem,\n", + " device_id=(0, left_neighbor),\n", + " device_id_type=pltpu.DeviceIdType.MESH,\n", + " )\n", + "\n", + " initial_right_copy = pltpu.make_async_remote_copy(\n", + " src_ref=x_ref.at[my_id, right_copy_slice],\n", + " dst_ref=hbm_scratch.at[working_slot, right_copy_slice],\n", + " send_sem=right_send_sem,\n", + " recv_sem=right_recv_sem,\n", + " device_id=(0, right_neighbor),\n", + " device_id_type=pltpu.DeviceIdType.MESH,\n", + " )\n", + "\n", + " left_copy = pltpu.make_async_remote_copy(\n", + " src_ref=hbm_scratch.at[working_slot, left_copy_slice],\n", + " dst_ref=hbm_scratch.at[receiving_slot, left_copy_slice],\n", + " send_sem=left_send_sem,\n", + " recv_sem=left_recv_sem,\n", + " device_id=(0, left_neighbor),\n", + " device_id_type=pltpu.DeviceIdType.MESH,\n", + " )\n", + " right_copy = pltpu.make_async_remote_copy(\n", + " # Note: Right copy is flipped with regards to slots since we are copying\n", + " # to the next outer_step iteration.\n", + " src_ref=hbm_scratch.at[receiving_slot, right_copy_slice],\n", + " dst_ref=hbm_scratch.at[working_slot, right_copy_slice],\n", + " send_sem=right_send_sem,\n", + " recv_sem=right_recv_sem,\n", + " device_id=(0, right_neighbor),\n", + " device_id_type=pltpu.DeviceIdType.MESH,\n", + " )\n", + "\n", + " # --- Prologue ---\n", + " @pl.when(is_start)\n", + " def _():\n", + " # Barrier with both neighbors at the start, since we will be\n", + " # communicating with both.\n", + " barrier_sem = pltpu.get_barrier_semaphore()\n", + " pltpu.semaphore_signal(\n", + " barrier_sem,\n", + " inc=1,\n", + " device_id=(0, left_neighbor),\n", + " device_id_type=pltpu.DeviceIdType.MESH,\n", + " )\n", + " pltpu.semaphore_signal(\n", + " barrier_sem,\n", + " inc=1,\n", + " device_id=(0, right_neighbor),\n", + " device_id_type=pltpu.DeviceIdType.MESH,\n", + " )\n", + " pltpu.semaphore_wait(barrier_sem, 2)\n", + "\n", + " # Initialize o_ref, acc_scratch, and hbm_scratch with initial copies.\n", + " o_ref[...] = jnp.zeros_like(o_ref[...])\n", + " accum_scratch[...] = jnp.zeros_like(accum_scratch[...])\n", + "\n", + " initial_left_copy.start()\n", + " initial_left_copy.wait()\n", + " initial_right_copy.start()\n", + "\n", + " # We tell our left neighbor that it is allowed to send to the right.\n", + " # (and vice versa for right neighbor)\n", + " signal(LEFT, right_capacity_sem)\n", + " signal(RIGHT, left_capacity_sem)\n", + "\n", + " # --- Body ---\n", + " # At the beginning of our kernel body, we start a DMA which copies\n", + " # the result we computed in the previous phase to our neighbor.\n", + " # This allows us to overlap the communication of sending our previous phase\n", + " # with the computation for the current phase.\n", + " @pl.when(~is_start)\n", + " def _():\n", + " @pl.when(phase == LEFT)\n", + " def _():\n", + " # We block here until our right neighbor tells use we can send to\n", + " # the right.\n", + " pltpu.semaphore_wait(right_capacity_sem, 1)\n", + " right_copy.start()\n", + "\n", + " @pl.when(phase == RIGHT)\n", + " def _():\n", + " # We block here until our left neighbor tells use we can send to\n", + " # the left.\n", + " pltpu.semaphore_wait(left_capacity_sem, 1)\n", + " left_copy.start()\n", + "\n", + " local_copy = pltpu.make_async_copy(\n", + " src_ref=hbm_scratch.at[working_slot, current_phase_slice],\n", + " dst_ref=accum_scratch,\n", + " sem=local_copy_sem,\n", + " )\n", + " local_copy.start()\n", + " local_copy.wait()\n", + "\n", + " @pl.when(~last_iteration)\n", + " def _():\n", + " @pl.when(phase == LEFT)\n", + " def _():\n", + " accum_scratch[...] += x_ref[left_copy_device, left_copy_slice]\n", + "\n", + " @pl.when(phase == RIGHT)\n", + " def _():\n", + " accum_scratch[...] += x_ref[right_copy_device, right_copy_slice]\n", + "\n", + " local_copy = pltpu.make_async_copy(\n", + " src_ref=accum_scratch,\n", + " dst_ref=hbm_scratch.at[working_slot, current_phase_slice],\n", + " sem=local_copy_sem,\n", + " )\n", + " local_copy.start()\n", + " local_copy.wait()\n", + "\n", + " @pl.when(is_start)\n", + " def _():\n", + " initial_right_copy.wait()\n", + "\n", + " # At the end of our kernel body, we wait on the DMA of the previous phase\n", + " # to make sure the results are ready for the next phase.\n", + " @pl.when(~is_start)\n", + " def _():\n", + " @pl.when(phase == LEFT)\n", + " def _():\n", + " right_copy.wait()\n", + " signal(LEFT, right_capacity_sem)\n", + "\n", + " @pl.when(phase == RIGHT)\n", + " def _():\n", + " left_copy.wait()\n", + " signal(RIGHT, left_capacity_sem)\n", + "\n", + " # --- Epilogue ---\n", + " # Store result on last iteration.\n", + " @pl.when(last_iteration)\n", + " def _():\n", + " # Clean up semaphores so that they exit with a value of 0.\n", + " @pl.when(phase == LEFT)\n", + " def _():\n", + " o_ref[left_copy_slice, ...] = accum_scratch[...]\n", + " pltpu.semaphore_wait(right_capacity_sem, 1)\n", + "\n", + " @pl.when(phase == RIGHT)\n", + " def _():\n", + " o_ref[right_copy_slice, ...] = accum_scratch[...]\n", + " pltpu.semaphore_wait(left_capacity_sem, 1)\n", + "\n", + "\n", + "out_shape = (\n", + " jax.ShapeDtypeStruct((block_size[0], block_size[1]), jnp.float32), # output\n", + " # Shape: [working/recv, block[0], block[1]]\n", + " jax.ShapeDtypeStruct(\n", + " (2, block_size[0], block_size[1]), jnp.float32\n", + " ), # hbm_scratch\n", + ")\n", + "\n", + "grid_spec = pltpu.PrefetchScalarGridSpec(\n", + " num_scalar_prefetch=0,\n", + " in_specs=[\n", + " pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.VMEM),\n", + " ],\n", + " out_specs=[\n", + " pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.VMEM),\n", + " pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY),\n", + " ],\n", + " grid=(num_devices, 2),\n", + " scratch_shapes=(\n", + " [pltpu.SemaphoreType.DMA] * 5\n", + " + [pltpu.SemaphoreType.REGULAR] * 2 # Capacity semaphores\n", + " + [\n", + " pltpu.VMEM((block_size[0] // 2, block_size[1]), jnp.float32)\n", + " ] # accum_scratch\n", + " ),\n", + ")\n", + "\n", + "\n", + "def pallas_reduce_scatter(input_arr):\n", + " input_arr = input_arr.reshape(num_devices, block_size[0], block_size[1])\n", + " return pl.pallas_call(\n", + " reduce_scatter_kernel,\n", + " out_shape=out_shape,\n", + " grid_spec=grid_spec,\n", + " compiler_params=pltpu.TPUCompilerParams(collective_id=0),\n", + " )(input_arr)[0]\n", + "\n", + "\n", + "pallas_result = jax.jit(\n", + " shard_map.shard_map(\n", + " pallas_reduce_scatter,\n", + " mesh=mesh,\n", + " in_specs=P(None, 'x'),\n", + " out_specs=P('x', None),\n", + " check_rep=False,\n", + " )\n", + ")(input_arr)\n", + "\n", + "pallas_result = jax.block_until_ready(pallas_result)" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": { + "executionInfo": { + "elapsed": 596, + "status": "ok", + "timestamp": 1722904806442, + "user": { + "displayName": "Justin Fu", + "userId": "17543197034567316452" + }, + "user_tz": 420 + }, + "id": "E-NMh-_teoi4", + "outputId": "24beb42f-1bdd-4c34-e8d2-681dd7f2e9c0" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Input: (64, 512) [0.78051674 0.3524047 0.59993696 0.9714314 0.24692321 0.01347649\n", + " 0.01857424 0.24841607 0.86097646 0.8261659 0.9753758 0.6902338\n", + " 0.4431417 0.963323 0.3158517 0.535548 ]\n", + "Pallas Result: (64, 128) [1.3593563 1.6274805 1.0979297 3.082869 1.4194957 1.4163033 1.2401303\n", + " 1.1892898 2.6545286 2.221559 2.7995253 2.08431 2.2509837 3.0726733\n", + " 2.4662397 1.9542246]\n", + "lax.psum_scatter Result: (64, 128) [1.3593563 1.6274805 1.0979297 3.082869 1.4194957 1.4163033 1.2401303\n", + " 1.1892898 2.6545286 2.221559 2.7995253 2.08431 2.2509837 3.0726733\n", + " 2.4662397 1.9542246]\n", + "Difference |Pallas - lax.psum_scatter|: 2.3841858e-07\n" + ] + } + ], + "source": [ + "# Compare our result to XLA.\n", + "def lax_reduce_sum_scatter(x):\n", + " x = x.reshape(num_devices, block_size[0], block_size[1])\n", + " return lax.psum_scatter(x, 'x')\n", + "\n", + "\n", + "xla_result = jax.jit(\n", + " shard_map.shard_map(\n", + " lax_reduce_sum_scatter,\n", + " mesh=mesh,\n", + " in_specs=P(None, 'x'),\n", + " out_specs=P('x', None),\n", + " )\n", + ")(input_arr)\n", + "\n", + "print('Input:', input_arr.shape, input_arr[::4, 0])\n", + "print('Pallas Result:', pallas_result.shape, pallas_result[::4, 0])\n", + "print('lax.psum_scatter Result:', xla_result.shape, xla_result[::4, 0])\n", + "print(\n", + " 'Difference |Pallas - lax.psum_scatter|:',\n", + " jnp.max(jnp.abs(pallas_result - xla_result)),\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "ThKas40r40Ji" + }, + "source": [ + "### Nested Remote and Local DMA Pipelines\n", + "\n", + "A limitation of the previous all-reduce and reduce-scatter kernels that we wrote is that the blocks we copy via remote DMA must be small enough to fit in our working VMEM that we use for accumulation. For some kernels it may be advantageous to use larger block sizes to better utilize the TPU. For example, a matrix multiplication requires on the order of $O(N^3)$ compute operations, but only $O(N^2)$ memory transfers. Therefore, we want each block of work transferred between devices to be large enough such that the operation becomes compute bound and we can hide the communication cost using pipelining. For reference, the VMEM of a TPU (for generations v4/v5) is typically on the order of 10-100MB, whereas HBM ranges from 10-100GB.\n", + "\n", + "To address this problem, we need to be able to write an \"inner kernel\" that handles local HBM-VMEM pipelining inside of the \"outer kernel\" that handles pipelining larger HBM-HBM transfers between devices. Pallas offers an API for constructing nested pipelines using the `emit_pipeline` function. The basic call signature for `emit_pipeline` follows that of a standard `pallas_call` by specifying a `grid` and `BlockSpec`s for the inputs and outputs:\n", + "\n", + "```python\n", + "def emit_pipeline(\n", + " kernel: Callable,\n", + " grid: tuple[int],\n", + " in_specs: PyTree[BlockSpec] = None,\n", + " out_specs: PyTree[BlockSpec] = None,\n", + " should_accumulate_out: bool = False,\n", + " dimension_semantics: tuple[GridDimensionSemantics] = None,\n", + ") -> Callable:\n", + " ... # Returns a custom pipeline given an inner kernel and BlockSpecs.\n", + "```\n", + "\n", + "Indeed, one can view `pallas_call` itself as simply a wrapper around `emit_pipeline`. Because our outer kernel only involves remote HBM-HBM transfers, we are not using any of the built-in pipelining that `pallas_call` provides for HBM-VMEM transfers. The following code skeleton demonstrates what a typical program structure would look like using this pattern:\n", + "\n", + "```python\n", + "\n", + "def outer_kernel(...):\n", + " # ... do work to pipeline remote HBM-HBM transfers (outer kernel)\n", + "\n", + " def inner_kernel(...):\n", + " # ... do work (inner kernel)\n", + " pltpu.emit_pipeline(\n", + " inner_kernel,\n", + " grid=inner_grid,\n", + " in_specs=...,\n", + " out_specs=...,\n", + " )(inner_kernel_args)\n", + " # ... do more work (outer kernel)\n", + "\n", + "pl.pallas_call(\n", + " outer_kernel,\n", + " grid=outer_grid,\n", + " in_specs=...\n", + " out_specs=...\n", + " scratch=inner_kernel_allocs\n", + ")\n", + "```" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "DzFeQjYaasX5" + }, + "source": [ + "### Example: Reduce-Scatter with large HBM blocks\n", + "\n", + "In this next example we will modify our previous reduce-scatter example to utilize a nested inner pipeline. Note that the communication and computation costs of `reduce_scatter` both scale linearly with the size of the input, so we do not necessarily expect to see the operation become compute-bound with larger block sizes. This example is purely for demonstration purposes on how to use the pipeline emitter.\n", + "\n", + "We will increase the block sizes of the outer kernel such that they would be undesirable to place inside of VMEM, and allocate all inputs and outputs in HBM (`memory_space=TPUMemorySpace.Any`). The only major change from our previous kernel is the body of the kernel where accumulation is done. Rather than manually copying from HBM to VMEM, accumulating, and copying back to HBM, we use `emit_pipeline` to handle the memory transfers for us. Accumulation is done in an inner kernel with a much smaller, VMEM-friendly block size.\n", + "\n", + "In our previous kernel we had the following kernel body to copy data from HBM to the VMEM accumulator, increment, and then copy the results back to HBM:\n", + "\n", + "```python\n", + "local_copy = pltpu.make_async_copy(\n", + " src_ref=hbm_scratch.at[working_slot, current_phase_slice],\n", + " dst_ref=accum_scratch,\n", + " sem=local_copy_sem,\n", + ")\n", + "local_copy.start()\n", + "local_copy.wait()\n", + "@pl.when(~last_iteration)\n", + "def _():\n", + " @pl.when(phase == LEFT)\n", + " def _():\n", + " accum_scratch[...] += x_ref[left_copy_device, left_copy_slice]\n", + " @pl.when(phase == RIGHT)\n", + " def _():\n", + " accum_scratch[...] += x_ref[right_copy_device, right_copy_slice]\n", + "local_copy = pltpu.make_async_copy(\n", + " src_ref=accum_scratch,\n", + " dst_ref=hbm_scratch.at[working_slot, current_phase_slice],\n", + " sem=local_copy_sem,\n", + ")\n", + "local_copy.start()\n", + "local_copy.wait()\n", + "```\n", + "\n", + "Our new kernel replaces it with the following `emit_pipeline` call:\n", + "\n", + "```python\n", + "def inner_kernel(input_ref, accum_ref):\n", + " accum_ref[...] = input_ref[...]\n", + "accum_pipeline = pltpu.emit_pipeline(inner_kernel,\n", + " in_specs=[inner_block_spec],\n", + " out_specs=inner_block_spec,\n", + " should_accumulate_out=True,\n", + " grid=inner_grid)\n", + "@pl.when(~last_iteration)\n", + "def _():\n", + " @pl.when(phase == LEFT)\n", + " def _():\n", + " accum_pipeline(x_ref.at[left_copy_device, left_copy_slice],\n", + " hbm_scratch.at[working_slot, left_copy_slice],\n", + " )\n", + " @pl.when(phase == RIGHT)\n", + " def _():\n", + " accum_pipeline(x_ref.at[right_copy_device, right_copy_slice],\n", + " hbm_scratch.at[working_slot, right_copy_slice],\n", + " )\n", + "```\n", + "\n", + "The full kernel is as follows:" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": { + "executionInfo": { + "elapsed": 1341, + "status": "ok", + "timestamp": 1722904807930, + "user": { + "displayName": "Justin Fu", + "userId": "17543197034567316452" + }, + "user_tz": 420 + }, + "id": "27jni-pSartL" + }, + "outputs": [], + "source": [ + "partition = P(None, 'x')\n", + "devices = mesh_utils.create_device_mesh((1, num_devices))\n", + "mesh = jax.sharding.Mesh(devices, partition)\n", + "sharding = jax.sharding.NamedSharding(mesh, partition)\n", + "\n", + "# We pick a large outer kernel block size that we do not want to place\n", + "# in VMEM. For pedagogical purposes we use (4096, 4096), although in\n", + "# principle this can be much larger.\n", + "outer_block_size = (4096, 4096)\n", + "# We pick a smaller VMEM block size for the inner kernel.\n", + "inner_block_size = (128, 128)\n", + "input_arr = jax.random.uniform(\n", + " jax.random.key(0),\n", + " shape=(\n", + " outer_block_size[0] * num_devices,\n", + " outer_block_size[1] * num_devices,\n", + " ),\n", + ")\n", + "input_arr = jax.device_put(input_arr, sharding)\n", + "\n", + "\n", + "inner_grid = (\n", + " outer_block_size[0] // inner_block_size[0] // 2,\n", + " outer_block_size[1] // inner_block_size[1],\n", + ")\n", + "inner_block_spec = pl.BlockSpec(\n", + " index_map=lambda i, j: (i, j),\n", + " block_shape=inner_block_size,\n", + " memory_space=pltpu.TPUMemorySpace.ANY,\n", + ")\n", + "\n", + "\n", + "def reduce_scatter_kernel(\n", + " x_ref,\n", + " o_ref,\n", + " hbm_scratch,\n", + " left_recv_sem,\n", + " left_send_sem,\n", + " copy_sem,\n", + " right_recv_sem,\n", + " right_send_sem,\n", + " left_capacity_sem,\n", + " right_capacity_sem,\n", + "):\n", + " outer_step = pl.program_id(0)\n", + " phase = pl.program_id(1)\n", + " is_start = jnp.logical_and(outer_step == 0, phase == 0)\n", + " last_iteration = outer_step == pl.num_programs(0) - 1\n", + "\n", + " working_slot = lax.rem(outer_step, 2)\n", + " receiving_slot = 1 - working_slot\n", + " my_id = lax.axis_index('x')\n", + " right_neighbor = mod(my_id + 1, num_devices)\n", + " left_neighbor = mod(my_id - 1, num_devices)\n", + "\n", + " left_copy_device = mod(my_id + outer_step + 1, num_devices)\n", + " right_copy_device = mod(my_id - outer_step - 1, num_devices)\n", + " left_copy_slice = pl.ds(0, outer_block_size[0] // 2)\n", + " right_copy_slice = pl.ds(outer_block_size[0] // 2, outer_block_size[0] // 2)\n", + " current_phase_slice = pl.ds(\n", + " phase * (outer_block_size[0] // 2), outer_block_size[0] // 2\n", + " )\n", + "\n", + " initial_left_copy = pltpu.make_async_remote_copy(\n", + " src_ref=x_ref.at[my_id, left_copy_slice],\n", + " dst_ref=hbm_scratch.at[working_slot, left_copy_slice],\n", + " send_sem=left_send_sem,\n", + " recv_sem=left_recv_sem,\n", + " device_id=(0, left_neighbor),\n", + " device_id_type=pltpu.DeviceIdType.MESH,\n", + " )\n", + "\n", + " initial_right_copy = pltpu.make_async_remote_copy(\n", + " src_ref=x_ref.at[my_id, right_copy_slice],\n", + " dst_ref=hbm_scratch.at[working_slot, right_copy_slice],\n", + " send_sem=right_send_sem,\n", + " recv_sem=right_recv_sem,\n", + " device_id=(0, right_neighbor),\n", + " device_id_type=pltpu.DeviceIdType.MESH,\n", + " )\n", + "\n", + " left_copy = pltpu.make_async_remote_copy(\n", + " src_ref=hbm_scratch.at[working_slot, left_copy_slice],\n", + " dst_ref=hbm_scratch.at[receiving_slot, left_copy_slice],\n", + " send_sem=left_send_sem,\n", + " recv_sem=left_recv_sem,\n", + " device_id=(0, left_neighbor),\n", + " device_id_type=pltpu.DeviceIdType.MESH,\n", + " )\n", + " right_copy = pltpu.make_async_remote_copy(\n", + " src_ref=hbm_scratch.at[receiving_slot, right_copy_slice],\n", + " dst_ref=hbm_scratch.at[working_slot, right_copy_slice],\n", + " send_sem=right_send_sem,\n", + " recv_sem=right_recv_sem,\n", + " device_id=(0, right_neighbor),\n", + " device_id_type=pltpu.DeviceIdType.MESH,\n", + " )\n", + "\n", + " # --- Prologue ---\n", + " @pl.when(is_start)\n", + " def _():\n", + " # Barrier with both neighbors at the start, since we will be\n", + " # communicating with both.\n", + " barrier_sem = pltpu.get_barrier_semaphore()\n", + " pltpu.semaphore_signal(\n", + " barrier_sem,\n", + " inc=1,\n", + " device_id=(0, left_neighbor),\n", + " device_id_type=pltpu.DeviceIdType.MESH,\n", + " )\n", + " pltpu.semaphore_signal(\n", + " barrier_sem,\n", + " inc=1,\n", + " device_id=(0, right_neighbor),\n", + " device_id_type=pltpu.DeviceIdType.MESH,\n", + " )\n", + " pltpu.semaphore_wait(barrier_sem, 2)\n", + "\n", + " initial_left_copy.start()\n", + " initial_left_copy.wait()\n", + " initial_right_copy.start()\n", + "\n", + " # We tell our left neighbor that it is allowed to send to the right.\n", + " # (and vice versa for right neighbor)\n", + " signal(LEFT, right_capacity_sem)\n", + " signal(RIGHT, left_capacity_sem)\n", + "\n", + " @pl.when(~is_start)\n", + " def _():\n", + " @pl.when(phase == LEFT)\n", + " def _():\n", + " # We block here until our right neighbor tells use we can send to\n", + " # the right.\n", + " pltpu.semaphore_wait(right_capacity_sem, 1)\n", + " right_copy.start()\n", + "\n", + " @pl.when(phase == RIGHT)\n", + " def _():\n", + " # We block here until our left neighbor tells use we can send to\n", + " # the left.\n", + " pltpu.semaphore_wait(left_capacity_sem, 1)\n", + " left_copy.start()\n", + "\n", + " # --- Body ---\n", + " def inner_kernel(input_ref, accum_ref):\n", + " # We do not explicitly use += because we set should_accumulate_out=True.\n", + " accum_ref[...] = input_ref[...]\n", + "\n", + " accum_pipeline = pltpu.emit_pipeline(\n", + " inner_kernel,\n", + " in_specs=[inner_block_spec],\n", + " out_specs=inner_block_spec,\n", + " should_accumulate_out=True,\n", + " grid=inner_grid,\n", + " )\n", + "\n", + " @pl.when(~last_iteration)\n", + " def _():\n", + " @pl.when(phase == LEFT)\n", + " def _():\n", + " accum_pipeline(\n", + " x_ref.at[left_copy_device, left_copy_slice],\n", + " hbm_scratch.at[working_slot, left_copy_slice],\n", + " )\n", + "\n", + " @pl.when(phase == RIGHT)\n", + " def _():\n", + " accum_pipeline(\n", + " x_ref.at[right_copy_device, right_copy_slice],\n", + " hbm_scratch.at[working_slot, right_copy_slice],\n", + " )\n", + "\n", + " # --- Epilogue ---\n", + " @pl.when(is_start)\n", + " def _():\n", + " initial_right_copy.wait()\n", + "\n", + " @pl.when(~is_start)\n", + " def _():\n", + " @pl.when(phase == LEFT)\n", + " def _():\n", + " right_copy.wait()\n", + " signal(LEFT, right_capacity_sem)\n", + "\n", + " @pl.when(phase == RIGHT)\n", + " def _():\n", + " left_copy.wait()\n", + " signal(RIGHT, left_capacity_sem)\n", + "\n", + " # Store result on last iteration.\n", + " @pl.when(last_iteration)\n", + " def _():\n", + " output_copy = pltpu.make_async_copy(\n", + " src_ref=hbm_scratch.at[working_slot, current_phase_slice],\n", + " dst_ref=o_ref.at[current_phase_slice],\n", + " sem=copy_sem,\n", + " )\n", + " output_copy.start()\n", + " output_copy.wait()\n", + "\n", + " # Clean up semaphores so that they exit with a value of 0.\n", + " @pl.when(phase == LEFT)\n", + " def _():\n", + " pltpu.semaphore_wait(right_capacity_sem, 1)\n", + "\n", + " @pl.when(phase == RIGHT)\n", + " def _():\n", + " pltpu.semaphore_wait(left_capacity_sem, 1)\n", + "\n", + "\n", + "out_shape = (\n", + " jax.ShapeDtypeStruct(\n", + " (outer_block_size[0], outer_block_size[1]), jnp.float32\n", + " ),\n", + " # Shape: [working/recv, block[0], block[1]]\n", + " jax.ShapeDtypeStruct(\n", + " (2, outer_block_size[0], outer_block_size[1]), jnp.float32\n", + " ), # hbm_scratch\n", + ")\n", + "\n", + "grid_spec = pltpu.PrefetchScalarGridSpec(\n", + " num_scalar_prefetch=0,\n", + " in_specs=[\n", + " pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY),\n", + " ],\n", + " out_specs=[\n", + " pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY),\n", + " pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY),\n", + " ],\n", + " grid=(num_devices, 2),\n", + " scratch_shapes=(\n", + " [pltpu.SemaphoreType.DMA] * 5\n", + " + [pltpu.SemaphoreType.REGULAR] * 2 # Capacity semaphores\n", + " ),\n", + ")\n", + "\n", + "\n", + "def pallas_reduce_scatter(input_arr):\n", + " input_arr = input_arr.reshape(\n", + " num_devices, outer_block_size[0], outer_block_size[1]\n", + " )\n", + " return pl.pallas_call(\n", + " reduce_scatter_kernel,\n", + " out_shape=out_shape,\n", + " grid_spec=grid_spec,\n", + " compiler_params=pltpu.TPUCompilerParams(collective_id=0),\n", + " )(input_arr)[0]\n", + "\n", + "\n", + "pallas_result = jax.jit(\n", + " shard_map.shard_map(\n", + " pallas_reduce_scatter,\n", + " mesh=mesh,\n", + " in_specs=P(None, 'x'),\n", + " out_specs=P('x', None),\n", + " check_rep=False,\n", + " )\n", + ")(input_arr)\n", + "\n", + "pallas_result = jax.block_until_ready(pallas_result)" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": { + "executionInfo": { + "elapsed": 768, + "status": "ok", + "timestamp": 1722904808851, + "user": { + "displayName": "Justin Fu", + "userId": "17543197034567316452" + }, + "user_tz": 420 + }, + "id": "cTEyiMDyx9Y0", + "outputId": "1de26695-3713-430e-9ab4-4ea646691680" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Input: (16384, 16384) [0.74162567 0.0242182 0.27751946 ... 0.05213022 0.36088037 0.04494429]\n", + "Pallas Result: (16384, 4096) [2.0648427 1.674587 1.9148926 ... 1.3371865 1.3296283 1.2887063]\n", + "lax.psum_scatter Result: (16384, 4096) [2.0648427 1.674587 1.9148926 ... 1.3371865 1.3296283 1.2887063]\n", + "Difference |Pallas - lax.psum_scatter|: 2.3841858e-07\n" + ] + } + ], + "source": [ + "# Now we compare our result to XLA.\n", + "def lax_reduce_sum_scatter(x):\n", + " x = x.reshape(num_devices, outer_block_size[0], outer_block_size[1])\n", + " return lax.psum_scatter(x, 'x')\n", + "\n", + "\n", + "xla_result = jax.jit(\n", + " shard_map.shard_map(\n", + " lax_reduce_sum_scatter,\n", + " mesh=mesh,\n", + " in_specs=P(None, 'x'),\n", + " out_specs=P('x', None),\n", + " )\n", + ")(input_arr)\n", + "\n", + "print('Input:', input_arr.shape, input_arr[::4, 0])\n", + "print('Pallas Result:', pallas_result.shape, pallas_result[::4, 0])\n", + "print('lax.psum_scatter Result:', xla_result.shape, xla_result[::4, 0])\n", + "print(\n", + " 'Difference |Pallas - lax.psum_scatter|:',\n", + " jnp.max(jnp.abs(pallas_result - xla_result)),\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "zz5AFbriliyv" + }, + "source": [ + "## Final Notes\n", + "\n", + "### Megacore\n", + "\n", + "Certain TPUs contain multiple cores in a [Megacore](https://jax.readthedocs.io/en/latest/pallas/tpu/pipelining.html#tpus-in-megacore-configuration) configuration. In this configuration, our general recommendation is to only initiate DMAs from a single core, and only perform HBM-HBM transfers. To do this, set one of the grid axes to the number of cores (can be obtained via `jax.devices()[0].num_cores`) and the dimension_semantics to `\"parallel\"`. Then, you can use `core_index = pl.program_id(axis)` to obtain the core index along that axis, and use `@pl.when(core_index==i)` to execute code specific to that core.\n", + "\n", + "### Interaction with XLA\n", + "\n", + "In this tutorial we covered several kernel examples which replicate the functionality of collective operations in JAX such as `lax.all_gather`, `lax.psum`, and `lax.psum_scatter`. An important caveat to note is that a Pallas kernel is somewhat opaque to the XLA compiler and may cause it to miss some optimizations it would normally perform. For example, XLA can asynchronously dispatch collective operations in order to interleave communication and computation without writing a custom kernel. This is not guaranteed to happen when Pallas kernels are involved so it is important to profile your program to see if this is an issue. Another example is the fact that the `emit_pipeline` function we used in this tutorial to generate nested pipelines is not visible to the XLA compiler, and therefore cannot be fused with neighboring operations.\n", + "\n", + "### Next Steps\n", + "\n", + "Excellent follow-up excercises for the reader could include implementing a distributed matrix multiplication, implementing `lax.all_to_all`, and relaxing synchronization to allow for additional run-ahead." + ] + } + ], + "metadata": { + "jupytext": { + "formats": "ipynb,md:myst", + "main_language": "python" + }, + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.3" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/docs/pallas/tpu/distributed.md b/docs/pallas/tpu/distributed.md new file mode 100644 index 000000000000..fc3f929866bd --- /dev/null +++ b/docs/pallas/tpu/distributed.md @@ -0,0 +1,1527 @@ +--- +jupytext: + formats: ipynb,md:myst + main_language: python + text_representation: + extension: .md + format_name: myst + format_version: 0.13 + jupytext_version: 1.16.4 +kernelspec: + display_name: Python 3 (ipykernel) + language: python + name: python3 +--- + ++++ {"id": "zSNjLhGQJMgq"} + +# Distributed Computing in Pallas for TPUs + +In this tutorial, we will cover the basics of distributed computing in Pallas on TPUs. We will learn about TPU topologies, communication using the remote DMA primitive, and calling a distributed kernel from JAX using `shard_map`. We will also cover some more advanced kernel writing techniques, such as double-buffering, bi-directional bandwidth optimization, and nested pipelining. As educational examples, we will learn how to implement various collective primitives from JAX, such as `lax.ppermute`, `lax.all_gather`, `lax.psum`, and `lax.psum_scatter`. + +Some recommended readings beforehand: + - [Pallas Pipelining on TPU](https://jax.readthedocs.io/en/latest/pallas/tpu/pipelining.html) + - [Collectives with `shard_map`](https://jax.readthedocs.io/en/latest/notebooks/shard_map.html#collectives-tutorial) + +```{code-cell} ipython3 +--- +executionInfo: + elapsed: 1978 + status: ok + timestamp: 1722904801801 + user: + displayName: Justin Fu + userId: '17543197034567316452' + user_tz: 420 +id: PyAGnWc9yI8T +outputId: 1d8229bd-cab5-495f-93e9-fff2e41db480 +--- +import jax +from jax import lax +from jax import numpy as jnp +from jax.experimental import mesh_utils +from jax.experimental import pallas as pl +from jax.experimental import shard_map +from jax.experimental.pallas import tpu as pltpu + +P = jax.sharding.PartitionSpec + +num_devices = jax.local_device_count() +assert num_devices > 1, "Please run this notebook with more than one device." +assert "TPU" in jax.devices()[0].device_kind, "Please run this notebook with TPU devices." +print(f"Running with {num_devices} {jax.devices()[0].device_kind} devices.") +``` + ++++ {"id": "DySMGNByclMi"} + +## TPU Topologies + +TPUs are typically deployed in pods of multiple devices connected via a high-bandwidth interchip interconnect (ICI) for communication within the pod that is much faster than a typical network connection. For example, the specifications sheet for a [TPU v5p](https://cloud.google.com/tpu/docs/v5p) states an ICI bandwidth of 4.8Tb/s per chip (for reference, TPU v5p also has 21Tb/s of *local* HBM bandwidth). The ICI allows us to implement fast and performant distributed kernels that require high-bandwidth communication within a pod, and use the datacenter network for parallelization over less bandwidth-intensive operations, such as data-parallelism over a batch dimension. + +TPUs pods are typically arranged in an ND torus topology. The following graphic gives several examples of configurations of different sizes. + +![tpu_topologies](https://cloud.google.com/static/tpu/docs/images/v4-topologies.png) + +Flattened as a graph, the torus can be visualized as follows. Each edge (orange or black) is a bidirectional connection between two devices. You will commonly hear about rings in conjunction with discussion about device toplogies — a key feature of a torus is that when taking a slice along an axis of the pod, such as the nodes `[(0,1), (1, 1), (2, 1), (3, 1)]` or `[(0, 1), (1, 1)]`, we have a ring of devices. This is a feature we can use to simplify communication patterns within the pod. + +![tpu_torus](https://cloud.google.com/static/tpu/docs/images/untwisted-tori.png) + ++++ {"id": "1Oc_WD1hChfN"} + +## Remote Direct Memory Access (RDMA) Model + +TPUs communicate via a push-only model known as a remote direct memory access (RDMA). A TPU is allowed to issue copy instruction to push from a local buffer to any buffer on another device within the same pod that executes asynchronously from the main program thread. However, a TPU can only read data that is stored locally. This is in contrast to more traditional multi-core programming where it is possible to both read from and write to values to a shared memory. + +### Async Remote Copy Operation +The `pltpu.make_async_remote_copy` function is used to create a remote DMA descriptor object which parameterizes both a "send" operation and a "receive" operation. Here's its signature: + +```python + def make_async_remote_copy( + src_ref: Ref, + dst_ref: Ref, + send_sem: Ref[SemaphoreType], + recv_sem: Ref[SemaphoreType], + device_id: int | tuple[int, ...], + device_id_type: DeviceIdType + ) -> AsyncCopyDescriptor: +``` + +- `src_ref` is the local `Ref` (in any memory space) containing the data you wish to send to `dst_ref` on another device. +- `dst_ref` is the remote `Ref` (in any memory space) at which data will be copied to on the target device. +- `send_sem` is a DMA semaphore used to block until all data has been sent from `src_ref`. +- `recv_sem` is a DMA semaphore used to block until the expected number of bytes have been received at `dst_ref`. The sender of the DMA will write to the receiver's `recv_sem`. +- `device_id` is the device ID of the target device to send to. +- `device_id_type` specifies the format of `device_id`, which can either be in LOGICAL format (integer device ID), or in MESH format (an ND-tuple index into the logical device mesh). The default mode is MESH. + +`make_async_remote_copy` returns a descriptor object on which you use the `.start()` method to initiate the DMA, and the `.wait_send()` to block on `send_sem` and `.wait_recv()` to block on `recv_sem` (or `.wait()` to block on both). If a device is only expected to send data, it is sufficient to only call `.start()` and `.wait_send()`, and likewise if a device is only receiving it is sufficient to only call `.wait_recv()`. If using a SPMD pattern where all devices execute the DMA, each device will generally call both `.start()` and `.wait()`. +```python +dma_descriptor = make_async_remote_copy(src_ref, dst_ref, send_sem, recv_sem, device_id) +dma_descriptor.start() # Initiate the DMA (non-blocking). +# ... do other work +dma_descriptor.wait_send() # Block until all data has been sent. +dma_descriptor.wait_recv() # Block until all data has been received. +``` + +As an example, let's visualize a DMA where we consider 4 devices (indexed 0, 1, 2, 3). We consider a scheme where device 0 copies to device 1, and device 2 & 3 copy to each other. In practice, we can create such an asymmetric communication pattern by using `@pl.when` to branch on the device ID. + +(1) Each device creates the DMA descriptor. Devices 0, 2, and 3 call `.start()` to initiate the DMA from `src_ref`. Device 1 is skips the `.start()` and does nothing, e.g. by using `pl.when`. + +![rdma_start](../../_static/pallas/distributed/rdma_start.svg) + +(2) As `.start()` is non-blocking, each device is free to do other computation while the DMA is in flight. Devices 0, 2, and 3 call `.wait_send()` to wait on `send_sem` which blocks until all data has been sent. + +![rdma_send](../../_static/pallas/distributed/rdma_send.svg) + +(3) Finally, devices 1, 2, and 3 will call `.wait_recv()` to wait on `recv_sem` until all data has arrived at `dst_ref`. + +![rdma_recv](../../_static/pallas/distributed/rdma_recv.svg) + +The above communication pattern can be written as follows: +```python +def example_kernel(input_ref, output_ref, send_sem, recv_sem): + device_id = lax.axis_index('x') + copy_0_to_1 = pltpu.make_async_remote_copy( + src_ref=input_ref, + dst_ref=output_ref, + send_sem=send_sem, + recv_sem=recv_sem, + device_id=1, + ) + copy_2_to_3 = pltpu.make_async_remote_copy( + src_ref=input_ref, + dst_ref=output_ref, + send_sem=send_sem, + recv_sem=recv_sem, + device_id=3, + ) + copy_3_to_2 = pltpu.make_async_remote_copy( + src_ref=input_ref, + dst_ref=output_ref, + send_sem=send_sem, + recv_sem=recv_sem, + device_id=2, + ) + @pl.when(device_id == 0) + def _(): + copy_0_to_1.start() + copy_0_to_1.wait_send() + @pl.when(device_id == 1) + def _(): + copy_0_to_1.wait_recv() + @pl.when(device_id == 2) + def _(): + copy_2_to_3.start() + copy_2_to_3.wait_send() + copy_3_to_2.wait_recv() + @pl.when(device_id == 3) + def _(): + copy_3_to_2.start() + copy_3_to_2.wait_send() + copy_2_to_3.wait_recv() +``` + +### DMA Semaphores + +`send_sem` and `recv_sem` are instances of a special type of semaphore reserved exclusively for use with DMAs. They must be allocated with the `tpu.SemaphoreType.DMA` type when specifying input specs to `pallas_call`. + +Internally, DMA semaphores can be thought of as integer-valued progress trackers. On DMA start, the local device will begin to increment the value of `send_sem` and the receiver's `recv_sem` asynchronously. Waiting on a semaphore will block until the value of the semaphore reaches the total bytes of data sent/received; when the value is reached, waiting threads are released and the sempahore's value is decremented by the same amount. This means that either all data has been sent (for `send_sem`) or all data has been received (for `dst_sem`). The value of the semaphore can be read with `pl.semaphore_read`, but note that the underlying semantics of the value could change between hardware generations (e.g. the value may not represent exactly the number of bytes sent, although this is a useful mental model to have when reasoning about the behavior of the semaphore). + +### Routing + +A sender is allowed to send data to any receiver within the same pod, even if they do not share a direct connection (the exception to this rule is for TPU v5e, where devices can only route to a power of 2 offset from themselves). TPUs have an internal routing mechanism which can pass data along to the next device on the path to the destination. However, communicating in this way is not recommended as you have no control over network contention as a kernel writer. The examples we will cover in this tutorial minimize inefficient communication by only transferring data to neighboring devices. + +### Failure modes + +If using remote DMAs incorrectly, you may encounter several failure modes which can be difficult to debug. The general symptoms of buggy DMA usage are crashes, hanging, or silent data corruption: +- If semaphores exit the program with an invalid non-zero value, Pallas will crash and exit the program. +- If semaphores are waited on but an insufficient number of bytes are received (i.e. there is no sender, or if the sent data is less than the size of `dst_ref` on the receiving device), the program may hang indefinitely waiting for bytes that are never sent. In this case the program would need to be restarted. +- If encountering a race condition, there could be silent data corruption if two simultaneous writes or a simultaneous read and write occur. + +Some common causes of the above include: +- If a device calls `.wait_recv()` but no other device sends to it, the kernel may hang. +- If a device is sent a more bytes than it expected to receive, it may also crash due to non-zero semaphore states. If sent less, it may hang indefinitely. +- If DMAs are started but the semaphores are not waited on, the program may crash due to non-zero semaphore states. +- If two devices copy to the same destination, you may encounter non-deterministic results due to a race condition, or crashing due to non-zero semaphore states. + ++++ {"id": "vpGSN1Sui0Bu"} + +### Example: Right Permute (`lax.ppermute`) + +Let's dive into a very basic example. We will implement a kernel that performs a right permutation, where each device sends its slice of the data to its right neighbor. + +Suppose we had an array with 512 elements, which we shard into slices of size 128 across 4 devices. Each device will pass its slice to the next device, and the output will consist of the same data, but with the slices rotated by 1. This is identical to the `lax.ppermute` operation where the permutation is set to `(n, (n+1) % 4)`. + +In order to call the kernel in distributed mode, we wrap the `pallas_call` in a `shard_map` transformation. From there, we can write the kernel the same way as you would write a normal single-device Pallas kernel, except we now have access to remote DMA instructions. JAX collective primitives such as `lax.axis_index` can be used to obtain a `device_id` that can be used to compute which target devices to copy to, by referencing the same named axes names passed into `shard_map`. + +```{code-cell} ipython3 +--- +executionInfo: + elapsed: 1606 + status: ok + timestamp: 1722904803566 + user: + displayName: Justin Fu + userId: '17543197034567316452' + user_tz: 420 +id: YkyIKN2thZ-V +outputId: 9b7ed142-d161-4237-fed8-cbce41adc5f0 +--- +partition = P(None, 'x') +devices = mesh_utils.create_device_mesh((1, num_devices)) +mesh = jax.sharding.Mesh(devices, partition) +sharding = jax.sharding.NamedSharding(mesh, partition) + +# Create an input array that shards the last dimension across +# all devices. +input_arr = jax.random.uniform(jax.random.key(0), (8, 128 * num_devices)) +input_arr = jax.device_put(input_arr, sharding) + + +def right_permute_kernel(input_ref, output_ref, send_sem, recv_sem): + my_id = lax.axis_index('x') + right_neighbor = lax.rem(my_id + 1, num_devices) + remote_copy_op = pltpu.make_async_remote_copy( + src_ref=input_ref, + dst_ref=output_ref, + send_sem=send_sem, + recv_sem=recv_sem, + device_id=(0, right_neighbor), + device_id_type=pltpu.DeviceIdType.MESH, + ) + remote_copy_op.start() + remote_copy_op.wait() + + +out_shape = jax.ShapeDtypeStruct((8, 128), jnp.float32) +grid_spec = pltpu.PrefetchScalarGridSpec( + num_scalar_prefetch=0, + # TPUMemorySpace.ANY will (usually) place the tensor in HBM. + in_specs=[ + pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY), + ], + out_specs=pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY), + scratch_shapes=( + # We allocate DMA semaphores in scratch memory. + [pltpu.SemaphoreType.DMA] * 2 + ), +) +right_permute = pl.pallas_call( + right_permute_kernel, + out_shape=out_shape, + grid_spec=grid_spec, +) +# Wrap the kernel within a shard_map to call. +pallas_result = jax.jit( + shard_map.shard_map( + right_permute, + mesh=mesh, + in_specs=partition, + out_specs=partition, + check_rep=False, + ) +)(input_arr) + +# Compare Pallas result to XLA shard_map result. +perm = tuple((src, (src + 1) % num_devices) for src in range(num_devices)) + +xla_result = jax.jit( + shard_map.shard_map( + lambda x: lax.ppermute(x, 'x', perm), + mesh=mesh, in_specs=partition, out_specs=partition) +)(input_arr) + +print('Input = ', input_arr[0, ::128]) +print('Pallas Result = ', pallas_result[0, ::128]) +print('lax.ppermute Result = ', xla_result[0, ::128]) +print( + 'Difference |Pallas - lax.ppermute| = ', + jnp.mean(jnp.abs(pallas_result - xla_result)), +) +``` + ++++ {"id": "iyfhdGXuUnq2"} + +### Example: All-gather (`lax.all_gather`) + +In this next example we will implement the all-gather collective operation, which has a JAX equivalent in `lax.all_gather`. In contrast with the right-permute example from above which only involves a pair of source and destination neighbors, an all-gather operation requires communication between all devices and therefore we must think about how data is routed between them. The specifics of how we implement this are dictated by the device topology, for which we assume is a ring. + +#### Ring Communication Pattern + +We will write our kernel assuming a ring topology. Rings are a natural fit for TPUs as slicing along any dimension of a torus produces a ring. When writing collectives, we often only need to think about 1D slices of our torus at a time because the different dimensions of the torus are reserved for different types of parallelism (data vs. model, for example). + +The strategy we will use is to write a looped kernel, where on each iteration a device receives one slice of the sharded array from its left neighbor, and copies the previously received slice to its right neighbor. After `num_devices` iterations, each device will have a copy of the entire array in its local HBM. + +![all_gather](../../_static/pallas/distributed/all_gather.svg) + +We can re-purpose Pallas's `grid` argument to implement the loop. Rather than iterating over tiles of an array as we have done in previous tutorials, we instead set the grid to `(num_devices,)` to indicate that we want to loop over the number of devices and use `pl.program_id` to obtain the loop iteration inside of the Pallas kernel. The following code snippet demonstrates how to implement this: + +```{code-cell} ipython3 +--- +executionInfo: + elapsed: 812 + status: ok + timestamp: 1722904804531 + user: + displayName: Justin Fu + userId: '17543197034567316452' + user_tz: 420 +id: ojQEZB5mBRqM +outputId: e1648f54-737c-4921-ca3b-b4c639a38d2b +--- +partition = P('x', None) +devices = mesh_utils.create_device_mesh((num_devices, 1)) +mesh = jax.sharding.Mesh(devices, partition) +sharding = jax.sharding.NamedSharding(mesh, partition) + +# Create an input array that shards the first dimension across +# all devices. +input_arr = jax.random.uniform(jax.random.key(0), (8 * num_devices, 128)) +input_arr = jax.device_put(input_arr, sharding) + + +def all_gather_kernel(input_ref, + output_ref, + local_copy_sem, + send_sem, + recv_sems): + outer_step = pl.program_id(0) + my_id = lax.axis_index('x') + right_neighbor = lax.rem(my_id + 1, num_devices) + copy_slot = my_id - outer_step + copy_slot = lax.rem(copy_slot + num_devices, num_devices) + + @pl.when(outer_step == 0) + def _(): + local_copy_op = pltpu.make_async_copy( + src_ref=input_ref, + dst_ref=output_ref.at[my_id], + sem=local_copy_sem, + ) + local_copy_op.start() + local_copy_op.wait() + + # Copy to our right neighbor. + # Note that we will also be receiving data from our left neighbor, + # but at `copy_slot-1` rather than `copy_slot`! This makes use of the fact + # that the indices do not need to be symmetric between remote DMAs. + remote_copy_op = pltpu.make_async_remote_copy( + src_ref=output_ref.at[copy_slot], + dst_ref=output_ref.at[copy_slot], + send_sem=send_sem, + recv_sem=recv_sems.at[outer_step], + device_id=(right_neighbor, 0), + device_id_type=pltpu.DeviceIdType.MESH, + ) + remote_copy_op.start() + remote_copy_op.wait() + +out_shape = jax.ShapeDtypeStruct((num_devices, 8, 128), jnp.float32) +grid_spec = pltpu.PrefetchScalarGridSpec( + num_scalar_prefetch=0, + in_specs=[ + # TPUMemorySpace.ANY will (usually) place the tensor in HBM. + pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY), + ], + out_specs=pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY), + scratch_shapes=( + # DMA semaphores are allocated in scratch memory. + # We allocated one semaphore for a local HBM-VMEM copy, + # and one for the remote send semaphore. + [pltpu.SemaphoreType.DMA] * 2 + # We additionally allocate one receive semaphore per device. + # This is to avoid situations where we have multiple + # DMAs in flight, as we do not want to share a receive + # semaphore between the DMAs. + + [pltpu.SemaphoreType.DMA((num_devices-1,))] + + ), + grid=(num_devices-1,) + ) + +all_gather = pl.pallas_call( + all_gather_kernel, + out_shape=out_shape, + grid_spec=grid_spec, + ) + +# Wrap the kernel within a shard_map to call. +pallas_result = jax.jit( + shard_map.shard_map( + all_gather, + mesh=mesh, + in_specs=partition, + out_specs=partition, + check_rep=False + ) +)(input_arr) + +# Compare Pallas result to XLA shard_map result. +xla_result = jax.jit( + shard_map.shard_map( + lambda x: lax.all_gather(x, 'x'), + mesh=mesh, in_specs=partition, out_specs=partition + ) +)(input_arr) + +print('Input: ', input_arr.shape, input_arr[::8, 0]) +print('Pallas Result: ', pallas_result.shape, pallas_result[:, 0, 0]) +print('lax.all_gather Result: ', xla_result.shape, xla_result[:, 0, 0]) +print('Difference |Pallas - lax.all_gather| = ', + jnp.mean(jnp.abs(pallas_result - xla_result))) +``` + ++++ {"id": "KgU7HI2pS4om"} + +A detail worth mentioning here is the use of multiple receive semaphores. Because we only block on the receiving device, it is still possible for a sender to have sent multiple DMAs in flight before the receiver has finished processing the first one (see the next section and reduce-sum example which discusses race conditions in more detail). In this situation we may hit a situation where the same semaphore is being used for multiple DMAs occurring simultaneously. To avoid this, we allocate `num_devices-1` semaphores so there is no risk of re-use. While this race condition is unlikely to happen on such a small kernel, on larger kernels there is more chance for devices to fall out of sync and potentially cause a silent failure. + ++++ {"id": "KgU7HI2pS4om"} + +## Advanced Techniques + +Now that we have seen how to write several basic kernels using remote DMA operations, we will go over more advanced techniques for synchronization and writing efficient kernels. + ++++ {"id": "8M_kdl0FCtrL"} + +### Synchronization: Regular and Barrier Semaphores + +The examples we implemented in the basic tutorial do not require special handling of synchronization as all necessary communication writes to disjoint buffers. However, other operations may require more complex communication patterns that need additional synchronization primitives to avoid race conditions. Pallas provides two additional primitives to help with this: regular and barrier semaphores. + +#### Regular Semaphores + +Regular semaphores are the standard tool used to synchronize across multiple devices. Semaphores are fundamentally counters - they can be incremented by any device after which a device can block until the value of the semaphore reaches a specific value (and then decrement the value). + +The three main operations that can be used on regular semaphores are signal, wait, and read: +```python +def semaphore_signal( + sem: Ref[SemaphoreType], + inc: int, + device_id: int | tuple[int, ...], + device_id_type: DeviceIdType +) -> None: + ... # Increments the semaphore `sem` on the target device `device_id` by `inc`. + +def semaphore_wait( + semaphore: Ref[SemaphoreType], + value: int, +) -> None: + ... # Blocks until the locally allocated copy of `sem` reaches `value`, then decrement by `value` and proceed. + +def semaphore_read( + sem: Ref[SemaphoreType], +) -> jax.Array: + ... # Returns the current value of `sem` as an `int32[]`. +``` + +In order to use regular semaphores, they can be allocated in the same way as a DMA semaphore, but by specifying `pltpu.SemaphoreType.REGULAR` rather than `pltpu.SemaphoreType.DMA`. + +Semaphores must be zero at the end of a Pallas program to complete succesfully. There are two error cases where this may happen: + - If a semaphore is over-signaled, the program will end with non-zero (>0) semaphores. In this case, the program will crash upon completion. This is useful for debugging as non-zero semaphores typically means there is a bug somewhere inside of the program. + - If a semaphore is over-waited, the program will hang on the blocking `semaphore_wait` call while it waits for the sempahore to be incremented. In this case the device or program will need to be restarted. + +#### Barrier Semaphores + +Barrier semaphores are globally-allocated semaphores used to synchronize devices across an entire program and ensure that all devices have entered the Pallas kernel. + +If a Pallas kernel is executed within the context of a larger XLA program, we need to ensure that all devices that communicate have entered the kernel. However, DMA and regular semaphores are both locally scoped - they are only understood by other devices that have entered the kernel. Barrier semaphores serve as a globally understood semaphore that can be used for synchronization no matter where in the XLA program the device is currently executing. + +By default, if you do not specify a barrier semaphore, Pallas will automatically insert a barrier semaphore at the beginning of your program. However, it can be more efficient to write your own. Barrier semaphores are similar to regular semaphores in that they are counters that can be incremented via `semaphore_signal` and can be decremented via `semaphore_wait`. They are created by calling `get_barrier_semaphore()` within a kernel. Typically, we use barriers once at the beginning of a kernel to synchronize with all devices we are communicating with. + +```python +from jax.experimental.pallas import tpu as pltpu + +def example_kernel(...): + # Use barrier semaphores at the beginning of a kernel. + # is_start_of_kernel = ... + # right_neighbor = ... + # ... + @pl.when(is_start_of_kernel) + def _(): + barrier_sem = pltpu.get_barrier_semaphore() + # Increment the semaphore of your right neighbor. + pltpu.semaphore_signal( + barrier_sem, + device_id=right_neighbor, + device_id_type=pltpu.DeviceIdType.LOGICAL, + ) + # Wait until your left neighbor has incremented your semaphore + pltpu.semaphore_wait(barrier_sem, 1) + # ... +``` + +When using barrier semaphores, the `collective_id` compiler parameter must be passed to `pallas_call` to specify which barrier semaphore is being used. A TPU has a small, fixed number of barrier semaphores available (typically on the order of 20-30) and therefore they should be used sparingly. In order to ensure correctness, only kernels that share the same communication pattern should use the same `collective_id`. For example, if two kernels synchronize only with neighbors on the same mesh axis, they are allowed to share the same `collective_id`. However, if two kernels synchronize along different axes, they must have different `collective_id`s. Failure to do so may result in race conditions that are difficult to debug. + +```python +kernel = pl.pallas_call( + example_kernel, + ..., + compiler_params=pltpu.TPUCompilerParams(collective_id=0), +) +``` + ++++ {"id": "zy20AxN5TSLA"} + +### Double-buffering + +In order to avoid reading from a local `Ref` that is also being written into by another device and creating a race condition, a useful technique is the "double-buffered" strategy where we allocate a two `Ref`s for each destination value. On each iteration, one `Ref` will be designated as a "working" slot, and the other will be designated as a "receiving" slot. The device is free to use the working slot for computation, but will only copy data into its neighbor's receiving slot. The working and receiving slots alternate every iteration, so that once a copy is finished, the old receiving slot becomes the new working slot, and vice versa. Using this scheme properly, data is never read from and written to the same buffer. + +The following code skeleton demonstrates how double-buffering can be used. We keep a running iteration counter in the variable `iteration`, and the `working_slot` and `receiving_slot` alternate between 0 and 1 every iteration. `dst_ref` is allocated as a double-buffer and has the size `[2, ...]`. On each iteration, we read from the working slot using `dst_ref.at[working_slot, ...]` and use the value to perform computation. Simultaneously, we copy to our neighbor's `dst_ref.at[receiving_slot]` to avoid overwriting their `working_slot` value. By structuring our communication in this fashion it is possible to overlap the communication latency of the remote DMA with local computation while minimizing the risk of race conditions. +```python +def kernel(...): + # ... + iteration = pl.program_id(0) + working_slot = lax.rem(iteration, 2) + receiving_slot = 1 - working_slot + # ... + + local_copy_op = pltpu.make_async_copy( + src_ref=dst_ref.at[working_slot, ...], + dst_ref=local_scratch_ref, + sem=local_copy_sem, + ) + local_copy_op.start() + remote_copy_op = pltpu.make_async_remote_copy( + src_ref=src_ref, + dst_ref=dst_ref.at[receiving_slot, ...], + send_sem=send_sem, + recv_sem=recv_sem, + device_id=target_device, + device_id_type=pltpu.DeviceIdType.MESH, + ) + remote_copy_op.start() + + local_copy_op.wait() + # ... do work on local_scratch while waiting for async_copy_op to finish. + remote_copy_op.wait() + +``` + +In terms of synchronization, the double-buffered construction works if all devices are executing on the same iteration. If a sender manages to get one iteration ahead of its receiver, it's `working_slot` and `receiving_slot` indices will be flipped compared to the receiver, meaning that it could be writing into the `working_slot` at the same time the receiver is reading from it. In order to avoid this, it may be necessary to use a semaphore to synchronize the sender with the receiver, or add additional buffering slots ("triple", "quadruple", or N-buffered) to allow additional run-ahead at the cost of more memory. In our previous `all_gather` example, note that the kernel contained a receiving buffer with N slots, which avoids race conditions altogether. In our next kernel, we will instead go through an example which uses a double-buffer with explicit synchronization. + ++++ {"id": "Or0Itv72No5d"} + +### Example: All-Reduce Sum (`lax.psum`) + +We will now implement an all-reduce sum kernel using double-buffering and semaphores for synchronization. For those familiar with collective operations in JAX, the equivalent operation is `lax.psum`. All-reduce is a standard collective operation where the objective is to reduce along an axis of an array, but the array is sharded across multiple devices. + +![reduce_sum_1](../../_static/pallas/distributed/reduce_sum_1.svg) + +In the above example, we have the array [5, 2, 1, 3] sharded across 4 devices. An all-reduce sum operation would sum all values and replicate the result on each device, leading to the result [11, 11, 11, 11] sharded across all 4 devices. + +The naive implementation of all-reduce would be to gather all required values onto each device, and then reduce. However, we can improve the performance of this implementation by interleaving communication with computation. An interleaved, single-direction all-reduce can be visualized as follows. On each iteration, we receive an input value from our left neighbor, and concurrently pass input along to our next neighbor while incrementing it with our local accumulator. After N-1 iterations, each device will have a copy of the full sum in it's memory. + +![reduce_sum_2](../../_static/pallas/distributed/reduce_sum_2.svg) + +#### Putting it all together + +The following kernel demonstrates how to combine these principles into a functional kernel. + +The prologue (executed when `outer_step==0`) first initiates a barrier with both neighbors to ensure that they have also entered the kernel. It also handles initialization for all `Ref`s and handles the first remote copy to the right neighbor's "working" slot. + +The main body assumes that a value has already been copied into our local working slot, either from the previous iteration or from the prologue. A complicating factor is that our destination buffers live in HBM, but we need to load values to VMEM before we perform arithmetic. Therefore, we simultaneously copy the working slot value into our VMEM (`receive_scratch`) and pass the value on to our right neighbor's receiving slot. Once the value has been copied into our VMEM, we can accumulate it into our result (contained in `o_ref`). + +A subtle race condition can occur if one device runs one loop ahead of it's right neighbor. In this case, it could copy into the receiver's `working_slot` at the same time the receiver is reading from it. In order to avoid this, each device will block on a `REGULAR` semaphore before copying into the right neighbor's `dst_ref` until it has signaled that it is done reading from its `working_slot`. This race condition is rarely triggered for a small kernel such as this example, but can it can be explicitly triggered if for example using a `pltpu.delay` instruction to artifically hang a device. + +Note that this is not an optimal or fully general kernel, as the block sizes must entirely fit in VMEM and we could better interleave communication and accumulation. We will discuss these optimizations in later sections. + +```{code-cell} ipython3 +--- +executionInfo: + elapsed: 254 + status: ok + timestamp: 1722904804952 + user: + displayName: Justin Fu + userId: '17543197034567316452' + user_tz: 420 +id: XrY5bMlvBroQ +outputId: 77497000-4496-462e-cc3c-73fb640cc14c +--- +partition = P(None, 'x') +devices = mesh_utils.create_device_mesh((1, num_devices)) +mesh = jax.sharding.Mesh(devices, partition) +sharding = jax.sharding.NamedSharding(mesh, partition) + +input_arr = jax.random.uniform(jax.random.key(0), shape=(8, 128 * num_devices)) +input_arr = jax.device_put(input_arr, sharding) + + +def all_reduce_kernel( + x_ref, + o_ref, + hbm_scratch, + copy_sem, + remote_recv_sem, + remote_send_sem, + capacity_sem, + receive_scratch, +): + outer_step = pl.program_id(0) + working_slot = lax.rem(outer_step, 2) + receiving_slot = 1 - working_slot + + my_id = lax.axis_index('x') + right_neighbor = lax.rem(my_id + 1, num_devices) + left_neighbor = lax.rem(my_id - 1 + num_devices, num_devices) + + @pl.when(outer_step == 0) + def _(): + # Barrier with both neighbors at the start, since we will be + # communicating with both. + barrier_sem = pltpu.get_barrier_semaphore() + pltpu.semaphore_signal( + barrier_sem, + inc=1, + device_id=(0, left_neighbor), + device_id_type=pltpu.DeviceIdType.MESH, + ) + pltpu.semaphore_signal( + barrier_sem, + inc=1, + device_id=(0, right_neighbor), + device_id_type=pltpu.DeviceIdType.MESH, + ) + pltpu.semaphore_wait(barrier_sem, 2) + + # Initialize o_ref, acc_scratch, and hbm_scratch. + o_ref[...] = jnp.zeros_like(o_ref) + receive_scratch[...] = jnp.zeros_like(receive_scratch) + initial_copy = pltpu.make_async_remote_copy( + src_ref=x_ref, + dst_ref=hbm_scratch.at[working_slot], + send_sem=remote_send_sem, + recv_sem=remote_recv_sem, + device_id=(0, right_neighbor), + device_id_type=pltpu.DeviceIdType.MESH, + ) + initial_copy.start() + initial_copy.wait() + + # Signal to our left neighbor that we are ready to receive. + # Without this signal, our left neighbor can be >=1 iteration ahead, + # meaning it could write into our working slot. + pltpu.semaphore_signal( + capacity_sem, + inc=1, + device_id=(0, left_neighbor), + device_id_type=pltpu.DeviceIdType.MESH, + ) + + # Copy the partial result our left neighbor sent to us into VMEM for + # computation. + local_copy = pltpu.make_async_copy( + src_ref=hbm_scratch.at[working_slot], + dst_ref=receive_scratch, + sem=copy_sem, + ) + local_copy.start() + + # Block until our right neighbor is ready to receive. + pltpu.semaphore_wait(capacity_sem, 1) + # Pass the value to our right neighbor. + remote_copy = pltpu.make_async_remote_copy( + src_ref=hbm_scratch.at[working_slot], + dst_ref=hbm_scratch.at[receiving_slot], + send_sem=remote_send_sem, + recv_sem=remote_recv_sem, + device_id=(0, right_neighbor), + device_id_type=pltpu.DeviceIdType.MESH, + ) + remote_copy.start() + # Finish local copy and accumulate while remote_copy is happening. + local_copy.wait() + o_ref[...] += receive_scratch[...] + # Block until remote copy finishes. + remote_copy.wait() + + +out_shape = ( + jax.ShapeDtypeStruct((8, 128), jnp.float32), + # We allocate the double-buffer as a Pallas output so that it is + # resident in HBM. + jax.ShapeDtypeStruct((2, 8, 128), jnp.float32), # hbm_scratch +) + +grid_spec = pltpu.PrefetchScalarGridSpec( + num_scalar_prefetch=0, + in_specs=[ + # Our input lives in VMEM + pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.VMEM), + ], + out_specs=[ + # Our output lives in VMEM + pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.VMEM), + # Our double-buffer lives in HBM + pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY), + ], + grid=(num_devices,), + scratch_shapes=( + [pltpu.SemaphoreType.DMA] * 3 + + [pltpu.SemaphoreType.REGULAR] # capacity_sem + + [pltpu.VMEM((8, 128), jnp.float32)] # receive_scratch + ), +) + +kernel = pl.pallas_call( + all_reduce_kernel, + out_shape=out_shape, + grid_spec=grid_spec, + compiler_params=pltpu.TPUCompilerParams(collective_id=0), +) + +pallas_result = jax.jit( + shard_map.shard_map( + kernel, + mesh=mesh, + in_specs=partition, + out_specs=partition, + check_rep=False, + ) +)(input_arr) +pallas_result = jax.block_until_ready(pallas_result)[0] + + +def lax_sum(x): + return lax.psum(x, 'x') + + +xla_result = jax.jit( + shard_map.shard_map( + lax_sum, mesh=mesh, in_specs=P(None, 'x'), out_specs=P(None, 'x') + ) +)(input_arr) + +print('Input = ', input_arr[0, ::128]) +print('Pallas result = ', pallas_result[0, ::128]) +print('lax.psum result = ', xla_result[0, ::128]) +difference = jnp.mean(jnp.abs(pallas_result - xla_result)) +print('Difference |Pallas - lax.psum| = ', difference) +``` + ++++ {"id": "d8bsZAzQreC_"} + +### Run-ahead and Race Conditions + +As a general rule of thumb, to maximize performance we want to allow a device to run-ahead of other devices without synchronization as much as possible without sacrificing correctness of the program. While we could enforce a barrier across all devices at the beginning of each iteration, this bottlenecks the performance of the program to the slowest device on each loop. By relaxing synchronization and allowing a moderate amount of run-ahead, we can better accommodate variance in latency between iterations and devices because a device that is slow on one iteration could catch up on the next iteration. + +In the all-reduce kernel we wrote previously, we allow devices to run ahead but by less than one iteration compared to its neighbors (however, non-neighboring devices could be more than 1 iteration apart). To see why the semaphore synchronization is necessary, consider the case when one device (say device 2) hangs and falls behind the other devices. An RDMA has no "handshake" — only the receiver is blocked while waiting for the data to arrive. Therefore, each device can run up to one iteration ahead before it becomes blocked waiting for the next RDMA to arrive. If we have N devices, this means that the final device can be up to N iterations ahead of the first device. + +![race_condition](../../_static/pallas/distributed/race_condition.svg) + +Without adding synchronization in the other direction (forcing senders to block), device 1 could potentially run up to `N` iterations (`N = num_devices`) ahead of device 2, sending multiple writes and overwriting values in the process. To solve this in the `all_reduce` kernel we wrote previously we implemented a "handshake" protocol where the receiver signals back to the sender that it is ready to receive, and only then does the sender begin issuing the next RDMA. + ++++ {"id": "UD8lNrqsUeXy"} + +### Bi-directional Communication + +In our previous kernels, we communicated in a single direction around a ring from left-to-right. However, as ICI connections are bi-directional, we are effectively wasting half of the total bandwidth by not sending values in the opposite direction from right-to-left. In this next kernel we will demonstrate an example which communicates in both directions to maximize ICI bandwidth. + ++++ {"id": "4KjakLhbBk73"} + +### Example: Bi-directional Reduce-Scatter (`lax.psum_scatter`) + +A reduce-scatter operation is the combination of an all-reduce followed by a scatter. Or alternatively, an all-reduce is the combination of a reduce-scatter followed by all-gather. + +The following graphic depicts the semantics of this operation. We assume that each device starts with a collection of partial sums (denoted by a letter + number, such as `A0`). The goal is to reduce along one axis (numbers), while sharding along the other axis (letters). + +![reduce_scatter_1](../../_static/pallas/distributed/reduce_scatter_1.svg) + +In order to implement a bi-directional communication strategy, we slice each input block in half, and designate a direction for each half. The top half of each block will be passed from right-to-left, and the bottom half will be passed from left-to-right. A second deviation from the communication patterns of our previous all-reduce and all-gather kernels is that we will also pass around accumulators or partial sums and keep the inputs local to each device. This is in contrast to the previous examples where we passed around inputs but kept the accumulator local to the device. Passing around the accumulator is a more natural fit for this problem as in contrast to all-reduce, most of the data in the inputs are not part of the output that will be stored locally on the device. (e.g. `B0`, `C0`, and `D0` in the above graphic will not be stored on the device holding `A` at the end). + +The following diagram illustrates this communication pattern, where the colored boxes represent accumulators (not inputs!). Initially, the accumulator is simply the value that was contained in the input. At each iteration of the algorithm, we will receive a partial sum from our neighbors in each direction. We then compute the correct slice of our input to accumulate into the partial buffer, then pass the new partial sum along to our next neighbor. After N iterations, the accumulator will have passed through each device, meaning that it will hold the full sum in the end. + +![reduce_scatter_2](../../_static/pallas/distributed/reduce_scatter_2.svg) + +In terms of construction of the kernel, we introduce an additional `phase` dimension to the Pallas grid, which denotes which accumulator (left or right) we are currently computing on. We let `phase=0` denote the accumulator moving to the left, and `phase=1` denote the accumulator moving to the right. We then pipeline the two phases, such that while computing the result for one phase we are transferring our previously computed values in the opposite direction in preparation for the next phase. For example, when we are on `phase=0` (left), we first begin a DMA to transfer results we computed in the previous iteration to our right neighbor (right-DMA). Then, we accumulate into the left-buffer and save the result to HBM. We then wait for the right-DMA to complete so that it is ready for `phase=1` (right). + +```{code-cell} ipython3 +--- +executionInfo: + elapsed: 544 + status: ok + timestamp: 1722904805699 + user: + displayName: Justin Fu + userId: '17543197034567316452' + user_tz: 420 +id: nRauUAxNHg28 +--- +partition = P(None, 'x') +devices = mesh_utils.create_device_mesh((1, num_devices)) +mesh = jax.sharding.Mesh(devices, partition) +sharding = jax.sharding.NamedSharding(mesh, partition) + +# We need a block size of (16, 128) to ensure that a half-slice is at least +# of size (8, 128), which is the size of a VREG. This makes tiling easier +# for the compiler. +block_size = (16, 128) +input_arr = jax.random.uniform( + jax.random.key(0), + shape=(block_size[0] * num_devices, block_size[1] * num_devices), +) +input_arr = jax.device_put(input_arr, sharding) + +LEFT = 0 +RIGHT = 1 + + +def mod(x, n): + return lax.rem(x + n, n) + + +def signal(left_or_right, semaphore): + my_id = lax.axis_index('x') + if left_or_right == LEFT: + neighbor = mod(my_id - 1, num_devices) + else: + neighbor = mod(my_id + 1, num_devices) + pltpu.semaphore_signal( + semaphore, + inc=1, + device_id=(0, neighbor), + device_id_type=pltpu.DeviceIdType.MESH, + ) + + +def reduce_scatter_kernel( + x_ref, + o_ref, + hbm_scratch, + local_copy_sem, + left_recv_sem, + left_send_sem, + right_recv_sem, + right_send_sem, + left_capacity_sem, + right_capacity_sem, + accum_scratch, +): + outer_step = pl.program_id(0) + phase = pl.program_id(1) + is_start = jnp.logical_and(outer_step == 0, phase == 0) + last_iteration = outer_step == pl.num_programs(0) - 1 + + working_slot = lax.rem(outer_step, 2) + receiving_slot = 1 - working_slot + my_id = lax.axis_index('x') + right_neighbor = mod(my_id + 1, num_devices) + left_neighbor = mod(my_id - 1, num_devices) + + left_copy_device = mod(my_id + outer_step + 1, num_devices) + right_copy_device = mod(my_id - outer_step - 1, num_devices) + # Slices can be specified using pl.ds(start, size) + left_copy_slice = pl.ds(0, block_size[0] // 2) + right_copy_slice = pl.ds(block_size[0] // 2, block_size[0] // 2) + current_phase_slice = pl.ds(phase * (block_size[0] // 2), block_size[0] // 2) + + initial_left_copy = pltpu.make_async_remote_copy( + src_ref=x_ref.at[my_id, left_copy_slice], + dst_ref=hbm_scratch.at[working_slot, left_copy_slice], + send_sem=left_send_sem, + recv_sem=left_recv_sem, + device_id=(0, left_neighbor), + device_id_type=pltpu.DeviceIdType.MESH, + ) + + initial_right_copy = pltpu.make_async_remote_copy( + src_ref=x_ref.at[my_id, right_copy_slice], + dst_ref=hbm_scratch.at[working_slot, right_copy_slice], + send_sem=right_send_sem, + recv_sem=right_recv_sem, + device_id=(0, right_neighbor), + device_id_type=pltpu.DeviceIdType.MESH, + ) + + left_copy = pltpu.make_async_remote_copy( + src_ref=hbm_scratch.at[working_slot, left_copy_slice], + dst_ref=hbm_scratch.at[receiving_slot, left_copy_slice], + send_sem=left_send_sem, + recv_sem=left_recv_sem, + device_id=(0, left_neighbor), + device_id_type=pltpu.DeviceIdType.MESH, + ) + right_copy = pltpu.make_async_remote_copy( + # Note: Right copy is flipped with regards to slots since we are copying + # to the next outer_step iteration. + src_ref=hbm_scratch.at[receiving_slot, right_copy_slice], + dst_ref=hbm_scratch.at[working_slot, right_copy_slice], + send_sem=right_send_sem, + recv_sem=right_recv_sem, + device_id=(0, right_neighbor), + device_id_type=pltpu.DeviceIdType.MESH, + ) + + # --- Prologue --- + @pl.when(is_start) + def _(): + # Barrier with both neighbors at the start, since we will be + # communicating with both. + barrier_sem = pltpu.get_barrier_semaphore() + pltpu.semaphore_signal( + barrier_sem, + inc=1, + device_id=(0, left_neighbor), + device_id_type=pltpu.DeviceIdType.MESH, + ) + pltpu.semaphore_signal( + barrier_sem, + inc=1, + device_id=(0, right_neighbor), + device_id_type=pltpu.DeviceIdType.MESH, + ) + pltpu.semaphore_wait(barrier_sem, 2) + + # Initialize o_ref, acc_scratch, and hbm_scratch with initial copies. + o_ref[...] = jnp.zeros_like(o_ref[...]) + accum_scratch[...] = jnp.zeros_like(accum_scratch[...]) + + initial_left_copy.start() + initial_left_copy.wait() + initial_right_copy.start() + + # We tell our left neighbor that it is allowed to send to the right. + # (and vice versa for right neighbor) + signal(LEFT, right_capacity_sem) + signal(RIGHT, left_capacity_sem) + + # --- Body --- + # At the beginning of our kernel body, we start a DMA which copies + # the result we computed in the previous phase to our neighbor. + # This allows us to overlap the communication of sending our previous phase + # with the computation for the current phase. + @pl.when(~is_start) + def _(): + @pl.when(phase == LEFT) + def _(): + # We block here until our right neighbor tells use we can send to + # the right. + pltpu.semaphore_wait(right_capacity_sem, 1) + right_copy.start() + + @pl.when(phase == RIGHT) + def _(): + # We block here until our left neighbor tells use we can send to + # the left. + pltpu.semaphore_wait(left_capacity_sem, 1) + left_copy.start() + + local_copy = pltpu.make_async_copy( + src_ref=hbm_scratch.at[working_slot, current_phase_slice], + dst_ref=accum_scratch, + sem=local_copy_sem, + ) + local_copy.start() + local_copy.wait() + + @pl.when(~last_iteration) + def _(): + @pl.when(phase == LEFT) + def _(): + accum_scratch[...] += x_ref[left_copy_device, left_copy_slice] + + @pl.when(phase == RIGHT) + def _(): + accum_scratch[...] += x_ref[right_copy_device, right_copy_slice] + + local_copy = pltpu.make_async_copy( + src_ref=accum_scratch, + dst_ref=hbm_scratch.at[working_slot, current_phase_slice], + sem=local_copy_sem, + ) + local_copy.start() + local_copy.wait() + + @pl.when(is_start) + def _(): + initial_right_copy.wait() + + # At the end of our kernel body, we wait on the DMA of the previous phase + # to make sure the results are ready for the next phase. + @pl.when(~is_start) + def _(): + @pl.when(phase == LEFT) + def _(): + right_copy.wait() + signal(LEFT, right_capacity_sem) + + @pl.when(phase == RIGHT) + def _(): + left_copy.wait() + signal(RIGHT, left_capacity_sem) + + # --- Epilogue --- + # Store result on last iteration. + @pl.when(last_iteration) + def _(): + # Clean up semaphores so that they exit with a value of 0. + @pl.when(phase == LEFT) + def _(): + o_ref[left_copy_slice, ...] = accum_scratch[...] + pltpu.semaphore_wait(right_capacity_sem, 1) + + @pl.when(phase == RIGHT) + def _(): + o_ref[right_copy_slice, ...] = accum_scratch[...] + pltpu.semaphore_wait(left_capacity_sem, 1) + + +out_shape = ( + jax.ShapeDtypeStruct((block_size[0], block_size[1]), jnp.float32), # output + # Shape: [working/recv, block[0], block[1]] + jax.ShapeDtypeStruct( + (2, block_size[0], block_size[1]), jnp.float32 + ), # hbm_scratch +) + +grid_spec = pltpu.PrefetchScalarGridSpec( + num_scalar_prefetch=0, + in_specs=[ + pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.VMEM), + ], + out_specs=[ + pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.VMEM), + pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY), + ], + grid=(num_devices, 2), + scratch_shapes=( + [pltpu.SemaphoreType.DMA] * 5 + + [pltpu.SemaphoreType.REGULAR] * 2 # Capacity semaphores + + [ + pltpu.VMEM((block_size[0] // 2, block_size[1]), jnp.float32) + ] # accum_scratch + ), +) + + +def pallas_reduce_scatter(input_arr): + input_arr = input_arr.reshape(num_devices, block_size[0], block_size[1]) + return pl.pallas_call( + reduce_scatter_kernel, + out_shape=out_shape, + grid_spec=grid_spec, + compiler_params=pltpu.TPUCompilerParams(collective_id=0), + )(input_arr)[0] + + +pallas_result = jax.jit( + shard_map.shard_map( + pallas_reduce_scatter, + mesh=mesh, + in_specs=P(None, 'x'), + out_specs=P('x', None), + check_rep=False, + ) +)(input_arr) + +pallas_result = jax.block_until_ready(pallas_result) +``` + +```{code-cell} ipython3 +--- +executionInfo: + elapsed: 596 + status: ok + timestamp: 1722904806442 + user: + displayName: Justin Fu + userId: '17543197034567316452' + user_tz: 420 +id: E-NMh-_teoi4 +outputId: 24beb42f-1bdd-4c34-e8d2-681dd7f2e9c0 +--- +# Compare our result to XLA. +def lax_reduce_sum_scatter(x): + x = x.reshape(num_devices, block_size[0], block_size[1]) + return lax.psum_scatter(x, 'x') + + +xla_result = jax.jit( + shard_map.shard_map( + lax_reduce_sum_scatter, + mesh=mesh, + in_specs=P(None, 'x'), + out_specs=P('x', None), + ) +)(input_arr) + +print('Input:', input_arr.shape, input_arr[::4, 0]) +print('Pallas Result:', pallas_result.shape, pallas_result[::4, 0]) +print('lax.psum_scatter Result:', xla_result.shape, xla_result[::4, 0]) +print( + 'Difference |Pallas - lax.psum_scatter|:', + jnp.max(jnp.abs(pallas_result - xla_result)), +) +``` + ++++ {"id": "ThKas40r40Ji"} + +### Nested Remote and Local DMA Pipelines + +A limitation of the previous all-reduce and reduce-scatter kernels that we wrote is that the blocks we copy via remote DMA must be small enough to fit in our working VMEM that we use for accumulation. For some kernels it may be advantageous to use larger block sizes to better utilize the TPU. For example, a matrix multiplication requires on the order of $O(N^3)$ compute operations, but only $O(N^2)$ memory transfers. Therefore, we want each block of work transferred between devices to be large enough such that the operation becomes compute bound and we can hide the communication cost using pipelining. For reference, the VMEM of a TPU (for generations v4/v5) is typically on the order of 10-100MB, whereas HBM ranges from 10-100GB. + +To address this problem, we need to be able to write an "inner kernel" that handles local HBM-VMEM pipelining inside of the "outer kernel" that handles pipelining larger HBM-HBM transfers between devices. Pallas offers an API for constructing nested pipelines using the `emit_pipeline` function. The basic call signature for `emit_pipeline` follows that of a standard `pallas_call` by specifying a `grid` and `BlockSpec`s for the inputs and outputs: + +```python +def emit_pipeline( + kernel: Callable, + grid: tuple[int], + in_specs: PyTree[BlockSpec] = None, + out_specs: PyTree[BlockSpec] = None, + should_accumulate_out: bool = False, + dimension_semantics: tuple[GridDimensionSemantics] = None, +) -> Callable: + ... # Returns a custom pipeline given an inner kernel and BlockSpecs. +``` + +Indeed, one can view `pallas_call` itself as simply a wrapper around `emit_pipeline`. Because our outer kernel only involves remote HBM-HBM transfers, we are not using any of the built-in pipelining that `pallas_call` provides for HBM-VMEM transfers. The following code skeleton demonstrates what a typical program structure would look like using this pattern: + +```python + +def outer_kernel(...): + # ... do work to pipeline remote HBM-HBM transfers (outer kernel) + + def inner_kernel(...): + # ... do work (inner kernel) + pltpu.emit_pipeline( + inner_kernel, + grid=inner_grid, + in_specs=..., + out_specs=..., + )(inner_kernel_args) + # ... do more work (outer kernel) + +pl.pallas_call( + outer_kernel, + grid=outer_grid, + in_specs=... + out_specs=... + scratch=inner_kernel_allocs +) +``` + ++++ {"id": "DzFeQjYaasX5"} + +### Example: Reduce-Scatter with large HBM blocks + +In this next example we will modify our previous reduce-scatter example to utilize a nested inner pipeline. Note that the communication and computation costs of `reduce_scatter` both scale linearly with the size of the input, so we do not necessarily expect to see the operation become compute-bound with larger block sizes. This example is purely for demonstration purposes on how to use the pipeline emitter. + +We will increase the block sizes of the outer kernel such that they would be undesirable to place inside of VMEM, and allocate all inputs and outputs in HBM (`memory_space=TPUMemorySpace.Any`). The only major change from our previous kernel is the body of the kernel where accumulation is done. Rather than manually copying from HBM to VMEM, accumulating, and copying back to HBM, we use `emit_pipeline` to handle the memory transfers for us. Accumulation is done in an inner kernel with a much smaller, VMEM-friendly block size. + +In our previous kernel we had the following kernel body to copy data from HBM to the VMEM accumulator, increment, and then copy the results back to HBM: + +```python +local_copy = pltpu.make_async_copy( + src_ref=hbm_scratch.at[working_slot, current_phase_slice], + dst_ref=accum_scratch, + sem=local_copy_sem, +) +local_copy.start() +local_copy.wait() +@pl.when(~last_iteration) +def _(): + @pl.when(phase == LEFT) + def _(): + accum_scratch[...] += x_ref[left_copy_device, left_copy_slice] + @pl.when(phase == RIGHT) + def _(): + accum_scratch[...] += x_ref[right_copy_device, right_copy_slice] +local_copy = pltpu.make_async_copy( + src_ref=accum_scratch, + dst_ref=hbm_scratch.at[working_slot, current_phase_slice], + sem=local_copy_sem, +) +local_copy.start() +local_copy.wait() +``` + +Our new kernel replaces it with the following `emit_pipeline` call: + +```python +def inner_kernel(input_ref, accum_ref): + accum_ref[...] = input_ref[...] +accum_pipeline = pltpu.emit_pipeline(inner_kernel, + in_specs=[inner_block_spec], + out_specs=inner_block_spec, + should_accumulate_out=True, + grid=inner_grid) +@pl.when(~last_iteration) +def _(): + @pl.when(phase == LEFT) + def _(): + accum_pipeline(x_ref.at[left_copy_device, left_copy_slice], + hbm_scratch.at[working_slot, left_copy_slice], + ) + @pl.when(phase == RIGHT) + def _(): + accum_pipeline(x_ref.at[right_copy_device, right_copy_slice], + hbm_scratch.at[working_slot, right_copy_slice], + ) +``` + +The full kernel is as follows: + +```{code-cell} ipython3 +--- +executionInfo: + elapsed: 1341 + status: ok + timestamp: 1722904807930 + user: + displayName: Justin Fu + userId: '17543197034567316452' + user_tz: 420 +id: 27jni-pSartL +--- +partition = P(None, 'x') +devices = mesh_utils.create_device_mesh((1, num_devices)) +mesh = jax.sharding.Mesh(devices, partition) +sharding = jax.sharding.NamedSharding(mesh, partition) + +# We pick a large outer kernel block size that we do not want to place +# in VMEM. For pedagogical purposes we use (4096, 4096), although in +# principle this can be much larger. +outer_block_size = (4096, 4096) +# We pick a smaller VMEM block size for the inner kernel. +inner_block_size = (128, 128) +input_arr = jax.random.uniform( + jax.random.key(0), + shape=( + outer_block_size[0] * num_devices, + outer_block_size[1] * num_devices, + ), +) +input_arr = jax.device_put(input_arr, sharding) + + +inner_grid = ( + outer_block_size[0] // inner_block_size[0] // 2, + outer_block_size[1] // inner_block_size[1], +) +inner_block_spec = pl.BlockSpec( + index_map=lambda i, j: (i, j), + block_shape=inner_block_size, + memory_space=pltpu.TPUMemorySpace.ANY, +) + + +def reduce_scatter_kernel( + x_ref, + o_ref, + hbm_scratch, + left_recv_sem, + left_send_sem, + copy_sem, + right_recv_sem, + right_send_sem, + left_capacity_sem, + right_capacity_sem, +): + outer_step = pl.program_id(0) + phase = pl.program_id(1) + is_start = jnp.logical_and(outer_step == 0, phase == 0) + last_iteration = outer_step == pl.num_programs(0) - 1 + + working_slot = lax.rem(outer_step, 2) + receiving_slot = 1 - working_slot + my_id = lax.axis_index('x') + right_neighbor = mod(my_id + 1, num_devices) + left_neighbor = mod(my_id - 1, num_devices) + + left_copy_device = mod(my_id + outer_step + 1, num_devices) + right_copy_device = mod(my_id - outer_step - 1, num_devices) + left_copy_slice = pl.ds(0, outer_block_size[0] // 2) + right_copy_slice = pl.ds(outer_block_size[0] // 2, outer_block_size[0] // 2) + current_phase_slice = pl.ds( + phase * (outer_block_size[0] // 2), outer_block_size[0] // 2 + ) + + initial_left_copy = pltpu.make_async_remote_copy( + src_ref=x_ref.at[my_id, left_copy_slice], + dst_ref=hbm_scratch.at[working_slot, left_copy_slice], + send_sem=left_send_sem, + recv_sem=left_recv_sem, + device_id=(0, left_neighbor), + device_id_type=pltpu.DeviceIdType.MESH, + ) + + initial_right_copy = pltpu.make_async_remote_copy( + src_ref=x_ref.at[my_id, right_copy_slice], + dst_ref=hbm_scratch.at[working_slot, right_copy_slice], + send_sem=right_send_sem, + recv_sem=right_recv_sem, + device_id=(0, right_neighbor), + device_id_type=pltpu.DeviceIdType.MESH, + ) + + left_copy = pltpu.make_async_remote_copy( + src_ref=hbm_scratch.at[working_slot, left_copy_slice], + dst_ref=hbm_scratch.at[receiving_slot, left_copy_slice], + send_sem=left_send_sem, + recv_sem=left_recv_sem, + device_id=(0, left_neighbor), + device_id_type=pltpu.DeviceIdType.MESH, + ) + right_copy = pltpu.make_async_remote_copy( + src_ref=hbm_scratch.at[receiving_slot, right_copy_slice], + dst_ref=hbm_scratch.at[working_slot, right_copy_slice], + send_sem=right_send_sem, + recv_sem=right_recv_sem, + device_id=(0, right_neighbor), + device_id_type=pltpu.DeviceIdType.MESH, + ) + + # --- Prologue --- + @pl.when(is_start) + def _(): + # Barrier with both neighbors at the start, since we will be + # communicating with both. + barrier_sem = pltpu.get_barrier_semaphore() + pltpu.semaphore_signal( + barrier_sem, + inc=1, + device_id=(0, left_neighbor), + device_id_type=pltpu.DeviceIdType.MESH, + ) + pltpu.semaphore_signal( + barrier_sem, + inc=1, + device_id=(0, right_neighbor), + device_id_type=pltpu.DeviceIdType.MESH, + ) + pltpu.semaphore_wait(barrier_sem, 2) + + initial_left_copy.start() + initial_left_copy.wait() + initial_right_copy.start() + + # We tell our left neighbor that it is allowed to send to the right. + # (and vice versa for right neighbor) + signal(LEFT, right_capacity_sem) + signal(RIGHT, left_capacity_sem) + + @pl.when(~is_start) + def _(): + @pl.when(phase == LEFT) + def _(): + # We block here until our right neighbor tells use we can send to + # the right. + pltpu.semaphore_wait(right_capacity_sem, 1) + right_copy.start() + + @pl.when(phase == RIGHT) + def _(): + # We block here until our left neighbor tells use we can send to + # the left. + pltpu.semaphore_wait(left_capacity_sem, 1) + left_copy.start() + + # --- Body --- + def inner_kernel(input_ref, accum_ref): + # We do not explicitly use += because we set should_accumulate_out=True. + accum_ref[...] = input_ref[...] + + accum_pipeline = pltpu.emit_pipeline( + inner_kernel, + in_specs=[inner_block_spec], + out_specs=inner_block_spec, + should_accumulate_out=True, + grid=inner_grid, + ) + + @pl.when(~last_iteration) + def _(): + @pl.when(phase == LEFT) + def _(): + accum_pipeline( + x_ref.at[left_copy_device, left_copy_slice], + hbm_scratch.at[working_slot, left_copy_slice], + ) + + @pl.when(phase == RIGHT) + def _(): + accum_pipeline( + x_ref.at[right_copy_device, right_copy_slice], + hbm_scratch.at[working_slot, right_copy_slice], + ) + + # --- Epilogue --- + @pl.when(is_start) + def _(): + initial_right_copy.wait() + + @pl.when(~is_start) + def _(): + @pl.when(phase == LEFT) + def _(): + right_copy.wait() + signal(LEFT, right_capacity_sem) + + @pl.when(phase == RIGHT) + def _(): + left_copy.wait() + signal(RIGHT, left_capacity_sem) + + # Store result on last iteration. + @pl.when(last_iteration) + def _(): + output_copy = pltpu.make_async_copy( + src_ref=hbm_scratch.at[working_slot, current_phase_slice], + dst_ref=o_ref.at[current_phase_slice], + sem=copy_sem, + ) + output_copy.start() + output_copy.wait() + + # Clean up semaphores so that they exit with a value of 0. + @pl.when(phase == LEFT) + def _(): + pltpu.semaphore_wait(right_capacity_sem, 1) + + @pl.when(phase == RIGHT) + def _(): + pltpu.semaphore_wait(left_capacity_sem, 1) + + +out_shape = ( + jax.ShapeDtypeStruct( + (outer_block_size[0], outer_block_size[1]), jnp.float32 + ), + # Shape: [working/recv, block[0], block[1]] + jax.ShapeDtypeStruct( + (2, outer_block_size[0], outer_block_size[1]), jnp.float32 + ), # hbm_scratch +) + +grid_spec = pltpu.PrefetchScalarGridSpec( + num_scalar_prefetch=0, + in_specs=[ + pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY), + ], + out_specs=[ + pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY), + pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY), + ], + grid=(num_devices, 2), + scratch_shapes=( + [pltpu.SemaphoreType.DMA] * 5 + + [pltpu.SemaphoreType.REGULAR] * 2 # Capacity semaphores + ), +) + + +def pallas_reduce_scatter(input_arr): + input_arr = input_arr.reshape( + num_devices, outer_block_size[0], outer_block_size[1] + ) + return pl.pallas_call( + reduce_scatter_kernel, + out_shape=out_shape, + grid_spec=grid_spec, + compiler_params=pltpu.TPUCompilerParams(collective_id=0), + )(input_arr)[0] + + +pallas_result = jax.jit( + shard_map.shard_map( + pallas_reduce_scatter, + mesh=mesh, + in_specs=P(None, 'x'), + out_specs=P('x', None), + check_rep=False, + ) +)(input_arr) + +pallas_result = jax.block_until_ready(pallas_result) +``` + +```{code-cell} ipython3 +--- +executionInfo: + elapsed: 768 + status: ok + timestamp: 1722904808851 + user: + displayName: Justin Fu + userId: '17543197034567316452' + user_tz: 420 +id: cTEyiMDyx9Y0 +outputId: 1de26695-3713-430e-9ab4-4ea646691680 +--- +# Now we compare our result to XLA. +def lax_reduce_sum_scatter(x): + x = x.reshape(num_devices, outer_block_size[0], outer_block_size[1]) + return lax.psum_scatter(x, 'x') + + +xla_result = jax.jit( + shard_map.shard_map( + lax_reduce_sum_scatter, + mesh=mesh, + in_specs=P(None, 'x'), + out_specs=P('x', None), + ) +)(input_arr) + +print('Input:', input_arr.shape, input_arr[::4, 0]) +print('Pallas Result:', pallas_result.shape, pallas_result[::4, 0]) +print('lax.psum_scatter Result:', xla_result.shape, xla_result[::4, 0]) +print( + 'Difference |Pallas - lax.psum_scatter|:', + jnp.max(jnp.abs(pallas_result - xla_result)), +) +``` + ++++ {"id": "zz5AFbriliyv"} + +## Final Notes + +### Megacore + +Certain TPUs contain multiple cores in a [Megacore](https://jax.readthedocs.io/en/latest/pallas/tpu/pipelining.html#tpus-in-megacore-configuration) configuration. In this configuration, our general recommendation is to only initiate DMAs from a single core, and only perform HBM-HBM transfers. To do this, set one of the grid axes to the number of cores (can be obtained via `jax.devices()[0].num_cores`) and the dimension_semantics to `"parallel"`. Then, you can use `core_index = pl.program_id(axis)` to obtain the core index along that axis, and use `@pl.when(core_index==i)` to execute code specific to that core. + +### Interaction with XLA + +In this tutorial we covered several kernel examples which replicate the functionality of collective operations in JAX such as `lax.all_gather`, `lax.psum`, and `lax.psum_scatter`. An important caveat to note is that a Pallas kernel is somewhat opaque to the XLA compiler and may cause it to miss some optimizations it would normally perform. For example, XLA can asynchronously dispatch collective operations in order to interleave communication and computation without writing a custom kernel. This is not guaranteed to happen when Pallas kernels are involved so it is important to profile your program to see if this is an issue. Another example is the fact that the `emit_pipeline` function we used in this tutorial to generate nested pipelines is not visible to the XLA compiler, and therefore cannot be fused with neighboring operations. + +### Next Steps + +Excellent follow-up excercises for the reader could include implementing a distributed matrix multiplication, implementing `lax.all_to_all`, and relaxing synchronization to allow for additional run-ahead. diff --git a/docs/pallas/tpu/index.rst b/docs/pallas/tpu/index.rst index 5680481f3947..20abad5f610e 100644 --- a/docs/pallas/tpu/index.rst +++ b/docs/pallas/tpu/index.rst @@ -9,3 +9,6 @@ TPU specific documentation. details pipelining matmul + sparse + distributed + diff --git a/docs/pallas/tpu/matmul.ipynb b/docs/pallas/tpu/matmul.ipynb index 0bd16095cb7e..51ce2ed6868f 100644 --- a/docs/pallas/tpu/matmul.ipynb +++ b/docs/pallas/tpu/matmul.ipynb @@ -210,8 +210,8 @@ " pl.BlockSpec((bk, bn), lambda i, j, k: (k, j))],\n", " out_specs=pl.BlockSpec((bm, bn), lambda i, j, k: (i, j)),\n", " grid=(m // bm, n // bn, k // bk),\n", - " compiler_params=dict(mosaic=dict(\n", - " dimension_semantics=(\"parallel\", \"parallel\", \"arbitrary\"))),\n", + " compiler_params=pltpu.TPUCompilerParams(\n", + " dimension_semantics=(\"parallel\", \"parallel\", \"arbitrary\")),\n", " )(x, y)" ] }, @@ -466,8 +466,8 @@ " grid=(m // bm, n // bn, k // bk),\n", " ),\n", " out_shape=jax.ShapeDtypeStruct((m, n), x.dtype),\n", - " compiler_params=dict(mosaic=dict(\n", - " dimension_semantics=(\"parallel\", \"parallel\", \"arbitrary\"))),\n", + " compiler_params=pltpu.TPUCompilerParams(\n", + " dimension_semantics=(\"parallel\", \"parallel\", \"arbitrary\")),\n", " )(x, y)" ] }, @@ -741,8 +741,8 @@ " grid=(m // bm, n // bn, k // bk),\n", " ),\n", " out_shape=jax.ShapeDtypeStruct((m, n), x.dtype),\n", - " compiler_params=dict(mosaic=dict(\n", - " dimension_semantics=(\"parallel\", \"parallel\", \"arbitrary\"))),\n", + " compiler_params=pltpu.TPUCompilerParams(\n", + " dimension_semantics=(\"parallel\", \"parallel\", \"arbitrary\")),\n", " )(x, y)" ] }, @@ -929,8 +929,8 @@ " grid=(m // bm, n // bn, k // bk),\n", " ),\n", " out_shape=jax.ShapeDtypeStruct((m, n), x.dtype),\n", - " compiler_params=dict(mosaic=dict(\n", - " dimension_semantics=(\"parallel\", \"parallel\", \"arbitrary\"))),\n", + " compiler_params=pltpu.TPUCompilerParams(\n", + " dimension_semantics=(\"parallel\", \"parallel\", \"arbitrary\")),\n", " )(x, y)" ] }, diff --git a/docs/pallas/tpu/matmul.md b/docs/pallas/tpu/matmul.md index a00880ebaf37..b8e6acbd45f9 100644 --- a/docs/pallas/tpu/matmul.md +++ b/docs/pallas/tpu/matmul.md @@ -5,7 +5,7 @@ jupytext: extension: .md format_name: myst format_version: 0.13 - jupytext_version: 1.16.1 + jupytext_version: 1.16.4 kernelspec: display_name: Python 3 (ipykernel) language: python @@ -167,8 +167,8 @@ def matmul( pl.BlockSpec((bk, bn), lambda i, j, k: (k, j))], out_specs=pl.BlockSpec((bm, bn), lambda i, j, k: (i, j)), grid=(m // bm, n // bn, k // bk), - compiler_params=dict(mosaic=dict( - dimension_semantics=("parallel", "parallel", "arbitrary"))), + compiler_params=pltpu.TPUCompilerParams( + dimension_semantics=("parallel", "parallel", "arbitrary")), )(x, y) ``` @@ -321,8 +321,8 @@ def matmul( grid=(m // bm, n // bn, k // bk), ), out_shape=jax.ShapeDtypeStruct((m, n), x.dtype), - compiler_params=dict(mosaic=dict( - dimension_semantics=("parallel", "parallel", "arbitrary"))), + compiler_params=pltpu.TPUCompilerParams( + dimension_semantics=("parallel", "parallel", "arbitrary")), )(x, y) ``` @@ -489,8 +489,8 @@ def matmul( grid=(m // bm, n // bn, k // bk), ), out_shape=jax.ShapeDtypeStruct((m, n), x.dtype), - compiler_params=dict(mosaic=dict( - dimension_semantics=("parallel", "parallel", "arbitrary"))), + compiler_params=pltpu.TPUCompilerParams( + dimension_semantics=("parallel", "parallel", "arbitrary")), )(x, y) ``` @@ -613,8 +613,8 @@ def matmul( grid=(m // bm, n // bn, k // bk), ), out_shape=jax.ShapeDtypeStruct((m, n), x.dtype), - compiler_params=dict(mosaic=dict( - dimension_semantics=("parallel", "parallel", "arbitrary"))), + compiler_params=pltpu.TPUCompilerParams( + dimension_semantics=("parallel", "parallel", "arbitrary")), )(x, y) ``` diff --git a/docs/pallas/tpu/pipelining.ipynb b/docs/pallas/tpu/pipelining.ipynb index 275a72f3837b..b5f2c652b5a5 100644 --- a/docs/pallas/tpu/pipelining.ipynb +++ b/docs/pallas/tpu/pipelining.ipynb @@ -1,5 +1,13 @@ { "cells": [ + { + "cell_type": "markdown", + "id": "7704d3bb", + "metadata": {}, + "source": [ + "(pallas_tpu_pipelining)=" + ] + }, { "cell_type": "markdown", "metadata": { @@ -33,6 +41,7 @@ "\n", "import jax\n", "from jax.experimental import pallas as pl\n", + "from jax.experimental.pallas import tpu as pltpu\n", "import jax.numpy as jnp\n", "import numpy as np" ] @@ -696,7 +705,7 @@ " in_specs=[block_spec, block_spec],\n", " out_specs=block_spec,\n", " grid=(2,),\n", - " compiler_params=dict(mosaic=dict(dimension_semantics=(\"parallel\",)))\n", + " compiler_params=pltpu.TPUCompilerParams(dimension_semantics=(\"parallel\",))\n", " )(x, y)\n", "\n", "x, y = jnp.ones((512, 512)), jnp.ones((512, 512))\n", diff --git a/docs/pallas/tpu/pipelining.md b/docs/pallas/tpu/pipelining.md index d753b404db1a..19150b3832fa 100644 --- a/docs/pallas/tpu/pipelining.md +++ b/docs/pallas/tpu/pipelining.md @@ -5,12 +5,14 @@ jupytext: extension: .md format_name: myst format_version: 0.13 - jupytext_version: 1.16.1 + jupytext_version: 1.16.4 kernelspec: display_name: Python 3 name: python3 --- +(pallas_tpu_pipelining)= + +++ {"id": "teoJ_fUwlu0l"} # Pipelining @@ -29,6 +31,7 @@ pipelines in Pallas that overlap memory I/O with compute. import jax from jax.experimental import pallas as pl +from jax.experimental.pallas import tpu as pltpu import jax.numpy as jnp import numpy as np ``` @@ -465,7 +468,7 @@ def add_matrices_pipelined_megacore(x: jax.Array, y: jax.Array) -> jax.Array: in_specs=[block_spec, block_spec], out_specs=block_spec, grid=(2,), - compiler_params=dict(mosaic=dict(dimension_semantics=("parallel",))) + compiler_params=pltpu.TPUCompilerParams(dimension_semantics=("parallel",)) )(x, y) x, y = jnp.ones((512, 512)), jnp.ones((512, 512)) diff --git a/docs/pallas/tpu/sparse.ipynb b/docs/pallas/tpu/sparse.ipynb new file mode 100644 index 000000000000..a80ba4ebedbb --- /dev/null +++ b/docs/pallas/tpu/sparse.ipynb @@ -0,0 +1,724 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "ZHuzXqQ-9JUQ" + }, + "source": [ + "# Scalar Prefetch and Block-Sparse Computation\n", + "\n", + "In this tutorial, we will cover the basics of block-sparse computing in Pallas. Sparse computation is a major reason to write custom Pallas kernels over simply using JAX/XLA, since it is generally difficult to express programs that perform a dynamic amount of computation in XLA due to static array shapes. In this tutorial we will learn how to use the scalar prefetch feature of Pallas in order to write block-sparse kernels that can dynamically skip over computation and blocks of memory." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "executionInfo": { + "elapsed": 56, + "status": "ok", + "timestamp": 1726001133029, + "user": { + "displayName": "Justin Fu", + "userId": "17543197034567316452" + }, + "user_tz": 420 + }, + "id": "ibeIs_6QFMAM", + "outputId": "d72edb91-4529-4650-c9e9-b96788608635" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Running on TPU v5 lite\n" + ] + } + ], + "source": [ + "import functools\n", + "import timeit\n", + "import numpy as np\n", + "import jax\n", + "from jax import numpy as jnp\n", + "from jax import lax\n", + "from jax.experimental import checkify\n", + "from jax.experimental import pallas as pl\n", + "from jax.experimental.pallas import tpu as pltpu\n", + "\n", + "assert \"TPU\" in jax.devices()[0].device_kind, \"Please run this notebook with TPU devices.\"\n", + "print(\"Running on\", jax.devices()[0].device_kind)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "FIDGpPTEIcOa" + }, + "source": [ + "## Dynamic Block Indexing with Scalar Prefetch\n", + "\n", + "We will be exploiting the \"scalar prefetch\" feature of Pallas to enable us to write sparse kernels. Scalar prefetch allows you to pass in a small amount of data into SMEM (\"scalar memory\") that is loaded before the start of the pipeline (\"prefetch\"). Because this data is loaded before the pipeline, it is available for use in the `index_map` for each BlockSpec, allowing the you to perform data-dependent indexing calculations. The main goal of this tutorial is to go over common programming patterns that utilize this feature.\n", + "\n", + "To use scalar prefetch, use `pltpu.PrefetchScalarGridSpec` in place of the standard `pl.GridSpec`:\n", + "\n", + "```python\n", + "class PrefetchScalarGridSpec:\n", + " def __init__(self,\n", + " num_scalar_prefetch: int,\n", + " grid: tuple[int, ...],\n", + " in_specs: PyTree[BlockSpec],\n", + " out_specs: PyTree[BlockSpec],\n", + " scratch_shapes: tuple[MemorySpace, ...]):\n", + " ...\n", + "```\n", + "\n", + "The `num_scalar_prefetch` parameter indicates the number of scalar prefetch values. When this is set to a non-zero value, it changes the call signature of the kernel and index maps to expect additional prefetch values. The prefetch `Ref`s passed in to the `index_map` and kernel are all allocated in SMEM and are not partitioned into blocks as they do not have a BlockSpec defined. Moreover, the order of arguments to both `index_map` and kernel are always fixed and described below:\n", + "\n", + "- Each `BlockSpec`'s `index_map` now expects the prefetch `Ref`s to come after the grid indices:\n", + "```python\n", + "def index_map(*grid_indices, *prefetch_refs):\n", + " ...\n", + "```\n", + "\n", + "- The user-defined kernel expects prefetch `Ref`s to come before the input `Ref`s. Additionally, the scratch refs come after the output `Ref`s.\n", + "```python\n", + "def kernel(*prefetch_refs, *input_refs, *output_refs, *scratch_refs):\n", + " ...\n", + "```\n", + "\n", + "- When calling a new kernel using `pallas_call`, the function returned by `pallas_call` also expects the scalar prefetch arguments to come before the inputs, e.g.\n", + "```python\n", + "kernel = pl.pallas_call(...)\n", + "result = kernel(*prefetch_args, *input_args)\n", + "```" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "pA8RmHEA2HN3" + }, + "source": [ + "## Example: Block Dynamic Slice with Scalar Prefetch\n", + "\n", + "Let's begin with a basic example that demonstrates how to use the scalar prefetch feature. We will implement a block-aligned dynamic slice kernel which simply extracts a block out of larger array based on user-specified indices:\n", + "\n", + "1. Outside of the kernel, we compute the block index to extract as: `block_idx = (start[0] // size[0], start[1] // size[1])`\n", + "\n", + "2. We pass `block_idx` as a scalar prefetch argument into `pallas_call`.\n", + "\n", + "3. In our index map, we use the block index to select the corresponding block by returning `(block_idx[0], block_idx[1])`.\n", + "\n", + "Of course, this kernel is limited in that our slice sizes must fit inside of a kernel block (limited by VMEM size) and we can only start on size-aligned indices. A more advanced kernel would decouple the kernel block size with the slice size and allow non-aligned start indices." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "executionInfo": { + "elapsed": 143, + "status": "ok", + "timestamp": 1726003877561, + "user": { + "displayName": "Justin Fu", + "userId": "17543197034567316452" + }, + "user_tz": 420 + }, + "id": "FWeTBlEYlCGD", + "outputId": "4b04a441-c97c-4d0d-d167-c60d4d31fd2e" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Error |result - lax.dynamic_slice| = 0\n" + ] + } + ], + "source": [ + "def dynamic_slice_kernel(indices, x_ref, o_ref):\n", + " del indices\n", + " o_ref[...] = x_ref[...]\n", + "\n", + "@checkify.checkify\n", + "@functools.partial(jax.jit, static_argnums=(2,))\n", + "def block_dynamic_slice(x, starts, sizes):\n", + " grid_spec = pltpu.PrefetchScalarGridSpec(\n", + " num_scalar_prefetch=1,\n", + " grid=(1, 1),\n", + " in_specs=[pl.BlockSpec(\n", + " sizes,\n", + " lambda i, j, block_idx: (block_idx[0], block_idx[1]))],\n", + " out_specs=pl.BlockSpec(sizes, lambda *_: (0, 0)),\n", + " )\n", + "\n", + " kernel = pl.pallas_call(\n", + " dynamic_slice_kernel,\n", + " grid_spec=grid_spec,\n", + " out_shape=jax.ShapeDtypeStruct(shape=sizes, dtype=x.dtype),\n", + " )\n", + " # Checkify inserts a runtime assert that starts are divisible by block size.\n", + " checkify.check(starts[0] % sizes[0] == 0, \"Starts must be divisible by size.\")\n", + " checkify.check(starts[1] % sizes[1] == 0, \"Starts must be divisible by size.\")\n", + " block_idx = jnp.array([starts[0] // sizes[0], starts[1] // sizes[1]])\n", + " return kernel(block_idx, x)\n", + "\n", + "shape = (512, 512)\n", + "x = jnp.reshape(jnp.arange(np.prod(shape), dtype=jnp.int32), shape)\n", + "err, result = block_dynamic_slice(x, starts=(128, 256), sizes=(128, 128))\n", + "err.throw()\n", + "ref = lax.dynamic_slice(x, start_indices=(128, 256), slice_sizes=(128, 128))\n", + "diff = jnp.max(jnp.abs(result - ref))\n", + "print(\"Error |result - lax.dynamic_slice| =\", diff)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "K2dod4lkoifa" + }, + "source": [ + "## Sparse Kernels: Representing Sparse Data\n", + "\n", + "Before we dive into implementing sparse kernels, let's first review how sparse matrices are represented. While there are several popular formats for storing sparse matrices, we will be following a blocked variant of the coordinate-list format (COO) in which we will store a matrix as a list of `(block_index, block_data)` pairs. All blocks that are not explicitly stored in the list are assumed to be zero, meaning we can save a significant amount of memory if there are many zero blocks in the matrix.\n", + "\n", + "The following figure demonstrates how we convert a 4x4 dense matrix (left) into a block-COO format (right) with a block size of 2x2. Note that in the sparse format, we can avoid explicitly storing the upper-right block which consists of all zero elements.\n", + "\n", + "![block_coo](../../_static/pallas/sparse/block_coo.svg)\n", + "\n", + "We will use the following helper function to sample a block-sparse matrix. It returns a dense matrix used for checking our results, as well as a list of block data and indices for each axis." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "1gLiSvgIYUEx" + }, + "outputs": [], + "source": [ + "def generate_block_sparse_mat(key, M, N, blk_M, blk_N, p=0.2, dtype=jnp.float32):\n", + " \"\"\"Returns a sampled matrix and its block-sparse representation.\n", + "\n", + " Args:\n", + " key: RNG Key.\n", + " M: Major array dimension.\n", + " N: Minor array dimension.\n", + " blk_M: Block size along M dimension.\n", + " blk_N: Block size along N dimension.\n", + " p: Probability that a block will be non-zero.\n", + " dtype: dtype of the sampled matrix.\n", + "\n", + " Returns:\n", + " dense_mat: A (M, N) dense sampled array.\n", + " block_data: A (num_blocks, blk_M, blk_N) array of data blocks representing\n", + " the non-zero blocks of the matrix.\n", + " indices_i: A (num_blocks,) array of block indices for the first axis.\n", + " indices_j: A (num_blocks,) array of block indices for the second axis.\n", + " \"\"\"\n", + " mask_key, blocks_key = jax.random.split(key)\n", + " num_blocks = (M // blk_M, N // blk_N)\n", + " # We first sample a block mask, denoting which blocks are nonzero.\n", + " block_mask = jax.random.bernoulli(mask_key, p=p, shape=num_blocks)\n", + " num_blocks = jnp.sum(block_mask)\n", + " indices = jnp.where(block_mask)\n", + " # For each non-zero block, we sample a block of random values.\n", + " block_data = jax.random.uniform(blocks_key,\n", + " shape=(num_blocks, blk_M, blk_N),\n", + " dtype=dtype)\n", + " # For checking purposes, create the dense version of the sparse matrix.\n", + " dense_mat = jnp.zeros((M, N), dtype=dtype)\n", + " for blk in range(num_blocks):\n", + " idx_i = indices[0][blk]\n", + " idx_j = indices[1][blk]\n", + " slice_i = slice(idx_i * blk_M, (idx_i + 1) * blk_M)\n", + " slice_j = slice(idx_j * blk_N, (idx_j + 1) * blk_N)\n", + " dense_mat = dense_mat.at[slice_i, slice_j].set(block_data[blk])\n", + " return dense_mat, block_data, indices[0], indices[1]" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "eFyoZSTOH9Fk" + }, + "source": [ + "## Example: Sparse @ Dense Matrix Multiplication\n", + "\n", + "In our first example, we will multiple a sparse LHS matrix with a dense RHS matrix to produce a dense output.\n", + "\n", + "We will structure our kernel grid with 2 loops - the outer loop over the columns of the RHS/output, and inner loop over the sparse blocks of the LHS. During each inner loop iteration, we load one block from the LHS and lookup the corresponding block on in the RHS using the block index of the contracting dimension (K). We multiply the two blocks together and accumulate into the correct output block. One outer loop iteration will compute a result for an entire column as depicted by the following diagram:\n", + "\n", + "![sparse_matmul](../../_static/pallas/sparse/sparse_matmul.svg)\n", + "\n", + "It is important that we group the block indices by row (e.g. `[0, 0, 1, 2, 3, 3]`) before we pass them into the kernel for two reasons. First, in our kernel we need to know when to initially zero-out the accumulator in the output ref, and it is easy to do so if the row indices are grouped. Second, the pipelining logic for Pallas does not allow us to re-visit blocks in the output `Ref` on non-consecutive iterations, and therefore we need to do all accumulation into an output block in consecutive kernel iterations. This is because the pipeline emitter will realize that we loading the same output block on consecutive iterations and keep the block in VMEM. When we change output block Pallas will finally store the output into HBM and assume we never touch it again. Failure to access output blocks consecutively will result in incorrect values even though the kernel is otherwise logically correct." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "executionInfo": { + "elapsed": 673, + "status": "ok", + "timestamp": 1725919879291, + "user": { + "displayName": "Justin Fu", + "userId": "17543197034567316452" + }, + "user_tz": 420 + }, + "id": "WfyV2WWhjsyA", + "outputId": "fa4d4fff-bc6b-4dc9-ac14-63276ca14131" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "mean |result - ref|: 0\n" + ] + } + ], + "source": [ + "M = N = K = 16384\n", + "blk_M = blk_N = blk_K = 512\n", + "\n", + "\n", + "def dsd_kernel(idxs_i_ref, idxs_k_ref, # Scalar prefetch inputs.\n", + " x_ref, y_ref, _, o_ref, # Kernel inputs.\n", + " accum_scratch,\n", + " ):\n", + " \"\"\"A DSD (Dense = Sparse @ Dense) matmul kernel.\"\"\"\n", + " del idxs_k_ref\n", + " blk_idx = pl.program_id(0)\n", + " is_start = blk_idx == 0\n", + " changed_blocks = (idxs_i_ref[blk_idx] != idxs_i_ref[jnp.maximum(blk_idx-1, 0)])\n", + " @pl.when(is_start | changed_blocks)\n", + " def _():\n", + " accum_scratch[...] = jnp.zeros_like(accum_scratch)\n", + " accum_scratch[...] += jnp.dot(x_ref[0, :, :], y_ref[...], preferred_element_type=jnp.float32)\n", + "\n", + " next_block_change = (idxs_i_ref[blk_idx] != idxs_i_ref[jnp.minimum(blk_idx+1, num_blocks)])\n", + " is_end = blk_idx == (num_blocks - 1)\n", + " @pl.when(is_end | next_block_change)\n", + " def _():\n", + " o_ref[...] = accum_scratch[...].astype(o_ref.dtype)\n", + "\n", + "\n", + "def x_map(blk_idx, j, blk_idxs_i, blk_idxs_k):\n", + " del j, blk_idxs_i, blk_idxs_k\n", + " return (blk_idx, 0, 0)\n", + "def y_map(blk_idx, j, blk_idxs_i, blk_idxs_k):\n", + " del blk_idxs_i\n", + " return (blk_idxs_k[blk_idx], j)\n", + "def o_map(blk_idx, j, blk_idxs_i, blk_idxs_k):\n", + " del blk_idxs_k\n", + " return (blk_idxs_i[blk_idx], j)\n", + "\n", + "(X_dense, X_blocks, indices_i, indices_k) = generate_block_sparse_mat(\n", + " jax.random.key(0), M, K, blk_M, blk_K, p=0.1, dtype=jnp.bfloat16)\n", + "num_blocks = X_blocks.shape[0]\n", + "Y = jax.random.uniform(jax.random.key(1), shape=(K, N), dtype=jnp.bfloat16)\n", + "zeros = jnp.zeros((M, N), dtype=jnp.bfloat16)\n", + "out_shape = jax.ShapeDtypeStruct((M, N), dtype=jnp.bfloat16)\n", + "\n", + "grid_spec = pltpu.PrefetchScalarGridSpec(\n", + " num_scalar_prefetch=2,\n", + " # Note that while num_blocks is static here, Pallas does support\n", + " # dynamic grid sizes.\n", + " grid=(num_blocks, N // blk_N),\n", + " in_specs=[pl.BlockSpec((1, blk_M, blk_K), x_map),\n", + " pl.BlockSpec((blk_K, blk_N), y_map),\n", + " # Placeholder for a zeros-array used by input_output_aliases.\n", + " pl.BlockSpec((blk_M, blk_N), o_map),\n", + " ],\n", + " out_specs=pl.BlockSpec((blk_M, blk_N), o_map),\n", + " scratch_shapes=[pltpu.VMEM((blk_M, blk_N), dtype=jnp.float32)]\n", + ")\n", + "kernel = pl.pallas_call(\n", + " dsd_kernel,\n", + " grid_spec=grid_spec,\n", + " out_shape=out_shape,\n", + " # We use input-output aliases to zero-out o_ref for blocks that we never\n", + " # visit. By passing in an array of zeros we avoid having o_ref start with\n", + " # uninitialized values.\n", + " input_output_aliases={4: 0}, # Map zeros to o_ref.\n", + ")\n", + "args = (indices_i, indices_k, X_blocks, Y, zeros)\n", + "result = kernel(*args)\n", + "\n", + "ref = X_dense @ Y\n", + "diff = jnp.abs(ref - result)\n", + "print('mean |result - ref|:', jnp.mean(diff))" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "2KDgPKF2tUjq" + }, + "source": [ + "We can do a quick benchmark to compare the performance of our sparse kernel compared to a dense matmul in JAX. On a TPU v5e chip, this kernel achieves a roughly ~6x speed increase compared to the theoretical 10x from the sparsity factor.\n", + "\n", + "There are a few main tips for performance here, mainly centered around reducing the communication overhead between HBM/VMEM:\n", + "- Using `dtype=jnp.bfloat16` is critical for performance since it reduces memory bandwidth by half.\n", + "- Using larger block sizes also helps, since matrix multiply is an $O(N^3)$ compute and $O(N^2)$ memory operation. As $N$ grows larger, the kernel becomes compute-bound. However, a counter-argument to this in practice is that smaller block sizes also enables data to be more sparse, so this is a parameter that should be selected carefully." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "executionInfo": { + "elapsed": 6576, + "status": "ok", + "timestamp": 1725919886762, + "user": { + "displayName": "Justin Fu", + "userId": "17543197034567316452" + }, + "user_tz": 420 + }, + "id": "CkzjqnekpZbx", + "outputId": "1ae9031e-705a-4d05-f8b9-d09623918300" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Sparse Kernel: 8.136 ms (avg over 100 trials)\n", + "Reference: 46.953 ms (avg over 100 trials)\n" + ] + } + ], + "source": [ + "# Benchmark Sparse Pallas kernel vs reference JAX implementation\n", + "\n", + "def benchmark(f, ntrials: int = 100):\n", + " def run(*args, **kwargs):\n", + " # Compile function first\n", + " jax.block_until_ready(f(*args, **kwargs))\n", + " # Time function\n", + " result = timeit.timeit(lambda: jax.block_until_ready(f(*args, **kwargs)),\n", + " number=ntrials)\n", + " time = result / ntrials\n", + " return time\n", + " return run\n", + "\n", + "\n", + "n_trials = 100\n", + "\n", + "pallas_impl = lambda *args: kernel(*args)\n", + "time = benchmark(pallas_impl, n_trials)(indices_i, indices_k, X_blocks, Y, zeros)\n", + "print(\"Sparse Kernel: %.3f ms (avg over %d trials)\" % (time * 1000, n_trials))\n", + "\n", + "ref_impl = jax.jit(lambda x, y: x @ y)\n", + "time = benchmark(ref_impl, n_trials)(X_dense, Y)\n", + "print(\"Reference: %.3f ms (avg over %d trials)\" % (time * 1000, n_trials))" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "Q1KKd5vTCwnB" + }, + "source": [ + "## Sparse Access Patterns on Dense Data\n", + "\n", + "In our previous example we considered the case when the data itself is sparse. This manifested itself in the kernel structure as a dimension in the kernel grid that was dynamic and looped over the number of nonzero blocks (`num_blocks`).\n", + "\n", + "A second useful programming pattern emerges when the underlying is data is dense, but we wish to perform sparse computation over it. Our kernel grid in this case will be dense, but we wish to skip over some blocks in the grid as indicated by a block-sparse mask. This type of programming pattern is commonly arises when using masks in many machine learning applications, such as causal or local masks in self-attention. In these cases, we can entirely skip over computation in blocks where the mask is zeroed-out. Examples of this programming pattern can be found in the Splash Attention and Grouped Matrix Multiplication kernels located in `jax/experimental/pallas/ops/tpu`, or in PyTorch's [FlexAttention](https://pytorch.org/blog/flexattention/).\n", + "\n", + "The main performance consideration with dealing with a sparse access pattern on dense data is the interaction with pipelining. On any given kernel iteration, the Pallas pipeline emitter will attempt to prefetch the next block of data by calling the `index_map` for each `BlockSpec` on the next iteration of the grid. However, if our computation is sparse we may be skipping the computation for the next block in the grid, so we need some method to tell the pipeline instead begin fetching the *next block that we are not skipping*. In order to do this, we need to construct *prefetch maps* which contains indices to the next non-skipped block of data for each kernel input. The following diagram illustrates how a prefetch map could be constructed for a block-sparse mask that is stored in a COO-like format.\n", + "\n", + "![prefetch_map](../../_static/pallas/sparse/prefetch_map.svg)\n", + "\n", + "*Left: A sparse access pattern, where the color blue denotes blocks with non-zero masks that we need to compute. Right: The prefetch map, where each element of the array contains the index of the next non-zero block data.*\n", + "\n", + "Once the prefetch map has been constructed, we can pass the map as a scalar prefetch argument and query it in the `index_map` function of the BlockSpec.\n", + "\n", + "```python\n", + "def mask_index_map(prefetch_map, i, j, ...):\n", + " next_nonzero_block = prefetch_map[i, j]\n", + " return (next_nonzero_block, 0, 0)\n", + "```\n", + "\n", + "We can construct similar index maps for the other inputs to the kernel. For dense inputs you will most likely need to construct prefetch maps which point to the next non-zero block index in the grid. Our next example will provide an example of using these prefetch maps." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "ii7rzL5YIA8-" + }, + "source": [ + "## Example: Dense @ Dense Matrix Multiplication with a Block-Sparse Output Mask" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "ecjiqWfA2RlV" + }, + "source": [ + "In our next example we will cover dense matrix multiplication fused with a sparse output mask using a prefetch map to improve pipelining performance. We will use the mask to selectively skip computing output blocks that are zeroed-out, therefore saving on computation costs.\n", + "\n", + "As we will be working with a sparse mask, we will begin by implementing a function that converts an `N x M` mask stored in dense format into a block-sparse format. We additionally need to compute prefetch maps to help the pipeline emitter know which block to fetch next. In total, our `sparsify_mask` function computes:\n", + "- A `block_mask` of shape `(num_N_blocks, num_M_blocks)` indicating if a block is all-zeros (value `0`) or contains non-zero elements (value `1`). If the `block_mask` has a value of 0 we can skip computing the block in the kernel.\n", + "- A `prefetch_mask` array of shape `(num_N_blocks, num_M_blocks)` consisting of indices into `mask_data` for the next non-zero block.\n", + "- A `prefetch_i` array of shape `(num_N_blocks, num_M_blocks)` consisting of the next non-masked `i` index of the mask.\n", + "- A `prefetch_j` array of shape `(num_N_blocks, num_M_blocks)` consisting of the next non-masked `j` index of the mask.\n", + "- A `mask_data` array of shape `(num_blocks, blk_N, blk_M)` containing data for non-zero blocks of the mask." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "19zGcliL2SJy" + }, + "outputs": [], + "source": [ + "def sparsify_mask(mask: jax.Array,\n", + " block_shape: tuple[int, int]):\n", + " \"\"\"Preprocesses a mask into a sparse reprentation.\n", + "\n", + " Args:\n", + " mask: A boolean array of shape [M, N]\n", + " block_shape: The size of a single block.\n", + "\n", + " Returns:\n", + " block_mask: A block_shape array of booleans indicating whether a block\n", + " is all-zeros (0) or contains non-zero elements (1).\n", + " prefetch_mask: A block_shape array of integers indicating the index of the\n", + " next non-zero block.\n", + " mask_data: A (num_blocks, block_shape) array containing\n", + " the data for non-zero blocks of the mask.\n", + " \"\"\"\n", + " M, N = mask.shape\n", + " bm, bn = block_shape\n", + "\n", + " block_mask = jnp.zeros((M // bm, N // bn), dtype=mask.dtype)\n", + " mask_types_finder = []\n", + " mask_data = []\n", + " mask_type_idxs = []\n", + "\n", + " next_mask_type_idx = 0\n", + " prefetch_mask = jnp.zeros_like(block_mask)\n", + " next_i = (M // bm) - 1\n", + " next_j = (N // bn) - 1\n", + " prefetch_i = jnp.zeros_like(block_mask)\n", + " prefetch_j = jnp.zeros_like(block_mask)\n", + " for i in range(M // bm, -1, -1):\n", + " for j in range(N // bn, -1, -1):\n", + " mask_block = mask[i * bm :(i + 1) * bm,\n", + " j * bn :(j + 1) * bn]\n", + " is_nonzero = jnp.any(mask_block)\n", + " if is_nonzero:\n", + " try:\n", + " type_index = mask_types_finder.index(str(mask_block))\n", + " except ValueError:\n", + " type_index = len(mask_types_finder)\n", + " mask_types_finder.append(str(mask_block))\n", + " mask_data.append(mask_block)\n", + " next_mask_type_idx = type_index\n", + " next_i = i\n", + " next_j = j\n", + " else:\n", + " type_index = -1\n", + " mask_type_idxs.append(type_index)\n", + " block_mask = block_mask.at[i, j].set(is_nonzero)\n", + " prefetch_mask = prefetch_mask.at[i, j].set(next_mask_type_idx)\n", + " prefetch_i = prefetch_i.at[i, j].set(next_i)\n", + " prefetch_j = prefetch_j.at[i, j].set(next_j)\n", + " return block_mask, prefetch_mask, prefetch_i, prefetch_j, jnp.stack(mask_data)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "w4b7ckKq67Xw" + }, + "source": [ + "In terms of the structure of the kernel, we use the same grid pattern as the standard matrix multiplication kernel we covered in previous tutorials with a 3 loops over the `N`, `M`, and `K` dimensions. Within the kernel itself, we first check the `block_mask` to see if the mask for the current output block was all zeros. If the mask is all zeros, we can skip computation and move onto the next block; otherwise we need to compute the matrix multiplication and then mask the result." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "executionInfo": { + "elapsed": 5374, + "status": "ok", + "timestamp": 1725919713252, + "user": { + "displayName": "Justin Fu", + "userId": "17543197034567316452" + }, + "user_tz": 420 + }, + "id": "4YQ9OmbTCSjT", + "outputId": "2d752609-34f2-4059-e8ba-4d80afe8cb26" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "mean |result - ref|: 1.0252e-05\n" + ] + } + ], + "source": [ + "M = N = K = 16384\n", + "blk_M = blk_N = 512\n", + "blk_K = 1024\n", + "\n", + "def sparse_mask_matmul(\n", + " block_mask_ref, prefetch_mask, prefetch_i, prefetch_j, # Scalar prefetch inputs.\n", + " x_ref, y_ref, mask_ref, o_ref, # Kernel inputs.\n", + " accum_scratch\n", + " ):\n", + " del prefetch_mask, prefetch_i, prefetch_j\n", + " i, j, k = pl.program_id(0), pl.program_id(1), pl.program_id(2)\n", + " should_compute = block_mask_ref[i, j] != 0\n", + " @pl.when(k == 0)\n", + " def _():\n", + " o_ref[...] = jnp.zeros_like(o_ref)\n", + " accum_scratch[...] = jnp.zeros_like(accum_scratch[...])\n", + "\n", + " # We only compute the output for blocks with non-zero masks.\n", + " # Otherwise we skip the computation entirely.\n", + " @pl.when(should_compute)\n", + " def _():\n", + " result = jnp.dot(x_ref[...], y_ref[...], preferred_element_type=jnp.float32)\n", + " accum_scratch[...] += result\n", + " @pl.when(k == pl.num_programs(2) - 1)\n", + " def _():\n", + " o_ref[...] = (mask_ref[0, ...] * accum_scratch[...]).astype(o_ref.dtype)\n", + "\n", + "X = jax.random.normal(jax.random.key(0), shape=(M, K), dtype=jnp.bfloat16)\n", + "Y = jax.random.normal(jax.random.key(1), shape=(K, N), dtype=jnp.bfloat16)\n", + "mask = jnp.ones((M, N), dtype=jnp.int32)\n", + "mask = jnp.tril(mask)\n", + "block_mask, prefetch_mask, prefetch_i, prefetch_j, sparse_mask_data = sparsify_mask(mask, (blk_M, blk_N))\n", + "\n", + "def x_map(i, j, k, block_mask, prefetch_mask, prefetch_i, prefetch_j):\n", + " del prefetch_mask, prefetch_j\n", + " # Zero-out the k index if the mask is zero, to avoid constantly fetching\n", + " # new blocks in the inner loop for blocks we are skipping.\n", + " k_fetch = (block_mask[i, j] != 0) * k\n", + " return (prefetch_i[i, j], k_fetch)\n", + "\n", + "def y_map(i, j, k, block_mask, prefetch_mask, prefetch_i, prefetch_j):\n", + " del prefetch_mask, prefetch_i\n", + " k_fetch = (block_mask[i, j] != 0) * k\n", + " return (k_fetch, prefetch_j[i, j])\n", + "\n", + "def mask_map(i, j, k, block_mask, prefetch_mask, *_):\n", + " del k, block_mask\n", + " return (prefetch_mask[i, j], 0, 0)\n", + "\n", + "def o_map(i, j, k, *_):\n", + " del k\n", + " return (i, j)\n", + "\n", + "grid_spec = pltpu.PrefetchScalarGridSpec(\n", + " num_scalar_prefetch=4,\n", + " grid=(M // blk_M, N // blk_N, K // blk_K),\n", + " in_specs=[pl.BlockSpec((blk_M, blk_K), x_map),\n", + " pl.BlockSpec((blk_K, blk_N), y_map),\n", + " pl.BlockSpec((1, blk_M, blk_N), mask_map)],\n", + " out_specs=pl.BlockSpec((blk_M, blk_N), o_map),\n", + " scratch_shapes=[pltpu.VMEM((blk_M, blk_N), dtype=jnp.float32)]\n", + ")\n", + "kernel = pl.pallas_call(\n", + " sparse_mask_matmul,\n", + " grid_spec=grid_spec,\n", + " out_shape=jax.ShapeDtypeStruct((M, N), jnp.bfloat16),\n", + ")\n", + "args = (block_mask, prefetch_mask, prefetch_i, prefetch_j, X, Y, sparse_mask_data)\n", + "result = kernel(*args)\n", + "\n", + "ref = mask * (X @ Y)\n", + "diff = jnp.abs(ref - result)\n", + "print('mean |result - ref|:', jnp.mean(diff))" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "uutNGgjZGGhB" + }, + "source": [ + "Now let's compare performance versus a naive dense implementation. On TPU v5e, we achieve around a ~1.8x speed increase with the sparse kernel, compared to a theoretical best-case of 2x from using a lower triangular mask and only visiting half of the possible outputs.\n", + "\n", + "We would generally expect performance to get closer to the theoretical peak as our inputs get larger, since a few of the main reasons why we don't exactly reach theoretical performance are:\n", + "- We skip slightly less than half of computation since the blocks along the diagonal are mixed 0s and 1s, and for mixed blocks we need to compute the entire block. With larger inputs, our overhead for mixed blocks becomes smaller relative to the overall computation.\n", + "- The pipeline bubble also becomes accounts for a less percentage of the overall runtime as inputs become larger." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "executionInfo": { + "elapsed": 8877, + "status": "ok", + "timestamp": 1725917397452, + "user": { + "displayName": "Justin Fu", + "userId": "17543197034567316452" + }, + "user_tz": 420 + }, + "id": "MAT9JjGNvsx8", + "outputId": "a32d56fb-a71b-4007-c6a5-e5270dcaa6cf" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Sparse Kernel: 28.648 ms (avg over 100 trials)\n", + "Reference: 49.988 ms (avg over 100 trials)\n" + ] + } + ], + "source": [ + "n_trials = 100\n", + "\n", + "pallas_impl = lambda *args: kernel(*args)\n", + "time = benchmark(pallas_impl, n_trials)(block_mask, prefetch_mask, prefetch_i, prefetch_j, X, Y, sparse_mask_data)\n", + "print(\"Sparse Kernel: %.3f ms (avg over %d trials)\" % (time * 1000, n_trials))\n", + "\n", + "ref_impl = jax.jit(lambda mask, x, y: mask * (x @ y))\n", + "time = benchmark(ref_impl, n_trials)(mask, X, Y)\n", + "print(\"Reference: %.3f ms (avg over %d trials)\" % (time * 1000, n_trials))" + ] + } + ], + "metadata": { + "jupytext": { + "formats": "ipynb,md:myst", + "main_language": "python" + }, + "kernelspec": { + "display_name": "Python 3", + "name": "python3" + }, + "language_info": { + "name": "python" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} diff --git a/docs/pallas/tpu/sparse.md b/docs/pallas/tpu/sparse.md new file mode 100644 index 000000000000..2ac25edb5064 --- /dev/null +++ b/docs/pallas/tpu/sparse.md @@ -0,0 +1,567 @@ +--- +jupytext: + formats: ipynb,md:myst + main_language: python + text_representation: + extension: .md + format_name: myst + format_version: 0.13 + jupytext_version: 1.16.4 +kernelspec: + display_name: Python 3 + name: python3 +--- + ++++ {"id": "ZHuzXqQ-9JUQ"} + +# Scalar Prefetch and Block-Sparse Computation + +In this tutorial, we will cover the basics of block-sparse computing in Pallas. Sparse computation is a major reason to write custom Pallas kernels over simply using JAX/XLA, since it is generally difficult to express programs that perform a dynamic amount of computation in XLA due to static array shapes. In this tutorial we will learn how to use the scalar prefetch feature of Pallas in order to write block-sparse kernels that can dynamically skip over computation and blocks of memory. + +```{code-cell} +--- +executionInfo: + elapsed: 56 + status: ok + timestamp: 1726001133029 + user: + displayName: Justin Fu + userId: '17543197034567316452' + user_tz: 420 +id: ibeIs_6QFMAM +outputId: d72edb91-4529-4650-c9e9-b96788608635 +--- +import functools +import timeit +import numpy as np +import jax +from jax import numpy as jnp +from jax import lax +from jax.experimental import checkify +from jax.experimental import pallas as pl +from jax.experimental.pallas import tpu as pltpu + +assert "TPU" in jax.devices()[0].device_kind, "Please run this notebook with TPU devices." +print("Running on", jax.devices()[0].device_kind) +``` + ++++ {"id": "FIDGpPTEIcOa"} + +## Dynamic Block Indexing with Scalar Prefetch + +We will be exploiting the "scalar prefetch" feature of Pallas to enable us to write sparse kernels. Scalar prefetch allows you to pass in a small amount of data into SMEM ("scalar memory") that is loaded before the start of the pipeline ("prefetch"). Because this data is loaded before the pipeline, it is available for use in the `index_map` for each BlockSpec, allowing the you to perform data-dependent indexing calculations. The main goal of this tutorial is to go over common programming patterns that utilize this feature. + +To use scalar prefetch, use `pltpu.PrefetchScalarGridSpec` in place of the standard `pl.GridSpec`: + +```python +class PrefetchScalarGridSpec: + def __init__(self, + num_scalar_prefetch: int, + grid: tuple[int, ...], + in_specs: PyTree[BlockSpec], + out_specs: PyTree[BlockSpec], + scratch_shapes: tuple[MemorySpace, ...]): + ... +``` + +The `num_scalar_prefetch` parameter indicates the number of scalar prefetch values. When this is set to a non-zero value, it changes the call signature of the kernel and index maps to expect additional prefetch values. The prefetch `Ref`s passed in to the `index_map` and kernel are all allocated in SMEM and are not partitioned into blocks as they do not have a BlockSpec defined. Moreover, the order of arguments to both `index_map` and kernel are always fixed and described below: + +- Each `BlockSpec`'s `index_map` now expects the prefetch `Ref`s to come after the grid indices: +```python +def index_map(*grid_indices, *prefetch_refs): + ... +``` + +- The user-defined kernel expects prefetch `Ref`s to come before the input `Ref`s. Additionally, the scratch refs come after the output `Ref`s. +```python +def kernel(*prefetch_refs, *input_refs, *output_refs, *scratch_refs): + ... +``` + +- When calling a new kernel using `pallas_call`, the function returned by `pallas_call` also expects the scalar prefetch arguments to come before the inputs, e.g. +```python +kernel = pl.pallas_call(...) +result = kernel(*prefetch_args, *input_args) +``` + ++++ {"id": "pA8RmHEA2HN3"} + +## Example: Block Dynamic Slice with Scalar Prefetch + +Let's begin with a basic example that demonstrates how to use the scalar prefetch feature. We will implement a block-aligned dynamic slice kernel which simply extracts a block out of larger array based on user-specified indices: + +1. Outside of the kernel, we compute the block index to extract as: `block_idx = (start[0] // size[0], start[1] // size[1])` + +2. We pass `block_idx` as a scalar prefetch argument into `pallas_call`. + +3. In our index map, we use the block index to select the corresponding block by returning `(block_idx[0], block_idx[1])`. + +Of course, this kernel is limited in that our slice sizes must fit inside of a kernel block (limited by VMEM size) and we can only start on size-aligned indices. A more advanced kernel would decouple the kernel block size with the slice size and allow non-aligned start indices. + +```{code-cell} +--- +executionInfo: + elapsed: 143 + status: ok + timestamp: 1726003877561 + user: + displayName: Justin Fu + userId: '17543197034567316452' + user_tz: 420 +id: FWeTBlEYlCGD +outputId: 4b04a441-c97c-4d0d-d167-c60d4d31fd2e +--- +def dynamic_slice_kernel(indices, x_ref, o_ref): + del indices + o_ref[...] = x_ref[...] + +@checkify.checkify +@functools.partial(jax.jit, static_argnums=(2,)) +def block_dynamic_slice(x, starts, sizes): + grid_spec = pltpu.PrefetchScalarGridSpec( + num_scalar_prefetch=1, + grid=(1, 1), + in_specs=[pl.BlockSpec( + sizes, + lambda i, j, block_idx: (block_idx[0], block_idx[1]))], + out_specs=pl.BlockSpec(sizes, lambda *_: (0, 0)), + ) + + kernel = pl.pallas_call( + dynamic_slice_kernel, + grid_spec=grid_spec, + out_shape=jax.ShapeDtypeStruct(shape=sizes, dtype=x.dtype), + ) + # Checkify inserts a runtime assert that starts are divisible by block size. + checkify.check(starts[0] % sizes[0] == 0, "Starts must be divisible by size.") + checkify.check(starts[1] % sizes[1] == 0, "Starts must be divisible by size.") + block_idx = jnp.array([starts[0] // sizes[0], starts[1] // sizes[1]]) + return kernel(block_idx, x) + +shape = (512, 512) +x = jnp.reshape(jnp.arange(np.prod(shape), dtype=jnp.int32), shape) +err, result = block_dynamic_slice(x, starts=(128, 256), sizes=(128, 128)) +err.throw() +ref = lax.dynamic_slice(x, start_indices=(128, 256), slice_sizes=(128, 128)) +diff = jnp.max(jnp.abs(result - ref)) +print("Error |result - lax.dynamic_slice| =", diff) +``` + ++++ {"id": "K2dod4lkoifa"} + +## Sparse Kernels: Representing Sparse Data + +Before we dive into implementing sparse kernels, let's first review how sparse matrices are represented. While there are several popular formats for storing sparse matrices, we will be following a blocked variant of the coordinate-list format (COO) in which we will store a matrix as a list of `(block_index, block_data)` pairs. All blocks that are not explicitly stored in the list are assumed to be zero, meaning we can save a significant amount of memory if there are many zero blocks in the matrix. + +The following figure demonstrates how we convert a 4x4 dense matrix (left) into a block-COO format (right) with a block size of 2x2. Note that in the sparse format, we can avoid explicitly storing the upper-right block which consists of all zero elements. + +![block_coo](../../_static/pallas/sparse/block_coo.svg) + +We will use the following helper function to sample a block-sparse matrix. It returns a dense matrix used for checking our results, as well as a list of block data and indices for each axis. + +```{code-cell} +:id: 1gLiSvgIYUEx + +def generate_block_sparse_mat(key, M, N, blk_M, blk_N, p=0.2, dtype=jnp.float32): + """Returns a sampled matrix and its block-sparse representation. + + Args: + key: RNG Key. + M: Major array dimension. + N: Minor array dimension. + blk_M: Block size along M dimension. + blk_N: Block size along N dimension. + p: Probability that a block will be non-zero. + dtype: dtype of the sampled matrix. + + Returns: + dense_mat: A (M, N) dense sampled array. + block_data: A (num_blocks, blk_M, blk_N) array of data blocks representing + the non-zero blocks of the matrix. + indices_i: A (num_blocks,) array of block indices for the first axis. + indices_j: A (num_blocks,) array of block indices for the second axis. + """ + mask_key, blocks_key = jax.random.split(key) + num_blocks = (M // blk_M, N // blk_N) + # We first sample a block mask, denoting which blocks are nonzero. + block_mask = jax.random.bernoulli(mask_key, p=p, shape=num_blocks) + num_blocks = jnp.sum(block_mask) + indices = jnp.where(block_mask) + # For each non-zero block, we sample a block of random values. + block_data = jax.random.uniform(blocks_key, + shape=(num_blocks, blk_M, blk_N), + dtype=dtype) + # For checking purposes, create the dense version of the sparse matrix. + dense_mat = jnp.zeros((M, N), dtype=dtype) + for blk in range(num_blocks): + idx_i = indices[0][blk] + idx_j = indices[1][blk] + slice_i = slice(idx_i * blk_M, (idx_i + 1) * blk_M) + slice_j = slice(idx_j * blk_N, (idx_j + 1) * blk_N) + dense_mat = dense_mat.at[slice_i, slice_j].set(block_data[blk]) + return dense_mat, block_data, indices[0], indices[1] +``` + ++++ {"id": "eFyoZSTOH9Fk"} + +## Example: Sparse @ Dense Matrix Multiplication + +In our first example, we will multiple a sparse LHS matrix with a dense RHS matrix to produce a dense output. + +We will structure our kernel grid with 2 loops - the outer loop over the columns of the RHS/output, and inner loop over the sparse blocks of the LHS. During each inner loop iteration, we load one block from the LHS and lookup the corresponding block on in the RHS using the block index of the contracting dimension (K). We multiply the two blocks together and accumulate into the correct output block. One outer loop iteration will compute a result for an entire column as depicted by the following diagram: + +![sparse_matmul](../../_static/pallas/sparse/sparse_matmul.svg) + +It is important that we group the block indices by row (e.g. `[0, 0, 1, 2, 3, 3]`) before we pass them into the kernel for two reasons. First, in our kernel we need to know when to initially zero-out the accumulator in the output ref, and it is easy to do so if the row indices are grouped. Second, the pipelining logic for Pallas does not allow us to re-visit blocks in the output `Ref` on non-consecutive iterations, and therefore we need to do all accumulation into an output block in consecutive kernel iterations. This is because the pipeline emitter will realize that we loading the same output block on consecutive iterations and keep the block in VMEM. When we change output block Pallas will finally store the output into HBM and assume we never touch it again. Failure to access output blocks consecutively will result in incorrect values even though the kernel is otherwise logically correct. + +```{code-cell} +--- +executionInfo: + elapsed: 673 + status: ok + timestamp: 1725919879291 + user: + displayName: Justin Fu + userId: '17543197034567316452' + user_tz: 420 +id: WfyV2WWhjsyA +outputId: fa4d4fff-bc6b-4dc9-ac14-63276ca14131 +--- +M = N = K = 16384 +blk_M = blk_N = blk_K = 512 + + +def dsd_kernel(idxs_i_ref, idxs_k_ref, # Scalar prefetch inputs. + x_ref, y_ref, _, o_ref, # Kernel inputs. + accum_scratch, + ): + """A DSD (Dense = Sparse @ Dense) matmul kernel.""" + del idxs_k_ref + blk_idx = pl.program_id(0) + is_start = blk_idx == 0 + changed_blocks = (idxs_i_ref[blk_idx] != idxs_i_ref[jnp.maximum(blk_idx-1, 0)]) + @pl.when(is_start | changed_blocks) + def _(): + accum_scratch[...] = jnp.zeros_like(accum_scratch) + accum_scratch[...] += jnp.dot(x_ref[0, :, :], y_ref[...], preferred_element_type=jnp.float32) + + next_block_change = (idxs_i_ref[blk_idx] != idxs_i_ref[jnp.minimum(blk_idx+1, num_blocks)]) + is_end = blk_idx == (num_blocks - 1) + @pl.when(is_end | next_block_change) + def _(): + o_ref[...] = accum_scratch[...].astype(o_ref.dtype) + + +def x_map(blk_idx, j, blk_idxs_i, blk_idxs_k): + del j, blk_idxs_i, blk_idxs_k + return (blk_idx, 0, 0) +def y_map(blk_idx, j, blk_idxs_i, blk_idxs_k): + del blk_idxs_i + return (blk_idxs_k[blk_idx], j) +def o_map(blk_idx, j, blk_idxs_i, blk_idxs_k): + del blk_idxs_k + return (blk_idxs_i[blk_idx], j) + +(X_dense, X_blocks, indices_i, indices_k) = generate_block_sparse_mat( + jax.random.key(0), M, K, blk_M, blk_K, p=0.1, dtype=jnp.bfloat16) +num_blocks = X_blocks.shape[0] +Y = jax.random.uniform(jax.random.key(1), shape=(K, N), dtype=jnp.bfloat16) +zeros = jnp.zeros((M, N), dtype=jnp.bfloat16) +out_shape = jax.ShapeDtypeStruct((M, N), dtype=jnp.bfloat16) + +grid_spec = pltpu.PrefetchScalarGridSpec( + num_scalar_prefetch=2, + # Note that while num_blocks is static here, Pallas does support + # dynamic grid sizes. + grid=(num_blocks, N // blk_N), + in_specs=[pl.BlockSpec((1, blk_M, blk_K), x_map), + pl.BlockSpec((blk_K, blk_N), y_map), + # Placeholder for a zeros-array used by input_output_aliases. + pl.BlockSpec((blk_M, blk_N), o_map), + ], + out_specs=pl.BlockSpec((blk_M, blk_N), o_map), + scratch_shapes=[pltpu.VMEM((blk_M, blk_N), dtype=jnp.float32)] +) +kernel = pl.pallas_call( + dsd_kernel, + grid_spec=grid_spec, + out_shape=out_shape, + # We use input-output aliases to zero-out o_ref for blocks that we never + # visit. By passing in an array of zeros we avoid having o_ref start with + # uninitialized values. + input_output_aliases={4: 0}, # Map zeros to o_ref. +) +args = (indices_i, indices_k, X_blocks, Y, zeros) +result = kernel(*args) + +ref = X_dense @ Y +diff = jnp.abs(ref - result) +print('mean |result - ref|:', jnp.mean(diff)) +``` + ++++ {"id": "2KDgPKF2tUjq"} + +We can do a quick benchmark to compare the performance of our sparse kernel compared to a dense matmul in JAX. On a TPU v5e chip, this kernel achieves a roughly ~6x speed increase compared to the theoretical 10x from the sparsity factor. + +There are a few main tips for performance here, mainly centered around reducing the communication overhead between HBM/VMEM: +- Using `dtype=jnp.bfloat16` is critical for performance since it reduces memory bandwidth by half. +- Using larger block sizes also helps, since matrix multiply is an $O(N^3)$ compute and $O(N^2)$ memory operation. As $N$ grows larger, the kernel becomes compute-bound. However, a counter-argument to this in practice is that smaller block sizes also enables data to be more sparse, so this is a parameter that should be selected carefully. + +```{code-cell} +--- +executionInfo: + elapsed: 6576 + status: ok + timestamp: 1725919886762 + user: + displayName: Justin Fu + userId: '17543197034567316452' + user_tz: 420 +id: CkzjqnekpZbx +outputId: 1ae9031e-705a-4d05-f8b9-d09623918300 +--- +# Benchmark Sparse Pallas kernel vs reference JAX implementation + +def benchmark(f, ntrials: int = 100): + def run(*args, **kwargs): + # Compile function first + jax.block_until_ready(f(*args, **kwargs)) + # Time function + result = timeit.timeit(lambda: jax.block_until_ready(f(*args, **kwargs)), + number=ntrials) + time = result / ntrials + return time + return run + + +n_trials = 100 + +pallas_impl = lambda *args: kernel(*args) +time = benchmark(pallas_impl, n_trials)(indices_i, indices_k, X_blocks, Y, zeros) +print("Sparse Kernel: %.3f ms (avg over %d trials)" % (time * 1000, n_trials)) + +ref_impl = jax.jit(lambda x, y: x @ y) +time = benchmark(ref_impl, n_trials)(X_dense, Y) +print("Reference: %.3f ms (avg over %d trials)" % (time * 1000, n_trials)) +``` + ++++ {"id": "Q1KKd5vTCwnB"} + +## Sparse Access Patterns on Dense Data + +In our previous example we considered the case when the data itself is sparse. This manifested itself in the kernel structure as a dimension in the kernel grid that was dynamic and looped over the number of nonzero blocks (`num_blocks`). + +A second useful programming pattern emerges when the underlying is data is dense, but we wish to perform sparse computation over it. Our kernel grid in this case will be dense, but we wish to skip over some blocks in the grid as indicated by a block-sparse mask. This type of programming pattern is commonly arises when using masks in many machine learning applications, such as causal or local masks in self-attention. In these cases, we can entirely skip over computation in blocks where the mask is zeroed-out. Examples of this programming pattern can be found in the Splash Attention and Grouped Matrix Multiplication kernels located in `jax/experimental/pallas/ops/tpu`, or in PyTorch's [FlexAttention](https://pytorch.org/blog/flexattention/). + +The main performance consideration with dealing with a sparse access pattern on dense data is the interaction with pipelining. On any given kernel iteration, the Pallas pipeline emitter will attempt to prefetch the next block of data by calling the `index_map` for each `BlockSpec` on the next iteration of the grid. However, if our computation is sparse we may be skipping the computation for the next block in the grid, so we need some method to tell the pipeline instead begin fetching the *next block that we are not skipping*. In order to do this, we need to construct *prefetch maps* which contains indices to the next non-skipped block of data for each kernel input. The following diagram illustrates how a prefetch map could be constructed for a block-sparse mask that is stored in a COO-like format. + +![prefetch_map](../../_static/pallas/sparse/prefetch_map.svg) + +*Left: A sparse access pattern, where the color blue denotes blocks with non-zero masks that we need to compute. Right: The prefetch map, where each element of the array contains the index of the next non-zero block data.* + +Once the prefetch map has been constructed, we can pass the map as a scalar prefetch argument and query it in the `index_map` function of the BlockSpec. + +```python +def mask_index_map(prefetch_map, i, j, ...): + next_nonzero_block = prefetch_map[i, j] + return (next_nonzero_block, 0, 0) +``` + +We can construct similar index maps for the other inputs to the kernel. For dense inputs you will most likely need to construct prefetch maps which point to the next non-zero block index in the grid. Our next example will provide an example of using these prefetch maps. + ++++ {"id": "ii7rzL5YIA8-"} + +## Example: Dense @ Dense Matrix Multiplication with a Block-Sparse Output Mask + ++++ {"id": "ecjiqWfA2RlV"} + +In our next example we will cover dense matrix multiplication fused with a sparse output mask using a prefetch map to improve pipelining performance. We will use the mask to selectively skip computing output blocks that are zeroed-out, therefore saving on computation costs. + +As we will be working with a sparse mask, we will begin by implementing a function that converts an `N x M` mask stored in dense format into a block-sparse format. We additionally need to compute prefetch maps to help the pipeline emitter know which block to fetch next. In total, our `sparsify_mask` function computes: +- A `block_mask` of shape `(num_N_blocks, num_M_blocks)` indicating if a block is all-zeros (value `0`) or contains non-zero elements (value `1`). If the `block_mask` has a value of 0 we can skip computing the block in the kernel. +- A `prefetch_mask` array of shape `(num_N_blocks, num_M_blocks)` consisting of indices into `mask_data` for the next non-zero block. +- A `prefetch_i` array of shape `(num_N_blocks, num_M_blocks)` consisting of the next non-masked `i` index of the mask. +- A `prefetch_j` array of shape `(num_N_blocks, num_M_blocks)` consisting of the next non-masked `j` index of the mask. +- A `mask_data` array of shape `(num_blocks, blk_N, blk_M)` containing data for non-zero blocks of the mask. + +```{code-cell} +:id: 19zGcliL2SJy + +def sparsify_mask(mask: jax.Array, + block_shape: tuple[int, int]): + """Preprocesses a mask into a sparse reprentation. + + Args: + mask: A boolean array of shape [M, N] + block_shape: The size of a single block. + + Returns: + block_mask: A block_shape array of booleans indicating whether a block + is all-zeros (0) or contains non-zero elements (1). + prefetch_mask: A block_shape array of integers indicating the index of the + next non-zero block. + mask_data: A (num_blocks, block_shape) array containing + the data for non-zero blocks of the mask. + """ + M, N = mask.shape + bm, bn = block_shape + + block_mask = jnp.zeros((M // bm, N // bn), dtype=mask.dtype) + mask_types_finder = [] + mask_data = [] + mask_type_idxs = [] + + next_mask_type_idx = 0 + prefetch_mask = jnp.zeros_like(block_mask) + next_i = (M // bm) - 1 + next_j = (N // bn) - 1 + prefetch_i = jnp.zeros_like(block_mask) + prefetch_j = jnp.zeros_like(block_mask) + for i in range(M // bm, -1, -1): + for j in range(N // bn, -1, -1): + mask_block = mask[i * bm :(i + 1) * bm, + j * bn :(j + 1) * bn] + is_nonzero = jnp.any(mask_block) + if is_nonzero: + try: + type_index = mask_types_finder.index(str(mask_block)) + except ValueError: + type_index = len(mask_types_finder) + mask_types_finder.append(str(mask_block)) + mask_data.append(mask_block) + next_mask_type_idx = type_index + next_i = i + next_j = j + else: + type_index = -1 + mask_type_idxs.append(type_index) + block_mask = block_mask.at[i, j].set(is_nonzero) + prefetch_mask = prefetch_mask.at[i, j].set(next_mask_type_idx) + prefetch_i = prefetch_i.at[i, j].set(next_i) + prefetch_j = prefetch_j.at[i, j].set(next_j) + return block_mask, prefetch_mask, prefetch_i, prefetch_j, jnp.stack(mask_data) +``` + ++++ {"id": "w4b7ckKq67Xw"} + +In terms of the structure of the kernel, we use the same grid pattern as the standard matrix multiplication kernel we covered in previous tutorials with a 3 loops over the `N`, `M`, and `K` dimensions. Within the kernel itself, we first check the `block_mask` to see if the mask for the current output block was all zeros. If the mask is all zeros, we can skip computation and move onto the next block; otherwise we need to compute the matrix multiplication and then mask the result. + +```{code-cell} +--- +executionInfo: + elapsed: 5374 + status: ok + timestamp: 1725919713252 + user: + displayName: Justin Fu + userId: '17543197034567316452' + user_tz: 420 +id: 4YQ9OmbTCSjT +outputId: 2d752609-34f2-4059-e8ba-4d80afe8cb26 +--- +M = N = K = 16384 +blk_M = blk_N = 512 +blk_K = 1024 + +def sparse_mask_matmul( + block_mask_ref, prefetch_mask, prefetch_i, prefetch_j, # Scalar prefetch inputs. + x_ref, y_ref, mask_ref, o_ref, # Kernel inputs. + accum_scratch + ): + del prefetch_mask, prefetch_i, prefetch_j + i, j, k = pl.program_id(0), pl.program_id(1), pl.program_id(2) + should_compute = block_mask_ref[i, j] != 0 + @pl.when(k == 0) + def _(): + o_ref[...] = jnp.zeros_like(o_ref) + accum_scratch[...] = jnp.zeros_like(accum_scratch[...]) + + # We only compute the output for blocks with non-zero masks. + # Otherwise we skip the computation entirely. + @pl.when(should_compute) + def _(): + result = jnp.dot(x_ref[...], y_ref[...], preferred_element_type=jnp.float32) + accum_scratch[...] += result + @pl.when(k == pl.num_programs(2) - 1) + def _(): + o_ref[...] = (mask_ref[0, ...] * accum_scratch[...]).astype(o_ref.dtype) + +X = jax.random.normal(jax.random.key(0), shape=(M, K), dtype=jnp.bfloat16) +Y = jax.random.normal(jax.random.key(1), shape=(K, N), dtype=jnp.bfloat16) +mask = jnp.ones((M, N), dtype=jnp.int32) +mask = jnp.tril(mask) +block_mask, prefetch_mask, prefetch_i, prefetch_j, sparse_mask_data = sparsify_mask(mask, (blk_M, blk_N)) + +def x_map(i, j, k, block_mask, prefetch_mask, prefetch_i, prefetch_j): + del prefetch_mask, prefetch_j + # Zero-out the k index if the mask is zero, to avoid constantly fetching + # new blocks in the inner loop for blocks we are skipping. + k_fetch = (block_mask[i, j] != 0) * k + return (prefetch_i[i, j], k_fetch) + +def y_map(i, j, k, block_mask, prefetch_mask, prefetch_i, prefetch_j): + del prefetch_mask, prefetch_i + k_fetch = (block_mask[i, j] != 0) * k + return (k_fetch, prefetch_j[i, j]) + +def mask_map(i, j, k, block_mask, prefetch_mask, *_): + del k, block_mask + return (prefetch_mask[i, j], 0, 0) + +def o_map(i, j, k, *_): + del k + return (i, j) + +grid_spec = pltpu.PrefetchScalarGridSpec( + num_scalar_prefetch=4, + grid=(M // blk_M, N // blk_N, K // blk_K), + in_specs=[pl.BlockSpec((blk_M, blk_K), x_map), + pl.BlockSpec((blk_K, blk_N), y_map), + pl.BlockSpec((1, blk_M, blk_N), mask_map)], + out_specs=pl.BlockSpec((blk_M, blk_N), o_map), + scratch_shapes=[pltpu.VMEM((blk_M, blk_N), dtype=jnp.float32)] +) +kernel = pl.pallas_call( + sparse_mask_matmul, + grid_spec=grid_spec, + out_shape=jax.ShapeDtypeStruct((M, N), jnp.bfloat16), +) +args = (block_mask, prefetch_mask, prefetch_i, prefetch_j, X, Y, sparse_mask_data) +result = kernel(*args) + +ref = mask * (X @ Y) +diff = jnp.abs(ref - result) +print('mean |result - ref|:', jnp.mean(diff)) +``` + ++++ {"id": "uutNGgjZGGhB"} + +Now let's compare performance versus a naive dense implementation. On TPU v5e, we achieve around a ~1.8x speed increase with the sparse kernel, compared to a theoretical best-case of 2x from using a lower triangular mask and only visiting half of the possible outputs. + +We would generally expect performance to get closer to the theoretical peak as our inputs get larger, since a few of the main reasons why we don't exactly reach theoretical performance are: +- We skip slightly less than half of computation since the blocks along the diagonal are mixed 0s and 1s, and for mixed blocks we need to compute the entire block. With larger inputs, our overhead for mixed blocks becomes smaller relative to the overall computation. +- The pipeline bubble also becomes accounts for a less percentage of the overall runtime as inputs become larger. + +```{code-cell} +--- +executionInfo: + elapsed: 8877 + status: ok + timestamp: 1725917397452 + user: + displayName: Justin Fu + userId: '17543197034567316452' + user_tz: 420 +id: MAT9JjGNvsx8 +outputId: a32d56fb-a71b-4007-c6a5-e5270dcaa6cf +--- +n_trials = 100 + +pallas_impl = lambda *args: kernel(*args) +time = benchmark(pallas_impl, n_trials)(block_mask, prefetch_mask, prefetch_i, prefetch_j, X, Y, sparse_mask_data) +print("Sparse Kernel: %.3f ms (avg over %d trials)" % (time * 1000, n_trials)) + +ref_impl = jax.jit(lambda mask, x, y: mask * (x @ y)) +time = benchmark(ref_impl, n_trials)(mask, X, Y) +print("Reference: %.3f ms (avg over %d trials)" % (time * 1000, n_trials)) +``` diff --git a/docs/persistent_compilation_cache.md b/docs/persistent_compilation_cache.md index af20aa7bba24..47a7587b620f 100644 --- a/docs/persistent_compilation_cache.md +++ b/docs/persistent_compilation_cache.md @@ -1,4 +1,4 @@ -# Persistent Compilation Cache +# Persistent compilation cache @@ -29,7 +29,7 @@ f(x) ### Setting cache directory The compilation cache is enabled when the -[cache location](https://github.com/google/jax/blob/jax-v0.4.26/jax/_src/config.py#L1206) +[cache location](https://github.com/jax-ml/jax/blob/jax-v0.4.26/jax/_src/config.py#L1206) is set. This should be done prior to the first compilation. Set the location as follows: @@ -54,7 +54,7 @@ os.environ["JAX_COMPILATION_CACHE_DIR"] = "/tmp/jax_cache" jax.config.update("jax_compilation_cache_dir", "/tmp/jax_cache") ``` -(3) Using [`set_cache_dir()`](https://github.com/google/jax/blob/jax-v0.4.26/jax/experimental/compilation_cache/compilation_cache.py#L18) +(3) Using [`set_cache_dir()`](https://github.com/jax-ml/jax/blob/jax-v0.4.26/jax/experimental/compilation_cache/compilation_cache.py#L18) ```python from jax.experimental.compilation_cache import compilation_cache as cc diff --git a/docs/profiling.md b/docs/profiling.md index 6eceec8f54b8..91f4d61b21b6 100644 --- a/docs/profiling.md +++ b/docs/profiling.md @@ -1,4 +1,4 @@ -# Profiling JAX programs +# Profiling computation diff --git a/docs/quickstart.md b/docs/quickstart.md index e071a7ce7555..77cbb9d46ab8 100644 --- a/docs/quickstart.md +++ b/docs/quickstart.md @@ -5,7 +5,7 @@ jupytext: extension: .md format_name: myst format_version: 0.13 - jupytext_version: 1.16.1 + jupytext_version: 1.16.4 kernelspec: display_name: Python 3 language: python @@ -16,7 +16,7 @@ kernelspec: -**JAX a library for array-oriented numerical computation (*à la* [NumPy](https://numpy.org/)), with automatic differentiation and JIT compilation to enable high-performance machine learning research**. +**JAX is a library for array-oriented numerical computation (*à la* [NumPy](https://numpy.org/)), with automatic differentiation and JIT compilation to enable high-performance machine learning research**. This document provides a quick overview of essential JAX features, so you can get started with JAX quickly: @@ -88,8 +88,8 @@ _ = selu_jit(x) # compiles on first call %timeit selu_jit(x).block_until_ready() ``` -The above timing represent execution on CPU, but the same code can be run on GPU or TPU, -typically for an even greater speedup. +The above timing represents execution on CPU, but the same code can be run on GPU or +TPU, typically for an even greater speedup. For more on JIT compilation in JAX, check out {ref}`jit-compilation`. @@ -183,7 +183,7 @@ print('Naively batched') %timeit naively_batched_apply_matrix(batched_x).block_until_ready() ``` -A programmer familiar with the the `jnp.dot` function might recognize that `apply_matrix` can +A programmer familiar with the `jnp.dot` function might recognize that `apply_matrix` can be rewritten to avoid explicit looping, using the built-in batching semantics of `jnp.dot`: ```{code-cell} diff --git a/docs/random-numbers.md b/docs/random-numbers.md index 85bb5ce01974..2ad1eadb0968 100644 --- a/docs/random-numbers.md +++ b/docs/random-numbers.md @@ -5,7 +5,7 @@ jupytext: extension: .md format_name: myst format_version: 0.13 - jupytext_version: 1.16.1 + jupytext_version: 1.16.4 kernelspec: display_name: Python 3 language: python diff --git a/docs/sharded-computation.ipynb b/docs/sharded-computation.ipynb index 8fa2107795fd..cdfda63c6f13 100644 --- a/docs/sharded-computation.ipynb +++ b/docs/sharded-computation.ipynb @@ -5,7 +5,7 @@ "metadata": {}, "source": [ "(sharded-computation)=\n", - "# Introduction to sharded computation\n", + "# Introduction to parallel programming\n", "\n", "\n", "\n", @@ -60,7 +60,7 @@ "\n", "Key to all of the distributed computation approaches below is the concept of *data sharding*, which describes how data is laid out on the available devices.\n", "\n", - "How can JAX can understand how the data is laid out across devices? JAX's datatype, the {class}`jax.Array` immutable array data structure, represents arrays with physical storage spanning one or multiple devices, and helps make parallelism a core feature of JAX. The {class}`jax.Array` object is designed with distributed data and computation in mind. Every `jax.Array` has an associated {mod}`jax.sharding.Sharding` object, which describes which shard of the global data is required by each global device. When you create a {class}`jax.Array` from scratch, you also need to create its `Sharding`.\n", + "How can JAX understand how the data is laid out across devices? JAX's datatype, the {class}`jax.Array` immutable array data structure, represents arrays with physical storage spanning one or multiple devices, and helps make parallelism a core feature of JAX. The {class}`jax.Array` object is designed with distributed data and computation in mind. Every `jax.Array` has an associated {mod}`jax.sharding.Sharding` object, which describes which shard of the global data is required by each global device. When you create a {class}`jax.Array` from scratch, you also need to create its `Sharding`.\n", "\n", "In the simplest cases, arrays are sharded on a single device, as demonstrated below:" ] @@ -188,15 +188,9 @@ } ], "source": [ - "# Pardon the boilerplate; constructing a sharding will become easier in future!\n", - "from jax.sharding import Mesh\n", - "from jax.sharding import PartitionSpec\n", - "from jax.sharding import NamedSharding\n", - "from jax.experimental import mesh_utils\n", + "from jax.sharding import PartitionSpec as P\n", "\n", - "P = jax.sharding.PartitionSpec\n", - "devices = mesh_utils.create_device_mesh((2, 4))\n", - "mesh = jax.sharding.Mesh(devices, ('x', 'y'))\n", + "mesh = jax.make_mesh((2, 4), ('x', 'y'))\n", "sharding = jax.sharding.NamedSharding(mesh, P('x', 'y'))\n", "print(sharding)" ] @@ -366,7 +360,7 @@ "\n", "## 2. Semi-automated sharding with constraints\n", "\n", - "If you'd like to have some control over the sharding used within a particular computation, JAX offers the {func}`~jax.lax.with_sharding_constraint` function. You can use {func}`jax.lax.with_sharding_constraint` (in place of (func}`jax.device_put()`) together with {func}`jax.jit` for more control over how the compiler constraints how the intermediate values and outputs are distributed.\n", + "If you'd like to have some control over the sharding used within a particular computation, JAX offers the {func}`~jax.lax.with_sharding_constraint` function. You can use {func}`jax.lax.with_sharding_constraint` (in place of {func}`jax.device_put()`) together with {func}`jax.jit` for more control over how the compiler constraints how the intermediate values and outputs are distributed.\n", "\n", "For example, suppose that within `f_contract` above, you'd prefer the output not to be partially-replicated, but rather to be fully sharded across the eight devices:" ] @@ -405,9 +399,7 @@ "@jax.jit\n", "def f_contract_2(x):\n", " out = x.sum(axis=0)\n", - " # mesh = jax.create_mesh((8,), 'x')\n", - " devices = mesh_utils.create_device_mesh(8)\n", - " mesh = jax.sharding.Mesh(devices, 'x')\n", + " mesh = jax.make_mesh((8,), ('x',))\n", " sharding = jax.sharding.NamedSharding(mesh, P('x'))\n", " return jax.lax.with_sharding_constraint(out, sharding)\n", "\n", @@ -460,8 +452,7 @@ ], "source": [ "from jax.experimental.shard_map import shard_map\n", - "P = jax.sharding.PartitionSpec\n", - "mesh = jax.sharding.Mesh(jax.devices(), 'x')\n", + "mesh = jax.make_mesh((8,), ('x',))\n", "\n", "f_elementwise_sharded = shard_map(\n", " f_elementwise,\n", @@ -659,8 +650,7 @@ } ], "source": [ - "P = jax.sharding.PartitionSpec\n", - "mesh = jax.sharding.Mesh(jax.devices(), 'x')\n", + "mesh = jax.make_mesh((8,), ('x',))\n", "sharding = jax.sharding.NamedSharding(mesh, P('x'))\n", "\n", "x_sharded = jax.device_put(x, sharding)\n", diff --git a/docs/sharded-computation.md b/docs/sharded-computation.md index 345ca7987b41..84516a557166 100644 --- a/docs/sharded-computation.md +++ b/docs/sharded-computation.md @@ -5,14 +5,14 @@ jupytext: extension: .md format_name: myst format_version: 0.13 - jupytext_version: 1.16.1 + jupytext_version: 1.16.4 kernelspec: display_name: Python 3 name: python3 --- (sharded-computation)= -# Introduction to sharded computation +# Introduction to parallel programming @@ -39,7 +39,7 @@ jax.devices() Key to all of the distributed computation approaches below is the concept of *data sharding*, which describes how data is laid out on the available devices. -How can JAX can understand how the data is laid out across devices? JAX's datatype, the {class}`jax.Array` immutable array data structure, represents arrays with physical storage spanning one or multiple devices, and helps make parallelism a core feature of JAX. The {class}`jax.Array` object is designed with distributed data and computation in mind. Every `jax.Array` has an associated {mod}`jax.sharding.Sharding` object, which describes which shard of the global data is required by each global device. When you create a {class}`jax.Array` from scratch, you also need to create its `Sharding`. +How can JAX understand how the data is laid out across devices? JAX's datatype, the {class}`jax.Array` immutable array data structure, represents arrays with physical storage spanning one or multiple devices, and helps make parallelism a core feature of JAX. The {class}`jax.Array` object is designed with distributed data and computation in mind. Every `jax.Array` has an associated {mod}`jax.sharding.Sharding` object, which describes which shard of the global data is required by each global device. When you create a {class}`jax.Array` from scratch, you also need to create its `Sharding`. In the simplest cases, arrays are sharded on a single device, as demonstrated below: @@ -72,15 +72,9 @@ Here, define a {class}`~jax.sharding.NamedSharding`, which specifies an N-dimens ```{code-cell} :outputId: 0b397dba-3ddc-4aca-f002-2beab7e6b8a5 -# Pardon the boilerplate; constructing a sharding will become easier in future! -from jax.sharding import Mesh -from jax.sharding import PartitionSpec -from jax.sharding import NamedSharding -from jax.experimental import mesh_utils +from jax.sharding import PartitionSpec as P -P = jax.sharding.PartitionSpec -devices = mesh_utils.create_device_mesh((2, 4)) -mesh = jax.sharding.Mesh(devices, ('x', 'y')) +mesh = jax.make_mesh((2, 4), ('x', 'y')) sharding = jax.sharding.NamedSharding(mesh, P('x', 'y')) print(sharding) ``` @@ -139,7 +133,7 @@ The result is partially replicated: that is, the first two elements of the array ## 2. Semi-automated sharding with constraints -If you'd like to have some control over the sharding used within a particular computation, JAX offers the {func}`~jax.lax.with_sharding_constraint` function. You can use {func}`jax.lax.with_sharding_constraint` (in place of (func}`jax.device_put()`) together with {func}`jax.jit` for more control over how the compiler constraints how the intermediate values and outputs are distributed. +If you'd like to have some control over the sharding used within a particular computation, JAX offers the {func}`~jax.lax.with_sharding_constraint` function. You can use {func}`jax.lax.with_sharding_constraint` (in place of {func}`jax.device_put()`) together with {func}`jax.jit` for more control over how the compiler constraints how the intermediate values and outputs are distributed. For example, suppose that within `f_contract` above, you'd prefer the output not to be partially-replicated, but rather to be fully sharded across the eight devices: @@ -149,9 +143,7 @@ For example, suppose that within `f_contract` above, you'd prefer the output not @jax.jit def f_contract_2(x): out = x.sum(axis=0) - # mesh = jax.create_mesh((8,), 'x') - devices = mesh_utils.create_device_mesh(8) - mesh = jax.sharding.Mesh(devices, 'x') + mesh = jax.make_mesh((8,), ('x',)) sharding = jax.sharding.NamedSharding(mesh, P('x')) return jax.lax.with_sharding_constraint(out, sharding) @@ -177,8 +169,7 @@ In the automatic parallelism methods explored above, you can write a function as :outputId: 435c32f3-557a-4676-c11b-17e6bab8c1e2 from jax.experimental.shard_map import shard_map -P = jax.sharding.PartitionSpec -mesh = jax.sharding.Mesh(jax.devices(), 'x') +mesh = jax.make_mesh((8,), ('x',)) f_elementwise_sharded = shard_map( f_elementwise, @@ -268,8 +259,7 @@ If you shard the leading axis of both `x` and `weights` in the same way, then th ```{code-cell} :outputId: 80be899e-8dbc-4bfc-acd2-0f3d554a0aa5 -P = jax.sharding.PartitionSpec -mesh = jax.sharding.Mesh(jax.devices(), 'x') +mesh = jax.make_mesh((8,), ('x',)) sharding = jax.sharding.NamedSharding(mesh, P('x')) x_sharded = jax.device_put(x, sharding) diff --git a/docs/sphinxext/jax_extensions.py b/docs/sphinxext/jax_extensions.py index 3a78557632a7..7cce8b88254d 100644 --- a/docs/sphinxext/jax_extensions.py +++ b/docs/sphinxext/jax_extensions.py @@ -26,14 +26,14 @@ def jax_issue_role(name, rawtext, text, lineno, inliner, options=None, :jax-issue:`1234` This will output a hyperlink of the form - `#1234 `_. These links work even + `#1234 `_. These links work even for PR numbers. """ text = text.lstrip('#') if not text.isdigit(): raise RuntimeError(f"Invalid content in {rawtext}: expected an issue or PR number.") options = {} if options is None else options - url = f"https://github.com/google/jax/issues/{text}" + url = f"https://github.com/jax-ml/jax/issues/{text}" node = nodes.reference(rawtext, '#' + text, refuri=url, **options) return [node], [] diff --git a/docs/stateful-computations.md b/docs/stateful-computations.md index 5a8af2b74142..2ff82e0431e2 100644 --- a/docs/stateful-computations.md +++ b/docs/stateful-computations.md @@ -5,14 +5,14 @@ jupytext: extension: .md format_name: myst format_version: 0.13 - jupytext_version: 1.16.1 + jupytext_version: 1.16.4 kernelspec: display_name: Python 3 language: python name: python3 --- -# Stateful Computations +# Stateful computations @@ -144,7 +144,7 @@ This is because, like the strategy we just applied, object-oriented programming In our case, the `CounterV2` class is nothing more than a namespace bringing all the functions that use `CounterState` into one location. Exercise for the reader: do you think it makes sense to keep it as a class? -Incidentally, you've already seen an example of this strategy in the JAX pseudo-randomness API, {mod}`jax.random`, shown in the :ref:`pseudorandom-numbers` section. +Incidentally, you've already seen an example of this strategy in the JAX pseudo-randomness API, {mod}`jax.random`, shown in the {ref}`pseudorandom-numbers` section. Unlike Numpy, which manages random state using implicitly updated stateful classes, JAX requires the programmer to work directly with the random generator state -- the PRNG key. @@ -234,4 +234,4 @@ Handling parameters manually seems fine if you're dealing with two parameters, b 2) Are we supposed to pipe all these things around manually? -The details can be tricky to handle, but there are examples of libraries that take care of this for you. See [JAX Neural Network Libraries](https://github.com/google/jax#neural-network-libraries) for some examples. +The details can be tricky to handle, but there are examples of libraries that take care of this for you. See [JAX Neural Network Libraries](https://github.com/jax-ml/jax#neural-network-libraries) for some examples. diff --git a/docs/tutorials.rst b/docs/tutorials.rst index 2f90e4226e50..a31517155e1a 100644 --- a/docs/tutorials.rst +++ b/docs/tutorials.rst @@ -1,7 +1,7 @@ .. _jax-tutorials: -JAX tutorials -============= +Tutorials +========= .. toctree:: :maxdepth: 1 @@ -16,3 +16,13 @@ JAX tutorials working-with-pytrees sharded-computation stateful-computations + +.. toctree:: + :maxdepth: 1 + :caption: Advanced tutorials + + advanced-autodiff + external-callbacks + gradient-checkpointing + jax-primitives + jaxpr diff --git a/docs/type_promotion.rst b/docs/type_promotion.rst index 103a8331df2b..d3724745fe08 100644 --- a/docs/type_promotion.rst +++ b/docs/type_promotion.rst @@ -218,7 +218,7 @@ Strict dtype promotion ---------------------- In some contexts it can be useful to disable implicit type promotion behavior, and instead require all promotions to be explicit. This can be done in JAX by setting the -``jax_numpy_dtype_promtion`` flag to ``'strict'``. Locally, it can be done with a\ +``jax_numpy_dtype_promotion`` flag to ``'strict'``. Locally, it can be done with a\ context manager: .. code-block:: python diff --git a/docs/user_guides.rst b/docs/user_guides.rst index 57913bf6d4c8..6481da7a31dd 100644 --- a/docs/user_guides.rst +++ b/docs/user_guides.rst @@ -1,6 +1,6 @@ .. _user-guides: -User Guides +User guides =========== User guides are deeper dives into particular topics within JAX @@ -9,7 +9,7 @@ or deployed codebases. .. toctree:: :maxdepth: 1 - :caption: Debugging and Performance + :caption: Debugging and performance notebooks/thinking_in_jax profiling @@ -20,25 +20,26 @@ or deployed codebases. .. toctree:: :maxdepth: 1 - :caption: Development + :caption: Interfaces - jaxpr - notebooks/external_callbacks - type_promotion pytrees - -.. toctree:: - :maxdepth: 1 - :caption: Run Time - + errors aot export/index - errors + type_promotion transfer_guard .. toctree:: :maxdepth: 1 - :caption: Custom Operations + :caption: Custom operations pallas/index ffi + +.. toctree:: + :caption: Example applications + :maxdepth: 1 + + notebooks/neural_network_with_tfds_data + notebooks/Neural_Network_and_Data_Loading + notebooks/vmapped_log_probs diff --git a/docs/working-with-pytrees.md b/docs/working-with-pytrees.md index 2bd1cc08ecdf..e41179996bc4 100644 --- a/docs/working-with-pytrees.md +++ b/docs/working-with-pytrees.md @@ -5,7 +5,7 @@ jupytext: extension: .md format_name: myst format_version: 0.13 - jupytext_version: 1.16.1 + jupytext_version: 1.16.4 kernelspec: display_name: Python 3 language: python diff --git a/docs/xla_flags.md b/docs/xla_flags.md new file mode 100644 index 000000000000..b332940ccb9d --- /dev/null +++ b/docs/xla_flags.md @@ -0,0 +1,89 @@ +# List of XLA compiler flags + + + +## Introduction +This guide gives a brief overview of XLA and how XLA relates to Jax. +For in-depth details please refer to [XLA documentation](https://openxla.org/xla). Then it lists commonly-used XLA compiler flags designed to optimize performance of Jax programs. + +## XLA: The Powerhouse Behind Jax +XLA (Accelerated Linear Algebra) is a domain-specific compiler for linear algebra that plays a pivotal role in Jax's performance and flexibility. It enables Jax to generate optimized code for various hardware backends (CPUs, GPUs, TPUs) by transforming and compiling your Python/NumPy-like code into efficient machine instructions. + +Jax uses XLA's JIT compilation capabilities to transform your Python functions into optimized XLA computations at runtime. + +## Configuring XLA in Jax: +You can influence XLA's behavior in Jax by setting XLA_FLAGS environment variables before running your Python script or colab notebook. + +For the colab notebooks: + +Provide flags using `os.environ['XLA_FLAGS']`: + + +```python +import os + +# Set multiple flags separated by spaces +os.environ['XLA_FLAGS'] = '--flag1=value1 --flag2=value2' +``` + +For the python scripts: + +Specify `XLA_FLAGS` as a part of cli command: + +```bash +XLA_FLAGS='--flag1=value1 --flag2=value2' python3 source.py +``` + +**Important Notes:** + +* Set `XLA_FLAGS` before importing Jax or other relevant libraries. Changing `XLA_FLAGS` after backend initialization will have no effect and given backend initialization time is not clearly defined it is usually safer to set `XLA_FLAGS` before executing any Jax code. +* Experiment with different flags to optimize performance for your specific use case. + + +**For further information:** +* Complete and up to date documentation about XLA can be found in the official [XLA documentation](https://openxla.org/xla). + +* For backends supported by open-source version of XLA (CPU, GPU), XLA flags are defined with their default values in [xla/debug_options_flags.cc](https://github.com/openxla/xla/blob/main/xla/debug_options_flags.cc), and a complete list of flags could be found [here](https://github.com/openxla/xla/blob/main/xla/xla.proto). +* TPU compiler flags are not part of [OpenXLA](https://github.com/openxla/xla), but commonly-used options are listed below. + +* Please note that this list of flags is not exhaustive and is subject to change. These flags are implementation details, and there is no guarantee that they will remain available or maintain their current behavior. +### Common XLA flags +| Flag | Type | Notes | +| ---- | ---- | ----- | +| `xla_dump_to` | String (filepath) | The folder where pre-optimization HLO files and other artifacts will be placed (see [XLA Tools](https://openxla.org/xla/tools)). | +| `xla_enable_async_collective_permute` | TristateFlag (true/false/auto) | Rewrites all collective-permute operations to their asynchronous variants. When set to `auto`, XLA can turn on async collective based on other configurations or conditions automatically. | +| `xla_enable_async_all_gather` | TristateFlag (true/false/auto) | If set to true, enables async all gather. If `auto`, enables only for platforms that implement async all-gather. The implementation (such as BC-offload or continuation fusion) is chosen based on other flag values. | +| `xla_disable_hlo_passes` | String (comma-separated list of pass names) | Comma-separated list of HLO passes to be disabled. These names must exactly match the pass name (no whitespace around commas). | + +### TPU XLA flags +| Flag | Type | Notes | +| ---- | ---- | ----- | +| `xla_tpu_enable_data_parallel_all_reduce_opt` | Boolean (true/false) | Optimization to increase overlap opportunities for DCN (data center networking) all-reduces used for data parallel sharding. | +| `xla_tpu_data_parallel_opt_different_sized_ops` | Boolean (true/false) | Enables pipelining of data parallel ops across multiple iterations even if their output sizes doesn't match what can Be saved in place in the stacked variables. Can increase memory pressure. | +| `xla_tpu_enable_async_collective_fusion` | Boolean (true/false) | Enables the pass which fuses async collective communications with compute ops (output/loop-fusion or convolution) that are scheduled between their -start and -done instructions. | +| `xla_tpu_enable_async_collective_fusion_fuse_all_gather` | TristateFlag (true/false/auto) | Enables fusing all-gathers within the AsyncCollectiveFusion pass.
If set to `auto`, it will be enabled based on the target. | +| `xla_tpu_enable_async_collective_fusion_multiple_steps` | Boolean (true/false) | Enables continuing the same async collective in multiple steps (fusions) in the AsyncCollectiveFusion pass. | +| `xla_tpu_overlap_compute_collective_tc` | Boolean (true/false) | Enables the overlap of compute and communication on a single TensorCore, i.e., one core equivalent of MegaCore fusion. | +| `xla_tpu_spmd_rng_bit_generator_unsafe` | Boolean (true/false) | Whether to run RngBitGenerator HLO in a partitioned way, which is unsafe if deterministic results are expected with different shardings on different parts of the computation. | +| `xla_tpu_megacore_fusion_allow_ags` | Boolean (true/false) | Allows fusing all-gathers with convolutions/all-reduces. | +| `xla_tpu_enable_ag_backward_pipelining` | Boolean (true/false) | Pipelines all-gathers (currently megascale all-gathers) backwards through scan loops. | + +### GPU XLA flags +| Flag | Type | Notes | +| ---- | ---- | ----- | +| `xla_gpu_enable_latency_hiding_scheduler` | Boolean (true/false) |This flag enables latency hiding schedulers to overlap asynchronous communication with computation efficiently. The default value is False. | +| `xla_gpu_enable_triton_gemm` | Boolean (true/false) | Use Triton-based matrix multiplication. | +| `xla_gpu_graph_level` | Flag (0-3) | The legacy flag for setting GPU graph level. Use xla_gpu_enable_command_buffer in new use cases. 0 = off; 1 = capture fusions and memcpys; 2 = capture gemms; 3 = capture convolutions. | +| `xla_gpu_all_reduce_combine_threshold_bytes` | Integer (bytes) | These flags tune when to combine multiple small AllGather / ReduceScatter / AllReduce into one big AllGather / ReduceScatter / AllReduce to reduce time spent on cross-device communication. For example, for the AllGather / ReduceScatter thresholds on a Transformer-based workload, consider tuning them high enough so as to combine at least a Transformer Layer’s weight AllGather / ReduceScatter. By default, the combine_threshold_bytes is set to 256. | +| `xla_gpu_all_gather_combine_threshold_bytes` | Integer (bytes) | See xla_gpu_all_reduce_combine_threshold_bytes above. | +| `xla_gpu_reduce_scatter_combine_threshold_bytes` | Integer (bytes) | See xla_gpu_all_reduce_combine_threshold_bytes above. | +| `xla_gpu_enable_pipelined_all_gather` | Boolean (true/false) | Enable pipelinling of all-gather instructions. | +| `xla_gpu_enable_pipelined_reduce_scatter` | Boolean (true/false) | Enable pipelinling of reduce-scatter instructions. | +| `xla_gpu_enable_pipelined_all_reduce` | Boolean (true/false) | Enable pipelinling of all-reduce instructions. | +| `xla_gpu_enable_while_loop_double_buffering` | Boolean (true/false) | Enable double-buffering for while loop. | +| `xla_gpu_enable_triton_softmax_fusion` | Boolean (true/false) | Use Triton-based Softmax fusion. | +| `xla_gpu_enable_all_gather_combine_by_dim` | Boolean (true/false) | Combine all-gather ops with the same gather dimension or irrespective of their dimension. | +| `xla_gpu_enable_reduce_scatter_combine_by_dim` | Boolean (true/false) | Combine reduce-scatter ops with the same dimension or irrespective of their dimension. | + +**Additional reading:** +* [GPU performance tips](https://jax.readthedocs.io/en/latest/gpu_performance_tips.html#xla-performance-flags) diff --git a/examples/ffi/CMakeLists.txt b/examples/ffi/CMakeLists.txt new file mode 100644 index 000000000000..8d9b811374d1 --- /dev/null +++ b/examples/ffi/CMakeLists.txt @@ -0,0 +1,15 @@ +cmake_minimum_required(VERSION 3.15...3.30) +project(${SKBUILD_PROJECT_NAME} LANGUAGES CXX) + +find_package(Python 3.10 REQUIRED COMPONENTS Interpreter Development.Module) +execute_process( + COMMAND "${Python_EXECUTABLE}" + "-c" "from jax.extend import ffi; print(ffi.include_dir())" + OUTPUT_STRIP_TRAILING_WHITESPACE OUTPUT_VARIABLE XLA_DIR) +message(STATUS "XLA include directory: ${XLA_DIR}") + +find_package(nanobind CONFIG REQUIRED) + +nanobind_add_module(_rms_norm NB_STATIC "src/jax_ffi_example/rms_norm.cc") +target_include_directories(_rms_norm PUBLIC ${XLA_DIR}) +install(TARGETS _rms_norm LIBRARY DESTINATION ${SKBUILD_PROJECT_NAME}) diff --git a/examples/ffi/README.md b/examples/ffi/README.md new file mode 100644 index 000000000000..cc7018782a25 --- /dev/null +++ b/examples/ffi/README.md @@ -0,0 +1,9 @@ +# End-to-end example usage for JAX's foreign function interface + +This directory includes an example project demonstrating the use of JAX's +foreign function interface (FFI). The JAX docs provide more information about +this interface in [the FFI tutorial](https://jax.readthedocs.io/en/latest/ffi.html), +but the example in this directory explicitly demonstrates: + +1. One way to package and distribute FFI targets, and +2. Some more advanced use cases. diff --git a/examples/ffi/pyproject.toml b/examples/ffi/pyproject.toml new file mode 100644 index 000000000000..130dd91bbc70 --- /dev/null +++ b/examples/ffi/pyproject.toml @@ -0,0 +1,12 @@ +[build-system] +requires = ["scikit-build-core", "nanobind", "jax>=0.4.31"] +build-backend = "scikit_build_core.build" + +[project] +name = "jax_ffi_example" +version = "0.0.1" +requires-python = ">=3.10" +dependencies = ["jax"] + +[project.optional-dependencies] +test = ["pytest", "absl-py"] diff --git a/examples/ffi/src/jax_ffi_example/__init__.py b/examples/ffi/src/jax_ffi_example/__init__.py new file mode 100644 index 000000000000..862a661e24b9 --- /dev/null +++ b/examples/ffi/src/jax_ffi_example/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2024 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/examples/ffi/src/jax_ffi_example/rms_norm.cc b/examples/ffi/src/jax_ffi_example/rms_norm.cc new file mode 100644 index 000000000000..2fb8d96c8461 --- /dev/null +++ b/examples/ffi/src/jax_ffi_example/rms_norm.cc @@ -0,0 +1,157 @@ +/* Copyright 2024 The JAX Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include +#include +#include +#include + +#include "nanobind/nanobind.h" +#include "xla/ffi/api/c_api.h" +#include "xla/ffi/api/ffi.h" + +namespace nb = nanobind; +namespace ffi = xla::ffi; + +// This is the example "library function" that we want to expose to JAX. This +// isn't meant to be a particularly good implementation, it's just here as a +// placeholder for the purposes of this tutorial. +float ComputeRmsNorm(float eps, int64_t size, const float *x, float *y) { + float sm = 0.0f; + for (int64_t n = 0; n < size; ++n) { + sm += x[n] * x[n]; + } + float scale = 1.0f / std::sqrt(sm / float(size) + eps); + for (int64_t n = 0; n < size; ++n) { + y[n] = x[n] * scale; + } + return scale; +} + +// A helper function for extracting the relevant dimensions from `ffi::Buffer`s. +// In this example, we treat all leading dimensions as batch dimensions, so this +// function returns the total number of elements in the buffer, and the size of +// the last dimension. +template +std::pair GetDims(const ffi::Buffer &buffer) { + auto dims = buffer.dimensions(); + if (dims.size() == 0) { + return std::make_pair(0, 0); + } + return std::make_pair(buffer.element_count(), dims.back()); +} + +// A wrapper function providing the interface between the XLA FFI call and our +// library function `ComputeRmsNorm` above. This function handles the batch +// dimensions by calling `ComputeRmsNorm` within a loop. +ffi::Error RmsNormImpl(float eps, ffi::Buffer x, + ffi::Result> y) { + auto [totalSize, lastDim] = GetDims(x); + if (lastDim == 0) { + return ffi::Error(ffi::ErrorCode::kInvalidArgument, + "RmsNorm input must be an array"); + } + for (int64_t n = 0; n < totalSize; n += lastDim) { + ComputeRmsNorm(eps, lastDim, &(x.typed_data()[n]), &(y->typed_data()[n])); + } + return ffi::Error::Success(); +} + +// Wrap `RmsNormImpl` and specify the interface to XLA. If you need to declare +// this handler in a header, you can use the `XLA_FFI_DECLASE_HANDLER_SYMBOL` +// macro: `XLA_FFI_DECLASE_HANDLER_SYMBOL(RmsNorm)`. +XLA_FFI_DEFINE_HANDLER_SYMBOL(RmsNorm, RmsNormImpl, + ffi::Ffi::Bind() + .Attr("eps") + .Arg>() // x + .Ret>() // y +); + +ffi::Error RmsNormFwdImpl(float eps, ffi::Buffer x, + ffi::Result> y, + ffi::Result> res) { + auto [totalSize, lastDim] = GetDims(x); + if (lastDim == 0) { + return ffi::Error(ffi::ErrorCode::kInvalidArgument, + "RmsNormFwd input must be an array"); + } + for (int64_t n = 0, idx = 0; n < totalSize; n += lastDim, ++idx) { + res->typed_data()[idx] = ComputeRmsNorm(eps, lastDim, &(x.typed_data()[n]), + &(y->typed_data()[n])); + } + return ffi::Error::Success(); +} + +XLA_FFI_DEFINE_HANDLER_SYMBOL(RmsNormFwd, RmsNormFwdImpl, + ffi::Ffi::Bind() + .Attr("eps") + .Arg>() // x + .Ret>() // y + .Ret>() // res +); + +void ComputeRmsNormBwd(int64_t size, float res, const float *x, + const float *ct_y, float *ct_x) { + float ct_res = 0.0f; + for (int64_t n = 0; n < size; ++n) { + ct_res += x[n] * ct_y[n]; + } + float factor = ct_res * res * res * res / float(size); + for (int64_t n = 0; n < size; ++n) { + ct_x[n] = res * ct_y[n] - factor * x[n]; + } +} + +ffi::Error RmsNormBwdImpl(ffi::Buffer res, ffi::Buffer x, + ffi::Buffer ct_y, + ffi::Result> ct_x) { + auto [totalSize, lastDim] = GetDims(x); + if (lastDim == 0) { + return ffi::Error(ffi::ErrorCode::kInvalidArgument, + "RmsNormBwd inputs must be arrays"); + } + for (int64_t n = 0, idx = 0; n < totalSize; n += lastDim, ++idx) { + ComputeRmsNormBwd(lastDim, res.typed_data()[idx], &(x.typed_data()[n]), + &(ct_y.typed_data()[n]), &(ct_x->typed_data()[n])); + } + return ffi::Error::Success(); +} + +XLA_FFI_DEFINE_HANDLER_SYMBOL(RmsNormBwd, RmsNormBwdImpl, + ffi::Ffi::Bind() + .Arg>() // res + .Arg>() // x + .Arg>() // ct_y + .Ret>() // ct_x +); + +template +nb::capsule EncapsulateFfiHandler(T *fn) { + static_assert(std::is_invocable_r_v, + "Encapsulated function must be and XLA FFI handler"); + return nb::capsule(reinterpret_cast(fn)); +} + +NB_MODULE(_rms_norm, m) { + m.def("registrations", []() { + nb::dict registrations; + registrations["rms_norm"] = EncapsulateFfiHandler(RmsNorm); + registrations["rms_norm_fwd"] = EncapsulateFfiHandler(RmsNormFwd); + registrations["rms_norm_bwd"] = EncapsulateFfiHandler(RmsNormBwd); + return registrations; + }); +} diff --git a/examples/ffi/src/jax_ffi_example/rms_norm.py b/examples/ffi/src/jax_ffi_example/rms_norm.py new file mode 100644 index 000000000000..4e0ed1d195b4 --- /dev/null +++ b/examples/ffi/src/jax_ffi_example/rms_norm.py @@ -0,0 +1,99 @@ +# Copyright 2024 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""An example demontrating the basic end-to-end use of the JAX FFI. + +This example is exactly the same as the one in the `FFI tutorial +`, so more details can be found +on that page. But, the high level summary is that we implement our custom +extension in ``rms_norm.cc``, then call it usin ``jax.extend.ffi.ffi_call`` in +this module. The behavior under autodiff is implemented using +``jax.custom_vjp``. +""" + +from functools import partial + +import numpy as np + +import jax +import jax.extend as jex +import jax.numpy as jnp + +from jax_ffi_example import _rms_norm + +for name, target in _rms_norm.registrations().items(): + jex.ffi.register_ffi_target(name, target) + + +@partial(jax.custom_vjp, nondiff_argnums=(1,)) +def rms_norm(x, eps=1e-5): + # We only implemented the `float32` version of this function, so we start by + # checking the dtype. This check isn't strictly necessary because type + # checking is also performed by the FFI when decoding input and output + # buffers, but it can be useful to check types in Python to raise more + # informative errors. + if x.dtype != jnp.float32: + raise ValueError("Only the float32 dtype is implemented by rms_norm") + + # In this case, the output of our FFI function is just a single array with the + # same shape and dtype as the input. + out_type = jax.ShapeDtypeStruct(x.shape, x.dtype) + + return jex.ffi.ffi_call( + # The target name must be the same string as we used to register the target + # above in `register_ffi_target` + "rms_norm", + out_type, + x, + # Note that here we're use `numpy` (not `jax.numpy`) to specify a dtype for + # the attribute `eps`. Our FFI function expects this to have the C++ `float` + # type (which corresponds to numpy's `float32` type), and it must be a + # static parameter (i.e. not a JAX array). + eps=np.float32(eps), + # The `vectorized` parameter controls this function's behavior under `vmap`. + vectorized=True, + ) + + +def rms_norm_fwd(x, eps=1e-5): + y, res = jex.ffi.ffi_call( + "rms_norm_fwd", + ( + jax.ShapeDtypeStruct(x.shape, x.dtype), + jax.ShapeDtypeStruct(x.shape[:-1], x.dtype), + ), + x, + eps=np.float32(eps), + vectorized=True, + ) + return y, (res, x) + + +def rms_norm_bwd(eps, res, ct): + del eps + res, x = res + assert res.shape == ct.shape[:-1] + assert x.shape == ct.shape + return ( + jex.ffi.ffi_call( + "rms_norm_bwd", + jax.ShapeDtypeStruct(ct.shape, ct.dtype), + res, + x, + ct, + vectorized=True, + ), + ) + + +rms_norm.defvjp(rms_norm_fwd, rms_norm_bwd) diff --git a/examples/ffi/tests/rms_norm_test.py b/examples/ffi/tests/rms_norm_test.py new file mode 100644 index 000000000000..aad5562629ed --- /dev/null +++ b/examples/ffi/tests/rms_norm_test.py @@ -0,0 +1,46 @@ +# Copyright 2024 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from absl.testing import absltest + +import jax +import jax.numpy as jnp +from jax._src import test_util as jtu + +from jax_ffi_example import rms_norm + +jax.config.parse_flags_with_absl() + + +def rms_norm_ref(x, eps=1e-5): + scale = jnp.sqrt(jnp.mean(jnp.square(x), axis=-1, keepdims=True) + eps) + return x / scale + + +class RmsNormTests(jtu.JaxTestCase): + def test_basic(self): + x = jnp.linspace(-0.5, 0.5, 15) + self.assertAllClose(rms_norm.rms_norm(x), rms_norm_ref(x)) + + def test_batching(self): + x = jnp.linspace(-0.5, 0.5, 15).reshape((3, 5)) + self.assertAllClose(jax.vmap(rms_norm.rms_norm)(x), jax.vmap(rms_norm_ref)(x)) + + def test_grads(self): + x = jnp.linspace(-0.5, 0.5, 15).reshape((3, 5)) + jtu.check_grads(rms_norm.rms_norm, (x,), order=1, modes=("rev",)) + + +if __name__ == "__main__": + absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/examples/jax_cpp/BUILD b/examples/jax_cpp/BUILD index fccf0cc37048..6e4647b5e491 100644 --- a/examples/jax_cpp/BUILD +++ b/examples/jax_cpp/BUILD @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -package(default_applicable_licenses = ["//third_party/py/jax:license"]) +package(default_applicable_licenses = ["//jax:license"]) licenses(["notice"]) @@ -21,13 +21,13 @@ cc_binary( srcs = ["main.cc"], tags = ["manual"], deps = [ - "//third_party/absl/status:statusor", + "@com_google_absl//absl/status:statusor", + "@tsl//tsl/platform:logging", + "@tsl//tsl/platform:platform_port", "@xla//xla:literal", "@xla//xla:literal_util", "@xla//xla/pjrt:pjrt_client", "@xla//xla/pjrt/cpu:cpu_client", "@xla//xla/tools:hlo_module_loader", - "@tsl//tsl/platform:logging", - "@tsl//tsl/platform:platform_port", ], ) diff --git a/jax/BUILD b/jax/BUILD index 66df0d2f7272..c25d0004e772 100644 --- a/jax/BUILD +++ b/jax/BUILD @@ -57,20 +57,6 @@ config_setting( }, ) -# When `build_cuda_plugin_from_source` is true, it assumes running `bazel test` without preinstalled -# cuda plugin. -bool_flag( - name = "build_cuda_plugin_from_source", - build_setting_default = False, -) - -config_setting( - name = "enable_build_cuda_plugin_from_source", - flag_values = { - ":build_cuda_plugin_from_source": "True", - }, -) - exports_files([ "LICENSE", "version.py", @@ -90,39 +76,32 @@ package_group( packages = [ # Intentionally avoid jax dependencies on jax.extend. # See https://jax.readthedocs.io/en/latest/jep/15856-jex.html - "//third_party/py/jax/tests/...", + "//tests/...", ] + jax_extend_internal_users, ) package_group( name = "mosaic_users", - packages = [ - "//...", - ] + mosaic_internal_users, + includes = [":internal"], + packages = mosaic_internal_users, ) package_group( name = "pallas_gpu_users", - packages = [ - "//...", - "//learning/brain/research/jax", - ] + pallas_gpu_internal_users, + includes = [":internal"], + packages = pallas_gpu_internal_users, ) package_group( name = "pallas_tpu_users", - packages = [ - "//...", - "//learning/brain/research/jax", - ] + pallas_tpu_internal_users, + includes = [":internal"], + packages = pallas_tpu_internal_users, ) package_group( name = "mosaic_gpu_users", - packages = [ - "//...", - "//learning/brain/research/jax", - ] + mosaic_gpu_internal_users, + includes = [":internal"], + packages = mosaic_gpu_internal_users, ) # JAX-private test utilities. @@ -295,6 +274,7 @@ py_library_providing_imports_info( ":dtypes", ":effects", ":environment_info", + ":internal_mesh_utils", ":jaxpr_util", ":layout", ":lazy_loader", @@ -319,6 +299,7 @@ py_library_providing_imports_info( ":version", ":xla", ":xla_bridge", + ":xla_metadata", "//jax/_src/lib", ] + py_deps("numpy") + py_deps("scipy") + py_deps("opt_einsum") + py_deps("flatbuffers") + jax_extra_deps, ) @@ -471,6 +452,7 @@ pytype_strict_library( ":tree_util", ":typing", ":util", + ":xla_metadata", "//jax/_src/lib", ] + py_deps("numpy"), ) @@ -670,9 +652,9 @@ pytype_strict_library( ":pallas_gpu_users", ], deps = [ - ":pallas", "//jax/_src/pallas/mosaic_gpu:core", # build_cleaner: keep "//jax/_src/pallas/mosaic_gpu:pallas_call_registration", # build_cleaner: keep + "//jax/_src/pallas/triton:core", "//jax/_src/pallas/triton:pallas_call_registration", # build_cleaner: keep "//jax/_src/pallas/triton:primitives", ], @@ -723,6 +705,7 @@ pytype_strict_library( ":state_types", ":tree_util", ":util", + ":xla_metadata", ] + py_deps("numpy"), ) @@ -785,7 +768,15 @@ pytype_strict_library( pytype_strict_library( name = "compute_on", srcs = ["_src/compute_on.py"], - deps = [], + deps = [":config"], +) + +pytype_strict_library( + name = "xla_metadata", + srcs = ["_src/xla_metadata.py"], + deps = [ + ":config", + ], ) pytype_strict_library( @@ -805,6 +796,7 @@ pytype_strict_library( deps = [ ":config", ":core", + ":internal_mesh_utils", ":mesh", ":op_shardings", ":partition_spec", @@ -828,6 +820,14 @@ pytype_strict_library( ] + py_deps("numpy"), ) +pytype_library( + name = "internal_mesh_utils", + srcs = ["_src/mesh_utils.py"], + deps = [ + ":xla_bridge", + ], +) + pytype_strict_library( name = "source_info_util", srcs = ["_src/source_info_util.py"], @@ -848,6 +848,7 @@ pytype_strict_library( ], deps = [ ":core", + ":dtypes", ":effects", ":pretty_printer", ":tree_util", @@ -943,6 +944,7 @@ pytype_strict_library( "_src/clusters/__init__.py", "_src/clusters/cloud_tpu_cluster.py", "_src/clusters/cluster.py", + "_src/clusters/k8s_cluster.py", "_src/clusters/mpi4py_cluster.py", "_src/clusters/ompi_cluster.py", "_src/clusters/slurm_cluster.py", @@ -1085,7 +1087,7 @@ pytype_library( srcs = ["experimental/mesh_utils.py"], visibility = ["//visibility:public"], deps = [ - ":xla_bridge", + ":internal_mesh_utils", ], ) diff --git a/jax/__init__.py b/jax/__init__.py index d9c4de6bb617..c6e073699b0c 100644 --- a/jax/__init__.py +++ b/jax/__init__.py @@ -29,7 +29,7 @@ # Defensively swallow any exceptions to avoid making jax unimportable from warnings import warn as _warn _warn(f"cloud_tpu_init failed: {exc!r}\n This a JAX bug; please report " - f"an issue at https://github.com/google/jax/issues") + f"an issue at https://github.com/jax-ml/jax/issues") del _warn del _cloud_tpu_init @@ -38,7 +38,7 @@ del _core # Note: import as is required for names to be exported. -# See PEP 484 & https://github.com/google/jax/issues/7570 +# See PEP 484 & https://github.com/jax-ml/jax/issues/7570 from jax._src.basearray import Array as Array from jax import tree as tree @@ -56,6 +56,7 @@ debug_nans as debug_nans, debug_infs as debug_infs, log_compiles as log_compiles, + no_tracing as no_tracing, explain_cache_misses as explain_cache_misses, default_device as default_device, default_matmul_precision as default_matmul_precision, @@ -119,14 +120,15 @@ from jax._src.api import pmap as pmap from jax._src.xla_bridge import process_count as process_count from jax._src.xla_bridge import process_index as process_index +from jax._src.xla_bridge import process_indices as process_indices from jax._src.callback import pure_callback as pure_callback from jax._src.ad_checkpoint import checkpoint_wrapper as remat from jax._src.api import ShapeDtypeStruct as ShapeDtypeStruct from jax._src.api import value_and_grad as value_and_grad from jax._src.api import vjp as vjp from jax._src.api import vmap as vmap -from jax._src.api import xla_computation as _deprecated_xla_computation from jax._src.sharding_impls import NamedSharding as NamedSharding +from jax._src.sharding_impls import make_mesh as make_mesh # Force import, allowing jax.interpreters.* to be used after import jax. from jax.interpreters import ad, batching, mlir, partial_eval, pxla, xla @@ -179,11 +181,6 @@ import jax.experimental.compilation_cache.compilation_cache as _ccache del _ccache -from jax._src.deprecations import register as _register_deprecation -_register_deprecation('jax-scipy-beta-args') -_register_deprecation('tracer-hash') -del _register_deprecation - _deprecations = { # Added July 2022 "treedef_is_leaf": ( @@ -226,20 +223,18 @@ "jax.clear_backends is deprecated.", _deprecated_clear_backends ), - # Added Jun 16, 2024 + # Remove after jax 0.4.35 release. "xla_computation": ( - "jax.xla_computation is deprecated. Please use the AOT APIs; see " + "jax.xla_computation is deleted. Please use the AOT APIs; see " "https://jax.readthedocs.io/en/latest/aot.html. For example, replace " "xla_computation(f)(*xs) with jit(f).lower(*xs).compiler_ir('hlo'). See " - "CHANGELOG.md for 0.4.30 for more examples.", - _deprecated_xla_computation + "CHANGELOG.md for 0.4.30 for more examples.", None ), } import typing as _typing if _typing.TYPE_CHECKING: from jax._src.api import clear_backends as clear_backends - from jax._src.api import xla_computation as xla_computation from jax._src.tree_util import treedef_is_leaf as treedef_is_leaf from jax._src.tree_util import tree_flatten as tree_flatten from jax._src.tree_util import tree_leaves as tree_leaves diff --git a/jax/_src/ad_checkpoint.py b/jax/_src/ad_checkpoint.py index 6bf481ef0496..bd7482eb50cf 100644 --- a/jax/_src/ad_checkpoint.py +++ b/jax/_src/ad_checkpoint.py @@ -27,7 +27,6 @@ from jax._src import api from jax._src import config from jax._src import core -from jax._src import dispatch from jax._src import dtypes from jax._src import linear_util as lu from jax._src import effects @@ -474,7 +473,7 @@ def _saved_residuals(jaxpr, arg_info) -> list[tuple[core.AbstractValue, str]]: if v in res_vars: if eqn.primitive is name_p or v in named_vars and (eqn := named_vars[v]): results.append((v.aval, f"named '{eqn.params['name']}' from {src}")) - elif str(eqn.primitive) == 'xla_call': + elif str(eqn.primitive) == 'pjit': results.append((v.aval, f"output of jitted function '{eqn.params['name']}' " f"from {src}")) @@ -515,7 +514,7 @@ def remat_jvp(primals, tangents, jaxpr, prevent_cse, differentiated, policy): prevent_cse=prevent_cse, differentiated=differentiated, policy=policy) out_primals, out_tangents_ = split_list(outs, [len(jaxpr.outvars)]) out_tangents_ = iter(out_tangents_) - out_tangents = [next(out_tangents_) if nz else ad_util.Zero.from_value(p) + out_tangents = [next(out_tangents_) if nz else ad_util.Zero.from_primal_value(p) for p, nz in zip(out_primals, out_nz)] return out_primals, out_tangents ad.primitive_jvps[remat_p] = remat_jvp @@ -547,7 +546,7 @@ def remat_partial_eval(trace, *tracers, jaxpr, **params): # To avoid precision mismatches in fwd and bwd passes due to XLA excess # precision, insert explicit x = reduce_precision(x, **finfo(x.dtype)) calls - # on producers of any residuals. See https://github.com/google/jax/pull/22244. + # on producers of any residuals. See https://github.com/jax-ml/jax/pull/22244. jaxpr_known_ = _insert_reduce_precision(jaxpr_known, num_res) # compute known outputs and residuals (hoisted out of remat primitive) @@ -755,7 +754,7 @@ def remat_expansion(*args, jaxpr: core.Jaxpr, prevent_cse: bool, return api.named_call(translation_rule, name="checkpoint")(*args, jaxpr=jaxpr) def _remat_translation_using_opt_barrier(*args, jaxpr: core.Jaxpr): - args = _optimization_barrier(args) + args = lax_internal.optimization_barrier(args) return core.eval_jaxpr(jaxpr, (), *args) # TODO(mattjj): add core utility for 'create dummy value for this type'? @@ -837,27 +836,6 @@ def _remat_lowering(ctx, *args, jaxpr: core.Jaxpr, prevent_cse: bool, mlir.register_lowering(remat_p, partial(_remat_lowering, is_gpu_platform=True), platform="gpu") -def _optimization_barrier_abstract_eval(*args): - return args - -def _optimization_barrier_lowering_rule(ctx, *args): - barrier_types = map(mlir.aval_to_ir_type, ctx.avals_in) - flat_args = mlir.flatten_ir_values(args) - barrier_op = hlo.OptimizationBarrierOp(flat_args) - return mlir.unflatten_ir_values_like_types(barrier_op.results, barrier_types) - -def _optimization_barrier(arg): - flat_args, treedef = tree_flatten(arg) - return tree_unflatten(treedef, optimization_barrier_p.bind(*flat_args)) - -optimization_barrier_p = core.Primitive('optimization_barrier') -optimization_barrier_p.multiple_results = True -optimization_barrier_p.def_impl( - partial(dispatch.apply_primitive, optimization_barrier_p)) -optimization_barrier_p.def_abstract_eval(_optimization_barrier_abstract_eval) -mlir.register_lowering(optimization_barrier_p, - _optimization_barrier_lowering_rule) - def checkpoint_name(x, name): return name_p.bind(x, name=name) @@ -936,3 +914,6 @@ def checkpoint_wrapper( raise NotImplementedError(msg) return checkpoint(fun, prevent_cse=prevent_cse, policy=policy, static_argnums=static_argnums) + +# TODO(phawkins): update users to refer to the public name. +_optimization_barrier = lax_internal.optimization_barrier diff --git a/jax/_src/ad_util.py b/jax/_src/ad_util.py index 90ae6c1413ec..bd1427f59e01 100644 --- a/jax/_src/ad_util.py +++ b/jax/_src/ad_util.py @@ -20,7 +20,7 @@ from jax._src import core from jax._src import traceback_util from jax._src.core import Primitive, valid_jaxtype, raise_to_shaped, get_aval -from jax._src.tree_util import register_pytree_node +from jax._src.tree_util import register_pytree_node, tree_map from jax._src.typing import Array, ArrayLike from jax._src.util import safe_map @@ -31,7 +31,6 @@ map = safe_map def add_jaxvals(x: ArrayLike, y: ArrayLike) -> Array: - dtype = core.get_aval(x).dtype return add_jaxvals_p.bind(x, y) add_jaxvals_p = Primitive('add_any') @@ -66,8 +65,8 @@ def __init__(self, aval: core.AbstractValue): def __repr__(self) -> str: return f'Zero({self.aval})' @staticmethod - def from_value(val: Any) -> Zero: - return Zero(raise_to_shaped(get_aval(val))) + def from_primal_value(val: Any) -> Zero: + return Zero(raise_to_shaped(get_aval(val)).to_tangent_aval()) register_pytree_node(Zero, lambda z: ((), z.aval), lambda aval, _: Zero(aval)) @@ -83,6 +82,7 @@ def _stop_gradient_impl(x: T) -> T: stop_gradient_p.def_abstract_eval(lambda x: x) +# User-facing version of `Zero` class SymbolicZero: def __init__(self, aval: core.AbstractValue) -> None: self.aval = aval @@ -109,6 +109,19 @@ def __getattr__(self, name): else: return attr + @staticmethod + def from_primal_value(val: Any) -> SymbolicZero: + return SymbolicZero(get_aval(val).to_tangent_aval()) + +def zero_from_primal(val, symbolic_zeros=False): + def f(x): + tangent_aval = get_aval(x).to_tangent_aval() + if symbolic_zeros: + return SymbolicZero(tangent_aval) + else: + return zeros_like_aval(tangent_aval) + return tree_map(f, val) + JaxTypeOrTracer = Any def replace_internal_symbolic_zeros( diff --git a/jax/_src/api.py b/jax/_src/api.py index 493f48a88624..6d2fc4143066 100644 --- a/jax/_src/api.py +++ b/jax/_src/api.py @@ -46,9 +46,9 @@ from jax._src import config from jax._src import core from jax._src import dispatch -from jax._src import effects from jax._src import array from jax._src import basearray +from jax._src import distributed from jax._src import dtypes from jax._src import sharding_impls from jax._src import sharding_specs @@ -59,7 +59,7 @@ from jax._src.core import eval_jaxpr, ShapedArray, ConcreteArray from jax._src.api_util import ( flatten_fun, flatten_fun_nokwargs, flatten_fun_nokwargs2, argnums_partial, - argnums_partial_except, flatten_axes, donation_vector, + flatten_axes, donation_vector, rebase_donate_argnums, _ensure_index, _ensure_index_tuple, shaped_abstractify, apply_flat_fun_nokwargs, check_callable, debug_info, result_paths, flat_out_axes, debug_info_final, fun_sourceinfo) @@ -72,13 +72,11 @@ from jax._src.layout import Layout, AutoLayout from jax._src.traceback_util import api_boundary from jax._src import tree_util -from jax._src.util import (unzip2, safe_map, safe_zip, wrap_name, wraps, - split_list) +from jax._src.util import unzip2, safe_map, safe_zip, wraps, split_list from jax._src import util from jax._src.interpreters import ad from jax._src.interpreters import batching -from jax._src.interpreters import mlir from jax._src.interpreters import partial_eval as pe from jax._src.interpreters import pxla from jax._src.interpreters import xla @@ -336,244 +334,6 @@ def disable_jit(disable: bool = True): yield -def xla_computation(fun: Callable, - static_argnums: int | Iterable[int] = (), - axis_env: Sequence[tuple[AxisName, int]] | None = None, - in_parts=None, out_parts=None, - backend: str | None = None, - tuple_args: bool = False, - instantiate_const_outputs: bool | None = None, - return_shape: bool = False, - donate_argnums: int | Iterable[int] = ()) -> Callable: - """Creates a function that produces its XLA computation given example args. - - .. warning:: - - This function is deprecated as of JAX v0.4.30, and will be removed in a future - JAX release. You can replace it with :ref:`ahead-of-time-lowering` APIs; for - example, ``jax.xla_computation(fn)(*args)`` can be replaced with - ``jax.jit(fn).lower(*args).compiler_ir('hlo')``. - See the `JAX 0.4.30 Change log`_ for more examples. - - Args: - fun: Function from which to form XLA computations. - static_argnums: See the :py:func:`jax.jit` docstring. - axis_env: Optional, a sequence of pairs where the first element is an axis - name and the second element is a positive integer representing the size of - the mapped axis with that name. This parameter is useful when lowering - functions that involve parallel communication collectives, and it - specifies the axis name/size environment that would be set up by - applications of :py:func:`jax.pmap`. See the examples below. - in_parts: Optional, how each argument to ``fun`` should be partitioned or - replicated. This is used to specify partitioned XLA computations, see - ``sharded_jit`` for more info. - out_parts: Optional, how each output of ``fun`` should be partitioned or - replicated. This is used to specify partitioned XLA computations, see - ``sharded_jit`` for more info. - backend: This is an experimental feature and the API is likely to change. - Optional, a string representing the XLA backend: ``'cpu'``, ``'gpu'``, or - ``'tpu'``. - tuple_args: Optional bool, defaults to ``False``. If ``True``, the resulting - XLA computation will have a single tuple argument that is unpacked into - the specified function arguments. If `None`, tupling will be enabled when - there are more than 100 arguments, since some platforms have limits on - argument arity. - instantiate_const_outputs: Deprecated argument, does nothing. - return_shape: Optional boolean, defaults to ``False``. If ``True``, the - wrapped function returns a pair where the first element is the XLA - computation and the second element is a pytree with the same structure as - the output of ``fun`` and where the leaves are objects with ``shape`` and - ``dtype`` attributes representing the corresponding types of the output - leaves. - donate_argnums: Specify which arguments are "donated" to the computation. - It is safe to donate arguments if you no longer need them once the - computation has finished. In some cases XLA can make use of donated - buffers to reduce the amount of memory needed to perform a computation, - for example recycling one of your input buffers to store a result. You - should not reuse buffers that you donate to a computation, JAX will raise - an error if you try to. - - Returns: - A wrapped version of ``fun`` that when applied to example arguments returns - a built XLA Computation (see xla_client.py), from which representations of - the unoptimized XLA HLO computation can be extracted using methods like - ``as_hlo_text``, ``as_serialized_hlo_module_proto``, and - ``as_hlo_dot_graph``. If the argument ``return_shape`` is ``True``, then the - wrapped function returns a pair where the first element is the XLA - Computation and the second element is a pytree representing the structure, - shapes, dtypes, and named shapes of the output of ``fun``. - - Concrete example arguments are not always necessary. For those arguments not - indicated by ``static_argnums``, any object with ``shape`` and ``dtype`` - attributes is acceptable (excepting namedtuples, which are treated as Python - containers). - - For example: - - >>> import jax - >>> - >>> def f(x): return jax.numpy.sin(jax.numpy.cos(x)) - >>> c = jax.xla_computation(f)(3.) # doctest: +SKIP - >>> print(c.as_hlo_text()) # doctest: +SKIP - HloModule xla_computation_f.6 - - ENTRY xla_computation_f.6 { - constant.2 = pred[] constant(false) - parameter.1 = f32[] parameter(0) - cosine.3 = f32[] cosine(parameter.1) - sine.4 = f32[] sine(cosine.3) - ROOT tuple.5 = (f32[]) tuple(sine.4) - } - - - - - Alternatively, the assignment to ``c`` above could be written: - - >>> import types - >>> scalar = types.SimpleNamespace(shape=(), dtype=np.dtype(np.float32)) - >>> c = jax.xla_computation(f)(scalar) # doctest: +SKIP - - - Here's an example that involves a parallel collective and axis name: - - >>> def f(x): return x - jax.lax.psum(x, 'i') - >>> c = jax.xla_computation(f, axis_env=[('i', 4)])(2) # doctest: +SKIP - >>> print(c.as_hlo_text()) # doctest: +SKIP - HloModule jaxpr_computation.9 - primitive_computation.3 { - parameter.4 = s32[] parameter(0) - parameter.5 = s32[] parameter(1) - ROOT add.6 = s32[] add(parameter.4, parameter.5) - } - ENTRY jaxpr_computation.9 { - tuple.1 = () tuple() - parameter.2 = s32[] parameter(0) - all-reduce.7 = s32[] all-reduce(parameter.2), replica_groups={{0,1,2,3}}, to_apply=primitive_computation.3 - ROOT subtract.8 = s32[] subtract(parameter.2, all-reduce.7) - } - - - - Notice the ``replica_groups`` that were generated. Here's an example that - generates more interesting ``replica_groups``: - - >>> from jax import lax - >>> def g(x): - ... rowsum = lax.psum(x, 'i') - ... colsum = lax.psum(x, 'j') - ... allsum = lax.psum(x, ('i', 'j')) - ... return rowsum, colsum, allsum - ... - >>> axis_env = [('i', 4), ('j', 2)] - >>> c = jax.xla_computation(g, axis_env=axis_env)(5.) # doctest: +SKIP - >>> print(c.as_hlo_text()) # doctest: +SKIP - HloModule jaxpr_computation__1.19 - [removed uninteresting text here] - ENTRY jaxpr_computation__1.19 { - tuple.1 = () tuple() - parameter.2 = f32[] parameter(0) - all-reduce.7 = f32[] all-reduce(parameter.2), replica_groups={{0,2,4,6},{1,3,5,7}}, to_apply=primitive_computation__1.3 - all-reduce.12 = f32[] all-reduce(parameter.2), replica_groups={{0,1},{2,3},{4,5},{6,7}}, to_apply=primitive_computation__1.8 - all-reduce.17 = f32[] all-reduce(parameter.2), replica_groups={{0,1,2,3,4,5,6,7}}, to_apply=primitive_computation__1.13 - ROOT tuple.18 = (f32[], f32[], f32[]) tuple(all-reduce.7, all-reduce.12, all-reduce.17) - } - - .. _JAX 0.4.30 Change log: https://jax.readthedocs.io/en/latest/changelog.html#jax-0-4-30-june-18-2024 - """ - if instantiate_const_outputs is not None: - raise ValueError( - "instantiate_const_outputs has been deprecated. Please use the ahead of" - " time APIs. You can read more here:" - " https://jax.readthedocs.io/en/latest/aot.html") - if in_parts is not None: - raise ValueError( - "in_parts has been deprecated. Please use the ahead of time APIs. You" - " can read more here: https://jax.readthedocs.io/en/latest/aot.html") - if out_parts is not None: - raise ValueError( - "out_parts has been deprecated. Please use the ahead of time APIs. You" - " can read more here: https://jax.readthedocs.io/en/latest/aot.html") - - check_callable(fun) - static_argnums = _ensure_index_tuple(static_argnums) - donate_argnums = _ensure_index_tuple(donate_argnums) - donate_argnums = rebase_donate_argnums(donate_argnums, static_argnums) - - fun_name = getattr(fun, "__name__", "unknown") - - platform = backend if backend is not None else xb.get_backend().platform - - def make_axis_env(nreps): - if axis_env is None: - return sharding_impls.AxisEnv(nreps, (), ()) - else: - nreps = nreps * math.prod(size for name, size in axis_env) - names, sizes = unzip2(axis_env) - return sharding_impls.AxisEnv(nreps, names, sizes) - - @wraps(fun) - @api_boundary - def computation_maker(*args, **kwargs): - if max(static_argnums + donate_argnums, default=-1) >= len(args): - raise ValueError(f"jitted function has {static_argnums=}, {donate_argnums=} but " - f"was called with only {len(args)} positional arguments.") - - f = lu.wrap_init(fun) - f, dyn_args = argnums_partial_except(f, static_argnums, args, allow_invalid=False) - args_flat, in_tree = tree_flatten((dyn_args, kwargs)) - if donate_argnums: - donated_invars = donation_vector(donate_argnums, (), in_tree) - else: - donated_invars = (False,) * len(args_flat) - - jaxtree_fun, out_tree = flatten_fun(f, in_tree) - avals = map(shaped_abstractify, args_flat) - with ExitStack() as stack: - for axis_name, size in axis_env or []: - stack.enter_context(core.extend_axis_env(axis_name, size, None)) - jaxpr, out_avals, consts, () = pe.trace_to_jaxpr_dynamic(jaxtree_fun, avals) - jaxpr = dispatch.apply_outfeed_rewriter(jaxpr) - if axis_env: - jaxpr = core.remove_named_axis_effects( - jaxpr, {axis_name for axis_name, _ in axis_env} - ) - axis_env_ = make_axis_env(dispatch.jaxpr_replicas(jaxpr)) - ordered_effects = list( - effects.ordered_effects.filter_in(jaxpr.effects)) - lowering_result = mlir.lower_jaxpr_to_module( - f"xla_computation_{fun_name}", - core.ClosedJaxpr(jaxpr, consts), - ordered_effects=ordered_effects, - backend_or_name=backend, - platforms=[platform], - axis_context=sharding_impls.ReplicaAxisContext(axis_env_), - name_stack=source_info_util.new_name_stack( - wrap_name(fun_name, "xla_computation")), - donated_args=donated_invars, - arg_shardings=None, - result_shardings=None, - lowering_parameters=mlir.LoweringParameters()) - - m = mlir.module_to_bytecode(lowering_result.module) - built = xc._xla.mlir.mlir_module_to_xla_computation( - m, use_tuple_args=tuple_args, return_tuple=True) - out_shapes_flat = [ShapeDtypeStruct(a.shape, a.dtype) for a in out_avals] - out_shapes_flat = [ShapeDtypeStruct(a.shape, a.dtype) for a in out_avals] - out_shape = tree_unflatten(out_tree(), out_shapes_flat) - for out_aval in out_avals: - if not isinstance(out_aval, ShapedArray): - raise RuntimeError("As we want to propagate the weak_type, we need " - "to get a ShapedArray, otherwise this " - "information is lost") - - if return_shape: - return built, out_shape - else: - return built - - return computation_maker - def grad(fun: Callable, argnums: int | Sequence[int] = 0, has_aux: bool = False, holomorphic: bool = False, allow_int: bool = False, @@ -920,7 +680,13 @@ def jacfun(*args, **kwargs): return jac_tree, aux return jacfun -jacobian = jacrev + + +def jacobian(fun: Callable, argnums: int | Sequence[int] = 0, + has_aux: bool = False, holomorphic: bool = False, allow_int: bool = False) -> Callable: + """Alias of :func:`jax.jacrev`.""" + return jacrev(fun, argnums=argnums, has_aux=has_aux, holomorphic=holomorphic, allow_int=allow_int) + _check_input_dtype_jacrev = partial(_check_input_dtype_revderiv, "jacrev") _check_output_dtype_jacrev = partial(_check_output_dtype_revderiv, "jacrev") @@ -1196,7 +962,7 @@ def vmap(fun: F, # list: if in_axes is not a leaf, it must be a tuple of trees. However, # in cases like these users expect tuples and lists to be treated # essentially interchangeably, so we canonicalize lists to tuples here - # rather than raising an error. https://github.com/google/jax/issues/2367 + # rather than raising an error. https://github.com/jax-ml/jax/issues/2367 in_axes = tuple(in_axes) if not (in_axes is None or type(in_axes) in {int, tuple, *batching.spec_types}): @@ -1827,7 +1593,7 @@ def cache_miss(*args, **kwargs): cpp_mapped_f = pmap_lib.pmap( fun, cache_miss, static_broadcasted_tuple, - lambda x, s: pxla.shard_args([s], [x])[0], + lambda x, s: pxla.shard_args([s], [None], [x])[0], pytree_registry=tree_util.default_registry) _pmap_cache_clears.add(cpp_mapped_f) @@ -2066,7 +1832,7 @@ def _lift_linearized(jaxpr, primal_avals, io_tree, out_pvals, consts, *py_args): def fun(*tangents): tangent_avals = list(map(core.get_aval, tangents)) for primal_aval, tangent_aval in zip(primal_avals, tangent_avals): - if not core.typecompat(primal_aval.at_least_vspace(), tangent_aval): + if not core.typecompat(primal_aval.to_tangent_aval(), tangent_aval): raise ValueError("linearized function called on tangent values inconsistent with " "the original primal values: " f"got {tangent_aval} for primal aval {primal_aval}") @@ -2109,7 +1875,7 @@ def _vjp_pullback_wrapper(name, out_primal_avals, io_tree, fun, *py_args_): f"got {in_tree}, but expected to match {in_tree_expected}") for arg, aval in zip(args, out_primal_avals): ct_aval = shaped_abstractify(arg) - ct_aval_expected = aval.at_least_vspace() + ct_aval_expected = aval.to_tangent_aval() if (not core.typecompat(ct_aval, ct_aval_expected) and not _temporary_dtype_exception(ct_aval, ct_aval_expected)): raise ValueError( @@ -2693,11 +2459,9 @@ class ShapeDtypeStruct: dtype: a dtype-like object sharding: (optional) a :class:`jax.Sharding` object """ - __slots__ = ["shape", "dtype", "sharding", "_dll"] - named_shape = {} # type: ignore + __slots__ = ["shape", "dtype", "sharding", "_dll", "weak_type"] - def __init__(self, shape, dtype, named_shape=None, sharding=None): - del named_shape # ignored, vestigial + def __init__(self, shape, dtype, *, sharding=None, weak_type=False): self.shape = tuple(shape) if dtype is None: raise ValueError("ShapeDtypeStruct: dtype must be specified.") @@ -2714,6 +2478,7 @@ def __init__(self, shape, dtype, named_shape=None, sharding=None): f" layout in a `ShapeDtypeStruct`. Got {sharding}") self.sharding = sharding.sharding if isinstance(sharding, Layout) else sharding self._dll = sharding.device_local_layout if isinstance(sharding, Layout) else None + self.weak_type = weak_type size = property(lambda self: math.prod(self.shape)) ndim = property(lambda self: len(self.shape)) @@ -2731,8 +2496,9 @@ def __len__(self): def __repr__(self): sh = f", sharding={self.sharding}" if self.sharding is not None else "" l = f", layout={self.layout}" if self._dll is not None else "" + wt = f", weak_type={self.weak_type}" if self.weak_type else "" return (f"{type(self).__name__}(shape={self.shape}, " - f"dtype={self.dtype.name}{sh}{l})") + f"dtype={self.dtype.name}{sh}{l}{wt})") __str__ = __repr__ @@ -2740,17 +2506,19 @@ def __eq__(self, other): if not isinstance(other, ShapeDtypeStruct): return False else: - return ((other.shape, other.dtype, other.sharding, other.layout) == - (self.shape, self.dtype, self.sharding, self.layout)) + return ((self.shape, self.dtype, self.sharding, self.layout, self.weak_type) == + (other.shape, other.dtype, other.sharding, other.layout, other.weak_type)) def __hash__(self): # TODO(frostig): avoid the conversion from dict by addressing - # https://github.com/google/jax/issues/8182 - return hash((self.shape, self.dtype, self.sharding, self.layout)) + # https://github.com/jax-ml/jax/issues/8182 + return hash((self.shape, self.dtype, self.sharding, self.layout, self.weak_type)) -core.pytype_aval_mappings[ShapeDtypeStruct] = ( - lambda x: ShapedArray(x.shape, dtypes.canonicalize_dtype(x.dtype, allow_extended_dtype=True), - weak_type=False)) +def _sds_aval_mapping(x): + return ShapedArray( + x.shape, dtypes.canonicalize_dtype(x.dtype, allow_extended_dtype=True), + weak_type=x.weak_type) +core.pytype_aval_mappings[ShapeDtypeStruct] = _sds_aval_mapping @api_boundary @@ -2952,7 +2720,6 @@ def try_to_block(x): return x - def clear_backends(): """ Clear all backend clients so that new backend clients can be created later. @@ -2965,14 +2732,16 @@ def clear_backends(): pjit._infer_params_cached.cache_clear() pjit._pjit_lower_cached.cache_clear() pjit._create_pjit_jaxpr.cache_clear() # pytype: disable=attribute-error - pjit._cpp_pjit_cache.clear() + pjit._cpp_pjit_cache_fun_only.clear() + pjit._cpp_pjit_cache_explicit_attributes.clear() xc._xla.PjitFunctionCache.clear_all() @atexit.register def clean_up(): - db = xb._default_backend - if db is not None and db.platform == "cpu": # pytype: disable=attribute-error + if xb._default_backend is not None: clear_backends() + # Shut down distributed system if it exists. Otherwise, this is a no-op. + distributed.shutdown() def live_arrays(platform=None): """Return all live arrays in the backend for `platform`. @@ -2993,7 +2762,8 @@ def clear_caches(): util.clear_all_weakref_lru_caches() # Clear all C++ compiled executable caches for pjit - pjit._cpp_pjit_cache.clear() + pjit._cpp_pjit_cache_fun_only.clear() + pjit._cpp_pjit_cache_explicit_attributes.clear() pjit._infer_params_cached.cache_clear() xc._xla.PjitFunctionCache.clear_all() diff --git a/jax/_src/api_util.py b/jax/_src/api_util.py index dd1cdcbe6bb8..329abd6b7570 100644 --- a/jax/_src/api_util.py +++ b/jax/_src/api_util.py @@ -556,6 +556,26 @@ def _assert_no_intersection(static_argnames, donate_argnames): f"{out} appear in both static_argnames and donate_argnames") +def resolve_kwargs(fun: Callable, args, kwargs) -> tuple[Any, ...]: + """Resolve input arguments to positional following a function's signature. + + This will raise a TypeError if any keyword-only arguments were passed by the + caller. + """ + if isinstance(fun, partial): + # functools.partial should have an opaque signature. + fun = lambda *args, **kwargs: None + ba = inspect.signature(fun).bind(*args, **kwargs) + ba.apply_defaults() + if ba.kwargs: + passed_kwargs = [k for k in ba.kwargs if k in kwargs] + if passed_kwargs: + raise TypeError( + f"keyword arguments ({passed_kwargs}) could not be resolved to " + "positions") + return ba.args + + def _dtype(x): try: return dtypes.result_type(x) diff --git a/jax/_src/array.py b/jax/_src/array.py index 03b0e49d3201..83be3d418c50 100644 --- a/jax/_src/array.py +++ b/jax/_src/array.py @@ -28,6 +28,7 @@ from jax._src import basearray from jax._src import config from jax._src import core +from jax._src import deprecations from jax._src import dispatch from jax._src import dtypes from jax._src import errors @@ -41,7 +42,7 @@ from jax._src.lib import xla_extension as xe from jax._src.sharding import Sharding from jax._src.sharding_impls import ( - PmapSharding, SingleDeviceSharding, + PmapSharding, SingleDeviceSharding, NamedSharding, device_replica_id_map, hashed_index, num_addressable_indices, local_to_global_shape) # pyformat: disable from jax._src.typing import ArrayLike, DLDeviceType from jax._src.util import safe_zip, unzip3, use_cpp_class, use_cpp_method, cache @@ -115,6 +116,16 @@ def _reconstruct_array(fun, args, arr_state, aval_state): np_value = fun(*args) np_value.__setstate__(arr_state) jnp_value = api.device_put(np_value) + # TODO(slebedev): Remove this branch after December 10th 2024. + if "named_shape" in aval_state: + deprecations.warn( + "jax-aval-named-shape", + "Pickled array contains an aval with a named_shape attribute. This is" + " deprecated and the code path supporting such avals will be removed." + " Please re-pickle the array.", + stacklevel=2, + ) + del aval_state["named_shape"] jnp_value.aval = jnp_value.aval.update(**aval_state) return jnp_value @@ -339,18 +350,18 @@ def __getitem__(self, idx): except ValueError: arr_idx = None if arr_idx is not None: - a = self._arrays[arr_idx] - out = ArrayImpl( - a.aval, SingleDeviceSharding(_get_device(a)), [a], committed=False, - _skip_checks=True) + out = self._arrays[arr_idx] + sharding = SingleDeviceSharding(_get_device(out)) if config.pmap_no_rank_reduction.value: # If cidx was the index of a single shard, then it corresponds to one # shard of the chunked dimension. dims = tuple(i for i, x in enumerate(cidx) if isinstance(x, int)) - return lax.squeeze(out, dimensions=dims) - else: - return out + # Squeeze on committed arrays to avoid data movement to shard 0. + out = lax.squeeze(out, dimensions=dims) + + return ArrayImpl( + out.aval, sharding, [out], committed=False, _skip_checks=True) return lax_numpy._rewriting_take(self, idx) @@ -489,7 +500,7 @@ def on_device_size_in_bytes(self): """Returns the total global on-device size of the array in bytes.""" arr = self._arrays[0] per_shard_size = arr.on_device_size_in_bytes() - return per_shard_size * len(self.sharding.device_set) + return per_shard_size * self.sharding.num_devices def devices(self) -> set[Device]: self._check_if_deleted() @@ -751,8 +762,7 @@ def get_data(index: Index | None) -> ArrayImpl | np.ndarray: and sharding.is_fully_replicated and first_value.is_fully_replicated and first_value.sharding._device_assignment == tuple(devices) - and (first_value.layout.device_local_layout == - pxla._maybe_get_default_layout(Layout(dll, sharding), None, sharding, aval))): + and first_value.layout.device_local_layout == dll): return first_value if dtypes.issubdtype(aval.dtype, dtypes.extended): @@ -882,17 +892,20 @@ def make_array_from_process_local_data( setting it to (4, 4) in this case. Args: - sharding: sharding of the global tensor. - host_local_data: data on the host to be placed on local devices. Each + sharding: Sharding of the global array. + local_data: Data on the host to be placed on local devices. Each dimension should either match global_shape, or match num_addressable_indices(dim). - global_shape: the target shape of the global tensor. If None, - will infer from host_local_data and sharding. + global_shape: The target shape of the global array. If None, + will infer from local_data and sharding. Returns: Tensor that will have sharding=sharding and of shape global_shape. """ # pyformat: enable + if xla_bridge.process_count() == 1: + return api.device_put(local_data, sharding) + # TODO(sandler): consider supporting partially specified global_shape or # making local_to_global_shape available in the api. local_shape = local_data.shape @@ -1013,7 +1026,13 @@ def make_array_from_single_device_arrays( core.pytype_aval_mappings[ArrayImpl] = abstract_arrays.canonical_concrete_aval xla.pytype_aval_mappings[ArrayImpl] = op.attrgetter('aval') xla.canonicalize_dtype_handlers[ArrayImpl] = pxla.identity -api_util._shaped_abstractify_handlers[ArrayImpl] = op.attrgetter('aval') +def _get_aval_array(self): + if config.sharding_in_types.value and isinstance(self.sharding, NamedSharding): + return self.aval.update(sharding=NamedSharding( + self.sharding.mesh.abstract_mesh, self.sharding.spec)) + else: + return self.aval +api_util._shaped_abstractify_handlers[ArrayImpl] = _get_aval_array # TODO(jakevdp) replace this with true inheritance at the C++ level. basearray.Array.register(ArrayImpl) @@ -1086,9 +1105,8 @@ def shard_sharded_device_array_slow_path(x, devices, indices, sharding): # Look up all buffers that contain the correct slice of the logical array. candidates_list = candidates[hashed_index(idx)] if not candidates_list: - # This array isn't sharded correctly. Reshard it via host roundtrip. - # TODO(skye): more efficient reshard? - return pxla.shard_args([sharding], [x._value], canonicalize=False)[0] + return pxla.shard_args([sharding], [None], [x._value], + canonicalize=False)[0] # Try to find a candidate buffer already on the correct device, # otherwise copy one of them. for buf in candidates_list: @@ -1097,7 +1115,6 @@ def shard_sharded_device_array_slow_path(x, devices, indices, sharding): break else: bufs.append(buf) - return pxla.batched_device_put(x.aval, sharding, bufs, devices) @@ -1108,23 +1125,25 @@ def _sharding_indices_and_eq(src_sharding, shape, dst_sharding): return dst_indices, tuple(src_indices) == tuple(dst_indices) -def _array_shard_arg(xs, shardings): +def _array_shard_arg(xs, shardings, layouts): results = [] batch_xs, batch_devs, batch_shardings, batch_indices = [], [], [], [] - for i, (x, sharding) in enumerate(safe_zip(xs, shardings)): + + for i, (x, sharding, layout) in enumerate(safe_zip(xs, shardings, layouts)): x._check_if_deleted() + indices, same_indices = _sharding_indices_and_eq(x.sharding, x.shape, sharding) + same_layout = (True if layout is None else + x.layout.device_local_layout == layout) - indices, same_indices = _sharding_indices_and_eq( - x.sharding, x.shape, sharding) if not x.is_fully_addressable: - if same_indices: + if same_indices and same_layout: results.append(x) else: raise NotImplementedError( "Cannot reshard an input that is not fully addressable") else: devices = sharding._addressable_device_assignment - if same_indices: + if same_indices and same_layout: # Add a placeholder result that will be filled in later. results.append(None) # Accumulate arguments to `batched_copy_array_to_devices_with_sharding`. @@ -1133,6 +1152,8 @@ def _array_shard_arg(xs, shardings): batch_shardings.append(sharding) batch_indices.append(i) # Resharding starts here: + elif not same_layout: + results.append(api.device_put(x, Layout(layout, sharding))) elif dispatch.is_single_device_sharding(x.sharding): results.append(shard_device_array(x, devices, indices, sharding)) else: @@ -1145,8 +1166,6 @@ def _array_shard_arg(xs, shardings): assert results[i] is None results[i] = copy_out return results - - pxla.shard_arg_handlers[ArrayImpl] = _array_shard_arg @@ -1178,8 +1197,8 @@ def _array_local_result_handler(aval, sharding, indices): # Token handlers -def _token_shard_arg(xs, shardings): - return _array_shard_arg([x._buf for x in xs], shardings) +def _token_shard_arg(xs, shardings, layouts): + return _array_shard_arg([x._buf for x in xs], shardings, layouts) pxla.shard_arg_handlers[core.Token] = _token_shard_arg diff --git a/jax/_src/basearray.py b/jax/_src/basearray.py index 4848d83d5315..c0b4f9f51c8b 100644 --- a/jax/_src/basearray.py +++ b/jax/_src/basearray.py @@ -124,7 +124,19 @@ def device(self) -> Device | Sharding: @abc.abstractmethod def copy_to_host_async(self): - """Copies jax.Array to host asynchronously.""" + """Copies an ``Array`` to the host asynchronously. + + For arrays that live an an accelerator, such as a GPU or a TPU, JAX may + cache the value of the array on the host. Normally this happens + behind the scenes when the value of an on-device array is requested by the + user, but waiting to initiate a device-to-host copy until the value is + requested requires that JAX block the caller while waiting for the copy to + complete. + + ``copy_to_host_async`` requests that JAX populate its on-host cache of an + array, but does not wait for the copy to complete. This may speed up a + future on-host access to the array's contents. + """ Array.__module__ = "jax" diff --git a/jax/_src/basearray.pyi b/jax/_src/basearray.pyi index 32e1d27dcdf5..23389b392414 100644 --- a/jax/_src/basearray.pyi +++ b/jax/_src/basearray.pyi @@ -14,17 +14,29 @@ import abc from collections.abc import Callable, Sequence from types import ModuleType -from typing import Any, Union +from typing import Any, Protocol, Union, runtime_checkable import numpy as np from jax._src.sharding import Sharding +# TODO(jakevdp) de-duplicate this with the DTypeLike definition in typing.py. +# We redefine these here to prevent circular imports. +@runtime_checkable +class SupportsDType(Protocol): + @property + def dtype(self) -> np.dtype: ... +DTypeLike = Union[str, type[Any], np.dtype, SupportsDType] + +Axis = Union[int, Sequence[int], None] Shard = Any # TODO: alias this to xla_client.Traceback Device = Any Traceback = Any +# TODO(jakevdp): fix import cycles and import this from jax._src.lax. +PrecisionLike = Any + class Array(abc.ABC): aval: Any @@ -117,73 +129,92 @@ class Array(abc.ABC): def __release_buffer__(self, view: memoryview) -> None: ... # np.ndarray methods: - def all(self, axis: int | Sequence[int] | None = None, out=None, - keepdims=None, *, where: ArrayLike | None = ...) -> Array: ... - def any(self, axis: int | Sequence[int] | None = None, out=None, - keepdims=None, *, where: ArrayLike | None = ...) -> Array: ... - def argmax(self, axis: int | None = None, out=None, keepdims=None) -> Array: ... - def argmin(self, axis: int | None = None, out=None, keepdims=None) -> Array: ... - def argpartition(self, kth, axis=-1, kind='introselect', order=None) -> Array: ... - def argsort(self, axis: int | None = -1, kind='quicksort', order=None) -> Array: ... - def astype(self, dtype) -> Array: ... - def choose(self, choices, out=None, mode='raise') -> Array: ... - def clip(self, min=None, max=None, out=None) -> Array: ... - def compress(self, condition, axis: int | None = None, out=None) -> Array: ... + def all(self, axis: Axis = None, out: None = None, + keepdims: bool = False, *, where: ArrayLike | None = None) -> Array: ... + def any(self, axis: Axis = None, out: None = None, + keepdims: bool = False, *, where: ArrayLike | None = None) -> Array: ... + def argmax(self, axis: int | None = None, out: None = None, + keepdims: bool | None = None) -> Array: ... + def argmin(self, axis: int | None = None, out: None = None, + keepdims: bool | None = None) -> Array: ... + def argpartition(self, kth: int, axis: int = -1) -> Array: ... + def argsort(self, axis: int | None = -1, *, kind: None = None, order: None = None, + stable: bool = True, descending: bool = False) -> Array: ... + def astype(self, dtype: DTypeLike | None = None, copy: bool = False, + device: Device | Sharding | None = None) -> Array: ... + def choose(self, choices: Sequence[ArrayLike], out: None = None, mode: str = 'raise') -> Array: ... + def clip(self, min: ArrayLike | None = None, max: ArrayLike | None = None) -> Array: ... + def compress(self, condition: ArrayLike, + axis: int | None = None, *, out: None = None, + size: int | None = None, fill_value: ArrayLike = 0) -> Array: ... def conj(self) -> Array: ... def conjugate(self) -> Array: ... def copy(self) -> Array: ... - def cumprod(self, axis: int | Sequence[int] | None = None, - dtype=None, out=None) -> Array: ... - def cumsum(self, axis: int | Sequence[int] | None = None, - dtype=None, out=None) -> Array: ... - def diagonal(self, offset=0, axis1: int = 0, axis2: int = 1) -> Array: ... - def dot(self, b, *, precision=None) -> Array: ... - def flatten(self) -> Array: ... + def cumprod(self, axis: Axis = None, dtype: DTypeLike | None = None, + out: None = None) -> Array: ... + def cumsum(self, axis: Axis = None, dtype: DTypeLike | None = None, + out: None = None) -> Array: ... + def diagonal(self, offset: int = 0, axis1: int = 0, axis2: int = 1) -> Array: ... + def dot(self, b: ArrayLike, *, precision: PrecisionLike = None, + preferred_element_type: DTypeLike | None = None) -> Array: ... + def flatten(self, order: str = "C") -> Array: ... @property def imag(self) -> Array: ... - def item(self, *args) -> Any: ... - def max(self, axis: int | Sequence[int] | None = None, out=None, - keepdims=None, initial=None, where=None) -> Array: ... - def mean(self, axis: int | Sequence[int] | None = None, dtype=None, - out=None, keepdims=False, *, where=None,) -> Array: ... - def min(self, axis: int | Sequence[int] | None = None, out=None, - keepdims=None, initial=None, where=None) -> Array: ... + def item(self, *args: int) -> Any: ... + def max(self, axis: Axis = None, out: None = None, + keepdims: bool = False, initial: ArrayLike | None = None, + where: ArrayLike | None = None) -> Array: ... + def mean(self, axis: Axis = None, dtype: DTypeLike | None = None, + out: None = None, keepdims: bool = False, *, + where: ArrayLike | None = None) -> Array: ... + def min(self, axis: Axis = None, out: None = None, + keepdims: bool = False, initial: ArrayLike | None = None, + where: ArrayLike | None = None) -> Array: ... @property def nbytes(self) -> int: ... - def nonzero(self, *, size=None, fill_value=None) -> Array: ... - def prod(self, axis: int | Sequence[int] | None = None, dtype=None, - out=None, keepdims=None, initial=None, where=None) -> Array: ... - def ptp(self, axis: int | Sequence[int] | None = None, out=None, - keepdims=False,) -> Array: ... - def ravel(self, order='C') -> Array: ... + def nonzero(self, *, fill_value: None | ArrayLike | tuple[ArrayLike, ...] = None, + size: int | None = None) -> tuple[Array, ...]: ... + def prod(self, axis: Axis = None, dtype: DTypeLike | None = None, + out: None = None, keepdims: bool = False, + initial: ArrayLike | None = None, where: ArrayLike | None = None, + promote_integers: bool = True) -> Array: ... + def ptp(self, axis: Axis = None, out: None = None, + keepdims: bool = False) -> Array: ... + def ravel(self, order: str = 'C') -> Array: ... @property def real(self) -> Array: ... - def repeat(self, repeats, axis: int | None = None, *, - total_repeat_length=None) -> Array: ... - def reshape(self, *args, order='C') -> Array: ... - def round(self, decimals=0, out=None) -> Array: ... - def searchsorted(self, v, side='left', sorter=None) -> Array: ... - def sort(self, axis: int | None = -1, kind='quicksort', order=None) -> Array: ... - def squeeze(self, axis: int | Sequence[int] | None = None) -> Array: ... - def std(self, axis: int | Sequence[int] | None = None, - dtype=None, out=None, ddof=0, keepdims=False, *, where=None) -> Array: ... - def sum(self, axis: int | Sequence[int] | None = None, dtype=None, - out=None, keepdims=None, initial=None, where=None) -> Array: ... + def repeat(self, repeats: ArrayLike, axis: int | None = None, *, + total_repeat_length: int | None = None) -> Array: ... + def reshape(self, *args: Any, order: str = "C") -> Array: ... + def round(self, decimals: int = 0, out: None = None) -> Array: ... + def searchsorted(self, v: ArrayLike, side: str = 'left', + sorter: ArrayLike | None = None, *, method: str = 'scan') -> Array: ... + def sort(self, axis: int | None = -1, *, kind: None = None, + order: None = None, stable: bool = True, descending: bool = False) -> Array: ... + def squeeze(self, axis: Axis = None) -> Array: ... + def std(self, axis: Axis = None, dtype: DTypeLike | None = None, + out: None = None, ddof: int = 0, keepdims: bool = False, *, + where: ArrayLike | None = None, correction: int | float | None = None) -> Array: ... + def sum(self, axis: Axis = None, dtype: DTypeLike | None = None, + out: None = None, keepdims: bool = False, initial: ArrayLike | None = None, + where: ArrayLike | None = None, promote_integers: bool = True) -> Array: ... def swapaxes(self, axis1: int, axis2: int) -> Array: ... - def take(self, indices, axis: int | None = None, out=None, - mode=None) -> Array: ... - def tobytes(self, order='C') -> bytes: ... + def take(self, indices: ArrayLike, axis: int | None = None, out: None = None, + mode: str | None = None, unique_indices: bool = False, indices_are_sorted: bool = False, + fill_value: StaticScalar | None = None) -> Array: ... + def tobytes(self, order: str = 'C') -> bytes: ... def tolist(self) -> list[Any]: ... - def trace(self, offset=0, axis1: int = 0, axis2: int = 1, dtype=None, - out=None) -> Array: ... - def transpose(self, *args) -> Array: ... + def trace(self, offset: int | ArrayLike = 0, axis1: int = 0, axis2: int = 1, + dtype: DTypeLike | None = None, out: None = None) -> Array: ... + def transpose(self, *args: Any) -> Array: ... @property def T(self) -> Array: ... @property def mT(self) -> Array: ... - def var(self, axis: int | Sequence[int] | None = None, - dtype=None, out=None, ddof=0, keepdims=False, *, where=None) -> Array: ... - def view(self, dtype=None, type=None) -> Array: ... + def var(self, axis: Axis = None, dtype: DTypeLike | None = None, + out: None = None, ddof: int = 0, keepdims: bool = False, *, + where: ArrayLike | None = None, correction: int | float | None = None) -> Array: ... + def view(self, dtype: DTypeLike | None = None, type: None = None) -> Array: ... # Even though we don't always support the NumPy array protocol, e.g., for # tracer types, for type checking purposes we must declare support so we diff --git a/jax/_src/cache_key.py b/jax/_src/cache_key.py index 6fdf0c600b7d..9bce9d0e4308 100644 --- a/jax/_src/cache_key.py +++ b/jax/_src/cache_key.py @@ -83,7 +83,8 @@ def get(module: ir.Module, 'jit__psum-14ac577cdb2ef6d986078b4054cc9893a9a14a16dbb0d8f37b89167c1f1aacdf' """ entries = [ - ("computation", lambda hash_obj: _hash_computation(hash_obj, module)), + ("computation", + lambda hash_obj: _hash_computation(hash_obj, module)), ("jax_lib version", lambda hash_obj: hash_obj.update( bytes(jaxlib_version_str.encode("utf-8")))), @@ -129,8 +130,26 @@ def _log_cache_key_hash(hash_obj, last_serialized: str, hashfn): ) +def _remove_custom_partitioning_ptr(m: ir.Module): + """ + Removes custom_partitioning callback pointer from precompiled IR. + Python function pointers are not deterministic across executions. + """ + def _update_bc_attribute(op: ir.Operation) -> ir.WalkResult: + if (op.name == "stablehlo.custom_call" and + op.attributes["call_target_name"].value == "CustomSPMDPartitioning"): + op.attributes["backend_config"] = ir.StringAttr.get("REMOVED") + return ir.WalkResult.ADVANCE + + m.operation.walk(_update_bc_attribute) + return m + + def _serialize_ir(m: ir.Module) -> bytes: output = io.BytesIO() + if config.remove_custom_partitioning_ptr_from_cache_key.value: + m = _remove_custom_partitioning_ptr(type_cast(ir.Module, + m.operation.clone())) m.operation.write_bytecode(file=output) return output.getvalue() diff --git a/jax/_src/callback.py b/jax/_src/callback.py index 453a4eba47bf..3a18dcdfa2ac 100644 --- a/jax/_src/callback.py +++ b/jax/_src/callback.py @@ -196,7 +196,7 @@ def _callback_op_sharding(axis_context, sharding: SingleDeviceSharding | None): device_assignment = axis_context.device_assignment if device_assignment is None: raise AssertionError( - "Please file a bug at https://github.com/google/jax/issues") + "Please file a bug at https://github.com/jax-ml/jax/issues") try: device_index = device_assignment.index(device) except IndexError as e: diff --git a/jax/_src/checkify.py b/jax/_src/checkify.py index 1167914e51c9..32cc4feb9054 100644 --- a/jax/_src/checkify.py +++ b/jax/_src/checkify.py @@ -25,6 +25,7 @@ from jax import dtypes from jax import lax +from jax.experimental import shard_map from jax._src import api from jax._src import linear_util as lu from jax._src import config @@ -931,6 +932,64 @@ def pjit_error_check(error, enabled_errors, *vals_in, jaxpr, return tree_unflatten(out_tree, err_and_out) error_checks[pjit.pjit_p] = pjit_error_check + +def shard_map_error_check( + error, enabled_errors, *vals_in, jaxpr, in_names, out_names, **kwargs +): + if (mesh := kwargs.get('mesh')) is None: + raise ValueError('Mesh must be provided for shard_map with checkify.') + + err_vals, err_tree = jtu.tree_flatten(error) + num_error_vals = len(err_vals) + # Replicated sharding for in errors. + new_in_names = (*([{}] * num_error_vals), *in_names) + new_vals_in = [*err_vals, *vals_in] + in_avals = list(map(get_shaped_aval, new_vals_in)) + for i, v in enumerate(in_avals): + if not (sharder := core.shard_aval_handlers.get(type(v))): + raise ValueError(f'Unsupported aval type: {type(v)}') + in_avals[i] = sharder(mesh, new_in_names[i], v) + + if not isinstance(jaxpr, core.ClosedJaxpr): + jaxpr = core.ClosedJaxpr(jaxpr, ()) + with core.extend_axis_env_nd(mesh.shape.items()): + # jaxpr to checked_jaxpr + checked_jaxpr, out_tree, _ = jaxpr_to_checkify_jaxpr( + jaxpr, enabled_errors, err_tree, *in_avals + ) + num_out_error_vals = out_tree.num_leaves - len(out_names) + + @lu.wrap_init + def expand_errors_leading_dim(*xs): + outs = core.eval_jaxpr(checked_jaxpr.jaxpr, checked_jaxpr.consts, *xs) + errs, outs = split_list(outs, [num_out_error_vals]) + errs = [lax.expand_dims(e, [0]) for e in errs] + return *errs, *outs + + with core.extend_axis_env_nd(mesh.shape.items()): + jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic( + expand_errors_leading_dim, checked_jaxpr.in_avals + ) + checked_jaxpr = core.ClosedJaxpr(jaxpr, consts) + + # Update shard_map params to account for extra error values. + # Use fully sharded partitioning for out errors. + new_out_names = (*([{0: mesh.axis_names}] * num_out_error_vals), *out_names) + subfun = lu.hashable_partial( + lu.wrap_init(core.eval_jaxpr), checked_jaxpr.jaxpr, checked_jaxpr.consts + ) + new_params = dict( + jaxpr=checked_jaxpr.jaxpr, + in_names=new_in_names, + out_names=new_out_names, + **kwargs, + ) + _, new_params = shard_map.shard_map_p.get_bind_params(new_params) + + err_and_out = shard_map.shard_map_p.bind(subfun, *new_vals_in, **new_params) + return tree_unflatten(out_tree, err_and_out) +error_checks[shard_map.shard_map_p] = shard_map_error_check + def custom_jvp_call_rule(in_err, enabled_errors, *in_vals, num_consts, jvp_jaxpr_thunk, call_jaxpr, **params): # The types to have in mind are: @@ -980,7 +1039,7 @@ def jvp(*xs): out = core.eval_jaxpr(jvp_jaxpr, jvp_consts, *primals, *nonzero_tangents) out_primals, nz_out_tangents = split_list(out, [len(out_zeros)]) nz_out_tangents_ = iter(nz_out_tangents) - out_tangents = [SymbolicZero(core.get_aval(p).at_least_vspace()) + out_tangents = [SymbolicZero(core.get_aval(p).to_tangent_aval()) if z else next(nz_out_tangents_) for p, z in zip(out_primals, out_zeros)] assert next(nz_out_tangents_, None) is None diff --git a/jax/_src/cloud_tpu_init.py b/jax/_src/cloud_tpu_init.py index 5b39994c7523..6033e1bbb928 100644 --- a/jax/_src/cloud_tpu_init.py +++ b/jax/_src/cloud_tpu_init.py @@ -80,8 +80,7 @@ def cloud_tpu_init() -> None: os.environ['TPU_ML_PLATFORM'] = 'JAX' os.environ['TPU_ML_PLATFORM_VERSION'] = version.__version__ os.environ.setdefault('ENABLE_RUNTIME_UPTIME_TELEMETRY', '1') - if hardware_utils.tpu_enhanced_barrier_supported(): - os.environ["LIBTPU_INIT_ARGS"] = os.environ.get("LIBTPU_INIT_ARGS","") + " --xla_tpu_use_enhanced_launch_barrier=true" + os.environ["LIBTPU_INIT_ARGS"] = os.environ.get("LIBTPU_INIT_ARGS","") + " --xla_tpu_use_enhanced_launch_barrier=true" # this makes tensorstore serialization work better on TPU os.environ.setdefault('TENSORSTORE_CURL_LOW_SPEED_TIME_SECONDS', '60') diff --git a/jax/_src/clusters/__init__.py b/jax/_src/clusters/__init__.py index 73e4ac9412f7..9abb628f8ae3 100644 --- a/jax/_src/clusters/__init__.py +++ b/jax/_src/clusters/__init__.py @@ -25,3 +25,4 @@ from .mpi4py_cluster import Mpi4pyCluster from .cloud_tpu_cluster import GkeTpuCluster from .cloud_tpu_cluster import GceTpuCluster +from .k8s_cluster import K8sCluster diff --git a/jax/_src/clusters/k8s_cluster.py b/jax/_src/clusters/k8s_cluster.py new file mode 100644 index 000000000000..1274724b8ebd --- /dev/null +++ b/jax/_src/clusters/k8s_cluster.py @@ -0,0 +1,124 @@ +# Copyright 2022 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from contextlib import contextmanager +from functools import cache +import os +import socket +import textwrap +import warnings +from jax._src import clusters + + +class K8sCluster(clusters.ClusterEnv): + + # Use an arbitrarily chosen port for the coordinator since we cannot + # rely on communication to choose one in real time. + _coordinator_port = '55527' + + @classmethod + def is_env_present(cls) -> bool: + if 'KUBERNETES_SERVICE_HOST' in os.environ: + try: + import kubernetes as k8s # pytype: disable=import-error + except ImportError as e: + warnings.warn(textwrap.fill( + "Kubernetes environment detected, but the `kubernetes` package is " + "not installed to enable automatic bootstrapping in this " + "environment. To enable automatic boostrapping, please install " + "jax with the [k8s] extra. For example:" + " pip install jax[k8s]" + " OR" + " pip install jax[k8s,]" + )) + return False + + k8s.config.load_incluster_config() + cls._core_api = k8s.client.CoreV1Api() + cls._batch_api = k8s.client.BatchV1Api() + cls._ApiException = k8s.client.exceptions.ApiException + return True + else: + return False + + @classmethod + @contextmanager + def _handle_api_exception(cls): + try: + yield + except cls._ApiException as e: + err_msg = [f"Kubernetes API Error: {e.status} - {e.reason}"] + if e.status == 403: + err_msg.append(textwrap.fill( + "It appears that the Kubernetes service account (SA) associated with " + "this job does not have the permission for pod introspection. Please " + "either grant the default SA permission to read pod info, or create a " + "dedicated service account with the permission and associated with " + "the job. For more details, see .", + width=80 + )) + raise RuntimeError('\n'.join(err_msg)) from e + + @classmethod + @cache + def _namespace(cls): + return open( + '/var/run/secrets/kubernetes.io/serviceaccount/namespace' + ).read().strip() + + @classmethod + @cache + def _pod(cls): + with cls._handle_api_exception(): + ip = socket.gethostbyname(os.getenv('HOSTNAME')) + pods = cls._core_api.list_namespaced_pod( + namespace=cls._namespace(), + field_selector=f'status.podIP={ip}' + ).items + assert len(pods) == 1, \ + f"Exactly 1 Kubernetes pod should have IP {ip}, got {len(pods)}." + return pods[0] + + @classmethod + @cache + def _job(cls): + with cls._handle_api_exception(): + return cls._batch_api.read_namespaced_job( + name=cls._pod().metadata.labels['job-name'], namespace=cls._namespace() + ) + + @classmethod + def get_coordinator_address(cls, timeout_secs: int | None) -> str: + return '{job_name}-0.{jobset_name}:{port}'.format( + job_name=cls._pod().metadata.labels['job-name'], + jobset_name=cls._job().metadata.labels['jobset.sigs.k8s.io/jobset-name'], + port=cls._coordinator_port + ) + + @classmethod + def get_process_count(cls) -> int: + # https://kubernetes.io/docs/concepts/workloads/controllers/job/#controlling-parallelism + return cls._job().spec.parallelism + + @classmethod + def get_process_id(cls) -> int: + # https://kubernetes.io/docs/concepts/workloads/controllers/job/#completion-mode + try: + return int(os.environ['JOB_COMPLETION_INDEX']) + except KeyError: + raise RuntimeError( + 'K8s job must be run with `completionMode: "Indexed"`.' + ) diff --git a/jax/_src/compilation_cache.py b/jax/_src/compilation_cache.py index 8117f871a969..c75d1783f356 100644 --- a/jax/_src/compilation_cache.py +++ b/jax/_src/compilation_cache.py @@ -72,7 +72,7 @@ def is_cache_used(backend: xla_client.Client) -> bool: # backend that supports serialization of executables. # TODO(skye): add warning when initializing cache on unsupported default # platform - supported_platforms = ["tpu", "gpu", "cpu"] + supported_platforms = ["tpu", "gpu", "cpu", "neuron"] if not _is_cache_enabled(): monitoring.record_event('/jax/compilation_cache/task_disabled_cache') @@ -265,7 +265,9 @@ def put_executable_and_time( cache.put(cache_key, executable_and_time) -def get_cache_key(module: ir.Module, devices: np.ndarray, compile_options, +def get_cache_key(module: ir.Module, + devices: np.ndarray, + compile_options, backend) -> str: return cache_key.get(module, devices, compile_options, backend, "zstandard" if zstandard is not None else "zlib") diff --git a/jax/_src/compilation_cache_interface.py b/jax/_src/compilation_cache_interface.py index 95d557c5531e..480457871a2f 100644 --- a/jax/_src/compilation_cache_interface.py +++ b/jax/_src/compilation_cache_interface.py @@ -12,7 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -from abc import abstractmethod +from __future__ import annotations + +import abc from jax._src import path as pathlib from jax._src import util @@ -21,10 +23,10 @@ class CacheInterface(util.StrictABC): _path: pathlib.Path - @abstractmethod + @abc.abstractmethod def get(self, key: str): pass - @abstractmethod + @abc.abstractmethod def put(self, key: str, value: bytes): pass diff --git a/jax/_src/compiler.py b/jax/_src/compiler.py index 8cad5a8fe9a3..108741b5f8fd 100644 --- a/jax/_src/compiler.py +++ b/jax/_src/compiler.py @@ -21,7 +21,7 @@ import os import tempfile import time -from typing import Any +from typing import Any, Callable import warnings from jax._src import compilation_cache @@ -33,7 +33,6 @@ from jax._src import traceback_util from jax._src.interpreters import mlir from jax._src.lib import xla_client as xc -from jax._src.lib import xla_extension_version from jax._src.lib.mlir import ir import numpy as np @@ -157,8 +156,7 @@ def get_compile_options( build_options = compile_options.executable_build_options build_options.use_spmd_partitioning = use_spmd_partitioning build_options.use_auto_spmd_partitioning = use_auto_spmd_partitioning - if xla_extension_version >= 280: - build_options.use_shardy_partitioner = use_shardy_partitioner + build_options.use_shardy_partitioner = use_shardy_partitioner if fdo_profile is not None: build_options.fdo_profile = fdo_profile if use_auto_spmd_partitioning: @@ -253,15 +251,45 @@ def backend_compile( else: built_c = module - # we use a separate function call to ensure that XLA compilation appears - # separately in Python profiling results - if host_callbacks: - return backend.compile(built_c, compile_options=options, - host_callbacks=host_callbacks) - # Some backends don't have `host_callbacks` option yet - # TODO(sharadmv): remove this fallback when all backends allow `compile` - # to take in `host_callbacks` - return backend.compile(built_c, compile_options=options) + try: + # we use a separate function call to ensure that XLA compilation appears + # separately in Python profiling results + if host_callbacks: + return backend.compile( + built_c, compile_options=options, host_callbacks=host_callbacks + ) + # Some backends don't have `host_callbacks` option yet + # TODO(sharadmv): remove this fallback when all backends allow `compile` + # to take in `host_callbacks` + return backend.compile(built_c, compile_options=options) + except xc.XlaRuntimeError as e: + for error_handler in _XLA_RUNTIME_ERROR_HANDLERS: + handler_result = error_handler(e) + if handler_result is not None: + raise handler_result from e + raise e + + +_XLA_RUNTIME_ERROR_HANDLERS = [] + + +def register_xla_runtime_error_handler( + handler_fn: Callable[[xc.XlaRuntimeError], Exception | None], +): + """Registers a custom exception handler for XLA runtime errors. + + Registering a custom handler allows re-raising a more informative exception + after encountering an XLARuntimeError. + + Args: + handler_fn: A function which returns a new exception to replace the original + XLA runtime error, or None if the original error should be propagated. + + Returns: + A new exception or None. + """ + _XLA_RUNTIME_ERROR_HANDLERS.append(handler_fn) + def compile_or_get_cached( backend: xc.Client, diff --git a/jax/_src/compute_on.py b/jax/_src/compute_on.py index 25b2be78d287..4495d38f9da8 100644 --- a/jax/_src/compute_on.py +++ b/jax/_src/compute_on.py @@ -15,6 +15,7 @@ from __future__ import annotations import threading from contextlib import contextmanager +from jax._src import config class ComputeOnContext(threading.local): @@ -28,6 +29,8 @@ def __init__(self): @contextmanager def extend_compute_type(c_type: str): compute_on_context.stack.append(c_type) + config.update_thread_local_jit_state( + compute_on_context_manager=tuple(compute_on_context.stack)) try: if len(set(filter(lambda x: x is not None, set(compute_on_context.stack)))) > 1: raise NotImplementedError( @@ -36,6 +39,8 @@ def extend_compute_type(c_type: str): yield compute_on_context.stack[-1] finally: compute_on_context.stack.pop() + config.update_thread_local_jit_state( + compute_on_context_manager=tuple(compute_on_context.stack)) def current_compute_type() -> str | None: return compute_on_context.stack[-1] if compute_on_context.stack else None diff --git a/jax/_src/config.py b/jax/_src/config.py index 5b4226f8fa33..b21d2f35f9a4 100644 --- a/jax/_src/config.py +++ b/jax/_src/config.py @@ -202,17 +202,26 @@ def trace_context(): tls = jax_jit.thread_local_state() axis_env_state = () mesh_context_manager = () + xla_metadata_context_manager = () + compute_on_context_manager = () + context: Any = tls.extra_jit_context if context and context.axis_env_state is not None: axis_env_state = context.axis_env_state if context and context.mesh_context_manager: mesh_context_manager = context.mesh_context_manager - return (axis_env_state, mesh_context_manager, enable_x64.value, + if context and context.xla_metadata_context_manager: + xla_metadata_context_manager = context.xla_metadata_context_manager + if context and context.compute_on_context_manager: + compute_on_context_manager = context.compute_on_context_manager + return (axis_env_state, mesh_context_manager, xla_metadata_context_manager, + compute_on_context_manager, enable_x64.value, numpy_rank_promotion.value, default_matmul_precision.value, dynamic_shapes.value, numpy_dtype_promotion.value, default_device.value, random_seed_offset.value, threefry_partitionable.value, threefry_gpu_kernel_lowering.value, + sharding_in_types.value, softmax_custom_jvp.value, enable_memories.value, disable_jit.value, @@ -826,6 +835,7 @@ class _GlobalExtraJitContext(NamedTuple): random_seed_offset: int = 0 threefry_partitionable: bool = False threefry_gpu_kernel_lowering: bool = False + sharding_in_types: bool = False softmax_custom_jvp: bool = False xla_profile_version: int = 0 pgle_profiling_runs: int = 0 @@ -851,6 +861,8 @@ class _ThreadLocalExtraJitContext(NamedTuple): dynamic_trace_state: Any | None = None axis_env_state: Hashable = () mesh_context_manager: Hashable = () + compute_on_context_manager: Hashable = () + xla_metadata_context_manager: Hashable = () # Values set by _StateContextManager context managers. # CAUTION: these must be initialized to `None`! The state context manager @@ -864,10 +876,12 @@ class _ThreadLocalExtraJitContext(NamedTuple): random_seed_offset: int | None = None threefry_partitionable: bool | None = None threefry_gpu_kernel_lowering: bool | None = None + sharding_in_types: bool | None = None softmax_custom_jvp: bool | None = None xla_profile_version: int | None = None pgle_profiling_runs: int | None = None enable_pgle: bool | None = None + use_shardy_partitioner: bool | None = None class _ThreadLocalStateCache(threading.local): @@ -1054,7 +1068,7 @@ def _update_jax_memories_thread_local(val): enable_memories = bool_state( 'jax_enable_memories', - default=False, + default=True, upgrade=True, update_global_hook=_update_jax_memories_global, update_thread_local_hook=_update_jax_memories_thread_local, @@ -1139,6 +1153,16 @@ def _update_jax_memories_thread_local(val): update_thread_local_hook=lambda val: update_thread_local_jit_state( threefry_gpu_kernel_lowering=val)) +sharding_in_types = bool_state( + name='jax_sharding_in_types', + default=False, + help=('When True, enables forward only sharding propagation in JAX and ' + 'avals have sharding on them.'), + update_global_hook=lambda val: _update_global_jit_state( + sharding_in_types=val), + update_thread_local_hook=lambda val: update_thread_local_jit_state( + sharding_in_types=val)) + softmax_custom_jvp = bool_state( name='jax_softmax_custom_jvp', @@ -1146,7 +1170,7 @@ def _update_jax_memories_thread_local(val): upgrade=True, help=('Use a new custom_jvp rule for jax.nn.softmax. The new rule should ' 'improve memory usage and stability. Set True to use new ' - 'behavior. See https://github.com/google/jax/pull/15677'), + 'behavior. See https://github.com/jax-ml/jax/pull/15677'), update_global_hook=lambda val: _update_global_jit_state( softmax_custom_jvp=val), update_thread_local_hook=lambda val: update_thread_local_jit_state( @@ -1323,6 +1347,16 @@ def _update_jax_memories_thread_local(val): 'size to grow indefinitely.'), ) +remove_custom_partitioning_ptr_from_cache_key = bool_state( + name='jax_remove_custom_partitioning_ptr_from_cache_key', + default=False, + help=('If set to True, remove the custom partitioning pointer ' + 'present in the precompiled stableHLO before hashing ' + 'during cache key computation. This is a potentially ' + 'unsafe flag to set and only users who are sure of ' + 'what they are trying to achieve should set it.'), +) + default_dtype_bits = enum_state( name='jax_default_dtype_bits', enum_values=['32', '64'], @@ -1344,6 +1378,15 @@ def _update_jax_memories_thread_local(val): update_thread_local_hook=lambda val: \ update_thread_local_jit_state(numpy_dtype_promotion=val)) +disallow_mesh_context_manager = bool_state( + name='jax_disallow_mesh_context_manager', + default=False, + help=( + 'If set to True, trying to use a mesh as a context manager will' + ' result in a RuntimeError.' + ), +) + def _update_x64_global(val): lib.jax_jit.global_state().enable_x64 = val @@ -1501,6 +1544,11 @@ def _update_disable_jit_thread_local(val): upgrade=True, help='Enable eager-mode pmap when jax_disable_jit is activated.') +no_tracing = bool_state( + name='jax_no_tracing', + default=False, + help='Disallow tracing for JIT compilation.') + disable_vmap_shmap_error = bool_state( name='jax_disable_vmap_shmap_error', default=False, @@ -1681,10 +1729,8 @@ def _update_debug_log_modules(module_names_str: str | None): pmap_no_rank_reduction = bool_state( name='jax_pmap_no_rank_reduction', - default=False, - help=( - "If True, pmap shards have a the same rank as their enclosing array." - ) + default=True, + help='If True, pmap shards have a the same rank as their enclosing array.', ) use_shardy_partitioner = bool_state( diff --git a/jax/_src/core.py b/jax/_src/core.py index ebf29cf0b253..9ef19fbeccdc 100644 --- a/jax/_src/core.py +++ b/jax/_src/core.py @@ -55,6 +55,7 @@ from jax._src import traceback_util from jax._src.typing import Array, DimSize, Shape from jax._src import typing +from jax._src import xla_metadata as xla_metadata_lib traceback_util.register_exclusion(__file__) @@ -261,12 +262,15 @@ def jaxpr_as_fun(closed_jaxpr: ClosedJaxpr, *args): class JaxprEqnContext: - def __init__(self, compute_type: str | None, threefry_partitionable: bool): + def __init__(self, compute_type: str | None, threefry_partitionable: bool, + xla_metadata=None): self.compute_type = compute_type self.threefry_partitionable = threefry_partitionable + self.xla_metadata = xla_metadata self._managers = [ (compute_on.extend_compute_type, self.compute_type), (config.threefry_partitionable.__call__, self.threefry_partitionable), + (xla_metadata_lib.set_xla_metadata, self.xla_metadata), ] @property @@ -278,8 +282,11 @@ def manager(self): yield def __repr__(self): - return (f"JaxprEqnContext(compute_type={self.compute_type}," - f"threefry_partitionable={self.threefry_partitionable})") + return ( + f"JaxprEqnContext(compute_type={self.compute_type}, " + f"threefry_partitionable={self.threefry_partitionable}, " + f"xla_metadata={self.xla_metadata})" + ) class JaxprEqn: @@ -333,8 +340,10 @@ def replace( def new_jaxpr_eqn(invars, outvars, primitive, params, effects, source_info=None, ctx=None): source_info = source_info or source_info_util.new_source_info() - ctx = ctx or JaxprEqnContext(compute_on.current_compute_type(), - config.threefry_partitionable.value) + ctx = ctx or JaxprEqnContext( + compute_on.current_compute_type(), + config.threefry_partitionable.value, + xla_metadata_lib.current_xla_metadata()) if config.enable_checks.value: assert all(isinstance(x, (Var, Literal)) for x in invars) assert all(isinstance(v, Var) for v in outvars) @@ -926,7 +935,7 @@ def unsafe_buffer_pointer(self): class EvalTrace(Trace): - # See comments in https://github.com/google/jax/pull/3370 + # See comments in https://github.com/jax-ml/jax/pull/3370 def pure(self, x): return x lift = sublift = pure @@ -989,7 +998,7 @@ def with_cur_sublevel(self): return self.trace_type(self, cur_sublevel(), **self.payload) class TraceStack: - # See comments in https://github.com/google/jax/pull/3370 + # See comments in https://github.com/jax-ml/jax/pull/3370 stack: list[MainTrace] dynamic: MainTrace @@ -1158,7 +1167,7 @@ def _why_alive(ignore_ids: set[int], x: Any) -> str: # parent->child jump. We do that by setting `parent` here to be a # grandparent (or great-grandparent) of `child`, and then handling that case # in _why_alive_container_info. See example: - # https://github.com/google/jax/pull/13022#discussion_r1008456599 + # https://github.com/jax-ml/jax/pull/13022#discussion_r1008456599 # To prevent this collapsing behavior, just comment out this code block. if (isinstance(parent, dict) and getattr(parents(parent)[0], '__dict__', None) is parents(child)[0]): @@ -1204,7 +1213,7 @@ def _why_alive_container_info(container, obj_id) -> str: @contextmanager def new_main(trace_type: type[Trace], dynamic: bool = False, **payload) -> Generator[MainTrace, None, None]: - # See comments in https://github.com/google/jax/pull/3370 + # See comments in https://github.com/jax-ml/jax/pull/3370 stack = thread_local_state.trace_state.trace_stack level = stack.next_level() main = MainTrace(level, trace_type, **payload) @@ -1245,7 +1254,7 @@ def dynamic_level() -> int: @contextmanager def new_base_main(trace_type: type[Trace], **payload) -> Generator[MainTrace, None, None]: - # See comments in https://github.com/google/jax/pull/3370 + # See comments in https://github.com/jax-ml/jax/pull/3370 stack = thread_local_state.trace_state.trace_stack main = MainTrace(0, trace_type, **payload) prev_dynamic, stack.dynamic = stack.dynamic, main @@ -1268,7 +1277,7 @@ def new_base_main(trace_type: type[Trace], @contextmanager def pop_level(level: int): if level == 0: - return (yield) + return (yield) # noqa: B901 prev, thread_local_state.trace_state.trace_stack.stack = \ thread_local_state.trace_state.trace_stack.stack, \ thread_local_state.trace_state.trace_stack.stack[:level] @@ -1310,7 +1319,7 @@ def f(x): else: return jnp.cos(x) - Here's a real-world example from https://github.com/google/jax/issues/3974:: + Here's a real-world example from https://github.com/jax-ml/jax/issues/3974:: import jax import jax.numpy as jnp @@ -1405,9 +1414,13 @@ def definitely_equal(x, y): class AbstractValue: __slots__: list[str] = [] - def at_least_vspace(self): + def to_tangent_aval(self): raise NotImplementedError("must override") + # TODO(dougalm): deprecate this alias + def at_least_vspace(self): + return self.to_tangent_aval() + def __repr__(self): try: kv_pairs = (f'{k}={v}' for k, v in self.__dict__.items()) @@ -1515,6 +1528,12 @@ def get_aval(x): else: return concrete_aval(x) +def get_type(x): + aval = get_aval(x) + if isinstance(aval, ConcreteArray): + return raise_to_shaped(aval) + else: + return aval def concretization_function_error(fun, suggest_astype=False): fname = getattr(fun, "__name__", fun) @@ -1577,23 +1596,18 @@ def physical_aval(aval: DShapedArray) -> DShapedArray: ... def physical_aval(aval: AbstractValue) -> AbstractValue: ... def physical_aval(aval): - aval_dtype = getattr(aval, 'dtype', None) - if aval_dtype and isinstance(aval_dtype, dtypes.ExtendedDType): - ctor = type(aval) - aval_shape = getattr(aval, 'shape', None) - assert aval_shape is not None, (ctor, aval) - elt_aval = aval_dtype._rules.physical_element_aval(aval_dtype) - assert type(elt_aval) is ShapedArray - return ctor((*aval_shape, *elt_aval.shape), elt_aval.dtype) # pytype: disable=wrong-arg-count - else: - return aval + if (isinstance(aval, (ShapedArray, DShapedArray)) and + isinstance(aval.dtype, dtypes.ExtendedDType)): + elt_aval = physical_element_aval(aval.dtype) + if isinstance(aval, ShapedArray): + return ShapedArray((*aval.shape, *elt_aval.shape), elt_aval.dtype) + return DShapedArray((*aval.shape, *elt_aval.shape), elt_aval.dtype) + return aval + +def physical_element_aval(edtype: dtypes.ExtendedDType) -> ShapedArray: + duck = edtype._rules.physical_element_aval(edtype) # type: ignore + return ShapedArray(duck.shape, dtypes.dtype(duck.dtype)) -def _short_dtype_name(dtype) -> str: - if isinstance(dtype, dtypes.ExtendedDType): - return str(dtype) - else: - return (dtype.name.replace('float', 'f').replace('uint' , 'u') - .replace('int' , 'i').replace('complex', 'c')) def _dtype_object(dtype): return dtype if isinstance(dtype, dtypes.ExtendedDType) else np.dtype(dtype) @@ -1638,7 +1652,7 @@ def __repr__(self): _oct = concretization_function_error(oct) _index = concretization_function_error(operator.index) - def at_least_vspace(self) -> AbstractValue: + def to_tangent_aval(self) -> AbstractValue: return UnshapedArray(primal_dtype_to_tangent_dtype(self.dtype), self.weak_type) @@ -1652,7 +1666,7 @@ def join(self, other): raise TypeError(self, other) def str_short(self, short_dtypes=False) -> str: - return _short_dtype_name(self.dtype) if short_dtypes else self.dtype.name + return dtypes.short_dtype_name(self.dtype) if short_dtypes else self.dtype.name def strip_weak_type(self): """Returns a copy of the aval with weak_type=False.""" @@ -1661,7 +1675,7 @@ def strip_weak_type(self): @property def shape(self): msg = ("UnshapedArray has no shape. Please open an issue at " - "https://github.com/google/jax/issues because it's unexpected for " + "https://github.com/jax-ml/jax/issues because it's unexpected for " "UnshapedArray instances to ever be produced.") raise TypeError(msg) @@ -1733,25 +1747,26 @@ def _invalid_shape_error(shape: Shape, context: str=""): return TypeError(msg) class ShapedArray(UnshapedArray): - __slots__ = ['shape'] + __slots__ = ['shape', 'sharding'] # inherits slots from parent array_abstraction_level = 2 - named_shape = {} # type: ignore - def __init__(self, shape, dtype, weak_type=False, named_shape=None): - del named_shape # unused, vestigial + def __init__(self, shape, dtype, weak_type=False, sharding=None): self.shape = canonicalize_shape(shape) self.dtype = _dtype_object(dtype) self.weak_type = weak_type + if config.sharding_in_types.value: + self.sharding = sharding - def update(self, shape=None, dtype=None, weak_type=None, named_shape=None): - del named_shape # unused, vestigial + def update(self, shape=None, dtype=None, weak_type=None, sharding=None): if shape is None: shape = self.shape if dtype is None: dtype = self.dtype if weak_type is None: weak_type = self.weak_type - return ShapedArray(shape, dtype, weak_type) + if sharding is None: + sharding = getattr(self, 'sharding', None) + return ShapedArray(shape, dtype, weak_type, sharding=sharding) ndim = property(lambda self: len(self.shape)) size = property(lambda self: @@ -1766,15 +1781,17 @@ def update(self, shape=None, dtype=None, weak_type=None, named_shape=None): def __eq__(self, other): return (type(self) is type(other) and self.dtype == other.dtype and self.shape == other.shape - and self.weak_type == other.weak_type) + and self.weak_type == other.weak_type + and getattr(self, 'sharding', None) == getattr(other, 'sharding', None)) def __hash__(self): # can use hash(self.dtype) and rely on the fact that numpy reuses base dtype # objects, e.g. `np.zeros(3).dtype is np.zeros(4).dtype`, or we can use # the unique character code via hash(self.dtype.char) - return hash((self.shape, self.dtype, self.weak_type)) + return hash((self.shape, self.dtype, self.weak_type, + getattr(self, 'sharding', None))) - def at_least_vspace(self): + def to_tangent_aval(self): return ShapedArray(self.shape, primal_dtype_to_tangent_dtype(self.dtype), self.weak_type) @@ -1788,10 +1805,13 @@ def join(self, other): raise TypeError(self, other) def str_short(self, short_dtypes=False): - dt_str = _short_dtype_name(self.dtype) if short_dtypes else self.dtype.name + dt_str = dtypes.short_dtype_name(self.dtype) if short_dtypes else self.dtype.name dt_str = dt_str.replace('void', 'float0') shapestr = ','.join(map(str, self.shape)) - return f'{dt_str}[{shapestr}]' + if hasattr(self, 'sharding'): + return f'{dt_str}[{shapestr}]({self.sharding})' + else: + return f'{dt_str}[{shapestr}]' def _len(self, ignored_tracer): try: @@ -1846,7 +1866,7 @@ def join(self, other) -> AbstractValue: raise TypeError(self, other) def str_short(self, short_dtypes=False) -> str: - dt_str = _short_dtype_name(self.dtype) if short_dtypes else self.dtype.name + dt_str = dtypes.short_dtype_name(self.dtype) if short_dtypes else self.dtype.name return f'{self.val}, dtype={dt_str}' _bool = partialmethod(_forward_to_value, bool) @@ -1896,7 +1916,7 @@ def __init__(self, shape, dtype, weak_type=False): def str_short(self, short_dtypes=False) -> str: del short_dtypes # ignored shape = f'{",".join(str(d) for d in self.shape)}' if self.shape else '' - dtype = _short_dtype_name(self.dtype) + dtype = dtypes.short_dtype_name(self.dtype) return f'{dtype}[{shape}]' __str__ = __repr__ = str_short @@ -1930,7 +1950,7 @@ def join(self, other): else: raise TypeError(self, other) - def at_least_vspace(self): + def to_tangent_aval(self): return DShapedArray(self.shape, primal_dtype_to_tangent_dtype(self.dtype), self.weak_type) @@ -1954,6 +1974,7 @@ def __init__(self, aval, data): assert data.shape == pad_shape self._aval = aval self._data = data + shape = property(lambda self: self._aval.shape) dtype = property(lambda self: self._aval.dtype) aval = property(lambda self: self._aval) @@ -1962,23 +1983,40 @@ def __repr__(self) -> str: # special-case scalar bints return f'{int(self._data)}{{≤{self.dtype.bound}}}' - dtypestr = _short_dtype_name(self._aval.dtype) + dtypestr = dtypes.short_dtype_name(self._aval.dtype) shapestr = ','.join(map(str, self.shape)) - slices = tuple(slice(int(d._data)) if type(d) is DArray and - type(d.dtype) is bint else slice(None) for d in self.shape) - data = self._data[slices] + data = self.data return f'{dtypestr}[{shapestr}] with value: {data}' + def __hash__(self) -> int: if not self.shape: return hash((self._aval, int(self._data))) raise TypeError("unhashable type: DArray") + def __eq__(self, other): if isinstance(other, DArray) and self._aval == other._aval: return self._data == other._data return False + def __len__(self): return self.shape[0] + @property + def data(self): + if not self.shape and type(self.dtype) is bint: + # special-case scalar bints + return self._data + + slices = tuple( + slice(int(d._data)) + if type(d) is DArray and type(d.dtype) is bint + else slice(None) + for d in self.shape + ) + data = self._data[slices] + return data + + pytype_aval_mappings[DArray] = \ lambda x: DConcreteArray(x._aval.shape, x._aval.dtype, x._aval.weak_type, x._data) @@ -2043,7 +2081,7 @@ def join(self, other): else: assert False, f"Cannot join {self} with {other}" def str_short(self, short_dtypes=False): return 'Tok' - def at_least_vspace(self): return self + def to_tangent_aval(self): return self abstract_token: AbstractToken = AbstractToken() # Singleton shaped array used by all abstract tokens when shape/dtype is needed. @@ -3159,7 +3197,7 @@ def pp_var(v: Var | Literal, context: JaxprPpContext) -> str: def pp_aval(a: AbstractValue, context: JaxprPpContext) -> str: if isinstance(a, DShapedArray): shape = [pp_var(d, context) if type(d) is Var else str(d) for d in a.shape] - dtype = _short_dtype_name(a.dtype) + dtype = dtypes.short_dtype_name(a.dtype) return f'{dtype}[{",".join(shape)}]' else: return a.str_short(short_dtypes=True) @@ -3366,3 +3404,53 @@ def clean_up_dead_vars(eqn: JaxprEqn, env: dict[Var, Any], # Used in shard_map for converting avals shard_aval_handlers = {} # type: ignore unshard_aval_handlers = {} # type: ignore + +# ----------------- external APIs for querying tracing context ----------------- + +# TODO(dougalm, jakevdp): expose these via jax.extend + +# Comparable object for checking whether JAX's trace state has changed. +class OpaqueTraceState: + def __init__(self, trace_info, convention): + self._trace_info = trace_info + self._convention = convention + + def __eq__(self, other): + if isinstance(other, OpaqueTraceState): + if self._convention in ["nnx"]: + return self._trace_info is other._trace_info + elif self._convention in ["haiku", "flax"]: + return self._trace_info == other._trace_info + else: + raise Exception(f"unrecognized convention: {self._convention}") + + +# Each library has its own opinion about what the important fragment of jax's +# internal state is. TODO: reconcile the differences and remove the flag. +def get_opaque_trace_state(convention="flax"): + if convention == "flax": + trace_info = find_top_trace(()).level + elif convention == "haiku": + trace_stack = thread_local_state.trace_state.trace_stack.stack + top_type = trace_stack[0].trace_type + level = trace_stack[-1].level + sublevel = cur_sublevel() + trace_info = (top_type, level, sublevel) + elif convention == "nnx": + trace_info = thread_local_state.trace_state.trace_stack.dynamic + else: + raise Exception(f"unrecognized convention: {convention}") + + return OpaqueTraceState(trace_info, convention) + +def nonempty_axis_env() -> bool: + return bool(thread_local_state.trace_state.axis_env) + +def unsafe_am_i_under_a_jit() -> bool: + return 'DynamicJaxprTrace' in str(thread_local_state.trace_state.trace_stack) + +def unsafe_am_i_under_a_vmap() -> bool: + return 'BatchTrace' in str(thread_local_state.trace_state.trace_stack) + +def unsafe_get_axis_names() -> list[str]: + return [axis.name for axis in thread_local_state.trace_state.axis_env] diff --git a/jax/_src/cudnn/__init__.py b/jax/_src/cudnn/__init__.py index 862a661e24b9..23d1fa28ff43 100644 --- a/jax/_src/cudnn/__init__.py +++ b/jax/_src/cudnn/__init__.py @@ -11,3 +11,5 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +from .fusion import cudnn_fusion diff --git a/jax/_src/cudnn/fused_attention_stablehlo.py b/jax/_src/cudnn/fused_attention_stablehlo.py index 262b8e2c140a..e20271f66301 100644 --- a/jax/_src/cudnn/fused_attention_stablehlo.py +++ b/jax/_src/cudnn/fused_attention_stablehlo.py @@ -109,6 +109,7 @@ def create_dot_product_attention_backend_config(batch, dropout_rate, mask_type, layout, + sliding_window_length, is_bwd): # Q, K, V: query, key, value in shape of BT(S)NH or BNT(S)H # P: BMM1 output in shape of BNTS @@ -119,7 +120,8 @@ def create_dot_product_attention_backend_config(batch, # BMM1Grad2: dP @ K -> dQ # BMM2Grad1: P @ dO -> dV # BMM2Grad2: dO @ V -> dP - + if sliding_window_length is None: + sliding_window_length = 0 cudnn_fmha_backend_config = { "algorithm": { "algo_id": "0", @@ -151,6 +153,7 @@ def create_dot_product_attention_backend_config(batch, "seed": seed, "is_flash_attention": True, "mask_type": convert_mask_type_to_string(mask_type), + "sliding_window_length": sliding_window_length, } # We define the contracting and batch dims in the format of @@ -284,8 +287,8 @@ def check_eq(a, b, c, msg): raise ValueError(f"kv_seqlen must have same batch as Q, got {kv_seq_b}") def check_is_flash_attention( - query, key, layout, cudnn_version, has_bias, is_training): - if layout == AttentionLayout.BNTH: + query, key, layout: int, cudnn_version, has_bias, is_training): + if layout == AttentionLayout.BNTH.value: _, _, T, H = query.shape _, _, S, _ = key.shape else: @@ -319,34 +322,38 @@ def check_compute_capability(capability): def _dot_product_attention_fwd( query, key, value, bias, q_seqlen, kv_seqlen, scale, seed, - dropout_rate, variadic_args, mask_type, layout, cudnn_version): + dropout_rate, variadic_args, mask_type, layout, + sliding_window_length, cudnn_version): # check if flash attention is supported for this attention pattern check_is_flash_attention( query, key, layout, cudnn_version, bias is not None, False) outputs = _dot_product_attention_fwd_p_wrapper.bind( query, key, value, bias, q_seqlen, kv_seqlen, scale=scale, seed=seed, dropout_rate=dropout_rate, variadic_args=variadic_args, - mask_type=mask_type, layout=layout, is_training=False) + mask_type=mask_type, layout=layout, + sliding_window_length=sliding_window_length, is_training=False) output = outputs[0] return output def _dot_product_attention_fwd_rule( query, key, value, bias, q_seqlen, kv_seqlen, scale, seed, - dropout_rate, variadic_args, mask_type, layout, cudnn_version): + dropout_rate, variadic_args, mask_type, layout, + sliding_window_length, cudnn_version): # check if flash attention is supported for this attention pattern check_is_flash_attention( query, key, layout, cudnn_version, bias is not None, True) outputs = _dot_product_attention_fwd_p_wrapper.bind( query, key, value, bias, q_seqlen, kv_seqlen, scale=scale, seed=seed, dropout_rate=dropout_rate, variadic_args=variadic_args, - mask_type=mask_type, layout=layout, is_training=True) + mask_type=mask_type, layout=layout, + sliding_window_length=sliding_window_length, is_training=True) res = (query, key, value, bias, q_seqlen, kv_seqlen, outputs[1], outputs[0]) return outputs[0], res def _dot_product_attention_bwd_rule( - scale, seed, dropout_rate, variadic_args, mask_type, layout, is_training, - res, grad_output): + scale, seed, dropout_rate, variadic_args, mask_type, layout, + sliding_window_length, is_training, res, grad_output): (query, key, value, bias, q_seqlen, kv_seqlen, activation, fwd_output) = res grads = _dot_product_attention_bwd_p_wrapper.bind( @@ -354,33 +361,39 @@ def _dot_product_attention_bwd_rule( fwd_output, grad_output, scale=scale, seed=seed, dropout_rate=dropout_rate, variadic_args=variadic_args, mask_type=mask_type, layout=layout, + sliding_window_length=sliding_window_length ) grads = (*grads,) + (None,) * (6 - len(grads)) return grads def _dot_product_attention_fwd_impl( query, key, value, bias, q_seqlen, kv_seqlen, scale, seed, - dropout_rate, variadic_args, mask_type, layout, is_training): + dropout_rate, variadic_args, mask_type, layout, + sliding_window_length, is_training): # args: {Q, K, V, mask*, bias*} outputs = _dot_product_attention_fwd_p.bind( query, key, value, bias, q_seqlen, kv_seqlen, scale=scale, seed=seed, dropout_rate=dropout_rate, variadic_args=variadic_args, - mask_type=mask_type, layout=layout, is_training=is_training) + mask_type=mask_type, layout=layout, + sliding_window_length=sliding_window_length, is_training=is_training) return outputs def _dot_product_attention_bwd_impl( query, key, value, bias, q_seqlen, kv_seqlen, activation, fwd_output, - grad_output, scale, seed, dropout_rate, variadic_args, mask_type, layout): + grad_output, scale, seed, dropout_rate, variadic_args, mask_type, layout, + sliding_window_length): grads = _dot_product_attention_bwd_p.bind( query, key, value, bias, q_seqlen, kv_seqlen, activation, fwd_output, grad_output, scale=scale, seed=seed, dropout_rate=dropout_rate, variadic_args=variadic_args, - mask_type=mask_type, layout=layout) + mask_type=mask_type, layout=layout, + sliding_window_length=sliding_window_length) return grads def _dot_product_attention_fwd_abstract( query, key, value, bias, q_seqlen, kv_seqlen, *, scale, seed, - dropout_rate, variadic_args, mask_type, layout, is_training): + dropout_rate, variadic_args, mask_type, layout, + sliding_window_length, is_training): query_dtype = dtypes.canonicalize_dtype(query.dtype) if layout == AttentionLayout.BNTH.value: B, N, T, _ = query.shape @@ -404,7 +417,7 @@ def _dot_product_attention_fwd_abstract( def _dot_product_attention_bwd_abstract( query, key, value, bias, q_seqlen, kv_seqlen, activation, fwd_output, grad_output, *, scale, seed, dropout_rate, variadic_args, mask_type, - layout): + layout, sliding_window_length): query_dtype = dtypes.canonicalize_dtype(query.dtype) key_dtype = dtypes.canonicalize_dtype(key.dtype) value_dtype = dtypes.canonicalize_dtype(value.dtype) @@ -442,7 +455,8 @@ def _dot_product_attention_bwd_abstract( def _dot_product_attention_fwd_cuda_lowering( ctx, query, key, value, bias, q_seqlen, kv_seqlen, scale, seed, - dropout_rate, variadic_args, mask_type, layout, is_training): + dropout_rate, variadic_args, mask_type, layout, + sliding_window_length, is_training): query_type = ir.RankedTensorType(query.type) query_shape = query_type.shape key_type = ir.RankedTensorType(key.type) @@ -465,7 +479,7 @@ def _dot_product_attention_fwd_cuda_lowering( workspace_type = ir.IntegerType.get_unsigned(8) backend_config = create_dot_product_attention_backend_config( B, N, T, S, query_type.element_type, scale, seed, dropout_rate, - mask_type, layout, is_bwd=False, + mask_type, layout, sliding_window_length, is_bwd=False, ) # {Q, K, V, bias*, q_seqlen*, kv_seqlen*} # {output, activation*, workspace} @@ -512,7 +526,7 @@ def _dot_product_attention_fwd_cuda_lowering( def _dot_product_attention_bwd_cuda_lowering( ctx, query, key, value, bias, q_seqlen, kv_seqlen, activation, fwd_output, grad_output, scale, seed, dropout_rate, variadic_args, - mask_type, layout): + mask_type, layout, sliding_window_length): query_type = ir.RankedTensorType(query.type) query_shape = query_type.shape key_type = ir.RankedTensorType(key.type) @@ -538,7 +552,7 @@ def _dot_product_attention_bwd_cuda_lowering( grad_value_shape = (B, k_N, S, H) backend_config = create_dot_product_attention_backend_config( B, q_N, T, S, query_type.element_type, scale, seed, dropout_rate, - mask_type, layout, is_bwd=True, + mask_type, layout, sliding_window_length, is_bwd=True, ) # {Q, K, V, activation, dO, bias*, O, q_seqlen*, kv_seqlen*} # {dQ, dK, dV, dbias*, workspace} @@ -601,7 +615,7 @@ def _check_valid_batch_dims(bdims): def _dot_product_attention_fwd_batcher( batched_args, batch_dims, *, scale, seed, dropout_rate, variadic_args, - mask_type, layout, is_training): + mask_type, layout, sliding_window_length, is_training): _check_valid_batch_dims(batch_dims) query, key, value, bias, q_seqlen, kv_seqlen = batched_args query_bdim = batch_dims[0] @@ -618,11 +632,12 @@ def _dot_product_attention_fwd_batcher( *_, S, _, _ = key.shape B = math.prod(Bs) has_bias, _ = variadic_args + original_shape = query.shape # reshape to 4D shape query = jnp.reshape(query, (B,) + query.shape[-3:]) key = jnp.reshape(key, (B,) + key.shape[-3:]) value = jnp.reshape(value, (B,) + key.shape[-3:]) - if has_bias: + if has_bias and batch_dims[3] is not None: bias = jnp.reshape(bias, (B, N, T, S)) if has_padding(mask_type): q_seqlen = jnp.reshape(q_seqlen, (B, )) @@ -631,11 +646,12 @@ def _dot_product_attention_fwd_batcher( outputs = _dot_product_attention_fwd_p_wrapper.bind( query, key, value, bias, q_seqlen, kv_seqlen, scale=scale, seed=seed, dropout_rate=dropout_rate, variadic_args=variadic_args, - mask_type=mask_type, layout=layout, is_training=is_training) + mask_type=mask_type, layout=layout, + sliding_window_length=sliding_window_length, is_training=is_training) # reshape to original shape output = outputs[0] - output = jnp.reshape(output, query.shape) + output = jnp.reshape(output, original_shape) if is_training: activation = outputs[1] activation = jnp.reshape(activation, (*Bs, N, T)) @@ -645,7 +661,7 @@ def _dot_product_attention_fwd_batcher( def _dot_product_attention_bwd_batcher( batched_args, batch_dims, *, scale, seed, dropout_rate, variadic_args, - mask_type, layout): + mask_type, layout, sliding_window_length): _check_valid_batch_dims(batch_dims) query, key, value, bias, q_seqlen, \ kv_seqlen, activation, fwd_output, grad_output = batched_args @@ -660,11 +676,20 @@ def _dot_product_attention_bwd_batcher( *_, S, _, _ = key.shape B = math.prod(Bs) has_bias, has_dbias = variadic_args + # Reset the has_dbias if the combined batch size is not 1, because cuDNN only + # supports dbias with a single batch. In this case, an all-zero dbias will be + # appended instead. + if B > 1: + variadic_args = (has_bias, False) + original_query_shape = query.shape + original_key_shape = key.shape + original_value_shape = value.shape + original_bias_shape = bias.shape if has_bias else None # reshape to 4D shape query = jnp.reshape(query, (B,) + query.shape[-3:]) key = jnp.reshape(key, (B,) + key.shape[-3:]) value = jnp.reshape(value, (B,) + key.shape[-3:]) - if has_bias: + if has_bias and batch_dims[3] is not None: bias = jnp.reshape(bias, (B, N, T, S)) if has_padding(mask_type): q_seqlen = jnp.reshape(q_seqlen, (B, )) @@ -679,17 +704,20 @@ def _dot_product_attention_bwd_batcher( fwd_output, grad_output, scale=scale, seed=seed, dropout_rate=dropout_rate, variadic_args=variadic_args, mask_type=mask_type, layout=layout, + sliding_window_length=sliding_window_length, ) - grad_query, grad_key, grad_value = grads[:3] # reshape to original shape - grad_query = jnp.reshape(grad_query, query.shape) - grad_key = jnp.reshape(grad_key, key.shape) - grad_value = jnp.reshape(grad_value, value.shape) + grads[0] = jnp.reshape(grads[0], original_query_shape) + grads[1] = jnp.reshape(grads[1], original_key_shape) + grads[2] = jnp.reshape(grads[2], original_value_shape) if has_dbias: - grad_bias = grads[3] - grad_bias = jnp.reshape(grad_bias, bias.shape) - return grads + (grad_bias,), out_bdims + (query_bdim,) + assert has_bias + if variadic_args[1]: + grads[3] = jnp.reshape(grads[3], original_bias_shape) + else: + grads.append(jnp.zeros(original_bias_shape, bias.dtype)) + out_bdims += (batch_dims[3],) return grads, out_bdims # custom partitioning @@ -745,18 +773,18 @@ def _infer_fwd_output_sharding(mesh, arg_shapes, variadic_args, is_training): return [out_sharding] _dot_product_attention_fwd_lower = custom_partitioning( - _dot_product_attention_fwd_impl, static_argnums=(6, 7, 8, 9, 10, 11, 12)) + _dot_product_attention_fwd_impl, static_argnums=(6, 7, 8, 9, 10, 11, 12, 13)) def _dot_product_attention_fwd_infer_sharding_from_operands( - scale, seed, dropout_rate, variadic_args, mask_type, layout, is_training, - mesh, arg_shapes, result_shape): + scale, seed, dropout_rate, variadic_args, mask_type, layout, sliding_window_length, + is_training, mesh, arg_shapes, result_shape): return _infer_fwd_output_sharding(mesh, arg_shapes, variadic_args, is_training) def _dot_product_attention_fwd_partition( - scale, seed, dropout_rate, variadic_args, mask_type, layout, is_training, - mesh, arg_shapes, result_shape): + scale, seed, dropout_rate, variadic_args, mask_type, layout, sliding_window_length, + is_training, mesh, arg_shapes, result_shape): # args sharding - arg_shardings = tuple([arg_i.sharding for arg_i in arg_shapes]) + arg_shardings = tuple(arg_i.sharding for arg_i in arg_shapes) out_shardings = _infer_fwd_output_sharding( mesh, arg_shapes, variadic_args, is_training) impl = functools.partial( @@ -767,6 +795,7 @@ def _dot_product_attention_fwd_partition( variadic_args=variadic_args, mask_type=mask_type, layout=layout, + sliding_window_length=sliding_window_length, is_training=is_training, ) return mesh, impl, out_shardings, arg_shardings @@ -793,20 +822,20 @@ def _infer_bwd_output_sharding(mesh, arg_shapes, variadic_args): return out_shardings _dot_product_attention_bwd_lower = custom_partitioning( - _dot_product_attention_bwd_impl, static_argnums=(9, 10, 11, 12, 13, 14) + _dot_product_attention_bwd_impl, static_argnums=(9, 10, 11, 12, 13, 14, 15) ) def _dot_product_attention_bwd_infer_sharding_from_operands( - scale, seed, dropout_rate, variadic_args, mask_type, layout, mesh, - arg_shapes, result_shape): + scale, seed, dropout_rate, variadic_args, mask_type, layout, + sliding_window_length, mesh, arg_shapes, result_shape): return _infer_bwd_output_sharding(mesh, arg_shapes, variadic_args) def _dot_product_attention_bwd_partition( - scale, seed, dropout_rate, variadic_args, mask_type, layout, mesh, - arg_shapes, result_shape): + scale, seed, dropout_rate, variadic_args, mask_type, layout, + sliding_window_length, mesh, arg_shapes, result_shape): out_shardings = _infer_bwd_output_sharding(mesh, arg_shapes, variadic_args) # args sharding - arg_shardings = tuple([arg_i.sharding for arg_i in arg_shapes]) + arg_shardings = tuple(arg_i.sharding for arg_i in arg_shapes) def sharded_impl(*args): impl = functools.partial( _dot_product_attention_bwd_impl, @@ -816,6 +845,7 @@ def sharded_impl(*args): variadic_args=variadic_args, mask_type=mask_type, layout=layout, + sliding_window_length=sliding_window_length, ) grads = impl(*args) _, has_dbias = variadic_args @@ -913,7 +943,7 @@ def sharded_impl(*args): ) -@functools.partial(jax.custom_vjp, nondiff_argnums=(6, 7, 8, 9, 10, 11, 12)) +@functools.partial(jax.custom_vjp, nondiff_argnums=(6, 7, 8, 9, 10, 11, 12, 13)) def _dot_product_attention(query: Array, key: Array, value: Array, @@ -926,11 +956,13 @@ def _dot_product_attention(query: Array, variadic_args: tuple[bool, ...], mask_type: bool, layout: int, + sliding_window_length: int | None, cudnn_version: int): output = _dot_product_attention_fwd( query, key, value, bias, q_seqlen, kv_seqlen, scale=scale, seed=seed, dropout_rate=dropout_rate, variadic_args=variadic_args, - mask_type=mask_type, layout=layout, cudnn_version=cudnn_version) + mask_type=mask_type, layout=layout, sliding_window_length=sliding_window_length, + cudnn_version=cudnn_version) return output # _dot_product_attention_fwd must have the same func signature as _dot_product_attention @@ -949,7 +981,8 @@ def dot_product_attention(query: Array, mask_type: MaskType = MaskType.NO_MASK, seed: int = 42, dropout_rate: float = 0., - qkv_layout: str = "BTNH"): + qkv_layout: str = "BTNH", + sliding_window_length: int | None = None): """Computes dot-product attention given query (Q), key (K), and value (V). This function serves as the core operation for applying attention @@ -980,7 +1013,11 @@ def dot_product_attention(query: Array, scale: Scale for the query. dropout_rate: Dropout rate. qkv_layout: Layout string, with supported formats being BTNH, BNTH, BSNH, - BNSH. + BNSH. + sliding_window_length: Window size to make attention only attend to each + token's left local window (pos - sliding_window_length, pos] where `pos` + is the index of each token. E.g., if sliding_window_length == 3 and the + sequence is [0, 1, 2, 3, c, 4, 5], token `c` can attend to [4, 5, c]. Returns: Output of the same shape as the query. @@ -993,6 +1030,9 @@ def dot_product_attention(query: Array, layout = _normalize_layout(qkv_layout) if has_padding(mask_type) and (q_seqlen is None or kv_seqlen is None): raise ValueError("Require q_seqlen and kv_seqlen to generate padding mask") + if sliding_window_length is not None and sliding_window_length <= 0: + raise ValueError( + f"Require sliding_window_length > 0, got {sliding_window_length}") if bias is not None: # reshape bias to have 4D shape @@ -1028,6 +1068,6 @@ def dot_product_attention(query: Array, kv_seqlen = jnp.zeros(0, dtype=query.dtype) output = _dot_product_attention( query, key, value, bias, q_seqlen, kv_seqlen, scale, seed, - dropout_rate, variadic_args, mask_type, layout.value, cudnn_version - ) + dropout_rate, variadic_args, mask_type, layout.value, sliding_window_length, + cudnn_version) return output diff --git a/jax/_src/cudnn/fusion.py b/jax/_src/cudnn/fusion.py new file mode 100644 index 000000000000..8a13399e3d63 --- /dev/null +++ b/jax/_src/cudnn/fusion.py @@ -0,0 +1,91 @@ +# Copyright 2024 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import functools +import jax +from jax import core as jax_core +from jax.interpreters import mlir +from jax.interpreters.mlir import hlo +from jax.interpreters.mlir import ir + + + +def _cudnn_fusion_impl(*args, jaxpr, **unused_kwargs): + del unused_kwargs + return jax_core.jaxpr_as_fun(jaxpr)(*args) + + +def _custom_abstract_eval(*args, jaxpr, **unused_kwargs): + del unused_kwargs + del args + return jaxpr.out_avals + + +cudnn_fusion_p = jax_core.Primitive("cudnn_fusion") +cudnn_fusion_p.multiple_results = True +cudnn_fusion_p.def_abstract_eval(_custom_abstract_eval) +cudnn_fusion_p.def_impl(_cudnn_fusion_impl) + + +def call_cudnn_fusion(f, *args, **kwargs): + """Creates a new cudnn_fusion corresponding to calling + the given function f with args and kwargs.""" + jaxpr, out_shapes = jax.make_jaxpr( + functools.partial(f, **kwargs), return_shape=True + )(*args) + flat_args = jax.tree.leaves(args) + out_tree = jax.tree.structure(out_shapes) + out_flat = cudnn_fusion_p.bind(*flat_args, name=f.__name__, jaxpr=jaxpr) + return jax.tree.unflatten(out_tree, out_flat) + + +def _cudnn_fusion_stablehlo_lowering( + ctx, + *args, + name, + jaxpr, +): + """Make cudnn_fusion which calls the implementation function. + Currently this leaks a CallOp since we're using the `core_call_lowering` + function, but this should get cleaned up by DCE easily. + """ + impl = mlir.core_call_lowering( + ctx, *args, name=name + ".impl", call_jaxpr=jaxpr + ) + call_op = impl[0].owner + called_fn = call_op.attributes["callee"] + cudnn_fusion = hlo.CustomCallOp( + [r.type for r in call_op.results], + call_op.operands, + call_target_name="__cudnn$fusion", + called_computations=ir.ArrayAttr.get([called_fn]), + ) + return cudnn_fusion.results + + +mlir.register_lowering( + cudnn_fusion_p, _cudnn_fusion_stablehlo_lowering, platform="cuda" + ) + + +def cudnn_fusion(f): + """Makes a function become a cuDNN kernel. Relies on XLA's handling of + custom fusions with __cudnn$fusion backend. Currently limited to GEMM + fusions. For example - batch matmul with mixed types and addition: + + @cudnn_fusion + def fn(x, y, z): + return jnp.float32(jax.lax.batch_matmul(jnp.bfloat16(x), y)) + z + """ + return functools.partial(call_cudnn_fusion, f) diff --git a/jax/_src/custom_batching.py b/jax/_src/custom_batching.py index 4d41849b75d3..35e7d33430bd 100644 --- a/jax/_src/custom_batching.py +++ b/jax/_src/custom_batching.py @@ -27,7 +27,7 @@ from jax._src import traceback_util from jax._src import tree_util from jax._src import util -from jax._src.api_util import flatten_fun_nokwargs +from jax._src.api_util import flatten_fun_nokwargs, resolve_kwargs from jax._src.interpreters import ad from jax._src.interpreters import batching from jax._src.interpreters.batching import not_mapped @@ -64,7 +64,12 @@ def def_vmap(self, vmap_rule: Callable) -> Callable: @traceback_util.api_boundary def __call__(self, *args, **kwargs): - assert not kwargs + args = resolve_kwargs(self.fun, args, kwargs) + fun_name = getattr(self.fun, "__name__", str(self.fun)) + if not self.vmap_rule: + raise AttributeError( + f"No batching rule defined for custom_vmap function {fun_name} " + "using def_vmap.") args_flat, in_tree = tree_flatten(args) flat_fun, out_tree = flatten_fun_nokwargs(lu.wrap_init(self.fun), in_tree) in_avals = [core.raise_to_shaped(core.get_aval(x)) for x in args_flat] @@ -186,7 +191,7 @@ def jvp_of_rule_rule(axis_size, in_batched, primals, tangents): # TODO(frostig): assert these also equal: # treedef_tuple((in_tree, in_tree)) - # once https://github.com/google/jax/issues/9066 is fixed + # once https://github.com/jax-ml/jax/issues/9066 is fixed assert tree_ps_ts == tree_ps_ts2 del tree_ps_ts2 diff --git a/jax/_src/custom_derivatives.py b/jax/_src/custom_derivatives.py index d27b0efc7e5e..f5ecdfcda286 100644 --- a/jax/_src/custom_derivatives.py +++ b/jax/_src/custom_derivatives.py @@ -17,7 +17,6 @@ from collections.abc import Callable, Sequence import dataclasses from functools import update_wrapper, reduce, partial, wraps -import inspect from typing import Any, Generic, TypeVar from jax._src import config @@ -30,7 +29,8 @@ from jax._src import traceback_util from jax._src.ad_util import ( stop_gradient_p, SymbolicZero, Zero, zeros_like_aval) -from jax._src.api_util import argnums_partial, flatten_fun_nokwargs +from jax._src.api_util import ( + argnums_partial, flatten_fun_nokwargs, resolve_kwargs) from jax._src.core import raise_to_shaped from jax._src.errors import UnexpectedTracerError from jax._src.interpreters import ad @@ -56,17 +56,6 @@ ### util -def _resolve_kwargs(fun, args, kwargs): - if isinstance(fun, partial): - # functools.partial should have an opaque signature. - fun = lambda *args, **kwargs: None - ba = inspect.signature(fun).bind(*args, **kwargs) - ba.apply_defaults() - if ba.kwargs: - raise TypeError("keyword arguments could not be resolved to positions") - else: - return ba.args - def _initial_style_jaxpr(fun, in_avals): jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(fun, in_avals) return jaxpr, consts @@ -78,7 +67,7 @@ def _sum_tangents(_, x, *xs): return reduce(ad.add_tangents, xs, x) def _zeros_like_pytree(x): - return tree_map(Zero.from_value, x) + return tree_map(Zero.from_primal_value, x) _stop_gradient = partial( tree_map, @@ -163,10 +152,11 @@ def defjvp(self, ``nondiff_argnums``, the ``jvp`` function should accept two arguments, where the first is a tuple of primal inputs and the second is a tuple of tangent inputs. The lengths of both tuples are equal to the number of - parameters of the ``custom_jvp`` function. The ``jvp`` function should - produce as output a pair where the first element is the primal output - and the second element is the tangent output. Elements of the input and - output tuples may be arrays or any nested tuples/lists/dicts thereof. + parameters of the :class:`~jax.custom_jvp` function. The ``jvp`` function + should produce as output a pair where the first element is the primal + output and the second element is the tangent output. Elements of the + input and output tuples may be arrays or any nested tuples/lists/dicts + thereof. symbolic_zeros: boolean, indicating whether the rule should be passed objects representing static symbolic zeros in its tangent argument in correspondence with unperturbed values; otherwise, only standard JAX @@ -177,48 +167,60 @@ def defjvp(self, ``False``. Returns: - None. + Returns ``jvp`` so that ``defjvp`` can be used as a decorator. Examples: - @jax.custom_jvp - def f(x, y): - return jnp.sin(x) * y - - @f.defjvp - def f_jvp(primals, tangents): - x, y = primals - x_dot, y_dot = tangents - primal_out = f(x, y) - tangent_out = jnp.cos(x) * x_dot * y + jnp.sin(x) * y_dot - return primal_out, tangent_out + >>> @jax.custom_jvp + ... def f(x, y): + ... return jnp.sin(x) * y + ... + >>> @f.defjvp + ... def f_jvp(primals, tangents): + ... x, y = primals + ... x_dot, y_dot = tangents + ... primal_out = f(x, y) + ... tangent_out = jnp.cos(x) * x_dot * y + jnp.sin(x) * y_dot + ... return primal_out, tangent_out + + >>> x = jnp.float32(1.0) + >>> y = jnp.float32(2.0) + >>> with jnp.printoptions(precision=2): + ... print(jax.value_and_grad(f)(x, y)) + (Array(1.68, dtype=float32), Array(1.08, dtype=float32)) """ self.jvp = jvp self.symbolic_zeros = symbolic_zeros return jvp - def defjvps(self, *jvps: Callable[..., ReturnValue] | None): + def defjvps(self, *jvps: Callable[..., ReturnValue] | None) -> None: """Convenience wrapper for defining JVPs for each argument separately. This convenience wrapper cannot be used together with ``nondiff_argnums``. Args: *jvps: a sequence of functions, one for each positional argument of the - ``custom_jvp`` function. Each function takes as arguments the tangent - value for the corresponding primal input, the primal output, and the - primal inputs. See the example below. + :class:`~jax.custom_jvp` function. Each function takes as arguments + the tangent value for the corresponding primal input, the primal + output, and the ßprimal inputs. See the example below. Returns: None. Examples: - @jax.custom_jvp - def f(x, y): - return jnp.sin(x) * y - - f.defjvps(lambda x_dot, primal_out, x, y: jnp.cos(x) * x_dot * y, - lambda y_dot, primal_out, x, y: jnp.sin(x) * y_dot) + >>> @jax.custom_jvp + ... def f(x, y): + ... return jnp.sin(x) * y + ... + >>> f.defjvps(lambda x_dot, primal_out, x, y: jnp.cos(x) * x_dot * y, + ... lambda y_dot, primal_out, x, y: jnp.sin(x) * y_dot) + + >>> x = jnp.float32(1.0) + >>> y = jnp.float32(2.0) + >>> with jnp.printoptions(precision=2): + ... print(jax.value_and_grad(f)(x, y)) + (Array(1.68, dtype=float32), Array(1.08, dtype=float32)) """ if self.nondiff_argnums: raise TypeError("Can't use ``defjvps`` with ``nondiff_argnums``.") @@ -240,7 +242,7 @@ def __call__(self, *args: Any, **kwargs: Any) -> ReturnValue: # pytype: disable msg = f"No JVP defined for custom_jvp function {primal_name} using defjvp." raise AttributeError(msg) jvp_name = getattr(self.jvp, '__name__', str(self.jvp)) - args = _resolve_kwargs(self.fun, args, kwargs) + args = resolve_kwargs(self.fun, args, kwargs) if self.nondiff_argnums: nondiff_argnums = set(self.nondiff_argnums) args = tuple(_stop_gradient(x) if i in nondiff_argnums else x @@ -325,24 +327,27 @@ def _flatten_jvp(primal_name, jvp_name, in_tree, maybe_out_type, *args): "shapes/dtypes of:\n" f""" {str(ty_tree_).replace("'", "")}""") raise TypeError(m) - # TODO(mattjj): compare primals' tangent types to tangent objects' types - primal_avals_out = [raise_to_shaped(core.get_aval(x), weak_type=False) - for x in primals_out] + primal_avals_out = [raise_to_shaped(core.get_aval(x), weak_type=False) for x in primals_out] + expected_tangent_avals_out = [ + raise_to_shaped(core.get_aval(x), weak_type=False).to_tangent_aval() + for x in primals_out] tangent_avals_out = [raise_to_shaped(core.get_aval(t), weak_type=False) if type(t) is not SymbolicZero else t.aval.strip_weak_type() for t in tangents_out] - if primal_avals_out != tangent_avals_out: - if len(primal_avals_out) == 1: - (av1,), (av2,) = primal_avals_out, tangent_avals_out + if expected_tangent_avals_out != tangent_avals_out: + if len(expected_tangent_avals_out) == 1: + (av_p,), (av_et,), (av_t,) = primal_avals_out, expected_tangent_avals_out, tangent_avals_out msg = ("Custom JVP rule must produce primal and tangent outputs with " - "equal shapes and dtypes, but got {} and {} respectively.") - raise TypeError(msg.format(av1.str_short(), av2.str_short())) + "corresponding shapes and dtypes. Expected {} (tangent type of {}) but got {}.") + raise TypeError(msg.format(av_et.str_short(), av_p.str_short(), av_t.str_short())) else: msg = ("Custom JVP rule must produce primal and tangent outputs with " - "equal shapes and dtypes, but got:\n{}") + "corresponding shapes and dtypes, but got:\n{}") disagreements = ( - f" primal {av1.str_short()} for tangent {av2.str_short()}" - for av1, av2 in zip(primal_avals_out, tangent_avals_out) if av1 != av2) + f" primal {av_p.str_short()} with tangent {av_t.str_short()}, expecting tangent {av_et}" + for av_p, av_et, av_t in zip(primal_avals_out, expected_tangent_avals_out, tangent_avals_out) + if av_et != av_t) + raise TypeError(msg.format('\n'.join(disagreements))) yield primals_out + tangents_out, (out_tree, primal_avals) @@ -390,7 +395,7 @@ def jvp(*xs): out = core.eval_jaxpr(jvp_jaxpr, jvp_consts, *primals, *nonzero_tangents) out_primals, nz_out_tangents = split_list(out, [len(out_zeros)]) nz_out_tangents_ = iter(nz_out_tangents) - out_tangents = [SymbolicZero(core.get_aval(p).at_least_vspace()) + out_tangents = [SymbolicZero(core.get_aval(p).to_tangent_aval()) if z else next(nz_out_tangents_) for p, z in zip(out_primals, out_zeros)] assert next(nz_out_tangents_, None) is None @@ -571,18 +576,24 @@ def defvjp(self, Examples: - @jax.custom_vjp - def f(x, y): - return jnp.sin(x) * y - - def f_fwd(x, y): - return f(x, y), (jnp.cos(x), jnp.sin(x), y) - - def f_bwd(res, g): - cos_x, sin_x, y = res - return (cos_x * g * y, sin_x * g) - - f.defvjp(f_fwd, f_bwd) + >>> @jax.custom_vjp + ... def f(x, y): + ... return jnp.sin(x) * y + ... + >>> def f_fwd(x, y): + ... return f(x, y), (jnp.cos(x), jnp.sin(x), y) + ... + >>> def f_bwd(res, g): + ... cos_x, sin_x, y = res + ... return (cos_x * g * y, sin_x * g) + ... + >>> f.defvjp(f_fwd, f_bwd) + + >>> x = jnp.float32(1.0) + >>> y = jnp.float32(2.0) + >>> with jnp.printoptions(precision=2): + ... print(jax.value_and_grad(f)(x, y)) + (Array(1.68, dtype=float32), Array(1.08, dtype=float32)) """ self.fwd = fwd self.bwd = bwd @@ -599,7 +610,7 @@ def __call__(self, *args: Any, **kwargs: Any) -> ReturnValue: # pytype: disable msg = f"No VJP defined for custom_vjp function {primal_name} using defvjp." raise AttributeError(msg) fwd_name = getattr(self.fwd, '__name__', str(self.fwd)) - args = _resolve_kwargs(self.fun, args, kwargs) + args = resolve_kwargs(self.fun, args, kwargs) if self.optimize_remat: fwd = optimize_remat_of_custom_vjp_fwd( self.fun, self.fwd, nondiff_argnums=self.nondiff_argnums, @@ -772,10 +783,10 @@ def append(x, d): raise TypeError(msg.format(in_tree2, in_tree)) from None results = [] for kp, a, ct in zip(keypaths, in_avals, cts_in_flat): - if ct is zero or a != a.at_least_vspace(): - results.append(Zero(a.at_least_vspace())) + if ct is zero or getattr(a.to_tangent_aval(), 'dtype') == dtypes.float0: + results.append(Zero(a.to_tangent_aval())) elif type(ct) is SymbolicZero: - if not core.typecompat(a.at_least_vspace(), a_ := ct.aval): + if not core.typecompat(a.to_tangent_aval(), a_ := ct.aval): msg = ("Custom VJP bwd rule produced a SymbolicZero with a shape/dtype " "that does not match the corresponding input tangent shape/dtype: " f"at output{keystr(kp)} the SymbolicZero had shape/dtype " @@ -786,7 +797,7 @@ def append(x, d): raise ValueError(msg) results.append(Zero(ct.aval)) else: - if (not core.typecompat(a.at_least_vspace(), a_ := core.get_aval(ct)) + if (not core.typecompat(a.to_tangent_aval(), a_ := core.get_aval(ct)) and not (_temporary_dtype_exception(a, a_) or _temporary_shape_exception(a, a_))): msg = ("Custom VJP bwd rule must produce an output with the same " @@ -900,16 +911,12 @@ def _custom_vjp_call_jaxpr_jvp( _, res_tree = out_trees() res_and_primals_out = core.eval_jaxpr(fwd_jaxpr, fwd_consts, *args) res, primals_out = split_list(res_and_primals_out, [res_tree.num_leaves]) - avals_out = [raise_to_shaped(core.get_aval(x)) for x in primals_out] + avals_out = [raise_to_shaped(core.get_aval(x)).to_tangent_aval() for x in primals_out] args_dot = map(ad.instantiate_zeros, args_dot) - # Cast float0 to zeros with the primal dtype because custom vjp rules don't - # currently handle float0s - args_dot = map(ad.replace_float0s, args, args_dot) tangents_out = ad.custom_lin_p.bind( *res, *args_dot, num_res=res_tree.num_leaves, bwd=bwd, out_avals=avals_out, symbolic_zeros=symbolic_zeros) tangents_out = map(lax.tie_p.bind, primals_out, tangents_out) - tangents_out = map(ad.recast_to_float0, primals_out, tangents_out) return primals_out, tangents_out ad.primitive_jvps[custom_vjp_call_jaxpr_p] = _custom_vjp_call_jaxpr_jvp @@ -1031,7 +1038,7 @@ def fwd(*args, **kwargs): ans, rule = fun(*args, **kwargs) ans_flat, out_tree = tree_flatten((ans,)) rule, in_tree = flatten_fun_nokwargs(lu.wrap_init(rule), out_tree) - ans_avals = [core.get_aval(x).at_least_vspace() for x in ans_flat] + ans_avals = [core.get_aval(x).to_tangent_aval() for x in ans_flat] jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(rule, ans_avals) return ans, Residuals(jaxpr, in_tree(), out_tree, consts) @@ -1137,7 +1144,7 @@ def rev(objective_fn, res, g): def _maybe_perturbed(x: Any) -> bool: # False if x can't represent an AD-perturbed value (i.e. a value # with a nontrivial tangent attached), up to heuristics, and True otherwise. - # See https://github.com/google/jax/issues/6415 for motivation. + # See https://github.com/jax-ml/jax/issues/6415 for motivation. x = core.full_lower(x) if not isinstance(x, core.Tracer): # If x is not a Tracer, it can't be perturbed. @@ -1145,7 +1152,7 @@ def _maybe_perturbed(x: Any) -> bool: elif isinstance(x, pe.DynamicJaxprTracer): # If x is a DynamicJaxprTracer then we're staging out; differentiation could # happen later, but some types always have trivial tangents. - vspace = x.aval.at_least_vspace() + vspace = x.aval.to_tangent_aval() return not (vspace is core.abstract_token or getattr(vspace, 'dtype', None) == dtypes.float0) elif not isinstance(x, ad.JVPTracer): @@ -1168,7 +1175,12 @@ def converted_fun(*args_hconsts): args, hoisted_consts = split_list(args_hconsts, [num_args]) consts = merge(closure_consts, hoisted_consts) all_args, in_tree2 = tree_flatten(tuple(args)) - assert in_tree == in_tree2 + if in_tree != in_tree2: + msg = ("The inputs to the closure produced by closure_convert must have " + "the same Pytree structure as the example arguments passed when " + f"closure_convert was called. Expected {in_tree}, but got " + f"{in_tree2}") + raise TypeError(msg) out_flat = core.eval_jaxpr(jaxpr, consts, *all_args) return tree_unflatten(out_tree, out_flat) @@ -1412,7 +1424,7 @@ def custom_vjp_by_custom_transpose(fun, fwd, bwd): @fun.defjvp def jvp(primals, tangents): outs, residuals = fwd(*primals) - tan_out_types = tree_map(lambda o: core.get_aval(o).at_least_vspace(), outs) + tan_out_types = tree_map(lambda o: core.get_aval(o).to_tangent_aval(), outs) tan_fn = custom_transpose(partial(disallow_jvp, out_avals=tan_out_types)) tan_fn.def_transpose(bwd) return outs, tan_fn(tan_out_types, residuals, tangents) @@ -1451,7 +1463,9 @@ def wrapped_fwd(*args, **kwargs) -> tuple[ReturnValue, Any]: # above and it would be good to consolidate it. primal_name = getattr(fun, "__name__", str(fun)) fwd_name = getattr(fwd, "__name__", str(fwd)) - args = _resolve_kwargs(fwd, args, kwargs) + # Note: we use `fun` instead of `fwd` here for consistency with + # custom_vjp.__call__ above. + args = resolve_kwargs(fun, args, kwargs) if nondiff_argnums: for i in nondiff_argnums: _check_for_tracers(args[i]) nondiff_argnums_ = set(nondiff_argnums) @@ -1532,6 +1546,9 @@ def _remat_opt_vmap( batched_fwd_jaxpr, out_batched = batching.batch_jaxpr( fwd_jaxpr, axis_size, in_batched, False, axis_name, spmd_axis_name, main_type) + extra_consts = batched_fwd_jaxpr.consts + batched_fwd_jaxpr = pe.close_jaxpr( + pe.convert_constvars_jaxpr(batched_fwd_jaxpr.jaxpr)) out_dims = [0 if b else not_mapped for b in out_batched] _, prim_batched = split_list(in_batched, [num_consts]) @@ -1544,7 +1561,8 @@ def batched_fun_jaxpr_thunk(): main_type) return batched_fun_jaxpr.jaxpr, batched_fun_jaxpr.consts - batched_outs = remat_opt_p.bind(*args, num_consts=num_consts, + batched_outs = remat_opt_p.bind(*extra_consts, *args, + num_consts=num_consts + len(extra_consts), num_res=num_res, fwd_jaxpr=batched_fwd_jaxpr, fun_jaxpr_thunk=batched_fun_jaxpr_thunk) @@ -1612,6 +1630,7 @@ def _remat_opt_dce(used_outs: list[bool], eqn: core.JaxprEqn): instantiate += [True] * (len(eqn.invars) - eqn.params["num_consts"]) new_jaxpr, used_ins = pe.dce_jaxpr(eqn.params["fwd_jaxpr"].jaxpr, used_outs, instantiate=instantiate) + assert not new_jaxpr.constvars closed_jaxpr = pe.close_jaxpr(new_jaxpr) invars = [v for used, v in zip(used_ins, eqn.invars) if used] new_params = dict(eqn.params) diff --git a/jax/_src/custom_partitioning.py b/jax/_src/custom_partitioning.py index 8f48746dda37..984d55fe2f6b 100644 --- a/jax/_src/custom_partitioning.py +++ b/jax/_src/custom_partitioning.py @@ -492,7 +492,7 @@ def _custom_partitioning_lowering_rule(ctx: mlir.LoweringRuleContext, *values, devices = axis_context.device_assignment if devices is None: raise AssertionError( - 'Please file a bug at https://github.com/google/jax/issues') + 'Please file a bug at https://github.com/jax-ml/jax/issues') if axis_context.mesh_shape is not None: ma, ms = list(zip(*axis_context.mesh_shape)) mesh = mesh_lib.Mesh(np.array(devices).reshape(ms), ma) diff --git a/jax/_src/debugger/core.py b/jax/_src/debugger/core.py index f6b0a81baf92..1efeed73cbc8 100644 --- a/jax/_src/debugger/core.py +++ b/jax/_src/debugger/core.py @@ -112,6 +112,11 @@ def from_frameinfo(cls, frame_info) -> DebuggerFrame: # then we subtract it off from the `lineno` and don't need to subtract 1 # since both start and lineno are 1-indexed. offset = frame_info.lineno - max(start, 1) + if offset >= len(source): + # Sometimes we don't get a valid source/offset pair. This seems to + # happen sometimes when code uses eval(). If that happens, give up. + source = [] + offset = None except OSError: source = [] offset = None diff --git a/jax/_src/debugging.py b/jax/_src/debugging.py index 62bfdf031c7c..465dc90e21da 100644 --- a/jax/_src/debugging.py +++ b/jax/_src/debugging.py @@ -46,6 +46,7 @@ from jax._src.lib.mlir.dialects import hlo from jax._src.sharding import Sharding from jax._src.sharding_impls import NamedSharding, parse_flatten_op_sharding +from jax._src.api_util import shaped_abstractify from jax._src.state import discharge as state_discharge logger = logging.getLogger(__name__) @@ -256,15 +257,36 @@ def debug_callback(callback: Callable[..., None], *args: Any, raise TypeError("first argument to jax.debug.callback must be callable, " f"but got an object of type {type(callback)}") flat_args, in_tree = tree_util.tree_flatten((args, kwargs)) - effect = ordered_debug_effect if ordered else debug_effect - def _flat_callback(*flat_args): - args, kwargs = tree_util.tree_unflatten(in_tree, flat_args) + static_args, dyn_args = {}, [] + for i, a in enumerate(flat_args): + try: + shaped_abstractify(a) + dyn_args.append(a) + except (AssertionError, TypeError): + static_args[i] = a + + def _flat_callback(*dyn_args): + all_args = [None] * (len(static_args) + len(dyn_args)) + di = iter(dyn_args) + for i in range(len(all_args)): + if i in static_args: + all_args[i] = static_args[i] + else: + all_args[i] = next(di) + assert next(di, None) is None + args, kwargs = tree_util.tree_unflatten(in_tree, all_args) callback(*args, **kwargs) return () - debug_callback_p.bind(*flat_args, callback=_flat_callback, effect=effect) + + effect = ordered_debug_effect if ordered else debug_effect + debug_callback_p.bind(*dyn_args, callback=_flat_callback, effect=effect) class _DebugPrintFormatChecker(string.Formatter): + def format_field(self, value, format_spec): + del value, format_spec + return "" # No formatting is done. + def check_unused_args(self, used_args, args, kwargs): unused_args = [arg for i, arg in enumerate(args) if i not in used_args] unused_kwargs = [k for k in kwargs if k not in used_args] @@ -314,7 +336,7 @@ def debug_print(fmt: str, *args, **kwargs): **kwargs: Additional keyword arguments to be formatted, as if passed to ``fmt.format``. """ - # Check that we provide the correct arguments to be formatted + # Check that we provide the correct arguments to be formatted. formatter.format(fmt, *args, **kwargs) debug_callback(functools.partial(_format_print_callback, fmt), *args, @@ -367,7 +389,7 @@ def _inspect_sharding_lowering_rule(ctx: mlir.LoweringRuleContext, value, *, devices = axis_context.device_assignment if devices is None: raise AssertionError( - 'Please file a bug at https://github.com/google/jax/issues') + 'Please file a bug at https://github.com/jax-ml/jax/issues') elif isinstance(axis_context, sharding_impls.SPMDAxisContext): devices = axis_context.mesh._flat_devices_tuple else: diff --git a/jax/_src/deprecations.py b/jax/_src/deprecations.py index fb346ca9b372..5f1d132bcbb3 100644 --- a/jax/_src/deprecations.py +++ b/jax/_src/deprecations.py @@ -117,3 +117,19 @@ def warn(deprecation_id: str, message: str, stacklevel: int) -> None: else: warnings.warn(message, category=DeprecationWarning, stacklevel=stacklevel + 1) + + +# Register a number of deprecations: we do this here to ensure they're +# always registered by the time `accelerate` and `is_acelerated` are called. +register('jax-aval-named-shape') +register('jax-dlpack-import-legacy') +register("jax-numpy-astype-complex-to-real") +register("jax-numpy-array-none") +register('jax-scipy-beta-args') +register('tracer-hash') +register('jax-numpy-reshape-newshape') +register('jax-numpy-clip-args') +register('jax-numpy-linalg-matrix_rank-tol') +register('jax-numpy-linalg-pinv-rcond') +register('jax-numpy-quantile-interpolation') +register('jax-numpy-trimzeros-not-1d-array') diff --git a/jax/_src/dispatch.py b/jax/_src/dispatch.py index e7fd8657ccdb..59739f4130f3 100644 --- a/jax/_src/dispatch.py +++ b/jax/_src/dispatch.py @@ -44,6 +44,7 @@ from jax._src.interpreters import xla from jax._src.interpreters import pxla from jax._src import lib +from jax._src.mesh import AbstractMesh from jax._src.lib import xla_client as xc from jax._src.monitoring import record_event_duration_secs from jax._src.partition_spec import PartitionSpec @@ -133,7 +134,7 @@ def get_token_input( # We only use replicated sharding for the first time when the token for the # order effect hasn't been created. s = jax.sharding.GSPMDSharding.get_replicated(devices) - sharded_tok = core.Token(pxla.shard_args([s], [tok])[0]) + sharded_tok = core.Token(pxla.shard_args([s], [None], [tok])[0]) self.current_tokens[eff] = sharded_tok return sharded_tok @@ -204,7 +205,7 @@ def jaxpr_has_primitive(jaxpr: core.Jaxpr, prim_name: str) -> bool: # stablehlo is oblivious of physical devices. prim_requires_devices_during_lowering: set[core.Primitive] = set() -def jaxpr_has_prim_requiring_devices(jaxpr: core.Jaxpr): +def jaxpr_has_prim_requiring_devices(jaxpr: core.Jaxpr) -> bool: for eqn in jaxpr.eqns: if eqn.primitive in prim_requires_devices_during_lowering: return True @@ -227,8 +228,11 @@ def get_intermediate_shardings( for eqn in jaxpr.eqns: if eqn.primitive is pjit.sharding_constraint_p: + s = eqn.params['sharding'] + if isinstance(s, NamedSharding) and isinstance(s.mesh, AbstractMesh): + continue source_info = SourceInfo(eqn.source_info, eqn.primitive.name) - yield (eqn.params['sharding'], source_info) + yield (s, source_info) elif eqn.primitive is pjit.pjit_p: source_info = SourceInfo(eqn.source_info, eqn.primitive.name) yield from ((i, source_info) for i in eqn.params['in_shardings']) @@ -414,7 +418,7 @@ def _device_put_sharding_impl(x, aval, device): return _different_device_order_reshard(x, s) if (s.is_fully_addressable and isinstance(x, array.ArrayImpl) and - x.is_fully_addressable and len(s.device_set) > 1 and + x.is_fully_addressable and s.num_devices > 1 and s._internal_device_list != x.sharding._internal_device_list and # pytype: disable=attribute-error s.device_set == x.sharding.device_set): assert isinstance(s, Sharding) @@ -511,7 +515,10 @@ def _batched_device_put_impl( if shard_arg_xs: # Batch shard_arg calls. Helps improve efficiency for backends that support # efficient batch transfer. - shard_arg_results = pxla.shard_args(shard_arg_shardings, shard_arg_xs) + # device_put handles `Layout` via a different path, so just pass `None` as + # the layout here. + shard_arg_results = pxla.shard_args( + shard_arg_shardings, [None] * len(shard_arg_xs), shard_arg_xs) for i, shard_arg_result in zip(shard_arg_indices, shard_arg_results): assert isinstance(ys[i], _DeferredShardArg) ys[i] = ys[i].result_handler(shard_arg_result) @@ -548,6 +555,10 @@ def _device_put_batcher(batched_args, batch_dims, **params): batching.primitive_batchers[device_put_p] = _device_put_batcher def _tpu_gpu_device_put_lowering(ctx, *xs, devices, srcs): + # TODO(yashkatariya): Maybe we should add the custom calls anyways if it's + # being used inside jit? Atleast for now, this preserves the old behavior. + if ctx.module_context.all_default_mem_kind: + return xs def lower(x, device, src, aval, out_aval): if (isinstance(device, (Sharding, TransferToMemoryKind)) and device.memory_kind is not None): @@ -558,6 +569,7 @@ def lower(x, device, src, aval, out_aval): return x return x return list(map(lower, xs, devices, srcs, ctx.avals_in, ctx.avals_out)) + mlir.register_lowering( device_put_p, _tpu_gpu_device_put_lowering, platform='tpu') mlir.register_lowering( @@ -565,12 +577,6 @@ def lower(x, device, src, aval, out_aval): def _common_device_put_lowering(ctx, *xs, devices, srcs): - for device in devices: - if (isinstance(device, (Sharding, TransferToMemoryKind)) and - device.memory_kind is not None): - raise NotImplementedError( - "Passing memory_kind to device_put via Shardings is not supported on" - f" platforms {ctx.module_context.platforms}") return xs mlir.register_lowering(device_put_p, _common_device_put_lowering) diff --git a/jax/_src/distributed.py b/jax/_src/distributed.py index 5e8e956cf98b..3ea9304b67aa 100644 --- a/jax/_src/distributed.py +++ b/jax/_src/distributed.py @@ -14,7 +14,6 @@ from __future__ import annotations -import atexit from collections.abc import Sequence import logging import os @@ -45,10 +44,12 @@ def initialize(self, initialization_timeout: int = 300, coordinator_bind_address: str | None = None): coordinator_address = (coordinator_address or - os.environ.get('JAX_COORDINATOR_ADDRESS', None)) + os.environ.get('JAX_COORDINATOR_ADDRESS')) if isinstance(local_device_ids, int): local_device_ids = [local_device_ids] + if local_device_ids is None and (env_ids := os.environ.get('JAX_LOCAL_DEVICE_IDS')): + local_device_ids = list(map(int, env_ids.split(","))) (coordinator_address, num_processes, process_id, local_device_ids) = ( clusters.ClusterEnv.auto_detect_unset_distributed_params( @@ -230,11 +231,11 @@ def initialize(coordinator_address: str | None = None, global_state.initialize(coordinator_address, num_processes, process_id, local_device_ids, cluster_detection_method, initialization_timeout, coordinator_bind_address) - atexit.register(shutdown) def shutdown(): """Shuts down the distributed system. - Does nothing if the distributed system is not running.""" + Does nothing if the distributed system is not running. + """ global_state.shutdown() diff --git a/jax/_src/dlpack.py b/jax/_src/dlpack.py index 386123ae61f0..ac976234eda5 100644 --- a/jax/_src/dlpack.py +++ b/jax/_src/dlpack.py @@ -16,14 +16,17 @@ from typing import Any -from jax._src.api import device_put from jax import numpy as jnp from jax._src import array +from jax._src import deprecations from jax._src import xla_bridge +from jax._src.api import device_put from jax._src.lax.lax import _array_copy from jax._src.lib import xla_client -from jax._src.typing import Array, DLDeviceType from jax._src.sharding import Sharding +from jax._src.typing import Array +from jax._src.typing import DLDeviceType + DLPACK_VERSION = (0, 8) MIN_DLPACK_VERSION = (0, 5) @@ -237,21 +240,19 @@ def from_dlpack(external_array, device transfer or copy was requested. Args: - external_array: An array object that has __dlpack__ and __dlpack_device__ - methods, or a DLPack tensor on either CPU or GPU (legacy API). - + external_array: An array object that has ``__dlpack__` and + ``__dlpack_device__`` methods. device: The (optional) :py:class:`Device`, representing the device on which - the returned array should be placed. If given, then the result is committed - to the device. If unspecified, the resulting array will be unpacked onto the - same device it originated from. Setting ``device`` to a device different from - the source of ``external_array`` will require a copy, meaning ``copy`` must be - set to either ``True`` or ``None``. - + the returned array should be placed. If given, then the result is + committed to the device. If unspecified, the resulting array will be + unpacked onto the same device it originated from. Setting ``device`` to a + device different from the source of ``external_array`` will require a + copy, meaning ``copy`` must be set to either ``True`` or ``None``. copy: An (optional) boolean, controlling whether or not a copy is performed. - If ``copy=True`` then a copy is always performed, even if unpacked onto the - same device. If ``copy=False`` then the copy is never performed and will raise - an error if necessary. When ``copy=None`` then a copy may be performed if - needed for a device transfer. + If ``copy=True`` then a copy is always performed, even if unpacked onto + the same device. If ``copy=False`` then the copy is never performed and + will raise an error if necessary. When ``copy=None`` then a copy may be + performed if needed for a device transfer. Returns: A jax.Array @@ -274,5 +275,15 @@ def from_dlpack(external_array, if hasattr(external_array, "__dlpack__"): return _from_dlpack(external_array, device, copy) - # Legacy path + # Deprecated legacy path. + # TODO(slebedev): Remove on or after December 3rd 2023. + deprecations.warn( + "jax-dlpack-import-legacy", + ( + "Calling from_dlpack with a DLPack tensor is deprecated. The argument" + " to from_dlpack should be an array from another framework that" + " implements the __dlpack__ protocol." + ), + stacklevel=2, + ) return _legacy_from_dlpack(external_array, device, copy) diff --git a/jax/_src/dtypes.py b/jax/_src/dtypes.py index 81f4180a1c12..82be38d1cb57 100644 --- a/jax/_src/dtypes.py +++ b/jax/_src/dtypes.py @@ -23,7 +23,9 @@ import abc import builtins +import dataclasses import functools +import types from typing import cast, overload, Any, Literal, Union import warnings @@ -784,7 +786,7 @@ def check_user_dtype_supported(dtype, fun_name=None): uint2, uint4, ] - if np_dtype.kind not in "biufc" and not is_custom_dtype: + if np_dtype.kind not in "biufc" and not is_custom_dtype and not dtype == float0: msg = f"JAX only supports number and bool dtypes, got dtype {dtype}" msg += f" in {fun_name}" if fun_name else "" raise TypeError(msg) @@ -793,7 +795,7 @@ def check_user_dtype_supported(dtype, fun_name=None): "and will be truncated to dtype {}. To enable more dtypes, set the " "jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell " "environment variable. " - "See https://github.com/google/jax#current-gotchas for more.") + "See https://github.com/jax-ml/jax#current-gotchas for more.") fun_name = f"requested in {fun_name}" if fun_name else "" truncated_dtype = canonicalize_dtype(np_dtype).name warnings.warn(msg.format(dtype, fun_name, truncated_dtype), stacklevel=3) @@ -834,3 +836,32 @@ def safe_to_cast(input_dtype_or_value: Any, # We deliberately use output_dtype rather than output_dtype_or_value here: # this effectively treats the output dtype as always strongly-typed. return result_type(input_dtype_or_value, output_dtype) == output_dtype + +def primal_tangent_dtype(primal_dtype, tangent_dtype, + name: str | None = None) -> ExtendedDType: + primal_dtype, tangent_dtype = map(dtype, (primal_dtype, tangent_dtype)) + name_ = name or (f'PrimalTangentDType{{{short_dtype_name(primal_dtype)}' + f'/{short_dtype_name(tangent_dtype)}}}') + rules = types.SimpleNamespace( + physical_element_aval= + lambda dtype: types.SimpleNamespace(shape=(), dtype=primal_dtype), + tangent_dtype=lambda dtype: tangent_dtype, + allow_conversion=True) + + class primal_tangent_dtype_scalar(extended): ... + + @dataclasses.dataclass(frozen=True) + class PrimalTangentDType(ExtendedDType): + name = name_ + _rules = rules + type = primal_tangent_dtype_scalar + __repr__ = lambda _: name_ + + return PrimalTangentDType() + +def short_dtype_name(dtype) -> str: + if isinstance(dtype, ExtendedDType): + return str(dtype) + else: + return (dtype.name.replace('float', 'f').replace('uint' , 'u') + .replace('int' , 'i').replace('complex', 'c')) diff --git a/jax/_src/earray.py b/jax/_src/earray.py index 36c8dc80c8ca..6598df01330a 100644 --- a/jax/_src/earray.py +++ b/jax/_src/earray.py @@ -104,11 +104,12 @@ def global_shards(self): # TODO(mattjj): _set_array_base_attributes -def _earray_shard_arg_handler(xs, shardings): +def _earray_shard_arg_handler(xs, shardings, layouts): arrs = [x._data for x in xs] phys_shardings = [sharding_impls.physical_sharding(x.aval, sharding) for x, sharding in zip(xs, shardings)] - return pxla.shard_args(phys_shardings, arrs) + # TODO(yashkatariya): `layouts` should be converted to physical layouts. + return pxla.shard_args(phys_shardings, layouts, arrs) pxla.shard_arg_handlers[EArray] = _earray_shard_arg_handler api_util._shaped_abstractify_handlers[EArray] = lambda self: self.aval diff --git a/jax/_src/export/_export.py b/jax/_src/export/_export.py index 4ee5dca86455..7f7773acbd39 100644 --- a/jax/_src/export/_export.py +++ b/jax/_src/export/_export.py @@ -58,7 +58,7 @@ zip = util.safe_zip DType = Any -Shape = jax._src.core.Shape +Shape = core.Shape # The values of input and output sharding from the lowering. LoweringSharding = Union[sharding.Sharding, pxla.UnspecifiedValue] HloSharding = xla_client.HloSharding @@ -609,6 +609,10 @@ def _export_lowered( f"disabled_checks={disabled_checks}") logging.info("Exported JAX function: %s\n", logmsg) logging.info(mlir.dump_module_message(mlir_module, "export")) + logging.info( + "Size of mlir_module_serialized: %d byte", + len(mlir_module_serialized), + ) _check_module(mlir_module, disabled_checks=disabled_checks) @@ -677,8 +681,7 @@ def _get_exported_vjp(exp_primal: Exported) -> Exported: def _module_to_bytecode(module: ir.Module) -> bytes: mlir_str = mlir.module_to_bytecode(module) # `target_version` is used to manage situations when a StableHLO producer - # (in this case, jax2tf) and a StableHLO consumer were built using - # different versions of StableHLO. + # and a StableHLO consumer were built using different versions of StableHLO. # # Each StableHLO version `producer_version` has a compatibility window, # i.e. range of versions [`consumer_version_min`, `consumer_version_max`], @@ -687,12 +690,19 @@ def _module_to_bytecode(module: ir.Module) -> bytes: # See https://github.com/openxla/stablehlo/blob/main/docs/compatibility.md # for the exact extent of these compatibility guarantees. # - # `hlo.get_minimum_version()` returns `consumer_version_min` - # for the current version of StableHLO. We are using it here to maximize - # forward compatibility, i.e. to maximize how far into the past we can go - # and still have the payloads produced by `serialize_portable_artifact` - # compatible with potential consumers from the past. - target_version = hlo.get_minimum_version() + # `hlo.get_version_from_compatibility_requirement(WEEK_4)` returns a version + # of StableHLO >= 4w old. This allows new StableHLO features to be used after + # ~4w and be compatible with any consumer that is updated on at least a + # monthly cadence. + # + # Note that this does not verify any JAX custom calls, which are only + # guaranteed 3w of forward compatibility, and only prevents use of new + # StableHLO features from failing on older hardware. + if hlo.get_api_version() < 9: + target_version = hlo.get_minimum_version() + else: + target_version = hlo.get_version_from_compatibility_requirement( + hlo.StablehloCompatibilityRequirement.WEEK_4) module_serialized = xla_client._xla.mlir.serialize_portable_artifact( # type: ignore mlir_str, target_version) return module_serialized @@ -920,6 +930,11 @@ def _check_lowering(lowering) -> None: _CPU_FFI_KERNELS = [ "lapack_spotrf_ffi", "lapack_dpotrf_ffi", "lapack_cpotrf_ffi", "lapack_zpotrf_ffi", + "lapack_sgeqrf_ffi", "lapack_dgeqrf_ffi", "lapack_cgeqrf_ffi", "lapack_zgeqrf_ffi", + "lapack_sorgqr_ffi", "lapack_dorgqr_ffi", "lapack_cungqr_ffi", "lapack_zungqr_ffi", + "lapack_ssyevd_ffi", "lapack_dsyevd_ffi", "lapack_cheevd_ffi", "lapack_zheevd_ffi", + "lapack_sgeev_ffi", "lapack_dgeev_ffi", "lapack_cgeev_ffi", "lapack_zgeev_ffi", + "lapack_sgesdd_ffi", "lapack_dgesdd_ffi", "lapack_cgesdd_ffi", "lapack_zgesdd_ffi", "lapack_sgetrf_ffi", "lapack_dgetrf_ffi", "lapack_cgetrf_ffi", "lapack_zgetrf_ffi", ] # These are the JAX custom call target names that are guaranteed to be stable. @@ -960,12 +975,11 @@ def _check_lowering(lowering) -> None: "lapack_sgetrf", "lapack_dgetrf", "lapack_cgetrf", "lapack_zgetrf", # schur on CPU "lapack_sgees", "lapack_dgees", "lapack_cgees", "lapack_zgees", - # # lu on GPU + # lu on GPU + "cu_lu_pivots_to_permutation", # "cublas_getrf_batched", "cusolver_getrf", # "hipblas_getrf_batched", "hipsolver_getrf", - # TODO(b/357034884): This can be added once the mimimum version of jaxlib - # (v0.4.32) includes this new FFI call. - # "cusolver_getrf_ffi", + "cusolver_getrf_ffi", # lu on TPU "LuDecomposition", # ApproxTopK on TPU @@ -1113,7 +1127,7 @@ def flattened_primal_fun_jax(*args_flat): vjp_in_avals = list( itertools.chain(in_avals, - map(lambda a: a.at_least_vspace(), out_avals))) + map(lambda a: a.to_tangent_aval(), out_avals))) if apply_jit: assert device_assignment is not None diff --git a/jax/_src/export/shape_poly.py b/jax/_src/export/shape_poly.py index 0173df4fd345..77786cbf1a9d 100644 --- a/jax/_src/export/shape_poly.py +++ b/jax/_src/export/shape_poly.py @@ -80,6 +80,8 @@ def __init__(self, message: str): # https://github.com/python/mypy/issues/5887 super().__init__(error_msg) +class UnexpectedDimVar(Exception): + pass class Comparator(Enum): EQ = 1 @@ -87,12 +89,14 @@ class Comparator(Enum): @dataclasses.dataclass(frozen=True) class _SymbolicConstraint: + # Either e1 == e2 if cmp == Comparator.EQ else e1 >= e2 cmp: Comparator debug_str: str # The form in which the user expressed it, for error messages - diff: _DimExpr # For GEQ: diff >= 0, and for EQ: diff == 0 + e1: DimSize # This has been normalized w.r.t. previous constraints only + e2: DimSize # This has been normalized w.r.t. previous constraints only def __repr__(self): - return f"Constraint({self.debug_str}: {self.diff})" + return f"Constraint({self.debug_str})" class _DimFactor: @@ -209,15 +213,22 @@ def __ge__(self, other: _DimFactor): """Lexicographic comparison""" return self._syntactic_cmp(other) >= 0 - def evaluate(self, env: DimVarEnv): + def evaluate(self, env: DimVarEnv, scope: SymbolicScope): if self.var is not None: try: return env[self.var] except KeyError: + # Perhaps there is a normalization rule for this variable + normalized_var = _DimExpr._from_var(self.var, scope) + if core.is_constant_dim(normalized_var): + return normalized_var + non_trivial_normalization = (v1 := normalized_var._to_var()) is None or v1 != self.var # type: ignore + if non_trivial_normalization: + return normalized_var._evaluate(env) # type: ignore err_msg = ( f"Encountered dimension variable '{self.var}' that is not appearing in the shapes of the function arguments.\n" "Please see https://jax.readthedocs.io/en/latest/export/shape_poly.html#dimension-variables-must-be-solvable-from-the-input-shapes for more details.") - raise KeyError(err_msg) + raise UnexpectedDimVar(err_msg) else: operand_values = [opnd._evaluate(env) for opnd in self.operands] if self.operation == _DimFactor.FLOORDIV: @@ -370,11 +381,11 @@ def divide(self, divisor: _DimTerm) -> _DimTerm: raise InconclusiveDimensionOperation(f"Cannot divide {self} by {divisor}.") return _DimTerm(new_factors) - def evaluate(self, env: DimVarEnv): + def evaluate(self, env: DimVarEnv, scope: SymbolicScope): prod = lambda xs: functools.reduce(_evaluate_multiply, xs) if xs else core.dim_constant(1) def pow_opt(v, p: int): return v if p == 1 else prod([v] * p) - return prod([pow_opt(f.evaluate(env), exp) for f, exp in self._factors]) + return prod([pow_opt(f.evaluate(env, scope), exp) for f, exp in self._factors]) def __deepcopy__(self, memo): return _DimTerm(copy.deepcopy(self._factors, memo)) @@ -404,7 +415,7 @@ class _DimExpr: def __init__(self, sorted_terms: SortedTerms, scope: SymbolicScope): # Do not construct _DimExpr directly, unless you are sure that `terms` is - # normalized; Use _DimExpr.normalize. + # normalized; Use _DimExpr._normalize_sorted_terms. self._sorted_terms = tuple(sorted_terms) or ((_DimTerm_one, 0),) self._scope = scope self._hash = None @@ -426,8 +437,8 @@ def _from_term(t: _DimTerm, t_k: int, scope: SymbolicScope) -> DimSize: return _DimExpr._normalize_sorted_terms(((t, t_k),), scope) @staticmethod - def _from_var(v: str, scope: SymbolicScope) -> _DimExpr: - return _DimExpr(((_DimTerm.from_var(v), 1),), scope) + def _from_var(v: str, scope: SymbolicScope) -> DimSize: + return _DimExpr._normalize_sorted_terms(((_DimTerm.from_var(v), 1),), scope) @staticmethod def _from_operation(operation: str, *operands: DimSize, @@ -475,8 +486,9 @@ def _add_coeff(coeffs: dict[_DimTerm, int], t: _DimTerm, coeff: int): def _normalize_term(t: _DimTerm, t_k: int, scope: SymbolicScope) -> Sequence[tuple[_DimTerm, int]]: # If (t, t_k) is among the scope normalization rules, then return - # a list of updates to apply to the expression containing (t, t_k). - # Returns empty sequence if no normalizations are necessary. + # a list of `term * coefficient` to add to the expression containing (t, t_k). + # Returns the empty sequence if no normalizations are necessary. + if not scope._normalization_rules: return [] updates = [] after, t_k_after = scope._normalization_rules.get(t, (None, 0)) if after is not None and t_k % t_k_after == 0: @@ -899,7 +911,7 @@ def _divmod(self, divisor: DimSize) -> tuple[DimSize, int]: def _evaluate(self, env: DimVarEnv): # Evaluates as a value of dtype=core.dim_value_dtype() - terms = [_evaluate_multiply(t.evaluate(env), core.dim_constant(t_k)) + terms = [_evaluate_multiply(t.evaluate(env, self.scope), core.dim_constant(t_k)) for t, t_k in self._sorted_terms] return functools.reduce(_evaluate_add, terms) if len(terms) > 1 else terms[0] @@ -1046,8 +1058,6 @@ def _parse_and_process_explicit_constraint(self, c_str: str): raise ValueError(f"Unsatisfiable explicit constraint: {c_str}") return - constr = _SymbolicConstraint(debug_str=c_str, cmp=cmp, diff=diff) # type: ignore[arg-type] - self._explicit_constraints.append(constr) if cmp == Comparator.EQ: if not isinstance(e1, _DimExpr): raise ValueError("Invalid equality constraint: {e1} == {e2}. " @@ -1063,6 +1073,9 @@ def _parse_and_process_explicit_constraint(self, c_str: str): f"Found multiple equality constraints with the same left-hand-side: {before}") self._normalization_rules[before] = (after, before_k) + constr = _SymbolicConstraint(debug_str=c_str, cmp=cmp, e1=e1, e2=e2) + self._explicit_constraints.append(constr) + def _check_same_scope(self, other: _DimExpr, when: str = "", self_descr: str = " ", @@ -2016,7 +2029,7 @@ def _solve_dim_equations( # Returns a shape environment and the shape constraints if it can solve all # dimension variables. Raises an exception if it cannot. shape_env: DimVarEnv = {} - solution_error_message_pieces: list[str | _DimExpr] = [ + solution_error_message_pieces: list[str | DimSize] = [ " Obtained dimension variables: " ] # Error message describing the solution # Prepare error message piece describing the polymorphic shape specs @@ -2050,8 +2063,8 @@ def process_one_eqn(eqn: _DimEquation) -> bool: for term, term_k in eqn.aval_dim_expr._sorted_terms: # Perhaps we can already evaluate this term (all vars solved) try: - term_value = term.evaluate(shape_env) - except KeyError: + term_value = term.evaluate(shape_env, scope) + except UnexpectedDimVar: # `mon` still uses some variables not yet solved. We handle only the # case when `mon` is a single variable. v = term.to_var() @@ -2118,14 +2131,19 @@ def add_explicit_symbolic_constraints(shape_env: DimVarEnv): if not shape_env: return assert scope is not None for constr in scope._explicit_constraints: - c_value = constr.diff._evaluate(shape_env) + # We can't just construct constr.e1 - constr.e2 because for an equality + # constraint it would be reduced to 0. + c_e1 = constr.e1._evaluate(shape_env) if not core.is_constant_dim(constr.e1) else constr.e1 # type: ignore + c_e2 = constr.e2._evaluate(shape_env) if not core.is_constant_dim(constr.e2) else constr.e2 # type: ignore + c_diff = c_e1 - c_e2 shape_constraints.add_constraint( - constr.cmp, c_value, 0, + constr.cmp, c_diff, 0, error_message_pieces=[ f"Input shapes do not match the symbolic shape constraint {constr.debug_str}. " - f"Expected '{constr.diff}' to be " + f"Expected '{constr.e1} - {constr.e2}' to be " f"{'greater or equal' if constr.cmp == Comparator.GEQ else 'equal'} to 0, " - "but found ", c_value, + "but found ", c_diff, + ". " + poly_specs_err_msg ] + solution_error_message_pieces + [ solution_err_msg_trailer_errors]) diff --git a/jax/_src/export/shape_poly_decision.py b/jax/_src/export/shape_poly_decision.py index e325722b0c26..4bad8b7be06d 100644 --- a/jax/_src/export/shape_poly_decision.py +++ b/jax/_src/export/shape_poly_decision.py @@ -23,6 +23,7 @@ import numpy as np +from jax._src import core from jax._src.export import shape_poly from jax._src.export.shape_poly import ( _DimExpr, _DimTerm, _DimFactor, @@ -84,7 +85,10 @@ def initialize(self) -> _DecisionByElimination: # the result (albeit, for now, without a good feedback loop to understand # how the order matters for inequalities). for constr in self.scope._explicit_constraints: - self.add_implicit_constraints_expr(constr.diff) + if not core.is_constant_dim(constr.e1): + self.add_implicit_constraints_expr(constr.e1) # type: ignore + if not core.is_constant_dim(constr.e2): + self.add_implicit_constraints_expr(constr.e2) # type: ignore # The equality constraints are not needed for inequality decisions, # because the LHS should always be rewritten in terms of the RHS. # In fact, adding them may break the assumption that if we eliminate @@ -92,7 +96,7 @@ def initialize(self) -> _DecisionByElimination: # may appear in the rest and may be rewritten to something larger. # However, we want to add the implicit constraints within. if constr.cmp == Comparator.GEQ: - self.combine_and_add_constraint(constr.cmp, constr.diff, 0, + self.combine_and_add_constraint(constr.cmp, constr.e1 - constr.e2, 0, constr.debug_str) diff --git a/jax/_src/extend/ffi.py b/jax/_src/extend/ffi.py index df1c09efffc5..833ac4f615a8 100644 --- a/jax/_src/extend/ffi.py +++ b/jax/_src/extend/ffi.py @@ -14,7 +14,7 @@ from __future__ import annotations -from collections.abc import Iterable, Mapping, Sequence +from collections.abc import Mapping, Sequence import ctypes import functools import os @@ -22,19 +22,19 @@ from jax._src import core from jax._src import dispatch -from jax._src import dtypes from jax._src import util from jax._src.callback import _check_shape_dtype, callback_batching_rule from jax._src.interpreters import ad from jax._src.interpreters import batching from jax._src.interpreters import mlir +from jax._src.layout import DeviceLocalLayout from jax._src.lib import jaxlib from jax._src.lib import xla_client from jax._src.lib.mlir import ir -from jax._src.typing import Array, ArrayLike, DimSize, DuckTypedArray -import numpy as np +from jax._src.typing import Array, ArrayLike, DuckTypedArray, Shape map, unsafe_map = util.safe_map, map +FfiLayoutOptions = Sequence[int] | DeviceLocalLayout | None def register_ffi_target( @@ -102,11 +102,28 @@ def include_dir() -> str: return os.path.join(jaxlib_dir, "include") +def _aval_shape(aval: core.AbstractValue) -> Shape: + return () if aval is core.abstract_token else aval.shape # pytype: disable=attribute-error + + +def _convert_layout(aval: core.AbstractValue, + layout: FfiLayoutOptions = None) -> Sequence[int]: + """Convert a layout to the minor-to-major order used by the custom call API.""" + if layout is None: + return list(reversed(range(len(_aval_shape(aval))))) + elif isinstance(layout, DeviceLocalLayout): + if layout._tiling is not None: + raise ValueError("The FFI does not support layouts with tiling") + return layout.major_to_minor[::-1] + else: + return layout + + def ffi_lowering( call_target_name: str, *, - operand_layouts: Sequence[Sequence[DimSize]] | None = None, - result_layouts: Sequence[Sequence[DimSize]] | None = None, + operand_layouts: Sequence[FfiLayoutOptions] | None = None, + result_layouts: Sequence[FfiLayoutOptions] | None = None, backend_config: Mapping[str, ir.Attribute] | None = None, **lowering_args: Any ) -> mlir.LoweringRule: @@ -137,48 +154,47 @@ def _lowering( kwargs = dict(lowering_args) kwargs.setdefault("api_version", 4) kwargs["backend_config"] = dict( - backend_config or {}, **{k: _ir_attribute(v) for k, v in params.items()}) + backend_config or {}, **{k: mlir.ir_attribute(v) for k, v in params.items()}) if "result_types" not in kwargs: kwargs["result_types"] = [mlir.aval_to_ir_type(aval) for aval in ctx.avals_out] if operand_layouts is None: - kwargs["operand_layouts"] = _default_layouts(aval.shape for aval in ctx.avals_in) # pytype: disable=attribute-error + kwargs["operand_layouts"] = map(_convert_layout, ctx.avals_in) + else: + kwargs["operand_layouts"] = [ + _convert_layout(*args) for args in zip(ctx.avals_in, operand_layouts)] if result_layouts is None: - kwargs["result_layouts"] = _default_layouts(aval.shape for aval in ctx.avals_out) + kwargs["result_layouts"] = map(_convert_layout, ctx.avals_out) + else: + kwargs["result_layouts"] = [ + _convert_layout(*args) for args in zip(ctx.avals_out, result_layouts)] + if "result_shapes" not in kwargs and not all( + core.is_constant_shape(_aval_shape(aval)) for aval in ctx.avals_out): + kwargs["result_shapes"] = [ + mlir.shape_tensor(mlir.eval_dynamic_shape_as_ivals(ctx, _aval_shape(aval))) + for aval in ctx.avals_out] return mlir.custom_call(call_target_name, operands=operands, **kwargs).results # type: ignore return _lowering -def _default_layouts(shapes: Iterable[Sequence[DimSize]]) -> list[list[DimSize]]: - return [list(reversed(range(len(shape)))) for shape in shapes] - - -def _ir_attribute(obj: Any) -> ir.Attribute: - # TODO(dfm): Similar functions exist in Pallas and Mosaic GPU. Perhaps these - # could be consolidated into mlir or similar. - if isinstance(obj, str): - return ir.StringAttr.get(obj) - elif isinstance(obj, bool): - return ir.BoolAttr.get(obj) - elif isinstance(obj, int): - return mlir.i64_attr(obj) - elif isinstance(obj, float): - return ir.FloatAttr.get_f64(obj) - elif hasattr(obj, "dtype"): - if not (dtypes.is_python_scalar(obj) or np.isscalar(obj)): - raise TypeError("Only scalar attributes are supported") - mlir_type = mlir.dtype_to_ir_type(obj.dtype) - if isinstance(mlir_type, ir.IntegerType): - return ir.IntegerAttr.get(mlir_type, obj) - elif isinstance(mlir_type, ir.FloatType): - return ir.FloatAttr.get(mlir_type, obj) - raise TypeError(f"Unsupported attribute type: {type(obj)}") +ResultMetadata = DuckTypedArray | core.AbstractToken + + +def _result_avals(results: Sequence[ResultMetadata]) -> tuple[core.AbstractValue, ...]: + avals: list[core.AbstractValue] = [] + for result in results: + if isinstance(result, core.AbstractToken): + avals.append(result) + else: + _check_shape_dtype(result) + avals.append(core.ShapedArray(result.shape, result.dtype)) + return tuple(avals) def ffi_call( target_name: str, - result_shape_dtypes: DuckTypedArray | Sequence[DuckTypedArray], + result_shape_dtypes: ResultMetadata | Sequence[ResultMetadata], *args: ArrayLike, vectorized: bool = False, **kwargs: Any, @@ -204,6 +220,7 @@ def ffi_call( ``dtype`` attributes which are expected to match the shape and dtype of the custom call output or outputs. :class:`~jax.ShapeDtypeStruct` is often used to define the elements of ``result_shape_dtypes``. + ``jax.core.abstract_token`` may be used to represent a token-typed output. *args: the arguments passed to the custom call. vectorized: boolean specifying whether the callback function can operate in a vectorized manner, as described above. @@ -216,12 +233,10 @@ def ffi_call( """ if isinstance(result_shape_dtypes, Sequence): multiple_results = True - result_types = result_shape_dtypes + result_avals = _result_avals(result_shape_dtypes) else: multiple_results = False - result_types = (result_shape_dtypes,) - map(_check_shape_dtype, result_types) - result_avals = tuple(core.ShapedArray(x.shape, x.dtype) for x in result_types) + result_avals = _result_avals((result_shape_dtypes,)) results = ffi_call_p.bind( *args, result_avals=result_avals, @@ -237,7 +252,7 @@ def ffi_call( def ffi_call_abstract_eval( *avals_in, - result_avals: tuple[core.ShapedArray, ...], + result_avals: tuple[core.AbstractValue, ...], target_name: str, vectorized: bool, **kwargs: Any, @@ -263,7 +278,7 @@ def ffi_call_transpose(*args, target_name, **kwargs): def ffi_call_lowering( ctx: mlir.LoweringRuleContext, *operands: ir.Value, - result_avals: tuple[core.ShapedArray, ...], + result_avals: tuple[core.AbstractValue, ...], target_name: str, vectorized: bool, **kwargs: Any, diff --git a/jax/_src/flatten_util.py b/jax/_src/flatten_util.py index e18ad1f6e793..11a9dda66e74 100644 --- a/jax/_src/flatten_util.py +++ b/jax/_src/flatten_util.py @@ -61,7 +61,7 @@ def _ravel_list(lst): if all(dt == to_dtype for dt in from_dtypes): # Skip any dtype conversion, resulting in a dtype-polymorphic `unravel`. - # See https://github.com/google/jax/issues/7809. + # See https://github.com/jax-ml/jax/issues/7809. del from_dtypes, to_dtype raveled = jnp.concatenate([jnp.ravel(e) for e in lst]) return raveled, HashablePartial(_unravel_list_single_dtype, indices, shapes) diff --git a/jax/_src/hardware_utils.py b/jax/_src/hardware_utils.py index dd3da5c4f58b..81ef07a71b19 100644 --- a/jax/_src/hardware_utils.py +++ b/jax/_src/hardware_utils.py @@ -20,20 +20,16 @@ _TPU_PCI_DEVICE_IDS = [ # TPU v2, v3 '0x0027', + # No public name (plc) + '0x0056', # TPU v4 '0x005e', + # TPU v5p + '0x0062', # TPU v5e '0x0063', - # Testing only - '0x0056', - '0x0062', -] - -_TPU_ENHANCED_BARRIER_SUPPORTED = [ - # TPU v2, v3 - '0x0027', - # TPU v4 - '0x005e', + # TPU v6e + '0x006f', ] _NVIDIA_GPU_DEVICES = [ @@ -59,12 +55,6 @@ def num_available_tpu_chips_and_device_id(): return num_chips, device_id -def tpu_enhanced_barrier_supported() -> bool: - """Returns if tpu_enhanced_barrier flag is supported on this TPU version.""" - _, device_id = num_available_tpu_chips_and_device_id() - return device_id in _TPU_ENHANCED_BARRIER_SUPPORTED - - def has_visible_nvidia_gpu() -> bool: """True if there's a visible nvidia gpu available on device, False otherwise.""" diff --git a/jax/_src/internal_test_util/export_back_compat_test_data/cpu_eig_lapack_geev.py b/jax/_src/internal_test_util/export_back_compat_test_data/cpu_eig_lapack_geev.py index 1e4b6428556b..bc28857fa325 100644 --- a/jax/_src/internal_test_util/export_back_compat_test_data/cpu_eig_lapack_geev.py +++ b/jax/_src/internal_test_util/export_back_compat_test_data/cpu_eig_lapack_geev.py @@ -283,3 +283,268 @@ mlir_module_serialized=b"ML\xefR\x01StableHLO_v0.9.0\x00\x01#\x05\x01\x03\x01\x03\x05\x03\x13\x07\t\x0b\r\x0f\x11\x13\x15\x17\x03\xe5\x9b7\x01[\x0f\x13\x0b\x07\x0b\x13\x0f\x0b\x17\x0b\x13\x13#\x0b\x0b\x0b3\x0b\x0b\x0b\x0b\x13\x0b\x0f\x0b\x0f\x0b\x13\x0b\x17\x13K\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x13\x1b\x0b\x0b\x13\x03A\x0fO\x0b\x0b/\x0b\x17\x13\x0b\x13\x0b\x13\x0b\x0b\x0b\x0f\x1f\x1f\x13\x0b\x0b\x0b\x0b\x1f#\x1f\x0f\x0b\x0bO/O\x01\x03\x0f\x035\x17\x0f\x07\x13\x0b\x07\x0f\x0f\x07\x07\x17\x17\x1b\x13\x07\x07\x13\x13\x13\x13\x13\x0f\x13\x13\x13\x13\x02z\x06\x1d9;\x03\x03\t\x8f\x05\x19\x1f\x05\x1b\x03\x03\x05\x95\x11\x01\x05\x05\x1d\x17\x13\xc2\x07\x01\x05\x1f\x03\x03\x05\x7f\x03\x03\t\x99\x03\x07\x1b\r\x1d\r\x0f\x1f\x05!\x05#\x05%\x03\x0b#_%e'g\x0fu)w\x05'\x05)\x05+\x05-\x03\x03-y\x05/\x1d1\x11\x051\x1d5\x11\x053\x03\x03\x05{\x055\x17\x13\xc6\x07\x01\x03\x03\x05}\x03\x11A\x81C\x83E\x85G_I\x87K\x89M_O\x8b\x057\x059\x05;\x05=\x05?\x05A\x05C\x05E\x03\x03\x05\x8d\x03\x05U\x91W\x93\x05G\x05I\x03\x03\t\x97\x1f%\x01\x1f'!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x03\x01\x1dK\x1f)\x11\x00\x00\x00\x00\x00\x00\x00\x00#\x1b\x03\x07imq\r\x03ak\x1dM\r\x03ao\x1dO\r\x03as\x1dQ\x1dS\x1dU\x13\r\x01\x1f\x05\t\x01\x00\x00\x00\x1f\x05\t\x04\x00\x00\x00\x1f\x11\x03V\x0b\x05\x1dW\x1dY\x05\x01\x03\x0b[[[[]\x03\r]cc]][\x1f\x05\t\x00\x00\x00\x00\x1f+\x01\t\x07\x07\x01\x1f\x0f!\x00\x00\x00\x00\x00\x00\xf8\x7f\x00\x00\x00\x00\x00\x00\xf8\x7f\x1f3\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1f5!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x01\x02\x02)\x05\x11\x11\x0b)\x01\x1f\x01)\x03\x11\x0b\x03\x15\x1d)\x01\x0b)\x01!\x13\x0b)\x05\x05\x05\x07)\x05\x11\x11\x07\x11\x01\x07\t\x03\x03)\x03A\x0b\x1b!)\x03!\x15)\x03\x01\x13)\x03\t\x13)\x03\x05\x13)\x03\x01\r)\x01\x07)\x03\x05\x07)\x03\x11\x07)\x03\x05\r)\x03\t\r\x04j\x03\x05\x01\x11\x07\x19\x07\x03\x01\x05\t\x11\x07!\x05\x03=i\x0b\x03/+\x03\x1d\r\x063\x03\x03\x03\x01\x05\x03\x017\x03\x05\x05\x03\x01=\x03\x05\x05\x03\x01\x15\x03\x11\x05\x03\x01\x15\x03\x11\x0f\x07\x01?\r\x03#\t\x03\x03\x05\x0b\x05\x07\t\x0b\x03\x05\x03\x01Q\x03\x05\x03\x07\x01\x03\x03\x05\x03\x19\x11\x07\x01S\x03-\x05\x17\x1b\x03\x07\x01\x03\x03/\x03\x1d\x05\x03\x01\x0b\x03\x0f\x03\x07\x01\x03\x03\t\x03!\x03\x07\x01Y\x031\x03\x1f\x07\x06\x01\x03\t\x07%\x11#\x03\x07\x01\x03\x03\x17\x03\x1d\x05\x03\x01\x0b\x03\x0f\x03\x07\x01\x03\x03\x03\x03+\x03\x07\x01\x17\x03\x19\x03)\x07\x06\x01\x03\x03\x07/\x13-\x03\x07\x01\x03\x03\x17\x03\x1d\x05\x03\x01\x0b\x03\x0f\x03\x07\x01\x03\x03\x03\x035\x03\x07\x01\x17\x03\x19\x033\x07\x06\x01\x03\x03\x079\x157\x13\x04\x07\x07'1;\x06\x03\x01\x05\x01\x00\x02\r[\x1b\x03\x0f\x0b\t\t\t!+\x1b\x1f/!!)#\x1f\x19\xb1}\x87\x1f\x1f\x15\x1d\x15\x13%)\x97\x13+\r\x15\x17\x1f\x17\x11\x11\x15\x19)\x0f\x0b\x11builtin\x00vhlo\x00module\x00broadcast_in_dim_v1\x00constant_v1\x00select_v1\x00func_v1\x00iota_v1\x00reshape_v1\x00custom_call_v1\x00compare_v1\x00return_v1\x00value\x00broadcast_dimensions\x00sym_name\x00/Users/necula/Source/jax/jax/experimental/jax2tf/tests/back_compat_test.py\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit_func\x00arg_attrs\x00function_type\x00res_attrs\x00sym_visibility\x00iota_dimension\x00jit(func)/jit(main)/iota[dtype=complex128 shape=(16,) dimension=0]\x00jit(func)/jit(main)/reshape[new_sizes=(4, 4) dimensions=None]\x00jit(func)/jit(main)/eig[compute_left_eigenvectors=True compute_right_eigenvectors=True]\x00api_version\x00backend_config\x00call_target_name\x00called_computations\x00has_side_effect\x00operand_layouts\x00output_operand_aliases\x00result_layouts\x00compare_type\x00comparison_direction\x00jax.result_info\x00[0]\x00[1]\x00[2]\x00main\x00public\x00\x00lapack_zgeev\x00", xla_call_module_version=6, ) # End paste + + +data_2024_08_19 = {} + + +# Pasted from the test output (see export_back_compat_test_util.py module docstring) +data_2024_08_19["c128"] = dict( + testdata_version=1, + platform='cpu', + custom_call_targets=['lapack_zgeev_ffi'], + serialized_date=datetime.date(2024, 8, 19), + inputs=(), + expected_outputs=(array([ 3.2464249196572972e+01+0.j, -2.4642491965729794e+00+0.j, + -1.4596915295025735e-15+0.j, 4.7403016698320490e-16+0.j]), array([[ 0.40377749076862324+0.j, 0.8288327563197503 +0.j, + -0.5409014947846461 +0.j, 0.10917005482608667-0.j], + [ 0.4648073711584899 +0.j, 0.43714638836388775-0.j, + 0.7854306338527134 +0.j, -0.5456169434539783 +0.j], + [ 0.5258372515483575 +0.j, 0.04546002040802463-0.j, + 0.05184321664851461-0.j, 0.7637237224296971 +0.j], + [ 0.5868671319382249 +0.j, -0.34622634754783843+0.j, + -0.296372355716581 +0.j, -0.32727683380180517+0.j]]), array([[ 0.11417645138733866+0.j, 0.7327780959803557 +0.j, + -0.5367326141844461 +0.j, -0.08617176416747369+0.j], + [ 0.33000459866554754+0.j, 0.28974835239692603-0.j, + 0.6342729310130916 +0.j, -0.28826848493327445+0.j], + [ 0.5458327459437569 +0.j, -0.15328139118650222+0.j, + 0.34165198052715445-0.j, 0.83505226236897 +0.j], + [ 0.7616608932219664 +0.j, -0.5963111347699301 +0.j, + -0.4391922973557999 +0.j, -0.460612013268222 +0.j]])), + mlir_module_text=r""" +module @jit_func attributes {jax.uses_shape_polymorphism = false, mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { + func.func public @main() -> (tensor<4xcomplex> {jax.result_info = "[0]", mhlo.layout_mode = "default"}, tensor<4x4xcomplex> {jax.result_info = "[1]", mhlo.layout_mode = "default"}, tensor<4x4xcomplex> {jax.result_info = "[2]", mhlo.layout_mode = "default"}) { + %0 = stablehlo.iota dim = 0 : tensor<16xcomplex> loc(#loc3) + %1 = stablehlo.reshape %0 : (tensor<16xcomplex>) -> tensor<4x4xcomplex> loc(#loc4) + %c = stablehlo.constant dense<4> : tensor loc(#loc5) + %c_0 = stablehlo.constant dense<4> : tensor loc(#loc5) + %2:4 = stablehlo.custom_call @lapack_zgeev_ffi(%1) {mhlo.backend_config = {compute_left = 86 : ui8, compute_right = 86 : ui8}, operand_layouts = [dense<[0, 1]> : tensor<2xindex>], result_layouts = [dense<0> : tensor<1xindex>, dense<[0, 1]> : tensor<2xindex>, dense<[0, 1]> : tensor<2xindex>, dense<> : tensor<0xindex>]} : (tensor<4x4xcomplex>) -> (tensor<4xcomplex>, tensor<4x4xcomplex>, tensor<4x4xcomplex>, tensor) loc(#loc5) + %c_1 = stablehlo.constant dense<0> : tensor loc(#loc5) + %3 = stablehlo.broadcast_in_dim %c_1, dims = [] : (tensor) -> tensor loc(#loc5) + %4 = stablehlo.compare EQ, %2#3, %3, SIGNED : (tensor, tensor) -> tensor loc(#loc5) + %5 = stablehlo.broadcast_in_dim %4, dims = [] : (tensor) -> tensor<1xi1> loc(#loc5) + %cst = stablehlo.constant dense<(0x7FF8000000000000,0x7FF8000000000000)> : tensor> loc(#loc5) + %6 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor>) -> tensor<4xcomplex> loc(#loc5) + %7 = stablehlo.broadcast_in_dim %5, dims = [0] : (tensor<1xi1>) -> tensor<4xi1> loc(#loc5) + %8 = stablehlo.select %7, %2#0, %6 : tensor<4xi1>, tensor<4xcomplex> loc(#loc5) + %9 = stablehlo.broadcast_in_dim %4, dims = [] : (tensor) -> tensor<1x1xi1> loc(#loc5) + %cst_2 = stablehlo.constant dense<(0x7FF8000000000000,0x7FF8000000000000)> : tensor> loc(#loc5) + %10 = stablehlo.broadcast_in_dim %cst_2, dims = [] : (tensor>) -> tensor<4x4xcomplex> loc(#loc5) + %11 = stablehlo.broadcast_in_dim %9, dims = [0, 1] : (tensor<1x1xi1>) -> tensor<4x4xi1> loc(#loc5) + %12 = stablehlo.select %11, %2#1, %10 : tensor<4x4xi1>, tensor<4x4xcomplex> loc(#loc5) + %13 = stablehlo.broadcast_in_dim %4, dims = [] : (tensor) -> tensor<1x1xi1> loc(#loc5) + %cst_3 = stablehlo.constant dense<(0x7FF8000000000000,0x7FF8000000000000)> : tensor> loc(#loc5) + %14 = stablehlo.broadcast_in_dim %cst_3, dims = [] : (tensor>) -> tensor<4x4xcomplex> loc(#loc5) + %15 = stablehlo.broadcast_in_dim %13, dims = [0, 1] : (tensor<1x1xi1>) -> tensor<4x4xi1> loc(#loc5) + %16 = stablehlo.select %15, %2#2, %14 : tensor<4x4xi1>, tensor<4x4xcomplex> loc(#loc5) + return %8, %12, %16 : tensor<4xcomplex>, tensor<4x4xcomplex>, tensor<4x4xcomplex> loc(#loc) + } loc(#loc) +} loc(#loc) +#loc = loc(unknown) +#loc1 = loc("third_party/py/jax/tests/export_back_compat_test.py":210:14) +#loc2 = loc("third_party/py/jax/tests/export_back_compat_test.py":211:13) +#loc3 = loc("jit(func)/jit(main)/iota[dtype=complex128 shape=(16,) dimension=0]"(#loc1)) +#loc4 = loc("jit(func)/jit(main)/reshape[new_sizes=(4, 4) dimensions=None]"(#loc1)) +#loc5 = loc("jit(func)/jit(main)/eig[compute_left_eigenvectors=True compute_right_eigenvectors=True]"(#loc2)) +""", + mlir_module_serialized=b'ML\xefR\x01StableHLO_v0.9.0\x00\x01#\x05\x01\x03\x01\x03\x05\x03\x13\x07\t\x0b\r\x0f\x11\x13\x15\x17\x03\xef\xa57\x01]\x0f\x13\x07\x0b\x0b\x13\x0f\x0b\x17\x0b\x13\x13+\x0b\x0f\x0b\x0b\x0b3\x0b\x0b\x0b\x0b\x13\x0b\x0f\x0b\x0f\x0b\x0b\x17S\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x13\x1b\x0b\x0b\x13\x03I\x0b\x0b\x0b\x0bO\x0f\x0b\x17\x1b\x0b\x1b\x0b\x1b\x0b\x0b\x0b\x0f/\x0b\x0b\x0b\x0b\x1b\x0b\x0b\x0f\x1b/\x0f\x1f\x0f\x0b\x0bO/O\x01\x05\x0b\x0f\x033\x17\x07\x07\x13\x0b\x0f\x0f\x0f\x07\x17\x17\x1b\x07\x13\x07\x07\x13\x13\x13\x13\x0f\x13\x13\x13\x13\x02\xa6\x06\x1d;=\x03\x03\t\x99\x1f\x05\x19\x05\x1b\x03\x03\x07\x9f\x11\x03\x05\x05\x1d\x17\x13J\x03\x1d\x05\x1f\x03\x03\x07\x7f\x03\x03\t\xa3\x03\t\x1b\x1d\x1f\r!\r\x0f#\x05!\x11\x01\x00\x05#\x05%\x05\'\x03\x0b\'])i+k\x0fy-{\x05)\x05+\x05-\x05/\x03\x031}\x051\x1d5\x11\x053\x1d9\x11\x055\x057\x17\x13N\x03\x1b\x03\x13A\x81C\x83E\x85G]I\x87K\x89M\x8fO]Q\x91\x059\x05;\x05=\x05?\x05A\x05C\x05E\x05G\x05I\x03\x03\x07\x97\x03\x05W\x9bY\x9d\x05K\x05M\x03\x03\t\xa1\x03\x01\x1dO\x1dQ\x1dS\x1f%!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x13#V#\x1b\x03\x07mqu\r\x05_oac\x1dU\r\x05_sac\x1dW\r\x05_wac\x1dY\x1d[\x1d]\x13\x07\x01\x1f\x13\x11\x04\x00\x00\x00\x00\x00\x00\x00\x0b\x03\x1d_\x1da\x05\x01\r\x05\x8bg\x8dg\x1dc\x1de\x03\x03e\x03\t\x93ee\x95\x1f\'\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1f)\x01\x1f\x0f\t\x00\x00\x00\x00\x1f+\x01\t\x07\x07\x01\x1f\x11!\x00\x00\x00\x00\x00\x00\xf8\x7f\x00\x00\x00\x00\x00\x00\xf8\x7f\x1f3\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1f5!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x01\t\x01\x02\x02)\x05\x11\x11\r\x1d\x01)\x03\x11\r\x03\x1d)\x01!)\x01\r)\x01\x07\x13)\x05\x05\x05\t)\x05\x11\x11\t\x11\x01\x07\x0b\x05\x05\x0b)\x03A\r\x1b!)\x03\t\x15)\x03\x05\x15)\x03\x01\x15)\x03\x01\x07)\x01\t)\x03\x05\t)\x03\x11\t)\x03\x05\x07)\x03\t\x07\x04"\x03\x05\x01\x11\x05\x19\x07\x03\x01\x05\t\x11\x05%\x07\x035a\x0b\x033/\x03\x1f\r\x067\x03\x05\x03\x01\x05\x03\x01\x15\x03\x13\x05\x03\x01\x15\x03\x13\x0f\x07\x01?\t\x0b\x05\x05\x0f\x03\x03\x05\x03\x01S\x03\x0f\x03\x07\x01\x03\x03\x0f\x03\x11\x11\x07\x01U\x03-\x05\x0f\x13\x03\x07\x01\x03\x03/\x03\x15\x05\x03\x01\x0b\x03\x11\x03\x07\x01\x03\x03\x0b\x03\x19\x03\x07\x01[\x031\x03\x17\x07\x06\x01\x03\x0b\x07\x1d\t\x1b\x03\x07\x01\x03\x03\x17\x03\x15\x05\x03\x01\x0b\x03\x11\x03\x07\x01\x03\x03\x05\x03#\x03\x07\x01\x17\x03\x19\x03!\x07\x06\x01\x03\x05\x07\'\x0b%\x03\x07\x01\x03\x03\x17\x03\x15\x05\x03\x01\x0b\x03\x11\x03\x07\x01\x03\x03\x05\x03-\x03\x07\x01\x17\x03\x19\x03+\x07\x06\x01\x03\x05\x071\r/\x13\x04\x05\x07\x1f)3\x06\x03\x01\x05\x01\x00^\x0eg\x1d\x1b#\x03\x0f\x0b\t\t\t\x11#!+\x1b\x1f/!)!)#\x1f\x19\xb1}\x87\x1f\x1f\x15\x1d\x15\x13%)9i\x13+\r\x15\x17\x1f\x17\x11\x11\x15\x19)\x0f\x0b\x11builtin\x00vhlo\x00module\x00broadcast_in_dim_v1\x00constant_v1\x00select_v1\x00func_v1\x00iota_v1\x00reshape_v1\x00custom_call_v1\x00compare_v1\x00return_v1\x00value\x00broadcast_dimensions\x00sym_name\x00third_party/py/jax/tests/export_back_compat_test.py\x00jax.uses_shape_polymorphism\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit_func\x00arg_attrs\x00function_type\x00res_attrs\x00sym_visibility\x00iota_dimension\x00jit(func)/jit(main)/iota[dtype=complex128 shape=(16,) dimension=0]\x00jit(func)/jit(main)/reshape[new_sizes=(4, 4) dimensions=None]\x00jit(func)/jit(main)/eig[compute_left_eigenvectors=True compute_right_eigenvectors=True]\x00api_version\x00backend_config\x00call_target_name\x00called_computations\x00has_side_effect\x00mhlo.backend_config\x00operand_layouts\x00output_operand_aliases\x00result_layouts\x00compare_type\x00comparison_direction\x00jax.result_info\x00mhlo.layout_mode\x00default\x00[0]\x00[1]\x00[2]\x00main\x00public\x00\x00lapack_zgeev_ffi\x00compute_left\x00compute_right\x00', + xla_call_module_version=9, + nr_devices=1, +) # End paste + + +# Pasted from the test output (see export_back_compat_test_util.py module docstring) +data_2024_08_19["c64"] = dict( + testdata_version=1, + platform='cpu', + custom_call_targets=['lapack_cgeev_ffi'], + serialized_date=datetime.date(2024, 8, 19), + inputs=(), + expected_outputs=(array([ 3.2464249e+01+0.j, -2.4642491e+00+0.j, -8.1492220e-07+0.j, + 3.0721142e-07+0.j], dtype=complex64), array([[ 0.40377736 +0.j, 0.8288328 +0.j, -0.53676015 +0.j, + 0.07707452 -0.j], + [ 0.4648074 +0.j, 0.43714643 -0.j, 0.79694915 +0.j, + -0.5069523 +0.j], + [ 0.52583736 +0.j, 0.04545992 -0.j, 0.016383484+0.j, + 0.7826807 +0.j], + [ 0.5868672 +0.j, -0.34622622 +0.j, -0.2765721 +0.j, + -0.35280296 +0.j]], dtype=complex64), array([[ 0.114176415+0.j, 0.73277825 +0.j, -0.54227245 +0.j, + -0.109032825+0.j], + [ 0.3300045 +0.j, 0.2897482 -0.j, 0.6655821 +0.j, + -0.25470036 +0.j], + [ 0.5458329 +0.j, -0.15328139 +0.j, 0.29565343 +0.j, + 0.83649963 +0.j], + [ 0.7616609 +0.j, -0.59631103 +0.j, -0.4189632 +0.j, + -0.47276634 +0.j]], dtype=complex64)), + mlir_module_text=r""" +module @jit_func attributes {jax.uses_shape_polymorphism = false, mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { + func.func public @main() -> (tensor<4xcomplex> {jax.result_info = "[0]", mhlo.layout_mode = "default"}, tensor<4x4xcomplex> {jax.result_info = "[1]", mhlo.layout_mode = "default"}, tensor<4x4xcomplex> {jax.result_info = "[2]", mhlo.layout_mode = "default"}) { + %0 = stablehlo.iota dim = 0 : tensor<16xcomplex> loc(#loc3) + %1 = stablehlo.reshape %0 : (tensor<16xcomplex>) -> tensor<4x4xcomplex> loc(#loc4) + %c = stablehlo.constant dense<4> : tensor loc(#loc5) + %c_0 = stablehlo.constant dense<4> : tensor loc(#loc5) + %2:4 = stablehlo.custom_call @lapack_cgeev_ffi(%1) {mhlo.backend_config = {compute_left = 86 : ui8, compute_right = 86 : ui8}, operand_layouts = [dense<[0, 1]> : tensor<2xindex>], result_layouts = [dense<0> : tensor<1xindex>, dense<[0, 1]> : tensor<2xindex>, dense<[0, 1]> : tensor<2xindex>, dense<> : tensor<0xindex>]} : (tensor<4x4xcomplex>) -> (tensor<4xcomplex>, tensor<4x4xcomplex>, tensor<4x4xcomplex>, tensor) loc(#loc5) + %c_1 = stablehlo.constant dense<0> : tensor loc(#loc5) + %3 = stablehlo.broadcast_in_dim %c_1, dims = [] : (tensor) -> tensor loc(#loc5) + %4 = stablehlo.compare EQ, %2#3, %3, SIGNED : (tensor, tensor) -> tensor loc(#loc5) + %5 = stablehlo.broadcast_in_dim %4, dims = [] : (tensor) -> tensor<1xi1> loc(#loc5) + %cst = stablehlo.constant dense<(0x7FC00000,0x7FC00000)> : tensor> loc(#loc5) + %6 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor>) -> tensor<4xcomplex> loc(#loc5) + %7 = stablehlo.broadcast_in_dim %5, dims = [0] : (tensor<1xi1>) -> tensor<4xi1> loc(#loc5) + %8 = stablehlo.select %7, %2#0, %6 : tensor<4xi1>, tensor<4xcomplex> loc(#loc5) + %9 = stablehlo.broadcast_in_dim %4, dims = [] : (tensor) -> tensor<1x1xi1> loc(#loc5) + %cst_2 = stablehlo.constant dense<(0x7FC00000,0x7FC00000)> : tensor> loc(#loc5) + %10 = stablehlo.broadcast_in_dim %cst_2, dims = [] : (tensor>) -> tensor<4x4xcomplex> loc(#loc5) + %11 = stablehlo.broadcast_in_dim %9, dims = [0, 1] : (tensor<1x1xi1>) -> tensor<4x4xi1> loc(#loc5) + %12 = stablehlo.select %11, %2#1, %10 : tensor<4x4xi1>, tensor<4x4xcomplex> loc(#loc5) + %13 = stablehlo.broadcast_in_dim %4, dims = [] : (tensor) -> tensor<1x1xi1> loc(#loc5) + %cst_3 = stablehlo.constant dense<(0x7FC00000,0x7FC00000)> : tensor> loc(#loc5) + %14 = stablehlo.broadcast_in_dim %cst_3, dims = [] : (tensor>) -> tensor<4x4xcomplex> loc(#loc5) + %15 = stablehlo.broadcast_in_dim %13, dims = [0, 1] : (tensor<1x1xi1>) -> tensor<4x4xi1> loc(#loc5) + %16 = stablehlo.select %15, %2#2, %14 : tensor<4x4xi1>, tensor<4x4xcomplex> loc(#loc5) + return %8, %12, %16 : tensor<4xcomplex>, tensor<4x4xcomplex>, tensor<4x4xcomplex> loc(#loc) + } loc(#loc) +} loc(#loc) +#loc = loc(unknown) +#loc1 = loc("third_party/py/jax/tests/export_back_compat_test.py":210:14) +#loc2 = loc("third_party/py/jax/tests/export_back_compat_test.py":211:13) +#loc3 = loc("jit(func)/jit(main)/iota[dtype=complex64 shape=(16,) dimension=0]"(#loc1)) +#loc4 = loc("jit(func)/jit(main)/reshape[new_sizes=(4, 4) dimensions=None]"(#loc1)) +#loc5 = loc("jit(func)/jit(main)/eig[compute_left_eigenvectors=True compute_right_eigenvectors=True]"(#loc2)) +""", + mlir_module_serialized=b'ML\xefR\x01StableHLO_v0.9.0\x00\x01#\x05\x01\x03\x01\x03\x05\x03\x13\x07\t\x0b\r\x0f\x11\x13\x15\x17\x03\xef\xa57\x01]\x0f\x13\x07\x0b\x0b\x13\x0f\x0b\x17\x0b\x13\x13+\x0b\x0f\x0b\x0b\x0b3\x0b\x0b\x0b\x0b\x13\x0b\x0f\x0b\x0f\x0b\x0b\x17S\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x13\x1b\x0b\x0b\x13\x03I\x0b\x0b\x0b\x0bO\x0f\x0b\x17\x1b\x0b\x1b\x0b\x1b\x0b\x0b\x0b\x0f/\x0b\x0b\x0b\x0b\x1b\x0b\x0b\x0f\x1b/\x0f\x1f\x0f\x0b\x0b//O\x01\x05\x0b\x0f\x033\x17\x07\x07\x13\x0b\x0f\x0f\x0f\x07\x17\x17\x1b\x07\x13\x07\x07\x13\x13\x13\x13\x0f\x13\x13\x13\x13\x02\x86\x06\x1d;=\x03\x03\t\x99\x1f\x05\x19\x05\x1b\x03\x03\x07\x9f\x11\x03\x05\x05\x1d\x17\x13J\x03\x1d\x05\x1f\x03\x03\x07\x7f\x03\x03\t\xa3\x03\t\x1b\x1d\x1f\r!\r\x0f#\x05!\x11\x01\x00\x05#\x05%\x05\'\x03\x0b\'])i+k\x0fy-{\x05)\x05+\x05-\x05/\x03\x031}\x051\x1d5\x11\x053\x1d9\x11\x055\x057\x17\x13N\x03\x1b\x03\x13A\x81C\x83E\x85G]I\x87K\x89M\x8fO]Q\x91\x059\x05;\x05=\x05?\x05A\x05C\x05E\x05G\x05I\x03\x03\x07\x97\x03\x05W\x9bY\x9d\x05K\x05M\x03\x03\t\xa1\x03\x01\x1dO\x1dQ\x1dS\x1f%!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x13#V#\x1b\x03\x07mqu\r\x05_oac\x1dU\r\x05_sac\x1dW\r\x05_wac\x1dY\x1d[\x1d]\x13\x07\x01\x1f\x13\x11\x04\x00\x00\x00\x00\x00\x00\x00\x0b\x03\x1d_\x1da\x05\x01\r\x05\x8bg\x8dg\x1dc\x1de\x03\x03e\x03\t\x93ee\x95\x1f\'\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1f)\x01\x1f\x0f\t\x00\x00\x00\x00\x1f+\x01\t\x07\x07\x01\x1f\x11\x11\x00\x00\xc0\x7f\x00\x00\xc0\x7f\x1f3\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1f5!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x01\t\x01\x02\x02)\x05\x11\x11\r\x1d\x01)\x03\x11\r\x03\x1d)\x01!)\x01\r)\x01\x07\x13)\x05\x05\x05\t)\x05\x11\x11\t\x11\x01\x07\x0b\x05\x05\t)\x03A\r\x1b!)\x03\t\x15)\x03\x05\x15)\x03\x01\x15)\x03\x01\x07)\x01\t)\x03\x05\t)\x03\x11\t)\x03\x05\x07)\x03\t\x07\x04"\x03\x05\x01\x11\x05\x19\x07\x03\x01\x05\t\x11\x05%\x07\x035a\x0b\x033/\x03\x1f\r\x067\x03\x05\x03\x01\x05\x03\x01\x15\x03\x13\x05\x03\x01\x15\x03\x13\x0f\x07\x01?\t\x0b\x05\x05\x0f\x03\x03\x05\x03\x01S\x03\x0f\x03\x07\x01\x03\x03\x0f\x03\x11\x11\x07\x01U\x03-\x05\x0f\x13\x03\x07\x01\x03\x03/\x03\x15\x05\x03\x01\x0b\x03\x11\x03\x07\x01\x03\x03\x0b\x03\x19\x03\x07\x01[\x031\x03\x17\x07\x06\x01\x03\x0b\x07\x1d\t\x1b\x03\x07\x01\x03\x03\x17\x03\x15\x05\x03\x01\x0b\x03\x11\x03\x07\x01\x03\x03\x05\x03#\x03\x07\x01\x17\x03\x19\x03!\x07\x06\x01\x03\x05\x07\'\x0b%\x03\x07\x01\x03\x03\x17\x03\x15\x05\x03\x01\x0b\x03\x11\x03\x07\x01\x03\x03\x05\x03-\x03\x07\x01\x17\x03\x19\x03+\x07\x06\x01\x03\x05\x071\r/\x13\x04\x05\x07\x1f)3\x06\x03\x01\x05\x01\x00Z\x0eg\x1d\x1b#\x03\x0f\x0b\t\t\t\x11#!+\x1b\x1f/!)!)#\x1f\x19\xb1}\x85\x1f\x1f\x15\x1d\x15\x13%)9i\x13+\r\x15\x17\x1f\x17\x11\x11\x15\x19)\x0f\x0b\x11builtin\x00vhlo\x00module\x00broadcast_in_dim_v1\x00constant_v1\x00select_v1\x00func_v1\x00iota_v1\x00reshape_v1\x00custom_call_v1\x00compare_v1\x00return_v1\x00value\x00broadcast_dimensions\x00sym_name\x00third_party/py/jax/tests/export_back_compat_test.py\x00jax.uses_shape_polymorphism\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit_func\x00arg_attrs\x00function_type\x00res_attrs\x00sym_visibility\x00iota_dimension\x00jit(func)/jit(main)/iota[dtype=complex64 shape=(16,) dimension=0]\x00jit(func)/jit(main)/reshape[new_sizes=(4, 4) dimensions=None]\x00jit(func)/jit(main)/eig[compute_left_eigenvectors=True compute_right_eigenvectors=True]\x00api_version\x00backend_config\x00call_target_name\x00called_computations\x00has_side_effect\x00mhlo.backend_config\x00operand_layouts\x00output_operand_aliases\x00result_layouts\x00compare_type\x00comparison_direction\x00jax.result_info\x00mhlo.layout_mode\x00default\x00[0]\x00[1]\x00[2]\x00main\x00public\x00\x00lapack_cgeev_ffi\x00compute_left\x00compute_right\x00', + xla_call_module_version=9, + nr_devices=1, +) # End paste + + +# Pasted from the test output (see export_back_compat_test_util.py module docstring) +data_2024_08_19["f32"] = dict( + testdata_version=1, + platform='cpu', + custom_call_targets=['lapack_sgeev_ffi'], + serialized_date=datetime.date(2024, 8, 19), + inputs=(), + expected_outputs=(array([ 3.2464241e+01+0.j, -2.4642482e+00+0.j, -4.5555478e-07+0.j, + 2.9215252e-07+0.j], dtype=complex64), array([[-0.40377742+0.j, 0.8288328 +0.j, -0.5253654 +0.j, + -0.11065983+0.j], + [-0.46480736+0.j, 0.43714654+0.j, 0.8159359 +0.j, + 0.547376 +0.j], + [-0.52583736+0.j, 0.04545998+0.j, -0.0557748 +0.j, + -0.7627722 +0.j], + [-0.5868672 +0.j, -0.34622627+0.j, -0.23479532+0.j, + 0.32605612+0.j]], dtype=complex64), array([[-0.114176415+0.j, 0.7327782 +0.j, -0.5364275 +0.j, + 0.15489015 +0.j], + [-0.33000445 +0.j, 0.28974816 +0.j, 0.6327556 +0.j, + 0.18506403 +0.j], + [-0.54583275 +0.j, -0.15328142 +0.j, 0.34377125 +0.j, + -0.83479893 +0.j], + [-0.761661 +0.j, -0.5963111 +0.j, -0.44009918 +0.j, + 0.49484456 +0.j]], dtype=complex64)), + mlir_module_text=r""" +module @jit_func attributes {jax.uses_shape_polymorphism = false, mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { + func.func public @main() -> (tensor<4xcomplex> {jax.result_info = "[0]", mhlo.layout_mode = "default"}, tensor<4x4xcomplex> {jax.result_info = "[1]", mhlo.layout_mode = "default"}, tensor<4x4xcomplex> {jax.result_info = "[2]", mhlo.layout_mode = "default"}) { + %0 = stablehlo.iota dim = 0 : tensor<16xf32> loc(#loc3) + %1 = stablehlo.reshape %0 : (tensor<16xf32>) -> tensor<4x4xf32> loc(#loc4) + %c = stablehlo.constant dense<4> : tensor loc(#loc5) + %c_0 = stablehlo.constant dense<4> : tensor loc(#loc5) + %2:5 = stablehlo.custom_call @lapack_sgeev_ffi(%1) {mhlo.backend_config = {compute_left = 86 : ui8, compute_right = 86 : ui8}, operand_layouts = [dense<[0, 1]> : tensor<2xindex>], result_layouts = [dense<0> : tensor<1xindex>, dense<0> : tensor<1xindex>, dense<[0, 1]> : tensor<2xindex>, dense<[0, 1]> : tensor<2xindex>, dense<> : tensor<0xindex>]} : (tensor<4x4xf32>) -> (tensor<4xf32>, tensor<4xf32>, tensor<4x4xcomplex>, tensor<4x4xcomplex>, tensor) loc(#loc5) + %3 = stablehlo.complex %2#0, %2#1 : tensor<4xcomplex> loc(#loc5) + %c_1 = stablehlo.constant dense<0> : tensor loc(#loc5) + %4 = stablehlo.broadcast_in_dim %c_1, dims = [] : (tensor) -> tensor loc(#loc5) + %5 = stablehlo.compare EQ, %2#4, %4, SIGNED : (tensor, tensor) -> tensor loc(#loc5) + %6 = stablehlo.broadcast_in_dim %5, dims = [] : (tensor) -> tensor<1xi1> loc(#loc5) + %cst = stablehlo.constant dense<(0x7FC00000,0x7FC00000)> : tensor> loc(#loc5) + %7 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor>) -> tensor<4xcomplex> loc(#loc5) + %8 = stablehlo.broadcast_in_dim %6, dims = [0] : (tensor<1xi1>) -> tensor<4xi1> loc(#loc5) + %9 = stablehlo.select %8, %3, %7 : tensor<4xi1>, tensor<4xcomplex> loc(#loc5) + %10 = stablehlo.broadcast_in_dim %5, dims = [] : (tensor) -> tensor<1x1xi1> loc(#loc5) + %cst_2 = stablehlo.constant dense<(0x7FC00000,0x7FC00000)> : tensor> loc(#loc5) + %11 = stablehlo.broadcast_in_dim %cst_2, dims = [] : (tensor>) -> tensor<4x4xcomplex> loc(#loc5) + %12 = stablehlo.broadcast_in_dim %10, dims = [0, 1] : (tensor<1x1xi1>) -> tensor<4x4xi1> loc(#loc5) + %13 = stablehlo.select %12, %2#2, %11 : tensor<4x4xi1>, tensor<4x4xcomplex> loc(#loc5) + %14 = stablehlo.broadcast_in_dim %5, dims = [] : (tensor) -> tensor<1x1xi1> loc(#loc5) + %cst_3 = stablehlo.constant dense<(0x7FC00000,0x7FC00000)> : tensor> loc(#loc5) + %15 = stablehlo.broadcast_in_dim %cst_3, dims = [] : (tensor>) -> tensor<4x4xcomplex> loc(#loc5) + %16 = stablehlo.broadcast_in_dim %14, dims = [0, 1] : (tensor<1x1xi1>) -> tensor<4x4xi1> loc(#loc5) + %17 = stablehlo.select %16, %2#3, %15 : tensor<4x4xi1>, tensor<4x4xcomplex> loc(#loc5) + return %9, %13, %17 : tensor<4xcomplex>, tensor<4x4xcomplex>, tensor<4x4xcomplex> loc(#loc) + } loc(#loc) +} loc(#loc) +#loc = loc(unknown) +#loc1 = loc("third_party/py/jax/tests/export_back_compat_test.py":210:14) +#loc2 = loc("third_party/py/jax/tests/export_back_compat_test.py":211:13) +#loc3 = loc("jit(func)/jit(main)/iota[dtype=float32 shape=(16,) dimension=0]"(#loc1)) +#loc4 = loc("jit(func)/jit(main)/reshape[new_sizes=(4, 4) dimensions=None]"(#loc1)) +#loc5 = loc("jit(func)/jit(main)/eig[compute_left_eigenvectors=True compute_right_eigenvectors=True]"(#loc2)) +""", + mlir_module_serialized=b"ML\xefR\x01StableHLO_v0.9.0\x00\x01%\x05\x01\x03\x01\x03\x05\x03\x15\x07\t\x0b\r\x0f\x11\x13\x15\x17\x19\x03\xf3\xa5;\x01]\x0f\x13\x07\x0b\x0b\x13\x0f\x0b\x17\x0b\x13\x13+\x0b\x0f\x0b\x0b\x0b3\x0b\x0b\x0b\x0b\x13\x0b\x0f\x0b\x0f\x0b\x0b\x17S\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x13\x1b\x0b\x0b\x13\x03I\x0b\x0b\x0b\x0bO\x0f/\x0b\x17\x1b\x0b\x1b\x0b\x1b\x0b\x0b\x0b\x0f/\x0b\x0b\x0b\x0b\x1b\x0b\x0b\x0f\x1f\x0f\x1f\x0f\x0b\x0b//O\x01\x05\x0b\x0f\x037\x17\x07\x07\x13\x07\x0f\x0f\x0b\x0f\x07\x13\x17\x17\x1b\x13\x17\x07\x07\x13\x13\x13\x13\x0f\x13\x13\x13\x13\x02\xae\x06\x1d;=\x03\x03\t\x99\x1f\x05\x1b\x05\x1d\x03\x03\x07\x9f\x11\x03\x05\x05\x1f\x17\x13J\x03\x1d\x05!\x03\x03\x07\x81\x03\x03\t\xa3\x03\t\x1b\x1d\x1f\r!\r\x0f#\x05#\x11\x01\x00\x05%\x05'\x05)\x03\x0b'])k+m\x0f{-}\x05+\x05-\x05/\x051\x03\x031\x7f\x053\x1d5\x11\x055\x1d9\x11\x057\x059\x17\x13N\x03\x1b\x03\x13A\x83C\x85E\x87G]I\x89K\x8bM\x91O]Q\x93\x05;\x05=\x05?\x05A\x05C\x05E\x05G\x05I\x05K\x03\x03\x07\x97\x03\x05W\x9bY\x9d\x05M\x05O\x03\x03\t\xa1\x03\x01\x1dQ\x1dS\x1dU\x1f)!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x13'V\x1f+\x11\x00\x00\x00\x00\x00\x00\x00\x00#\x1f\x03\x07osw\r\x05_qac\x1dW\r\x05_uac\x1dY\r\x05_yac\x1d[\x1d]\x1d_\x13\x07\x01\x1f\x15\x11\x04\x00\x00\x00\x00\x00\x00\x00\x0b\x03\x1da\x1dc\x05\x01\r\x05\x8dg\x8fg\x1de\x1dg\x03\x03e\x03\x0biiee\x95\x1f-\x01\x1f\x0f\t\x00\x00\x00\x00\x1f/\x01\t\x07\x07\x01\x1f\x11\x11\x00\x00\xc0\x7f\x00\x00\xc0\x7f\x1f7\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1f9!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x01\t\x01\x02\x02)\x05\x11\x11\x13\x1d\x01)\x03\x11\x13\t)\x01%)\x01\x13\x03\r)\x01\x07\x13)\x03\x11\r)\x05\x05\x05\t)\x05\x11\x11\t\x11\x01\x07\x0b\x05\x05)\x03A\r)\x05\x11\x11\r\x1b!)\x03\t\x17)\x03\x05\x17)\x03\x01\x17)\x03\x01\x07)\x01\t)\x03\x05\t)\x03\x11\t)\x03\x05\x07)\x03\t\x07\x04F\x03\x05\x01\x11\x05\x19\x07\x03\x01\x05\t\x11\x05%\x07\x039e\x0b\x033/\x03!\r\x067\x03#\x03\x01\x05\x03\x01\x15\x03\x15\x05\x03\x01\x15\x03\x15\x0f\x07\x01?\x0b\x19\x19\x05\x05\x0f\x03\x03\x11\x06\x01\x03\x0b\x05\t\x0b\x05\x03\x01S\x03\x0f\x03\x07\x01\x03\x03\x0f\x03\x15\x13\x07\x01U\x031\x05\x11\x17\x03\x07\x01\x03\x033\x03\x19\x05\x03\x01\x0b\x03\x11\x03\x07\x01\x03\x03\x0b\x03\x1d\x03\x07\x01[\x035\x03\x1b\x07\x06\x01\x03\x0b\x07!\x13\x1f\x03\x07\x01\x03\x03\x1b\x03\x19\x05\x03\x01\x0b\x03\x11\x03\x07\x01\x03\x03\x05\x03'\x03\x07\x01\x17\x03\x1d\x03%\x07\x06\x01\x03\x05\x07+\r)\x03\x07\x01\x03\x03\x1b\x03\x19\x05\x03\x01\x0b\x03\x11\x03\x07\x01\x03\x03\x05\x031\x03\x07\x01\x17\x03\x1d\x03/\x07\x06\x01\x03\x05\x075\x0f3\x15\x04\x05\x07#-7\x06\x03\x01\x05\x01\x00\x82\x0ei\x1d\x1b#\x03\x0f\x0b\t\t\t\x11#!+\x1b\x1f/!)!)#\x1f\x19\xb1}\x81\x1f\x1f\x15\x1d\x15\x13%)9i\x13+\r\x15\x17\x17\x1f\x17\x11\x11\x15\x19)\x0f\x0b\x11builtin\x00vhlo\x00module\x00broadcast_in_dim_v1\x00constant_v1\x00select_v1\x00func_v1\x00iota_v1\x00reshape_v1\x00custom_call_v1\x00complex_v1\x00compare_v1\x00return_v1\x00value\x00broadcast_dimensions\x00sym_name\x00third_party/py/jax/tests/export_back_compat_test.py\x00jax.uses_shape_polymorphism\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit_func\x00arg_attrs\x00function_type\x00res_attrs\x00sym_visibility\x00iota_dimension\x00jit(func)/jit(main)/iota[dtype=float32 shape=(16,) dimension=0]\x00jit(func)/jit(main)/reshape[new_sizes=(4, 4) dimensions=None]\x00jit(func)/jit(main)/eig[compute_left_eigenvectors=True compute_right_eigenvectors=True]\x00api_version\x00backend_config\x00call_target_name\x00called_computations\x00has_side_effect\x00mhlo.backend_config\x00operand_layouts\x00output_operand_aliases\x00result_layouts\x00compare_type\x00comparison_direction\x00jax.result_info\x00mhlo.layout_mode\x00default\x00[0]\x00[1]\x00[2]\x00main\x00public\x00\x00lapack_sgeev_ffi\x00compute_left\x00compute_right\x00", + xla_call_module_version=9, + nr_devices=1, +) # End paste + + +# Pasted from the test output (see export_back_compat_test_util.py module docstring) +data_2024_08_19["f64"] = dict( + testdata_version=1, + platform='cpu', + custom_call_targets=['lapack_dgeev_ffi'], + serialized_date=datetime.date(2024, 8, 19), + inputs=(), + expected_outputs=(array([ 3.2464249196572972e+01+0.j, -2.4642491965729789e+00+0.j, + -1.4885423746029788e-15+0.j, 4.7495173217146935e-16+0.j]), array([[-0.40377749076862246 +0.j, -0.8288327563197503 +0.j, + -0.541090767303977 +0.j, 0.10767692008040902 +0.j], + [-0.4648073711584901 +0.j, -0.43714638836388775 +0.j, + 0.7847911174458492 +0.j, -0.5438508504687168 +0.j], + [-0.5258372515483576 +0.j, -0.045460020408024666+0.j, + 0.05369006702023438 +0.j, 0.7646709406962073 +0.j], + [-0.5868671319382248 +0.j, 0.34622634754783854 +0.j, + -0.2973904171621061 +0.j, -0.32849701030789913 +0.j]]), array([[-0.11417645138733848+0.j, -0.7327780959803556 +0.j, + -0.5370341524353898 +0.j, -0.0849751818967924 +0.j], + [-0.33000459866554754+0.j, -0.2897483523969262 +0.j, + 0.6357878989446506 +0.j, -0.29000500336734825+0.j], + [-0.545832745943757 +0.j, 0.15328139118650214+0.j, + 0.33952665941686755+0.j, 0.8349355524250736 +0.j], + [-0.7616608932219664 +0.j, 0.5963111347699303 +0.j, + -0.43828040592612855+0.j, -0.45995536716093305+0.j]])), + mlir_module_text=r""" +module @jit_func attributes {jax.uses_shape_polymorphism = false, mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { + func.func public @main() -> (tensor<4xcomplex> {jax.result_info = "[0]", mhlo.layout_mode = "default"}, tensor<4x4xcomplex> {jax.result_info = "[1]", mhlo.layout_mode = "default"}, tensor<4x4xcomplex> {jax.result_info = "[2]", mhlo.layout_mode = "default"}) { + %0 = stablehlo.iota dim = 0 : tensor<16xf64> loc(#loc3) + %1 = stablehlo.reshape %0 : (tensor<16xf64>) -> tensor<4x4xf64> loc(#loc4) + %c = stablehlo.constant dense<4> : tensor loc(#loc5) + %c_0 = stablehlo.constant dense<4> : tensor loc(#loc5) + %2:5 = stablehlo.custom_call @lapack_dgeev_ffi(%1) {mhlo.backend_config = {compute_left = 86 : ui8, compute_right = 86 : ui8}, operand_layouts = [dense<[0, 1]> : tensor<2xindex>], result_layouts = [dense<0> : tensor<1xindex>, dense<0> : tensor<1xindex>, dense<[0, 1]> : tensor<2xindex>, dense<[0, 1]> : tensor<2xindex>, dense<> : tensor<0xindex>]} : (tensor<4x4xf64>) -> (tensor<4xf64>, tensor<4xf64>, tensor<4x4xcomplex>, tensor<4x4xcomplex>, tensor) loc(#loc5) + %3 = stablehlo.complex %2#0, %2#1 : tensor<4xcomplex> loc(#loc5) + %c_1 = stablehlo.constant dense<0> : tensor loc(#loc5) + %4 = stablehlo.broadcast_in_dim %c_1, dims = [] : (tensor) -> tensor loc(#loc5) + %5 = stablehlo.compare EQ, %2#4, %4, SIGNED : (tensor, tensor) -> tensor loc(#loc5) + %6 = stablehlo.broadcast_in_dim %5, dims = [] : (tensor) -> tensor<1xi1> loc(#loc5) + %cst = stablehlo.constant dense<(0x7FF8000000000000,0x7FF8000000000000)> : tensor> loc(#loc5) + %7 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor>) -> tensor<4xcomplex> loc(#loc5) + %8 = stablehlo.broadcast_in_dim %6, dims = [0] : (tensor<1xi1>) -> tensor<4xi1> loc(#loc5) + %9 = stablehlo.select %8, %3, %7 : tensor<4xi1>, tensor<4xcomplex> loc(#loc5) + %10 = stablehlo.broadcast_in_dim %5, dims = [] : (tensor) -> tensor<1x1xi1> loc(#loc5) + %cst_2 = stablehlo.constant dense<(0x7FF8000000000000,0x7FF8000000000000)> : tensor> loc(#loc5) + %11 = stablehlo.broadcast_in_dim %cst_2, dims = [] : (tensor>) -> tensor<4x4xcomplex> loc(#loc5) + %12 = stablehlo.broadcast_in_dim %10, dims = [0, 1] : (tensor<1x1xi1>) -> tensor<4x4xi1> loc(#loc5) + %13 = stablehlo.select %12, %2#2, %11 : tensor<4x4xi1>, tensor<4x4xcomplex> loc(#loc5) + %14 = stablehlo.broadcast_in_dim %5, dims = [] : (tensor) -> tensor<1x1xi1> loc(#loc5) + %cst_3 = stablehlo.constant dense<(0x7FF8000000000000,0x7FF8000000000000)> : tensor> loc(#loc5) + %15 = stablehlo.broadcast_in_dim %cst_3, dims = [] : (tensor>) -> tensor<4x4xcomplex> loc(#loc5) + %16 = stablehlo.broadcast_in_dim %14, dims = [0, 1] : (tensor<1x1xi1>) -> tensor<4x4xi1> loc(#loc5) + %17 = stablehlo.select %16, %2#3, %15 : tensor<4x4xi1>, tensor<4x4xcomplex> loc(#loc5) + return %9, %13, %17 : tensor<4xcomplex>, tensor<4x4xcomplex>, tensor<4x4xcomplex> loc(#loc) + } loc(#loc) +} loc(#loc) +#loc = loc(unknown) +#loc1 = loc("third_party/py/jax/tests/export_back_compat_test.py":210:14) +#loc2 = loc("third_party/py/jax/tests/export_back_compat_test.py":211:13) +#loc3 = loc("jit(func)/jit(main)/iota[dtype=float64 shape=(16,) dimension=0]"(#loc1)) +#loc4 = loc("jit(func)/jit(main)/reshape[new_sizes=(4, 4) dimensions=None]"(#loc1)) +#loc5 = loc("jit(func)/jit(main)/eig[compute_left_eigenvectors=True compute_right_eigenvectors=True]"(#loc2)) +""", + mlir_module_serialized=b"ML\xefR\x01StableHLO_v0.9.0\x00\x01%\x05\x01\x03\x01\x03\x05\x03\x15\x07\t\x0b\r\x0f\x11\x13\x15\x17\x19\x03\xf3\xa5;\x01]\x0f\x13\x07\x0b\x0b\x13\x0f\x0b\x17\x0b\x13\x13+\x0b\x0f\x0b\x0b\x0b3\x0b\x0b\x0b\x0b\x13\x0b\x0f\x0b\x0f\x0b\x0b\x17S\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x13\x1b\x0b\x0b\x13\x03I\x0b\x0b\x0b\x0bO\x0f/\x0b\x17\x1b\x0b\x1b\x0b\x1b\x0b\x0b\x0b\x0f/\x0b\x0b\x0b\x0b\x1b\x0b\x0b\x0f\x1f\x0f\x1f\x0f\x0b\x0bO/O\x01\x05\x0b\x0f\x037\x17\x07\x07\x13\x07\x0f\x0f\x0b\x0f\x07\x13\x17\x17\x1b\x13\x17\x07\x07\x13\x13\x13\x13\x0f\x13\x13\x13\x13\x02\xce\x06\x1d;=\x03\x03\t\x99\x1f\x05\x1b\x05\x1d\x03\x03\x07\x9f\x11\x03\x05\x05\x1f\x17\x13J\x03\x1d\x05!\x03\x03\x07\x81\x03\x03\t\xa3\x03\t\x1b\x1d\x1f\r!\r\x0f#\x05#\x11\x01\x00\x05%\x05'\x05)\x03\x0b'])k+m\x0f{-}\x05+\x05-\x05/\x051\x03\x031\x7f\x053\x1d5\x11\x055\x1d9\x11\x057\x059\x17\x13N\x03\x1b\x03\x13A\x83C\x85E\x87G]I\x89K\x8bM\x91O]Q\x93\x05;\x05=\x05?\x05A\x05C\x05E\x05G\x05I\x05K\x03\x03\x07\x97\x03\x05W\x9bY\x9d\x05M\x05O\x03\x03\t\xa1\x03\x01\x1dQ\x1dS\x1dU\x1f)!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x13'V\x1f+\x11\x00\x00\x00\x00\x00\x00\x00\x00#\x1f\x03\x07osw\r\x05_qac\x1dW\r\x05_uac\x1dY\r\x05_yac\x1d[\x1d]\x1d_\x13\x07\x01\x1f\x15\x11\x04\x00\x00\x00\x00\x00\x00\x00\x0b\x03\x1da\x1dc\x05\x01\r\x05\x8dg\x8fg\x1de\x1dg\x03\x03e\x03\x0biiee\x95\x1f-\x01\x1f\x0f\t\x00\x00\x00\x00\x1f/\x01\t\x07\x07\x01\x1f\x11!\x00\x00\x00\x00\x00\x00\xf8\x7f\x00\x00\x00\x00\x00\x00\xf8\x7f\x1f7\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1f9!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x01\t\x01\x02\x02)\x05\x11\x11\x13\x1d\x01)\x03\x11\x13\x0b)\x01%)\x01\x13\x03\r)\x01\x07\x13)\x03\x11\r)\x05\x05\x05\t)\x05\x11\x11\t\x11\x01\x07\x0b\x05\x05)\x03A\r)\x05\x11\x11\r\x1b!)\x03\t\x17)\x03\x05\x17)\x03\x01\x17)\x03\x01\x07)\x01\t)\x03\x05\t)\x03\x11\t)\x03\x05\x07)\x03\t\x07\x04F\x03\x05\x01\x11\x05\x19\x07\x03\x01\x05\t\x11\x05%\x07\x039e\x0b\x033/\x03!\r\x067\x03#\x03\x01\x05\x03\x01\x15\x03\x15\x05\x03\x01\x15\x03\x15\x0f\x07\x01?\x0b\x19\x19\x05\x05\x0f\x03\x03\x11\x06\x01\x03\x0b\x05\t\x0b\x05\x03\x01S\x03\x0f\x03\x07\x01\x03\x03\x0f\x03\x15\x13\x07\x01U\x031\x05\x11\x17\x03\x07\x01\x03\x033\x03\x19\x05\x03\x01\x0b\x03\x11\x03\x07\x01\x03\x03\x0b\x03\x1d\x03\x07\x01[\x035\x03\x1b\x07\x06\x01\x03\x0b\x07!\x13\x1f\x03\x07\x01\x03\x03\x1b\x03\x19\x05\x03\x01\x0b\x03\x11\x03\x07\x01\x03\x03\x05\x03'\x03\x07\x01\x17\x03\x1d\x03%\x07\x06\x01\x03\x05\x07+\r)\x03\x07\x01\x03\x03\x1b\x03\x19\x05\x03\x01\x0b\x03\x11\x03\x07\x01\x03\x03\x05\x031\x03\x07\x01\x17\x03\x1d\x03/\x07\x06\x01\x03\x05\x075\x0f3\x15\x04\x05\x07#-7\x06\x03\x01\x05\x01\x00\x82\x0ei\x1d\x1b#\x03\x0f\x0b\t\t\t\x11#!+\x1b\x1f/!)!)#\x1f\x19\xb1}\x81\x1f\x1f\x15\x1d\x15\x13%)9i\x13+\r\x15\x17\x17\x1f\x17\x11\x11\x15\x19)\x0f\x0b\x11builtin\x00vhlo\x00module\x00broadcast_in_dim_v1\x00constant_v1\x00select_v1\x00func_v1\x00iota_v1\x00reshape_v1\x00custom_call_v1\x00complex_v1\x00compare_v1\x00return_v1\x00value\x00broadcast_dimensions\x00sym_name\x00third_party/py/jax/tests/export_back_compat_test.py\x00jax.uses_shape_polymorphism\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit_func\x00arg_attrs\x00function_type\x00res_attrs\x00sym_visibility\x00iota_dimension\x00jit(func)/jit(main)/iota[dtype=float64 shape=(16,) dimension=0]\x00jit(func)/jit(main)/reshape[new_sizes=(4, 4) dimensions=None]\x00jit(func)/jit(main)/eig[compute_left_eigenvectors=True compute_right_eigenvectors=True]\x00api_version\x00backend_config\x00call_target_name\x00called_computations\x00has_side_effect\x00mhlo.backend_config\x00operand_layouts\x00output_operand_aliases\x00result_layouts\x00compare_type\x00comparison_direction\x00jax.result_info\x00mhlo.layout_mode\x00default\x00[0]\x00[1]\x00[2]\x00main\x00public\x00\x00lapack_dgeev_ffi\x00compute_left\x00compute_right\x00", + xla_call_module_version=9, + nr_devices=1, +) # End paste diff --git a/jax/_src/internal_test_util/export_back_compat_test_data/cpu_eigh_lapack_syev.py b/jax/_src/internal_test_util/export_back_compat_test_data/cpu_eigh_lapack_syev.py index fcc32058bbee..f0696db1aeda 100644 --- a/jax/_src/internal_test_util/export_back_compat_test_data/cpu_eigh_lapack_syev.py +++ b/jax/_src/internal_test_util/export_back_compat_test_data/cpu_eigh_lapack_syev.py @@ -383,3 +383,446 @@ xla_call_module_version=4, ), # End paste ) + +data_2024_08_19 = {} + + +# Pasted from the test output (see export_back_compat_test_util.py module docstring) +data_2024_08_19["c128"] = dict( + testdata_version=1, + platform='cpu', + custom_call_targets=['lapack_zheevd_ffi'], + serialized_date=datetime.date(2024, 8, 19), + inputs=(), + expected_outputs=(array([[-6.1857700048412056e-01+0.j, 2.4081403770912022e-01+0.j, + 3.5662489253627483e-01+0.j, -6.3034019033669797e-01+0.j, + 1.0043483479985752e-16+0.j, -2.8842036081919542e-02+0.j, + 7.7164692943283169e-25+0.j, -1.8446994643771725e-01+0.j], + [-4.7070881487314609e-01+0.j, 4.7473787464450828e-01+0.j, + -4.8036836210243361e-01+0.j, 4.3802686872516400e-01+0.j, + 1.7961797619639255e-01+0.j, 8.3080980076741355e-03+0.j, + 2.1415294457221759e-01+0.j, -2.2856669794666584e-01+0.j], + [-3.2284062926217072e-01+0.j, -5.4336490915553370e-01+0.j, + 2.2181041859724987e-01+0.j, 2.9947877954402286e-01+0.j, + -3.6491813600134637e-01+0.j, 3.2867679819727436e-01+0.j, + 3.8223299448843473e-01+0.j, -2.7266344945561438e-01+0.j], + [-1.7497244365119527e-01+0.j, -8.9251550609769331e-02+0.j, + -6.3518515114898352e-02+0.j, 1.9162997359209963e-01+0.j, + -2.2087281326110142e-01+0.j, 5.9957027043505008e-02+0.j, + -8.7632498908241274e-01+0.j, -3.1676020096456303e-01+0.j], + [-2.7104258040220017e-02+0.j, -3.3772873786627688e-01+0.j, + 2.5901386593721754e-01+0.j, 1.7032650752287815e-01+0.j, + 6.7521217612940321e-01+0.j, -4.5036136532965476e-01+0.j, + -1.2279030059078447e-02+0.j, -3.6085695247351163e-01+0.j], + [ 1.2076392757075533e-01+0.j, -3.3834734096469249e-01+0.j, + -6.5506827461665529e-01+0.j, -5.0472498521116760e-01+0.j, + 6.9987430903492132e-02+0.j, 1.0595648906599270e-01+0.j, + 8.3443844143082035e-02+0.j, -4.0495370398246017e-01+0.j], + [ 2.6863211318173102e-01+0.j, 2.2958613191407312e-01+0.j, + 6.3952843755683969e-02+0.j, 1.8776775771084192e-02+0.j, + -5.3523731432241317e-01+0.j, -5.9199531677602002e-01+0.j, + 1.7916671834524250e-01+0.j, -4.4905045549140887e-01+0.j], + [ 4.1650029879270667e-01+0.j, 3.6355449432857068e-01+0.j, + 2.9755313100756148e-01+0.j, 1.6826270392616000e-02+0.j, + 1.9621068035557282e-01+0.j, 5.6830030587314817e-01+0.j, + 2.9607517592514260e-02+0.j, -4.9314720700035747e-01+0.j]]), array([-2.4598804776133626e+01, -4.6567755957874661e-14, + -1.9932120610662194e-14, -5.7323356091157378e-15, + -4.5459724251334835e-16, 4.0479851042511616e-14, + 9.2325194924982089e-14, 2.7659880477613365e+02])), + mlir_module_text=r""" +#loc7 = loc("third_party/py/jax/tests/export_back_compat_test.py":277:27) +#loc18 = loc("jit()/jit(main)/pjit[in_shardings=(UnspecifiedValue,) out_shardings=(UnspecifiedValue,) in_layouts=(None,) out_layouts=(None,) resource_env=None donated_invars=(False,) name=tril keep_unused=False inline=False]"(#loc7)) +module @jit__lambda_ attributes {jax.uses_shape_polymorphism = false, mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { + func.func public @main() -> (tensor<8x8xcomplex> {jax.result_info = "[0]", mhlo.layout_mode = "default"}, tensor<8xf64> {jax.result_info = "[1]", mhlo.layout_mode = "default"}) { + %0 = stablehlo.iota dim = 0 : tensor<64xcomplex> loc(#loc9) + %1 = stablehlo.reshape %0 : (tensor<64xcomplex>) -> tensor<8x8xcomplex> loc(#loc10) + %2 = stablehlo.transpose %1, dims = [1, 0] : (tensor<8x8xcomplex>) -> tensor<8x8xcomplex> loc(#loc11) + %3 = stablehlo.real %2 : (tensor<8x8xcomplex>) -> tensor<8x8xf64> loc(#loc12) + %4 = stablehlo.imag %2 : (tensor<8x8xcomplex>) -> tensor<8x8xf64> loc(#loc13) + %5 = stablehlo.negate %4 : tensor<8x8xf64> loc(#loc14) + %6 = stablehlo.complex %3, %5 : tensor<8x8xcomplex> loc(#loc15) + %7 = stablehlo.add %1, %6 : tensor<8x8xcomplex> loc(#loc16) + %cst = stablehlo.constant dense<(2.000000e+00,0.000000e+00)> : tensor> loc(#loc) + %8 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor>) -> tensor<8x8xcomplex> loc(#loc17) + %9 = stablehlo.divide %7, %8 : tensor<8x8xcomplex> loc(#loc17) + %10 = call @tril(%9) : (tensor<8x8xcomplex>) -> tensor<8x8xcomplex> loc(#loc18) + %c = stablehlo.constant dense<8> : tensor loc(#loc19) + %c_0 = stablehlo.constant dense<8> : tensor loc(#loc19) + %11:3 = stablehlo.custom_call @lapack_zheevd_ffi(%10) {mhlo.backend_config = {mode = 86 : ui8, uplo = 76 : ui8}, operand_layouts = [dense<[0, 1]> : tensor<2xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<0> : tensor<1xindex>, dense<> : tensor<0xindex>]} : (tensor<8x8xcomplex>) -> (tensor<8x8xcomplex>, tensor<8xf64>, tensor) loc(#loc19) + %c_1 = stablehlo.constant dense<0> : tensor loc(#loc19) + %12 = stablehlo.broadcast_in_dim %c_1, dims = [] : (tensor) -> tensor loc(#loc19) + %13 = stablehlo.compare EQ, %11#2, %12, SIGNED : (tensor, tensor) -> tensor loc(#loc19) + %14 = stablehlo.broadcast_in_dim %13, dims = [] : (tensor) -> tensor<1x1xi1> loc(#loc19) + %cst_2 = stablehlo.constant dense<(0x7FF8000000000000,0x7FF8000000000000)> : tensor> loc(#loc19) + %15 = stablehlo.broadcast_in_dim %cst_2, dims = [] : (tensor>) -> tensor<8x8xcomplex> loc(#loc19) + %16 = stablehlo.broadcast_in_dim %14, dims = [0, 1] : (tensor<1x1xi1>) -> tensor<8x8xi1> loc(#loc19) + %17 = stablehlo.select %16, %11#0, %15 : tensor<8x8xi1>, tensor<8x8xcomplex> loc(#loc19) + %18 = stablehlo.broadcast_in_dim %13, dims = [] : (tensor) -> tensor<1xi1> loc(#loc19) + %cst_3 = stablehlo.constant dense<0x7FF8000000000000> : tensor loc(#loc19) + %19 = stablehlo.broadcast_in_dim %cst_3, dims = [] : (tensor) -> tensor<8xf64> loc(#loc19) + %20 = stablehlo.broadcast_in_dim %18, dims = [0] : (tensor<1xi1>) -> tensor<8xi1> loc(#loc19) + %21 = stablehlo.select %20, %11#1, %19 : tensor<8xi1>, tensor<8xf64> loc(#loc19) + return %17, %21 : tensor<8x8xcomplex>, tensor<8xf64> loc(#loc) + } loc(#loc) + func.func private @tril(%arg0: tensor<8x8xcomplex> {mhlo.layout_mode = "default"} loc("jit()/jit(main)/pjit[in_shardings=(UnspecifiedValue,) out_shardings=(UnspecifiedValue,) in_layouts=(None,) out_layouts=(None,) resource_env=None donated_invars=(False,) name=tril keep_unused=False inline=False]"(#loc7))) -> (tensor<8x8xcomplex> {mhlo.layout_mode = "default"}) { + %0 = stablehlo.iota dim = 0 : tensor<8x8xi32> loc(#loc20) + %c = stablehlo.constant dense<0> : tensor loc(#loc18) + %1 = stablehlo.broadcast_in_dim %c, dims = [] : (tensor) -> tensor<8x8xi32> loc(#loc21) + %2 = stablehlo.add %0, %1 : tensor<8x8xi32> loc(#loc21) + %3 = stablehlo.iota dim = 1 : tensor<8x8xi32> loc(#loc22) + %4 = stablehlo.compare GE, %2, %3, SIGNED : (tensor<8x8xi32>, tensor<8x8xi32>) -> tensor<8x8xi1> loc(#loc23) + %cst = stablehlo.constant dense<(0.000000e+00,0.000000e+00)> : tensor> loc(#loc18) + %5 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor>) -> tensor<8x8xcomplex> loc(#loc24) + %6 = stablehlo.select %4, %arg0, %5 : tensor<8x8xi1>, tensor<8x8xcomplex> loc(#loc25) + return %6 : tensor<8x8xcomplex> loc(#loc18) + } loc(#loc18) +} loc(#loc) +#loc = loc(unknown) +#loc1 = loc("third_party/py/jax/tests/export_back_compat_test.py":269:26) +#loc2 = loc("third_party/py/jax/tests/export_back_compat_test.py":269:14) +#loc3 = loc("third_party/py/jax/tests/export_back_compat_test.py":271:34) +#loc4 = loc("third_party/py/jax/tests/export_back_compat_test.py":271:25) +#loc5 = loc("third_party/py/jax/tests/export_back_compat_test.py":271:15) +#loc6 = loc("third_party/py/jax/tests/export_back_compat_test.py":271:14) +#loc8 = loc("third_party/py/jax/tests/export_back_compat_test.py":277:11) +#loc9 = loc("jit()/jit(main)/iota[dtype=complex128 shape=(64,) dimension=0]"(#loc1)) +#loc10 = loc("jit()/jit(main)/reshape[new_sizes=(8, 8) dimensions=None]"(#loc2)) +#loc11 = loc("jit()/jit(main)/transpose[permutation=(1, 0)]"(#loc3)) +#loc12 = loc("jit()/jit(main)/real"(#loc4)) +#loc13 = loc("jit()/jit(main)/imag"(#loc4)) +#loc14 = loc("jit()/jit(main)/neg"(#loc4)) +#loc15 = loc("jit()/jit(main)/complex"(#loc4)) +#loc16 = loc("jit()/jit(main)/add"(#loc5)) +#loc17 = loc("jit()/jit(main)/div"(#loc6)) +#loc19 = loc("jit()/jit(main)/eigh[lower=True sort_eigenvalues=True subset_by_index=None]"(#loc8)) +#loc20 = loc("jit()/jit(main)/jit(tril)/iota[dtype=int32 shape=(8, 8) dimension=0]"(#loc7)) +#loc21 = loc("jit()/jit(main)/jit(tril)/add"(#loc7)) +#loc22 = loc("jit()/jit(main)/jit(tril)/iota[dtype=int32 shape=(8, 8) dimension=1]"(#loc7)) +#loc23 = loc("jit()/jit(main)/jit(tril)/ge"(#loc7)) +#loc24 = loc("jit()/jit(main)/jit(tril)/broadcast_in_dim[shape=(8, 8) broadcast_dimensions=()]"(#loc7)) +#loc25 = loc("jit()/jit(main)/jit(tril)/select_n"(#loc7)) +""", + mlir_module_serialized=b'ML\xefR\x01StableHLO_v0.9.0\x00\x013\x05\x01\x03\x01\x03\x05\x03#\x07\t\x0b\r\x0f\x11\x13\x15\x17\x19\x1b\x1d\x1f!#%\'\x03\xda\x02*\x02?\x01\xab\x0f\x0b\x13\x17\x0f\x0b\x07\x17\x0b\x0b\x0f\x0b\x0b\x0b\x0b\x13\x0b\x13\x0f\x0b\x0b\x0f\x13+\x0b\x0f\x0b\x0b\x0b33\x0b\x0f\x0b\x0b\x13\x0f\x0b\x1b\x0f\x0b\x13\x0f\x0b\x0f\x0b\x0f\x0b\x17\x0f\x0b\x17\x13\x0b\x0f\x0b\x17\x0f\x0b\x0f\x0b\x0f\x0b\x0f\x0b\x0f\x0b\x17\x13\x0b\x17\x13\x0b\x0b\x17S\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x03a\x0b\x0b\x0b\x0b\x0f\x0b\x0bO\x0b\x13\x1b\x0b\x1b\x0b\x0b\x0b\x13\x0b\x0b\x0f\x1f\x0f\x0f\x0bOOO/\x0b\x0b\x0b\x0b\x1b\x0b\x0f\x0b\x0f\x0f\x0f\x17\x17/\x0f\x0bOO//\x01\x0b\x1f\x17\x17\x17\x17\x01\x05\x0b\x0f\x03;\x17\x07\x0f\x0f\x07\x07\x13\x17\x0b\x17\x0f\x07\x07\x17\x13\x07\x0f\x17\x17\x13\x17\x13\x13\x13\x0f\x17\x13\x13\x13\x02\xa6\n\x1d\x93\x95\x05)\x03\x03\x13\xd5\x17\x03V\x047\x1d?\x07\x05+\x1f\x17\x03>\x043\x05-\x05/\x11\x03\x05\x051\x053\x055\x057\x03\x03!\xd1\x059\x03\x03\x0b\xd3\x1dE\x07\x05;\x05=\x1d\x8b\x8d\x03\x03\x0b\xe1\x03\t135\x157\x15\x119\x05?\x11\x01\x00\x05A\x05C\x05E\x03\x0b\x17\xaf\x19\xbb\x1b\xbd\x11\xc7\x1d\xc9\x03\x0b\x17\xb3\x19\xcd\x1b\xb3\x11\xb5\x1d\xcf\x05G\x1dC\x07\x05I\x05K\x03\x03!\xd7\x1dK\x07\x05M\x03\x05\'\xb7)\xd9\x1dQ\x07\x05O\x03\x03\x0b\xdb\x1dW\x07\x05Q\x1d[\x07\x05S\x1d_a\x05U\x17\x036\x045\x1deg\x05W\x17\x036\x04\x1d\x03\x03k\xdd\x05Y\x1doq\x05[\x17\x03>\x04E\x1du\x0f\x05]\x1dy\x0f\x05_\x1d}\x0f\x05a\x1d\x81\x0f\x05c\x1d\x85\x87\x05e\x17\x03>\x04\x1f\x03\x03\x0b\xdf\x05g\x17\x03>\x04\x1d\x03\x03\x91\xb5\x05i\x05k\x17\x03V\x04\x17\x03\x13\x99\xe3\x9b\xe5\x9d\xe7\x9f\xaf\xa1\xe9\xa3\xeb\xa5\xf5\xa7\xf7\xa9\xfb\x05m\x05o\x05q\x05s\x05u\x05w\x05y\x05{\x05}\x1d\x7f\x1d\x81\x03\x01\x1d\x83\x03\x03\xcb\x1d\x85\t\x07\x1f/!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00#\'\x03\x05\xbf\xc3\r\x05\xb1\xc1\xab\xad\x1d\x87\r\x05\xb1\xc5\xab\xad\x1d\x89\x1d\x8b\x1d\x8d\r\x03\xab\xad#)\x1d\x8f\x13\x07\x01\x1f\x0b\t\x00\x00\x00\x00\x1f+\x01\x13\x07\x05\x07\x05\x1f\t!\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x1f!!\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x1f\t!\x00\x00\x00\x00\x00\x00\x00@\x00\x00\x00\x00\x00\x00\x00\x00\x1f\x19\x11\x08\x00\x00\x00\x00\x00\x00\x00\x0b\x03\x1d\x91\x1d\x93\x05\x01\r\x05\xed\xef\xf1\xf3\x1d\x95\x13#V\x1d\x97\x13#L\x03\x03\xb9\x03\x03\xf9\x15\x03\x01\x01\x01\x03\x07\xb9\xfd\xff\x1f1\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1f3\x01\x07\x01\x1f\t!\x00\x00\x00\x00\x00\x00\xf8\x7f\x00\x00\x00\x00\x00\x00\xf8\x7f\x1f!!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x1f%\x11\x00\x00\x00\x00\x00\x00\xf8\x7f\x1f=\x11\x00\x00\x00\x00\x00\x00\x00\x00\x03\x05\'\xb7)\x02\x02\x03\x03\x0b\x06\x02\x03\x03\x13\n\x02\x03\x03\x0b\x0e\x02\x03\x03\x13\x12\x02\x01\t\x01\x02\x02)\x05!!\x15\x1d)\x01\x15)\x01\x1d\x01\x0b)\x03!\x0f)\x05!!\x1d\x03\x0f)\x05!!\x0f)\x01\x07\x13\x1b)\x05!!\r)\x03\t\x07!)\x01\x0f\x11\x01\x05\x05\x11\x11\x03\x05\x03\x05)\x03\x01\x07)\x03\x02\x02\x15)\x03\t\x1b)\x03\x05\x1b)\x03\x01\x1b)\x01\r)\x05\x05\x05\r)\x03\x05\r)\x03!\r)\x03\x05\x07\x04\x06\x05\x05\x01\x11\r/\x07\x03\x01\t\x0b\x11\r;\x07\x03=u\x07\x03]\x1f\x03-\x13\x06c\x03\x05\x03\x01\x15\x07mi\x03\x05\x03\x03\x17\x06s\x03\x17\x03\x05\x19\x06w\x03\x17\x03\x05\x1b\x06{\x03\x17\x03\t\x1d\x06\x7f\x03\x05\x05\x07\x0b\r\x06\x83\x03\x05\x05\x03\r\x05\x03\r\x89\x03\t\x03\x07+\x05\x03\x05\x03\x11\x1f\x06+\x03\x05\x05\x0f\x13!\x07\t\x8f\x03\x05\x03\x15\x05\x03\x01-\x03\x19\x05\x03\x01-\x03\x19#\x07\x01\x97\x07\x05\x11\x0b\x03\x17\x05\x03\x01#\x03\x0b\x03\x07\x01\x05\x03\x0b\x03#\x0f\x07\x01\x16\x02\x035\x05!%\x03\x07\x01\x05\x037\x03\'\x05\x03\x01\x1a\x02\x03\t\x03\x07\x01\x05\x03\x05\x03+\x03\x07\x01\x1e\x02\x03\x1f\x03)\t\x06\x01\x03\x05\x07/\x1d-\x03\x07\x01\x05\x039\x03\'\x05\x03\x01"\x02\x03%\x03\x07\x01\x05\x03\x11\x035\x03\x07\x01&\x02\x03;\x033\t\x06\x01\x03\x11\x079\x1f7\x11\x04\r\x051;\x0b\x11\t=\x07\x03\x15+\x03\x05\t\x07\x03A\x1f\x03\x13\x05\x03\t#\x03\x0b\x03\x07%\x05\x03\x13\x03\x05\r\x06%\x03\x13\x05\x03\x07\x07\x03IG\x03\x13\x0f\x07OM\x03\x1f\x05\t\x0b\x05\x03\tS\x03\t\x03\x07U\x05\x03\x05\x03\x0f\t\x06Y\x03\x05\x07\r\x01\x11\x11\x04\t\x03\x13\x06\x03\x01\x05\x01\x00\xe2\x1c\x99\x0b\x0b%\x03\x11\x0f\x0b\t\t\x0b!\x11#\x1f/!)!)#\x1f\x19\xa9\x0f99A9;;m\x19\x85\x8fW\xb3K\x9bM\x9bn\x03\x1b%)9+\x1b\x1f\x1f\x15\x1d\x15+\x13\ri\x1f\x11\x15\x17\x15\x11\x11\x1b\x17\x15\x17\x0f\x11\x15\x11\x19)\x0f\x0b\x11builtin\x00vhlo\x00module\x00broadcast_in_dim_v1\x00constant_v1\x00iota_v1\x00select_v1\x00func_v1\x00add_v1\x00compare_v1\x00return_v1\x00reshape_v1\x00transpose_v1\x00real_v1\x00imag_v1\x00negate_v1\x00complex_v1\x00divide_v1\x00call_v1\x00custom_call_v1\x00third_party/py/jax/tests/export_back_compat_test.py\x00value\x00sym_name\x00broadcast_dimensions\x00arg_attrs\x00function_type\x00res_attrs\x00sym_visibility\x00iota_dimension\x00compare_type\x00comparison_direction\x00jax.uses_shape_polymorphism\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit__lambda_\x00jit()/jit(main)/pjit[in_shardings=(UnspecifiedValue,) out_shardings=(UnspecifiedValue,) in_layouts=(None,) out_layouts=(None,) resource_env=None donated_invars=(False,) name=tril keep_unused=False inline=False]\x00jit()/jit(main)/jit(tril)/iota[dtype=int32 shape=(8, 8) dimension=0]\x00jit()/jit(main)/jit(tril)/add\x00jit()/jit(main)/jit(tril)/iota[dtype=int32 shape=(8, 8) dimension=1]\x00jit()/jit(main)/jit(tril)/ge\x00jit()/jit(main)/jit(tril)/broadcast_in_dim[shape=(8, 8) broadcast_dimensions=()]\x00jit()/jit(main)/jit(tril)/select_n\x00jit()/jit(main)/iota[dtype=complex128 shape=(64,) dimension=0]\x00jit()/jit(main)/reshape[new_sizes=(8, 8) dimensions=None]\x00permutation\x00jit()/jit(main)/transpose[permutation=(1, 0)]\x00jit()/jit(main)/real\x00jit()/jit(main)/imag\x00jit()/jit(main)/neg\x00jit()/jit(main)/complex\x00jit()/jit(main)/add\x00jit()/jit(main)/div\x00callee\x00jit()/jit(main)/eigh[lower=True sort_eigenvalues=True subset_by_index=None]\x00api_version\x00backend_config\x00call_target_name\x00called_computations\x00has_side_effect\x00mhlo.backend_config\x00operand_layouts\x00output_operand_aliases\x00result_layouts\x00mhlo.layout_mode\x00default\x00jax.result_info\x00tril\x00[0]\x00[1]\x00main\x00public\x00private\x00\x00lapack_zheevd_ffi\x00mode\x00uplo\x00', + xla_call_module_version=9, + nr_devices=1, +) # End paste + + +# Pasted from the test output (see export_back_compat_test_util.py module docstring) +data_2024_08_19["c64"] = dict( + testdata_version=1, + platform='cpu', + custom_call_targets=['lapack_cheevd_ffi'], + serialized_date=datetime.date(2024, 8, 19), + inputs=(), + expected_outputs=(array([[-0.6185769 +0.j, -0.20142993 +0.j, -0.09725195 +0.j, + 0.62983674 +0.j, -0.07926044 +0.j, 0.3605001 -0.j, + -0.019093221 +0.j, -0.18446997 +0.j], + [-0.47070873 +0.j, 0.29325768 +0.j, -0.19454116 +0.j, + -0.6394365 +0.j, 0.06229549 +0.j, 0.33249345 +0.j, + 0.28112718 +0.j, -0.22856665 +0.j], + [-0.32284075 +0.j, -0.12361939 +0.j, 0.20547704 +0.j, + -0.18307868 +0.j, 0.47294614 +0.j, -0.3170349 +0.j, + -0.6373532 +0.j, -0.27266347 +0.j], + [-0.17497246 +0.j, -0.079641335 +0.j, 0.15042792 +0.j, + -0.15416273 +0.j, -0.815209 +0.j, -0.38054234 +0.j, + -0.083263926 +0.j, -0.31676024 +0.j], + [-0.027104257 +0.j, -0.26490977 +0.j, 0.32271704 +0.j, + 0.08653544 +0.j, 0.30305928 +0.j, -0.33998996 +0.j, + 0.6926741 +0.j, -0.360857 +0.j], + [ 0.120763965 +0.j, 0.43288827 +0.j, -0.64385164 +0.j, + 0.2652551 +0.j, 0.094823755 +0.j, -0.37435007 +0.j, + 0.00091664493+0.j, -0.40495378 +0.j], + [ 0.26863196 +0.j, 0.51607686 +0.j, 0.53846526 +0.j, + 0.16969058 +0.j, -0.0216703 +0.j, 0.35755336 +0.j, + -0.113144726 +0.j, -0.4490505 +0.j], + [ 0.4165004 +0.j, -0.57262254 +0.j, -0.28144246 +0.j, + -0.17463988 +0.j, -0.016984984 +0.j, 0.3613705 +0.j, + -0.12186296 +0.j, -0.49314725 +0.j]], dtype=complex64), array([-2.4598808e+01, -3.3105560e-05, -3.1002426e-05, -1.0103593e-05, + -1.0022322e-05, 4.0141886e-06, 9.5510331e-06, 2.7659882e+02], + dtype=float32)), + mlir_module_text=r""" +#loc7 = loc("third_party/py/jax/tests/export_back_compat_test.py":277:27) +#loc18 = loc("jit()/jit(main)/pjit[in_shardings=(UnspecifiedValue,) out_shardings=(UnspecifiedValue,) in_layouts=(None,) out_layouts=(None,) resource_env=None donated_invars=(False,) name=tril keep_unused=False inline=False]"(#loc7)) +module @jit__lambda_ attributes {jax.uses_shape_polymorphism = false, mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { + func.func public @main() -> (tensor<8x8xcomplex> {jax.result_info = "[0]", mhlo.layout_mode = "default"}, tensor<8xf32> {jax.result_info = "[1]", mhlo.layout_mode = "default"}) { + %0 = stablehlo.iota dim = 0 : tensor<64xcomplex> loc(#loc9) + %1 = stablehlo.reshape %0 : (tensor<64xcomplex>) -> tensor<8x8xcomplex> loc(#loc10) + %2 = stablehlo.transpose %1, dims = [1, 0] : (tensor<8x8xcomplex>) -> tensor<8x8xcomplex> loc(#loc11) + %3 = stablehlo.real %2 : (tensor<8x8xcomplex>) -> tensor<8x8xf32> loc(#loc12) + %4 = stablehlo.imag %2 : (tensor<8x8xcomplex>) -> tensor<8x8xf32> loc(#loc13) + %5 = stablehlo.negate %4 : tensor<8x8xf32> loc(#loc14) + %6 = stablehlo.complex %3, %5 : tensor<8x8xcomplex> loc(#loc15) + %7 = stablehlo.add %1, %6 : tensor<8x8xcomplex> loc(#loc16) + %cst = stablehlo.constant dense<(2.000000e+00,0.000000e+00)> : tensor> loc(#loc) + %8 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor>) -> tensor<8x8xcomplex> loc(#loc17) + %9 = stablehlo.divide %7, %8 : tensor<8x8xcomplex> loc(#loc17) + %10 = call @tril(%9) : (tensor<8x8xcomplex>) -> tensor<8x8xcomplex> loc(#loc18) + %c = stablehlo.constant dense<8> : tensor loc(#loc19) + %c_0 = stablehlo.constant dense<8> : tensor loc(#loc19) + %11:3 = stablehlo.custom_call @lapack_cheevd_ffi(%10) {mhlo.backend_config = {mode = 86 : ui8, uplo = 76 : ui8}, operand_layouts = [dense<[0, 1]> : tensor<2xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<0> : tensor<1xindex>, dense<> : tensor<0xindex>]} : (tensor<8x8xcomplex>) -> (tensor<8x8xcomplex>, tensor<8xf32>, tensor) loc(#loc19) + %c_1 = stablehlo.constant dense<0> : tensor loc(#loc19) + %12 = stablehlo.broadcast_in_dim %c_1, dims = [] : (tensor) -> tensor loc(#loc19) + %13 = stablehlo.compare EQ, %11#2, %12, SIGNED : (tensor, tensor) -> tensor loc(#loc19) + %14 = stablehlo.broadcast_in_dim %13, dims = [] : (tensor) -> tensor<1x1xi1> loc(#loc19) + %cst_2 = stablehlo.constant dense<(0x7FC00000,0x7FC00000)> : tensor> loc(#loc19) + %15 = stablehlo.broadcast_in_dim %cst_2, dims = [] : (tensor>) -> tensor<8x8xcomplex> loc(#loc19) + %16 = stablehlo.broadcast_in_dim %14, dims = [0, 1] : (tensor<1x1xi1>) -> tensor<8x8xi1> loc(#loc19) + %17 = stablehlo.select %16, %11#0, %15 : tensor<8x8xi1>, tensor<8x8xcomplex> loc(#loc19) + %18 = stablehlo.broadcast_in_dim %13, dims = [] : (tensor) -> tensor<1xi1> loc(#loc19) + %cst_3 = stablehlo.constant dense<0x7FC00000> : tensor loc(#loc19) + %19 = stablehlo.broadcast_in_dim %cst_3, dims = [] : (tensor) -> tensor<8xf32> loc(#loc19) + %20 = stablehlo.broadcast_in_dim %18, dims = [0] : (tensor<1xi1>) -> tensor<8xi1> loc(#loc19) + %21 = stablehlo.select %20, %11#1, %19 : tensor<8xi1>, tensor<8xf32> loc(#loc19) + return %17, %21 : tensor<8x8xcomplex>, tensor<8xf32> loc(#loc) + } loc(#loc) + func.func private @tril(%arg0: tensor<8x8xcomplex> {mhlo.layout_mode = "default"} loc("jit()/jit(main)/pjit[in_shardings=(UnspecifiedValue,) out_shardings=(UnspecifiedValue,) in_layouts=(None,) out_layouts=(None,) resource_env=None donated_invars=(False,) name=tril keep_unused=False inline=False]"(#loc7))) -> (tensor<8x8xcomplex> {mhlo.layout_mode = "default"}) { + %0 = stablehlo.iota dim = 0 : tensor<8x8xi32> loc(#loc20) + %c = stablehlo.constant dense<0> : tensor loc(#loc18) + %1 = stablehlo.broadcast_in_dim %c, dims = [] : (tensor) -> tensor<8x8xi32> loc(#loc21) + %2 = stablehlo.add %0, %1 : tensor<8x8xi32> loc(#loc21) + %3 = stablehlo.iota dim = 1 : tensor<8x8xi32> loc(#loc22) + %4 = stablehlo.compare GE, %2, %3, SIGNED : (tensor<8x8xi32>, tensor<8x8xi32>) -> tensor<8x8xi1> loc(#loc23) + %cst = stablehlo.constant dense<(0.000000e+00,0.000000e+00)> : tensor> loc(#loc18) + %5 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor>) -> tensor<8x8xcomplex> loc(#loc24) + %6 = stablehlo.select %4, %arg0, %5 : tensor<8x8xi1>, tensor<8x8xcomplex> loc(#loc25) + return %6 : tensor<8x8xcomplex> loc(#loc18) + } loc(#loc18) +} loc(#loc) +#loc = loc(unknown) +#loc1 = loc("third_party/py/jax/tests/export_back_compat_test.py":269:26) +#loc2 = loc("third_party/py/jax/tests/export_back_compat_test.py":269:14) +#loc3 = loc("third_party/py/jax/tests/export_back_compat_test.py":271:34) +#loc4 = loc("third_party/py/jax/tests/export_back_compat_test.py":271:25) +#loc5 = loc("third_party/py/jax/tests/export_back_compat_test.py":271:15) +#loc6 = loc("third_party/py/jax/tests/export_back_compat_test.py":271:14) +#loc8 = loc("third_party/py/jax/tests/export_back_compat_test.py":277:11) +#loc9 = loc("jit()/jit(main)/iota[dtype=complex64 shape=(64,) dimension=0]"(#loc1)) +#loc10 = loc("jit()/jit(main)/reshape[new_sizes=(8, 8) dimensions=None]"(#loc2)) +#loc11 = loc("jit()/jit(main)/transpose[permutation=(1, 0)]"(#loc3)) +#loc12 = loc("jit()/jit(main)/real"(#loc4)) +#loc13 = loc("jit()/jit(main)/imag"(#loc4)) +#loc14 = loc("jit()/jit(main)/neg"(#loc4)) +#loc15 = loc("jit()/jit(main)/complex"(#loc4)) +#loc16 = loc("jit()/jit(main)/add"(#loc5)) +#loc17 = loc("jit()/jit(main)/div"(#loc6)) +#loc19 = loc("jit()/jit(main)/eigh[lower=True sort_eigenvalues=True subset_by_index=None]"(#loc8)) +#loc20 = loc("jit()/jit(main)/jit(tril)/iota[dtype=int32 shape=(8, 8) dimension=0]"(#loc7)) +#loc21 = loc("jit()/jit(main)/jit(tril)/add"(#loc7)) +#loc22 = loc("jit()/jit(main)/jit(tril)/iota[dtype=int32 shape=(8, 8) dimension=1]"(#loc7)) +#loc23 = loc("jit()/jit(main)/jit(tril)/ge"(#loc7)) +#loc24 = loc("jit()/jit(main)/jit(tril)/broadcast_in_dim[shape=(8, 8) broadcast_dimensions=()]"(#loc7)) +#loc25 = loc("jit()/jit(main)/jit(tril)/select_n"(#loc7)) +""", + mlir_module_serialized=b'ML\xefR\x01StableHLO_v0.9.0\x00\x013\x05\x01\x03\x01\x03\x05\x03#\x07\t\x0b\r\x0f\x11\x13\x15\x17\x19\x1b\x1d\x1f!#%\'\x03\xda\x02*\x02?\x01\xab\x0f\x0b\x13\x17\x0f\x0b\x07\x17\x0b\x0b\x0f\x0b\x0b\x0b\x0b\x13\x0b\x13\x0f\x0b\x0b\x0f\x13+\x0b\x0f\x0b\x0b\x0b33\x0b\x0f\x0b\x0b\x13\x0f\x0b\x1b\x0f\x0b\x13\x0f\x0b\x0f\x0b\x0f\x0b\x17\x0f\x0b\x17\x13\x0b\x0f\x0b\x17\x0f\x0b\x0f\x0b\x0f\x0b\x0f\x0b\x0f\x0b\x17\x13\x0b\x17\x13\x0b\x0b\x17S\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x03a\x0b\x0b\x0b\x0b\x0f\x0b\x0bO\x0b\x13\x1b\x0b\x1b\x0b\x0b\x0b\x13\x0b\x0b\x0f\x1f\x0f\x0f\x0b/O//\x0b\x0b\x0b\x0b\x1b\x0b\x0f\x0b\x0f\x0f\x0f\x17\x17/\x0f\x0b/O\x1f/\x01\x0b\x1f\x17\x17\x17\x17\x01\x05\x0b\x0f\x03;\x17\x07\x0f\x0f\x07\x07\x13\x17\x0b\x17\x0f\x07\x07\x17\x13\x07\x0f\x17\x17\x13\x17\x13\x13\x13\x0f\x17\x13\x13\x13\x026\n\x1d\x93\x95\x05)\x03\x03\x13\xd5\x17\x03V\x047\x1d?\x07\x05+\x1f\x17\x03>\x043\x05-\x05/\x11\x03\x05\x051\x053\x055\x057\x03\x03!\xd1\x059\x03\x03\x0b\xd3\x1dE\x07\x05;\x05=\x1d\x8b\x8d\x03\x03\x0b\xe1\x03\t135\x157\x15\x119\x05?\x11\x01\x00\x05A\x05C\x05E\x03\x0b\x17\xaf\x19\xbb\x1b\xbd\x11\xc7\x1d\xc9\x03\x0b\x17\xb3\x19\xcd\x1b\xb3\x11\xb5\x1d\xcf\x05G\x1dC\x07\x05I\x05K\x03\x03!\xd7\x1dK\x07\x05M\x03\x05\'\xb7)\xd9\x1dQ\x07\x05O\x03\x03\x0b\xdb\x1dW\x07\x05Q\x1d[\x07\x05S\x1d_a\x05U\x17\x036\x045\x1deg\x05W\x17\x036\x04\x1d\x03\x03k\xdd\x05Y\x1doq\x05[\x17\x03>\x04E\x1du\x0f\x05]\x1dy\x0f\x05_\x1d}\x0f\x05a\x1d\x81\x0f\x05c\x1d\x85\x87\x05e\x17\x03>\x04\x1f\x03\x03\x0b\xdf\x05g\x17\x03>\x04\x1d\x03\x03\x91\xb5\x05i\x05k\x17\x03V\x04\x17\x03\x13\x99\xe3\x9b\xe5\x9d\xe7\x9f\xaf\xa1\xe9\xa3\xeb\xa5\xf5\xa7\xf7\xa9\xfb\x05m\x05o\x05q\x05s\x05u\x05w\x05y\x05{\x05}\x1d\x7f\x1d\x81\x03\x01\x1d\x83\x03\x03\xcb\x1d\x85\t\x07\x1f/!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00#\'\x03\x05\xbf\xc3\r\x05\xb1\xc1\xab\xad\x1d\x87\r\x05\xb1\xc5\xab\xad\x1d\x89\x1d\x8b\x1d\x8d\r\x03\xab\xad#)\x1d\x8f\x13\x07\x01\x1f\x0b\t\x00\x00\x00\x00\x1f+\x01\x13\x07\x05\x07\x05\x1f\t\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1f!!\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x1f\t\x11\x00\x00\x00@\x00\x00\x00\x00\x1f\x19\x11\x08\x00\x00\x00\x00\x00\x00\x00\x0b\x03\x1d\x91\x1d\x93\x05\x01\r\x05\xed\xef\xf1\xf3\x1d\x95\x13#V\x1d\x97\x13#L\x03\x03\xb9\x03\x03\xf9\x15\x03\x01\x01\x01\x03\x07\xb9\xfd\xff\x1f1\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1f3\x01\x07\x01\x1f\t\x11\x00\x00\xc0\x7f\x00\x00\xc0\x7f\x1f!!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x1f%\t\x00\x00\xc0\x7f\x1f=\x11\x00\x00\x00\x00\x00\x00\x00\x00\x03\x05\'\xb7)\x02\x02\x03\x03\x0b\x06\x02\x03\x03\x13\n\x02\x03\x03\x0b\x0e\x02\x03\x03\x13\x12\x02\x01\t\x01\x02\x02)\x05!!\x15\x1d)\x01\x15)\x01\x1d\x01\t)\x03!\x0f)\x05!!\x1d\x03\x0f)\x05!!\x0f)\x01\x07\x13\x1b)\x05!!\r)\x03\t\x07!)\x01\x0f\x11\x01\x05\x05\x11\x11\x03\x05\x03\x05)\x03\x01\x07)\x03\x02\x02\x15)\x03\t\x1b)\x03\x05\x1b)\x03\x01\x1b)\x01\r)\x05\x05\x05\r)\x03\x05\r)\x03!\r)\x03\x05\x07\x04\x06\x05\x05\x01\x11\r/\x07\x03\x01\t\x0b\x11\r;\x07\x03=u\x07\x03]\x1f\x03-\x13\x06c\x03\x05\x03\x01\x15\x07mi\x03\x05\x03\x03\x17\x06s\x03\x17\x03\x05\x19\x06w\x03\x17\x03\x05\x1b\x06{\x03\x17\x03\t\x1d\x06\x7f\x03\x05\x05\x07\x0b\r\x06\x83\x03\x05\x05\x03\r\x05\x03\r\x89\x03\t\x03\x07+\x05\x03\x05\x03\x11\x1f\x06+\x03\x05\x05\x0f\x13!\x07\t\x8f\x03\x05\x03\x15\x05\x03\x01-\x03\x19\x05\x03\x01-\x03\x19#\x07\x01\x97\x07\x05\x11\x0b\x03\x17\x05\x03\x01#\x03\x0b\x03\x07\x01\x05\x03\x0b\x03#\x0f\x07\x01\x16\x02\x035\x05!%\x03\x07\x01\x05\x037\x03\'\x05\x03\x01\x1a\x02\x03\t\x03\x07\x01\x05\x03\x05\x03+\x03\x07\x01\x1e\x02\x03\x1f\x03)\t\x06\x01\x03\x05\x07/\x1d-\x03\x07\x01\x05\x039\x03\'\x05\x03\x01"\x02\x03%\x03\x07\x01\x05\x03\x11\x035\x03\x07\x01&\x02\x03;\x033\t\x06\x01\x03\x11\x079\x1f7\x11\x04\r\x051;\x0b\x11\t=\x07\x03\x15+\x03\x05\t\x07\x03A\x1f\x03\x13\x05\x03\t#\x03\x0b\x03\x07%\x05\x03\x13\x03\x05\r\x06%\x03\x13\x05\x03\x07\x07\x03IG\x03\x13\x0f\x07OM\x03\x1f\x05\t\x0b\x05\x03\tS\x03\t\x03\x07U\x05\x03\x05\x03\x0f\t\x06Y\x03\x05\x07\r\x01\x11\x11\x04\t\x03\x13\x06\x03\x01\x05\x01\x00\xde\x1c\x99\x0b\x0b%\x03\x11\x0f\x0b\t\t\x0b!\x11#\x1f/!)!)#\x1f\x19\xa9\x0f99A9;;m\x19\x85\x8dW\xb3K\x9bM\x9bn\x03\x1b%)9+\x1b\x1f\x1f\x15\x1d\x15+\x13\ri\x1f\x11\x15\x17\x15\x11\x11\x1b\x17\x15\x17\x0f\x11\x15\x11\x19)\x0f\x0b\x11builtin\x00vhlo\x00module\x00broadcast_in_dim_v1\x00constant_v1\x00iota_v1\x00select_v1\x00func_v1\x00add_v1\x00compare_v1\x00return_v1\x00reshape_v1\x00transpose_v1\x00real_v1\x00imag_v1\x00negate_v1\x00complex_v1\x00divide_v1\x00call_v1\x00custom_call_v1\x00third_party/py/jax/tests/export_back_compat_test.py\x00value\x00sym_name\x00broadcast_dimensions\x00arg_attrs\x00function_type\x00res_attrs\x00sym_visibility\x00iota_dimension\x00compare_type\x00comparison_direction\x00jax.uses_shape_polymorphism\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit__lambda_\x00jit()/jit(main)/pjit[in_shardings=(UnspecifiedValue,) out_shardings=(UnspecifiedValue,) in_layouts=(None,) out_layouts=(None,) resource_env=None donated_invars=(False,) name=tril keep_unused=False inline=False]\x00jit()/jit(main)/jit(tril)/iota[dtype=int32 shape=(8, 8) dimension=0]\x00jit()/jit(main)/jit(tril)/add\x00jit()/jit(main)/jit(tril)/iota[dtype=int32 shape=(8, 8) dimension=1]\x00jit()/jit(main)/jit(tril)/ge\x00jit()/jit(main)/jit(tril)/broadcast_in_dim[shape=(8, 8) broadcast_dimensions=()]\x00jit()/jit(main)/jit(tril)/select_n\x00jit()/jit(main)/iota[dtype=complex64 shape=(64,) dimension=0]\x00jit()/jit(main)/reshape[new_sizes=(8, 8) dimensions=None]\x00permutation\x00jit()/jit(main)/transpose[permutation=(1, 0)]\x00jit()/jit(main)/real\x00jit()/jit(main)/imag\x00jit()/jit(main)/neg\x00jit()/jit(main)/complex\x00jit()/jit(main)/add\x00jit()/jit(main)/div\x00callee\x00jit()/jit(main)/eigh[lower=True sort_eigenvalues=True subset_by_index=None]\x00api_version\x00backend_config\x00call_target_name\x00called_computations\x00has_side_effect\x00mhlo.backend_config\x00operand_layouts\x00output_operand_aliases\x00result_layouts\x00mhlo.layout_mode\x00default\x00jax.result_info\x00tril\x00[0]\x00[1]\x00main\x00public\x00private\x00\x00lapack_cheevd_ffi\x00mode\x00uplo\x00', + xla_call_module_version=9, + nr_devices=1, +) # End paste + + +# Pasted from the test output (see export_back_compat_test_util.py module docstring) +data_2024_08_19["f32"] = dict( + testdata_version=1, + platform='cpu', + custom_call_targets=['lapack_ssyevd_ffi'], + serialized_date=datetime.date(2024, 8, 19), + inputs=(), + expected_outputs=(array([[-0.6185769 , -0.20142993 , -0.09725195 , 0.62983674 , + -0.07926044 , 0.3605001 , -0.019093221 , -0.18446997 ], + [-0.47070873 , 0.29325768 , -0.19454119 , -0.6394365 , + 0.0622955 , 0.33249345 , 0.28112718 , -0.22856665 ], + [-0.32284075 , -0.12361939 , 0.20547704 , -0.18307868 , + 0.47294614 , -0.3170349 , -0.6373532 , -0.27266347 ], + [-0.17497246 , -0.079641335 , 0.15042791 , -0.15416273 , + -0.815209 , -0.38054234 , -0.083263926 , -0.31676024 ], + [-0.027104253 , -0.26490977 , 0.32271704 , 0.08653544 , + 0.30305928 , -0.33998996 , 0.6926741 , -0.360857 ], + [ 0.12076397 , 0.43288827 , -0.64385164 , 0.2652551 , + 0.09482376 , -0.37435007 , 0.00091664493, -0.40495378 ], + [ 0.26863196 , 0.51607686 , 0.53846526 , 0.16969058 , + -0.021670295 , 0.35755336 , -0.113144726 , -0.4490505 ], + [ 0.4165004 , -0.57262254 , -0.2814425 , -0.17463988 , + -0.01698498 , 0.3613705 , -0.12186296 , -0.49314725 ]], + dtype=float32), array([-2.4598808e+01, -3.3105560e-05, -3.1002426e-05, -1.0103593e-05, + -1.0022322e-05, 4.0141886e-06, 9.5510331e-06, 2.7659882e+02], + dtype=float32)), + mlir_module_text=r""" +#loc6 = loc("third_party/py/jax/tests/export_back_compat_test.py":277:27) +#loc13 = loc("jit()/jit(main)/pjit[in_shardings=(UnspecifiedValue,) out_shardings=(UnspecifiedValue,) in_layouts=(None,) out_layouts=(None,) resource_env=None donated_invars=(False,) name=tril keep_unused=False inline=False]"(#loc6)) +module @jit__lambda_ attributes {jax.uses_shape_polymorphism = false, mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { + func.func public @main() -> (tensor<8x8xf32> {jax.result_info = "[0]", mhlo.layout_mode = "default"}, tensor<8xf32> {jax.result_info = "[1]", mhlo.layout_mode = "default"}) { + %0 = stablehlo.iota dim = 0 : tensor<64xf32> loc(#loc8) + %1 = stablehlo.reshape %0 : (tensor<64xf32>) -> tensor<8x8xf32> loc(#loc9) + %2 = stablehlo.transpose %1, dims = [1, 0] : (tensor<8x8xf32>) -> tensor<8x8xf32> loc(#loc10) + %3 = stablehlo.add %1, %2 : tensor<8x8xf32> loc(#loc11) + %cst = stablehlo.constant dense<2.000000e+00> : tensor loc(#loc) + %4 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor) -> tensor<8x8xf32> loc(#loc12) + %5 = stablehlo.divide %3, %4 : tensor<8x8xf32> loc(#loc12) + %6 = call @tril(%5) : (tensor<8x8xf32>) -> tensor<8x8xf32> loc(#loc13) + %c = stablehlo.constant dense<8> : tensor loc(#loc14) + %c_0 = stablehlo.constant dense<8> : tensor loc(#loc14) + %7:3 = stablehlo.custom_call @lapack_ssyevd_ffi(%6) {mhlo.backend_config = {mode = 86 : ui8, uplo = 76 : ui8}, operand_layouts = [dense<[0, 1]> : tensor<2xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<0> : tensor<1xindex>, dense<> : tensor<0xindex>]} : (tensor<8x8xf32>) -> (tensor<8x8xf32>, tensor<8xf32>, tensor) loc(#loc14) + %c_1 = stablehlo.constant dense<0> : tensor loc(#loc14) + %8 = stablehlo.broadcast_in_dim %c_1, dims = [] : (tensor) -> tensor loc(#loc14) + %9 = stablehlo.compare EQ, %7#2, %8, SIGNED : (tensor, tensor) -> tensor loc(#loc14) + %10 = stablehlo.broadcast_in_dim %9, dims = [] : (tensor) -> tensor<1x1xi1> loc(#loc14) + %cst_2 = stablehlo.constant dense<0x7FC00000> : tensor loc(#loc14) + %11 = stablehlo.broadcast_in_dim %cst_2, dims = [] : (tensor) -> tensor<8x8xf32> loc(#loc14) + %12 = stablehlo.broadcast_in_dim %10, dims = [0, 1] : (tensor<1x1xi1>) -> tensor<8x8xi1> loc(#loc14) + %13 = stablehlo.select %12, %7#0, %11 : tensor<8x8xi1>, tensor<8x8xf32> loc(#loc14) + %14 = stablehlo.broadcast_in_dim %9, dims = [] : (tensor) -> tensor<1xi1> loc(#loc14) + %cst_3 = stablehlo.constant dense<0x7FC00000> : tensor loc(#loc14) + %15 = stablehlo.broadcast_in_dim %cst_3, dims = [] : (tensor) -> tensor<8xf32> loc(#loc14) + %16 = stablehlo.broadcast_in_dim %14, dims = [0] : (tensor<1xi1>) -> tensor<8xi1> loc(#loc14) + %17 = stablehlo.select %16, %7#1, %15 : tensor<8xi1>, tensor<8xf32> loc(#loc14) + return %13, %17 : tensor<8x8xf32>, tensor<8xf32> loc(#loc) + } loc(#loc) + func.func private @tril(%arg0: tensor<8x8xf32> {mhlo.layout_mode = "default"} loc("jit()/jit(main)/pjit[in_shardings=(UnspecifiedValue,) out_shardings=(UnspecifiedValue,) in_layouts=(None,) out_layouts=(None,) resource_env=None donated_invars=(False,) name=tril keep_unused=False inline=False]"(#loc6))) -> (tensor<8x8xf32> {mhlo.layout_mode = "default"}) { + %0 = stablehlo.iota dim = 0 : tensor<8x8xi32> loc(#loc15) + %c = stablehlo.constant dense<0> : tensor loc(#loc13) + %1 = stablehlo.broadcast_in_dim %c, dims = [] : (tensor) -> tensor<8x8xi32> loc(#loc16) + %2 = stablehlo.add %0, %1 : tensor<8x8xi32> loc(#loc16) + %3 = stablehlo.iota dim = 1 : tensor<8x8xi32> loc(#loc17) + %4 = stablehlo.compare GE, %2, %3, SIGNED : (tensor<8x8xi32>, tensor<8x8xi32>) -> tensor<8x8xi1> loc(#loc18) + %cst = stablehlo.constant dense<0.000000e+00> : tensor loc(#loc13) + %5 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor) -> tensor<8x8xf32> loc(#loc19) + %6 = stablehlo.select %4, %arg0, %5 : tensor<8x8xi1>, tensor<8x8xf32> loc(#loc20) + return %6 : tensor<8x8xf32> loc(#loc13) + } loc(#loc13) +} loc(#loc) +#loc = loc(unknown) +#loc1 = loc("third_party/py/jax/tests/export_back_compat_test.py":269:26) +#loc2 = loc("third_party/py/jax/tests/export_back_compat_test.py":269:14) +#loc3 = loc("third_party/py/jax/tests/export_back_compat_test.py":271:34) +#loc4 = loc("third_party/py/jax/tests/export_back_compat_test.py":271:15) +#loc5 = loc("third_party/py/jax/tests/export_back_compat_test.py":271:14) +#loc7 = loc("third_party/py/jax/tests/export_back_compat_test.py":277:11) +#loc8 = loc("jit()/jit(main)/iota[dtype=float32 shape=(64,) dimension=0]"(#loc1)) +#loc9 = loc("jit()/jit(main)/reshape[new_sizes=(8, 8) dimensions=None]"(#loc2)) +#loc10 = loc("jit()/jit(main)/transpose[permutation=(1, 0)]"(#loc3)) +#loc11 = loc("jit()/jit(main)/add"(#loc4)) +#loc12 = loc("jit()/jit(main)/div"(#loc5)) +#loc14 = loc("jit()/jit(main)/eigh[lower=True sort_eigenvalues=True subset_by_index=None]"(#loc7)) +#loc15 = loc("jit()/jit(main)/jit(tril)/iota[dtype=int32 shape=(8, 8) dimension=0]"(#loc6)) +#loc16 = loc("jit()/jit(main)/jit(tril)/add"(#loc6)) +#loc17 = loc("jit()/jit(main)/jit(tril)/iota[dtype=int32 shape=(8, 8) dimension=1]"(#loc6)) +#loc18 = loc("jit()/jit(main)/jit(tril)/ge"(#loc6)) +#loc19 = loc("jit()/jit(main)/jit(tril)/broadcast_in_dim[shape=(8, 8) broadcast_dimensions=()]"(#loc6)) +#loc20 = loc("jit()/jit(main)/jit(tril)/select_n"(#loc6)) +""", + mlir_module_serialized=b"ML\xefR\x01StableHLO_v0.9.0\x00\x01+\x05\x01\x03\x01\x03\x05\x03\x1b\x07\t\x0b\r\x0f\x11\x13\x15\x17\x19\x1b\x1d\x1f\x03\x96\x02\xff9\x01\xa1\x0f\x13\x17\x0b\x0f\x0b\x07\x0b\x0b\x0f\x0b\x0b\x0b\x0b\x13\x0b\x13\x0f\x0b\x0b\x0f\x13\x13+\x0b\x0f\x0b\x0b\x0b33\x0b\x0f\x0b\x0b\x13\x0f\x0b\x1b\x0f\x0b\x13\x0f\x0b\x0f\x0b\x0f\x0b\x17\x0f\x0b\x17\x13\x0b\x0f\x0b\x17\x0f\x0b\x17\x13\x0b\x17\x13\x0b\x0b\x17S\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x1b\x13\x13\x03_\x0b\x0b\x0b\x0b\x0f\x0b\x0bO\x0b\x13\x1b\x0b\x1b\x0b\x0b\x0b\x13\x0b\x0b\x0f\x1f\x0f\x0f\x0b\x1fO\x1f/\x0b\x0b\x0b\x0b\x1b\x0b\x0f\x0b\x0f\x0f\x0f\x17\x17/\x0f\x0b\x1fO/\x01\x05\x0b\x0f\x035\x17\x0f\x07\x0f\x07\x07\x13\x17\x0f\x07\x07\x17\x13\x07\x17\x17\x13\x17\x13\x13\x13\x0f\x17\x13\x13\x13\x02:\t\x1d\x83\x85\x03\x03\x11\xcb\x17\x07V\x047\x05!\x1d?\x05\x05#\x1f\x05%\x05'\x11\x03\x05\x05)\x05+\x05-\x05/\x03\x03\x1f\xc7\x051\x03\x03\x0b\xc9\x1dE\x05\x053\x055\x1d{}\x03\x03\x0b\xd7\x03\x03\x0b\xf9\x03\t135\x137\x13\x0f9\x057\x11\x01\x00\x059\x05;\x05=\x03\x0b\x15\xa5\x17\xb1\x19\xb3\x0f\xbd\x1b\xbf\x03\x0b\x15\xa9\x17\xc3\x19\xa9\x0f\xab\x1b\xc5\x05?\x1dC\x05\x05A\x05C\x03\x03\x1f\xcd\x1dK\x05\x05E\x03\x05%\xad'\xcf\x1dQ\x05\x05G\x03\x03\x0b\xd1\x1dW\x05\x05I\x1d[\x05\x05K\x1d_a\x05M\x17\x076\x045\x1deg\x05O\x17\x076\x04\x1d\x03\x03k\xd3\x05Q\x1doq\x05S\x17\x07>\x04E\x1duw\x05U\x17\x07>\x04\x1f\x03\x03\x0b\xd5\x05W\x17\x07>\x04\x1d\x03\x03\x81\xab\x05Y\x05[\x17\x07V\x04\x17\x03\x13\x89\xd9\x8b\xdb\x8d\xdd\x8f\xa5\x91\xdf\x93\xe1\x95\xeb\x97\xed\x99\xf1\x05]\x05_\x05a\x05c\x05e\x05g\x05i\x05k\x05m\x03\x05%\xad'\xf7\x03\x03\x11\xfb\x03\x03\x11\xfd\x1do\x1dq\x03\x01\x1ds\x03\x03\xc1\x1du\t\x07\x1f)!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00#!\x03\x05\xb5\xb9\r\x05\xa7\xb7\xa1\xa3\x1dw\r\x05\xa7\xbb\xa1\xa3\x1dy\x1d{\x1d}\r\x03\xa1\xa3##\x1d\x7f\x13\t\x01\x1f\x0b\t\x00\x00\x00\x00\x1f%\x01\x13\t\x05\x07\x05\x1f\x07\t\x00\x00\x00\x00\x1f\x1d!\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x1f\x07\t\x00\x00\x00@\x1f\x15\x11\x08\x00\x00\x00\x00\x00\x00\x00\x0b\x03\x1d\x81\x1d\x83\x05\x01\r\x05\xe3\xe5\xe7\xe9\x1d\x85\x13\x1fV\x1d\x87\x13\x1fL\x03\x03\xaf\x03\x03\xef\x15\x03\x01\x01\x01\x03\x07\xaf\xf3\xf5\x1f+\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1f-\x01\x07\x01\x1f\x07\t\x00\x00\xc0\x7f\x1f\x1d!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x1f7\x11\x00\x00\x00\x00\x00\x00\x00\x00\x01\t\x01\x02\x02)\x05!!\x0f)\x01\x0f\x1d)\x01\x19\x01\t)\x03!\x0f)\x05!!\x19)\x01\t\x13\x1b)\x05!!\r)\x03\t\t!\x11\x01\x05\x05\x11\x11\x03\x05\x03\x05)\x03\x01\t)\x03\x02\x02\x0f)\x03\t\x17)\x03\x05\x17)\x03\x01\x17)\x01\r)\x05\x05\x05\r)\x03\x05\r)\x03!\r)\x03\x05\t\x04~\x04\x05\x01\x11\r/\x07\x03\x01\t\x0b\x11\r;\x07\x035e\x07\x03]\x1d\x03'\x13\x06c\x03\x05\x03\x01\x15\x07mi\x03\x05\x03\x03\r\x06s\x03\x05\x05\x03\x05\x05\x03\ry\x03\x07\x03\x07)\x03\x03\x05\x03\t\x17\x06)\x03\x05\x05\x07\x0b\x19\x07\t\x7f\x03\x05\x03\r\x05\x03\x01+\x03\x15\x05\x03\x01+\x03\x15\x1b\x07\x01\x87\x07\x05\x11\x0b\x03\x0f\x05\x03\x01!\x03\x0b\x03\x07\x01\x03\x03\x0b\x03\x1b\x0f\x07\x01\x9b\x03/\x05\x19\x1d\x03\x07\x01\x03\x031\x03\x1f\x05\x03\x01-\x03\x07\x03\x07\x01\x03\x03\x05\x03#\x03\x07\x01\x9d\x03\x1b\x03!\t\x06\x01\x03\x05\x07'\x15%\x03\x07\x01\x03\x033\x03\x1f\x05\x03\x01-\x03\x07\x03\x07\x01\x03\x03\x11\x03-\x03\x07\x01\x9f\x035\x03+\t\x06\x01\x03\x11\x071\x17/\x11\x04\r\x05)3\x0b\x11\t=\x07\x03\x15+\x03\x05\t\x07\x03A\x1d\x03\x13\x05\x03\t!\x03\x0b\x03\x07#\x03\x03\x13\x03\x05\r\x06#\x03\x13\x05\x03\x07\x07\x03IG\x03\x13\x0f\x07OM\x03\x1b\x05\t\x0b\x05\x03\tS\x03\x07\x03\x07U\x03\x03\x05\x03\x0f\t\x06Y\x03\x05\x07\r\x01\x11\x11\x04\t\x03\x13\x06\x03\x01\x05\x01\x00J\x1a\x89\x0b\x0b%\x03\x11\x0f\x0b\t\t\x0b!\x11#\x1f/!)!)#\x1f\x19\xa9\x0f99m\x19\x85\x89W\xb3K\x9bM\x9bn\x03\x1b%)9+\x1b\x1f\x1f\x15\x1d\x15+\x13\ri\x1f\x11\x15\x1b\x17\x15\x17\x0f\x11\x15\x11\x19)\x0f\x0b\x11builtin\x00vhlo\x00module\x00broadcast_in_dim_v1\x00constant_v1\x00iota_v1\x00select_v1\x00func_v1\x00add_v1\x00compare_v1\x00return_v1\x00reshape_v1\x00transpose_v1\x00divide_v1\x00call_v1\x00custom_call_v1\x00third_party/py/jax/tests/export_back_compat_test.py\x00value\x00sym_name\x00broadcast_dimensions\x00arg_attrs\x00function_type\x00res_attrs\x00sym_visibility\x00iota_dimension\x00compare_type\x00comparison_direction\x00jax.uses_shape_polymorphism\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit__lambda_\x00jit()/jit(main)/pjit[in_shardings=(UnspecifiedValue,) out_shardings=(UnspecifiedValue,) in_layouts=(None,) out_layouts=(None,) resource_env=None donated_invars=(False,) name=tril keep_unused=False inline=False]\x00jit()/jit(main)/jit(tril)/iota[dtype=int32 shape=(8, 8) dimension=0]\x00jit()/jit(main)/jit(tril)/add\x00jit()/jit(main)/jit(tril)/iota[dtype=int32 shape=(8, 8) dimension=1]\x00jit()/jit(main)/jit(tril)/ge\x00jit()/jit(main)/jit(tril)/broadcast_in_dim[shape=(8, 8) broadcast_dimensions=()]\x00jit()/jit(main)/jit(tril)/select_n\x00jit()/jit(main)/iota[dtype=float32 shape=(64,) dimension=0]\x00jit()/jit(main)/reshape[new_sizes=(8, 8) dimensions=None]\x00permutation\x00jit()/jit(main)/transpose[permutation=(1, 0)]\x00jit()/jit(main)/add\x00jit()/jit(main)/div\x00callee\x00jit()/jit(main)/eigh[lower=True sort_eigenvalues=True subset_by_index=None]\x00api_version\x00backend_config\x00call_target_name\x00called_computations\x00has_side_effect\x00mhlo.backend_config\x00operand_layouts\x00output_operand_aliases\x00result_layouts\x00mhlo.layout_mode\x00default\x00jax.result_info\x00tril\x00[0]\x00[1]\x00main\x00public\x00private\x00\x00lapack_ssyevd_ffi\x00mode\x00uplo\x00", + xla_call_module_version=9, + nr_devices=1, +) # End paste + + +# Pasted from the test output (see export_back_compat_test_util.py module docstring) +data_2024_08_19["f64"] = dict( + testdata_version=1, + platform='cpu', + custom_call_targets=['lapack_dsyevd_ffi'], + serialized_date=datetime.date(2024, 8, 19), + inputs=(), + expected_outputs=(array([[-6.1857700048412056e-01, 2.4081403770912022e-01, + 3.5662489253627483e-01, -6.3034019033669797e-01, + 1.0043483479985752e-16, -2.8842036081919542e-02, + 7.7164692943283169e-25, -1.8446994643771725e-01], + [-4.7070881487314614e-01, 4.7473787464450845e-01, + -4.8036836210243367e-01, 4.3802686872516400e-01, + 1.7961797619639258e-01, 8.3080980076741355e-03, + 2.1415294457221756e-01, -2.2856669794666584e-01], + [-3.2284062926217072e-01, -5.4336490915553370e-01, + 2.2181041859724990e-01, 2.9947877954402297e-01, + -3.6491813600134632e-01, 3.2867679819727436e-01, + 3.8223299448843473e-01, -2.7266344945561438e-01], + [-1.7497244365119530e-01, -8.9251550609769414e-02, + -6.3518515114898394e-02, 1.9162997359209971e-01, + -2.2087281326110139e-01, 5.9957027043505064e-02, + -8.7632498908241274e-01, -3.1676020096456303e-01], + [-2.7104258040220038e-02, -3.3772873786627672e-01, + 2.5901386593721748e-01, 1.7032650752287815e-01, + 6.7521217612940332e-01, -4.5036136532965476e-01, + -1.2279030059078447e-02, -3.6085695247351163e-01], + [ 1.2076392757075530e-01, -3.3834734096469254e-01, + -6.5506827461665540e-01, -5.0472498521116749e-01, + 6.9987430903492118e-02, 1.0595648906599275e-01, + 8.3443844143082022e-02, -4.0495370398246017e-01], + [ 2.6863211318173097e-01, 2.2958613191407318e-01, + 6.3952843755683941e-02, 1.8776775771084137e-02, + -5.3523731432241317e-01, -5.9199531677602002e-01, + 1.7916671834524248e-01, -4.4905045549140887e-01], + [ 4.1650029879270661e-01, 3.6355449432857079e-01, + 2.9755313100756142e-01, 1.6826270392615944e-02, + 1.9621068035557282e-01, 5.6830030587314817e-01, + 2.9607517592514246e-02, -4.9314720700035747e-01]]), array([-2.4598804776133626e+01, -4.6567755957874661e-14, + -1.9932120610662194e-14, -5.7323356091157378e-15, + -4.5459724251334835e-16, 4.0479851042511616e-14, + 9.2325194924982089e-14, 2.7659880477613365e+02])), + mlir_module_text=r""" +#loc6 = loc("third_party/py/jax/tests/export_back_compat_test.py":277:27) +#loc13 = loc("jit()/jit(main)/pjit[in_shardings=(UnspecifiedValue,) out_shardings=(UnspecifiedValue,) in_layouts=(None,) out_layouts=(None,) resource_env=None donated_invars=(False,) name=tril keep_unused=False inline=False]"(#loc6)) +module @jit__lambda_ attributes {jax.uses_shape_polymorphism = false, mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { + func.func public @main() -> (tensor<8x8xf64> {jax.result_info = "[0]", mhlo.layout_mode = "default"}, tensor<8xf64> {jax.result_info = "[1]", mhlo.layout_mode = "default"}) { + %0 = stablehlo.iota dim = 0 : tensor<64xf64> loc(#loc8) + %1 = stablehlo.reshape %0 : (tensor<64xf64>) -> tensor<8x8xf64> loc(#loc9) + %2 = stablehlo.transpose %1, dims = [1, 0] : (tensor<8x8xf64>) -> tensor<8x8xf64> loc(#loc10) + %3 = stablehlo.add %1, %2 : tensor<8x8xf64> loc(#loc11) + %cst = stablehlo.constant dense<2.000000e+00> : tensor loc(#loc) + %4 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor) -> tensor<8x8xf64> loc(#loc12) + %5 = stablehlo.divide %3, %4 : tensor<8x8xf64> loc(#loc12) + %6 = call @tril(%5) : (tensor<8x8xf64>) -> tensor<8x8xf64> loc(#loc13) + %c = stablehlo.constant dense<8> : tensor loc(#loc14) + %c_0 = stablehlo.constant dense<8> : tensor loc(#loc14) + %7:3 = stablehlo.custom_call @lapack_dsyevd_ffi(%6) {mhlo.backend_config = {mode = 86 : ui8, uplo = 76 : ui8}, operand_layouts = [dense<[0, 1]> : tensor<2xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<0> : tensor<1xindex>, dense<> : tensor<0xindex>]} : (tensor<8x8xf64>) -> (tensor<8x8xf64>, tensor<8xf64>, tensor) loc(#loc14) + %c_1 = stablehlo.constant dense<0> : tensor loc(#loc14) + %8 = stablehlo.broadcast_in_dim %c_1, dims = [] : (tensor) -> tensor loc(#loc14) + %9 = stablehlo.compare EQ, %7#2, %8, SIGNED : (tensor, tensor) -> tensor loc(#loc14) + %10 = stablehlo.broadcast_in_dim %9, dims = [] : (tensor) -> tensor<1x1xi1> loc(#loc14) + %cst_2 = stablehlo.constant dense<0x7FF8000000000000> : tensor loc(#loc14) + %11 = stablehlo.broadcast_in_dim %cst_2, dims = [] : (tensor) -> tensor<8x8xf64> loc(#loc14) + %12 = stablehlo.broadcast_in_dim %10, dims = [0, 1] : (tensor<1x1xi1>) -> tensor<8x8xi1> loc(#loc14) + %13 = stablehlo.select %12, %7#0, %11 : tensor<8x8xi1>, tensor<8x8xf64> loc(#loc14) + %14 = stablehlo.broadcast_in_dim %9, dims = [] : (tensor) -> tensor<1xi1> loc(#loc14) + %cst_3 = stablehlo.constant dense<0x7FF8000000000000> : tensor loc(#loc14) + %15 = stablehlo.broadcast_in_dim %cst_3, dims = [] : (tensor) -> tensor<8xf64> loc(#loc14) + %16 = stablehlo.broadcast_in_dim %14, dims = [0] : (tensor<1xi1>) -> tensor<8xi1> loc(#loc14) + %17 = stablehlo.select %16, %7#1, %15 : tensor<8xi1>, tensor<8xf64> loc(#loc14) + return %13, %17 : tensor<8x8xf64>, tensor<8xf64> loc(#loc) + } loc(#loc) + func.func private @tril(%arg0: tensor<8x8xf64> {mhlo.layout_mode = "default"} loc("jit()/jit(main)/pjit[in_shardings=(UnspecifiedValue,) out_shardings=(UnspecifiedValue,) in_layouts=(None,) out_layouts=(None,) resource_env=None donated_invars=(False,) name=tril keep_unused=False inline=False]"(#loc6))) -> (tensor<8x8xf64> {mhlo.layout_mode = "default"}) { + %0 = stablehlo.iota dim = 0 : tensor<8x8xi32> loc(#loc15) + %c = stablehlo.constant dense<0> : tensor loc(#loc13) + %1 = stablehlo.broadcast_in_dim %c, dims = [] : (tensor) -> tensor<8x8xi32> loc(#loc16) + %2 = stablehlo.add %0, %1 : tensor<8x8xi32> loc(#loc16) + %3 = stablehlo.iota dim = 1 : tensor<8x8xi32> loc(#loc17) + %4 = stablehlo.compare GE, %2, %3, SIGNED : (tensor<8x8xi32>, tensor<8x8xi32>) -> tensor<8x8xi1> loc(#loc18) + %cst = stablehlo.constant dense<0.000000e+00> : tensor loc(#loc13) + %5 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor) -> tensor<8x8xf64> loc(#loc19) + %6 = stablehlo.select %4, %arg0, %5 : tensor<8x8xi1>, tensor<8x8xf64> loc(#loc20) + return %6 : tensor<8x8xf64> loc(#loc13) + } loc(#loc13) +} loc(#loc) +#loc = loc(unknown) +#loc1 = loc("third_party/py/jax/tests/export_back_compat_test.py":269:26) +#loc2 = loc("third_party/py/jax/tests/export_back_compat_test.py":269:14) +#loc3 = loc("third_party/py/jax/tests/export_back_compat_test.py":271:34) +#loc4 = loc("third_party/py/jax/tests/export_back_compat_test.py":271:15) +#loc5 = loc("third_party/py/jax/tests/export_back_compat_test.py":271:14) +#loc7 = loc("third_party/py/jax/tests/export_back_compat_test.py":277:11) +#loc8 = loc("jit()/jit(main)/iota[dtype=float64 shape=(64,) dimension=0]"(#loc1)) +#loc9 = loc("jit()/jit(main)/reshape[new_sizes=(8, 8) dimensions=None]"(#loc2)) +#loc10 = loc("jit()/jit(main)/transpose[permutation=(1, 0)]"(#loc3)) +#loc11 = loc("jit()/jit(main)/add"(#loc4)) +#loc12 = loc("jit()/jit(main)/div"(#loc5)) +#loc14 = loc("jit()/jit(main)/eigh[lower=True sort_eigenvalues=True subset_by_index=None]"(#loc7)) +#loc15 = loc("jit()/jit(main)/jit(tril)/iota[dtype=int32 shape=(8, 8) dimension=0]"(#loc6)) +#loc16 = loc("jit()/jit(main)/jit(tril)/add"(#loc6)) +#loc17 = loc("jit()/jit(main)/jit(tril)/iota[dtype=int32 shape=(8, 8) dimension=1]"(#loc6)) +#loc18 = loc("jit()/jit(main)/jit(tril)/ge"(#loc6)) +#loc19 = loc("jit()/jit(main)/jit(tril)/broadcast_in_dim[shape=(8, 8) broadcast_dimensions=()]"(#loc6)) +#loc20 = loc("jit()/jit(main)/jit(tril)/select_n"(#loc6)) +""", + mlir_module_serialized=b"ML\xefR\x01StableHLO_v0.9.0\x00\x01+\x05\x01\x03\x01\x03\x05\x03\x1b\x07\t\x0b\r\x0f\x11\x13\x15\x17\x19\x1b\x1d\x1f\x03\x96\x02\xff9\x01\xa1\x0f\x13\x17\x0b\x0f\x0b\x07\x0b\x0b\x0f\x0b\x0b\x0b\x0b\x13\x0b\x13\x0f\x0b\x0b\x0f\x13\x13+\x0b\x0f\x0b\x0b\x0b33\x0b\x0f\x0b\x0b\x13\x0f\x0b\x1b\x0f\x0b\x13\x0f\x0b\x0f\x0b\x0f\x0b\x17\x0f\x0b\x17\x13\x0b\x0f\x0b\x17\x0f\x0b\x17\x13\x0b\x17\x13\x0b\x0b\x17S\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x1b\x13\x13\x03_\x0b\x0b\x0b\x0b\x0f\x0b\x0bO\x0b\x13\x1b\x0b\x1b\x0b\x0b\x0b\x13\x0b\x0b\x0f\x1f\x0f\x0f\x0b/O//\x0b\x0b\x0b\x0b\x1b\x0b\x0f\x0b\x0f\x0f\x0f\x17\x17/\x0f\x0b/O/\x01\x05\x0b\x0f\x035\x17\x0f\x07\x0f\x07\x07\x13\x17\x0f\x07\x07\x17\x13\x07\x17\x17\x13\x17\x13\x13\x13\x0f\x17\x13\x13\x13\x02j\t\x1d\x83\x85\x03\x03\x11\xcb\x17\x07V\x047\x05!\x1d?\x05\x05#\x1f\x05%\x05'\x11\x03\x05\x05)\x05+\x05-\x05/\x03\x03\x1f\xc7\x051\x03\x03\x0b\xc9\x1dE\x05\x053\x055\x1d{}\x03\x03\x0b\xd7\x03\x03\x0b\xf9\x03\t135\x137\x13\x0f9\x057\x11\x01\x00\x059\x05;\x05=\x03\x0b\x15\xa5\x17\xb1\x19\xb3\x0f\xbd\x1b\xbf\x03\x0b\x15\xa9\x17\xc3\x19\xa9\x0f\xab\x1b\xc5\x05?\x1dC\x05\x05A\x05C\x03\x03\x1f\xcd\x1dK\x05\x05E\x03\x05%\xad'\xcf\x1dQ\x05\x05G\x03\x03\x0b\xd1\x1dW\x05\x05I\x1d[\x05\x05K\x1d_a\x05M\x17\x076\x045\x1deg\x05O\x17\x076\x04\x1d\x03\x03k\xd3\x05Q\x1doq\x05S\x17\x07>\x04E\x1duw\x05U\x17\x07>\x04\x1f\x03\x03\x0b\xd5\x05W\x17\x07>\x04\x1d\x03\x03\x81\xab\x05Y\x05[\x17\x07V\x04\x17\x03\x13\x89\xd9\x8b\xdb\x8d\xdd\x8f\xa5\x91\xdf\x93\xe1\x95\xeb\x97\xed\x99\xf1\x05]\x05_\x05a\x05c\x05e\x05g\x05i\x05k\x05m\x03\x05%\xad'\xf7\x03\x03\x11\xfb\x03\x03\x11\xfd\x1do\x1dq\x03\x01\x1ds\x03\x03\xc1\x1du\t\x07\x1f)!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00#!\x03\x05\xb5\xb9\r\x05\xa7\xb7\xa1\xa3\x1dw\r\x05\xa7\xbb\xa1\xa3\x1dy\x1d{\x1d}\r\x03\xa1\xa3##\x1d\x7f\x13\t\x01\x1f\x0b\t\x00\x00\x00\x00\x1f%\x01\x13\t\x05\x07\x05\x1f\x07\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1f\x1d!\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x1f\x07\x11\x00\x00\x00\x00\x00\x00\x00@\x1f\x15\x11\x08\x00\x00\x00\x00\x00\x00\x00\x0b\x03\x1d\x81\x1d\x83\x05\x01\r\x05\xe3\xe5\xe7\xe9\x1d\x85\x13\x1fV\x1d\x87\x13\x1fL\x03\x03\xaf\x03\x03\xef\x15\x03\x01\x01\x01\x03\x07\xaf\xf3\xf5\x1f+\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1f-\x01\x07\x01\x1f\x07\x11\x00\x00\x00\x00\x00\x00\xf8\x7f\x1f\x1d!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x1f7\x11\x00\x00\x00\x00\x00\x00\x00\x00\x01\t\x01\x02\x02)\x05!!\x0f)\x01\x0f\x1d)\x01\x19\x01\x0b)\x03!\x0f)\x05!!\x19)\x01\t\x13\x1b)\x05!!\r)\x03\t\t!\x11\x01\x05\x05\x11\x11\x03\x05\x03\x05)\x03\x01\t)\x03\x02\x02\x0f)\x03\t\x17)\x03\x05\x17)\x03\x01\x17)\x01\r)\x05\x05\x05\r)\x03\x05\r)\x03!\r)\x03\x05\t\x04~\x04\x05\x01\x11\r/\x07\x03\x01\t\x0b\x11\r;\x07\x035e\x07\x03]\x1d\x03'\x13\x06c\x03\x05\x03\x01\x15\x07mi\x03\x05\x03\x03\r\x06s\x03\x05\x05\x03\x05\x05\x03\ry\x03\x07\x03\x07)\x03\x03\x05\x03\t\x17\x06)\x03\x05\x05\x07\x0b\x19\x07\t\x7f\x03\x05\x03\r\x05\x03\x01+\x03\x15\x05\x03\x01+\x03\x15\x1b\x07\x01\x87\x07\x05\x11\x0b\x03\x0f\x05\x03\x01!\x03\x0b\x03\x07\x01\x03\x03\x0b\x03\x1b\x0f\x07\x01\x9b\x03/\x05\x19\x1d\x03\x07\x01\x03\x031\x03\x1f\x05\x03\x01-\x03\x07\x03\x07\x01\x03\x03\x05\x03#\x03\x07\x01\x9d\x03\x1b\x03!\t\x06\x01\x03\x05\x07'\x15%\x03\x07\x01\x03\x033\x03\x1f\x05\x03\x01-\x03\x07\x03\x07\x01\x03\x03\x11\x03-\x03\x07\x01\x9f\x035\x03+\t\x06\x01\x03\x11\x071\x17/\x11\x04\r\x05)3\x0b\x11\t=\x07\x03\x15+\x03\x05\t\x07\x03A\x1d\x03\x13\x05\x03\t!\x03\x0b\x03\x07#\x03\x03\x13\x03\x05\r\x06#\x03\x13\x05\x03\x07\x07\x03IG\x03\x13\x0f\x07OM\x03\x1b\x05\t\x0b\x05\x03\tS\x03\x07\x03\x07U\x03\x03\x05\x03\x0f\t\x06Y\x03\x05\x07\r\x01\x11\x11\x04\t\x03\x13\x06\x03\x01\x05\x01\x00J\x1a\x89\x0b\x0b%\x03\x11\x0f\x0b\t\t\x0b!\x11#\x1f/!)!)#\x1f\x19\xa9\x0f99m\x19\x85\x89W\xb3K\x9bM\x9bn\x03\x1b%)9+\x1b\x1f\x1f\x15\x1d\x15+\x13\ri\x1f\x11\x15\x1b\x17\x15\x17\x0f\x11\x15\x11\x19)\x0f\x0b\x11builtin\x00vhlo\x00module\x00broadcast_in_dim_v1\x00constant_v1\x00iota_v1\x00select_v1\x00func_v1\x00add_v1\x00compare_v1\x00return_v1\x00reshape_v1\x00transpose_v1\x00divide_v1\x00call_v1\x00custom_call_v1\x00third_party/py/jax/tests/export_back_compat_test.py\x00value\x00sym_name\x00broadcast_dimensions\x00arg_attrs\x00function_type\x00res_attrs\x00sym_visibility\x00iota_dimension\x00compare_type\x00comparison_direction\x00jax.uses_shape_polymorphism\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit__lambda_\x00jit()/jit(main)/pjit[in_shardings=(UnspecifiedValue,) out_shardings=(UnspecifiedValue,) in_layouts=(None,) out_layouts=(None,) resource_env=None donated_invars=(False,) name=tril keep_unused=False inline=False]\x00jit()/jit(main)/jit(tril)/iota[dtype=int32 shape=(8, 8) dimension=0]\x00jit()/jit(main)/jit(tril)/add\x00jit()/jit(main)/jit(tril)/iota[dtype=int32 shape=(8, 8) dimension=1]\x00jit()/jit(main)/jit(tril)/ge\x00jit()/jit(main)/jit(tril)/broadcast_in_dim[shape=(8, 8) broadcast_dimensions=()]\x00jit()/jit(main)/jit(tril)/select_n\x00jit()/jit(main)/iota[dtype=float64 shape=(64,) dimension=0]\x00jit()/jit(main)/reshape[new_sizes=(8, 8) dimensions=None]\x00permutation\x00jit()/jit(main)/transpose[permutation=(1, 0)]\x00jit()/jit(main)/add\x00jit()/jit(main)/div\x00callee\x00jit()/jit(main)/eigh[lower=True sort_eigenvalues=True subset_by_index=None]\x00api_version\x00backend_config\x00call_target_name\x00called_computations\x00has_side_effect\x00mhlo.backend_config\x00operand_layouts\x00output_operand_aliases\x00result_layouts\x00mhlo.layout_mode\x00default\x00jax.result_info\x00tril\x00[0]\x00[1]\x00main\x00public\x00private\x00\x00lapack_dsyevd_ffi\x00mode\x00uplo\x00", + xla_call_module_version=9, + nr_devices=1, +) # End paste diff --git a/jax/_src/internal_test_util/export_back_compat_test_data/cpu_qr_lapack_geqrf.py b/jax/_src/internal_test_util/export_back_compat_test_data/cpu_qr_lapack_geqrf.py index 21448430ead6..94314a7ae518 100644 --- a/jax/_src/internal_test_util/export_back_compat_test_data/cpu_qr_lapack_geqrf.py +++ b/jax/_src/internal_test_util/export_back_compat_test_data/cpu_qr_lapack_geqrf.py @@ -98,6 +98,7 @@ xla_call_module_version=4, ) # End paste + # Pasted from the test output (see back_compat_test.py module docstring) data_2023_03_17["f64"] = dict( testdata_version=1, @@ -180,6 +181,7 @@ xla_call_module_version=4, ) # End paste + # Pasted from the test output (see back_compat_test.py module docstring) data_2023_03_17["c64"] = dict( testdata_version=1, @@ -346,3 +348,466 @@ mlir_module_serialized=b"ML\xefR\x03MLIRxxx-trunk\x00\x01+\x05\x01\x05\x01\x03\x05\x03\x1b\x07\t\x0b\r\x0f\x11\x13\x15\x17\x19\x1b\x1d\x1f\x03\xa6\x02\n\x029\x01\x9b\x0f\x0f\x17\x13\x0b\x0f\x13\x07\x0b\x0b\x0b\x13\x0b\x0b\x0b\x0b\x0b\x13\x0b\x0f\x0b\x0b\x13\x17\x13\x13\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x13\x13\x13\x13\x1b\x13\x13\x0b33\x0b\x0f\x0b\x13\x0b\x13\x0f\x0b\x1b\x0f\x0b\x0f\x0b\x0f\x0b\x0f\x0b\x0f\x0b\x0bK\x13\x13\x0f\x0b#\x0b\x0b\x0b\x0f\x0b\x0bK\x03g\x0fO/\x0b/\x0b\x0f\x0b\x0b\x0b\x0b\x0b\x0b\x13\x13\x0b\x13\x0b\x0b\x0b\x0b\x0b\x0b\x0f\x1f\x0f\x0f\x0bO\x1f\x1f\x1f\x0b\x1f\x0f\x17\x1b\x0f\x0f\x0f\x0f\x1f\x0bOO/\x0b'\x0f\x17\x17\x01\x05\x17\x0b\x039\x0f\x17\x0f\x07\x0b\x07\x07\x17\x13\x17\x17\x07\x0f\x17\x13\x17\x07\x17\x13\x13\x1b\x13\x13\x13\x13\x13\x13\x17\x02\x16\n\x1d\x7f\x05\x1d\x97\x05\x17!\xee\x05\x01\x03\x03\x15\xcd\x05!\x1dY\x05\x03\x03\t\xd7\x1f\x05#\x05%\x05'\x03\x03\t\xf1\x05)\x05+\x05-\x05/\x051\x03\x03%\xc9\x053\x1da\x05\x055\x057\x03\x03\t\xd3\x17!\xea\x05\x01\x03\x03\t\xd5\x03\x03\t\xd9\x059\x05;\x05=\x05?\x05A\x05C\x05E\x05G\x03\x03\x11\xe5\x03\x03\x11\xe7\x03\x03\x11\xe9\x03\x03\t\xed\x03\x05)\xab+\xef\x03\x03\x15\xf3\x03\x03\x13S\x05I\x03\x0b\x19\xa1\x1b\xb3\x1d\xb5\x13\xbf\x1f\xc1\x03\x0b\x19\xa7\x1b\xc5\x1d\xa7\x13\xa9\x1f\xc7\x05K\x1d]\x05\x05M\x03\x03\t\xcb\x05O\x03\x03%\xcf\x1dg\x05\x05Q\x03\x05)\xab+\xd1\x1dm\x05\x05S\x1dq\x05\x05U\x1du\x05\x05W\x1dy/\x05Y\x1d}/\x05[\x05]\x03\x115\xad7\xaf9\xdb;\xa1=\xb1?\xddA\xdfC\xe3\x03\x03\x11\xeb\x03\x03\x15\xf5\x1d\x89\x05\x05_\x03\x07\x8d\xa3\x8f\xa3\x91\xa3\x05a\x05c\x05e\x1d\x95\x05\x05g\x05i\x03\x115\xad7\xaf9\xf7;\xa1=\xb1?\xf9A\xfbC\xff\x1f+\x01\x1f-!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x1f/\x11\x00\x00\x00\x00\x00\x00\x00\x00\x03\x01\x1f\x1d\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1dk\x03\x03\xc3\x1dm\t\x07\x0b\x05\x1do\x05\x01#\x1f\x03\x05\xb7\xbb\r\x03\xa5\xb9\x1dq\r\x03\xa5\xbd\x1ds\x1du\x1dw\r\x01##\x1dy\x13\x0b\x01\x1f\x01\t\xff\xff\xff\xff\x1f%\x01\x13\x0b\x05\x07\x05\x1f\x05!\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x1f\x01\t\x01\x00\x00\x00\x1f\x01\t\x03\x00\x00\x00\x1f\x01\t`\x00\x00\x00\x1d{\x03\x0b\x9b\x9b\x9b\x9b\x9d\x03\x03\xe1\x15\x03\x01\x11\x01\x03\t\x9d\x9f\x9b\x9f\x13\x07\x01\x13\x07\x05\x13\x07\t\x13\x07\r\x1f\x01\t\x00\x00\x00\x00\x07\x01\x1f\x05!\x00\x00\x00\x00\x00\x00\xf8\x7f\x00\x00\x00\x00\x00\x00\xf8\x7f\x1f\x1d!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x1f5\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1d}\x03\x0f\x9b\x9b\x9b\x9b\x9b\x9d\x9f\x03\x03\xfd\x15\x03\x01\x15\x01\x03\x07\x9d\x9b\x9f\x03\x03\x06\x02\xa9\x05\x7f)\x01\x07)\x05\r\r\t)\x01\t\x1b\x03!\x1d\x01)\x05\r\r\x07)\x03\r\t)\x03\x02\x03\t)\x05\r\r\r\x13)\x01\r)\x05\x05\x05\r)\x03\t\x0b\x11\x01\x05\x03\x03\x0b\x11\x03\x03\x03\x03)\x03\x01\x0b)\x03%\t/\t\x03\x11\x01\x13)\x03\x01\x17)\x03\t\x17)\x03\x05\x17)\x03\x05\r)\x03\r\r)\x03\x05\x0b/\x07\x03\x01\x13\x04\xe6\x06\x05\x01\x11\x0fQ\x07\x03\x01\t\x0f\x11\x0fU\x05\x03Y\xb5\x0b\x03w#\x03'\x17\x06{\x03\x03\x03\x01\x03\x03\x011\x03\x01\x03\x03\x01\r\x03\x01\x03\x03\x01\r\x03\x01\x03\x03\x013\x03\x01\x13\x07\x01\x81\x03)\x0b\x05\x07\t\x0b\x03\x07\x07\x01E\x03\x03\x03\r\x07\x07\x01G\x03\x11\x03\r\x07\x07\x01I\x03\x01\x03\r\x07\x07\x01\x83\x03\x13\x03\r\x03\x03\x01K\x03\x01\x05\x07\x01\x07\x03\x01\x03\x17\r\x07\x01M\x03\x19\x05\x13\x19\x05\x07\x01\x07\x03\x1b\x03\x1b\x03\x03\x01\x17\x03\x05\x05\x07\x01\x07\x03\x03\x03\x1f\x05\x07\x01O\x03\x15\x03\x1d\t\x06\x01\x03\x03\x07#\x0f!\x05\x07\x01\x07\x031\x03\x1b\x03\x03\x01\x17\x03\x05\x05\x07\x01\x07\x03\x11\x03)\x05\x07\x01\x85\x033\x03'\t\x06\x01\x03\x11\x07-\x11+\x03\x03\x87-\x03\x05\x19\x07\x93\x8b\x03\x03\x05%1\x03\x03\x031\x03\x01\x03\x03\x03\r\x03\x01\x03\x03\x03\r\x03\x01\x03\x03\x03\r\x03\x01\x03\x03\x033\x03\x01\x13\x07\x03\x99\x037\x0f579;=3/\x07\x07\x03E\x03\x03\x03?\x07\x07\x03G\x03\x01\x03?\x07\x07\x03I\x03\x13\x03?\x03\x03\x03K\x03\x01\x05\x07\x03\x07\x03\x01\x03G\r\x07\x03M\x03\x19\x05CI\x05\x07\x03\x07\x03\x1b\x03K\x03\x03\x03\x17\x03\x05\x05\x07\x03\x07\x03\x03\x03O\x05\x07\x03O\x03\x15\x03M\t\x06\x03\x03\x03\x07SAQ\x1b\x07\x0b\x02\x02\x03\x03\x03%\x11\x04\x0f\x05UW\x0f\x11\x0bW\x05\x03\x15+\x03\x03\x0f\x0b\x03[#\x03\x0f\x03\x03\x0b_\x03\x01\x05\x07'\x07\x03\x0f\x03\x05\x15\x06'\x03\x0f\x05\x03\x07\x0b\x03ec\x03\x0f\r\x07ki\x03\x15\x05\t\x0b\x03\x03\x0b-\x03\x05\x05\x07o\x07\x03\x03\x03\x0f\t\x06s\x03\x03\x07\r\x11\x01\x11\x04\x0b\x03\x13\x06\x03\x01\x05\x01\x00\xd2\x18\x81\x0f\x1d\x1d\x11\x0f\x0b\t\t\x03\x0b!Y\x87##%_=\x85\x8dW\xb3K\x9bM\x9b\xd2\x02\x1b\x1f/!!)#\x1f\x19+\x1b\x1f\x83\x1f\x15\x1d\x15+\x13\r\r\x11\x0f\x17\x0f\x1f\x15\x11\x17\x11\x15+)\x19\x0f\x0b\x11builtin\x00vhlo\x00module\x00constant_v1\x00broadcast_in_dim_v1\x00get_tuple_element_v1\x00select_v1\x00iota_v1\x00compare_v1\x00func_v1\x00return_v1\x00custom_call_v1\x00add_v1\x00reshape_v1\x00pad_v1\x00call_v1\x00value\x00index\x00sym_name\x00broadcast_dimensions\x00arg_attrs\x00function_type\x00res_attrs\x00sym_visibility\x00third_party/py/jax/experimental/jax2tf/tests/back_compat_test.py\x00iota_dimension\x00compare_type\x00comparison_direction\x00api_version\x00backend_config\x00call_target_name\x00called_computations\x00has_side_effect\x00operand_layouts\x00output_operand_aliases\x00result_layouts\x00jit__lambda_\x00jit()/jit(main)/pjit[in_shardings=(UnspecifiedValue,) out_shardings=(UnspecifiedValue,) resource_env=None donated_invars=(False,) name=triu keep_unused=False inline=False]\x00jit()/jit(main)/jit(triu)/iota[dtype=int32 shape=(3, 3) dimension=0]\x00jit()/jit(main)/jit(triu)/add\x00jit()/jit(main)/jit(triu)/iota[dtype=int32 shape=(3, 3) dimension=1]\x00jit()/jit(main)/jit(triu)/ge\x00jit()/jit(main)/jit(triu)/broadcast_in_dim[shape=(3, 3) broadcast_dimensions=()]\x00jit()/jit(main)/jit(triu)/select_n\x00jit()/jit(main)/iota[dtype=complex128 shape=(9,) dimension=0]\x00jit()/jit(main)/reshape[new_sizes=(3, 3) dimensions=None]\x00jit()/jit(main)/geqrf\x00jit()/jit(main)/qr[full_matrices=True]\x00edge_padding_high\x00edge_padding_low\x00interior_padding\x00jit()/jit(main)/pad[padding_config=((0, 0, 0), (0, 0, 0))]\x00jit()/jit(main)/householder_product\x00jax.result_info\x00triu\x00\x00[0]\x00[1]\x00main\x00public\x00private\x00lapack_zgeqrf\x00lapack_zungqr\x00callee\x00", xla_call_module_version=4, ) # End paste + + +data_2024_08_22 = {} + + +# Pasted from the test output (see export_back_compat_test_util.py module docstring) +data_2024_08_22['c128'] = dict( + testdata_version=1, + platform='cpu', + custom_call_targets=['lapack_zgeqrf_ffi', 'lapack_zungqr_ffi'], + serialized_date=datetime.date(2024, 8, 22), + inputs=(), + expected_outputs=( + array([ + [0.0 + 0.0j, 0.9128709291752773 + 0.0j, 0.40824829046386235 + 0.0j], + [ + -0.447213595499958 - 0.0j, + 0.3651483716701102 + 0.0j, + -0.8164965809277263 + 0.0j, + ], + [ + -0.894427190999916 - 0.0j, + -0.1825741858350548 + 0.0j, + 0.40824829046386324 + 0.0j, + ], + ]), + array([ + [ + -6.7082039324993694e00 + 0.0j, + -8.0498447189992444e00 + 0.0j, + -9.3914855054991175e00 + 0.0j, + ], + [ + 0.0000000000000000e00 + 0.0j, + 1.0954451150103341e00 + 0.0j, + 2.1908902300206665e00 + 0.0j, + ], + [ + 0.0000000000000000e00 + 0.0j, + 0.0000000000000000e00 + 0.0j, + -8.8817841970012523e-16 + 0.0j, + ], + ]), + ), + mlir_module_text=r""" +#loc3 = loc("third_party/py/jax/tests/export_back_compat_test.py":364:11) +#loc10 = loc("jit()/jit(main)/pjit[in_shardings=(UnspecifiedValue,) out_shardings=(UnspecifiedValue,) in_layouts=(None,) out_layouts=(None,) resource_env=None donated_invars=(False,) name=triu keep_unused=False inline=False]"(#loc3)) +module @jit__lambda_ attributes {jax.uses_shape_polymorphism = false, mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { + func.func public @main() -> (tensor<3x3xcomplex> {jax.result_info = "[0]", mhlo.layout_mode = "default"}, tensor<3x3xcomplex> {jax.result_info = "[1]", mhlo.layout_mode = "default"}) { + %0 = stablehlo.iota dim = 0 : tensor<9xcomplex> loc(#loc4) + %1 = stablehlo.reshape %0 : (tensor<9xcomplex>) -> tensor<3x3xcomplex> loc(#loc5) + %c = stablehlo.constant dense<3> : tensor loc(#loc6) + %c_0 = stablehlo.constant dense<3> : tensor loc(#loc6) + %2:2 = stablehlo.custom_call @lapack_zgeqrf_ffi(%1) {mhlo.backend_config = {}, operand_layouts = [dense<[0, 1]> : tensor<2xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<0> : tensor<1xindex>]} : (tensor<3x3xcomplex>) -> (tensor<3x3xcomplex>, tensor<3xcomplex>) loc(#loc6) + %cst = stablehlo.constant dense<(0.000000e+00,0.000000e+00)> : tensor> loc(#loc7) + %3 = stablehlo.pad %2#0, %cst, low = [0, 0], high = [0, 0], interior = [0, 0] : (tensor<3x3xcomplex>, tensor>) -> tensor<3x3xcomplex> loc(#loc8) + %c_1 = stablehlo.constant dense<3> : tensor loc(#loc9) + %c_2 = stablehlo.constant dense<3> : tensor loc(#loc9) + %c_3 = stablehlo.constant dense<3> : tensor loc(#loc9) + %c_4 = stablehlo.constant dense<1> : tensor loc(#loc9) + %c_5 = stablehlo.constant dense<3> : tensor loc(#loc9) + %c_6 = stablehlo.constant dense<3> : tensor loc(#loc9) + %c_7 = stablehlo.constant dense<3> : tensor loc(#loc9) + %c_8 = stablehlo.constant dense<96> : tensor loc(#loc9) + %4:3 = stablehlo.custom_call @lapack_zungqr(%c_4, %c_5, %c_6, %c_7, %c_8, %3, %2#1) {api_version = 2 : i32, operand_layouts = [dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<[0, 1]> : tensor<2xindex>, dense<0> : tensor<1xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<> : tensor<0xindex>, dense<0> : tensor<1xindex>]} : (tensor, tensor, tensor, tensor, tensor, tensor<3x3xcomplex>, tensor<3xcomplex>) -> (tensor<3x3xcomplex>, tensor, tensor<96xcomplex>) loc(#loc9) + %c_9 = stablehlo.constant dense<0> : tensor loc(#loc9) + %5 = stablehlo.broadcast_in_dim %c_9, dims = [] : (tensor) -> tensor loc(#loc9) + %6 = stablehlo.compare EQ, %4#1, %5, SIGNED : (tensor, tensor) -> tensor loc(#loc9) + %7 = stablehlo.broadcast_in_dim %6, dims = [] : (tensor) -> tensor<1x1xi1> loc(#loc9) + %cst_10 = stablehlo.constant dense<(0x7FF8000000000000,0x7FF8000000000000)> : tensor> loc(#loc9) + %8 = stablehlo.broadcast_in_dim %cst_10, dims = [] : (tensor>) -> tensor<3x3xcomplex> loc(#loc9) + %9 = stablehlo.broadcast_in_dim %7, dims = [0, 1] : (tensor<1x1xi1>) -> tensor<3x3xi1> loc(#loc9) + %10 = stablehlo.select %9, %4#0, %8 : tensor<3x3xi1>, tensor<3x3xcomplex> loc(#loc9) + %11 = call @triu(%2#0) : (tensor<3x3xcomplex>) -> tensor<3x3xcomplex> loc(#loc10) + return %10, %11 : tensor<3x3xcomplex>, tensor<3x3xcomplex> loc(#loc) + } loc(#loc) + func.func private @triu(%arg0: tensor<3x3xcomplex> {mhlo.layout_mode = "default"} loc("jit()/jit(main)/pjit[in_shardings=(UnspecifiedValue,) out_shardings=(UnspecifiedValue,) in_layouts=(None,) out_layouts=(None,) resource_env=None donated_invars=(False,) name=triu keep_unused=False inline=False]"(#loc3))) -> (tensor<3x3xcomplex> {mhlo.layout_mode = "default"}) { + %0 = stablehlo.iota dim = 0 : tensor<3x3xi32> loc(#loc11) + %c = stablehlo.constant dense<-1> : tensor loc(#loc10) + %1 = stablehlo.broadcast_in_dim %c, dims = [] : (tensor) -> tensor<3x3xi32> loc(#loc12) + %2 = stablehlo.add %0, %1 : tensor<3x3xi32> loc(#loc12) + %3 = stablehlo.iota dim = 1 : tensor<3x3xi32> loc(#loc13) + %4 = stablehlo.compare GE, %2, %3, SIGNED : (tensor<3x3xi32>, tensor<3x3xi32>) -> tensor<3x3xi1> loc(#loc14) + %cst = stablehlo.constant dense<(0.000000e+00,0.000000e+00)> : tensor> loc(#loc10) + %5 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor>) -> tensor<3x3xcomplex> loc(#loc15) + %6 = stablehlo.select %4, %5, %arg0 : tensor<3x3xi1>, tensor<3x3xcomplex> loc(#loc16) + return %6 : tensor<3x3xcomplex> loc(#loc10) + } loc(#loc10) +} loc(#loc) +#loc = loc(unknown) +#loc1 = loc("third_party/py/jax/tests/export_back_compat_test.py":363:26) +#loc2 = loc("third_party/py/jax/tests/export_back_compat_test.py":363:14) +#loc4 = loc("jit()/jit(main)/iota[dtype=complex128 shape=(9,) dimension=0]"(#loc1)) +#loc5 = loc("jit()/jit(main)/reshape[new_sizes=(3, 3) dimensions=None]"(#loc2)) +#loc6 = loc("jit()/jit(main)/geqrf"(#loc3)) +#loc7 = loc("jit()/jit(main)/qr[full_matrices=True]"(#loc3)) +#loc8 = loc("jit()/jit(main)/pad[padding_config=((0, 0, 0), (0, 0, 0))]"(#loc3)) +#loc9 = loc("jit()/jit(main)/householder_product"(#loc3)) +#loc11 = loc("jit()/jit(main)/jit(triu)/iota[dtype=int32 shape=(3, 3) dimension=0]"(#loc3)) +#loc12 = loc("jit()/jit(main)/jit(triu)/add"(#loc3)) +#loc13 = loc("jit()/jit(main)/jit(triu)/iota[dtype=int32 shape=(3, 3) dimension=1]"(#loc3)) +#loc14 = loc("jit()/jit(main)/jit(triu)/ge"(#loc3)) +#loc15 = loc("jit()/jit(main)/jit(triu)/broadcast_in_dim[shape=(3, 3) broadcast_dimensions=()]"(#loc3)) +#loc16 = loc("jit()/jit(main)/jit(triu)/select_n"(#loc3)) +""", + mlir_module_serialized=( + b"ML\xefR\x01StableHLO_v0.9.0\x00\x01)\x05\x01\x03\x01\x03\x05\x03\x19\x07\t\x0b\r\x0f\x11\x13\x15\x17\x19\x1b\x1d\x03\xae\x02\x12\x025\x01\x9d\x0f\x17\x0b\x0f\x13\x13\x0b\x07\x0b\x0f\x13\x0f\x0b\x0b\x0b\x0b\x13\x0b\x0b\x0f\x0b\x0b\x13\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b+\x0b\x0f\x0b\x0b\x0b33\x0b\x0f\x0b\x13\x0b\x13\x0f\x0b\x1b\x0f\x0b\x0f\x0b\x0f\x0b\x0f\x0b\x17\x0f\x0b\x17\x0bS\x0b\x0f\x0b#\x0b\x0b\x0b\x0f\x0b\x0b\x13\x13K\x13\x1b\x13\x03g\x0fO\x0b\x0b\x0b//\x0b\x0f\x0b\x0b\x0b\x0b\x0b\x13\x1b\x0b\x1b\x0b\x0b\x0b\x13\x0b\x0b\x0f\x1f\x0f\x0f\x0bO/\x0b\x0b\x0b\x0f\x0f\x17\x13\x1f\x1f\x1f\x0b\x0b'\x0f\x17\x17\x1f\x0bOO\x01\x07\x17\x17\x0b\x01\x05\x0b\x0f\x031\x17\x0f\x0f\x0b\x07\x0f\x17\x07\x07\x07\x17\x13\x17\x07\x17\x13\x13\x13\x13\x13\x17\x13\x0f\x17\x02\xf2\t\x1d\x8f\x03\x17\x11\xb2\x05\x17\x05\x1f\x1dO\x03\x03\x03%\xd1\x03\x03\x05\xd9\x05!\x1f\x05#\x1dy\x03\x03\x03\x05\xeb\x11\x03\x05\x05%\x05'\x05)\x05+\x03\x03#\xcd\x05-\x05/\x1dW\x03\x051\x053\x03\x03\x05\xd7\x055\x057\x059\x05;\x05=\x05?\x05A\x05C\x03\tACE\x17G\x17\rI\x05E\x11\x01\x00\x05G\x05I\x05K\x03\x0b\x19\xa1\x1b\xb7\x1d\xb9\r\xc3\x1f\xc5\x03\x0b\x19\xad\x1b\xc9\x1d\xad\r\xaf\x1f\xcb\x05M\x1dS\x03\x05O\x03\x03\x05\xcf\x05Q\x03\x03#\xd3\x1d]\x03\x05S\x03\x05)\xb1+\xd5\x1dc\x03\x05U\x1dg\x03\x05W\x1dk\x03\x05Y\x1doq\x05[\x17\x11\xae\x055\x1duw\x05]\x17\x11\xae\x05\x1d\x05_\x03\x13/\xdb1\xb33\xdd5\xa17\xb5}\xdf9\xe1;\xe3=\xe7\x05a\x1d\x81\x03\x05c\x03\x07\x85\xa9\x87\xa9\x89\xa9\x05e\x05g\x05i\x1d\x8d\x03\x05k\x05m\x03\x03\x05\xe9\x03\x03\x05\xed\x03\x11/\xef1\xb33\xf15\xa17\xb59\xf3;\xf5=\xf9\x03\x03\x05\xfb\x03\x05)\xb1+\xfd\x03\x03\x05\xff\x1f/\x01\x1f)!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x03\x01\x1do\x1dq\x1f+\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1f\x1b\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1ds\x03\x03\xc7\x1du\t\x07\x1dw\x05\x01#\x1d\x03\x05\xbb\xbf\r\x05\xab\xbd\xa3\xa5\x1dy\r\x05\xab\xc1\xa3\xa5\x1d{\x1d}\x1d\x7f\r\x03\xa3\xa5#!\x1d\x81\x13\r\x01\x1f\x07\t\xff\xff\xff\xff\x1f#\x01\x13\r\x05\x07\x05\x1f\x0f!\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x1f\t\x11\x03\x00\x00\x00\x00\x00\x00\x00\x0b\x03\x1d\x83\r\x01\x03\x03\x9f\x03\x03\xe5\x15\x03\x01\x01\x01\x03\x05\x9f\xa7\x1f\x07\t\x01\x00\x00\x00\x1f\x07\t\x03\x00\x00\x00\x1f\x07\t`\x00\x00\x00\x0b\x05\x1d\x85\x03\x0f\x9d\x9d\x9d\x9d\x9d\x9f\xa7\x03\x03\xf7\x15\x03\x01\x15\x01\x03\x07\x9f\x9d\xa7\x1f\x07\t\x00\x00\x00\x00\x07\x01\x1f\x0f!\x00\x00\x00\x00\x00\x00\xf8\x7f\x00\x00\x00\x00\x00\x00\xf8\x7f\x1f\x1b!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x03\x03%\x02\x02\x03\x03\x0e\x02\xaf\x05\x87\x01\t\x01\x02\x02)\x05\r\r\x0b)\x01\x17)\x01\r\x03\x1f\x1d)\x01\x0b)\x05\r\r\x17\x01\x13\x1b)\x05\r\r\x13)\x03\t\r\x11\x01\x05\x05\x05\x0b\x11\x03\x05\x03\x05)\x03\x01\r)\x03%\x0b)\x03\r\x0b)\x03\t\x15)\x03\x05\x15)\x03\x02\x03\x0b)\x03\x01\x15)\x01\x13)\x05\x05\x05\x13\x04\x8a\x04\x05\x01\x11\x0f?\x07\x03\x01\t\t\x11\x0fK\x07\x039i\x07\x03m!\x03%\x15\x06s\x03\x05\x03\x01\x03\x03\x13\x0b\x03\t\x03\x03\x13\x0b\x03\t\x11\x07\x13{\x05\x05'\x03\x03\x03\x03\x7f-\x03\x0f\x17\x07\x8b\x83\x03\x05\x05\t\r\x03\x03\x01\x0b\x03\t\x03\x03\x01\x0b\x03\t\x03\x03\x01\x0b\x03\t\x03\x03\x01\x91\x03\x07\x03\x03\x01\x15\x03\x07\x03\x03\x01\x15\x03\x07\x03\x03\x01\x15\x03\x07\x03\x03\x01\x93\x03\x07\x11\x07\x01\x95\x07\x05\x07-\x0f\x17\x19\x1b\x1d\x1f\x0f\x0b\x03\x03\x01\x97\x03\x07\x05\x07\x01\t\x03\x07\x03'\x0b\x07\x01\x99\x031\x05#)\x05\x07\x01\t\x033\x03+\x03\x03\x01\x9b\x03\x0f\x05\x07\x01\t\x03\x05\x03/\x05\x07\x01\x06\x02\x03\x19\x03-\r\x06\x01\x03\x05\x073!1\x19\x07\x07\n\x02\x03\x05\x03\t\x0f\x04\x0f\x0557\t\x11\x07M\x07\x03\x15+\x03\x05\x07\x07\x03Q!\x03\x11\x03\x03\x07U\x03\x07\x05\x07'\t\x03\x11\x03\x05\x13\x06'\x03\x11\x05\x03\x07\x07\x03[Y\x03\x11\x0b\x07a_\x03\x19\x05\t\x0b\x03\x03\x07-\x03\x0f\x05\x07e\t\x03\x05\x03\x0f\r\x06i\x03\x05\x07\r\x11\x01\x0f\x04\x07\x03\x13\x06\x03\x01\x05\x01\x00\xaa\x1a\x89\x0f\x1d%\x11\x0f\x0b\t\t\x03\x0b!\x11#Y\x87##%_)=\x85\x8dW\xb3K\x9bM\x9bn\x03\x1b%)9\x1f/!!)#\x1f\x19+\x1b+\x1f\x1f\x15\x1d\x15i\x13\r\x11\x0f\x17\x0f\x1f\x15\x15\x17\x11\x11)\x19\x0f\x0b\x11builtin\x00vhlo\x00module\x00constant_v1\x00broadcast_in_dim_v1\x00iota_v1\x00func_v1\x00compare_v1\x00select_v1\x00return_v1\x00custom_call_v1\x00add_v1\x00reshape_v1\x00pad_v1\x00call_v1\x00value\x00sym_name\x00third_party/py/jax/tests/export_back_compat_test.py\x00arg_attrs\x00function_type\x00res_attrs\x00sym_visibility\x00iota_dimension\x00broadcast_dimensions\x00compare_type\x00comparison_direction\x00api_version\x00backend_config\x00call_target_name\x00called_computations\x00has_side_effect\x00operand_layouts\x00output_operand_aliases\x00result_layouts\x00jax.uses_shape_polymorphism\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit__lambda_\x00jit()/jit(main)/pjit[in_shardings=(UnspecifiedValue,)" + b' out_shardings=(UnspecifiedValue,) in_layouts=(None,)' + b' out_layouts=(None,) resource_env=None donated_invars=(False,)' + b' name=triu keep_unused=False' + b' inline=False]\x00jit()/jit(main)/jit(triu)/iota[dtype=int32' + b' shape=(3, 3)' + b' dimension=0]\x00jit()/jit(main)/jit(triu)/add\x00jit()/jit(main)/jit(triu)/iota[dtype=int32' + b' shape=(3, 3)' + b' dimension=1]\x00jit()/jit(main)/jit(triu)/ge\x00jit()/jit(main)/jit(triu)/broadcast_in_dim[shape=(3,' + b' 3) broadcast_dimensions=()]\x00jit()/jit(main)/jit(triu)/select_n\x00jit()/jit(main)/iota[dtype=complex128' + b' shape=(9,)' + b' dimension=0]\x00jit()/jit(main)/reshape[new_sizes=(3, 3)' + b' dimensions=None]\x00jit()/jit(main)/geqrf\x00mhlo.backend_config\x00jit()/jit(main)/qr[full_matrices=True]\x00edge_padding_high\x00edge_padding_low\x00interior_padding\x00jit()/jit(main)/pad[padding_config=((0,' + b' 0, 0), (0, 0,' + b' 0))]\x00jit()/jit(main)/householder_product\x00mhlo.layout_mode\x00default\x00jax.result_info\x00triu\x00\x00[0]\x00[1]\x00main\x00public\x00private\x00lapack_zgeqrf_ffi\x00lapack_zungqr\x00callee\x00' + ), + xla_call_module_version=9, + nr_devices=1, +) # End paste + + +# Pasted from the test output (see export_back_compat_test_util.py module docstring) +data_2024_08_22['c64'] = dict( + testdata_version=1, + platform='cpu', + custom_call_targets=['lapack_cgeqrf_ffi', 'lapack_cungqr_ffi'], + serialized_date=datetime.date(2024, 8, 22), + inputs=(), + expected_outputs=( + array( + [ + [0.0 + 0.0j, 0.91287076 + 0.0j, 0.4082487 + 0.0j], + [-0.44721356 - 0.0j, 0.36514866 + 0.0j, -0.8164965 + 0.0j], + [-0.8944271 - 0.0j, -0.18257445 + 0.0j, 0.40824816 + 0.0j], + ], + dtype=complex64, + ), + array( + [ + [ + -6.7082043e00 + 0.0j, + -8.0498438e00 + 0.0j, + -9.3914852e00 + 0.0j, + ], + [0.0000000e00 + 0.0j, 1.0954441e00 + 0.0j, 2.1908894e00 + 0.0j], + [ + 0.0000000e00 + 0.0j, + 0.0000000e00 + 0.0j, + 7.1525574e-07 + 0.0j, + ], + ], + dtype=complex64, + ), + ), + mlir_module_text=r""" +#loc3 = loc("third_party/py/jax/tests/export_back_compat_test.py":364:11) +#loc10 = loc("jit()/jit(main)/pjit[in_shardings=(UnspecifiedValue,) out_shardings=(UnspecifiedValue,) in_layouts=(None,) out_layouts=(None,) resource_env=None donated_invars=(False,) name=triu keep_unused=False inline=False]"(#loc3)) +module @jit__lambda_ attributes {jax.uses_shape_polymorphism = false, mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { + func.func public @main() -> (tensor<3x3xcomplex> {jax.result_info = "[0]", mhlo.layout_mode = "default"}, tensor<3x3xcomplex> {jax.result_info = "[1]", mhlo.layout_mode = "default"}) { + %0 = stablehlo.iota dim = 0 : tensor<9xcomplex> loc(#loc4) + %1 = stablehlo.reshape %0 : (tensor<9xcomplex>) -> tensor<3x3xcomplex> loc(#loc5) + %c = stablehlo.constant dense<3> : tensor loc(#loc6) + %c_0 = stablehlo.constant dense<3> : tensor loc(#loc6) + %2:2 = stablehlo.custom_call @lapack_cgeqrf_ffi(%1) {mhlo.backend_config = {}, operand_layouts = [dense<[0, 1]> : tensor<2xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<0> : tensor<1xindex>]} : (tensor<3x3xcomplex>) -> (tensor<3x3xcomplex>, tensor<3xcomplex>) loc(#loc6) + %cst = stablehlo.constant dense<(0.000000e+00,0.000000e+00)> : tensor> loc(#loc7) + %3 = stablehlo.pad %2#0, %cst, low = [0, 0], high = [0, 0], interior = [0, 0] : (tensor<3x3xcomplex>, tensor>) -> tensor<3x3xcomplex> loc(#loc8) + %c_1 = stablehlo.constant dense<3> : tensor loc(#loc9) + %c_2 = stablehlo.constant dense<3> : tensor loc(#loc9) + %c_3 = stablehlo.constant dense<3> : tensor loc(#loc9) + %c_4 = stablehlo.constant dense<1> : tensor loc(#loc9) + %c_5 = stablehlo.constant dense<3> : tensor loc(#loc9) + %c_6 = stablehlo.constant dense<3> : tensor loc(#loc9) + %c_7 = stablehlo.constant dense<3> : tensor loc(#loc9) + %c_8 = stablehlo.constant dense<96> : tensor loc(#loc9) + %4:3 = stablehlo.custom_call @lapack_cungqr(%c_4, %c_5, %c_6, %c_7, %c_8, %3, %2#1) {api_version = 2 : i32, operand_layouts = [dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<[0, 1]> : tensor<2xindex>, dense<0> : tensor<1xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<> : tensor<0xindex>, dense<0> : tensor<1xindex>]} : (tensor, tensor, tensor, tensor, tensor, tensor<3x3xcomplex>, tensor<3xcomplex>) -> (tensor<3x3xcomplex>, tensor, tensor<96xcomplex>) loc(#loc9) + %c_9 = stablehlo.constant dense<0> : tensor loc(#loc9) + %5 = stablehlo.broadcast_in_dim %c_9, dims = [] : (tensor) -> tensor loc(#loc9) + %6 = stablehlo.compare EQ, %4#1, %5, SIGNED : (tensor, tensor) -> tensor loc(#loc9) + %7 = stablehlo.broadcast_in_dim %6, dims = [] : (tensor) -> tensor<1x1xi1> loc(#loc9) + %cst_10 = stablehlo.constant dense<(0x7FC00000,0x7FC00000)> : tensor> loc(#loc9) + %8 = stablehlo.broadcast_in_dim %cst_10, dims = [] : (tensor>) -> tensor<3x3xcomplex> loc(#loc9) + %9 = stablehlo.broadcast_in_dim %7, dims = [0, 1] : (tensor<1x1xi1>) -> tensor<3x3xi1> loc(#loc9) + %10 = stablehlo.select %9, %4#0, %8 : tensor<3x3xi1>, tensor<3x3xcomplex> loc(#loc9) + %11 = call @triu(%2#0) : (tensor<3x3xcomplex>) -> tensor<3x3xcomplex> loc(#loc10) + return %10, %11 : tensor<3x3xcomplex>, tensor<3x3xcomplex> loc(#loc) + } loc(#loc) + func.func private @triu(%arg0: tensor<3x3xcomplex> {mhlo.layout_mode = "default"} loc("jit()/jit(main)/pjit[in_shardings=(UnspecifiedValue,) out_shardings=(UnspecifiedValue,) in_layouts=(None,) out_layouts=(None,) resource_env=None donated_invars=(False,) name=triu keep_unused=False inline=False]"(#loc3))) -> (tensor<3x3xcomplex> {mhlo.layout_mode = "default"}) { + %0 = stablehlo.iota dim = 0 : tensor<3x3xi32> loc(#loc11) + %c = stablehlo.constant dense<-1> : tensor loc(#loc10) + %1 = stablehlo.broadcast_in_dim %c, dims = [] : (tensor) -> tensor<3x3xi32> loc(#loc12) + %2 = stablehlo.add %0, %1 : tensor<3x3xi32> loc(#loc12) + %3 = stablehlo.iota dim = 1 : tensor<3x3xi32> loc(#loc13) + %4 = stablehlo.compare GE, %2, %3, SIGNED : (tensor<3x3xi32>, tensor<3x3xi32>) -> tensor<3x3xi1> loc(#loc14) + %cst = stablehlo.constant dense<(0.000000e+00,0.000000e+00)> : tensor> loc(#loc10) + %5 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor>) -> tensor<3x3xcomplex> loc(#loc15) + %6 = stablehlo.select %4, %5, %arg0 : tensor<3x3xi1>, tensor<3x3xcomplex> loc(#loc16) + return %6 : tensor<3x3xcomplex> loc(#loc10) + } loc(#loc10) +} loc(#loc) +#loc = loc(unknown) +#loc1 = loc("third_party/py/jax/tests/export_back_compat_test.py":363:26) +#loc2 = loc("third_party/py/jax/tests/export_back_compat_test.py":363:14) +#loc4 = loc("jit()/jit(main)/iota[dtype=complex64 shape=(9,) dimension=0]"(#loc1)) +#loc5 = loc("jit()/jit(main)/reshape[new_sizes=(3, 3) dimensions=None]"(#loc2)) +#loc6 = loc("jit()/jit(main)/geqrf"(#loc3)) +#loc7 = loc("jit()/jit(main)/qr[full_matrices=True]"(#loc3)) +#loc8 = loc("jit()/jit(main)/pad[padding_config=((0, 0, 0), (0, 0, 0))]"(#loc3)) +#loc9 = loc("jit()/jit(main)/householder_product"(#loc3)) +#loc11 = loc("jit()/jit(main)/jit(triu)/iota[dtype=int32 shape=(3, 3) dimension=0]"(#loc3)) +#loc12 = loc("jit()/jit(main)/jit(triu)/add"(#loc3)) +#loc13 = loc("jit()/jit(main)/jit(triu)/iota[dtype=int32 shape=(3, 3) dimension=1]"(#loc3)) +#loc14 = loc("jit()/jit(main)/jit(triu)/ge"(#loc3)) +#loc15 = loc("jit()/jit(main)/jit(triu)/broadcast_in_dim[shape=(3, 3) broadcast_dimensions=()]"(#loc3)) +#loc16 = loc("jit()/jit(main)/jit(triu)/select_n"(#loc3)) +""", + mlir_module_serialized=( + b"ML\xefR\x01StableHLO_v0.9.0\x00\x01)\x05\x01\x03\x01\x03\x05\x03\x19\x07\t\x0b\r\x0f\x11\x13\x15\x17\x19\x1b\x1d\x03\xae\x02\x12\x025\x01\x9d\x0f\x17\x0b\x0f\x13\x13\x0b\x07\x0b\x0f\x13\x0f\x0b\x0b\x0b\x0b\x13\x0b\x0b\x0f\x0b\x0b\x13\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b+\x0b\x0f\x0b\x0b\x0b33\x0b\x0f\x0b\x13\x0b\x13\x0f\x0b\x1b\x0f\x0b\x0f\x0b\x0f\x0b\x0f\x0b\x17\x0f\x0b\x17\x0bS\x0b\x0f\x0b#\x0b\x0b\x0b\x0f\x0b\x0b\x13\x13K\x13\x1b\x13\x03g\x0fO\x0b\x0b\x0b//\x0b\x0f\x0b\x0b\x0b\x0b\x0b\x13\x1b\x0b\x1b\x0b\x0b\x0b\x13\x0b\x0b\x0f\x1f\x0f\x0f\x0b//\x0b\x0b\x0b\x0f\x0f\x17\x13\x1f\x1f\x1f\x0b\x0b'\x0f\x17\x17\x1f\x0b/O\x01\x07\x17\x17\x0b\x01\x05\x0b\x0f\x031\x17\x0f\x0f\x0b\x07\x0f\x17\x07\x07\x07\x17\x13\x17\x07\x17\x13\x13\x13\x13\x13\x17\x13\x0f\x17\x02\xb2\t\x1d\x8f\x03\x17\x11\xb2\x05\x17\x05\x1f\x1dO\x03\x03\x03%\xd1\x03\x03\x05\xd9\x05!\x1f\x05#\x1dy\x03\x03\x03\x05\xeb\x11\x03\x05\x05%\x05'\x05)\x05+\x03\x03#\xcd\x05-\x05/\x1dW\x03\x051\x053\x03\x03\x05\xd7\x055\x057\x059\x05;\x05=\x05?\x05A\x05C\x03\tACE\x17G\x17\rI\x05E\x11\x01\x00\x05G\x05I\x05K\x03\x0b\x19\xa1\x1b\xb7\x1d\xb9\r\xc3\x1f\xc5\x03\x0b\x19\xad\x1b\xc9\x1d\xad\r\xaf\x1f\xcb\x05M\x1dS\x03\x05O\x03\x03\x05\xcf\x05Q\x03\x03#\xd3\x1d]\x03\x05S\x03\x05)\xb1+\xd5\x1dc\x03\x05U\x1dg\x03\x05W\x1dk\x03\x05Y\x1doq\x05[\x17\x11\xae\x055\x1duw\x05]\x17\x11\xae\x05\x1d\x05_\x03\x13/\xdb1\xb33\xdd5\xa17\xb5}\xdf9\xe1;\xe3=\xe7\x05a\x1d\x81\x03\x05c\x03\x07\x85\xa9\x87\xa9\x89\xa9\x05e\x05g\x05i\x1d\x8d\x03\x05k\x05m\x03\x03\x05\xe9\x03\x03\x05\xed\x03\x11/\xef1\xb33\xf15\xa17\xb59\xf3;\xf5=\xf9\x03\x03\x05\xfb\x03\x05)\xb1+\xfd\x03\x03\x05\xff\x1f/\x01\x1f)!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x03\x01\x1do\x1dq\x1f+\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1f\x1b\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1ds\x03\x03\xc7\x1du\t\x07\x1dw\x05\x01#\x1d\x03\x05\xbb\xbf\r\x05\xab\xbd\xa3\xa5\x1dy\r\x05\xab\xc1\xa3\xa5\x1d{\x1d}\x1d\x7f\r\x03\xa3\xa5#!\x1d\x81\x13\r\x01\x1f\x07\t\xff\xff\xff\xff\x1f#\x01\x13\r\x05\x07\x05\x1f\x0f\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1f\t\x11\x03\x00\x00\x00\x00\x00\x00\x00\x0b\x03\x1d\x83\r\x01\x03\x03\x9f\x03\x03\xe5\x15\x03\x01\x01\x01\x03\x05\x9f\xa7\x1f\x07\t\x01\x00\x00\x00\x1f\x07\t\x03\x00\x00\x00\x1f\x07\t`\x00\x00\x00\x0b\x05\x1d\x85\x03\x0f\x9d\x9d\x9d\x9d\x9d\x9f\xa7\x03\x03\xf7\x15\x03\x01\x15\x01\x03\x07\x9f\x9d\xa7\x1f\x07\t\x00\x00\x00\x00\x07\x01\x1f\x0f\x11\x00\x00\xc0\x7f\x00\x00\xc0\x7f\x1f\x1b!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x03\x03%\x02\x02\x03\x03\x0e\x02\xaf\x05\x87\x01\t\x01\x02\x02)\x05\r\r\x0b)\x01\x17)\x01\r\x03\x1f\x1d)\x01\x0b)\x05\r\r\x17\x01\x13\x1b)\x05\r\r\x13)\x03\t\r\x11\x01\x05\x05\x05\t\x11\x03\x05\x03\x05)\x03\x01\r)\x03%\x0b)\x03\r\x0b)\x03\t\x15)\x03\x05\x15)\x03\x02\x03\x0b)\x03\x01\x15)\x01\x13)\x05\x05\x05\x13\x04\x8a\x04\x05\x01\x11\x0f?\x07\x03\x01\t\t\x11\x0fK\x07\x039i\x07\x03m!\x03%\x15\x06s\x03\x05\x03\x01\x03\x03\x13\x0b\x03\t\x03\x03\x13\x0b\x03\t\x11\x07\x13{\x05\x05'\x03\x03\x03\x03\x7f-\x03\x0f\x17\x07\x8b\x83\x03\x05\x05\t\r\x03\x03\x01\x0b\x03\t\x03\x03\x01\x0b\x03\t\x03\x03\x01\x0b\x03\t\x03\x03\x01\x91\x03\x07\x03\x03\x01\x15\x03\x07\x03\x03\x01\x15\x03\x07\x03\x03\x01\x15\x03\x07\x03\x03\x01\x93\x03\x07\x11\x07\x01\x95\x07\x05\x07-\x0f\x17\x19\x1b\x1d\x1f\x0f\x0b\x03\x03\x01\x97\x03\x07\x05\x07\x01\t\x03\x07\x03'\x0b\x07\x01\x99\x031\x05#)\x05\x07\x01\t\x033\x03+\x03\x03\x01\x9b\x03\x0f\x05\x07\x01\t\x03\x05\x03/\x05\x07\x01\x06\x02\x03\x19\x03-\r\x06\x01\x03\x05\x073!1\x19\x07\x07\n\x02\x03\x05\x03\t\x0f\x04\x0f\x0557\t\x11\x07M\x07\x03\x15+\x03\x05\x07\x07\x03Q!\x03\x11\x03\x03\x07U\x03\x07\x05\x07'\t\x03\x11\x03\x05\x13\x06'\x03\x11\x05\x03\x07\x07\x03[Y\x03\x11\x0b\x07a_\x03\x19\x05\t\x0b\x03\x03\x07-\x03\x0f\x05\x07e\t\x03\x05\x03\x0f\r\x06i\x03\x05\x07\r\x11\x01\x0f\x04\x07\x03\x13\x06\x03\x01\x05\x01\x00\xa6\x1a\x89\x0f\x1d%\x11\x0f\x0b\t\t\x03\x0b!\x11#Y\x87##%_)=\x85\x8bW\xb3K\x9bM\x9bn\x03\x1b%)9\x1f/!!)#\x1f\x19+\x1b+\x1f\x1f\x15\x1d\x15i\x13\r\x11\x0f\x17\x0f\x1f\x15\x15\x17\x11\x11)\x19\x0f\x0b\x11builtin\x00vhlo\x00module\x00constant_v1\x00broadcast_in_dim_v1\x00iota_v1\x00func_v1\x00compare_v1\x00select_v1\x00return_v1\x00custom_call_v1\x00add_v1\x00reshape_v1\x00pad_v1\x00call_v1\x00value\x00sym_name\x00third_party/py/jax/tests/export_back_compat_test.py\x00arg_attrs\x00function_type\x00res_attrs\x00sym_visibility\x00iota_dimension\x00broadcast_dimensions\x00compare_type\x00comparison_direction\x00api_version\x00backend_config\x00call_target_name\x00called_computations\x00has_side_effect\x00operand_layouts\x00output_operand_aliases\x00result_layouts\x00jax.uses_shape_polymorphism\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit__lambda_\x00jit()/jit(main)/pjit[in_shardings=(UnspecifiedValue,)" + b' out_shardings=(UnspecifiedValue,) in_layouts=(None,)' + b' out_layouts=(None,) resource_env=None donated_invars=(False,)' + b' name=triu keep_unused=False' + b' inline=False]\x00jit()/jit(main)/jit(triu)/iota[dtype=int32' + b' shape=(3, 3)' + b' dimension=0]\x00jit()/jit(main)/jit(triu)/add\x00jit()/jit(main)/jit(triu)/iota[dtype=int32' + b' shape=(3, 3)' + b' dimension=1]\x00jit()/jit(main)/jit(triu)/ge\x00jit()/jit(main)/jit(triu)/broadcast_in_dim[shape=(3,' + b' 3) broadcast_dimensions=()]\x00jit()/jit(main)/jit(triu)/select_n\x00jit()/jit(main)/iota[dtype=complex64' + b' shape=(9,)' + b' dimension=0]\x00jit()/jit(main)/reshape[new_sizes=(3, 3)' + b' dimensions=None]\x00jit()/jit(main)/geqrf\x00mhlo.backend_config\x00jit()/jit(main)/qr[full_matrices=True]\x00edge_padding_high\x00edge_padding_low\x00interior_padding\x00jit()/jit(main)/pad[padding_config=((0,' + b' 0, 0), (0, 0,' + b' 0))]\x00jit()/jit(main)/householder_product\x00mhlo.layout_mode\x00default\x00jax.result_info\x00triu\x00\x00[0]\x00[1]\x00main\x00public\x00private\x00lapack_cgeqrf_ffi\x00lapack_cungqr\x00callee\x00' + ), + xla_call_module_version=9, + nr_devices=1, +) # End paste + + +# Pasted from the test output (see export_back_compat_test_util.py module docstring) +data_2024_08_22['f32'] = dict( + testdata_version=1, + platform='cpu', + custom_call_targets=['lapack_sgeqrf_ffi', 'lapack_sorgqr_ffi'], + serialized_date=datetime.date(2024, 8, 22), + inputs=(), + expected_outputs=( + array( + [ + [0.0, 0.91287076, 0.4082487], + [-0.44721356, 0.36514866, -0.8164965], + [-0.8944271, -0.18257445, 0.40824816], + ], + dtype=float32, + ), + array( + [ + [-6.7082043e00, -8.0498438e00, -9.3914852e00], + [0.0000000e00, 1.0954441e00, 2.1908894e00], + [0.0000000e00, 0.0000000e00, 7.1525574e-07], + ], + dtype=float32, + ), + ), + mlir_module_text=r""" +#loc3 = loc("third_party/py/jax/tests/export_back_compat_test.py":364:11) +#loc10 = loc("jit()/jit(main)/pjit[in_shardings=(UnspecifiedValue,) out_shardings=(UnspecifiedValue,) in_layouts=(None,) out_layouts=(None,) resource_env=None donated_invars=(False,) name=triu keep_unused=False inline=False]"(#loc3)) +module @jit__lambda_ attributes {jax.uses_shape_polymorphism = false, mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { + func.func public @main() -> (tensor<3x3xf32> {jax.result_info = "[0]", mhlo.layout_mode = "default"}, tensor<3x3xf32> {jax.result_info = "[1]", mhlo.layout_mode = "default"}) { + %0 = stablehlo.iota dim = 0 : tensor<9xf32> loc(#loc4) + %1 = stablehlo.reshape %0 : (tensor<9xf32>) -> tensor<3x3xf32> loc(#loc5) + %c = stablehlo.constant dense<3> : tensor loc(#loc6) + %c_0 = stablehlo.constant dense<3> : tensor loc(#loc6) + %2:2 = stablehlo.custom_call @lapack_sgeqrf_ffi(%1) {mhlo.backend_config = {}, operand_layouts = [dense<[0, 1]> : tensor<2xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<0> : tensor<1xindex>]} : (tensor<3x3xf32>) -> (tensor<3x3xf32>, tensor<3xf32>) loc(#loc6) + %cst = stablehlo.constant dense<0.000000e+00> : tensor loc(#loc7) + %3 = stablehlo.pad %2#0, %cst, low = [0, 0], high = [0, 0], interior = [0, 0] : (tensor<3x3xf32>, tensor) -> tensor<3x3xf32> loc(#loc8) + %c_1 = stablehlo.constant dense<3> : tensor loc(#loc9) + %c_2 = stablehlo.constant dense<3> : tensor loc(#loc9) + %c_3 = stablehlo.constant dense<3> : tensor loc(#loc9) + %c_4 = stablehlo.constant dense<1> : tensor loc(#loc9) + %c_5 = stablehlo.constant dense<3> : tensor loc(#loc9) + %c_6 = stablehlo.constant dense<3> : tensor loc(#loc9) + %c_7 = stablehlo.constant dense<3> : tensor loc(#loc9) + %c_8 = stablehlo.constant dense<96> : tensor loc(#loc9) + %4:3 = stablehlo.custom_call @lapack_sorgqr(%c_4, %c_5, %c_6, %c_7, %c_8, %3, %2#1) {api_version = 2 : i32, operand_layouts = [dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<[0, 1]> : tensor<2xindex>, dense<0> : tensor<1xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<> : tensor<0xindex>, dense<0> : tensor<1xindex>]} : (tensor, tensor, tensor, tensor, tensor, tensor<3x3xf32>, tensor<3xf32>) -> (tensor<3x3xf32>, tensor, tensor<96xf32>) loc(#loc9) + %c_9 = stablehlo.constant dense<0> : tensor loc(#loc9) + %5 = stablehlo.broadcast_in_dim %c_9, dims = [] : (tensor) -> tensor loc(#loc9) + %6 = stablehlo.compare EQ, %4#1, %5, SIGNED : (tensor, tensor) -> tensor loc(#loc9) + %7 = stablehlo.broadcast_in_dim %6, dims = [] : (tensor) -> tensor<1x1xi1> loc(#loc9) + %cst_10 = stablehlo.constant dense<0x7FC00000> : tensor loc(#loc9) + %8 = stablehlo.broadcast_in_dim %cst_10, dims = [] : (tensor) -> tensor<3x3xf32> loc(#loc9) + %9 = stablehlo.broadcast_in_dim %7, dims = [0, 1] : (tensor<1x1xi1>) -> tensor<3x3xi1> loc(#loc9) + %10 = stablehlo.select %9, %4#0, %8 : tensor<3x3xi1>, tensor<3x3xf32> loc(#loc9) + %11 = call @triu(%2#0) : (tensor<3x3xf32>) -> tensor<3x3xf32> loc(#loc10) + return %10, %11 : tensor<3x3xf32>, tensor<3x3xf32> loc(#loc) + } loc(#loc) + func.func private @triu(%arg0: tensor<3x3xf32> {mhlo.layout_mode = "default"} loc("jit()/jit(main)/pjit[in_shardings=(UnspecifiedValue,) out_shardings=(UnspecifiedValue,) in_layouts=(None,) out_layouts=(None,) resource_env=None donated_invars=(False,) name=triu keep_unused=False inline=False]"(#loc3))) -> (tensor<3x3xf32> {mhlo.layout_mode = "default"}) { + %0 = stablehlo.iota dim = 0 : tensor<3x3xi32> loc(#loc11) + %c = stablehlo.constant dense<-1> : tensor loc(#loc10) + %1 = stablehlo.broadcast_in_dim %c, dims = [] : (tensor) -> tensor<3x3xi32> loc(#loc12) + %2 = stablehlo.add %0, %1 : tensor<3x3xi32> loc(#loc12) + %3 = stablehlo.iota dim = 1 : tensor<3x3xi32> loc(#loc13) + %4 = stablehlo.compare GE, %2, %3, SIGNED : (tensor<3x3xi32>, tensor<3x3xi32>) -> tensor<3x3xi1> loc(#loc14) + %cst = stablehlo.constant dense<0.000000e+00> : tensor loc(#loc10) + %5 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor) -> tensor<3x3xf32> loc(#loc15) + %6 = stablehlo.select %4, %5, %arg0 : tensor<3x3xi1>, tensor<3x3xf32> loc(#loc16) + return %6 : tensor<3x3xf32> loc(#loc10) + } loc(#loc10) +} loc(#loc) +#loc = loc(unknown) +#loc1 = loc("third_party/py/jax/tests/export_back_compat_test.py":363:26) +#loc2 = loc("third_party/py/jax/tests/export_back_compat_test.py":363:14) +#loc4 = loc("jit()/jit(main)/iota[dtype=float32 shape=(9,) dimension=0]"(#loc1)) +#loc5 = loc("jit()/jit(main)/reshape[new_sizes=(3, 3) dimensions=None]"(#loc2)) +#loc6 = loc("jit()/jit(main)/geqrf"(#loc3)) +#loc7 = loc("jit()/jit(main)/qr[full_matrices=True]"(#loc3)) +#loc8 = loc("jit()/jit(main)/pad[padding_config=((0, 0, 0), (0, 0, 0))]"(#loc3)) +#loc9 = loc("jit()/jit(main)/householder_product"(#loc3)) +#loc11 = loc("jit()/jit(main)/jit(triu)/iota[dtype=int32 shape=(3, 3) dimension=0]"(#loc3)) +#loc12 = loc("jit()/jit(main)/jit(triu)/add"(#loc3)) +#loc13 = loc("jit()/jit(main)/jit(triu)/iota[dtype=int32 shape=(3, 3) dimension=1]"(#loc3)) +#loc14 = loc("jit()/jit(main)/jit(triu)/ge"(#loc3)) +#loc15 = loc("jit()/jit(main)/jit(triu)/broadcast_in_dim[shape=(3, 3) broadcast_dimensions=()]"(#loc3)) +#loc16 = loc("jit()/jit(main)/jit(triu)/select_n"(#loc3)) +""", + mlir_module_serialized=( + b"ML\xefR\x01StableHLO_v0.9.0\x00\x01)\x05\x01\x03\x01\x03\x05\x03\x19\x07\t\x0b\r\x0f\x11\x13\x15\x17\x19\x1b\x1d\x03\xaa\x02\x12\x023\x01\x9d\x0f\x17\x0b\x0f\x13\x13\x0b\x07\x0b\x0f\x13\x0f\x0b\x0b\x0b\x0b\x13\x0b\x0b\x0f\x0b\x0b\x13\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b+\x0b\x0f\x0b\x0b\x0b33\x0b\x0f\x0b\x13\x0b\x13\x0f\x0b\x1b\x0f\x0b\x0f\x0b\x0f\x0b\x0f\x0b\x17\x0f\x0b\x17\x0bS\x0b\x0f\x0b#\x0b\x0b\x0b\x0f\x0b\x0b\x13\x13K\x13\x1b\x13\x03g\x0fO\x0b\x0b\x0b//\x0b\x0f\x0b\x0b\x0b\x0b\x0b\x13\x1b\x0b\x1b\x0b\x0b\x0b\x13\x0b\x0b\x0f\x1f\x0f\x0f\x0b\x1f/\x0b\x0b\x0b\x0f\x0f\x17\x13\x1f\x1f\x1f\x0b\x0b'\x0f\x17\x17\x1f\x0b\x1fO\x01\x07\x17\x17\x0b\x01\x05\x0b\x0f\x03/\x17\x0f\x0f\x07\x07\x0f\x17\x07\x07\x07\x17\x13\x17\x17\x13\x13\x13\x13\x13\x17\x13\x0f\x17\x02\x8a\t\x1d\x8f\x03\x17\x11\xb2\x05\x17\x05\x1f\x1dO\x03\x03\x03%\xd1\x03\x03\x05\xd9\x05!\x1f\x05#\x1dy\x03\x03\x03\x05\xeb\x11\x03\x05\x05%\x05'\x05)\x05+\x03\x03#\xcd\x05-\x05/\x1dW\x03\x051\x053\x03\x03\x05\xd7\x055\x057\x059\x05;\x05=\x05?\x05A\x05C\x03\tACE\x17G\x17\rI\x05E\x11\x01\x00\x05G\x05I\x05K\x03\x0b\x19\xa1\x1b\xb7\x1d\xb9\r\xc3\x1f\xc5\x03\x0b\x19\xad\x1b\xc9\x1d\xad\r\xaf\x1f\xcb\x05M\x1dS\x03\x05O\x03\x03\x05\xcf\x05Q\x03\x03#\xd3\x1d]\x03\x05S\x03\x05)\xb1+\xd5\x1dc\x03\x05U\x1dg\x03\x05W\x1dk\x03\x05Y\x1doq\x05[\x17\x11\xae\x055\x1duw\x05]\x17\x11\xae\x05\x1d\x05_\x03\x13/\xdb1\xb33\xdd5\xa17\xb5}\xdf9\xe1;\xe3=\xe7\x05a\x1d\x81\x03\x05c\x03\x07\x85\xa9\x87\xa9\x89\xa9\x05e\x05g\x05i\x1d\x8d\x03\x05k\x05m\x03\x03\x05\xe9\x03\x03\x05\xed\x03\x11/\xef1\xb33\xf15\xa17\xb59\xf3;\xf5=\xf9\x03\x03\x05\xfb\x03\x05)\xb1+\xfd\x03\x03\x05\xff\x1f-\x01\x1f'!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x03\x01\x1do\x1dq\x1f)\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1f\x1b\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1ds\x03\x03\xc7\x1du\t\x07\x1dw\x05\x01#\x1d\x03\x05\xbb\xbf\r\x05\xab\xbd\xa3\xa5\x1dy\r\x05\xab\xc1\xa3\xa5\x1d{\x1d}\x1d\x7f\r\x03\xa3\xa5#\x1f\x1d\x81\x13\r\x01\x1f\x07\t\xff\xff\xff\xff\x1f!\x01\x13\r\x05\x07\x05\x1f\x0f\t\x00\x00\x00\x00\x1f\t\x11\x03\x00\x00\x00\x00\x00\x00\x00\x0b\x03\x1d\x83\r\x01\x03\x03\x9f\x03\x03\xe5\x15\x03\x01\x01\x01\x03\x05\x9f\xa7\x1f\x07\t\x01\x00\x00\x00\x1f\x07\t\x03\x00\x00\x00\x1f\x07\t`\x00\x00\x00\x0b\x05\x1d\x85\x03\x0f\x9d\x9d\x9d\x9d\x9d\x9f\xa7\x03\x03\xf7\x15\x03\x01\x15\x01\x03\x07\x9f\x9d\xa7\x1f\x07\t\x00\x00\x00\x00\x07\x01\x1f\x0f\t\x00\x00\xc0\x7f\x1f\x1b!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x03\x03%\x02\x02\x03\x03\x0e\x02\xaf\x05\x87\x01\t\x01\x02\x02)\x05\r\r\x0b)\x01\x17)\x01\r\t\x1d)\x01\x0b)\x05\r\r\x17\x01\x13\x1b)\x05\r\r\x13)\x03\t\r\x11\x01\x05\x05\x05\x11\x03\x05\x03\x05)\x03\x01\r)\x03%\x0b)\x03\r\x0b)\x03\t\x15)\x03\x05\x15)\x03\x02\x03\x0b)\x03\x01\x15)\x01\x13)\x05\x05\x05\x13\x04\x8a\x04\x05\x01\x11\x0f?\x07\x03\x01\t\t\x11\x0fK\x07\x039i\x07\x03m!\x03#\x15\x06s\x03\x05\x03\x01\x03\x03\x13\x0b\x03\t\x03\x03\x13\x0b\x03\t\x11\x07\x13{\x05\x05%\x03\x03\x03\x03\x7f-\x03\x0f\x17\x07\x8b\x83\x03\x05\x05\t\r\x03\x03\x01\x0b\x03\t\x03\x03\x01\x0b\x03\t\x03\x03\x01\x0b\x03\t\x03\x03\x01\x91\x03\x07\x03\x03\x01\x15\x03\x07\x03\x03\x01\x15\x03\x07\x03\x03\x01\x15\x03\x07\x03\x03\x01\x93\x03\x07\x11\x07\x01\x95\x07\x05\x07+\x0f\x17\x19\x1b\x1d\x1f\x0f\x0b\x03\x03\x01\x97\x03\x07\x05\x07\x01\t\x03\x07\x03'\x0b\x07\x01\x99\x03/\x05#)\x05\x07\x01\t\x031\x03+\x03\x03\x01\x9b\x03\x0f\x05\x07\x01\t\x03\x05\x03/\x05\x07\x01\x06\x02\x03\x19\x03-\r\x06\x01\x03\x05\x073!1\x19\x07\x07\n\x02\x03\x05\x03\t\x0f\x04\x0f\x0557\t\x11\x07M\x07\x03\x15+\x03\x05\x07\x07\x03Q!\x03\x11\x03\x03\x07U\x03\x07\x05\x07'\t\x03\x11\x03\x05\x13\x06'\x03\x11\x05\x03\x07\x07\x03[Y\x03\x11\x0b\x07a_\x03\x19\x05\t\x0b\x03\x03\x07-\x03\x0f\x05\x07e\t\x03\x05\x03\x0f\r\x06i\x03\x05\x07\r\x11\x01\x0f\x04\x07\x03\x13\x06\x03\x01\x05\x01\x00\x9e\x1a\x89\x0f\x1d%\x11\x0f\x0b\t\t\x03\x0b!\x11#Y\x87##%_)=\x85\x87W\xb3K\x9bM\x9bn\x03\x1b%)9\x1f/!!)#\x1f\x19+\x1b+\x1f\x1f\x15\x1d\x15i\x13\r\x11\x0f\x17\x0f\x1f\x15\x15\x17\x11\x11)\x19\x0f\x0b\x11builtin\x00vhlo\x00module\x00constant_v1\x00broadcast_in_dim_v1\x00iota_v1\x00func_v1\x00compare_v1\x00select_v1\x00return_v1\x00custom_call_v1\x00add_v1\x00reshape_v1\x00pad_v1\x00call_v1\x00value\x00sym_name\x00third_party/py/jax/tests/export_back_compat_test.py\x00arg_attrs\x00function_type\x00res_attrs\x00sym_visibility\x00iota_dimension\x00broadcast_dimensions\x00compare_type\x00comparison_direction\x00api_version\x00backend_config\x00call_target_name\x00called_computations\x00has_side_effect\x00operand_layouts\x00output_operand_aliases\x00result_layouts\x00jax.uses_shape_polymorphism\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit__lambda_\x00jit()/jit(main)/pjit[in_shardings=(UnspecifiedValue,)" + b' out_shardings=(UnspecifiedValue,) in_layouts=(None,)' + b' out_layouts=(None,) resource_env=None donated_invars=(False,)' + b' name=triu keep_unused=False' + b' inline=False]\x00jit()/jit(main)/jit(triu)/iota[dtype=int32' + b' shape=(3, 3)' + b' dimension=0]\x00jit()/jit(main)/jit(triu)/add\x00jit()/jit(main)/jit(triu)/iota[dtype=int32' + b' shape=(3, 3)' + b' dimension=1]\x00jit()/jit(main)/jit(triu)/ge\x00jit()/jit(main)/jit(triu)/broadcast_in_dim[shape=(3,' + b' 3) broadcast_dimensions=()]\x00jit()/jit(main)/jit(triu)/select_n\x00jit()/jit(main)/iota[dtype=float32' + b' shape=(9,)' + b' dimension=0]\x00jit()/jit(main)/reshape[new_sizes=(3, 3)' + b' dimensions=None]\x00jit()/jit(main)/geqrf\x00mhlo.backend_config\x00jit()/jit(main)/qr[full_matrices=True]\x00edge_padding_high\x00edge_padding_low\x00interior_padding\x00jit()/jit(main)/pad[padding_config=((0,' + b' 0, 0), (0, 0,' + b' 0))]\x00jit()/jit(main)/householder_product\x00mhlo.layout_mode\x00default\x00jax.result_info\x00triu\x00\x00[0]\x00[1]\x00main\x00public\x00private\x00lapack_sgeqrf_ffi\x00lapack_sorgqr\x00callee\x00' + ), + xla_call_module_version=9, + nr_devices=1, +) # End paste + + +# Pasted from the test output (see export_back_compat_test_util.py module docstring) +data_2024_08_22['f64'] = dict( + testdata_version=1, + platform='cpu', + custom_call_targets=['lapack_dgeqrf_ffi', 'lapack_dorgqr_ffi'], + serialized_date=datetime.date(2024, 8, 22), + inputs=(), + expected_outputs=( + array([ + [0.0, 0.9128709291752773, 0.40824829046386235], + [-0.447213595499958, 0.3651483716701102, -0.8164965809277263], + [-0.894427190999916, -0.1825741858350548, 0.40824829046386324], + ]), + array([ + [ + -6.7082039324993694e00, + -8.0498447189992444e00, + -9.3914855054991175e00, + ], + [ + 0.0000000000000000e00, + 1.0954451150103341e00, + 2.1908902300206665e00, + ], + [ + 0.0000000000000000e00, + 0.0000000000000000e00, + -8.8817841970012523e-16, + ], + ]), + ), + mlir_module_text=r""" +#loc3 = loc("third_party/py/jax/tests/export_back_compat_test.py":364:11) +#loc10 = loc("jit()/jit(main)/pjit[in_shardings=(UnspecifiedValue,) out_shardings=(UnspecifiedValue,) in_layouts=(None,) out_layouts=(None,) resource_env=None donated_invars=(False,) name=triu keep_unused=False inline=False]"(#loc3)) +module @jit__lambda_ attributes {jax.uses_shape_polymorphism = false, mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { + func.func public @main() -> (tensor<3x3xf64> {jax.result_info = "[0]", mhlo.layout_mode = "default"}, tensor<3x3xf64> {jax.result_info = "[1]", mhlo.layout_mode = "default"}) { + %0 = stablehlo.iota dim = 0 : tensor<9xf64> loc(#loc4) + %1 = stablehlo.reshape %0 : (tensor<9xf64>) -> tensor<3x3xf64> loc(#loc5) + %c = stablehlo.constant dense<3> : tensor loc(#loc6) + %c_0 = stablehlo.constant dense<3> : tensor loc(#loc6) + %2:2 = stablehlo.custom_call @lapack_dgeqrf_ffi(%1) {mhlo.backend_config = {}, operand_layouts = [dense<[0, 1]> : tensor<2xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<0> : tensor<1xindex>]} : (tensor<3x3xf64>) -> (tensor<3x3xf64>, tensor<3xf64>) loc(#loc6) + %cst = stablehlo.constant dense<0.000000e+00> : tensor loc(#loc7) + %3 = stablehlo.pad %2#0, %cst, low = [0, 0], high = [0, 0], interior = [0, 0] : (tensor<3x3xf64>, tensor) -> tensor<3x3xf64> loc(#loc8) + %c_1 = stablehlo.constant dense<3> : tensor loc(#loc9) + %c_2 = stablehlo.constant dense<3> : tensor loc(#loc9) + %c_3 = stablehlo.constant dense<3> : tensor loc(#loc9) + %c_4 = stablehlo.constant dense<1> : tensor loc(#loc9) + %c_5 = stablehlo.constant dense<3> : tensor loc(#loc9) + %c_6 = stablehlo.constant dense<3> : tensor loc(#loc9) + %c_7 = stablehlo.constant dense<3> : tensor loc(#loc9) + %c_8 = stablehlo.constant dense<96> : tensor loc(#loc9) + %4:3 = stablehlo.custom_call @lapack_dorgqr(%c_4, %c_5, %c_6, %c_7, %c_8, %3, %2#1) {api_version = 2 : i32, operand_layouts = [dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<> : tensor<0xindex>, dense<[0, 1]> : tensor<2xindex>, dense<0> : tensor<1xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<> : tensor<0xindex>, dense<0> : tensor<1xindex>]} : (tensor, tensor, tensor, tensor, tensor, tensor<3x3xf64>, tensor<3xf64>) -> (tensor<3x3xf64>, tensor, tensor<96xf64>) loc(#loc9) + %c_9 = stablehlo.constant dense<0> : tensor loc(#loc9) + %5 = stablehlo.broadcast_in_dim %c_9, dims = [] : (tensor) -> tensor loc(#loc9) + %6 = stablehlo.compare EQ, %4#1, %5, SIGNED : (tensor, tensor) -> tensor loc(#loc9) + %7 = stablehlo.broadcast_in_dim %6, dims = [] : (tensor) -> tensor<1x1xi1> loc(#loc9) + %cst_10 = stablehlo.constant dense<0x7FF8000000000000> : tensor loc(#loc9) + %8 = stablehlo.broadcast_in_dim %cst_10, dims = [] : (tensor) -> tensor<3x3xf64> loc(#loc9) + %9 = stablehlo.broadcast_in_dim %7, dims = [0, 1] : (tensor<1x1xi1>) -> tensor<3x3xi1> loc(#loc9) + %10 = stablehlo.select %9, %4#0, %8 : tensor<3x3xi1>, tensor<3x3xf64> loc(#loc9) + %11 = call @triu(%2#0) : (tensor<3x3xf64>) -> tensor<3x3xf64> loc(#loc10) + return %10, %11 : tensor<3x3xf64>, tensor<3x3xf64> loc(#loc) + } loc(#loc) + func.func private @triu(%arg0: tensor<3x3xf64> {mhlo.layout_mode = "default"} loc("jit()/jit(main)/pjit[in_shardings=(UnspecifiedValue,) out_shardings=(UnspecifiedValue,) in_layouts=(None,) out_layouts=(None,) resource_env=None donated_invars=(False,) name=triu keep_unused=False inline=False]"(#loc3))) -> (tensor<3x3xf64> {mhlo.layout_mode = "default"}) { + %0 = stablehlo.iota dim = 0 : tensor<3x3xi32> loc(#loc11) + %c = stablehlo.constant dense<-1> : tensor loc(#loc10) + %1 = stablehlo.broadcast_in_dim %c, dims = [] : (tensor) -> tensor<3x3xi32> loc(#loc12) + %2 = stablehlo.add %0, %1 : tensor<3x3xi32> loc(#loc12) + %3 = stablehlo.iota dim = 1 : tensor<3x3xi32> loc(#loc13) + %4 = stablehlo.compare GE, %2, %3, SIGNED : (tensor<3x3xi32>, tensor<3x3xi32>) -> tensor<3x3xi1> loc(#loc14) + %cst = stablehlo.constant dense<0.000000e+00> : tensor loc(#loc10) + %5 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor) -> tensor<3x3xf64> loc(#loc15) + %6 = stablehlo.select %4, %5, %arg0 : tensor<3x3xi1>, tensor<3x3xf64> loc(#loc16) + return %6 : tensor<3x3xf64> loc(#loc10) + } loc(#loc10) +} loc(#loc) +#loc = loc(unknown) +#loc1 = loc("third_party/py/jax/tests/export_back_compat_test.py":363:26) +#loc2 = loc("third_party/py/jax/tests/export_back_compat_test.py":363:14) +#loc4 = loc("jit()/jit(main)/iota[dtype=float64 shape=(9,) dimension=0]"(#loc1)) +#loc5 = loc("jit()/jit(main)/reshape[new_sizes=(3, 3) dimensions=None]"(#loc2)) +#loc6 = loc("jit()/jit(main)/geqrf"(#loc3)) +#loc7 = loc("jit()/jit(main)/qr[full_matrices=True]"(#loc3)) +#loc8 = loc("jit()/jit(main)/pad[padding_config=((0, 0, 0), (0, 0, 0))]"(#loc3)) +#loc9 = loc("jit()/jit(main)/householder_product"(#loc3)) +#loc11 = loc("jit()/jit(main)/jit(triu)/iota[dtype=int32 shape=(3, 3) dimension=0]"(#loc3)) +#loc12 = loc("jit()/jit(main)/jit(triu)/add"(#loc3)) +#loc13 = loc("jit()/jit(main)/jit(triu)/iota[dtype=int32 shape=(3, 3) dimension=1]"(#loc3)) +#loc14 = loc("jit()/jit(main)/jit(triu)/ge"(#loc3)) +#loc15 = loc("jit()/jit(main)/jit(triu)/broadcast_in_dim[shape=(3, 3) broadcast_dimensions=()]"(#loc3)) +#loc16 = loc("jit()/jit(main)/jit(triu)/select_n"(#loc3)) +""", + mlir_module_serialized=( + b"ML\xefR\x01StableHLO_v0.9.0\x00\x01)\x05\x01\x03\x01\x03\x05\x03\x19\x07\t\x0b\r\x0f\x11\x13\x15\x17\x19\x1b\x1d\x03\xaa\x02\x12\x023\x01\x9d\x0f\x17\x0b\x0f\x13\x13\x0b\x07\x0b\x0f\x13\x0f\x0b\x0b\x0b\x0b\x13\x0b\x0b\x0f\x0b\x0b\x13\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b+\x0b\x0f\x0b\x0b\x0b33\x0b\x0f\x0b\x13\x0b\x13\x0f\x0b\x1b\x0f\x0b\x0f\x0b\x0f\x0b\x0f\x0b\x17\x0f\x0b\x17\x0bS\x0b\x0f\x0b#\x0b\x0b\x0b\x0f\x0b\x0b\x13\x13K\x13\x1b\x13\x03g\x0fO\x0b\x0b\x0b//\x0b\x0f\x0b\x0b\x0b\x0b\x0b\x13\x1b\x0b\x1b\x0b\x0b\x0b\x13\x0b\x0b\x0f\x1f\x0f\x0f\x0b//\x0b\x0b\x0b\x0f\x0f\x17\x13\x1f\x1f\x1f\x0b\x0b'\x0f\x17\x17\x1f\x0b/O\x01\x07\x17\x17\x0b\x01\x05\x0b\x0f\x03/\x17\x0f\x0f\x07\x07\x0f\x17\x07\x07\x07\x17\x13\x17\x17\x13\x13\x13\x13\x13\x17\x13\x0f\x17\x02\xaa\t\x1d\x8f\x03\x17\x11\xb2\x05\x17\x05\x1f\x1dO\x03\x03\x03%\xd1\x03\x03\x05\xd9\x05!\x1f\x05#\x1dy\x03\x03\x03\x05\xeb\x11\x03\x05\x05%\x05'\x05)\x05+\x03\x03#\xcd\x05-\x05/\x1dW\x03\x051\x053\x03\x03\x05\xd7\x055\x057\x059\x05;\x05=\x05?\x05A\x05C\x03\tACE\x17G\x17\rI\x05E\x11\x01\x00\x05G\x05I\x05K\x03\x0b\x19\xa1\x1b\xb7\x1d\xb9\r\xc3\x1f\xc5\x03\x0b\x19\xad\x1b\xc9\x1d\xad\r\xaf\x1f\xcb\x05M\x1dS\x03\x05O\x03\x03\x05\xcf\x05Q\x03\x03#\xd3\x1d]\x03\x05S\x03\x05)\xb1+\xd5\x1dc\x03\x05U\x1dg\x03\x05W\x1dk\x03\x05Y\x1doq\x05[\x17\x11\xae\x055\x1duw\x05]\x17\x11\xae\x05\x1d\x05_\x03\x13/\xdb1\xb33\xdd5\xa17\xb5}\xdf9\xe1;\xe3=\xe7\x05a\x1d\x81\x03\x05c\x03\x07\x85\xa9\x87\xa9\x89\xa9\x05e\x05g\x05i\x1d\x8d\x03\x05k\x05m\x03\x03\x05\xe9\x03\x03\x05\xed\x03\x11/\xef1\xb33\xf15\xa17\xb59\xf3;\xf5=\xf9\x03\x03\x05\xfb\x03\x05)\xb1+\xfd\x03\x03\x05\xff\x1f-\x01\x1f'!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x03\x01\x1do\x1dq\x1f)\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1f\x1b\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1ds\x03\x03\xc7\x1du\t\x07\x1dw\x05\x01#\x1d\x03\x05\xbb\xbf\r\x05\xab\xbd\xa3\xa5\x1dy\r\x05\xab\xc1\xa3\xa5\x1d{\x1d}\x1d\x7f\r\x03\xa3\xa5#\x1f\x1d\x81\x13\r\x01\x1f\x07\t\xff\xff\xff\xff\x1f!\x01\x13\r\x05\x07\x05\x1f\x0f\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1f\t\x11\x03\x00\x00\x00\x00\x00\x00\x00\x0b\x03\x1d\x83\r\x01\x03\x03\x9f\x03\x03\xe5\x15\x03\x01\x01\x01\x03\x05\x9f\xa7\x1f\x07\t\x01\x00\x00\x00\x1f\x07\t\x03\x00\x00\x00\x1f\x07\t`\x00\x00\x00\x0b\x05\x1d\x85\x03\x0f\x9d\x9d\x9d\x9d\x9d\x9f\xa7\x03\x03\xf7\x15\x03\x01\x15\x01\x03\x07\x9f\x9d\xa7\x1f\x07\t\x00\x00\x00\x00\x07\x01\x1f\x0f\x11\x00\x00\x00\x00\x00\x00\xf8\x7f\x1f\x1b!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x03\x03%\x02\x02\x03\x03\x0e\x02\xaf\x05\x87\x01\t\x01\x02\x02)\x05\r\r\x0b)\x01\x17)\x01\r\x0b\x1d)\x01\x0b)\x05\r\r\x17\x01\x13\x1b)\x05\r\r\x13)\x03\t\r\x11\x01\x05\x05\x05\x11\x03\x05\x03\x05)\x03\x01\r)\x03%\x0b)\x03\r\x0b)\x03\t\x15)\x03\x05\x15)\x03\x02\x03\x0b)\x03\x01\x15)\x01\x13)\x05\x05\x05\x13\x04\x8a\x04\x05\x01\x11\x0f?\x07\x03\x01\t\t\x11\x0fK\x07\x039i\x07\x03m!\x03#\x15\x06s\x03\x05\x03\x01\x03\x03\x13\x0b\x03\t\x03\x03\x13\x0b\x03\t\x11\x07\x13{\x05\x05%\x03\x03\x03\x03\x7f-\x03\x0f\x17\x07\x8b\x83\x03\x05\x05\t\r\x03\x03\x01\x0b\x03\t\x03\x03\x01\x0b\x03\t\x03\x03\x01\x0b\x03\t\x03\x03\x01\x91\x03\x07\x03\x03\x01\x15\x03\x07\x03\x03\x01\x15\x03\x07\x03\x03\x01\x15\x03\x07\x03\x03\x01\x93\x03\x07\x11\x07\x01\x95\x07\x05\x07+\x0f\x17\x19\x1b\x1d\x1f\x0f\x0b\x03\x03\x01\x97\x03\x07\x05\x07\x01\t\x03\x07\x03'\x0b\x07\x01\x99\x03/\x05#)\x05\x07\x01\t\x031\x03+\x03\x03\x01\x9b\x03\x0f\x05\x07\x01\t\x03\x05\x03/\x05\x07\x01\x06\x02\x03\x19\x03-\r\x06\x01\x03\x05\x073!1\x19\x07\x07\n\x02\x03\x05\x03\t\x0f\x04\x0f\x0557\t\x11\x07M\x07\x03\x15+\x03\x05\x07\x07\x03Q!\x03\x11\x03\x03\x07U\x03\x07\x05\x07'\t\x03\x11\x03\x05\x13\x06'\x03\x11\x05\x03\x07\x07\x03[Y\x03\x11\x0b\x07a_\x03\x19\x05\t\x0b\x03\x03\x07-\x03\x0f\x05\x07e\t\x03\x05\x03\x0f\r\x06i\x03\x05\x07\r\x11\x01\x0f\x04\x07\x03\x13\x06\x03\x01\x05\x01\x00\x9e\x1a\x89\x0f\x1d%\x11\x0f\x0b\t\t\x03\x0b!\x11#Y\x87##%_)=\x85\x87W\xb3K\x9bM\x9bn\x03\x1b%)9\x1f/!!)#\x1f\x19+\x1b+\x1f\x1f\x15\x1d\x15i\x13\r\x11\x0f\x17\x0f\x1f\x15\x15\x17\x11\x11)\x19\x0f\x0b\x11builtin\x00vhlo\x00module\x00constant_v1\x00broadcast_in_dim_v1\x00iota_v1\x00func_v1\x00compare_v1\x00select_v1\x00return_v1\x00custom_call_v1\x00add_v1\x00reshape_v1\x00pad_v1\x00call_v1\x00value\x00sym_name\x00third_party/py/jax/tests/export_back_compat_test.py\x00arg_attrs\x00function_type\x00res_attrs\x00sym_visibility\x00iota_dimension\x00broadcast_dimensions\x00compare_type\x00comparison_direction\x00api_version\x00backend_config\x00call_target_name\x00called_computations\x00has_side_effect\x00operand_layouts\x00output_operand_aliases\x00result_layouts\x00jax.uses_shape_polymorphism\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit__lambda_\x00jit()/jit(main)/pjit[in_shardings=(UnspecifiedValue,)" + b' out_shardings=(UnspecifiedValue,) in_layouts=(None,)' + b' out_layouts=(None,) resource_env=None donated_invars=(False,)' + b' name=triu keep_unused=False' + b' inline=False]\x00jit()/jit(main)/jit(triu)/iota[dtype=int32' + b' shape=(3, 3)' + b' dimension=0]\x00jit()/jit(main)/jit(triu)/add\x00jit()/jit(main)/jit(triu)/iota[dtype=int32' + b' shape=(3, 3)' + b' dimension=1]\x00jit()/jit(main)/jit(triu)/ge\x00jit()/jit(main)/jit(triu)/broadcast_in_dim[shape=(3,' + b' 3) broadcast_dimensions=()]\x00jit()/jit(main)/jit(triu)/select_n\x00jit()/jit(main)/iota[dtype=float64' + b' shape=(9,)' + b' dimension=0]\x00jit()/jit(main)/reshape[new_sizes=(3, 3)' + b' dimensions=None]\x00jit()/jit(main)/geqrf\x00mhlo.backend_config\x00jit()/jit(main)/qr[full_matrices=True]\x00edge_padding_high\x00edge_padding_low\x00interior_padding\x00jit()/jit(main)/pad[padding_config=((0,' + b' 0, 0), (0, 0,' + b' 0))]\x00jit()/jit(main)/householder_product\x00mhlo.layout_mode\x00default\x00jax.result_info\x00triu\x00\x00[0]\x00[1]\x00main\x00public\x00private\x00lapack_dgeqrf_ffi\x00lapack_dorgqr\x00callee\x00' + ), + xla_call_module_version=9, + nr_devices=1, +) # End paste diff --git a/jax/_src/internal_test_util/export_back_compat_test_data/cpu_svd_lapack_gesdd.py b/jax/_src/internal_test_util/export_back_compat_test_data/cpu_svd_lapack_gesdd.py index 192309f2a54c..2d71308caeda 100644 --- a/jax/_src/internal_test_util/export_back_compat_test_data/cpu_svd_lapack_gesdd.py +++ b/jax/_src/internal_test_util/export_back_compat_test_data/cpu_svd_lapack_gesdd.py @@ -442,3 +442,424 @@ mlir_module_serialized=b'ML\xefR\x01StableHLO_v0.9.0\x00\x01\x1f\x05\x01\x03\x01\x03\x05\x03\x0f\x07\t\x0b\r\x0f\x11\x13\x03\xf9\xa9=\x01S\x0f\x0b\x07\x13\x0b\x13\x0f\x0b\x13\x13\x13\x13#\x0b\x0b\x0b3\x0b\x0b\x0b\x0b\x0b\x17\x0b\x13\x13K\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x13\x1b\x0b\x0b\x13\x13\x03W\x0fo/\x0b\x0f\x1b\x0b\x0b\x0b\x0b\x0b\x17\x13\x0b\x13\x0b\x13\x0b\x0b\x0b\x1f\x1f\x1f\x1f\x0b\x0b\x0b\x0b\x0b\'\x0f\x17+O\x1f\x0f\x0b\x0b//OOo\x01\x03\x0f\x03;\x0f\x1b\x07\x07\x17\x07\x07\x0b\x07\x0f\x13\x0f\x1b\x1b\x1f\x13\x17\x17\x13\x13\x13\x13\x13\x13\x17\x13\x17\x13\x13\x02N\x08\x1d+-\x05\x15\x1f\x03\x03\t\x99\x05\x17\x03\x03\t\x9f\x11\x01\x05\x05\x19\x03\x03\x03{\x03\x03\x03\x7f\x03\x03\x03\xa5\x03\x03\t\xa7\x03\x07\x1b\r\x1d\r\x0f\x1f\x05\x1b\x05\x1d\x05\x1f\x03\x0b#[%g\'i\x0fw)y\x05!\x05#\x05%\x05\'\x05)\x17/\x8e\x05\x01\x05+\x03\x03\x03}\x03\x03\x03\x81\x03\x117\x839\x85;\x87=\x89?\x8bA\x8dC\x8fE\x93\x05-\x05/\x051\x053\x055\x057\x059\x05;\x03\x03\x03\x97\x03\x05K\x9bM\x9d\x05=\x05?\x03\x03\x03\xa1\x03\x03\t\xa3\x1f\'\x01\x1f)1\x01\x00\x00\x00\x00\x00\x00\x00\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x1f-\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1dA\x03\x03]\r\x05_ace\x1dC\x1dE\x1dG\x1dI#\x1f\x03\x07kos\r\x03Ym\x1dK\r\x03Yq\x1dM\r\x03Yu\x1dO\x1dQ\x1dS\x1f\x03\t\x01\x00\x00\x00\x1f\x03\t\x02\x00\x00\x00\x1f\x03\t\x04\x00\x00\x00\x1f\x03\t\x08\x01\x00\x00\x0b\x05\x1dU\x1dW\x03\x01\x05\x01\x03\x0fSSSSSSU\x03\x03\x91\x15\x03\x01\x19\x01\x03\x11U\x95UUWWWW\x1f+!\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x1f\x03\t\x00\x00\x00\x00\x1f/\x01\t\x07\x07\x01\x1f5\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1f\x19\x11\x00\x00\x00\x00\x00\x00\xf8\x7f\x1f9!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x1f\x15!\x00\x00\x00\x00\x00\x00\xf8\x7f\x00\x00\x00\x00\x00\x00\xf8\x7f\x1f;1\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x02\x00\x00\x00\x00\x00\x00\x00\x01\x02\x02)\x01\x13)\x07\t\x11\x11\x11\x01\x0b)\x05\t\x11\t\x13\x1d\x03\t\x1b)\x01\x11)\x03\t\x13)\x01\t)\x07\t\x05\x05\x07)\x07\t\x11\x11\x07\x11\x03\x05\x07\x05\x0b\x05)\x03\x81\x13)\x03"\x03\t)\x03B\x08\x11)\x03\x01\r)\x03\r\r)\x03\t\r)\x03\x05\r)\x03\x01\x0f)\x03\t\x07)\x05\t\x05\x07)\x03\x05\x0f)\x05\t\x11\x07)\x03\t\x0f)\x03\r\x0f\x04\x82\x03\x05\x01\x11\x05\x19\x07\x03\x01\x05\t\x11\x05!\x05\x03Ck\x03\x05\x05\x03\x03\x01\x11\x03\x03\x03\x03\x01\x11\x03\x03\x03\x03\x011\x03\x03\x03\x03\x01\x13\x03\x03\x03\x03\x01\x13\x03\x03\x03\x03\x013\x03\x03\x0b\x07\x015\x11\x05\x0b\x05\x05\x17!#%\x0f\x03\x05\x07\t\x0b\r\x01\x03\x03\x01G\x03\x03\x05\x07\x01\x07\x03\x17\x03\x1f\r\x07\x01I\x031\x05\x17!\x05\x07\x01\x0b\x033\x03#\x03\x03\x01O\x03\x19\x05\x07\x01\x07\x03\x0b\x03\'\x05\x07\x01Q\x037\x03%\x07\x06\x01\x03\x0b\x07+\x11)\x05\x07\x01\x0b\x03\x1b\x03#\x03\x03\x01\x15\x03\x15\x05\x07\x01\x07\x03\x05\x031\x05\x07\x01\x17\x03\x1d\x03/\x07\x06\x01\x03\x05\x075\x133\x05\x07\x01\x0b\x03\x1b\x03#\x03\x03\x01\x15\x03\x15\x05\x07\x01\x07\x03\x05\x03;\x05\x07\x01\x17\x03\x1d\x039\x07\x06\x01\x03\x05\x07?\x15=\x0f\x04\x05\x077-A\x06\x03\x01\x05\x01\x00\xbe\nY\x1d\x03\x0f\x0b\t\t\t\x1b\x1d\r\x1b!+\x1b\x1f/!!)#\x1f\x19\x97y\x1f\x15\x1d\x15\x13%)\x13+\r\x15\x17\x1f\x11\x15)\x19\x0f\x0b\x11builtin\x00vhlo\x00module\x00constant_v1\x00broadcast_in_dim_v1\x00select_v1\x00func_v1\x00custom_call_v1\x00compare_v1\x00return_v1\x00value\x00broadcast_dimensions\x00sym_name\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit_func\x00arg_attrs\x00function_type\x00res_attrs\x00sym_visibility\x00jit(func)/jit(main)/svd[full_matrices=True compute_uv=True]\x00/Users/necula/Source/jax/jax/experimental/jax2tf/tests/back_compat_test.py\x00api_version\x00backend_config\x00call_target_name\x00called_computations\x00has_side_effect\x00operand_layouts\x00output_operand_aliases\x00result_layouts\x00compare_type\x00comparison_direction\x00jax.result_info\x00jax.arg_info\x00input\x00mhlo.sharding\x00{replicated}\x00[0]\x00[1]\x00[2]\x00main\x00public\x00\x00lapack_zgesdd\x00', xla_call_module_version=6, ) # End paste + +data_2024_08_13 = {} + + +# Pasted from the test output (see export_back_compat_test_util.py module docstring) +data_2024_08_13["c128"] = dict( + testdata_version=1, + platform='cpu', + custom_call_targets=['lapack_zgesdd_ffi'], + serialized_date=datetime.date(2024, 8, 13), + inputs=(array([[[-0.9247611722912019-1.3615157109291343j , + -1.0663457975211892+4.73170030936092j , + -1.4918732811689488-2.880861991859318j , + -1.111356346434667 -2.869701609083459j ], + [-4.71291623424314 -1.5444012898828912j , + -5.232967549101415 -0.41287816948482003j, + 0.8905737109262459+9.50245186328329j , + 4.397722119094926 -6.842005210371916j ], + [ 1.9369405063276903+2.3496014107398917j , + -1.5609345742256133+4.2102103739897805j , + 0.6596030248996742+5.195353435247212j , + 0.6315014498240328-1.2778849649354402j ], + [ 5.115159214503849 -0.8856276268773485j , + 1.3719934567460779-2.236070491368575j , + 0.4974504006612811-3.0462081956756637j , + -0.2620346712025989+4.424682727912594j ]], + + [[-1.8242711798401063-0.8543252170262536j , + -2.724527211360488 +2.256038331706666j , + -1.2777487543905157+0.976556823566376j , + 3.7438974536713223-0.4994301527847589j ], + [-0.6359051102028691+2.730662301129662j , + -1.2877728943263032+3.9124921723649053j , + -3.4618573226579894+1.7835551986994034j , + -1.4710491660152465+2.144967500163963j ], + [-3.6013691182532828+2.8182351980619034j , + 2.0045935428878803+1.1146211993017152j , + -2.332213857689336 -0.874915651404938j , + -1.5393862406530452+0.6852883119580928j ], + [-2.674897392856801 +2.0724239502976984j , + -3.349108041292141 -1.0215359152295307j , + 0.2603515088197114-1.9093411474619364j , + 5.41252457188561 +8.634368042893094j ]]]),), + expected_outputs=(array([[[-0.0417367825863334 +0.10796693731538422j , + 0.6813428383170979 +0.3432797958929331j , + -0.4177022900286576 +0.20028957850808846j , + -0.4344351366508529 +0.034743251442636236j], + [-0.8408468609573512 -0.13260646044648036j , + -0.21674151028481226 +0.015170556885426567j, + 0.17147327711152344 +0.15310416152982537j , + -0.3568765623609291 +0.2190438430670875j ], + [-0.26736181440441353 +0.1379833616281102j , + -0.1753427835255798 -0.3789926157696272j , + -0.8179957069096053 -0.037506032257391686j, + 0.25392637883428515 -0.009771014463849592j], + [ 0.4056923996806594 -0.08297706578106906j , + -0.4321527034953763 +0.097915456635744j , + -0.23439193826962634 -0.0842713053222817j , + -0.423482961456089 +0.625144811494929j ]], + + [[ 0.027268437398665468+0.3631205555033544j , + 0.2702977135592881 +0.13046165871625626j , + 0.042868670139236786-0.47658594176021335j , + 0.7242702256119966 +0.15420620503522459j ], + [-0.08593436615104452 +0.11899901833255505j , + 0.370502861093553 -0.6240865462984537j , + 0.46902056878805953 -0.3474794992077024j , + -0.31667671459632085 -0.1034006436993295j ], + [-0.07914843440873574 -0.033487314943774216j, + 0.4110353453489126 -0.4550908055665629j , + -0.43113180393027273 +0.40910871949631994j , + 0.137827301024203 +0.49428280062680047j ], + [-0.7478497242333215 +0.5283836938016965j , + -0.08345894989956637 +0.011807690067190318j, + -0.27178304569905287 +0.05652627940674812j , + -0.0991195491344199 -0.25988596540006825j ]]]), array([[16.80132997488892 , 7.74475561455812 , 5.831221808032042 , + 1.1195288361137763], + [12.395375946948931 , 8.218551160453815 , 4.68363485027408 , + 1.882091536383919 ]]), array([[[ 0.3579625104055671 +0.j , + 0.40179383774178024 -0.12693597167020743j , + -0.0751486661300563 -0.6109813931761134j , + -0.23049271148274275 +0.51209309438597j ], + [-0.46828614153085474 +0.j , + -0.013958972669495653+0.4210606476774212j , + -0.6006888466394118 -0.3766516564723723j , + -0.24264518623236989 -0.20408557153193463j ], + [-0.6392945524816099 +0.j , + 0.24323886076029005 -0.6679928485374246j , + 0.18168178910997027 -0.08126854868489738j , + -0.2030612067046727 -0.07124733621915219j ], + [-0.49383540371426055 +0.j , + -0.010402968929686451+0.37346249914107377j , + 0.2799428270410499 +0.019494062167627474j, + 0.32588905219319264 +0.6569569657140542j ]], + + [[ 0.26669203705168437 +0.j , + 0.24929033811571388 +0.27271089049933883j , + -0.012922512768026959+0.16383354123801502j , + 0.07388201893235019 -0.8717175469187742j ], + [-0.6156140469162427 +0.j , + -0.33787077397020177 +0.3779715465092333j , + -0.39160430587261197 -0.2839601305776179j , + -0.27148886041576736 -0.23729034093304668j ], + [ 0.5618758038857614 +0.j , + -0.5788776267734558 -0.13833058883452376j , + -0.48995086206819655 +0.19259594116096806j , + -0.22967101640965004 -0.012926826751577636j], + [-0.48393210641613604 +0.j , + -0.10492296054284367 -0.4911419972025976j , + -0.07782239226461207 +0.6751317817750168j , + 0.11941657609231512 -0.19354808489959857j ]]])), + mlir_module_text=r""" +#loc1 = loc("input") +module @jit_func attributes {jax.uses_shape_polymorphism = false, mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { + func.func public @main(%arg0: tensor<2x4x4xcomplex> {mhlo.layout_mode = "default"} loc("input")) -> (tensor<2x4x4xcomplex> {jax.result_info = "[0]", mhlo.layout_mode = "default"}, tensor<2x4xf64> {jax.result_info = "[1]", mhlo.layout_mode = "default"}, tensor<2x4x4xcomplex> {jax.result_info = "[2]", mhlo.layout_mode = "default"}) { + %c = stablehlo.constant dense<2> : tensor loc(#loc3) + %c_0 = stablehlo.constant dense<4> : tensor loc(#loc3) + %c_1 = stablehlo.constant dense<4> : tensor loc(#loc3) + %0:5 = stablehlo.custom_call @lapack_zgesdd_ffi(%arg0) {mhlo.backend_config = {mode = 65 : ui8}, operand_layouts = [dense<[1, 2, 0]> : tensor<3xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[1, 2, 0]> : tensor<3xindex>, dense<[1, 0]> : tensor<2xindex>, dense<[1, 2, 0]> : tensor<3xindex>, dense<[1, 2, 0]> : tensor<3xindex>, dense<0> : tensor<1xindex>]} : (tensor<2x4x4xcomplex>) -> (tensor<2x4x4xcomplex>, tensor<2x4xf64>, tensor<2x4x4xcomplex>, tensor<2x4x4xcomplex>, tensor<2xi32>) loc(#loc3) + %c_2 = stablehlo.constant dense<0> : tensor loc(#loc3) + %1 = stablehlo.broadcast_in_dim %c_2, dims = [] : (tensor) -> tensor<2xi32> loc(#loc3) + %2 = stablehlo.compare EQ, %0#4, %1, SIGNED : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> loc(#loc3) + %3 = stablehlo.broadcast_in_dim %2, dims = [0] : (tensor<2xi1>) -> tensor<2x1xi1> loc(#loc3) + %cst = stablehlo.constant dense<0x7FF8000000000000> : tensor loc(#loc3) + %4 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor) -> tensor<2x4xf64> loc(#loc3) + %5 = stablehlo.broadcast_in_dim %3, dims = [0, 1] : (tensor<2x1xi1>) -> tensor<2x4xi1> loc(#loc3) + %6 = stablehlo.select %5, %0#1, %4 : tensor<2x4xi1>, tensor<2x4xf64> loc(#loc3) + %7 = stablehlo.broadcast_in_dim %2, dims = [0] : (tensor<2xi1>) -> tensor<2x1x1xi1> loc(#loc3) + %cst_3 = stablehlo.constant dense<(0x7FF8000000000000,0x7FF8000000000000)> : tensor> loc(#loc3) + %8 = stablehlo.broadcast_in_dim %cst_3, dims = [] : (tensor>) -> tensor<2x4x4xcomplex> loc(#loc3) + %9 = stablehlo.broadcast_in_dim %7, dims = [0, 1, 2] : (tensor<2x1x1xi1>) -> tensor<2x4x4xi1> loc(#loc3) + %10 = stablehlo.select %9, %0#2, %8 : tensor<2x4x4xi1>, tensor<2x4x4xcomplex> loc(#loc3) + %11 = stablehlo.broadcast_in_dim %2, dims = [0] : (tensor<2xi1>) -> tensor<2x1x1xi1> loc(#loc3) + %cst_4 = stablehlo.constant dense<(0x7FF8000000000000,0x7FF8000000000000)> : tensor> loc(#loc3) + %12 = stablehlo.broadcast_in_dim %cst_4, dims = [] : (tensor>) -> tensor<2x4x4xcomplex> loc(#loc3) + %13 = stablehlo.broadcast_in_dim %11, dims = [0, 1, 2] : (tensor<2x1x1xi1>) -> tensor<2x4x4xi1> loc(#loc3) + %14 = stablehlo.select %13, %0#3, %12 : tensor<2x4x4xi1>, tensor<2x4x4xcomplex> loc(#loc3) + return %10, %6, %14 : tensor<2x4x4xcomplex>, tensor<2x4xf64>, tensor<2x4x4xcomplex> loc(#loc) + } loc(#loc) +} loc(#loc) +#loc = loc(unknown) +#loc2 = loc("third_party/py/jax/tests/export_back_compat_test.py":574:13) +#loc3 = loc("jit(func)/jit(main)/svd[full_matrices=True compute_uv=True subset_by_index=None]"(#loc2)) +""", + mlir_module_serialized=b"ML\xefR\x01StableHLO_v0.9.0\x00\x01\x1f\x05\x01\x03\x01\x03\x05\x03\x0f\x07\t\x0b\r\x0f\x11\x13\x03\xf9\xab;\x01Y\x0f\x0b\x07\x13\x0b\x13\x0f\x0b\x13\x13\x13+\x0b\x0f\x0b\x0b\x0b3\x0b\x0b\x0b\x0b\x0f\x0b\x13\x0b\x17\x0bS\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x13\x1b\x0b\x0b\x13\x13\x03S\x0b\x0bo\x0b\x0f\x13\x0b\x17\x1b\x0b\x1b\x0b\x1b\x0b\x0b\x0b//\x0b\x0b\x0b\x0b\x0b\x13\x0b\x0f\x0f\x0f\x17\x1fO/\x1f\x0f\x0b\x0b//OOo\x01\x05\x0b\x0f\x037\x1b\x0f\x07\x07\x17\x07\x07\x0f\x0b\x13\x07\x0f\x0f\x1b\x1b\x1f\x07\x13\x13\x13\x13\x13\x17\x13\x17\x13\x13\x02\x1a\x08\x1d35\x05\x15\x1f\x03\x03\t\x9b\x05\x17\x03\x03\t\xa1\x11\x03\x05\x05\x19\x03\x03\x03{\x03\x03\x03\xa7\x03\x03\t\xa9\x03\t\x19\x1b\x1d\r\x1f\r\x0f!\x05\x1b\x11\x01\x00\x05\x1d\x05\x1f\x05!\x03\x0b%a'e)g\x0fu+w\x05#\x05%\x05'\x05)\x1d/\x05\x05+\x03\x03\x03y\x05-\x177\xfa\x08\x1b\x05/\x03\x13;}=\x7f?\x81A\x83C\x85E\x87G\x8dI\x8fK\x93\x051\x053\x055\x057\x059\x05;\x05=\x05?\x05A\x03\x03\x03\x99\x03\x05Q\x9dS\x9f\x05C\x05E\x03\x03\x03\xa3\x03\x03\t\xa5\x1dG\x1dI\x1f'1\x01\x00\x00\x00\x00\x00\x00\x00\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x1dK\x03\x03c\r\x03Y[##\x03\x07imq\r\x05_kY[\x1dM\r\x05_oY[\x1dO\r\x05_sY[\x1dQ\x1dS\x1dU\x1f\x07\x11\x02\x00\x00\x00\x00\x00\x00\x00\x1f\x07\x11\x04\x00\x00\x00\x00\x00\x00\x00\x0b\x03\x1dW\x1dY\x03\x01\x05\x01\r\x03\x89\x8b\x1d[\x13%A\x03\x03]\x03\x03\x91\x15\x03\x01\x01\x01\x03\x0b]\x95]]\x97\x1f)!\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x1f+\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1f\x1b\t\x00\x00\x00\x00\x1f-\x01\t\x07\x07\x01\x1f3\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1f\x1d\x11\x00\x00\x00\x00\x00\x00\xf8\x7f\x1f7!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x1f\x13!\x00\x00\x00\x00\x00\x00\xf8\x7f\x00\x00\x00\x00\x00\x00\xf8\x7f\x1f91\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x02\x00\x00\x00\x00\x00\x00\x00\x01\t\x01\x02\x02)\x07\t\x11\x11\x15)\x01\t\x1d\x01)\x05\t\x11\x0f\x0b\x13)\x01\x15\x03\x0f)\x03\t\x19\x1b)\x01\x19)\x01\x0f)\x07\t\x05\x05\x0b)\x07\t\x11\x11\x0b\x11\x03\x05\x07\x05\r\x05!)\x03\r\x11)\x03\t\x11)\x03\x05\x11)\x03\x01\t)\x03\t\x0b)\x05\t\x05\x0b)\x03\x05\t)\x05\t\x11\x0b)\x03\t\t)\x03\r\t\x04\x16\x03\x05\x01\x11\x05\x17\x07\x03\x01\x05\t\x11\x05#\x07\x037_\x03\x05-\x05\x03\x011\x03\x07\x05\x03\x01\x11\x03\x07\x05\x03\x01\x11\x03\x07\x0b\x07\x019\x0b\x05\r\x05\x05\x17\x03\x01\x05\x03\x01M\x03\x1b\x03\x07\x01\x07\x03\x17\x03\x13\r\x07\x01O\x03/\x05\x11\x15\x03\x07\x01\x0b\x031\x03\x17\x05\x03\x01U\x03\x1d\x03\x07\x01\x07\x03\r\x03\x1b\x03\x07\x01W\x035\x03\x19\x07\x06\x01\x03\r\x07\x1f\x0b\x1d\x03\x07\x01\x0b\x03\x1f\x03\x17\x05\x03\x01\x13\x03\x13\x03\x07\x01\x07\x03\x05\x03%\x03\x07\x01\x15\x03!\x03#\x07\x06\x01\x03\x05\x07)\r'\x03\x07\x01\x0b\x03\x1f\x03\x17\x05\x03\x01\x13\x03\x13\x03\x07\x01\x07\x03\x05\x03/\x03\x07\x01\x15\x03!\x03-\x07\x06\x01\x03\x05\x073\x0f1\x0f\x04\x05\x07+!5\x06\x03\x01\x05\x01\x00f\x0b]\x0b%\x03\x0f\x0b\t\t\t!\x11#+\x1b\x1f/!)!)#\x1f\x19i\xa3\r\x1f\x15\x1d\x15\x13%)9\x13+\r\x15\x17\x1f\x11\x15\x19)\x0f\x0b\x11builtin\x00vhlo\x00module\x00broadcast_in_dim_v1\x00constant_v1\x00select_v1\x00func_v1\x00custom_call_v1\x00compare_v1\x00return_v1\x00value\x00broadcast_dimensions\x00sym_name\x00jax.uses_shape_polymorphism\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit_func\x00arg_attrs\x00function_type\x00res_attrs\x00sym_visibility\x00input\x00jit(func)/jit(main)/svd[full_matrices=True compute_uv=True subset_by_index=None]\x00third_party/py/jax/tests/export_back_compat_test.py\x00api_version\x00backend_config\x00call_target_name\x00called_computations\x00has_side_effect\x00mhlo.backend_config\x00operand_layouts\x00output_operand_aliases\x00result_layouts\x00compare_type\x00comparison_direction\x00mhlo.layout_mode\x00default\x00jax.result_info\x00[0]\x00[1]\x00[2]\x00main\x00public\x00\x00lapack_zgesdd_ffi\x00mode\x00", + xla_call_module_version=9, + nr_devices=1, +) # End paste + + +# Pasted from the test output (see export_back_compat_test_util.py module docstring) +data_2024_08_13["c64"] = dict( + testdata_version=1, + platform='cpu', + custom_call_targets=['lapack_cgesdd_ffi'], + serialized_date=datetime.date(2024, 8, 13), + inputs=(array([[[ 1.6052934 +0.45878917j, 4.587192 -4.5177283j , + 0.4177733 -1.9419309j , -2.2248359 -4.5042715j ], + [-7.083374 -8.127356j , 2.7596245 -4.991001j , + -0.52622825+5.033981j , -0.35441273-1.8215327j ], + [-0.7996552 -2.4052901j , -0.8506142 -3.164714j , + -0.3090829 +2.2020447j , 1.2367196 +2.8830793j ], + [ 1.4633094 -0.5451007j , -3.7833478 +6.6770763j , + -3.1279542 -2.2322626j , -2.1099617 -2.9661314j ]], + + [[ 1.2560439 -5.4743752j , -2.0085676 +2.0063214j , + -0.8132642 -3.4407883j , -0.17360081+0.6419895j ], + [ 2.3756726 +6.3315964j , -0.31447247-1.9387872j , + 4.6732006 -4.286903j , 1.7702469 -1.4957623j ], + [ 1.6918924 -0.52161306j, 0.49963537+4.7751374j , + -1.9243752 -4.5870543j , 2.8829405 +1.7382988j ], + [ 1.4884951 -0.44194785j, -1.3645276 -2.8733373j , + -0.39430943+2.4366508j , -0.76268387+5.2014065j ]]], + dtype=complex64),), + expected_outputs=(array([[[ 0.016725361+0.19210356j , 0.545269 +0.5572638j , + 0.41363978 +0.18964852j , -0.26152337 -0.28195122j ], + [ 0.53678614 +0.6405725j , -0.21783227 -0.21288806j , + 0.28426635 +0.30535886j , 0.15201291 +0.1076857j ], + [ 0.21286921 +0.15473497j , 0.06647172 -0.25652882j , + -0.4074609 -0.10356678j , -0.11794218 -0.8184482j ], + [-0.39079374 -0.20583557j , -0.18335938 -0.44217706j , + 0.63489586 +0.19758745j , 0.038679928-0.363512j ]], + + [[-0.31785947 +0.39032045j , -0.12733367 -0.30841753j , + 0.2639419 +0.26815215j , -0.21332225 -0.6694792j ], + [-0.39241248 -0.60790956j , -0.14006217 +0.4104069j , + -0.08306134 -0.101844534j, -0.45091915 -0.26039878j ], + [-0.36103737 +0.28761536j , -0.49654633 +0.100843735j, + -0.13752809 -0.6203827j , 0.35439843 -0.028546259j], + [ 0.062335134-0.07821423j , 0.35014486 -0.5668197j , + -0.42214072 -0.5090834j , -0.2889286 -0.15894136j ]]], + dtype=complex64), array([[15.135656 , 9.3730345, 7.44493 , 0.4152342], + [12.316968 , 8.661011 , 5.005059 , 2.1159043]], dtype=float32), array([[[-0.65378654 +0.j , -0.20306695 -0.6166746j , + 0.29948464 +0.24257994j , -0.00760437 +0.049453575j], + [ 0.5271269 +0.j , -0.112915546-0.7116953j , + -0.08921899 -0.36348897j , -0.23654734 -0.08269382j ], + [-0.31538552 +0.j , -0.014410704+0.15958196j , + -0.17958632 -0.136909j , -0.6930434 -0.58613425j ], + [-0.44185144 +0.j , 0.17604697 -0.05049205j , + -0.42138547 -0.6948516j , 0.22373372 +0.24654455j ]], + + [[-0.64551586 +0.j , 0.3293224 -0.1167212j , + -0.09352748 +0.6710144j , -0.038554132+0.02716675j ], + [ 0.4241116 +0.j , 0.031135 -0.539813j , + -0.26271757 +0.22760022j , -0.6360964 -0.04817466j ], + [-0.45774835 +0.j , -0.15202752 +0.2734652j , + 0.18930997 -0.32975054j , -0.73310995 -0.10269694j ], + [ 0.4403465 +0.j , 0.29474002 +0.6330784j , + 0.31271845 +0.42166728j , -0.20595443 -0.02053237j ]]], + dtype=complex64)), + mlir_module_text=r""" +#loc1 = loc("input") +module @jit_func attributes {jax.uses_shape_polymorphism = false, mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { + func.func public @main(%arg0: tensor<2x4x4xcomplex> {mhlo.layout_mode = "default"} loc("input")) -> (tensor<2x4x4xcomplex> {jax.result_info = "[0]", mhlo.layout_mode = "default"}, tensor<2x4xf32> {jax.result_info = "[1]", mhlo.layout_mode = "default"}, tensor<2x4x4xcomplex> {jax.result_info = "[2]", mhlo.layout_mode = "default"}) { + %c = stablehlo.constant dense<2> : tensor loc(#loc3) + %c_0 = stablehlo.constant dense<4> : tensor loc(#loc3) + %c_1 = stablehlo.constant dense<4> : tensor loc(#loc3) + %0:5 = stablehlo.custom_call @lapack_cgesdd_ffi(%arg0) {mhlo.backend_config = {mode = 65 : ui8}, operand_layouts = [dense<[1, 2, 0]> : tensor<3xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[1, 2, 0]> : tensor<3xindex>, dense<[1, 0]> : tensor<2xindex>, dense<[1, 2, 0]> : tensor<3xindex>, dense<[1, 2, 0]> : tensor<3xindex>, dense<0> : tensor<1xindex>]} : (tensor<2x4x4xcomplex>) -> (tensor<2x4x4xcomplex>, tensor<2x4xf32>, tensor<2x4x4xcomplex>, tensor<2x4x4xcomplex>, tensor<2xi32>) loc(#loc3) + %c_2 = stablehlo.constant dense<0> : tensor loc(#loc3) + %1 = stablehlo.broadcast_in_dim %c_2, dims = [] : (tensor) -> tensor<2xi32> loc(#loc3) + %2 = stablehlo.compare EQ, %0#4, %1, SIGNED : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> loc(#loc3) + %3 = stablehlo.broadcast_in_dim %2, dims = [0] : (tensor<2xi1>) -> tensor<2x1xi1> loc(#loc3) + %cst = stablehlo.constant dense<0x7FC00000> : tensor loc(#loc3) + %4 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor) -> tensor<2x4xf32> loc(#loc3) + %5 = stablehlo.broadcast_in_dim %3, dims = [0, 1] : (tensor<2x1xi1>) -> tensor<2x4xi1> loc(#loc3) + %6 = stablehlo.select %5, %0#1, %4 : tensor<2x4xi1>, tensor<2x4xf32> loc(#loc3) + %7 = stablehlo.broadcast_in_dim %2, dims = [0] : (tensor<2xi1>) -> tensor<2x1x1xi1> loc(#loc3) + %cst_3 = stablehlo.constant dense<(0x7FC00000,0x7FC00000)> : tensor> loc(#loc3) + %8 = stablehlo.broadcast_in_dim %cst_3, dims = [] : (tensor>) -> tensor<2x4x4xcomplex> loc(#loc3) + %9 = stablehlo.broadcast_in_dim %7, dims = [0, 1, 2] : (tensor<2x1x1xi1>) -> tensor<2x4x4xi1> loc(#loc3) + %10 = stablehlo.select %9, %0#2, %8 : tensor<2x4x4xi1>, tensor<2x4x4xcomplex> loc(#loc3) + %11 = stablehlo.broadcast_in_dim %2, dims = [0] : (tensor<2xi1>) -> tensor<2x1x1xi1> loc(#loc3) + %cst_4 = stablehlo.constant dense<(0x7FC00000,0x7FC00000)> : tensor> loc(#loc3) + %12 = stablehlo.broadcast_in_dim %cst_4, dims = [] : (tensor>) -> tensor<2x4x4xcomplex> loc(#loc3) + %13 = stablehlo.broadcast_in_dim %11, dims = [0, 1, 2] : (tensor<2x1x1xi1>) -> tensor<2x4x4xi1> loc(#loc3) + %14 = stablehlo.select %13, %0#3, %12 : tensor<2x4x4xi1>, tensor<2x4x4xcomplex> loc(#loc3) + return %10, %6, %14 : tensor<2x4x4xcomplex>, tensor<2x4xf32>, tensor<2x4x4xcomplex> loc(#loc) + } loc(#loc) +} loc(#loc) +#loc = loc(unknown) +#loc2 = loc("third_party/py/jax/tests/export_back_compat_test.py":574:13) +#loc3 = loc("jit(func)/jit(main)/svd[full_matrices=True compute_uv=True subset_by_index=None]"(#loc2)) +""", + mlir_module_serialized=b"ML\xefR\x01StableHLO_v0.9.0\x00\x01\x1f\x05\x01\x03\x01\x03\x05\x03\x0f\x07\t\x0b\r\x0f\x11\x13\x03\xf9\xab;\x01Y\x0f\x0b\x07\x13\x0b\x13\x0f\x0b\x13\x13\x13+\x0b\x0f\x0b\x0b\x0b3\x0b\x0b\x0b\x0b\x0f\x0b\x13\x0b\x17\x0bS\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x13\x1b\x0b\x0b\x13\x13\x03S\x0b\x0bo\x0b\x0f\x13\x0b\x17\x1b\x0b\x1b\x0b\x1b\x0b\x0b\x0b//\x0b\x0b\x0b\x0b\x0b\x13\x0b\x0f\x0f\x0f\x17\x1fO/\x1f\x0f\x0b\x0b/\x1fO/o\x01\x05\x0b\x0f\x037\x1b\x0f\x07\x07\x17\x07\x07\x0f\x0b\x13\x07\x0f\x0f\x1b\x1b\x1f\x07\x13\x13\x13\x13\x13\x17\x13\x17\x13\x13\x02\xea\x07\x1d35\x05\x15\x1f\x03\x03\t\x9b\x05\x17\x03\x03\t\xa1\x11\x03\x05\x05\x19\x03\x03\x03{\x03\x03\x03\xa7\x03\x03\t\xa9\x03\t\x19\x1b\x1d\r\x1f\r\x0f!\x05\x1b\x11\x01\x00\x05\x1d\x05\x1f\x05!\x03\x0b%a'e)g\x0fu+w\x05#\x05%\x05'\x05)\x1d/\x05\x05+\x03\x03\x03y\x05-\x177\xfa\x08\x1b\x05/\x03\x13;}=\x7f?\x81A\x83C\x85E\x87G\x8dI\x8fK\x93\x051\x053\x055\x057\x059\x05;\x05=\x05?\x05A\x03\x03\x03\x99\x03\x05Q\x9dS\x9f\x05C\x05E\x03\x03\x03\xa3\x03\x03\t\xa5\x1dG\x1dI\x1f'1\x01\x00\x00\x00\x00\x00\x00\x00\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x1dK\x03\x03c\r\x03Y[##\x03\x07imq\r\x05_kY[\x1dM\r\x05_oY[\x1dO\r\x05_sY[\x1dQ\x1dS\x1dU\x1f\x07\x11\x02\x00\x00\x00\x00\x00\x00\x00\x1f\x07\x11\x04\x00\x00\x00\x00\x00\x00\x00\x0b\x03\x1dW\x1dY\x03\x01\x05\x01\r\x03\x89\x8b\x1d[\x13%A\x03\x03]\x03\x03\x91\x15\x03\x01\x01\x01\x03\x0b]\x95]]\x97\x1f)!\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x1f+\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1f\x1b\t\x00\x00\x00\x00\x1f-\x01\t\x07\x07\x01\x1f3\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1f\x1d\t\x00\x00\xc0\x7f\x1f7!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x1f\x13\x11\x00\x00\xc0\x7f\x00\x00\xc0\x7f\x1f91\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x02\x00\x00\x00\x00\x00\x00\x00\x01\t\x01\x02\x02)\x07\t\x11\x11\x15)\x01\t\x1d\x01)\x05\t\x11\x0f\t\x13)\x01\x15\x03\x0f)\x03\t\x19\x1b)\x01\x19)\x01\x0f)\x07\t\x05\x05\x0b)\x07\t\x11\x11\x0b\x11\x03\x05\x07\x05\r\x05!)\x03\r\x11)\x03\t\x11)\x03\x05\x11)\x03\x01\t)\x03\t\x0b)\x05\t\x05\x0b)\x03\x05\t)\x05\t\x11\x0b)\x03\t\t)\x03\r\t\x04\x16\x03\x05\x01\x11\x05\x17\x07\x03\x01\x05\t\x11\x05#\x07\x037_\x03\x05-\x05\x03\x011\x03\x07\x05\x03\x01\x11\x03\x07\x05\x03\x01\x11\x03\x07\x0b\x07\x019\x0b\x05\r\x05\x05\x17\x03\x01\x05\x03\x01M\x03\x1b\x03\x07\x01\x07\x03\x17\x03\x13\r\x07\x01O\x03/\x05\x11\x15\x03\x07\x01\x0b\x031\x03\x17\x05\x03\x01U\x03\x1d\x03\x07\x01\x07\x03\r\x03\x1b\x03\x07\x01W\x035\x03\x19\x07\x06\x01\x03\r\x07\x1f\x0b\x1d\x03\x07\x01\x0b\x03\x1f\x03\x17\x05\x03\x01\x13\x03\x13\x03\x07\x01\x07\x03\x05\x03%\x03\x07\x01\x15\x03!\x03#\x07\x06\x01\x03\x05\x07)\r'\x03\x07\x01\x0b\x03\x1f\x03\x17\x05\x03\x01\x13\x03\x13\x03\x07\x01\x07\x03\x05\x03/\x03\x07\x01\x15\x03!\x03-\x07\x06\x01\x03\x05\x073\x0f1\x0f\x04\x05\x07+!5\x06\x03\x01\x05\x01\x00f\x0b]\x0b%\x03\x0f\x0b\t\t\t!\x11#+\x1b\x1f/!)!)#\x1f\x19i\xa3\r\x1f\x15\x1d\x15\x13%)9\x13+\r\x15\x17\x1f\x11\x15\x19)\x0f\x0b\x11builtin\x00vhlo\x00module\x00broadcast_in_dim_v1\x00constant_v1\x00select_v1\x00func_v1\x00custom_call_v1\x00compare_v1\x00return_v1\x00value\x00broadcast_dimensions\x00sym_name\x00jax.uses_shape_polymorphism\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit_func\x00arg_attrs\x00function_type\x00res_attrs\x00sym_visibility\x00input\x00jit(func)/jit(main)/svd[full_matrices=True compute_uv=True subset_by_index=None]\x00third_party/py/jax/tests/export_back_compat_test.py\x00api_version\x00backend_config\x00call_target_name\x00called_computations\x00has_side_effect\x00mhlo.backend_config\x00operand_layouts\x00output_operand_aliases\x00result_layouts\x00compare_type\x00comparison_direction\x00mhlo.layout_mode\x00default\x00jax.result_info\x00[0]\x00[1]\x00[2]\x00main\x00public\x00\x00lapack_cgesdd_ffi\x00mode\x00", + xla_call_module_version=9, + nr_devices=1, +) # End paste + + +# Pasted from the test output (see export_back_compat_test_util.py module docstring) +data_2024_08_13["f32"] = dict( + testdata_version=1, + platform='cpu', + custom_call_targets=['lapack_sgesdd_ffi'], + serialized_date=datetime.date(2024, 8, 13), + inputs=(array([[[ 1.5410905 , -2.775912 , -2.374003 , 4.028736 ], + [-0.56933475, 1.6115232 , 0.9041465 , -0.8321383 ], + [-5.382895 , 4.734856 , 2.1972926 , 1.5553856 ], + [ 0.5109847 , -1.1969309 , 3.3766198 , -1.3678027 ]], + + [[ 2.2637439 , 3.406768 , 4.809871 , 2.8010902 ], + [-1.9981416 , -0.6599986 , 0.5138156 , 4.5982494 ], + [-2.335944 , -9.151717 , -1.0481138 , 2.272443 ], + [-8.257684 , 1.8223318 , 0.38403794, 5.0769973 ]]], + dtype=float32),), + expected_outputs=(array([[[-0.48540133 , 0.6682398 , -0.48819908 , -0.28196266 ], + [ 0.21800542 , -0.13631387 , 0.14819776 , -0.9549501 ], + [ 0.84570533 , 0.44643924 , -0.27943408 , 0.08597416 ], + [ 0.04052323 , -0.57928103 , -0.8133976 , -0.034290295]], + + [[-0.21146727 , 0.46376404 , 0.7863092 , 0.34917426 ], + [ 0.3461469 , 0.21883708 , 0.3399651 , -0.846591 ], + [ 0.6526193 , -0.58340365 , 0.39724028 , 0.27555162 ], + [ 0.6399629 , 0.6298205 , -0.32915345 , 0.29228795 ]]], + dtype=float32), array([[ 8.551605 , 5.3574076 , 2.8073733 , 0.52260846], + [11.457574 , 10.041604 , 5.671653 , 1.4754113 ]], + dtype=float32), array([[[-0.6319044 , 0.66122514, 0.39110142, -0.10255312], + [-0.29710513, 0.13673344, -0.50112027, 0.8011937 ], + [ 0.08969161, 0.4433049 , -0.736473 , -0.5030347 ], + [-0.7101976 , -0.5895469 , -0.23135659, -0.30745378]], + + [[-0.69643414, -0.50230867, -0.11150038, 0.50023323], + [-0.32121184, 0.7889567 , 0.31831914, 0.4159848 ], + [ 0.5096959 , -0.31399366, 0.60193473, 0.5284817 ], + [-0.3898877 , -0.16322279, 0.72382 , -0.5453722 ]]], + dtype=float32)), + mlir_module_text=r""" +#loc1 = loc("input") +module @jit_func attributes {jax.uses_shape_polymorphism = false, mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { + func.func public @main(%arg0: tensor<2x4x4xf32> {mhlo.layout_mode = "default"} loc("input")) -> (tensor<2x4x4xf32> {jax.result_info = "[0]", mhlo.layout_mode = "default"}, tensor<2x4xf32> {jax.result_info = "[1]", mhlo.layout_mode = "default"}, tensor<2x4x4xf32> {jax.result_info = "[2]", mhlo.layout_mode = "default"}) { + %c = stablehlo.constant dense<2> : tensor loc(#loc3) + %c_0 = stablehlo.constant dense<4> : tensor loc(#loc3) + %c_1 = stablehlo.constant dense<4> : tensor loc(#loc3) + %0:5 = stablehlo.custom_call @lapack_sgesdd_ffi(%arg0) {mhlo.backend_config = {mode = 65 : ui8}, operand_layouts = [dense<[1, 2, 0]> : tensor<3xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[1, 2, 0]> : tensor<3xindex>, dense<[1, 0]> : tensor<2xindex>, dense<[1, 2, 0]> : tensor<3xindex>, dense<[1, 2, 0]> : tensor<3xindex>, dense<0> : tensor<1xindex>]} : (tensor<2x4x4xf32>) -> (tensor<2x4x4xf32>, tensor<2x4xf32>, tensor<2x4x4xf32>, tensor<2x4x4xf32>, tensor<2xi32>) loc(#loc3) + %c_2 = stablehlo.constant dense<0> : tensor loc(#loc3) + %1 = stablehlo.broadcast_in_dim %c_2, dims = [] : (tensor) -> tensor<2xi32> loc(#loc3) + %2 = stablehlo.compare EQ, %0#4, %1, SIGNED : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> loc(#loc3) + %3 = stablehlo.broadcast_in_dim %2, dims = [0] : (tensor<2xi1>) -> tensor<2x1xi1> loc(#loc3) + %cst = stablehlo.constant dense<0x7FC00000> : tensor loc(#loc3) + %4 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor) -> tensor<2x4xf32> loc(#loc3) + %5 = stablehlo.broadcast_in_dim %3, dims = [0, 1] : (tensor<2x1xi1>) -> tensor<2x4xi1> loc(#loc3) + %6 = stablehlo.select %5, %0#1, %4 : tensor<2x4xi1>, tensor<2x4xf32> loc(#loc3) + %7 = stablehlo.broadcast_in_dim %2, dims = [0] : (tensor<2xi1>) -> tensor<2x1x1xi1> loc(#loc3) + %cst_3 = stablehlo.constant dense<0x7FC00000> : tensor loc(#loc3) + %8 = stablehlo.broadcast_in_dim %cst_3, dims = [] : (tensor) -> tensor<2x4x4xf32> loc(#loc3) + %9 = stablehlo.broadcast_in_dim %7, dims = [0, 1, 2] : (tensor<2x1x1xi1>) -> tensor<2x4x4xi1> loc(#loc3) + %10 = stablehlo.select %9, %0#2, %8 : tensor<2x4x4xi1>, tensor<2x4x4xf32> loc(#loc3) + %11 = stablehlo.broadcast_in_dim %2, dims = [0] : (tensor<2xi1>) -> tensor<2x1x1xi1> loc(#loc3) + %cst_4 = stablehlo.constant dense<0x7FC00000> : tensor loc(#loc3) + %12 = stablehlo.broadcast_in_dim %cst_4, dims = [] : (tensor) -> tensor<2x4x4xf32> loc(#loc3) + %13 = stablehlo.broadcast_in_dim %11, dims = [0, 1, 2] : (tensor<2x1x1xi1>) -> tensor<2x4x4xi1> loc(#loc3) + %14 = stablehlo.select %13, %0#3, %12 : tensor<2x4x4xi1>, tensor<2x4x4xf32> loc(#loc3) + return %10, %6, %14 : tensor<2x4x4xf32>, tensor<2x4xf32>, tensor<2x4x4xf32> loc(#loc) + } loc(#loc) +} loc(#loc) +#loc = loc(unknown) +#loc2 = loc("third_party/py/jax/tests/export_back_compat_test.py":574:13) +#loc3 = loc("jit(func)/jit(main)/svd[full_matrices=True compute_uv=True subset_by_index=None]"(#loc2)) +""", + mlir_module_serialized=b"ML\xefR\x01StableHLO_v0.9.0\x00\x01\x1f\x05\x01\x03\x01\x03\x05\x03\x0f\x07\t\x0b\r\x0f\x11\x13\x03\xf1\xa77\x01W\x0f\x07\x0b\x13\x0b\x13\x13\x0f\x0b\x13\x13+\x0b\x0f\x0b\x0b\x0b3\x0b\x0b\x0b\x0b\x0f\x0b\x13\x0b\x17\x0bS\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x13\x1b\x0b\x0b\x13\x03Q\x0b\x0bo\x0b\x0f\x13\x0b\x17\x1b\x0b\x1b\x0b\x1b\x0b\x0b\x0b//\x0b\x0b\x0b\x0b\x0b\x13\x0b\x0f\x0f\x0f\x17\x1fO/\x1f\x0f\x0b\x0b/\x1fOo\x01\x05\x0b\x0f\x033\x1b\x0f\x07\x07\x17\x0f\x07\x07\x13\x07\x0f\x1b\x1b\x1f\x07\x13\x13\x13\x13\x13\x17\x13\x17\x13\x13\x02\x9a\x07\x1d35\x1f\x05\x15\x03\x03\t\x99\x05\x17\x03\x03\t\x9f\x03\x03\x05\xa1\x11\x03\x05\x05\x19\x03\x03\x05y\x03\x03\t\xa5\x03\t\x19\x1b\x1d\x0f\x1f\x0f\x11!\x05\x1b\x11\x01\x00\x05\x1d\x05\x1f\x05!\x03\x0b%_'c)e\x11s+u\x05#\x05%\x05'\x05)\x1d/\x03\x05+\x03\x03\x05w\x05-\x177\xfa\x08\x1b\x05/\x03\x13;{=}?\x7fA\x81C\x83E\x85G\x8bI\x8dK\x91\x051\x053\x055\x057\x059\x05;\x05=\x05?\x05A\x03\x03\x05\x97\x03\x05Q\x9bS\x9d\x05C\x05E\x03\x03\t\xa3\x1dG\x1dI\x1f#1\x01\x00\x00\x00\x00\x00\x00\x00\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x1dK\x03\x03a\r\x03WY#\x1f\x03\x07gko\r\x05]iWY\x1dM\r\x05]mWY\x1dO\r\x05]qWY\x1dQ\x1dS\x1dU\x1f\x07\x11\x02\x00\x00\x00\x00\x00\x00\x00\x1f\x07\x11\x04\x00\x00\x00\x00\x00\x00\x00\x0b\x03\x1dW\x1dY\x03\x01\x05\x01\r\x03\x87\x89\x1d[\x13!A\x03\x03[\x03\x03\x8f\x15\x03\x01\x01\x01\x03\x0b[\x93[[\x95\x1f%!\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x1f'\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1f\x19\t\x00\x00\x00\x00\x1f)\x01\t\x07\x07\x01\x1f/\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1f\x0f\t\x00\x00\xc0\x7f\x1f3!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x1f51\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x02\x00\x00\x00\x00\x00\x00\x00\x01\t\x01\x02\x02)\x07\t\x11\x11\x11)\x01\t\x1d\x01)\x05\t\x11\x11)\x01\x11\t\x13)\x03\t\x17\x1b)\x01\x17)\x07\t\x05\x05\x0b)\x07\t\x11\x11\x0b\x11\x03\x05\x07\x05\r\x05!)\x03\r\x13)\x03\t\x13)\x03\x05\x13)\x03\x01\t)\x03\t\x0b)\x05\t\x05\x0b)\x03\x05\t)\x05\t\x11\x0b)\x03\t\t)\x03\r\t\x04\x16\x03\x05\x01\x11\x03\x17\x07\x03\x01\x05\t\x11\x03#\x07\x037_\x03\x05-\x05\x03\x011\x03\x07\x05\x03\x01\x13\x03\x07\x05\x03\x01\x13\x03\x07\x0b\x07\x019\x0b\x05\r\x05\x05\x15\x03\x01\x05\x03\x01M\x03\x19\x03\x07\x01\x07\x03\x15\x03\x13\r\x07\x01O\x03+\x05\x11\x15\x03\x07\x01\x0b\x03-\x03\x17\x05\x03\x01\r\x03\x0f\x03\x07\x01\x07\x03\r\x03\x1b\x03\x07\x01U\x031\x03\x19\x07\x06\x01\x03\r\x07\x1f\x0b\x1d\x03\x07\x01\x0b\x03\x1b\x03\x17\x05\x03\x01\r\x03\x0f\x03\x07\x01\x07\x03\x05\x03%\x03\x07\x01\x15\x03\x1d\x03#\x07\x06\x01\x03\x05\x07)\r'\x03\x07\x01\x0b\x03\x1b\x03\x17\x05\x03\x01\r\x03\x0f\x03\x07\x01\x07\x03\x05\x03/\x03\x07\x01\x15\x03\x1d\x03-\x07\x06\x01\x03\x05\x073\x0f1\x0f\x04\x03\x07+!5\x06\x03\x01\x05\x01\x00f\x0b]\x0b%\x03\x0f\x0b\t\t\t!\x11#+\x1b\x1f/!)!)#\x1f\x19i\xa3\r\x1f\x15\x1d\x15\x13%)9\x13+\r\x15\x17\x1f\x11\x15\x19)\x0f\x0b\x11builtin\x00vhlo\x00module\x00broadcast_in_dim_v1\x00constant_v1\x00select_v1\x00func_v1\x00custom_call_v1\x00compare_v1\x00return_v1\x00value\x00broadcast_dimensions\x00sym_name\x00jax.uses_shape_polymorphism\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit_func\x00arg_attrs\x00function_type\x00res_attrs\x00sym_visibility\x00input\x00jit(func)/jit(main)/svd[full_matrices=True compute_uv=True subset_by_index=None]\x00third_party/py/jax/tests/export_back_compat_test.py\x00api_version\x00backend_config\x00call_target_name\x00called_computations\x00has_side_effect\x00mhlo.backend_config\x00operand_layouts\x00output_operand_aliases\x00result_layouts\x00compare_type\x00comparison_direction\x00mhlo.layout_mode\x00default\x00jax.result_info\x00[0]\x00[1]\x00[2]\x00main\x00public\x00\x00lapack_sgesdd_ffi\x00mode\x00", + xla_call_module_version=9, + nr_devices=1, +) # End paste + + +# Pasted from the test output (see export_back_compat_test_util.py module docstring) +data_2024_08_13["f64"] = dict( + testdata_version=1, + platform='cpu', + custom_call_targets=['lapack_dgesdd_ffi'], + serialized_date=datetime.date(2024, 8, 13), + inputs=(array([[[ 0.3445689867809981 , 3.5114993759427104 , + 4.702602090972179 , -0.2702264758497052 ], + [ 2.209901632583705 , -2.6286702510632773 , + 4.591276599385847 , 3.4465035398844828 ], + [-1.5083742421154478 , 3.3225165204269635 , + 1.2596205557926703 , 3.524804355848018 ], + [ 1.5118969169108838 , 1.838885943509677 , + 2.818520751293422 , 3.06002540493494 ]], + + [[-2.4045510943950843 , -1.5657555633438576 , + -0.6061472334580296 , -0.23926156407779164], + [ 4.087879920053448 , -3.2507640936811715 , + -2.2556577657517476 , 6.090369998330348 ], + [ 1.1165401344486945 , 2.2134726894037247 , + 5.225178515435584 , 1.9794693474107725 ], + [-4.127878192684534 , -0.37313660200336163, + 0.7893465897510026 , -2.0315217791342848 ]]]),), + expected_outputs=(array([[[-0.5109626909166218 , -0.41744996156105796 , + -0.731253241567692 , 0.17297790257908272 ], + [-0.5623501368035175 , 0.7608931604238581 , + 0.03470920608540995 , 0.32186828528169453 ], + [-0.39585755254587396 , -0.49547702914054115 , + 0.6561880513437817 , 0.4089212062978682 ], + [-0.5157288533916832 , -0.035772078593888285, + 0.18297871183094855 , -0.8362194085221047 ]], + + [[-0.12124821978030864 , -0.30260506534356224 , + -0.5817463045715605 , -0.7451847292758066 ], + [ 0.8877417367326683 , -0.1579400123987918 , + -0.37611807392676866 , 0.21331843758089156 ], + [ 0.030552216758649886, 0.9244545314395404 , + -0.36861075330670934 , -0.09260936183071362 ], + [-0.443035032603635 , -0.1699086407831784 , + -0.6198649402326368 , 0.624994775612963 ]]]), array([[8.951386926411187 , 5.762891699811625 , 3.8391040088894437, + 1.269646897103325 ], + [9.215006888576916 , 6.4772976708832255, 3.246269458558178 , + 0.0511210199435459]]), array([[[-0.1789027692424481 , -0.28818125207050604, + -0.7749616998111009 , -0.5332726590950896 ], + [ 0.3871215938703837 , -0.8985113987184387 , + 0.13976186700464233, 0.1525803344591491 ], + [-0.2314069792404015 , -0.03708202130554682, + -0.5045854966104311 , 0.8309447696839618 ], + [-0.8744034999217863 , -0.32901938548360005, + 0.35396957633060844, -0.04324699218274111]], + + [[ 0.6276106632546885 , -0.267287353478729 , + -0.2299525871877408 , 0.69410671635204 ], + [ 0.28029316975925644, 0.47811378046591546, + 0.8083625695047307 , 0.1984764674680803 ], + [ 0.6187014005224261 , 0.4771409534394446 , + -0.37406866975606345, -0.4996175715979325 ], + [-0.38045915857935025, 0.6872417290515942 , + -0.3921025301835002 , 0.4787538410571401 ]]])), + mlir_module_text=r""" +#loc1 = loc("input") +module @jit_func attributes {jax.uses_shape_polymorphism = false, mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { + func.func public @main(%arg0: tensor<2x4x4xf64> {mhlo.layout_mode = "default"} loc("input")) -> (tensor<2x4x4xf64> {jax.result_info = "[0]", mhlo.layout_mode = "default"}, tensor<2x4xf64> {jax.result_info = "[1]", mhlo.layout_mode = "default"}, tensor<2x4x4xf64> {jax.result_info = "[2]", mhlo.layout_mode = "default"}) { + %c = stablehlo.constant dense<2> : tensor loc(#loc3) + %c_0 = stablehlo.constant dense<4> : tensor loc(#loc3) + %c_1 = stablehlo.constant dense<4> : tensor loc(#loc3) + %0:5 = stablehlo.custom_call @lapack_dgesdd_ffi(%arg0) {mhlo.backend_config = {mode = 65 : ui8}, operand_layouts = [dense<[1, 2, 0]> : tensor<3xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[1, 2, 0]> : tensor<3xindex>, dense<[1, 0]> : tensor<2xindex>, dense<[1, 2, 0]> : tensor<3xindex>, dense<[1, 2, 0]> : tensor<3xindex>, dense<0> : tensor<1xindex>]} : (tensor<2x4x4xf64>) -> (tensor<2x4x4xf64>, tensor<2x4xf64>, tensor<2x4x4xf64>, tensor<2x4x4xf64>, tensor<2xi32>) loc(#loc3) + %c_2 = stablehlo.constant dense<0> : tensor loc(#loc3) + %1 = stablehlo.broadcast_in_dim %c_2, dims = [] : (tensor) -> tensor<2xi32> loc(#loc3) + %2 = stablehlo.compare EQ, %0#4, %1, SIGNED : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> loc(#loc3) + %3 = stablehlo.broadcast_in_dim %2, dims = [0] : (tensor<2xi1>) -> tensor<2x1xi1> loc(#loc3) + %cst = stablehlo.constant dense<0x7FF8000000000000> : tensor loc(#loc3) + %4 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor) -> tensor<2x4xf64> loc(#loc3) + %5 = stablehlo.broadcast_in_dim %3, dims = [0, 1] : (tensor<2x1xi1>) -> tensor<2x4xi1> loc(#loc3) + %6 = stablehlo.select %5, %0#1, %4 : tensor<2x4xi1>, tensor<2x4xf64> loc(#loc3) + %7 = stablehlo.broadcast_in_dim %2, dims = [0] : (tensor<2xi1>) -> tensor<2x1x1xi1> loc(#loc3) + %cst_3 = stablehlo.constant dense<0x7FF8000000000000> : tensor loc(#loc3) + %8 = stablehlo.broadcast_in_dim %cst_3, dims = [] : (tensor) -> tensor<2x4x4xf64> loc(#loc3) + %9 = stablehlo.broadcast_in_dim %7, dims = [0, 1, 2] : (tensor<2x1x1xi1>) -> tensor<2x4x4xi1> loc(#loc3) + %10 = stablehlo.select %9, %0#2, %8 : tensor<2x4x4xi1>, tensor<2x4x4xf64> loc(#loc3) + %11 = stablehlo.broadcast_in_dim %2, dims = [0] : (tensor<2xi1>) -> tensor<2x1x1xi1> loc(#loc3) + %cst_4 = stablehlo.constant dense<0x7FF8000000000000> : tensor loc(#loc3) + %12 = stablehlo.broadcast_in_dim %cst_4, dims = [] : (tensor) -> tensor<2x4x4xf64> loc(#loc3) + %13 = stablehlo.broadcast_in_dim %11, dims = [0, 1, 2] : (tensor<2x1x1xi1>) -> tensor<2x4x4xi1> loc(#loc3) + %14 = stablehlo.select %13, %0#3, %12 : tensor<2x4x4xi1>, tensor<2x4x4xf64> loc(#loc3) + return %10, %6, %14 : tensor<2x4x4xf64>, tensor<2x4xf64>, tensor<2x4x4xf64> loc(#loc) + } loc(#loc) +} loc(#loc) +#loc = loc(unknown) +#loc2 = loc("third_party/py/jax/tests/export_back_compat_test.py":574:13) +#loc3 = loc("jit(func)/jit(main)/svd[full_matrices=True compute_uv=True subset_by_index=None]"(#loc2)) +""", + mlir_module_serialized=b"ML\xefR\x01StableHLO_v0.9.0\x00\x01\x1f\x05\x01\x03\x01\x03\x05\x03\x0f\x07\t\x0b\r\x0f\x11\x13\x03\xf1\xa77\x01W\x0f\x07\x0b\x13\x0b\x13\x13\x0f\x0b\x13\x13+\x0b\x0f\x0b\x0b\x0b3\x0b\x0b\x0b\x0b\x0f\x0b\x13\x0b\x17\x0bS\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x13\x1b\x0b\x0b\x13\x03Q\x0b\x0bo\x0b\x0f\x13\x0b\x17\x1b\x0b\x1b\x0b\x1b\x0b\x0b\x0b//\x0b\x0b\x0b\x0b\x0b\x13\x0b\x0f\x0f\x0f\x17\x1fO/\x1f\x0f\x0b\x0b//Oo\x01\x05\x0b\x0f\x033\x1b\x0f\x07\x07\x17\x0f\x07\x07\x13\x07\x0f\x1b\x1b\x1f\x07\x13\x13\x13\x13\x13\x17\x13\x17\x13\x13\x02\xaa\x07\x1d35\x1f\x05\x15\x03\x03\t\x99\x05\x17\x03\x03\t\x9f\x03\x03\x05\xa1\x11\x03\x05\x05\x19\x03\x03\x05y\x03\x03\t\xa5\x03\t\x19\x1b\x1d\x0f\x1f\x0f\x11!\x05\x1b\x11\x01\x00\x05\x1d\x05\x1f\x05!\x03\x0b%_'c)e\x11s+u\x05#\x05%\x05'\x05)\x1d/\x03\x05+\x03\x03\x05w\x05-\x177\xfa\x08\x1b\x05/\x03\x13;{=}?\x7fA\x81C\x83E\x85G\x8bI\x8dK\x91\x051\x053\x055\x057\x059\x05;\x05=\x05?\x05A\x03\x03\x05\x97\x03\x05Q\x9bS\x9d\x05C\x05E\x03\x03\t\xa3\x1dG\x1dI\x1f#1\x01\x00\x00\x00\x00\x00\x00\x00\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x1dK\x03\x03a\r\x03WY#\x1f\x03\x07gko\r\x05]iWY\x1dM\r\x05]mWY\x1dO\r\x05]qWY\x1dQ\x1dS\x1dU\x1f\x07\x11\x02\x00\x00\x00\x00\x00\x00\x00\x1f\x07\x11\x04\x00\x00\x00\x00\x00\x00\x00\x0b\x03\x1dW\x1dY\x03\x01\x05\x01\r\x03\x87\x89\x1d[\x13!A\x03\x03[\x03\x03\x8f\x15\x03\x01\x01\x01\x03\x0b[\x93[[\x95\x1f%!\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x1f'\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1f\x19\t\x00\x00\x00\x00\x1f)\x01\t\x07\x07\x01\x1f/\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1f\x0f\x11\x00\x00\x00\x00\x00\x00\xf8\x7f\x1f3!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x1f51\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x02\x00\x00\x00\x00\x00\x00\x00\x01\t\x01\x02\x02)\x07\t\x11\x11\x11)\x01\t\x1d\x01)\x05\t\x11\x11)\x01\x11\x0b\x13)\x03\t\x17\x1b)\x01\x17)\x07\t\x05\x05\x0b)\x07\t\x11\x11\x0b\x11\x03\x05\x07\x05\r\x05!)\x03\r\x13)\x03\t\x13)\x03\x05\x13)\x03\x01\t)\x03\t\x0b)\x05\t\x05\x0b)\x03\x05\t)\x05\t\x11\x0b)\x03\t\t)\x03\r\t\x04\x16\x03\x05\x01\x11\x03\x17\x07\x03\x01\x05\t\x11\x03#\x07\x037_\x03\x05-\x05\x03\x011\x03\x07\x05\x03\x01\x13\x03\x07\x05\x03\x01\x13\x03\x07\x0b\x07\x019\x0b\x05\r\x05\x05\x15\x03\x01\x05\x03\x01M\x03\x19\x03\x07\x01\x07\x03\x15\x03\x13\r\x07\x01O\x03+\x05\x11\x15\x03\x07\x01\x0b\x03-\x03\x17\x05\x03\x01\r\x03\x0f\x03\x07\x01\x07\x03\r\x03\x1b\x03\x07\x01U\x031\x03\x19\x07\x06\x01\x03\r\x07\x1f\x0b\x1d\x03\x07\x01\x0b\x03\x1b\x03\x17\x05\x03\x01\r\x03\x0f\x03\x07\x01\x07\x03\x05\x03%\x03\x07\x01\x15\x03\x1d\x03#\x07\x06\x01\x03\x05\x07)\r'\x03\x07\x01\x0b\x03\x1b\x03\x17\x05\x03\x01\r\x03\x0f\x03\x07\x01\x07\x03\x05\x03/\x03\x07\x01\x15\x03\x1d\x03-\x07\x06\x01\x03\x05\x073\x0f1\x0f\x04\x03\x07+!5\x06\x03\x01\x05\x01\x00f\x0b]\x0b%\x03\x0f\x0b\t\t\t!\x11#+\x1b\x1f/!)!)#\x1f\x19i\xa3\r\x1f\x15\x1d\x15\x13%)9\x13+\r\x15\x17\x1f\x11\x15\x19)\x0f\x0b\x11builtin\x00vhlo\x00module\x00broadcast_in_dim_v1\x00constant_v1\x00select_v1\x00func_v1\x00custom_call_v1\x00compare_v1\x00return_v1\x00value\x00broadcast_dimensions\x00sym_name\x00jax.uses_shape_polymorphism\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit_func\x00arg_attrs\x00function_type\x00res_attrs\x00sym_visibility\x00input\x00jit(func)/jit(main)/svd[full_matrices=True compute_uv=True subset_by_index=None]\x00third_party/py/jax/tests/export_back_compat_test.py\x00api_version\x00backend_config\x00call_target_name\x00called_computations\x00has_side_effect\x00mhlo.backend_config\x00operand_layouts\x00output_operand_aliases\x00result_layouts\x00compare_type\x00comparison_direction\x00mhlo.layout_mode\x00default\x00jax.result_info\x00[0]\x00[1]\x00[2]\x00main\x00public\x00\x00lapack_dgesdd_ffi\x00mode\x00", + xla_call_module_version=9, + nr_devices=1, +) # End paste diff --git a/jax/_src/internal_test_util/export_back_compat_test_data/cuda_lu_cusolver_getrf.py b/jax/_src/internal_test_util/export_back_compat_test_data/cuda_lu_cusolver_getrf.py new file mode 100644 index 000000000000..47da841aec0a --- /dev/null +++ b/jax/_src/internal_test_util/export_back_compat_test_data/cuda_lu_cusolver_getrf.py @@ -0,0 +1,196 @@ +# Copyright 2024 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# ruff: noqa + +import datetime +from numpy import array, int32, float32, complex64 + +data_2024_08_19 = {} + +data_2024_08_19["f32"] = dict( + testdata_version=1, + platform='cuda', + custom_call_targets=['cu_lu_pivots_to_permutation', 'cusolver_getrf_ffi'], + serialized_date=datetime.date(2024, 8, 19), + inputs=(), + expected_outputs=(array([[ 8. , 9. , 10. , 11. ], + [ 0. , 1. , 2. , 3. ], + [ 0.5, 0.5, 0. , 0. ]], dtype=float32), array([2, 2, 2], dtype=int32), array([2, 0, 1], dtype=int32)), + mlir_module_text=r""" +module @jit__lambda_ attributes {jax.uses_shape_polymorphism = false, mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { + func.func public @main() -> (tensor<3x4xf32> {jax.result_info = "[0]", mhlo.layout_mode = "default"}, tensor<3xi32> {jax.result_info = "[1]", mhlo.layout_mode = "default"}, tensor<3xi32> {jax.result_info = "[2]", mhlo.layout_mode = "default"}) { + %0 = stablehlo.iota dim = 0 : tensor<12xf32> loc(#loc4) + %1 = stablehlo.reshape %0 : (tensor<12xf32>) -> tensor<3x4xf32> loc(#loc5) + %2:3 = stablehlo.custom_call @cusolver_getrf_ffi(%1) {mhlo.backend_config = {}, operand_layouts = [dense<[0, 1]> : tensor<2xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<0> : tensor<1xindex>, dense<> : tensor<0xindex>]} : (tensor<3x4xf32>) -> (tensor<3x4xf32>, tensor<3xi32>, tensor) loc(#loc6) + %c = stablehlo.constant dense<1> : tensor loc(#loc6) + %3 = stablehlo.broadcast_in_dim %c, dims = [] : (tensor) -> tensor<3xi32> loc(#loc6) + %4 = stablehlo.subtract %2#1, %3 : tensor<3xi32> loc(#loc6) + %c_0 = stablehlo.constant dense<0> : tensor loc(#loc6) + %5 = stablehlo.broadcast_in_dim %c_0, dims = [] : (tensor) -> tensor loc(#loc6) + %6 = stablehlo.compare GE, %2#2, %5, SIGNED : (tensor, tensor) -> tensor loc(#loc6) + %7 = stablehlo.broadcast_in_dim %6, dims = [] : (tensor) -> tensor<1x1xi1> loc(#loc6) + %cst = stablehlo.constant dense<0x7FC00000> : tensor loc(#loc6) + %8 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor) -> tensor<3x4xf32> loc(#loc6) + %9 = stablehlo.broadcast_in_dim %7, dims = [0, 1] : (tensor<1x1xi1>) -> tensor<3x4xi1> loc(#loc6) + %10 = stablehlo.select %9, %2#0, %8 : tensor<3x4xi1>, tensor<3x4xf32> loc(#loc6) + %11 = stablehlo.custom_call @cu_lu_pivots_to_permutation(%4) {mhlo.backend_config = {}, operand_layouts = [dense<0> : tensor<1xindex>], result_layouts = [dense<0> : tensor<1xindex>]} : (tensor<3xi32>) -> tensor<3xi32> loc(#loc7) + return %10, %4, %11 : tensor<3x4xf32>, tensor<3xi32>, tensor<3xi32> loc(#loc) + } loc(#loc) +} loc(#loc) +#loc = loc(unknown) +#loc1 = loc("third_party/py/jax/tests/export_back_compat_test.py":441:26) +#loc2 = loc("third_party/py/jax/tests/export_back_compat_test.py":441:14) +#loc3 = loc("third_party/py/jax/tests/export_back_compat_test.py":442:11) +#loc4 = loc("jit()/jit(main)/iota[dtype=float32 shape=(12,) dimension=0]"(#loc1)) +#loc5 = loc("jit()/jit(main)/reshape[new_sizes=(3, 4) dimensions=None]"(#loc2)) +#loc6 = loc("jit()/jit(main)/lu"(#loc3)) +#loc7 = loc("jit()/jit(main)/lu_pivots_to_permutation[permutation_size=3]"(#loc3)) +""", + mlir_module_serialized=b"ML\xefR\x01StableHLO_v0.9.0\x00\x01%\x05\x01\x03\x01\x03\x05\x03\x15\x07\t\x0b\r\x0f\x11\x13\x15\x17\x19\x03\xe9\xab+\x01c\x0f\x13\x07\x0b\x0b\x0f\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x17\x0b+\x0b\x0f\x0b\x0b\x0b3\x0b\x0b\x0b\x0b\x13\x0b\x0f\x0b\x17\x0f\x0b\x17S\x0b\x13\x13\x1b\x0b\x0b\x13\x13S\x0f\x0b\x03I\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0bO/\x0f\x0b\x17\x1b\x0b\x1b\x0b\x1b\x0b\x0b\x0b\x0f\x0b\x0f\x0f\x17\x17\x0f\x1f\x0f\x1f\x0b\x0b\x1fO\x0b\x01\x05\x0b\x0f\x03'\x13\x0f\x17\x07\x07\x07\x07\x07\x0f\x1b\x13\x13\x13\x13\x13\x0f\x17\x17\x13\x02^\x06\x1dM!\x03\x03#\x9d\x1f\x05\x1b\x05\x1d\x11\x03\x05\x05\x1f\x05!\x05#\x05%\x05'\x05)\x05+\x05-\x05/\x051\x17\x07\xea\x06\x17\x053\x03\t')+\x0b-\x0b\r/\x055\x11\x01\x00\x057\x059\x05;\x03\x0b3c5y7{\r\x899\x8b\x05=\x05?\x05A\x05C\x03\x03=\x8d\x05E\x1dAC\x05G\x17\x07\xe6\x065\x1dGI\x05I\x17\x07\xe6\x06\x1d\x03\x13\x0fk\x11m\x13\x8f\x15c\x17o\x19q\x1b\x91\x1d\x93\x1f\x97\x05K\x03\x03\t\x9b\x03\x03\t\x9f\x03\x05U\xa1W\xa3\x05M\x05O\x03\x03\t\xa5\x03\x03#\xa7\x03\x13\x0fk\x11m\x13\xa9\x15c\x17o\x19q\x1bw\x1dc\x1fw\x1da!\x05Q\x03\x01\x1dS\x1dU\x1dW\x0b\x03\x1dY\x05\x01\r\x01\x1f\x1b!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x1f\x1d\x11\x00\x00\x00\x00\x00\x00\x00\x00\x03\x03u#\x17\x03\x07}\x81\x85\r\x05e\x7fgi\x1d[\r\x05e\x83gi\x1d]\r\x05e\x87gi\x1d_\x1da\x1dc\x13\r\x01\x1de\x03\x03s\x03\x03\x95\x15\x03\x01\x01\x01\x03\x07su\x99\x1f\x1f\x01\x1f\x07\t\x01\x00\x00\x00\x1f!\x01\x1f\x07\t\x00\x00\x00\x00\t\x07\x07\x05\x1f\x15\t\x00\x00\xc0\x7f\x1f)!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x1dg\x01\t\x01\x02\x02)\x03\r\x13)\x01\x13)\x05\r\x11\x0b\t\x1d\x13\x01\x1b)\x01\x0b\x11\x01\x07\t\x05\x05)\x031\x0b)\x03\t\x0f)\x03\x05\x0f)\x03\x01\x0f)\x03\x01\r)\x01\x11)\x05\x05\x05\x11)\x05\r\x11\x11)\x03\t\r\x04.\x02\x05\x01\x11\x05%\x07\x03\x01\x05\t\x11\x051\x07\x03#A\x0b\x03?;\x03\x19\r\x06E\x03\t\x03\x01\x07\x07\x01K\x07\t\x05\x07\x03\x03\x05\x03\x01O\x03\x07\x03\x07\x01\x03\x03\x05\x03\x0b\x0f\x06\x01\x03\x05\x05\x07\r\x05\x03\x01Q\x03\x07\x03\x07\x01\x03\x03\x07\x03\x11\x11\x07\x01S\x03#\x05\t\x13\x03\x07\x01\x03\x03%\x03\x15\x05\x03\x01Y\x03\x15\x03\x07\x01\x03\x03\t\x03\x19\x03\x07\x01[\x03'\x03\x17\x13\x06\x01\x03\t\x07\x1d\x05\x1b\x07\x07_]\x03\x05\x03\x0f\x15\x04\x05\x07\x1f\x0f!\x06\x03\x01\x05\x01\x00\xe2\x0ei9'\x0f\x0b\t\t\t\x03\x11#!\x8b+\x1b7\x85\x89\x1f\x1f\x15\x1d\x15\x1b%)9+\x1f/!)!)#\x1f\x19\x13\ri\x15\x15\x17\x19\x17\x11\x11\x1f\x19)\x0f\x0b\x11builtin\x00vhlo\x00module\x00broadcast_in_dim_v1\x00constant_v1\x00custom_call_v1\x00func_v1\x00iota_v1\x00reshape_v1\x00subtract_v1\x00compare_v1\x00select_v1\x00return_v1\x00third_party/py/jax/tests/export_back_compat_test.py\x00value\x00sym_name\x00api_version\x00backend_config\x00call_target_name\x00called_computations\x00has_side_effect\x00mhlo.backend_config\x00operand_layouts\x00output_operand_aliases\x00result_layouts\x00broadcast_dimensions\x00jax.uses_shape_polymorphism\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit__lambda_\x00arg_attrs\x00function_type\x00res_attrs\x00sym_visibility\x00iota_dimension\x00jit()/jit(main)/iota[dtype=float32 shape=(12,) dimension=0]\x00jit()/jit(main)/reshape[new_sizes=(3, 4) dimensions=None]\x00jit()/jit(main)/lu\x00compare_type\x00comparison_direction\x00jit()/jit(main)/lu_pivots_to_permutation[permutation_size=3]\x00jax.result_info\x00mhlo.layout_mode\x00default\x00\x00[0]\x00[1]\x00[2]\x00main\x00public\x00cusolver_getrf_ffi\x00cu_lu_pivots_to_permutation\x00", + xla_call_module_version=9, + nr_devices=1, +) # End paste + +data_2024_08_19["f64"] = dict( + testdata_version=1, + platform='cuda', + custom_call_targets=['cu_lu_pivots_to_permutation', 'cusolver_getrf_ffi'], + serialized_date=datetime.date(2024, 8, 19), + inputs=(), + expected_outputs=(array([[ 8. , 9. , 10. , 11. ], + [ 0. , 1. , 2. , 3. ], + [ 0.5, 0.5, 0. , 0. ]]), array([2, 2, 2], dtype=int32), array([2, 0, 1], dtype=int32)), + mlir_module_text=r""" +module @jit__lambda_ attributes {jax.uses_shape_polymorphism = false, mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { + func.func public @main() -> (tensor<3x4xf64> {jax.result_info = "[0]", mhlo.layout_mode = "default"}, tensor<3xi32> {jax.result_info = "[1]", mhlo.layout_mode = "default"}, tensor<3xi32> {jax.result_info = "[2]", mhlo.layout_mode = "default"}) { + %0 = stablehlo.iota dim = 0 : tensor<12xf64> loc(#loc4) + %1 = stablehlo.reshape %0 : (tensor<12xf64>) -> tensor<3x4xf64> loc(#loc5) + %2:3 = stablehlo.custom_call @cusolver_getrf_ffi(%1) {mhlo.backend_config = {}, operand_layouts = [dense<[0, 1]> : tensor<2xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<0> : tensor<1xindex>, dense<> : tensor<0xindex>]} : (tensor<3x4xf64>) -> (tensor<3x4xf64>, tensor<3xi32>, tensor) loc(#loc6) + %c = stablehlo.constant dense<1> : tensor loc(#loc6) + %3 = stablehlo.broadcast_in_dim %c, dims = [] : (tensor) -> tensor<3xi32> loc(#loc6) + %4 = stablehlo.subtract %2#1, %3 : tensor<3xi32> loc(#loc6) + %c_0 = stablehlo.constant dense<0> : tensor loc(#loc6) + %5 = stablehlo.broadcast_in_dim %c_0, dims = [] : (tensor) -> tensor loc(#loc6) + %6 = stablehlo.compare GE, %2#2, %5, SIGNED : (tensor, tensor) -> tensor loc(#loc6) + %7 = stablehlo.broadcast_in_dim %6, dims = [] : (tensor) -> tensor<1x1xi1> loc(#loc6) + %cst = stablehlo.constant dense<0x7FF8000000000000> : tensor loc(#loc6) + %8 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor) -> tensor<3x4xf64> loc(#loc6) + %9 = stablehlo.broadcast_in_dim %7, dims = [0, 1] : (tensor<1x1xi1>) -> tensor<3x4xi1> loc(#loc6) + %10 = stablehlo.select %9, %2#0, %8 : tensor<3x4xi1>, tensor<3x4xf64> loc(#loc6) + %11 = stablehlo.custom_call @cu_lu_pivots_to_permutation(%4) {mhlo.backend_config = {}, operand_layouts = [dense<0> : tensor<1xindex>], result_layouts = [dense<0> : tensor<1xindex>]} : (tensor<3xi32>) -> tensor<3xi32> loc(#loc7) + return %10, %4, %11 : tensor<3x4xf64>, tensor<3xi32>, tensor<3xi32> loc(#loc) + } loc(#loc) +} loc(#loc) +#loc = loc(unknown) +#loc1 = loc("third_party/py/jax/tests/export_back_compat_test.py":441:26) +#loc2 = loc("third_party/py/jax/tests/export_back_compat_test.py":441:14) +#loc3 = loc("third_party/py/jax/tests/export_back_compat_test.py":442:11) +#loc4 = loc("jit()/jit(main)/iota[dtype=float64 shape=(12,) dimension=0]"(#loc1)) +#loc5 = loc("jit()/jit(main)/reshape[new_sizes=(3, 4) dimensions=None]"(#loc2)) +#loc6 = loc("jit()/jit(main)/lu"(#loc3)) +#loc7 = loc("jit()/jit(main)/lu_pivots_to_permutation[permutation_size=3]"(#loc3)) +""", + mlir_module_serialized=b"ML\xefR\x01StableHLO_v0.9.0\x00\x01%\x05\x01\x03\x01\x03\x05\x03\x15\x07\t\x0b\r\x0f\x11\x13\x15\x17\x19\x03\xe9\xab+\x01c\x0f\x13\x07\x0b\x0b\x0f\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x17\x0b+\x0b\x0f\x0b\x0b\x0b3\x0b\x0b\x0b\x0b\x13\x0b\x0f\x0b\x17\x0f\x0b\x17S\x0b\x13\x13\x1b\x0b\x0b\x13\x13S\x0f\x0b\x03I\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0bO/\x0f\x0b\x17\x1b\x0b\x1b\x0b\x1b\x0b\x0b\x0b\x0f\x0b\x0f\x0f\x17\x17\x0f\x1f\x0f\x1f\x0b\x0b/O\x0b\x01\x05\x0b\x0f\x03'\x13\x0f\x17\x07\x07\x07\x07\x07\x0f\x1b\x13\x13\x13\x13\x13\x0f\x17\x17\x13\x02n\x06\x1dM!\x03\x03#\x9d\x1f\x05\x1b\x05\x1d\x11\x03\x05\x05\x1f\x05!\x05#\x05%\x05'\x05)\x05+\x05-\x05/\x051\x17\x07\xea\x06\x17\x053\x03\t')+\x0b-\x0b\r/\x055\x11\x01\x00\x057\x059\x05;\x03\x0b3c5y7{\r\x899\x8b\x05=\x05?\x05A\x05C\x03\x03=\x8d\x05E\x1dAC\x05G\x17\x07\xe6\x065\x1dGI\x05I\x17\x07\xe6\x06\x1d\x03\x13\x0fk\x11m\x13\x8f\x15c\x17o\x19q\x1b\x91\x1d\x93\x1f\x97\x05K\x03\x03\t\x9b\x03\x03\t\x9f\x03\x05U\xa1W\xa3\x05M\x05O\x03\x03\t\xa5\x03\x03#\xa7\x03\x13\x0fk\x11m\x13\xa9\x15c\x17o\x19q\x1bw\x1dc\x1fw\x1da!\x05Q\x03\x01\x1dS\x1dU\x1dW\x0b\x03\x1dY\x05\x01\r\x01\x1f\x1b!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x1f\x1d\x11\x00\x00\x00\x00\x00\x00\x00\x00\x03\x03u#\x17\x03\x07}\x81\x85\r\x05e\x7fgi\x1d[\r\x05e\x83gi\x1d]\r\x05e\x87gi\x1d_\x1da\x1dc\x13\r\x01\x1de\x03\x03s\x03\x03\x95\x15\x03\x01\x01\x01\x03\x07su\x99\x1f\x1f\x01\x1f\x07\t\x01\x00\x00\x00\x1f!\x01\x1f\x07\t\x00\x00\x00\x00\t\x07\x07\x05\x1f\x15\x11\x00\x00\x00\x00\x00\x00\xf8\x7f\x1f)!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x1dg\x01\t\x01\x02\x02)\x03\r\x13)\x01\x13)\x05\r\x11\x0b\x0b\x1d\x13\x01\x1b)\x01\x0b\x11\x01\x07\t\x05\x05)\x031\x0b)\x03\t\x0f)\x03\x05\x0f)\x03\x01\x0f)\x03\x01\r)\x01\x11)\x05\x05\x05\x11)\x05\r\x11\x11)\x03\t\r\x04.\x02\x05\x01\x11\x05%\x07\x03\x01\x05\t\x11\x051\x07\x03#A\x0b\x03?;\x03\x19\r\x06E\x03\t\x03\x01\x07\x07\x01K\x07\t\x05\x07\x03\x03\x05\x03\x01O\x03\x07\x03\x07\x01\x03\x03\x05\x03\x0b\x0f\x06\x01\x03\x05\x05\x07\r\x05\x03\x01Q\x03\x07\x03\x07\x01\x03\x03\x07\x03\x11\x11\x07\x01S\x03#\x05\t\x13\x03\x07\x01\x03\x03%\x03\x15\x05\x03\x01Y\x03\x15\x03\x07\x01\x03\x03\t\x03\x19\x03\x07\x01[\x03'\x03\x17\x13\x06\x01\x03\t\x07\x1d\x05\x1b\x07\x07_]\x03\x05\x03\x0f\x15\x04\x05\x07\x1f\x0f!\x06\x03\x01\x05\x01\x00\xe2\x0ei9'\x0f\x0b\t\t\t\x03\x11#!\x8b+\x1b7\x85\x89\x1f\x1f\x15\x1d\x15\x1b%)9+\x1f/!)!)#\x1f\x19\x13\ri\x15\x15\x17\x19\x17\x11\x11\x1f\x19)\x0f\x0b\x11builtin\x00vhlo\x00module\x00broadcast_in_dim_v1\x00constant_v1\x00custom_call_v1\x00func_v1\x00iota_v1\x00reshape_v1\x00subtract_v1\x00compare_v1\x00select_v1\x00return_v1\x00third_party/py/jax/tests/export_back_compat_test.py\x00value\x00sym_name\x00api_version\x00backend_config\x00call_target_name\x00called_computations\x00has_side_effect\x00mhlo.backend_config\x00operand_layouts\x00output_operand_aliases\x00result_layouts\x00broadcast_dimensions\x00jax.uses_shape_polymorphism\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit__lambda_\x00arg_attrs\x00function_type\x00res_attrs\x00sym_visibility\x00iota_dimension\x00jit()/jit(main)/iota[dtype=float64 shape=(12,) dimension=0]\x00jit()/jit(main)/reshape[new_sizes=(3, 4) dimensions=None]\x00jit()/jit(main)/lu\x00compare_type\x00comparison_direction\x00jit()/jit(main)/lu_pivots_to_permutation[permutation_size=3]\x00jax.result_info\x00mhlo.layout_mode\x00default\x00\x00[0]\x00[1]\x00[2]\x00main\x00public\x00cusolver_getrf_ffi\x00cu_lu_pivots_to_permutation\x00", + xla_call_module_version=9, + nr_devices=1, +) # End paste + +data_2024_08_19["c64"] = dict( + testdata_version=1, + platform='cuda', + custom_call_targets=['cu_lu_pivots_to_permutation', 'cusolver_getrf_ffi'], + serialized_date=datetime.date(2024, 8, 19), + inputs=(), + expected_outputs=(array([[ 8. +0.j, 9. +0.j, 10. +0.j, 11. +0.j], + [ 0. +0.j, 1. +0.j, 2. +0.j, 3. +0.j], + [ 0.5+0.j, 0.5+0.j, 0. +0.j, 0. +0.j]], dtype=complex64), array([2, 2, 2], dtype=int32), array([2, 0, 1], dtype=int32)), + mlir_module_text=r""" +module @jit__lambda_ attributes {jax.uses_shape_polymorphism = false, mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { + func.func public @main() -> (tensor<3x4xcomplex> {jax.result_info = "[0]", mhlo.layout_mode = "default"}, tensor<3xi32> {jax.result_info = "[1]", mhlo.layout_mode = "default"}, tensor<3xi32> {jax.result_info = "[2]", mhlo.layout_mode = "default"}) { + %0 = stablehlo.iota dim = 0 : tensor<12xcomplex> loc(#loc4) + %1 = stablehlo.reshape %0 : (tensor<12xcomplex>) -> tensor<3x4xcomplex> loc(#loc5) + %2:3 = stablehlo.custom_call @cusolver_getrf_ffi(%1) {mhlo.backend_config = {}, operand_layouts = [dense<[0, 1]> : tensor<2xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<0> : tensor<1xindex>, dense<> : tensor<0xindex>]} : (tensor<3x4xcomplex>) -> (tensor<3x4xcomplex>, tensor<3xi32>, tensor) loc(#loc6) + %c = stablehlo.constant dense<1> : tensor loc(#loc6) + %3 = stablehlo.broadcast_in_dim %c, dims = [] : (tensor) -> tensor<3xi32> loc(#loc6) + %4 = stablehlo.subtract %2#1, %3 : tensor<3xi32> loc(#loc6) + %c_0 = stablehlo.constant dense<0> : tensor loc(#loc6) + %5 = stablehlo.broadcast_in_dim %c_0, dims = [] : (tensor) -> tensor loc(#loc6) + %6 = stablehlo.compare GE, %2#2, %5, SIGNED : (tensor, tensor) -> tensor loc(#loc6) + %7 = stablehlo.broadcast_in_dim %6, dims = [] : (tensor) -> tensor<1x1xi1> loc(#loc6) + %cst = stablehlo.constant dense<(0x7FC00000,0x7FC00000)> : tensor> loc(#loc6) + %8 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor>) -> tensor<3x4xcomplex> loc(#loc6) + %9 = stablehlo.broadcast_in_dim %7, dims = [0, 1] : (tensor<1x1xi1>) -> tensor<3x4xi1> loc(#loc6) + %10 = stablehlo.select %9, %2#0, %8 : tensor<3x4xi1>, tensor<3x4xcomplex> loc(#loc6) + %11 = stablehlo.custom_call @cu_lu_pivots_to_permutation(%4) {mhlo.backend_config = {}, operand_layouts = [dense<0> : tensor<1xindex>], result_layouts = [dense<0> : tensor<1xindex>]} : (tensor<3xi32>) -> tensor<3xi32> loc(#loc7) + return %10, %4, %11 : tensor<3x4xcomplex>, tensor<3xi32>, tensor<3xi32> loc(#loc) + } loc(#loc) +} loc(#loc) +#loc = loc(unknown) +#loc1 = loc("third_party/py/jax/tests/export_back_compat_test.py":441:26) +#loc2 = loc("third_party/py/jax/tests/export_back_compat_test.py":441:14) +#loc3 = loc("third_party/py/jax/tests/export_back_compat_test.py":442:11) +#loc4 = loc("jit()/jit(main)/iota[dtype=complex64 shape=(12,) dimension=0]"(#loc1)) +#loc5 = loc("jit()/jit(main)/reshape[new_sizes=(3, 4) dimensions=None]"(#loc2)) +#loc6 = loc("jit()/jit(main)/lu"(#loc3)) +#loc7 = loc("jit()/jit(main)/lu_pivots_to_permutation[permutation_size=3]"(#loc3)) +""", + mlir_module_serialized=b"ML\xefR\x01StableHLO_v0.9.0\x00\x01%\x05\x01\x03\x01\x03\x05\x03\x15\x07\t\x0b\r\x0f\x11\x13\x15\x17\x19\x03\xeb\xab-\x01c\x0f\x13\x07\x0b\x0b\x0f\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x17\x0b+\x0b\x0f\x0b\x0b\x0b3\x0b\x0b\x0b\x0b\x13\x0b\x0f\x0b\x17\x0f\x0b\x17S\x0b\x13\x13\x1b\x0b\x0b\x13\x13S\x0f\x0b\x03I\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0bO/\x0f\x0b\x17\x1b\x0b\x1b\x0b\x1b\x0b\x0b\x0b\x0f\x0b\x0f\x0f\x17\x17\x0f\x1f\x0f\x1f\x0b\x0b/O\x0b\x01\x05\x0b\x0f\x03)\x13\x0f\x17\x0b\x07\x07\x07\x07\x0f\x1b\x07\x13\x13\x13\x13\x13\x0f\x17\x17\x13\x02v\x06\x1dM!\x03\x03#\x9d\x1f\x05\x1b\x05\x1d\x11\x03\x05\x05\x1f\x05!\x05#\x05%\x05'\x05)\x05+\x05-\x05/\x051\x17\x07\xea\x06\x17\x053\x03\t')+\x0b-\x0b\r/\x055\x11\x01\x00\x057\x059\x05;\x03\x0b3c5y7{\r\x899\x8b\x05=\x05?\x05A\x05C\x03\x03=\x8d\x05E\x1dAC\x05G\x17\x07\xe6\x065\x1dGI\x05I\x17\x07\xe6\x06\x1d\x03\x13\x0fk\x11m\x13\x8f\x15c\x17o\x19q\x1b\x91\x1d\x93\x1f\x97\x05K\x03\x03\t\x9b\x03\x03\t\x9f\x03\x05U\xa1W\xa3\x05M\x05O\x03\x03\t\xa5\x03\x03#\xa7\x03\x13\x0fk\x11m\x13\xa9\x15c\x17o\x19q\x1bw\x1dc\x1fw\x1da!\x05Q\x03\x01\x1dS\x1dU\x1dW\x0b\x03\x1dY\x05\x01\r\x01\x1f\x1d!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x1f\x1f\x11\x00\x00\x00\x00\x00\x00\x00\x00\x03\x03u#\x17\x03\x07}\x81\x85\r\x05e\x7fgi\x1d[\r\x05e\x83gi\x1d]\r\x05e\x87gi\x1d_\x1da\x1dc\x13\r\x01\x1de\x03\x03s\x03\x03\x95\x15\x03\x01\x01\x01\x03\x07su\x99\x1f!\x01\x1f\x07\t\x01\x00\x00\x00\x1f#\x01\x1f\x07\t\x00\x00\x00\x00\t\x07\x07\x05\x1f\x15\x11\x00\x00\xc0\x7f\x00\x00\xc0\x7f\x1f+!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x1dg\x01\t\x01\x02\x02)\x03\r\x13)\x01\x13)\x05\r\x11\x0b\x03\x19\x1d\x13\x01\x1b)\x01\x0b\x11\x01\x07\t\x05\x05\t)\x031\x0b)\x03\t\x0f)\x03\x05\x0f)\x03\x01\x0f)\x03\x01\r)\x01\x11)\x05\x05\x05\x11)\x05\r\x11\x11)\x03\t\r\x04.\x02\x05\x01\x11\x05%\x07\x03\x01\x05\t\x11\x051\x07\x03#A\x0b\x03?;\x03\x1b\r\x06E\x03\t\x03\x01\x07\x07\x01K\x07\t\x05\x07\x03\x03\x05\x03\x01O\x03\x07\x03\x07\x01\x03\x03\x05\x03\x0b\x0f\x06\x01\x03\x05\x05\x07\r\x05\x03\x01Q\x03\x07\x03\x07\x01\x03\x03\x07\x03\x11\x11\x07\x01S\x03%\x05\t\x13\x03\x07\x01\x03\x03'\x03\x15\x05\x03\x01Y\x03\x15\x03\x07\x01\x03\x03\t\x03\x19\x03\x07\x01[\x03)\x03\x17\x13\x06\x01\x03\t\x07\x1d\x05\x1b\x07\x07_]\x03\x05\x03\x0f\x15\x04\x05\x07\x1f\x0f!\x06\x03\x01\x05\x01\x00\xea\x0ei9'\x0f\x0b\t\t\t\x03\x11#!\x8b+\x1b7\x85\x8d\x1f\x1f\x15\x1d\x15\x1b%)9+\x1f/!)!)#\x1f\x19\x13\ri\x15\x15\x17\x19\x17\x11\x11\x1f\x19)\x0f\x0b\x11builtin\x00vhlo\x00module\x00broadcast_in_dim_v1\x00constant_v1\x00custom_call_v1\x00func_v1\x00iota_v1\x00reshape_v1\x00subtract_v1\x00compare_v1\x00select_v1\x00return_v1\x00third_party/py/jax/tests/export_back_compat_test.py\x00value\x00sym_name\x00api_version\x00backend_config\x00call_target_name\x00called_computations\x00has_side_effect\x00mhlo.backend_config\x00operand_layouts\x00output_operand_aliases\x00result_layouts\x00broadcast_dimensions\x00jax.uses_shape_polymorphism\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit__lambda_\x00arg_attrs\x00function_type\x00res_attrs\x00sym_visibility\x00iota_dimension\x00jit()/jit(main)/iota[dtype=complex64 shape=(12,) dimension=0]\x00jit()/jit(main)/reshape[new_sizes=(3, 4) dimensions=None]\x00jit()/jit(main)/lu\x00compare_type\x00comparison_direction\x00jit()/jit(main)/lu_pivots_to_permutation[permutation_size=3]\x00jax.result_info\x00mhlo.layout_mode\x00default\x00\x00[0]\x00[1]\x00[2]\x00main\x00public\x00cusolver_getrf_ffi\x00cu_lu_pivots_to_permutation\x00", + xla_call_module_version=9, + nr_devices=1, +) # End paste + +data_2024_08_19["c128"] = dict( + testdata_version=1, + platform='cuda', + custom_call_targets=['cu_lu_pivots_to_permutation', 'cusolver_getrf_ffi'], + serialized_date=datetime.date(2024, 8, 19), + inputs=(), + expected_outputs=(array([[ 8. +0.j, 9. +0.j, 10. +0.j, 11. +0.j], + [ 0. +0.j, 1. +0.j, 2. +0.j, 3. +0.j], + [ 0.5+0.j, 0.5+0.j, 0. +0.j, 0. +0.j]]), array([2, 2, 2], dtype=int32), array([2, 0, 1], dtype=int32)), + mlir_module_text=r""" +module @jit__lambda_ attributes {jax.uses_shape_polymorphism = false, mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { + func.func public @main() -> (tensor<3x4xcomplex> {jax.result_info = "[0]", mhlo.layout_mode = "default"}, tensor<3xi32> {jax.result_info = "[1]", mhlo.layout_mode = "default"}, tensor<3xi32> {jax.result_info = "[2]", mhlo.layout_mode = "default"}) { + %0 = stablehlo.iota dim = 0 : tensor<12xcomplex> loc(#loc4) + %1 = stablehlo.reshape %0 : (tensor<12xcomplex>) -> tensor<3x4xcomplex> loc(#loc5) + %2:3 = stablehlo.custom_call @cusolver_getrf_ffi(%1) {mhlo.backend_config = {}, operand_layouts = [dense<[0, 1]> : tensor<2xindex>], output_operand_aliases = [#stablehlo.output_operand_alias], result_layouts = [dense<[0, 1]> : tensor<2xindex>, dense<0> : tensor<1xindex>, dense<> : tensor<0xindex>]} : (tensor<3x4xcomplex>) -> (tensor<3x4xcomplex>, tensor<3xi32>, tensor) loc(#loc6) + %c = stablehlo.constant dense<1> : tensor loc(#loc6) + %3 = stablehlo.broadcast_in_dim %c, dims = [] : (tensor) -> tensor<3xi32> loc(#loc6) + %4 = stablehlo.subtract %2#1, %3 : tensor<3xi32> loc(#loc6) + %c_0 = stablehlo.constant dense<0> : tensor loc(#loc6) + %5 = stablehlo.broadcast_in_dim %c_0, dims = [] : (tensor) -> tensor loc(#loc6) + %6 = stablehlo.compare GE, %2#2, %5, SIGNED : (tensor, tensor) -> tensor loc(#loc6) + %7 = stablehlo.broadcast_in_dim %6, dims = [] : (tensor) -> tensor<1x1xi1> loc(#loc6) + %cst = stablehlo.constant dense<(0x7FF8000000000000,0x7FF8000000000000)> : tensor> loc(#loc6) + %8 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor>) -> tensor<3x4xcomplex> loc(#loc6) + %9 = stablehlo.broadcast_in_dim %7, dims = [0, 1] : (tensor<1x1xi1>) -> tensor<3x4xi1> loc(#loc6) + %10 = stablehlo.select %9, %2#0, %8 : tensor<3x4xi1>, tensor<3x4xcomplex> loc(#loc6) + %11 = stablehlo.custom_call @cu_lu_pivots_to_permutation(%4) {mhlo.backend_config = {}, operand_layouts = [dense<0> : tensor<1xindex>], result_layouts = [dense<0> : tensor<1xindex>]} : (tensor<3xi32>) -> tensor<3xi32> loc(#loc7) + return %10, %4, %11 : tensor<3x4xcomplex>, tensor<3xi32>, tensor<3xi32> loc(#loc) + } loc(#loc) +} loc(#loc) +#loc = loc(unknown) +#loc1 = loc("third_party/py/jax/tests/export_back_compat_test.py":441:26) +#loc2 = loc("third_party/py/jax/tests/export_back_compat_test.py":441:14) +#loc3 = loc("third_party/py/jax/tests/export_back_compat_test.py":442:11) +#loc4 = loc("jit()/jit(main)/iota[dtype=complex128 shape=(12,) dimension=0]"(#loc1)) +#loc5 = loc("jit()/jit(main)/reshape[new_sizes=(3, 4) dimensions=None]"(#loc2)) +#loc6 = loc("jit()/jit(main)/lu"(#loc3)) +#loc7 = loc("jit()/jit(main)/lu_pivots_to_permutation[permutation_size=3]"(#loc3)) +""", + mlir_module_serialized=b"ML\xefR\x01StableHLO_v0.9.0\x00\x01%\x05\x01\x03\x01\x03\x05\x03\x15\x07\t\x0b\r\x0f\x11\x13\x15\x17\x19\x03\xeb\xab-\x01c\x0f\x13\x07\x0b\x0b\x0f\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x17\x0b+\x0b\x0f\x0b\x0b\x0b3\x0b\x0b\x0b\x0b\x13\x0b\x0f\x0b\x17\x0f\x0b\x17S\x0b\x13\x13\x1b\x0b\x0b\x13\x13S\x0f\x0b\x03I\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0bO/\x0f\x0b\x17\x1b\x0b\x1b\x0b\x1b\x0b\x0b\x0b\x0f\x0b\x0f\x0f\x17\x17\x0f\x1f\x0f\x1f\x0b\x0bOO\x0b\x01\x05\x0b\x0f\x03)\x13\x0f\x17\x0b\x07\x07\x07\x07\x0f\x1b\x07\x13\x13\x13\x13\x13\x0f\x17\x17\x13\x02\x96\x06\x1dM!\x03\x03#\x9d\x1f\x05\x1b\x05\x1d\x11\x03\x05\x05\x1f\x05!\x05#\x05%\x05'\x05)\x05+\x05-\x05/\x051\x17\x07\xea\x06\x17\x053\x03\t')+\x0b-\x0b\r/\x055\x11\x01\x00\x057\x059\x05;\x03\x0b3c5y7{\r\x899\x8b\x05=\x05?\x05A\x05C\x03\x03=\x8d\x05E\x1dAC\x05G\x17\x07\xe6\x065\x1dGI\x05I\x17\x07\xe6\x06\x1d\x03\x13\x0fk\x11m\x13\x8f\x15c\x17o\x19q\x1b\x91\x1d\x93\x1f\x97\x05K\x03\x03\t\x9b\x03\x03\t\x9f\x03\x05U\xa1W\xa3\x05M\x05O\x03\x03\t\xa5\x03\x03#\xa7\x03\x13\x0fk\x11m\x13\xa9\x15c\x17o\x19q\x1bw\x1dc\x1fw\x1da!\x05Q\x03\x01\x1dS\x1dU\x1dW\x0b\x03\x1dY\x05\x01\r\x01\x1f\x1d!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x1f\x1f\x11\x00\x00\x00\x00\x00\x00\x00\x00\x03\x03u#\x17\x03\x07}\x81\x85\r\x05e\x7fgi\x1d[\r\x05e\x83gi\x1d]\r\x05e\x87gi\x1d_\x1da\x1dc\x13\r\x01\x1de\x03\x03s\x03\x03\x95\x15\x03\x01\x01\x01\x03\x07su\x99\x1f!\x01\x1f\x07\t\x01\x00\x00\x00\x1f#\x01\x1f\x07\t\x00\x00\x00\x00\t\x07\x07\x05\x1f\x15!\x00\x00\x00\x00\x00\x00\xf8\x7f\x00\x00\x00\x00\x00\x00\xf8\x7f\x1f+!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x1dg\x01\t\x01\x02\x02)\x03\r\x13)\x01\x13)\x05\r\x11\x0b\x03\x19\x1d\x13\x01\x1b)\x01\x0b\x11\x01\x07\t\x05\x05\x0b)\x031\x0b)\x03\t\x0f)\x03\x05\x0f)\x03\x01\x0f)\x03\x01\r)\x01\x11)\x05\x05\x05\x11)\x05\r\x11\x11)\x03\t\r\x04.\x02\x05\x01\x11\x05%\x07\x03\x01\x05\t\x11\x051\x07\x03#A\x0b\x03?;\x03\x1b\r\x06E\x03\t\x03\x01\x07\x07\x01K\x07\t\x05\x07\x03\x03\x05\x03\x01O\x03\x07\x03\x07\x01\x03\x03\x05\x03\x0b\x0f\x06\x01\x03\x05\x05\x07\r\x05\x03\x01Q\x03\x07\x03\x07\x01\x03\x03\x07\x03\x11\x11\x07\x01S\x03%\x05\t\x13\x03\x07\x01\x03\x03'\x03\x15\x05\x03\x01Y\x03\x15\x03\x07\x01\x03\x03\t\x03\x19\x03\x07\x01[\x03)\x03\x17\x13\x06\x01\x03\t\x07\x1d\x05\x1b\x07\x07_]\x03\x05\x03\x0f\x15\x04\x05\x07\x1f\x0f!\x06\x03\x01\x05\x01\x00\xee\x0ei9'\x0f\x0b\t\t\t\x03\x11#!\x8b+\x1b7\x85\x8f\x1f\x1f\x15\x1d\x15\x1b%)9+\x1f/!)!)#\x1f\x19\x13\ri\x15\x15\x17\x19\x17\x11\x11\x1f\x19)\x0f\x0b\x11builtin\x00vhlo\x00module\x00broadcast_in_dim_v1\x00constant_v1\x00custom_call_v1\x00func_v1\x00iota_v1\x00reshape_v1\x00subtract_v1\x00compare_v1\x00select_v1\x00return_v1\x00third_party/py/jax/tests/export_back_compat_test.py\x00value\x00sym_name\x00api_version\x00backend_config\x00call_target_name\x00called_computations\x00has_side_effect\x00mhlo.backend_config\x00operand_layouts\x00output_operand_aliases\x00result_layouts\x00broadcast_dimensions\x00jax.uses_shape_polymorphism\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit__lambda_\x00arg_attrs\x00function_type\x00res_attrs\x00sym_visibility\x00iota_dimension\x00jit()/jit(main)/iota[dtype=complex128 shape=(12,) dimension=0]\x00jit()/jit(main)/reshape[new_sizes=(3, 4) dimensions=None]\x00jit()/jit(main)/lu\x00compare_type\x00comparison_direction\x00jit()/jit(main)/lu_pivots_to_permutation[permutation_size=3]\x00jax.result_info\x00mhlo.layout_mode\x00default\x00\x00[0]\x00[1]\x00[2]\x00main\x00public\x00cusolver_getrf_ffi\x00cu_lu_pivots_to_permutation\x00", + xla_call_module_version=9, + nr_devices=1, +) # End paste diff --git a/jax/_src/internal_test_util/export_back_compat_test_data/cuda_lu_pivots_to_permutation.py b/jax/_src/internal_test_util/export_back_compat_test_data/cuda_lu_pivots_to_permutation.py new file mode 100644 index 000000000000..12285a45b77a --- /dev/null +++ b/jax/_src/internal_test_util/export_back_compat_test_data/cuda_lu_pivots_to_permutation.py @@ -0,0 +1,55 @@ +# Copyright 2024 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import datetime +from numpy import array, int32 + +# Pasted from the test output (see export_back_compat_test_util.py module docstring) +data_2024_08_08 = dict( + testdata_version=1, + platform='cuda', + custom_call_targets=['cu_lu_pivots_to_permutation'], + serialized_date=datetime.date(2024, 8, 8), + inputs=(), + expected_outputs=(array([[[0, 1, 2, 3, 4, 5, 6, 7], + [4, 5, 6, 7, 0, 1, 2, 3], + [0, 1, 2, 3, 4, 5, 6, 7]], + + [[0, 1, 2, 3, 4, 5, 6, 7], + [0, 1, 2, 3, 4, 5, 6, 7], + [0, 1, 2, 3, 4, 5, 6, 7]]], dtype=int32),), + mlir_module_text=r""" +module @jit__lambda_ attributes {jax.uses_shape_polymorphism = false, mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { + func.func public @main() -> (tensor<2x3x8xi32> {jax.result_info = "", mhlo.layout_mode = "default"}) { + %0 = stablehlo.iota dim = 0 : tensor<24xi32> loc(#loc4) + %1 = stablehlo.reshape %0 : (tensor<24xi32>) -> tensor<2x3x4xi32> loc(#loc5) + %c = stablehlo.constant dense<2> : tensor loc(#loc6) + %c_0 = stablehlo.constant dense<3> : tensor loc(#loc6) + %c_1 = stablehlo.constant dense<4> : tensor loc(#loc6) + %2 = stablehlo.custom_call @cu_lu_pivots_to_permutation(%1) {mhlo.backend_config = {permutation_size = 8 : i32}, operand_layouts = [dense<[2, 1, 0]> : tensor<3xindex>], result_layouts = [dense<[2, 1, 0]> : tensor<3xindex>]} : (tensor<2x3x4xi32>) -> tensor<2x3x8xi32> loc(#loc6) + return %2 : tensor<2x3x8xi32> loc(#loc) + } loc(#loc) +} loc(#loc) +#loc = loc(unknown) +#loc1 = loc("third_party/py/jax/tests/export_back_compat_test.py":347:26) +#loc2 = loc("third_party/py/jax/tests/export_back_compat_test.py":347:14) +#loc3 = loc("third_party/py/jax/tests/export_back_compat_test.py":348:11) +#loc4 = loc("jit()/jit(main)/iota[dtype=int32 shape=(24,) dimension=0]"(#loc1)) +#loc5 = loc("jit()/jit(main)/reshape[new_sizes=(2, 3, 4) dimensions=None]"(#loc2)) +#loc6 = loc("jit()/jit(main)/lu_pivots_to_permutation[permutation_size=8]"(#loc3)) +""", + mlir_module_serialized=b"ML\xefR\x01StableHLO_v0.9.0\x00\x01\x1d\x05\x01\x03\x01\x03\x05\x03\r\x07\t\x0b\r\x0f\x11\x03\xa7}\x17\x01Q\x0f\x07\x0b\x0b\x0f\x0b+\x0b\x0f\x0b\x0b\x0b3\x0b\x0b\x0b\x0b\x13\x0b\x0f\x0b\x17\x0f\x0b\x17\x13\x0b\x17\x13\x13S\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x03-\x0b\x0b\x0f\x0b\x0f\x1b\x0b\x0b\x0b\x0b\x0b\x0f///\x0b\x0b\x0b\x13\x0b\x0fo\x01\x05\x0b\x0f\x03\x13\x0f\x07\x1b\x07\x13\x13\x1b\x13\x07\x02Z\x04\x1d57\x1f\x05\x13\x05\x15\x11\x03\x05\x05\x17\x03\t\x0f\x11\x13\t\x15\t\x0b\x17\x05\x19\x11\x01\x00\x05\x1b\x05\x1d\x05\x1f\x03\x0b\x1bQ\x1dW\x1fY\x0bc!e\x05!\x05#\x05%\x05'\x03\x03%g\x05)\x1d)+\x05+\x17\x05n\x055\x1d/1\x05-\x17\x05n\x05\x1d\x03\x03\x07i\x05/\x17\x05r\x05\x17\x03\x03\x07k\x03\x03\x07m\x03\x13?oASCqEQGsIuKUMQOU\x051\x053\x055\x057\x059\x05;\x05=\x05?\x05A\x03\x01\x1dC\x03\x03{#\r\x03\x03[\r\x05]S_a\x1dE\x1dG\x1dI\x1dK\x1dM\x13\x0b\x01\x1f\x05\x11\x02\x00\x00\x00\x00\x00\x00\x00\x1f\x05\x11\x03\x00\x00\x00\x00\x00\x00\x00\x1f\x05\x11\x04\x00\x00\x00\x00\x00\x00\x00\x0b\x03\x1dO\x05\x01\r\x03wy\x1dQ\x13\x07!\x1f\x131\x02\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\t\x01\x02\x02)\x01\x0b\x1b)\x07\t\r!\x07\x1d\x11\x01\x03\t)\x03a\x07)\x07\t\r\x11\x07)\x03\r\x15\x13\x04{\x05\x01\x11\x03\r\x07\x03\x01\x05\x05\x11\x03\x19\x07\x03\r\x1d\x07\x03'#\x03\x0f\t\x06-\x03\x11\x03\x01\x03\x03\x013\x03\x05\x03\x03\x019\x03\x05\x03\x03\x01;\x03\x05\x0b\x07\x01=\x03\t\x03\x03\r\x04\x03\x03\x0b\x06\x03\x01\x05\x01\x00f\x0cS#9\x0f\x0b\x11#!\x03\x1f/!)!)#\x1f\x19\x8b\x8b\x85\x1f\x1f\x15\x1d\x15\x1b%)9\x13\ri\x15\x1f\x17\x11\x11\x19\x0f\x0b\x11builtin\x00vhlo\x00module\x00constant_v1\x00func_v1\x00iota_v1\x00reshape_v1\x00custom_call_v1\x00return_v1\x00third_party/py/jax/tests/export_back_compat_test.py\x00value\x00sym_name\x00jax.uses_shape_polymorphism\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit__lambda_\x00arg_attrs\x00function_type\x00res_attrs\x00sym_visibility\x00iota_dimension\x00jit()/jit(main)/iota[dtype=int32 shape=(24,) dimension=0]\x00jit()/jit(main)/reshape[new_sizes=(2, 3, 4) dimensions=None]\x00jit()/jit(main)/lu_pivots_to_permutation[permutation_size=8]\x00api_version\x00backend_config\x00call_target_name\x00called_computations\x00has_side_effect\x00mhlo.backend_config\x00operand_layouts\x00output_operand_aliases\x00result_layouts\x00\x00jax.result_info\x00mhlo.layout_mode\x00default\x00main\x00public\x00cu_lu_pivots_to_permutation\x00permutation_size\x00", + xla_call_module_version=9, + nr_devices=1, +) # End paste diff --git a/jax/_src/internal_test_util/export_back_compat_test_data/pallas/mosaic_matmul.py b/jax/_src/internal_test_util/export_back_compat_test_data/pallas/mosaic_matmul.py new file mode 100644 index 000000000000..2c94cb777b46 --- /dev/null +++ b/jax/_src/internal_test_util/export_back_compat_test_data/pallas/mosaic_matmul.py @@ -0,0 +1,345 @@ +# Copyright 2024 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import datetime +from numpy import array, float32 + + +# Pasted from the test output (see export_back_compat_test_util.py module docstring) +data_2024_09_24 = dict( + testdata_version=1, + platform='tpu', + custom_call_targets=['tpu_custom_call'], + serialized_date=datetime.date(2024, 9, 24), + inputs=(), + expected_outputs=(array([[ 90458.2 , 90470.875, 90480.85 , 90491.1 , + 90500.945, 90510.945, 90521.19 , 90530.95 , + 90540.78 , 90551.16 , 90560.67 , 90570.734, + 90580.73 , 90590.586, 90600.66 , 90610.61 ], + [ 643341.75 , 643434.25 , 643509.75 , 643587.1 , + 643660.1 , 643735.94 , 643813.5 , 643886. , + 643960.6 , 644039.5 , 644110.25 , 644186.75 , + 644262.5 , 644336.06 , 644412.9 , 644488.3 ], + [ 1196323.2 , 1196495.6 , 1196636.8 , 1196781. , + 1196917.5 , 1197059. , 1197203.9 , 1197339.2 , + 1197478.5 , 1197625.8 , 1197757.8 , 1197900.5 , + 1198042. , 1198179.4 , 1198323.1 , 1198464. ], + [ 1749075.5 , 1749327.8 , 1749534.4 , 1749746. , + 1749945.5 , 1750152.8 , 1750365.1 , 1750563. , + 1750767.2 , 1750983.1 , 1751176.2 , 1751385.5 , + 1751592.8 , 1751793.8 , 1752004.2 , 1752210.8 ], + [ 2302500.5 , 2302832.5 , 2303104.8 , 2303383.5 , + 2303646.2 , 2303919.5 , 2304199. , 2304459.8 , + 2304728.5 , 2305013. , 2305267.2 , 2305542.8 , + 2305816.2 , 2306081. , 2306358.5 , 2306630.5 ], + [ 2855440.2 , 2855852.2 , 2856190.2 , 2856535.5 , + 2856861.5 , 2857200.5 , 2857547.2 , 2857870.8 , + 2858204.5 , 2858557. , 2858872.5 , 2859214.5 , + 2859553.2 , 2859882. , 2860226. , 2860563.5 ], + [ 3407472. , 3407964.2 , 3408367.5 , 3408780.2 , + 3409169.5 , 3409574.2 , 3409988.5 , 3410374.5 , + 3410772.8 , 3411194. , 3411570.5 , 3411978.8 , + 3412383.5 , 3412776. , 3413186.8 , 3413590. ], + [ 3959847.5 , 3960419. , 3960888. , 3961367.8 , + 3961820.2 , 3962291. , 3962772.5 , 3963221.2 , + 3963684.8 , 3964174.5 , 3964612. , 3965086.8 , + 3965557.2 , 3966013.2 , 3966491. , 3966959.5 ], + [ 4515869.5 , 4516521.5 , 4517056. , 4517602.5 , + 4518118. , 4518654.5 , 4519203. , 4519715. , + 4520243. , 4520801. , 4521300. , 4521841. , + 4522377.5 , 4522897. , 4523441.5 , 4523975.5 ], + [ 5061659. , 5062390. , 5062990. , 5063603.5 , + 5064182. , 5064784.5 , 5065401. , 5065975. , + 5066567.5 , 5067194. , 5067754. , 5068362. , + 5068964. , 5069547. , 5070159. , 5070758.5 ], + [ 5621329. , 5622141. , 5622806.5 , 5623487.5 , + 5624129.5 , 5624797. , 5625481. , 5626118. , + 5626775. , 5627470.5 , 5628092. , 5628765. , + 5629433. , 5630080.5 , 5630758.5 , 5631424. ], + [ 6172820.5 , 6173712. , 6174443. , 6175191. , + 6175896.5 , 6176630. , 6177381. , 6178080.5 , + 6178803. , 6179566. , 6180248.5 , 6180988.5 , + 6181722. , 6182432.5 , 6183177.5 , 6183908. ], + [ 6723343.5 , 6724315. , 6725111.5 , 6725927. , + 6726696. , 6727495.5 , 6728313.5 , 6729076.5 , + 6729864. , 6730696. , 6731440. , 6732246. , + 6733046. , 6733820.5 , 6734632. , 6735428.5 ], + [ 7280537. , 7281587.5 , 7282449.5 , 7283331.5 , + 7284163.5 , 7285029. , 7285914. , 7286739. , + 7287591. , 7288492. , 7289297. , 7290169.5 , + 7291035. , 7291873.5 , 7292752. , 7293614. ], + [ 7828292. , 7829423. , 7830350. , 7831299.5 , + 7832194.5 , 7833125.5 , 7834078.5 , 7834966. , + 7835883. , 7836852. , 7837718. , 7838657. , + 7839588. , 7840490. , 7841436. , 7842363.5 ], + [ 8384808.5 , 8386019.5 , 8387012.5 , 8388029.5 , + 8388988. , 8389985. , 8391005. , 8391956. , + 8392937. , 8393974. , 8394902. , 8395907. , + 8396904. , 8397870. , 8398882. , 8399875. ], + [ 8928697. , 8929987. , 8931044. , 8932126. , + 8933146. , 8934208. , 8935294. , 8936306. , + 8937351. , 8938455. , 8939443. , 8940514. , + 8941574. , 8942604. , 8943682. , 8944738. ], + [ 9501496. , 9502866. , 9503990. , 9505141. , + 9506226. , 9507354. , 9508508. , 9509584. , + 9510695. , 9511870. , 9512919. , 9514058. , + 9515186. , 9516279. , 9517425. , 9518549. ], + [10055416. , 10056868. , 10058060. , 10059279. , + 10060428. , 10061624. , 10062848. , 10063988. , + 10065166. , 10066410. , 10067522. , 10068729. , + 10069925. , 10071083. , 10072298. , 10073489. ], + [10595886. , 10597416. , 10598672. , 10599958. , + 10601170. , 10602431. , 10603721. , 10604923. , + 10606164. , 10607477. , 10608650. , 10609922. , + 10611182. , 10612404. , 10613684. , 10614940. ], + [11135804. , 11137412. , 11138732. , 11140083. , + 11141357. , 11142682. , 11144038. , 11145302. , + 11146606. , 11147985. , 11149218. , 11150554. , + 11151880. , 11153164. , 11154509. , 11155829. ], + [11686791. , 11688480. , 11689864. , 11691282. , + 11692618. , 11694007. , 11695430. , 11696756. , + 11698124. , 11699570. , 11700864. , 11702265. , + 11703656. , 11705003. , 11706414. , 11707799. ], + [12263420. , 12265190. , 12266642. , 12268128. , + 12269529. , 12270986. , 12272478. , 12273868. , + 12275303. , 12276820. , 12278176. , 12279646. , + 12281104. , 12282516. , 12283996. , 12285446. ], + [12821178. , 12823029. , 12824548. , 12826102. , + 12827567. , 12829092. , 12830652. , 12832106. , + 12833606. , 12835192. , 12836610. , 12838148. , + 12839673. , 12841150. , 12842699. , 12844217. ], + [13362964. , 13364895. , 13366479. , 13368100. , + 13369628. , 13371218. , 13372846. , 13374362. , + 13375927. , 13377582. , 13379061. , 13380665. , + 13382256. , 13383796. , 13385411. , 13386995. ], + [13902882. , 13904890. , 13906538. , 13908225. , + 13909815. , 13911470. , 13913163. , 13914740. , + 13916368. , 13918090. , 13919629. , 13921298. , + 13922952. , 13924556. , 13926236. , 13927884. ], + [14443848. , 14445934. , 14447646. , 14449398. , + 14451050. , 14452769. , 14454528. , 14456166. , + 14457858. , 14459647. , 14461246. , 14462979. , + 14464698. , 14466363. , 14468108. , 14469820. ], + [15024406. , 15026576. , 15028355. , 15030176. , + 15031893. , 15033679. , 15035507. , 15037210. , + 15038968. , 15040828. , 15042490. , 15044291. , + 15046077. , 15047808. , 15049621. , 15051400. ], + [15586096. , 15588347. , 15590193. , 15592082. , + 15593863. , 15595716. , 15597613. , 15599380. , + 15601204. , 15603133. , 15604856. , 15606726. , + 15608579. , 15610375. , 15612257. , 15614103. ], + [16130043. , 16132373. , 16134285. , 16136242. , + 16138087. , 16140006. , 16141970. , 16143800. , + 16145690. , 16147688. , 16149473. , 16151409. , + 16153328. , 16155188. , 16157138. , 16159050. ], + [16669960. , 16672369. , 16674345. , 16676367. , + 16678274. , 16680258. , 16682287. , 16684178. , + 16686132. , 16688196. , 16690041. , 16692042. , + 16694026. , 16695948. , 16697962. , 16699938. ], + [17209878. , 17212364. , 17214404. , 17216492. , + 17218460. , 17220508. , 17222604. , 17224556. , + 17226572. , 17228704. , 17230608. , 17232676. , + 17234724. , 17236708. , 17238788. , 17240828. ], + [17817286. , 17819860. , 17821972. , 17824132. , + 17826172. , 17828292. , 17830460. , 17832482. , + 17834570. , 17836776. , 17838748. , 17840888. , + 17843008. , 17845062. , 17847216. , 17849328. ], + [18357204. , 18359856. , 18362032. , 18364258. , + 18366358. , 18368542. , 18370778. , 18372860. , + 18375012. , 18377284. , 18379316. , 18381520. , + 18383704. , 18385820. , 18388040. , 18390216. ], + [18897120. , 18899852. , 18902092. , 18904384. , + 18906544. , 18908794. , 18911096. , 18913240. , + 18915452. , 18917792. , 18919884. , 18922152. , + 18924402. , 18926580. , 18928864. , 18931104. ], + [19437040. , 19439848. , 19442152. , 19444508. , + 19446732. , 19449044. , 19451412. , 19453616. , + 19455894. , 19458302. , 19460452. , 19462786. , + 19465100. , 19467340. , 19469688. , 19471992. ], + [19976956. , 19979844. , 19982212. , 19984634. , + 19986920. , 19989296. , 19991728. , 19993996. , + 19996336. , 19998810. , 20001020. , 20003420. , + 20005796. , 20008100. , 20010514. , 20012882. ], + [20516874. , 20519838. , 20522270. , 20524760. , + 20527106. , 20529548. , 20532046. , 20534374. , + 20536776. , 20539318. , 20541588. , 20544052. , + 20546492. , 20548858. , 20551338. , 20553770. ], + [21056792. , 21059834. , 21062330. , 21064884. , + 21067292. , 21069800. , 21072364. , 21074752. , + 21077218. , 21079826. , 21082156. , 21084684. , + 21087190. , 21089618. , 21092164. , 21094660. ], + [21596710. , 21599830. , 21602390. , 21605010. , + 21607480. , 21610050. , 21612680. , 21615130. , + 21617660. , 21620336. , 21622724. , 21625318. , + 21627888. , 21630378. , 21632988. , 21635548. ], + [22218698. , 22221906. , 22224536. , 22227228. , + 22229768. , 22232408. , 22235108. , 22237628. , + 22240228. , 22242976. , 22245432. , 22248094. , + 22250736. , 22253292. , 22255972. , 22258602. ], + [22802946. , 22806238. , 22808938. , 22811700. , + 22814306. , 22817016. , 22819790. , 22822374. , + 22825044. , 22827864. , 22830384. , 22833120. , + 22835830. , 22838456. , 22841208. , 22843908. ], + [23351442. , 23354816. , 23357584. , 23360416. , + 23363088. , 23365866. , 23368710. , 23371360. , + 23374096. , 23376988. , 23379572. , 23382374. , + 23385154. , 23387846. , 23390668. , 23393436. ], + [23891360. , 23894812. , 23897644. , 23900542. , + 23903276. , 23906118. , 23909028. , 23911738. , + 23914536. , 23917496. , 23920140. , 23923008. , + 23925850. , 23928604. , 23931492. , 23934324. ], + [24431278. , 24434808. , 24437704. , 24440668. , + 24443462. , 24446368. , 24449344. , 24452116. , + 24454978. , 24458004. , 24460708. , 24463640. , + 24466548. , 24469364. , 24472316. , 24475212. ], + [24971196. , 24974804. , 24977764. , 24980792. , + 24983648. , 24986620. , 24989662. , 24992494. , + 24995420. , 24998512. , 25001276. , 25004274. , + 25007244. , 25010124. , 25013142. , 25016102. ], + [25511114. , 25514800. , 25517824. , 25520918. , + 25523836. , 25526872. , 25529978. , 25532872. , + 25535860. , 25539020. , 25541844. , 25544906. , + 25547942. , 25550884. , 25553966. , 25556990. ], + [26051032. , 26054796. , 26057884. , 26061044. , + 26064024. , 26067124. , 26070296. , 26073250. , + 26076302. , 26079528. , 26082412. , 26085540. , + 26088640. , 26091644. , 26094792. , 26097880. ], + [26590950. , 26594792. , 26597944. , 26601168. , + 26604210. , 26607374. , 26610612. , 26613628. , + 26616744. , 26620038. , 26622980. , 26626172. , + 26629336. , 26632402. , 26635616. , 26638768. ], + [27130868. , 27134786. , 27138002. , 27141294. , + 27144396. , 27147624. , 27150930. , 27154008. , + 27157184. , 27160546. , 27163548. , 27166804. , + 27170034. , 27173162. , 27176440. , 27179656. ], + [27723244. , 27727248. , 27730532. , 27733892. , + 27737062. , 27740358. , 27743732. , 27746876. , + 27750120. , 27753552. , 27756618. , 27759944. , + 27763240. , 27766436. , 27769780. , 27773064. ], + [28323220. , 28327310. , 28330664. , 28334094. , + 28337332. , 28340696. , 28344142. , 28347352. , + 28350664. , 28354168. , 28357300. , 28360696. , + 28364062. , 28367324. , 28370744. , 28374096. ], + [28885444. , 28889616. , 28893040. , 28896544. , + 28899848. , 28903284. , 28906802. , 28910078. , + 28913460. , 28917038. , 28920236. , 28923702. , + 28927138. , 28930468. , 28933960. , 28937382. ], + [29425518. , 29429768. , 29433256. , 29436826. , + 29440192. , 29443692. , 29447276. , 29450614. , + 29454062. , 29457706. , 29460964. , 29464496. , + 29467996. , 29471390. , 29474946. , 29478434. ], + [29965436. , 29969764. , 29973316. , 29976952. , + 29980378. , 29983944. , 29987594. , 29990992. , + 29994504. , 29998216. , 30001532. , 30005128. , + 30008694. , 30012148. , 30015770. , 30019322. ], + [30505352. , 30509760. , 30513376. , 30517076. , + 30520566. , 30524196. , 30527910. , 30531372. , + 30534944. , 30538724. , 30542100. , 30545760. , + 30549392. , 30552908. , 30556596. , 30560212. ], + [31045270. , 31049756. , 31053436. , 31057202. , + 31060752. , 31064448. , 31068228. , 31071750. , + 31075386. , 31079232. , 31082668. , 31086394. , + 31090088. , 31093668. , 31097420. , 31101100. ], + [31585188. , 31589752. , 31593496. , 31597328. , + 31600940. , 31604698. , 31608544. , 31612128. , + 31615828. , 31619740. , 31623236. , 31627028. , + 31630786. , 31634428. , 31638244. , 31641988. ], + [32125106. , 32129748. , 32133556. , 32137452. , + 32141126. , 32144950. , 32148862. , 32152506. , + 32156268. , 32160248. , 32163804. , 32167660. , + 32171482. , 32175186. , 32179068. , 32182876. ], + [32665024. , 32669744. , 32673616. , 32677578. , + 32681314. , 32685200. , 32689178. , 32692884. , + 32696712. , 32700756. , 32704372. , 32708292. , + 32712180. , 32715946. , 32719894. , 32723766. ], + [33221238. , 33226038. , 33229974. , 33234004. , + 33237804. , 33241756. , 33245802. , 33249570. , + 33253460. , 33257576. , 33261252. , 33265238. , + 33269192. , 33273022. , 33277034. , 33280972. ], + [33836944. , 33841824. , 33845832. , 33849936. , + 33853804. , 33857824. , 33861940. , 33865776. , + 33869736. , 33873920. , 33877664. , 33881720. , + 33885744. , 33889640. , 33893724. , 33897732. ], + [34414896. , 34419864. , 34423944. , 34428112. , + 34432048. , 34436140. , 34440328. , 34444232. , + 34448260. , 34452520. , 34456324. , 34460456. , + 34464548. , 34468512. , 34472672. , 34476748. ], + [34824696. , 34829728. , 34833856. , 34838080. , + 34842064. , 34846208. , 34850448. , 34854396. , + 34858476. , 34862792. , 34866644. , 34870824. , + 34874968. , 34878984. , 34883192. , 34887320. ]], + dtype=float32),), + mlir_module_text=r""" +#loc6 = loc("third_party/py/jax/tests/pallas/export_back_compat_pallas_test.py":74:12) +#loc14 = loc("jit(func)/jit(main)/pjit"(#loc6)) +module @jit_func attributes {jax.uses_shape_polymorphism = false, mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { + func.func public @main() -> (tensor<64x16xf32> {jax.result_info = "", mhlo.layout_mode = "default"}) { + %c = stablehlo.constant dense<0> : tensor loc(#loc) + %c_0 = stablehlo.constant dense<16> : tensor loc(#loc) + %cst = stablehlo.constant dense<1.000000e+00> : tensor loc(#loc) + %cst_1 = stablehlo.constant dense<1.000000e-03> : tensor loc(#loc) + %0 = stablehlo.iota dim = 0 : tensor<524288xf32> loc(#loc9) + %1 = stablehlo.reshape %0 : (tensor<524288xf32>) -> tensor<1024x512xf32> loc(#loc10) + %2 = stablehlo.broadcast_in_dim %cst_1, dims = [] : (tensor) -> tensor<1024x512xf32> loc(#loc11) + %3 = stablehlo.multiply %2, %1 : tensor<1024x512xf32> loc(#loc11) + %4 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor) -> tensor<1024x512xf32> loc(#loc12) + %5 = stablehlo.add %4, %3 : tensor<1024x512xf32> loc(#loc12) + %6 = stablehlo.slice %5 [0:512, 0:256] : (tensor<1024x512xf32>) -> tensor<512x256xf32> loc(#loc13) + %7 = call @matmul(%5, %6) : (tensor<1024x512xf32>, tensor<512x256xf32>) -> tensor<1024x256xf32> loc(#loc14) + %8 = stablehlo.iota dim = 0 : tensor<64xi32> loc(#loc15) + %9 = stablehlo.broadcast_in_dim %c_0, dims = [] : (tensor) -> tensor<64xi32> loc(#loc16) + %10 = stablehlo.multiply %9, %8 : tensor<64xi32> loc(#loc16) + %11 = stablehlo.broadcast_in_dim %c, dims = [] : (tensor) -> tensor<64xi32> loc(#loc17) + %12 = stablehlo.add %11, %10 : tensor<64xi32> loc(#loc17) + %13 = stablehlo.iota dim = 0 : tensor<16xi32> loc(#loc15) + %14 = stablehlo.broadcast_in_dim %c_0, dims = [] : (tensor) -> tensor<16xi32> loc(#loc16) + %15 = stablehlo.multiply %14, %13 : tensor<16xi32> loc(#loc16) + %16 = stablehlo.broadcast_in_dim %c, dims = [] : (tensor) -> tensor<16xi32> loc(#loc17) + %17 = stablehlo.add %16, %15 : tensor<16xi32> loc(#loc17) + %18 = stablehlo.broadcast_in_dim %12, dims = [0] : (tensor<64xi32>) -> tensor<64x16x1xi32> loc(#loc18) + %19 = stablehlo.broadcast_in_dim %17, dims = [1] : (tensor<16xi32>) -> tensor<64x16x1xi32> loc(#loc18) + %20 = stablehlo.concatenate %18, %19, dim = 2 : (tensor<64x16x1xi32>, tensor<64x16x1xi32>) -> tensor<64x16x2xi32> loc(#loc19) + %21 = "stablehlo.gather"(%7, %20) <{dimension_numbers = #stablehlo.gather, indices_are_sorted = true, slice_sizes = array}> : (tensor<1024x256xf32>, tensor<64x16x2xi32>) -> tensor<64x16xf32> loc(#loc20) + return %21 : tensor<64x16xf32> loc(#loc) + } loc(#loc) + func.func private @matmul(%arg0: tensor<1024x512xf32> {mhlo.layout_mode = "default"} loc("jit(func)/jit(main)/pjit"(#loc6)), %arg1: tensor<512x256xf32> {mhlo.layout_mode = "default"} loc("jit(func)/jit(main)/pjit"(#loc6))) -> (tensor<1024x256xf32> {mhlo.layout_mode = "default"}) { + %0 = stablehlo.custom_call @tpu_custom_call(%arg0, %arg1) {backend_config = "{\22custom_call_config\22: {\22body\22: \22TUzvUgFNTElSZ29vZ2xlMy10cnVuawABLQkBAwUHAQMJAxkLDQ8RExUXGRsdHyED58ETAbkHEwsTCwsLDwsPDw8LC1MLDw8PDwsPDw8LCwsLExMPDxMPGwsPC0MLFwuFC3MLCwsLFxsLGwsbCxsbGw8LExMPEw8LCxMPExMTHwsTGwsLEwsPCxMTEwsTDwsTEwUHjZFhBwNZARMPBx8nDwcLKyMCZggfAwMLiwUjAwMLdwUlBScFKR15ewUrHSmnHSmrHSm3BS0FLyMJBSEAAQAAAAAAAAABAAAAAAAADREdhzkdETsdEY0dEY8FMR0RqREJAREJBQUzBTUFNwU5FwU7BxcFQyMdlZcRDQAXrRcLHbO1AwVHSQlLBTsRCQ0FPQMPT1ENU1dZWy1dLwlfYWMFPwEHubm7DQ9hZmZpbmVfbWFwPChkMCwgZDEpIC0+IChkMCwgZDEpPgAFQSMJBzEEAAAAAAAAAAEAAAAAAAAAAgAAAAAAAAAFQwVFBUcFSQEHZWltAwUZZxsdCTEDBRlrGx0JMwMFGW8bHQk1AwUNHwkxAwUNHwkzAwUNHwk1EQEBBUsXBTsXAwMLfxEBBQMDNy0dhTkFTQVPAwM3LxEDARcFRQ0XBUcNAwMLkyUFCQAAAAAFURcFQ0EDBZs/nT8FUwVVAwOhvwVXHaU7BVkXBUMFFwVRKRcFUQUFWwMDC7ETCwEFXRcFPycXBT8JI3RwdS5kaW1lbnNpb25fc2VtYW50aWNzPHBhcmFsbGVsPgAjdHB1LmRpbWVuc2lvbl9zZW1hbnRpY3M8YXJiaXRyYXJ5PgAjdHB1Lm1lbW9yeV9zcGFjZTx2bWVtPgAjYXJpdGguZmFzdG1hdGg8bm9uZT4AAQICAycFAggCCAsXvQUCCAIIC1UBAgQLAQkFDwEBAQcHBwcBBQcBAQEFAQEEIgYFAREBRQcDAREHEQFNBwNHgw8BAQEBAQEHAQcBBwEHAQMDDwcDAQMDD30DAQMDDwcDAQ0HD4EDDQUFDxEGgwMBAxUDAyEHAwENByGJAw0FFxkTFCEDGwkDCx0DA0OvAwsZBkMDBQNHAwMXAwMDAwMXAwMDBQYXAwUHDUtNCwQXCUkNS00PAEEDAQUPAEEDAyMDAwMDAyMDAwMFBiMDBQcNHR8DAyUDAwMDAyUDAwMFBiUDBQcHIyUDAycDAwMDAycDAwMFBicDBQcJKSsDAz2RAwUVBz2ZAwUHJy0vFwejnwMFBSExAwMTAwMDAwMTAwMDBQYTAwUHDTU3CwQTCTMNNTcDAysDAwMDAysDAwMFBisDBQcNOz0DAxUDAwMDAxUDAwMFBhUDBQcLQUMLBBUJPwtBQwkAAQcRAXEHAwkLBwEBAQEBAQMDAQcDAQkEAQUBBQcRAXMHAwkLBwEBAQEBAQMDAQcDAQkEAQUFAwcRAXUHAwkLBwEBAQEBAQMDAQcDAQkEAQUBAwYDAQUBAO4JXyUFCxMdHRsNLQkdCyMhIykdLRUZGRkNHSULHQ0TcyMXFw8ZFRcbGRUZHw8NCR0RYnVpbHRpbgBzdGFibGVfbW9zYWljAHRwdQBhcml0aABtb2R1bGUAYXJpdGguY29uc3RhbnQAdmVjdG9yLmxvYWQAZnVuYy5mdW5jAGZ1bmMucmV0dXJuAHZlY3Rvci5zdG9yZQBhcml0aC5jbXBpAHNjZi55aWVsZABhcml0aC5leHR1aQBzY2YuaWYAdHB1Lm1hdG11bABhcml0aC5hZGRmAHZlY3Rvci5icm9hZGNhc3QAdGhpcmRfcGFydHkvcHkvamF4L2V4cGVyaW1lbnRhbC9wYWxsYXMvb3BzL3RwdS9tYXRtdWwucHkAc3ltX25hbWUAdmFsdWUAZnVuY3Rpb25fdHlwZQAvZ2V0AHRyYW5zZm9ybV9pbmRpY2VzAHdpbmRvd19ib3VuZHMAL3N3YXAAdHJhbnNmb3JtXzAAdHJhbnNmb3JtXzEAdHJhbnNmb3JtXzIAcHJlZGljYXRlAHN0YWJsZV9tb3NhaWMudmVyc2lvbgBtYXRtdWxfa2VybmVsAGRpbWVuc2lvbl9zZW1hbnRpY3MAaXRlcmF0aW9uX2JvdW5kcwBzY2FsYXJfcHJlZmV0Y2gAc2NyYXRjaF9vcGVyYW5kcwBtYWluAHdpbmRvd19wYXJhbXMAL2VxAC9jb252ZXJ0X2VsZW1lbnRfdHlwZQAvY29uZAAvZG90X2dlbmVyYWwAdHJhbnNwb3NlX2xocwB0cmFuc3Bvc2VfcmhzAGZhc3RtYXRoAC9hZGQALQAvYnJvYWRjYXN0X2luX2RpbQA=\22, \22serialization_format\22: 1, \22needs_layout_passes\22: true}, \22implicit_sharding\22: {\22type\22: \22MANUAL\22}}", kernel_name = "matmul_kernel", operand_layouts = [dense<[1, 0]> : tensor<2xindex>, dense<[1, 0]> : tensor<2xindex>], result_layouts = [dense<[1, 0]> : tensor<2xindex>]} : (tensor<1024x512xf32>, tensor<512x256xf32>) -> tensor<1024x256xf32> loc(#loc21) + return %0 : tensor<1024x256xf32> loc(#loc14) + } loc(#loc14) +} loc(#loc) +#loc = loc(unknown) +#loc1 = loc("third_party/py/jax/tests/pallas/export_back_compat_pallas_test.py":71:25) +#loc2 = loc("third_party/py/jax/tests/pallas/export_back_compat_pallas_test.py":72:43) +#loc3 = loc("third_party/py/jax/tests/pallas/export_back_compat_pallas_test.py":71:17) +#loc4 = loc("third_party/py/jax/tests/pallas/export_back_compat_pallas_test.py":71:10) +#loc5 = loc("third_party/py/jax/tests/pallas/export_back_compat_pallas_test.py":73:10) +#loc7 = loc("third_party/py/jax/tests/pallas/export_back_compat_pallas_test.py":76:13) +#loc8 = loc("third_party/py/jax/experimental/pallas/ops/tpu/matmul.py":68:9) +#loc9 = loc("jit(func)/jit(main)/iota"(#loc1)) +#loc10 = loc("jit(func)/jit(main)/reshape"(#loc2)) +#loc11 = loc("jit(func)/jit(main)/mul"(#loc3)) +#loc12 = loc("jit(func)/jit(main)/add"(#loc4)) +#loc13 = loc("jit(func)/jit(main)/slice"(#loc5)) +#loc15 = loc("jit(func)/jit(main)/iota"(#loc7)) +#loc16 = loc("jit(func)/jit(main)/mul"(#loc7)) +#loc17 = loc("jit(func)/jit(main)/add"(#loc7)) +#loc18 = loc("jit(func)/jit(main)/broadcast_in_dim"(#loc7)) +#loc19 = loc("jit(func)/jit(main)/concatenate"(#loc7)) +#loc20 = loc("jit(func)/jit(main)/gather"(#loc7)) +#loc21 = loc("jit(func)/jit(main)/jit(matmul)/pallas_call"(#loc8)) +""", + mlir_module_serialized=b'ML\xefR\rStableHLO_v1.5.0\x00\x01-\x05\x01\x05\x1d\x01\x03\x0b\x03\x1b\x0f\x13\x17\x1b\x1f#\'+/37;?\x03\xe5\xa3/\x01W\x07\x0b\x13\x0f\x0f\x0f\x0f\x0b\x0f\x0b\x0f\x0b\x0f\x0f#\x0b\x0f\x0b\x0b\x0b\x0b\x13\x13\x0b\x0f\x0b\x13\x0b\x0f\x13\x0f\x0b\x13\x13\x13\x0f\x0b\x13\x0b\x0f\x0b\x0f\x0b\x03M\x0f\x0b\x13O\x0f\x0b\x0b\x0b/\x0fO\x0b\x0f\x1b\x0b\x0b\x0b\x0b\x13\x0b\x0f\x0b\x0b\x0b\x0b\x0b\x0b\x13\x0f\x1f\x1f\x1f\x1fO///\x0b\x01\x05\x0b\x0f\x03+\x1f\x07\x07\x07\x17\x13\x0f\x0f\x13\x1f\x1f\x1b\x1f\x13\x13\x1b\x13\x07\x1b\x13\x1f\x02\x92\x06\x1f\x05!\x17\x03\x99\x1b\x1d)+\x1d\x13\x05\x1d\x17\x05\x11\x03\x05\x05#\x1d\x13C\x05%\x1d\x17E\x05\'\x1d\x0f\x05\x1dM\x05\x03\x07\x1f!#\r%\r\x05)\x11\x01\x00\x05+\x05-\x05/\x051\x17\x03\x95\x19\x03\x03/\x83\x053\x1d35\x055\x177\x89\x13\x057\x1d\x0f;\x17\x03\x8f3\x1d?A\x059\x17\x03\x91W\x17\x03\x8f#\x17\x03\x8f\x15\x1dIK\x05;\x17\x03\x93\x15\x05=\x1dQ\x05\x05?\x1dU\x05\x05A\x1f+\x01\x03\x01\r\x03ac\x1f%!\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x13\x0b\x01\x1dC\x1dE\x1dG\x1f\x15\x11\x01\x00\x00\x00\x00\x00\x00\x00\x13\x0b\t\x1f\x15!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00#!\x03\x03q\r\x05suac\x1dI\x1dK\x1dM\x1dO\x03\x05[[##\x03\x03[\x1dQ\x1dS\x0b\x03\x1dU\x1dW\x05\x01\x03\x05]]\x03\x03]\x1f\x11\t\x00\x00\x00\x00\x1f\x11\t\x10\x00\x00\x00\x1f\x13\t\x00\x00\x80?\x1f\x13\to\x12\x83:\x1f\x15!\x00\x02\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x1f\x15\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1f\x1f\x11\x00\x00\x00\x00\x00\x00\x00\x00\x1f\x1f\x11\x01\x00\x00\x00\x00\x00\x00\x00\x05\x03\x01\t\x01\x02\x02)\x05\x02 \x02\x10\x07\t\x1b\x1d)\x03\x02\x02\t)\x03A\t)\x01\t)\x01\x07)\x03\t\x0b)\x05\x02\x10\x02\x08\x07)\x05\x02 \x02\x08\x07)\x05\x02\x02A\x07)\x07\x02\x02A\x05\t)\x03\x05\x0b\x11\x01\x03\x1b\x11\x05\x05\x17\x03\x19)\x03\t\'\x13)\x03\x04\x00\x80\x07)\x03\x01\x0b)\x07\x02\x02A\t\t\x04\x02\x04\x05\x01Q\x01\x1d\x01\x07\x04\xda\x03\x03\x01\t\rP\x01\x03\x07\x042\x03\x035m\x05B\x01\x05\x03\x11\x05B\x01\x07\x03\x11\x05B\x01\t\x03\x13\x05B\x01\x0b\x03\x13\x07B9\r\x03)\x13\x06=\x03\x05\x03\t\x03F\x11\x0f\x03\x05\x03\x07\t\x06\x11\x03\x05\x05\r\x0b\x03F\x15\x0f\x03\x05\x03\x05\x0b\x06\x15\x03\x05\x05\x11\x0f\x15FG\x11\x03\x17\x03\x13\x17F\x07\x13\x03\x19\x05\x13\x15\x07B\x19\r\x03\r\x03F\t\x0f\x03\r\x03\x03\t\x06\t\x03\r\x05\x1b\x19\x03F\x0b\x0f\x03\r\x03\x01\x0b\x06\x0b\x03\r\x05\x1f\x1d\x07B\x19\r\x03\x0f\x03F\t\x0f\x03\x0f\x03\x03\t\x06\t\x03\x0f\x05%#\x03F\x0b\x0f\x03\x0f\x03\x01\x0b\x06\x0b\x03\x0f\x05)\'\x03F\x1b\x15\x03\x1d\x03!\x03F\x1b\x17\x03\x1d\x03+\x19FO\x19\x03-\x05-/\x1bFS\x1b\x03\x1b\x05\x171\x0f\x04\x01\x033\rP\x07\x1d\x07\x041\x03\x07\x0b\x05\x0b\x07/\x07\x00\x11G1-\x1f\x03\x19\x05\x01\x03\x0f\x04\x07\x03\x05\x06\x03\x01\x05\x01\x00\x1a3Y!j&\x1d\x11\x0f\x0b\x03!\x0f\x11#7AK59sY\x193\x13%)9113\x85\x15\x1f\x11\x13\x17\x1f\x15\x11\x0f\x19\x11\x19)\x0f\x0b\x11builtin\x00vhlo\x00module\x00broadcast_in_dim_v1\x00constant_v1\x00iota_v1\x00multiply_v1\x00add_v1\x00func_v1\x00return_v1\x00custom_call_v1\x00reshape_v1\x00slice_v1\x00call_v1\x00concatenate_v1\x00gather_v2\x00third_party/py/jax/tests/pallas/export_back_compat_pallas_test.py\x00jit(func)/jit(main)/iota\x00jit(func)/jit(main)/mul\x00jit(func)/jit(main)/add\x00jax.uses_shape_polymorphism\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit_func\x00jit(func)/jit(main)/pjit\x00kernel_name\x00jit(func)/jit(main)/jit(matmul)/pallas_call\x00third_party/py/jax/experimental/pallas/ops/tpu/matmul.py\x00jit(func)/jit(main)/reshape\x00jit(func)/jit(main)/slice\x00jit(func)/jit(main)/broadcast_in_dim\x00jit(func)/jit(main)/concatenate\x00jit(func)/jit(main)/gather\x00mhlo.layout_mode\x00default\x00matmul\x00jax.result_info\x00\x00main\x00public\x00private\x00matmul_kernel\x00{"custom_call_config": {"body": "TUzvUgFNTElSZ29vZ2xlMy10cnVuawABLQkBAwUHAQMJAxkLDQ8RExUXGRsdHyED58ETAbkHEwsTCwsLDwsPDw8LC1MLDw8PDwsPDw8LCwsLExMPDxMPGwsPC0MLFwuFC3MLCwsLFxsLGwsbCxsbGw8LExMPEw8LCxMPExMTHwsTGwsLEwsPCxMTEwsTDwsTEwUHjZFhBwNZARMPBx8nDwcLKyMCZggfAwMLiwUjAwMLdwUlBScFKR15ewUrHSmnHSmrHSm3BS0FLyMJBSEAAQAAAAAAAAABAAAAAAAADREdhzkdETsdEY0dEY8FMR0RqREJAREJBQUzBTUFNwU5FwU7BxcFQyMdlZcRDQAXrRcLHbO1AwVHSQlLBTsRCQ0FPQMPT1ENU1dZWy1dLwlfYWMFPwEHubm7DQ9hZmZpbmVfbWFwPChkMCwgZDEpIC0+IChkMCwgZDEpPgAFQSMJBzEEAAAAAAAAAAEAAAAAAAAAAgAAAAAAAAAFQwVFBUcFSQEHZWltAwUZZxsdCTEDBRlrGx0JMwMFGW8bHQk1AwUNHwkxAwUNHwkzAwUNHwk1EQEBBUsXBTsXAwMLfxEBBQMDNy0dhTkFTQVPAwM3LxEDARcFRQ0XBUcNAwMLkyUFCQAAAAAFURcFQ0EDBZs/nT8FUwVVAwOhvwVXHaU7BVkXBUMFFwVRKRcFUQUFWwMDC7ETCwEFXRcFPycXBT8JI3RwdS5kaW1lbnNpb25fc2VtYW50aWNzPHBhcmFsbGVsPgAjdHB1LmRpbWVuc2lvbl9zZW1hbnRpY3M8YXJiaXRyYXJ5PgAjdHB1Lm1lbW9yeV9zcGFjZTx2bWVtPgAjYXJpdGguZmFzdG1hdGg8bm9uZT4AAQICAycFAggCCAsXvQUCCAIIC1UBAgQLAQkFDwEBAQcHBwcBBQcBAQEFAQEEIgYFAREBRQcDAREHEQFNBwNHgw8BAQEBAQEHAQcBBwEHAQMDDwcDAQMDD30DAQMDDwcDAQ0HD4EDDQUFDxEGgwMBAxUDAyEHAwENByGJAw0FFxkTFCEDGwkDCx0DA0OvAwsZBkMDBQNHAwMXAwMDAwMXAwMDBQYXAwUHDUtNCwQXCUkNS00PAEEDAQUPAEEDAyMDAwMDAyMDAwMFBiMDBQcNHR8DAyUDAwMDAyUDAwMFBiUDBQcHIyUDAycDAwMDAycDAwMFBicDBQcJKSsDAz2RAwUVBz2ZAwUHJy0vFwejnwMFBSExAwMTAwMDAwMTAwMDBQYTAwUHDTU3CwQTCTMNNTcDAysDAwMDAysDAwMFBisDBQcNOz0DAxUDAwMDAxUDAwMFBhUDBQcLQUMLBBUJPwtBQwkAAQcRAXEHAwkLBwEBAQEBAQMDAQcDAQkEAQUBBQcRAXMHAwkLBwEBAQEBAQMDAQcDAQkEAQUFAwcRAXUHAwkLBwEBAQEBAQMDAQcDAQkEAQUBAwYDAQUBAO4JXyUFCxMdHRsNLQkdCyMhIykdLRUZGRkNHSULHQ0TcyMXFw8ZFRcbGRUZHw8NCR0RYnVpbHRpbgBzdGFibGVfbW9zYWljAHRwdQBhcml0aABtb2R1bGUAYXJpdGguY29uc3RhbnQAdmVjdG9yLmxvYWQAZnVuYy5mdW5jAGZ1bmMucmV0dXJuAHZlY3Rvci5zdG9yZQBhcml0aC5jbXBpAHNjZi55aWVsZABhcml0aC5leHR1aQBzY2YuaWYAdHB1Lm1hdG11bABhcml0aC5hZGRmAHZlY3Rvci5icm9hZGNhc3QAdGhpcmRfcGFydHkvcHkvamF4L2V4cGVyaW1lbnRhbC9wYWxsYXMvb3BzL3RwdS9tYXRtdWwucHkAc3ltX25hbWUAdmFsdWUAZnVuY3Rpb25fdHlwZQAvZ2V0AHRyYW5zZm9ybV9pbmRpY2VzAHdpbmRvd19ib3VuZHMAL3N3YXAAdHJhbnNmb3JtXzAAdHJhbnNmb3JtXzEAdHJhbnNmb3JtXzIAcHJlZGljYXRlAHN0YWJsZV9tb3NhaWMudmVyc2lvbgBtYXRtdWxfa2VybmVsAGRpbWVuc2lvbl9zZW1hbnRpY3MAaXRlcmF0aW9uX2JvdW5kcwBzY2FsYXJfcHJlZmV0Y2gAc2NyYXRjaF9vcGVyYW5kcwBtYWluAHdpbmRvd19wYXJhbXMAL2VxAC9jb252ZXJ0X2VsZW1lbnRfdHlwZQAvY29uZAAvZG90X2dlbmVyYWwAdHJhbnNwb3NlX2xocwB0cmFuc3Bvc2VfcmhzAGZhc3RtYXRoAC9hZGQALQAvYnJvYWRjYXN0X2luX2RpbQA=", "serialization_format": 1, "needs_layout_passes": true}, "implicit_sharding": {"type": "MANUAL"}}\x00tpu_custom_call\x00\x08u!\x05O\x01\x0bYmowy\x03\x91\x03\x93\x03\x95\x03\x97\x03_\x03W\x07\x99\x9bg\x03e\x03\x9d\x03\x9f\x03i\x11ki\xa1WWgkW\x0b{}\x7fe\x81\x11\x85\x87\x89Y\x8b\x8dY\x8f', + xla_call_module_version=9, + nr_devices=1, +) # End paste diff --git a/jax/_src/internal_test_util/export_back_compat_test_data/pallas/mosaic_semaphore_dma.py b/jax/_src/internal_test_util/export_back_compat_test_data/pallas/mosaic_semaphore_dma.py new file mode 100644 index 000000000000..a44e92846b98 --- /dev/null +++ b/jax/_src/internal_test_util/export_back_compat_test_data/pallas/mosaic_semaphore_dma.py @@ -0,0 +1,95 @@ +# Copyright 2024 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import datetime +from numpy import array, float32 + +# Pasted from the test output (see export_back_compat_test_util.py module docstring) +semaphore_and_dma_2024_04_22 = dict( + testdata_version=1, + platform='tpu', + custom_call_targets=['tpu_custom_call'], + serialized_date=datetime.date(2024, 4, 22), + inputs=(), + expected_outputs=(array(1., dtype=float32),), + mlir_module_text=r""" +#loc2 = loc("third_party/py/jax_triton/googlexpallas_tpu/back_compat_test.py":60:4) +#loc3 = loc("third_party/py/absl/testing/absltest.py":2718:19) +#loc4 = loc("third_party/py/absl/testing/absltest.py":2754:35) +#loc5 = loc("third_party/py/absl/testing/absltest.py":2298:6) +#loc6 = loc("third_party/py/absl/app.py":395:13) +#loc7 = loc("third_party/py/absl/app.py":473:6) +#loc8 = loc("third_party/py/absl/testing/absltest.py":2300:4) +#loc9 = loc("third_party/py/absl/testing/absltest.py":2182:2) +#loc10 = loc("third_party/py/jax_triton/googlexpallas_tpu/back_compat_test.py":64:2) +#loc11 = loc("third_party/py/jax_triton/googlexpallas_tpu/back_compat_test.py":57:10) +#loc14 = loc("PallasKernelTest.test_semaphore_and_dma_22_04_2024"(#loc2)) +#loc15 = loc("_run_and_get_tests_result"(#loc3)) +#loc16 = loc("run_tests"(#loc4)) +#loc17 = loc("_run_in_app..main_function"(#loc5)) +#loc18 = loc("_run_main"(#loc6)) +#loc19 = loc("run"(#loc7)) +#loc20 = loc("_run_in_app"(#loc8)) +#loc21 = loc("main"(#loc9)) +#loc22 = loc(""(#loc10)) +#loc23 = loc("PallasKernelTest.test_semaphore_and_dma_22_04_2024..func"(#loc11)) +#loc25 = loc(callsite(#loc21 at #loc22)) +#loc26 = loc(callsite(#loc20 at #loc25)) +#loc27 = loc(callsite(#loc19 at #loc26)) +#loc28 = loc(callsite(#loc18 at #loc27)) +#loc29 = loc(callsite(#loc17 at #loc28)) +#loc30 = loc(callsite(#loc16 at #loc29)) +#loc31 = loc(callsite(#loc15 at #loc30)) +#loc32 = loc(callsite(#loc14 at #loc31)) +#loc34 = loc(callsite(#loc23 at #loc32)) +#loc38 = loc("jit(func)/jit(main)/pjit[in_shardings=(UnspecifiedValue,) out_shardings=(UnspecifiedValue,) in_layouts=(None,) out_layouts=(None,) resource_env=None donated_invars=(False,) name=wrapped keep_unused=False inline=False]"(#loc34)) +#loc42 = loc("jit(func)/jit(main)/jit(wrapped)/pjit[in_shardings=(UnspecifiedValue,) out_shardings=(UnspecifiedValue,) in_layouts=(None,) out_layouts=(None,) resource_env=None donated_invars=(False,) name=apply_kernel keep_unused=False inline=False]"(#loc34)) +module @jit_func attributes {jax.uses_shape_polymorphism = false, mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { + func.func public @main() -> (tensor {jax.result_info = "", mhlo.layout_mode = "default"}) { + %0 = stablehlo.iota dim = 0 : tensor<16384xf32> loc(#loc36) + %1 = stablehlo.reshape %0 : (tensor<16384xf32>) -> tensor<128x128xf32> loc(#loc37) + %2 = call @wrapped(%1) : (tensor<128x128xf32>) -> tensor<128x128xf32> loc(#loc38) + %3 = stablehlo.compare EQ, %1, %2, FLOAT : (tensor<128x128xf32>, tensor<128x128xf32>) -> tensor<128x128xi1> loc(#loc39) + %c = stablehlo.constant dense : tensor loc(#loc40) + %4 = stablehlo.reduce(%3 init: %c) applies stablehlo.and across dimensions = [0, 1] : (tensor<128x128xi1>, tensor) -> tensor loc(#loc40) + %5 = stablehlo.convert %4 : (tensor) -> tensor loc(#loc41) + return %5 : tensor loc(#loc) + } loc(#loc) + func.func private @wrapped(%arg0: tensor<128x128xf32> {mhlo.layout_mode = "default"} loc("jit(func)/jit(main)/pjit[in_shardings=(UnspecifiedValue,) out_shardings=(UnspecifiedValue,) in_layouts=(None,) out_layouts=(None,) resource_env=None donated_invars=(False,) name=wrapped keep_unused=False inline=False]"(#loc34))) -> (tensor<128x128xf32> {mhlo.layout_mode = "default"}) { + %0 = call @apply_kernel(%arg0) : (tensor<128x128xf32>) -> tensor<128x128xf32> loc(#loc42) + return %0 : tensor<128x128xf32> loc(#loc38) + } loc(#loc38) + func.func private @apply_kernel(%arg0: tensor<128x128xf32> {mhlo.layout_mode = "default"} loc("jit(func)/jit(main)/jit(wrapped)/pjit[in_shardings=(UnspecifiedValue,) out_shardings=(UnspecifiedValue,) in_layouts=(None,) out_layouts=(None,) resource_env=None donated_invars=(False,) name=apply_kernel keep_unused=False inline=False]"(#loc34))) -> (tensor<128x128xf32> {mhlo.layout_mode = "default"}) { + %0 = stablehlo.custom_call @tpu_custom_call(%arg0) {backend_config = "{\22custom_call_config\22: {\22body\22: \22TUzvUgFNTElSZ29vZ2xlMy10cnVuawABJwcBAwUBAwcDFQkLDQ8RExUXGRsD27UTAbELBwsPCw8PCw8PDw8PDw8LDw9VDxMPDxMLDzMLCwsLhQsLCwsPCxMPCxMPCxMPCxcPCxcPCxcPCxcPCxcPCxcPFw8LDxMPDw8PDw8PDwsLDwsPDxMLDw8TBQWFYQEPJw8PFwcXFwUFTT0CzgYFHR8FHx1HSQUhFRGLEQUBBSMdS00dUVMdV1kdXV8dY2UdaWsdb3EFJR11dx17fWFmZmluZV9tYXA8KCkgLT4gKCk+ABWHCwMDnZ8doaMdqasDAzEzBScRBQUDCzc5Oz1BDUMNRQ8FKQEBBSsNB2FmZmluZV9tYXA8KGQwLCBkMSkgLT4gKGQwLCBkMSk+AAUtBS8FMQUzFRFPBTUXAWsRFRNVBTcXAXMVFRVbBTkXAXkJFRdhBTsXBXoqJxUZZwU9FwUKK0cVG20FPxcF6iMNFR1zBUEXHy4GGxUheQVDFx9mBw0VI38FRRcF8iMJHQ+BFwUaIgUdhScFRx0JiRcBZRUVE40VFY8VF5EVGZMVG5UVHZcVISMdmycFSQVLEQMFBU0VpQsdCacXAWcVBU8VrQsdCa8XAWkVI3RwdS5tZW1vcnlfc3BhY2U8c2VtYXBob3JlX21lbT4AI3RwdS5tZW1vcnlfc3BhY2U8dm1lbT4AF7MFAgQCBAk/AQICAQIEBQUBAQELF7EBDyUXsQERJSF0cHUuZG1hX3NlbWFwaG9yZQAhdHB1LnNlbWFwaG9yZQAEpQUBEQMvBwMBBQcRAzUHAwULBQEDAQMJEAcFAwklAwIHAwsDAgcDDQ0EgwcBAwUPBJkFBQMFAyspAwMRBCsFBwkFAy0pAwMTBC0FBwsVAAcLAAMGAwEFAQA+ElFjtQ2XyxkJFUcVNWeDqxkTIyEdKS03C8dRgRUbHxshGRcVHx0PCR0RYnVpbHRpbgBzdGFibGVfbW9zYWljAHRwdQBtb2R1bGUAdHB1LnNlbV9hbGxvYwBhcml0aC5jb25zdGFudABmdW5jLmZ1bmMAdHB1LnJlZ2lvbgBmdW5jLnJldHVybgB0cHUuZW5xdWV1ZV9kbWEAdHB1LndhaXRfZG1hAHRwdS5zZW1fc2lnbmFsAHRwdS5zZW1fd2FpdAB0cHUueWllbGQAdGhpcmRfcGFydHkvcHkvamF4X3RyaXRvbi9nb29nbGUvcGFsbGFzX3RwdS9iYWNrX2NvbXBhdF90ZXN0LnB5AHRoaXJkX3BhcnR5L3B5L2Fic2wvdGVzdGluZy9hYnNsdGVzdC5weQBQYWxsYXNLZXJuZWxUZXN0LnRlc3Rfc2VtYXBob3JlX2FuZF9kbWFfMjJfMDRfMjAyNC48bG9jYWxzPi5mdW5jLjxsb2NhbHM+LmRtYV9rZXJuZWwuPGxvY2Fscz4uYm9keQBtYWluAHRoaXJkX3BhcnR5L3B5L2Fic2wvYXBwLnB5AHN0YWJsZV9tb3NhaWMudmVyc2lvbgBkaW1lbnNpb25fc2VtYW50aWNzAGZ1bmN0aW9uX3R5cGUAc2NhbGFyX3ByZWZldGNoAHNjcmF0Y2hfb3BlcmFuZHMAc3ltX25hbWUAL3J1bl9zY29wZWQAUGFsbGFzS2VybmVsVGVzdC50ZXN0X3NlbWFwaG9yZV9hbmRfZG1hXzIyXzA0XzIwMjQuPGxvY2Fscz4uZnVuYy48bG9jYWxzPi5kbWFfa2VybmVsAFBhbGxhc0tlcm5lbFRlc3QudGVzdF9zZW1hcGhvcmVfYW5kX2RtYV8yMl8wNF8yMDI0Ljxsb2NhbHM+LmZ1bmMAUGFsbGFzS2VybmVsVGVzdC50ZXN0X3NlbWFwaG9yZV9hbmRfZG1hXzIyXzA0XzIwMjQAX3J1bl9hbmRfZ2V0X3Rlc3RzX3Jlc3VsdABydW5fdGVzdHMAX3J1bl9pbl9hcHAuPGxvY2Fscz4ubWFpbl9mdW5jdGlvbgBfcnVuX21haW4AcnVuAF9ydW5faW5fYXBwAC9kbWFfc3RhcnRbdHJlZT1QeVRyZWVEZWYoKCosICgpLCAqLCAoKSwgKiwgKCksIE5vbmUsIE5vbmUsIE5vbmUpKSBkZXZpY2VfaWRfdHlwZT1EZXZpY2VJZFR5cGUuTUVTSF0AL2RtYV93YWl0W3RyZWU9UHlUcmVlRGVmKCgqLCAoKSwgKiwgKCkpKSBkZXZpY2VfaWRfdHlwZT1EZXZpY2VJZFR5cGUuTUVTSF0AdmFsdWUAL3NlbWFwaG9yZV9zaWduYWxbYXJnc190cmVlPVB5VHJlZURlZihbKiwgKCksICosIE5vbmVdKSBkZXZpY2VfaWRfdHlwZT1EZXZpY2VJZFR5cGUuTUVTSF0AL3NlbWFwaG9yZV93YWl0W2FyZ3NfdHJlZT1QeVRyZWVEZWYoWyosICgpLCAqXSldAA==\22, \22serialization_format\22: 1, \22needs_layout_passes\22: true}, \22implicit_sharding\22: {\22type\22: \22MANUAL\22}}", kernel_name = "dma_kernel", operand_layouts = [dense<[1, 0]> : tensor<2xindex>], result_layouts = [dense<[1, 0]> : tensor<2xindex>]} : (tensor<128x128xf32>) -> tensor<128x128xf32> loc(#loc43) + return %0 : tensor<128x128xf32> loc(#loc42) + } loc(#loc42) +} loc(#loc) +#loc = loc(unknown) +#loc1 = loc("third_party/py/jax_triton/googlexpallas_tpu/back_compat_test.py":56:10) +#loc12 = loc("third_party/py/jax_triton/googlexpallas_tpu/back_compat_test.py":58:13) +#loc13 = loc("PallasKernelTest.test_semaphore_and_dma_22_04_2024..func"(#loc1)) +#loc24 = loc("PallasKernelTest.test_semaphore_and_dma_22_04_2024..func"(#loc12)) +#loc33 = loc(callsite(#loc13 at #loc32)) +#loc35 = loc(callsite(#loc24 at #loc32)) +#loc36 = loc("jit(func)/jit(main)/iota[dtype=float32 shape=(16384,) dimension=0]"(#loc33)) +#loc37 = loc("jit(func)/jit(main)/reshape[new_sizes=(128, 128) dimensions=None]"(#loc33)) +#loc39 = loc("jit(func)/jit(main)/eq"(#loc35)) +#loc40 = loc("jit(func)/jit(main)/reduce_and[axes=(0, 1)]"(#loc35)) +#loc41 = loc("jit(func)/jit(main)/convert_element_type[new_dtype=float32 weak_type=False]"(#loc35)) +#loc43 = loc("jit(func)/jit(main)/jit(wrapped)/jit(apply_kernel)/tpu_custom_call[config=CustomCallBackendConfig() kernel_name=dma_kernel kernel_regeneration_metadata=None out_avals=(ShapedArray(float32[128,128]),) input_output_aliases=()]"(#loc34)) +""", + mlir_module_serialized=b'ML\xefR\x01StableHLO_v0.9.0\x00\x01\'\x05\x01\x03\x01\x03\x05\x03\x17\x07\t\x0b\r\x0f\x11\x13\x15\x17\x19\x1b\x03z\x02\n\x02\x1f\x01\xc9\x0f\x0b\x0b\x0b\x0f\x0f\x07\x0b\x0b\x0b\x0b\x0f\x0b\x0f\x0f\x0f\x0b\x0b\x0f+\x0b\x0f\x0b\x0b\x0b33\x0b\x0f\x13\x0f\x0b\x13\x0f\x0f\x0b\x17\x0f\x0f\x0b\x17\x0f\x0f\x0b\x17\x0f\x0f\x0b\x17\x0f\x0f\x0b\x17\x0f\x0f\x0b\x17\x0f\x0f\x0b\x17\x0f\x0b\x133\x0bS\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0f\x0b\x13\x13\x0b\x0f\x0b\x0f\x13\x0f\x0b\x13\x1b\x0b\x0b\x0f\x0b\x0f\x13\x13\x0b\x0b\x13\x0b\x039\x0f\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0f\x0b\x0f\x1b\x0b\x0b\x0b\x0b\x13\x0b\x0b\x0b\x0b\x0bO\x0f\x0b\x0b\x13O\x01\x05\x13\x0b\x01\x05\x0b\x0f\x03\x1b\x1f\x0f\x07\x0f\x07\x07\x13\x17\x13\x07\x1b\x1f\x13\x02\xb2\x07\x1d\xc3\x1d\x05\x1d\x05\x1f\x05!\x1d7\x17\x1d\x83\x17\x1f\x05#\x05%\x05\'\x05)\x159\x1b\x05+\x15=C\x15\xbb\x1b\x11\x03\x05\x05-\x05/\x15\xa7\x1b\x03\t)+-\x1f/\x1f\x071\x051\x11\x01\x00\x053\x055\x057\x03\x0b\x0f\xcb\x11\xdb\x13\xdd\x07\xe5\x15\xe7\x03\x0b\x0f\xc9\x11\xd1\x13\xc9\x07\xd3\x15\xd5\x059\x1d\x19;\x17\x03s\x15\x1d?A\x05;\x17\x03y\t\x15EK\x1dGI\x05=\x17\x05z*\'\x15MS\x1dOQ\x05?\x17\x05\n+G\x15U[\x1dWY\x05A\x17\x05\xea#\r\x15]c\x1d_a\x05C\x17!.\x06\x1b\x15ek\x1dgi\x05E\x17!f\x07\r\x15ms\x1doq\x05G\x17\x05\xf2#\t\x15u{\x1dwy\x05I\x17\x05\x1a"\x05\x1d}\x7f\x05K\x17\x03\x81\x05\x03\x0b\x0f\xc9\x11\xd1\x13\xc9\x07\xd7\x15\xd5\x05M\x03\x13\x87\xeb\x89\xed\x8b\xef\x8d\xcb\x8f\xf1\x91\xf3\x93\xd9\x95\xcb\x97\xd9\x05O\x05Q\x05S\x05U\x05W\x05Y\x05[\x05]\x05_\x1d\x9b\x17\x05a\x03\x03#\xd7\x03\x03\xa1\xf7\x05c\x1d\xa5%\x05e\x1d\x19\xa9\x17\x03q\x15\x1d\xad%\x05g\x03\x03#\xd3\x03\x05\xb3\xf9\xb5\xfb\x05i\x05k\x1d\xb9\x1d\x05m\x1d\x19\xbd\x17\x03u\x1b\x03\x03\xc1\xfd\x05o\x05q\x03\x03\xc7\xff\x05s\x03\x03\xe9\x03\x01\x1du\x1dw#\x13\x1dy\x1d{\x1d}\x03\x03\xf5#\x11\x03\x03\xdf\r\x05\xe1\xe3\xcd\xcf\x1d\x7f\x1d\x81\x1dI\x1d\x83\r\x03\xcd\xcf\x0b\x03\x1d\x85\x1d\x87\x05\x01\x1d\x89\x1f\x15!\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x13\r\x01\t\x03\x07\x01\x1f\x07\x03\xff\x1f\x1d!\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x1d\x06\x02\x1d\x05\x8b\x01\t\x01\x02\x02)\x05\x02\x04\x02\x04\t)\x01\x0f\t)\x01\t\x1d\x01\x11\x01\x03\x0b\x11\x03\x05\x03\x05)\x03\t\x17\x13)\x03\x04\x00\x04\t)\x05\x02\x04\x02\x04\x0f)\x03\t\r\x04F\x02\x05\x01\x11\r\'\x07\x03\x01\r\x05\x11\r3\x07\x03\x0f!\x0b\x03\xa3\x9f\x03\x19\r\x06\xab\x03\x05\x03\x01\x07\x07\t\xaf\x03\x05\x03\x03\x0f\x07\xb7\xb1\x03\x1b\x05\x03\x05\x11\x03\x01\xbf\x03\x07\x13\x17\x01\xc5\x03\x07\x05\x07\t\x07\x03\x07\x0b\x05\x07\x01\x07\x01\x17\x06\x01\x03\x07\x05\x01\x03\x03\x04\x01\x03\x05\x15\x06\x02\x02\x03\x0b\x03\x0b\x03\x04\r\x03\r\x05\x11\t5\x07\x03\x05\x0b\x03\x05\t\x07\x07\x0b\x9d\x03\x05\x03\x01\x03\x04\t\x03\x03\x05\x11\x0b\x81\x07\x03\x05\x0b\x03\x05\x0b\t\x07\x99\x85\x03\x05\x03\x01\x03\x04\x0b\x03\x03\x06\x03\x01\x05\x01\x00\xbeG\x8d\x99\x17!\xba(\x0f\x03!\x1b\x11\x11\x11#\x17Y\r/+\x1b\x85\x87\x1f\xaa\x03\x1f/!\x19!)#\x1f\x19\xb2\x03\x13\x0b\x19\t\x15G\x155gj\x03\x13%)9\x0f7\x83\x1f\x15\x1d\x15\x13Q\x81\x0f\x17\x15\x19\x17\x17\x11\x1f\x11\x11\x15\x0f\x0b\x11builtin\x00vhlo\x00module\x00return_v1\x00func_v1\x00call_v1\x00custom_call_v1\x00iota_v1\x00reshape_v1\x00compare_v1\x00constant_v1\x00reduce_v1\x00convert_v1\x00and_v1\x00third_party/py/jax_triton/googlexpallas_tpu/back_compat_test.py\x00third_party/py/absl/testing/absltest.py\x00sym_name\x00arg_attrs\x00function_type\x00res_attrs\x00sym_visibility\x00PallasKernelTest.test_semaphore_and_dma_22_04_2024..func\x00third_party/py/absl/app.py\x00callee\x00jax.uses_shape_polymorphism\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit_func\x00jit(func)/jit(main)/pjit[in_shardings=(UnspecifiedValue,) out_shardings=(UnspecifiedValue,) in_layouts=(None,) out_layouts=(None,) resource_env=None donated_invars=(False,) name=wrapped keep_unused=False inline=False]\x00PallasKernelTest.test_semaphore_and_dma_22_04_2024\x00_run_and_get_tests_result\x00run_tests\x00_run_in_app..main_function\x00_run_main\x00run\x00_run_in_app\x00main\x00\x00jit(func)/jit(main)/jit(wrapped)/pjit[in_shardings=(UnspecifiedValue,) out_shardings=(UnspecifiedValue,) in_layouts=(None,) out_layouts=(None,) resource_env=None donated_invars=(False,) name=apply_kernel keep_unused=False inline=False]\x00api_version\x00backend_config\x00call_target_name\x00called_computations\x00has_side_effect\x00kernel_name\x00operand_layouts\x00output_operand_aliases\x00result_layouts\x00jit(func)/jit(main)/jit(wrapped)/jit(apply_kernel)/tpu_custom_call[config=CustomCallBackendConfig() kernel_name=dma_kernel kernel_regeneration_metadata=None out_avals=(ShapedArray(float32[128,128]),) input_output_aliases=()]\x00iota_dimension\x00jit(func)/jit(main)/iota[dtype=float32 shape=(16384,) dimension=0]\x00jit(func)/jit(main)/reshape[new_sizes=(128, 128) dimensions=None]\x00compare_type\x00comparison_direction\x00jit(func)/jit(main)/eq\x00value\x00jit(func)/jit(main)/reduce_and[axes=(0, 1)]\x00dimensions\x00mhlo.layout_mode\x00default\x00wrapped\x00private\x00apply_kernel\x00jax.result_info\x00\x00public\x00{"custom_call_config": {"body": "TUzvUgFNTElSZ29vZ2xlMy10cnVuawABJwcBAwUBAwcDFQkLDQ8RExUXGRsD27UTAbELBwsPCw8PCw8PDw8PDw8LDw9VDxMPDxMLDzMLCwsLhQsLCwsPCxMPCxMPCxMPCxcPCxcPCxcPCxcPCxcPCxcPFw8LDxMPDw8PDw8PDwsLDwsPDxMLDw8TBQWFYQEPJw8PFwcXFwUFTT0CzgYFHR8FHx1HSQUhFRGLEQUBBSMdS00dUVMdV1kdXV8dY2UdaWsdb3EFJR11dx17fWFmZmluZV9tYXA8KCkgLT4gKCk+ABWHCwMDnZ8doaMdqasDAzEzBScRBQUDCzc5Oz1BDUMNRQ8FKQEBBSsNB2FmZmluZV9tYXA8KGQwLCBkMSkgLT4gKGQwLCBkMSk+AAUtBS8FMQUzFRFPBTUXAWsRFRNVBTcXAXMVFRVbBTkXAXkJFRdhBTsXBXoqJxUZZwU9FwUKK0cVG20FPxcF6iMNFR1zBUEXHy4GGxUheQVDFx9mBw0VI38FRRcF8iMJHQ+BFwUaIgUdhScFRx0JiRcBZRUVE40VFY8VF5EVGZMVG5UVHZcVISMdmycFSQVLEQMFBU0VpQsdCacXAWcVBU8VrQsdCa8XAWkVI3RwdS5tZW1vcnlfc3BhY2U8c2VtYXBob3JlX21lbT4AI3RwdS5tZW1vcnlfc3BhY2U8dm1lbT4AF7MFAgQCBAk/AQICAQIEBQUBAQELF7EBDyUXsQERJSF0cHUuZG1hX3NlbWFwaG9yZQAhdHB1LnNlbWFwaG9yZQAEpQUBEQMvBwMBBQcRAzUHAwULBQEDAQMJEAcFAwklAwIHAwsDAgcDDQ0EgwcBAwUPBJkFBQMFAyspAwMRBCsFBwkFAy0pAwMTBC0FBwsVAAcLAAMGAwEFAQA+ElFjtQ2XyxkJFUcVNWeDqxkTIyEdKS03C8dRgRUbHxshGRcVHx0PCR0RYnVpbHRpbgBzdGFibGVfbW9zYWljAHRwdQBtb2R1bGUAdHB1LnNlbV9hbGxvYwBhcml0aC5jb25zdGFudABmdW5jLmZ1bmMAdHB1LnJlZ2lvbgBmdW5jLnJldHVybgB0cHUuZW5xdWV1ZV9kbWEAdHB1LndhaXRfZG1hAHRwdS5zZW1fc2lnbmFsAHRwdS5zZW1fd2FpdAB0cHUueWllbGQAdGhpcmRfcGFydHkvcHkvamF4X3RyaXRvbi9nb29nbGUvcGFsbGFzX3RwdS9iYWNrX2NvbXBhdF90ZXN0LnB5AHRoaXJkX3BhcnR5L3B5L2Fic2wvdGVzdGluZy9hYnNsdGVzdC5weQBQYWxsYXNLZXJuZWxUZXN0LnRlc3Rfc2VtYXBob3JlX2FuZF9kbWFfMjJfMDRfMjAyNC48bG9jYWxzPi5mdW5jLjxsb2NhbHM+LmRtYV9rZXJuZWwuPGxvY2Fscz4uYm9keQBtYWluAHRoaXJkX3BhcnR5L3B5L2Fic2wvYXBwLnB5AHN0YWJsZV9tb3NhaWMudmVyc2lvbgBkaW1lbnNpb25fc2VtYW50aWNzAGZ1bmN0aW9uX3R5cGUAc2NhbGFyX3ByZWZldGNoAHNjcmF0Y2hfb3BlcmFuZHMAc3ltX25hbWUAL3J1bl9zY29wZWQAUGFsbGFzS2VybmVsVGVzdC50ZXN0X3NlbWFwaG9yZV9hbmRfZG1hXzIyXzA0XzIwMjQuPGxvY2Fscz4uZnVuYy48bG9jYWxzPi5kbWFfa2VybmVsAFBhbGxhc0tlcm5lbFRlc3QudGVzdF9zZW1hcGhvcmVfYW5kX2RtYV8yMl8wNF8yMDI0Ljxsb2NhbHM+LmZ1bmMAUGFsbGFzS2VybmVsVGVzdC50ZXN0X3NlbWFwaG9yZV9hbmRfZG1hXzIyXzA0XzIwMjQAX3J1bl9hbmRfZ2V0X3Rlc3RzX3Jlc3VsdABydW5fdGVzdHMAX3J1bl9pbl9hcHAuPGxvY2Fscz4ubWFpbl9mdW5jdGlvbgBfcnVuX21haW4AcnVuAF9ydW5faW5fYXBwAC9kbWFfc3RhcnRbdHJlZT1QeVRyZWVEZWYoKCosICgpLCAqLCAoKSwgKiwgKCksIE5vbmUsIE5vbmUsIE5vbmUpKSBkZXZpY2VfaWRfdHlwZT1EZXZpY2VJZFR5cGUuTUVTSF0AL2RtYV93YWl0W3RyZWU9UHlUcmVlRGVmKCgqLCAoKSwgKiwgKCkpKSBkZXZpY2VfaWRfdHlwZT1EZXZpY2VJZFR5cGUuTUVTSF0AdmFsdWUAL3NlbWFwaG9yZV9zaWduYWxbYXJnc190cmVlPVB5VHJlZURlZihbKiwgKCksICosIE5vbmVdKSBkZXZpY2VfaWRfdHlwZT1EZXZpY2VJZFR5cGUuTUVTSF0AL3NlbWFwaG9yZV93YWl0W2FyZ3NfdHJlZT1QeVRyZWVEZWYoWyosICgpLCAqXSldAA==", "serialization_format": 1, "needs_layout_passes": true}, "implicit_sharding": {"type": "MANUAL"}}\x00tpu_custom_call\x00dma_kernel\x00jit(func)/jit(main)/convert_element_type[new_dtype=float32 weak_type=False]\x00', + xla_call_module_version=9, + nr_devices=1, +) # End paste diff --git a/jax/_src/internal_test_util/export_back_compat_test_data/pallas/cuda_add_one.py b/jax/_src/internal_test_util/export_back_compat_test_data/pallas/triton_add_one.py similarity index 100% rename from jax/_src/internal_test_util/export_back_compat_test_data/pallas/cuda_add_one.py rename to jax/_src/internal_test_util/export_back_compat_test_data/pallas/triton_add_one.py diff --git a/jax/_src/internal_test_util/test_harnesses.py b/jax/_src/internal_test_util/test_harnesses.py index 4bf6d1ceb145..2c94907568d9 100644 --- a/jax/_src/internal_test_util/test_harnesses.py +++ b/jax/_src/internal_test_util/test_harnesses.py @@ -30,7 +30,7 @@ particular test, we write them as `Limitation` objects that can be reused in multiple tests and can also be used to generate documentation, e.g., the report of [unsupported and partially-implemented JAX -primitives](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/g3doc/jax_primitives_coverage.md) +primitives](https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/g3doc/jax_primitives_coverage.md) The limitations are used to filter out from tests the harnesses that are known to fail. A Limitation is specific to a harness. @@ -515,7 +515,7 @@ def _make_convert_element_type_harness(name, for old_dtype in jtu.dtypes.all: # TODO(bchetioui): JAX behaves weirdly when old_dtype corresponds to floating # point numbers and new_dtype is an unsigned integer. See issue - # https://github.com/google/jax/issues/5082 for details. + # https://github.com/jax-ml/jax/issues/5082 for details. for new_dtype in (jtu.dtypes.all if not (dtypes.issubdtype(old_dtype, np.floating) or dtypes.issubdtype(old_dtype, np.complexfloating)) @@ -2336,7 +2336,7 @@ def _make_select_and_scatter_add_harness(name, # Validate padding for padding in [ # TODO(bchetioui): commented out the test based on - # https://github.com/google/jax/issues/4690 + # https://github.com/jax-ml/jax/issues/4690 # ((1, 2), (2, 3), (3, 4)) # non-zero padding ((1, 1), (1, 1), (1, 1)) # non-zero padding ]: diff --git a/jax/_src/interpreters/ad.py b/jax/_src/interpreters/ad.py index ea9da4574e3d..f1f46a5c18f7 100644 --- a/jax/_src/interpreters/ad.py +++ b/jax/_src/interpreters/ad.py @@ -21,7 +21,6 @@ from functools import partial from typing import Any -import jax from jax._src import config from jax._src import linear_util as lu from jax._src.interpreters import partial_eval as pe @@ -58,7 +57,7 @@ def _update_annotation( # Implicit arguments never have tangents, so generate the tangent part of the # type annotation from explicit arguments only. explicit_avals = [aval for aval, explicit in orig_type if explicit] - tan_types = [(aval.at_least_vspace(), True) + tan_types = [(aval.to_tangent_aval(), True) for nz, aval in zip(explicit_nonzeros, explicit_avals) if nz] return lu.annotate(f, (*orig_type, *tan_types)) @@ -73,7 +72,7 @@ def jvp(fun: lu.WrappedFun, has_aux=False, instantiate=True, @lu.transformation def jvpfun(instantiate, transform_stack, primals, tangents): - tangents = [Zero.from_value(t) if not isinstance(t, Zero) + tangents = [Zero.from_primal_value(t) if not isinstance(t, Zero) and dtype(t) == float0 else t for t in tangents] ctx = (source_info_util.transform_name_stack('jvp') if transform_stack else contextlib.nullcontext()) @@ -125,7 +124,7 @@ def linearize(traceable, *primals, **kwargs): jvpfun, aux = jvp(traceable, has_aux=True) in_pvals = (tuple(pe.PartialVal.known(p) for p in primals) - + tuple(pe.PartialVal.unknown(get_aval(p).at_least_vspace()) + + tuple(pe.PartialVal.unknown(get_aval(p).to_tangent_aval()) for p in primals)) _, in_tree = tree_flatten(((primals, primals), {})) jvpfun_flat, out_tree = flatten_fun(jvpfun, in_tree) @@ -167,18 +166,6 @@ def unpair_pval(pval): aval_1, aval_2 = aval return (aval_1, const_1), (aval_2, const_2) -def replace_float0s(primal, tangent): - if dtype(tangent) == float0: - return zeros_like_jaxval(primal) - else: - return tangent - -def recast_to_float0(primal, tangent): - if core.primal_dtype_to_tangent_dtype(dtype(primal)) == float0: - return Zero(get_aval(primal).at_least_vspace()) - else: - return tangent - # NOTE: The FIXMEs below are caused by primal/tangent mixups (type # errors if you will) @@ -204,7 +191,7 @@ def write_cotangent(prim, v, ct): # assert v.aval.strip_weak_type() == joined_aval, (prim, v.aval, ct_aval) def read_cotangent(v): - return ct_env.pop(v, Zero(v.aval.at_least_vspace())) + return ct_env.pop(v, Zero(v.aval.to_tangent_aval())) def read_primal(v): if type(v) is Literal: @@ -296,11 +283,11 @@ def nonzero_tangent_outputs(*args, **kwargs): class JVPTrace(Trace): def pure(self, val): - tangent_zero = Zero(get_aval(val).at_least_vspace()) + tangent_zero = Zero.from_primal_value(val) return JVPTracer(self, val, tangent_zero) def lift(self, val): - tangent_zero = Zero(get_aval(val).at_least_vspace()) + tangent_zero = Zero.from_primal_value(val) return JVPTracer(self, val, tangent_zero) def sublift(self, val): @@ -344,7 +331,7 @@ def new_out_axes_thunk(): result = call_primitive.bind(_update_annotation(f_jvp, f.in_type, which_nz), *args, **new_params) primal_out, tangent_out = tree_unflatten(out_tree(), result) - tangent_out = [Zero(get_aval(p).at_least_vspace()) if t is None else t + tangent_out = [Zero.from_primal_value(p) if t is None else t for p, t in zip(primal_out, tangent_out)] return [JVPTracer(self, p, t) for p, t in zip(primal_out, tangent_out)] @@ -375,13 +362,11 @@ def process_custom_jvp_call(self, _, __, f_jvp, tracers, *, symbolic_zeros): primals_in = map(core.full_lower, primals_in) if not symbolic_zeros: tangents_in = map(instantiate_zeros, tangents_in) - tangents_in = map(replace_float0s, primals_in, tangents_in) else: tangents_in = map(replace_internal_symbolic_zeros, tangents_in) outs = f_jvp.call_wrapped(*it.chain(primals_in, tangents_in)) primals_out, tangents_out = split_list(outs, [len(outs) // 2]) tangents_out = map(replace_rule_output_symbolic_zeros, tangents_out) - tangents_out = map(recast_to_float0, primals_out, tangents_out) return map(partial(JVPTracer, self), primals_out, tangents_out) def post_process_custom_jvp_call(self, out_tracers, _): @@ -389,6 +374,9 @@ def post_process_custom_jvp_call(self, out_tracers, _): def process_custom_vjp_call(self, _, __, fwd, bwd, tracers, out_trees, symbolic_zeros): + # Local import to prevent an import cycle. + from jax._src.lax import lax + primals_in, tangents_in = unzip2((t.primal, t.tangent) for t in tracers) fwd_in = [(core.full_lower(p), type(t) is not Zero) for p, t in zip(primals_in, tangents_in)] @@ -396,14 +384,13 @@ def process_custom_vjp_call(self, _, __, fwd, bwd, tracers, out_trees, res_and_primals_out = fwd.call_wrapped(*fwd_in) _, res_tree = out_trees() res, primals_out = split_list(res_and_primals_out, [res_tree.num_leaves]) - avals_out = [raise_to_shaped(core.get_aval(x)) for x in primals_out] + avals_out = [raise_to_shaped(core.get_aval(x)).to_tangent_aval() for x in primals_out] # TODO(frostig,mattjj): avoid instantiating zeros when we don't have to! tangents_in = map(instantiate_zeros, tangents_in) tangents_out = custom_lin_p.bind( *res, *tangents_in, num_res=res_tree.num_leaves, bwd=bwd, out_avals=avals_out, symbolic_zeros=symbolic_zeros) - tangents_out = map(jax._src.lax.lax.tie_p.bind, primals_out, tangents_out) - tangents_out = map(recast_to_float0, primals_out, tangents_out) + tangents_out = map(lax.tie_p.bind, primals_out, tangents_out) return map(partial(JVPTracer, self), primals_out, tangents_out) def post_process_custom_vjp_call(self, out_tracers, _): @@ -503,8 +490,8 @@ def linear_jvp(primitive, primals, tangents, **params): val_out = primitive.bind(*primals, **params) if all(type(tangent) is Zero for tangent in tangents): if primitive.multiple_results: - return val_out, map(Zero.from_value, val_out) - return val_out, Zero.from_value(val_out) + return val_out, map(Zero.from_primal_value, val_out) + return val_out, Zero.from_primal_value(val_out) else: tangents = map(instantiate_zeros, tangents) return val_out, primitive.bind(*tangents, **params) @@ -531,7 +518,7 @@ def standard_jvp(jvprules, primitive, primals, tangents, **params): val_out = primitive.bind(*primals, **params) tangents_out = [rule(t, *primals, **params) for rule, t in zip(jvprules, tangents) if rule is not None and type(t) is not Zero] - return val_out, functools.reduce(add_tangents, tangents_out, Zero.from_value(val_out)) + return val_out, functools.reduce(add_tangents, tangents_out, Zero.from_primal_value(val_out)) def defjvp2(primitive, *jvprules): assert isinstance(primitive, Primitive) @@ -543,7 +530,7 @@ def standard_jvp2(jvprules, primitive, primals, tangents, **params): tangents_out = (rule(t, val_out, *primals, **params) for rule, t in zip(jvprules, tangents) if rule is not None and type(t) is not Zero) tangents_out = list(tangents_out) - return val_out, functools.reduce(add_tangents, tangents_out, Zero.from_value(val_out)) + return val_out, functools.reduce(add_tangents, tangents_out, Zero.from_primal_value(val_out)) def add_tangents(x, y): if type(x) is Zero: @@ -578,7 +565,7 @@ def defjvp_zero(primitive): def zero_jvp(primitive, primals, tangents, **params): r = primitive.bind(*primals, **params) - return r, Zero.from_value(r) + return r, Zero.from_primal_value(r) deflinear2(add_jaxvals_p, lambda t, *args: (t, t)) @@ -589,7 +576,7 @@ def instantiate_zeros(tangent): @lu.transformation_with_aux def traceable(in_tree, *primals_and_tangents): primals, tangents = tree_unflatten(in_tree, primals_and_tangents) - tangents = [Zero(get_aval(p).at_least_vspace()) if t is None else t + tangents = [Zero.from_primal_value(p) if t is None else t for p, t in zip(primals, tangents)] primals_out, tangents_out = yield (primals, tangents), {} tangents_out = [None if type(t) is Zero else t for t in tangents_out] @@ -693,7 +680,7 @@ def _jvp_jaxpr(jaxpr, nonzeros, instantiate): f = lu.wrap_init(core.jaxpr_as_fun(jaxpr)) f_jvp, out_nonzeros = f_jvp_traceable(jvp(f, instantiate=instantiate, transform_stack=False), nonzeros) - tangent_avals = [aval for aval, nz in zip(jaxpr.in_avals, nonzeros) if nz] + tangent_avals = [aval.to_tangent_aval() for aval, nz in zip(jaxpr.in_avals, nonzeros) if nz] avals_in = list(it.chain(jaxpr.in_avals, tangent_avals)) jaxpr_out, avals_out, literals_out, () = pe.trace_to_jaxpr_dynamic(f_jvp, avals_in) return core.ClosedJaxpr(jaxpr_out, literals_out), out_nonzeros() @@ -703,7 +690,7 @@ def f_jvp_traceable(nonzeros, *primals_and_nztangents): num_primals = len(nonzeros) primals = list(primals_and_nztangents[:num_primals]) nonzero_tangents = iter(primals_and_nztangents[num_primals:]) - tangents = [next(nonzero_tangents) if nz else Zero.from_value(p) + tangents = [next(nonzero_tangents) if nz else Zero.from_primal_value(p) for p, nz in zip(primals, nonzeros)] primals_out, tangents_out = yield (primals, tangents), {} out_nonzeros = [type(t) is not Zero for t in tangents_out] diff --git a/jax/_src/interpreters/batching.py b/jax/_src/interpreters/batching.py index fbcd2c4a7a30..27cde6d31d35 100644 --- a/jax/_src/interpreters/batching.py +++ b/jax/_src/interpreters/batching.py @@ -88,6 +88,7 @@ def _jumble_flatten(jumble): elt_ty = jumble.aval.elt_ty.update(shape=tuple(new_shape)) aval = jumble.aval.replace(elt_ty=elt_ty) return (lengths, jumble.data), aval + def _jumble_unflatten(aval, x): lengths, data = x new_shape = [d.replace(lengths=lengths[d.lengths - 1]) @@ -251,7 +252,10 @@ def to_elt(trace: Trace, get_idx: GetIdx, x: Vmappable, spec: MapSpec) -> Elt: return (BatchTracer(trace, x, spec, source_info_util.current()) if spec is not None else x) else: - assert False + # TODO(mvoz): This is a terrible place to fall into if you pass + # a non jumble type in, make it clearer what went wrong. + assert False, f'Unexpected type in ELT? {type(x)}' + to_elt_handlers: dict[type, ToEltHandler] = {} def from_elt(trace: BatchTrace, axis_size: AxisSize, i: int, diff --git a/jax/_src/interpreters/mlir.py b/jax/_src/interpreters/mlir.py index 31d281b88f11..af773365b12d 100644 --- a/jax/_src/interpreters/mlir.py +++ b/jax/_src/interpreters/mlir.py @@ -277,12 +277,7 @@ def ir_constant(val: Any) -> IrValues: raise TypeError(f"No constant handler for type: {type(val)}") def _numpy_array_constant(x: np.ndarray | np.generic) -> IrValues: - element_type = dtype_to_ir_type(x.dtype) - shape = x.shape - if x.dtype == np.bool_: - x = np.packbits(x, bitorder='little') # type: ignore - x = np.ascontiguousarray(x) - attr = ir.DenseElementsAttr.get(x, type=element_type, shape=shape) # type: ignore + attr = _numpy_array_attribute(x) return hlo.constant(attr) @@ -344,6 +339,87 @@ def _token_constant_handler(val): return hlo.create_token() register_constant_handler(core.Token, _token_constant_handler) +# Attributes + +AttributeHandler = Callable[[Any], ir.Attribute] +_attribute_handlers: dict[type[Any], AttributeHandler] = {} + +def register_attribute_handler(type_: type[Any], handler_fun: AttributeHandler): + _attribute_handlers[type_] = handler_fun + +def get_attribute_handler(type_: type[Any]) -> AttributeHandler: + return _attribute_handlers[type_] + +def _numpy_scalar_attribute(val: Any) -> ir.Attribute: + mlir_type = dtype_to_ir_type(val.dtype) + if isinstance(mlir_type, ir.IntegerType): + return ir.IntegerAttr.get(mlir_type, val) + elif isinstance(mlir_type, ir.FloatType): + return ir.FloatAttr.get(mlir_type, val) + else: + raise TypeError(f"Unsupported scalar attribute type: {type(val)}") + +def _numpy_array_attribute(x: np.ndarray | np.generic) -> ir.Attribute: + element_type = dtype_to_ir_type(x.dtype) + shape = x.shape + if x.dtype == np.bool_: + x = np.packbits(x, bitorder='little') # type: ignore + x = np.ascontiguousarray(x) + return ir.DenseElementsAttr.get(x, type=element_type, shape=shape) # type: ignore + +def _numpy_array_attribute_handler(val: np.ndarray | np.generic) -> ir.Attribute: + if 0 in val.strides and val.size > 0: + raise ValueError( + "NumPy arrays with zero strides are not supported as MLIR attributes") + if val.dtype == dtypes.float0: + val = np.zeros(val.shape, dtype=np.bool_) + if dtypes.is_python_scalar(val) or np.isscalar(val): + return _numpy_scalar_attribute(val) + else: + return _numpy_array_attribute(val) + +register_attribute_handler(np.ndarray, _numpy_array_attribute_handler) + +for _scalar_type in [np.int8, np.int16, np.int32, np.int64, + np.uint8, np.uint16, np.uint32, np.uint64, + np.float16, np.float32, np.float64, + np.complex64, np.complex128, + np.bool_, np.longlong, dtypes.bfloat16]: + register_attribute_handler(_scalar_type, _numpy_array_attribute_handler) # type: ignore + +def _python_scalar_attribute_handler(dtype, val): + return _numpy_scalar_attribute(np.array(val, dtype)) + +for ptype, dtype in dtypes.python_scalar_dtypes.items(): + register_attribute_handler( + ptype, partial(_python_scalar_attribute_handler, dtype)) + +register_attribute_handler(str, ir.StringAttr.get) +register_attribute_handler(bytes, ir.StringAttr.get) + +def _dict_attribute_handler(val: dict[str, Any]) -> ir.Attribute: + return ir.DictAttr.get({k: ir_attribute(v) for k, v in val.items()}) + +register_attribute_handler(dict, _dict_attribute_handler) + +def _sequence_attribute_handler(val: Sequence[Any]) -> ir.Attribute: + return ir.ArrayAttr.get([ir_attribute(v) for v in val]) + +register_attribute_handler(list, _sequence_attribute_handler) +register_attribute_handler(tuple, _sequence_attribute_handler) + +def ir_attribute(val: Any) -> ir.Attribute: + """Convert a Python value to an MLIR attribute.""" + for t in type(val).__mro__: + handler = _attribute_handlers.get(t) + if handler: + out = handler(val) + assert isinstance(out, ir.Attribute), (type(val), out) + return out + if hasattr(val, '__jax_array__'): + return ir_attribute(val.__jax_array__()) + raise TypeError(f"No attribute handler defined for type: {type(val)}") + # Source locations def get_canonical_source_file(file_name: str, caches: TracebackCaches) -> str: @@ -407,10 +483,9 @@ def _traceback_to_location(ctx: ModuleContext, tb: xc.Traceback) -> ir.Location: return loc def _source_info_to_location( - ctx: ModuleContext, primitive: core.Primitive, params: dict[str, Any], + ctx: ModuleContext, primitive: core.Primitive, source_info: source_info_util.SourceInfo) -> ir.Location: - eqn_str = (f'{source_info.name_stack}/' - f'{core.str_eqn_compact(primitive, params)}') + eqn_str = f'{source_info.name_stack}/{primitive.name}' if config.include_full_tracebacks_in_locations.value: if source_info.traceback is None: loc = ir.Location.unknown() @@ -462,7 +537,7 @@ def dump_module_to_file(module: ir.Module, stage_name: str) -> str | None: sym_name = module.operation.attributes['sym_name'] module_name = ir.StringAttr(sym_name).value - name = f"jax_ir{id}_{_make_string_safe_for_filename(module_name)}_{stage_name}.mlir" + name = f"jax_ir{id:04d}_{_make_string_safe_for_filename(module_name)}_{stage_name}.mlir" out_dir = path.Path(out_dir_name) out_dir.mkdir(parents=True, exist_ok=True) @@ -604,6 +679,7 @@ class ModuleContext: host_callbacks: list[Any] # Keep state for the lowering of shape polymorphism shape_poly_state: ShapePolyLoweringState + all_default_mem_kind: bool # Cached primitive lowerings. cached_primitive_lowerings: dict[Any, func_dialect.FuncOp] @@ -633,7 +709,8 @@ def __init__( symbol_table: ir.SymbolTable | None = None, cached_primitive_lowerings: None | (dict[Any, func_dialect.FuncOp]) = None, traceback_caches: None | TracebackCaches = None, - shape_poly_state = None): + shape_poly_state = None, + all_default_mem_kind: bool = True): self.context = context or make_ir_context() self.module = module or ir.Module.create(loc=ir.Location.unknown(self.context)) @@ -651,6 +728,7 @@ def __init__( self.host_callbacks = host_callbacks self.shape_poly_state = ( shape_poly_state or ShapePolyLoweringState((), tuple(platforms))) + self.all_default_mem_kind = all_default_mem_kind self.lowering_parameters = lowering_parameters @property @@ -873,7 +951,7 @@ class LoweringResult(NamedTuple): shape_poly_state: ShapePolyLoweringState -_platforms_with_donation = ["cpu", "cuda", "rocm", "tpu"] +_platforms_with_donation = ["cpu", "cuda", "rocm", "tpu", "neuron"] def add_manual_axes(axis_ctx: sharding_impls.SPMDAxisContext, sharding, ndim): @@ -921,7 +999,7 @@ def _to_xla_layout(layout: DeviceLocalLayout | None | AutoLayout, return "auto" if aval is core.abstract_token: return "default" - return layout._to_xla_layout(aval.dtype) # type: ignore + return str(layout._to_xla_layout(aval.dtype)) # type: ignore def _get_mem_kind(s: JSharding | None) -> str | None: @@ -975,6 +1053,7 @@ def lower_jaxpr_to_module( result_memory_kinds = (map(_get_mem_kind, result_shardings) if result_shardings is not None else None) + # TODO(yashkatariya): Simplify the donation logic. xla_donated_args = None platforms_with_donation = [p for p in platforms if p in _platforms_with_donation] @@ -988,14 +1067,12 @@ def lower_jaxpr_to_module( if num_partitions > 1 and ( result_shardings is None or all(s is None for s in result_shardings)): xla_donated_args = donated_args + donated_args = [False] * len(donated_args) if xla_donated_args is None: - input_output_aliases, donated_args = _set_up_aliases( + input_output_aliases, donated_args, xla_donated_args = _set_up_aliases( input_output_aliases, in_avals, out_avals, donated_args, - arg_memory_kinds, result_memory_kinds) - unlowerable_effects = lowerable_effects.filter_not_in(jaxpr.effects) - if unlowerable_effects: - raise ValueError(f'Cannot lower jaxpr with effects: {jaxpr.effects}') - if xla_donated_args is None and any(donated_args): + arg_memory_kinds, result_memory_kinds, in_layouts, out_layouts) + if any(donated_args): unused_donations = [str(a) for a, d in zip(in_avals, donated_args) if d] msg = "See an explanation at https://jax.readthedocs.io/en/latest/faq.html#buffer-donation." if not platforms_with_donation: @@ -1003,15 +1080,13 @@ def lower_jaxpr_to_module( if unused_donations: warnings.warn("Some donated buffers were not usable:" f" {', '.join(unused_donations)}.\n{msg}") - - if xla_donated_args is not None: - assert input_output_aliases is None - if input_output_aliases is not None: - assert xla_donated_args is None - # Delete donated_args by default here, since it's not needed beyond this point del donated_args + unlowerable_effects = lowerable_effects.filter_not_in(jaxpr.effects) + if unlowerable_effects: + raise ValueError(f'Cannot lower jaxpr with effects: {jaxpr.effects}') + # HLO channels need to start at 1 channel_iter = itertools.count(1) # Create a keepalives list that will be mutated during the lowering. @@ -1034,20 +1109,22 @@ def lower_jaxpr_to_module( channel_iterator=channel_iter, host_callbacks=host_callbacks, lowering_parameters=lowering_parameters, - shape_poly_state=ShapePolyLoweringState(dim_vars, platforms)) + shape_poly_state=ShapePolyLoweringState(dim_vars, platforms), + all_default_mem_kind=all_default_mem_kind) with ctx.context, ir.Location.unknown(ctx.context): # Remove module name characters that XLA would alter. This ensures that # XLA computation preserves the module name. attrs = ctx.module.operation.attributes if config.use_shardy_partitioner.value: - assert (isinstance(axis_context, sharding_impls.ShardingContext) and - axis_context.mesh_shape is not None) - ctx.module.body.append( - dialects.sdy.MeshOp( - "mesh", - dialects.sdy.MeshAttr.get( - [dialects.sdy.MeshAxisAttr.get(name, size) - for name, size in axis_context.mesh_shape]))) + if (isinstance(axis_context, sharding_impls.ShardingContext) and + axis_context.mesh_shape is not None): + sdy_mesh_attr = dialects.sdy.MeshAttr.get( + [dialects.sdy.MeshAxisAttr.get(name, size) + for name, size in axis_context.mesh_shape]) + else: + sdy_mesh_attr = dialects.sdy.MeshAttr.get([]) + + ctx.module.body.append(dialects.sdy.MeshOp("mesh", sdy_mesh_attr)) module_name = _module_name_regex.sub("_", module_name) attrs["sym_name"] = ir.StringAttr.get(module_name) attrs["mhlo.num_replicas"] = i32_attr(num_replicas) @@ -1090,8 +1167,9 @@ def emit_diagnostic_info(d): ctx.shape_poly_state) -def _set_up_aliases(input_output_aliases, avals_in, avals_out, donated_args, - arg_memory_kinds, result_memory_kinds): +def _set_up_aliases(input_output_aliases, avals_in, avals_out, + donated_args, arg_memory_kinds, result_memory_kinds, + in_layouts, out_layouts): if input_output_aliases is None: input_output_aliases = [None] * len(avals_in) else: @@ -1120,6 +1198,7 @@ def _set_up_aliases(input_output_aliases, avals_in, avals_out, donated_args, if donated and aliased is None: donations[(aval, am)].append(i) + xla_donated_args = None out_donated_args = list(donated_args) for i, (aval, rm) in enumerate(zip(avals_out, result_memory_kinds)): # Only donate if memory kinds match. Relax this when the compiler can @@ -1127,10 +1206,23 @@ def _set_up_aliases(input_output_aliases, avals_in, avals_out, donated_args, key = (aval, rm) if donations.get(key, ()): input_id = donations[key].popleft() - input_output_aliases[input_id] = i out_donated_args[input_id] = False + # We can alias if XLA performs layout assignment because XLA will + # respect the aliases when assigning layouts. Its only for two + # mismatched explicitly assigned layouts that XLA will certainly fail. + if (in_layouts is None or + out_layouts is None or + in_layouts[input_id] == out_layouts[i] or + isinstance(in_layouts[input_id], AutoLayout) or + isinstance(out_layouts[i], AutoLayout)): + input_output_aliases[input_id] = i + else: + # Fallback to xla donation if layouts don't match. + if xla_donated_args is None: + xla_donated_args = [False] * len(avals_in) + xla_donated_args[input_id] = True - return input_output_aliases, out_donated_args + return input_output_aliases, out_donated_args, xla_donated_args Token = ir.Value token_type = hlo.TokenType.get @@ -1554,7 +1646,15 @@ def replicate_trailing_dims(ctx, val: ir.Value, aval) -> ir.Value: # For example: if the key.shape is (8, 2) and key_data(key).shape is (8, 2, 2), # then the sharding will be P(P.UNCONSTRAINED, P.UNCONSTRAINED, None). # The below custom call achieves the sharding like above example. - return wrap_with_sharding_op( + if config.use_shardy_partitioner.value: + physical_ndim = core.physical_aval(aval).ndim + s = sharding.SdyArraySharding( + mesh_name='mesh', + dimension_shardings=[sharding.SdyDimSharding(axes=[], is_closed=i >= aval.ndim) + for i in range(physical_ndim)]) + return wrap_with_sharding_op(ctx, val, aval, s) + else: + return wrap_with_sharding_op( ctx, val, aval, xc.HloSharding.replicate().to_proto(), unspecified_dims=set(range(aval.ndim))) @@ -1665,7 +1765,7 @@ def get_override_lowering_rule(primitive: core.Primitive) -> LoweringRule | None in_nodes = map(read, eqn.invars) source_info = eqn.source_info.replace( name_stack=name_stack + eqn.source_info.name_stack) - loc = _source_info_to_location(ctx, eqn.primitive, eqn.params, source_info) + loc = _source_info_to_location(ctx, eqn.primitive, source_info) with source_info_util.user_context(eqn.source_info.traceback), loc: override_rule = get_override_lowering_rule(eqn.primitive) platform_rules: dict[str, LoweringRule] = {} @@ -1823,6 +1923,10 @@ def lower_per_platform(ctx: LoweringRuleContext, lambda o: wrap_compute_type_in_place(ctx, o.owner), filter(_is_not_block_argument, flatten_ir_values(output)), ) + map( + lambda o: wrap_xla_metadata_in_place(ctx, o.owner), + flatten_ir_values(output), + ) return output assert len(platforms) > 1 and len(kept_rules) >= 2, (platforms, kept_rules) @@ -1864,6 +1968,10 @@ def lower_per_platform(ctx: LoweringRuleContext, lambda o: wrap_compute_type_in_place(ctx, o.owner), filter(_is_not_block_argument, out_nodes), ) + map( + lambda o: wrap_xla_metadata_in_place(ctx, o.owner), + out_nodes, + ) if inner_ctx.tokens_out is not None: assert len(ordered_effects) == len(inner_ctx.tokens_out) out_nodes = [inner_ctx.tokens_out.get(eff) @@ -2025,14 +2133,32 @@ def wrap_compute_type_in_place(ctx, op): op.operation.attributes["mhlo.frontend_attributes"] = ir.DictAttr.get(dict_attr) +def wrap_xla_metadata_in_place(ctx, op): + ctx_attributes = {} + existing_attributes = {} + if ctx.jaxpr_eqn_ctx is not None and ctx.jaxpr_eqn_ctx.xla_metadata: + for k, v in ctx.jaxpr_eqn_ctx.xla_metadata.items(): + ctx_attributes[k] = ir.StringAttr.get(str(v).lower()) + if isinstance(op, ir.Operation): + # combine with existing mhlo.frontend_attributes + op_attributes_dict = {attr.name: attr.attr for attr in op.attributes} + for k, attributes in op_attributes_dict.items(): + if k == "mhlo.frontend_attributes": + v_dict = {attr.name: attr.attr for attr in attributes} + for fa_key, fa_val in v_dict.items(): + existing_attributes[fa_key] = fa_val + op.attributes["mhlo.frontend_attributes"] = ir.DictAttr.get( + ctx_attributes | existing_attributes + ) + + def broadcast_in_dim(ctx: LoweringRuleContext, op, aval_out: core.AbstractValue, *, broadcast_dimensions) -> ir.Value: # broadcast_dimension[i] is the axis of the result where the axis i of # op is broadcast. # Lower a possibly-dynamic broadcast_in_dim if dtypes.issubdtype(aval_out.dtype, dtypes.extended): # type: ignore - elt_shape = aval_out.dtype._rules.physical_element_aval( # type: ignore - aval_out.dtype).shape # type: ignore + elt_shape = core.physical_element_aval(aval_out.dtype).shape # type: ignore trailing_dims = [aval_out.ndim + i for i in range(len(elt_shape))] # type: ignore broadcast_dimensions = [*broadcast_dimensions, *trailing_dims] physical_aval_out = core.physical_aval(aval_out) @@ -2086,8 +2212,7 @@ def reshape(ctx: LoweringRuleContext, op, aval_out: core.AbstractValue) -> ir.Va def slice_op(ctx: LoweringRuleContext, x, aval_out, *, start_indices, limit_indices, strides) -> ir.Value: if dtypes.issubdtype(aval_out.dtype, dtypes.extended): - elt_shape = aval_out.dtype._rules.physical_element_aval( - aval_out.dtype).shape + elt_shape = core.physical_element_aval(aval_out.dtype).shape trailing_zeros = [0] * len(elt_shape) trailing_ones = [1] * len(elt_shape) start_indices = (*start_indices, *trailing_zeros) @@ -2114,8 +2239,7 @@ def dynamic_slice(ctx: LoweringRuleContext, aval_out, x, *, start_indices) -> ir.Value: x_aval = ctx.avals_in[0] if dtypes.issubdtype(aval_out.dtype, dtypes.extended): - elt_shape = aval_out.dtype._rules.physical_element_aval( - aval_out.dtype).shape + elt_shape = core.physical_element_aval(aval_out.dtype).shape index_avals = ctx.avals_in[1:] dtype = dtypes.canonicalize_dtype( index_avals[0].dtype if index_avals else 'int64') # type: ignore @@ -2148,8 +2272,7 @@ def dynamic_slice(ctx: LoweringRuleContext, aval_out, x, *, def dynamic_update_slice(ctx: LoweringRuleContext, aval_out, x, update, *, start_indices) -> ir.Value: if dtypes.issubdtype(aval_out.dtype, dtypes.extended): - elt_shape = aval_out.dtype._rules.physical_element_aval( - aval_out.dtype).shape + elt_shape = core.physical_element_aval(aval_out.dtype).shape index_avals = ctx.avals_in[2:] dtype = dtypes.canonicalize_dtype( index_avals[0].dtype if index_avals else 'int64') # type: ignore diff --git a/jax/_src/interpreters/partial_eval.py b/jax/_src/interpreters/partial_eval.py index 558f735f4403..6bc3cceb7ab7 100644 --- a/jax/_src/interpreters/partial_eval.py +++ b/jax/_src/interpreters/partial_eval.py @@ -35,6 +35,7 @@ from jax._src import profiler from jax._src import source_info_util from jax._src import compute_on +from jax._src import xla_metadata as xla_metadata_lib from jax._src.api_util import (flattened_fun_in_tree, flatten_fun_nokwargs, fun_sourceinfo) from jax._src.core import (Trace, Tracer, Jaxpr, Literal, get_aval, @@ -167,11 +168,6 @@ def new_instantiated_literal(self, val) -> JaxprTracer: def new_instantiated_const(self, val) -> JaxprTracer: aval = get_aval(val) - if isinstance(aval, DShapedArray): - shape = [self.new_instantiated_const(d) - if isinstance(d, Tracer) and d._trace.level < self.level else d - for d in aval.shape] - aval = aval.update(shape=tuple(shape)) return JaxprTracer(self, PartialVal.unknown(aval), ConstVar(val)) def new_arg(self, pval: PartialVal) -> JaxprTracer: @@ -257,15 +253,9 @@ def process_call(self, primitive, f, tracers, params): # which were unknown to the first call (corresponding to in_avals). # Wrap f to perform the partial evaluation and plumb out aux data. - if not config.dynamic_shapes.value: - f_ = trace_to_subjaxpr_nounits_fwd(f, self.main, False) - f_, aux = partial_eval_wrapper_nounits(f_, tuple(in_knowns), - tuple(in_avals)) - else: - if f.in_type is None: - f = lu.annotate(f, tuple((a, True) for a in in_avals)) - f_, aux = trace_to_subjaxpr_nounits_dyn(f, self.main, tuple(in_knowns), - f.in_type, False) + f_ = trace_to_subjaxpr_nounits_fwd(f, self.main, False) + f_, aux = partial_eval_wrapper_nounits(f_, tuple(in_knowns), + tuple(in_avals)) # Adjust parameters (e.g. donated_invars) for the call to be evaluated now. const_params = update_params(params, in_knowns, 0) @@ -568,92 +558,6 @@ def partial_eval_wrapper_nounits( out_knowns, out_avals, out_consts = partition_pvals(out_pvals) yield (*out_consts, *res), (*maybe_fwds, out_knowns, out_avals, jaxpr, env) -@lu.transformation_with_aux -def trace_to_subjaxpr_nounits_dyn( - main: core.MainTrace, in_knowns: Sequence[bool], in_type: InputType, - instantiate: bool | Sequence[bool], - *in_consts: Any): - trace = main.with_cur_sublevel() - in_avals, which_explicit = unzip2(in_type) - - # To form input tracers from in_type, we need to first build ConstVar tracers - # for all axis sizes, so that we can then use those tracers in the shapes of - # avals for unknown inputs' tracers. We use ConstVar recipes for on-the-fly - # type agreement checking via get_referent. - in_consts_full: list[JaxprTracer | None] = [None] * len(in_type) - in_consts_iter, in_knowns_iter = iter(in_consts), iter(in_knowns) - for idx, (aval, explicit) in enumerate(in_type): - if explicit and next(in_knowns_iter): - constval = next(in_consts_iter) - if isinstance(aval, DShapedArray): - for i, d in enumerate(aval.shape): - if isinstance(d, DBIdx): - if in_consts_full[d.val] is None: - in_consts_full[d.val] = \ - JaxprTracer(trace, PartialVal.unknown(in_avals[d.val]), - ConstVar(constval.shape[i])) - assert core.same_referent(constval.shape[i], in_consts_full[d.val]) - shape = [in_consts_full[d.val] if type(d) is DBIdx else d - for d in aval.shape] - aval = aval.update(shape=tuple(shape)) - in_consts_full[idx] = JaxprTracer(trace, PartialVal.unknown(aval), - ConstVar(constval)) - # Check that we covered all axis sizes with ConstVar tracers. - for idx, (aval, explicit) in enumerate(in_type): - if not explicit: assert in_consts_full[idx] is not None - if isinstance(aval, DShapedArray): - assert all(type(d) is not DBIdx or in_consts_full[d.val] is not None - for d in aval.shape) - - # Next, build tracers for all unknown inputs, using the in_consts_full list - # for axis size tracers when necessary. - in_tracers = [] - in_knowns_iter = iter(in_knowns) - for aval, explicit in in_type: - if explicit and not next(in_knowns_iter): - if isinstance(aval, DShapedArray): - shape = [in_consts_full[d.val] if type(d) is DBIdx else d - for d in aval.shape] - aval = aval.update(shape=tuple(shape)) - tracer = JaxprTracer(trace, PartialVal.unknown(aval), LambdaBinding()) - in_tracers.append(tracer) - - # Merge in_consts and in_tracers and call wrapped fn with explicit arguments. - in_args = merge_lists(in_knowns, in_tracers, in_consts) - ans = yield in_args, {} - - # Instantiate outputs and build jaxpr. - if isinstance(instantiate, bool): - instantiate = [instantiate] * len(ans) - out_tracers = map(trace.full_raise, map(core.full_lower, ans)) - out_tracers = [trace.instantiate_const(trace.full_raise(t)) if inst else t - for inst, t in zip(instantiate, out_tracers)] - - # Collect known outputs. - out_knowns: list[bool] = [t.is_known() for t in out_tracers] - out_consts: list[Any] = [t.pval.get_known() for t in out_tracers - if t.is_known()] - - # Build the jaxpr. - out_tracers = [t for t in out_tracers if not t.is_known()] - jaxpr, res, env = tracers_to_jaxpr(in_tracers, out_tracers) - out_avals = [v.aval for v in jaxpr.outvars] - idx_map = {v: InDBIdx(i) - for i, v in enumerate(it.chain(jaxpr.constvars, jaxpr.invars))} - out_type = [(a.update(shape=tuple(idx_map.get(d, d) for d in a.shape)) # type: ignore - if type(a) is DShapedArray else a, True) for a in out_avals] - - # Which residuals are just forwarded inputs? Check obj id, then prune. - id_map = {id(c.recipe.val): i for i, c in enumerate(in_consts_full) # type: ignore - if c is not None} - fwds: list[int | None] = [id_map.get(id(c)) for c in res] - res = tuple(c for c, fwd in zip(res, fwds) if fwd is None) - - del main, in_consts, trace, in_consts_iter, in_knowns_iter, in_consts_full, \ - in_tracers, in_args, ans, out_tracers, out_avals - yield (*out_consts, *res), (fwds, out_knowns, tuple(out_type), jaxpr, env) - - custom_partial_eval_rules: dict[Primitive, Callable] = {} call_partial_eval_rules: dict[Primitive, Callable] = {} call_param_updaters: dict[Primitive, Callable] = {} @@ -898,8 +802,11 @@ def new_eqn_recipe(in_tracers: Sequence[JaxprTracer], assert ("donated_invars" in params and len(params["donated_invars"]) == len(params["call_jaxpr"].invars)) out_avals = [core.raise_to_shaped(t.aval) for t in out_tracers] - ctx = ctx or JaxprEqnContext(compute_on.current_compute_type(), - config.threefry_partitionable.value) + ctx = ctx or JaxprEqnContext( + compute_on.current_compute_type(), + config.threefry_partitionable.value, + xla_metadata_lib.current_xla_metadata(), + ) return JaxprEqnRecipe(object(), tuple(in_tracers), map(ref, out_tracers), out_avals, primitive, params, effects, source_info, ctx) @@ -1355,7 +1262,7 @@ def partial_eval_jaxpr_custom_rule_not_implemented( name: str, saveable: Callable[..., RematCases_], unks_in: Sequence[bool], inst_in: Sequence[bool], eqn: JaxprEqn) -> PartialEvalCustomResult: msg = (f'custom-policy remat rule not implemented for {name}, ' - 'open a feature request at https://github.com/google/jax/issues!') + 'open a feature request at https://github.com/jax-ml/jax/issues!') raise NotImplementedError(msg) @@ -1532,6 +1439,18 @@ def _prune_closed_jaxpr_outputs( def dce_jaxpr(jaxpr: Jaxpr, used_outputs: Sequence[bool], instantiate: bool | Sequence[bool] = False, ) -> tuple[Jaxpr, list[bool]]: + """Runs dead-code elementation on a given jaxpr. + + Args: + jaxpr: The jaxpr to DCE. + used_outputs: A list of bools indicating which outputs are used. + instantiate: A bool or a list of bools indicating which inputs should be + considered used, regardless of whether they are actually used in a jaxpr. + If a bool, the same value is used for all inputs. + + Returns: + A tuple of ``(new_jaxpr, used_inputs)``. + """ if type(instantiate) is bool: instantiate = (instantiate,) * len(jaxpr.invars) return _dce_jaxpr(jaxpr, tuple(used_outputs), tuple(instantiate)) @@ -1541,7 +1460,7 @@ def dce_jaxpr_consts(jaxpr: Jaxpr, used_outputs: Sequence[bool], instantiate: bool | Sequence[bool] = False, ) -> tuple[Jaxpr, list[bool], list[bool]]: jaxpr_ = convert_constvars_jaxpr(jaxpr) - new_jaxpr, used_inputs_ = dce_jaxpr(jaxpr_, used_outputs) + new_jaxpr, used_inputs_ = dce_jaxpr(jaxpr_, used_outputs, instantiate) used_consts, used_inputs = split_list(used_inputs_, [len(jaxpr.constvars)]) if sum(used_consts): new_jaxpr = convert_invars_to_constvars(new_jaxpr, sum(used_consts)) @@ -2046,11 +1965,9 @@ def process_primitive(self, primitive, tracers, params): def default_process_primitive(self, primitive, tracers, params): avals = [t.aval for t in tracers] out_avals, effects = primitive.abstract_eval(*avals, **params) - # == serve as a "not xor" here. - if not (isinstance(out_avals, (tuple,list)) == primitive.multiple_results): - raise ValueError(f"{primitive}.abstract_eval() method should return" - f" a tuple or a list if {primitive}.multiple_results" - " is true. Otherwise it shouldn't.") + if isinstance(out_avals, (tuple, list)) != primitive.multiple_results: + raise ValueError(f"{primitive}.abstract_eval() method should return " + f"a tuple or a list iff {primitive}.multiple_results.") out_avals = [out_avals] if not primitive.multiple_results else out_avals source_info = source_info_util.current() out_tracers = [DynamicJaxprTracer(self, a, source_info) for a in out_avals] @@ -2144,6 +2061,7 @@ def post_process_map(self, map_primitive, out_tracers, params): def process_custom_jvp_call(self, prim, fun, jvp, tracers, *, symbolic_zeros): in_avals = [t.aval for t in tracers] + in_tangent_avals = [t.to_tangent_aval() for t in in_avals] with core.new_sublevel(): fun_jaxpr, out_avals, consts, () = trace_to_subjaxpr_dynamic(fun, self.main, in_avals) closed_fun_jaxpr = core.ClosedJaxpr(convert_constvars_jaxpr(fun_jaxpr), ()) @@ -2152,7 +2070,7 @@ def process_custom_jvp_call(self, prim, fun, jvp, tracers, *, symbolic_zeros): @_memoize def jvp_jaxpr_thunk(*in_zeros): for store in jvp.stores: store and store.reset() - nz_tangent_avals, zero_avals = partition_list(in_zeros, in_avals) + nz_tangent_avals, zero_avals = partition_list(in_zeros, in_tangent_avals) jvp_, out_zeros = _jvp_jaxpr_zeros(jvp, in_zeros, tuple(zero_avals)) in_avals_ = (*in_avals, *nz_tangent_avals) jaxpr, _, out_consts, () = trace_to_subjaxpr_dynamic(jvp_, main_(), in_avals_) @@ -2770,7 +2688,7 @@ def call_padding_rule(prim, in_avals, out_avals, *args, call_jaxpr, **params): # TODO(mattjj): the following are deprecated; update callers to _nounits version -# See https://github.com/google/jax/pull/9498 +# See https://github.com/jax-ml/jax/pull/9498 @lu.transformation def trace_to_subjaxpr(main: core.MainTrace, instantiate: bool | Sequence[bool], pvals: Sequence[PartialVal]): @@ -2814,14 +2732,13 @@ def inline_jaxpr_into_trace( outvars = [Var('', v.aval) for v in eqn.outvars] src_ = (src if not eqn.source_info.name_stack else src.replace(name_stack=src.name_stack + eqn.source_info.name_stack)) - trace.frame.add_eqn(core.new_jaxpr_eqn(invars, outvars, eqn.primitive, - eqn.params, eqn.effects, src_)) + trace.frame.add_eqn(eqn.replace(invars, outvars, source_info=src_)) # type: ignore map(env.setdefault, eqn.outvars, outvars) tracer_env: dict[Var, Any] = dict(zip([*jaxpr.constvars, *jaxpr.invars], [*consts, *arg_tracers])) def new_tracer(atom): - tracer = DynamicJaxprTracer(trace, atom.aval, src) + tracer = tracer_env[atom] = DynamicJaxprTracer(trace, atom.aval, src) trace.frame.tracers.append(tracer) trace.frame.tracer_to_var[id(tracer)] = env[atom] return tracer diff --git a/jax/_src/interpreters/pxla.py b/jax/_src/interpreters/pxla.py index 88297bd9204b..4c134f266da5 100644 --- a/jax/_src/interpreters/pxla.py +++ b/jax/_src/interpreters/pxla.py @@ -22,6 +22,7 @@ from collections.abc import Callable, Sequence, Iterable, Iterator import dataclasses from functools import partial, lru_cache, cached_property +import functools import itertools as it import logging import math @@ -32,6 +33,7 @@ import jax +from jax._src import api from jax._src import api_util from jax._src import compiler from jax._src import config @@ -60,6 +62,7 @@ from jax._src.interpreters import xla from jax._src.layout import DeviceLocalLayout, AutoLayout, Layout from jax._src.lib import xla_client as xc +from jax._src.lib import xla_extension_version from jax._src.lib.mlir import ir from jax._src.lib.mlir.dialects import hlo from jax._src.partition_spec import PartitionSpec @@ -87,6 +90,7 @@ class WeakRefList(list): logger = logging.getLogger(__name__) Index = Union[int, slice, tuple[Union[int, slice], ...]] +PyTreeDef = tree_util.PyTreeDef NoSharding = sharding_specs.NoSharding Chunked = sharding_specs.Chunked @@ -106,39 +110,68 @@ class WeakRefList(list): def identity(x): return x @profiler.annotate_function -def shard_args(shardings: Sequence[JSharding], args, canonicalize=True) -> Sequence[xc.ArrayImpl]: +def shard_args(shardings: Sequence[JSharding], layouts, args, + canonicalize=True) -> Sequence[xc.ArrayImpl]: # Fast path for one argument. if len(args) == 1: arg = args[0] if canonicalize: arg = xla.canonicalize_dtype(arg) - return shard_arg_handlers[type(arg)]([arg], shardings) + return shard_arg_handlers[type(arg)]([arg], shardings, layouts) - # type(arg) -> (indices, args, shardings) - batches = collections.defaultdict(lambda: ([], [], [])) # type: ignore - for i, (arg, sharding) in enumerate(safe_zip(args, shardings)): + # type(arg) -> (list[indices], list[args], list[shardings]) + batches = collections.defaultdict(lambda: ([], [], [], [])) # type: ignore + for i, (arg, sharding, layout) in enumerate(safe_zip(args, shardings, layouts)): if canonicalize: arg = xla.canonicalize_dtype(arg) batch = batches[type(arg)] batch[0].append(i) batch[1].append(arg) batch[2].append(sharding) + batch[3].append(layout) # Call `shard_arg_handlers` per batch and build a flat list of arrays returned # from each call in the same order as `args`. Since `batches` is grouped by # types, we cannot simply flatten the results and we have to use the original # indices to put each array back to its original position. results: list[jax.Array | None] = [None] * len(args) - for t, (indices, a, s) in batches.items(): - outs = shard_arg_handlers[t](a, s) + for t, (indices, a, s, l) in batches.items(): + outs = shard_arg_handlers[t](a, s, l) for i, out in safe_zip(indices, outs): results[i] = out - assert all(result is not None for result in results) return results -shard_arg_handlers: dict[Any, Callable[[Sequence[Any], Sequence[Any]], Sequence[Any]]] = {} +shard_arg_handlers: dict[ + Any, Callable[[Sequence[Any], Sequence[Any], Sequence[Any]], Sequence[Any]] +] = {} + + +@lru_cache(maxsize=2048) +def is_default_layout(curr_layout, sharding, aval): + if curr_layout is None or sharding is None or is_unspecified(sharding): + return True + if (aval is core.abstract_token or aval.dtype == dtypes.float0 or + dtypes.issubdtype(aval.dtype, dtypes.extended)): + return True + if isinstance(curr_layout, AutoLayout): + return False + d = sharding._device_assignment[0] + shard_shape = sharding.shard_shape(aval.shape) + try: + # TODO(yashkatariya): Replace this with normal `==` check once CPU supports + # int4. + return is_user_xla_layout_equal( + curr_layout, + DeviceLocalLayout.from_pjrt_layout( + d.client.get_default_layout(aval.dtype, shard_shape, d))) + except xe.XlaRuntimeError as e: + msg, *_ = e.args + if isinstance(msg, str) and msg.startswith("UNIMPLEMENTED"): + return True + else: + raise @lru_cache(maxsize=1024) @@ -146,34 +179,37 @@ def _get_replicated_slices(num_addressable_devices: int): return ((slice(None),),) * num_addressable_devices -def _masked_array_error(xs, shardings): +def _masked_array_error(xs, shardings, layouts): raise ValueError("numpy masked arrays are not supported as direct inputs to JAX functions. " "Use arr.filled() to convert the value to a standard numpy array.") shard_arg_handlers[np.ma.MaskedArray] = _masked_array_error -def _shard_array(xs, shardings): +def _shard_np_array(xs, shardings, layouts): results = [] - for x, sharding in safe_zip(xs, shardings): + for x, sharding, layout in safe_zip(xs, shardings, layouts): devices = sharding._addressable_device_assignment if x.dtype == dtypes.float0: x = np.zeros(x.shape, dtype=np.dtype(bool)) aval = api_util.shaped_abstractify(x) - if sharding.is_fully_replicated: - shards = [x] * len(devices) + if layout is not None: + results.append(api.device_put(x, Layout(layout, sharding))) else: - indices = tuple(sharding.addressable_devices_indices_map(x.shape).values()) - shards = [x[i] for i in indices] - results.append(batched_device_put(aval, sharding, shards, devices)) + if sharding.is_fully_replicated: + shards = [x] * len(devices) + else: + indices = tuple(sharding.addressable_devices_indices_map(x.shape).values()) + shards = [x[i] for i in indices] + results.append(batched_device_put(aval, sharding, shards, devices)) return results for _t in array_types: - shard_arg_handlers[_t] = _shard_array + shard_arg_handlers[_t] = _shard_np_array -def _shard_darray(xs, shardings): - return shard_args(shardings, [x._data for x in xs]) +def _shard_darray(xs, shardings, layouts): + return shard_args(shardings, layouts, [x._data for x in xs]) shard_arg_handlers[core.DArray] = _shard_darray -def _shard_mutable_array(xs, shardings): - return shard_args(shardings, [x._buf for x in xs]) +def _shard_mutable_array(xs, shardings, layouts): + return shard_args(shardings, layouts, [x._buf for x in xs]) shard_arg_handlers[core.MutableArray] = _shard_mutable_array def batched_device_put(aval: core.ShapedArray, @@ -464,7 +500,7 @@ def process_map(self, map_primitive, fun, tracers, params): def process_custom_jvp_call(self, prim, fun, jvp, tracers, *, symbolic_zeros): if symbolic_zeros: msg = ("custom_jvp with symbolic_zeros=True not supported with eager pmap. " - "Please open an issue at https://github.com/google/jax/issues !") + "Please open an issue at https://github.com/jax-ml/jax/issues !") raise NotImplementedError(msg) del prim, jvp, symbolic_zeros # always base main, can drop jvp in_vals, in_axes = unzip2((t.val, t.shard_axes) for t in tracers) @@ -477,7 +513,7 @@ def process_custom_vjp_call(self, primitive, fun, fwd, bwd, tracers, out_trees, symbolic_zeros): if symbolic_zeros: msg = ("custom_vjp with symbolic_zeros=True not supported with eager pmap. " - "Please open an issue at https://github.com/google/jax/issues !") + "Please open an issue at https://github.com/jax-ml/jax/issues !") raise NotImplementedError(msg) del primitive, fwd, bwd, out_trees, symbolic_zeros # always base main, drop vjp in_vals, in_axes = unzip2((t.val, t.shard_axes) for t in tracers) @@ -931,6 +967,7 @@ def build_execute_fun(self): handle_outs = local_avals_to_results_handler(self.local_output_avals, self.output_shardings) handle_args = InputsHandler(self.input_shardings, + [None] * len(self.input_shardings), self.compiled.local_devices(), input_indices) execute_fun = ExecuteReplicated(self.compiled, "parallel computation", self.backend, handle_args, handle_outs, @@ -1109,12 +1146,15 @@ def _get_pmap_sharding(devices, specs): class InputsHandler: - __slots__ = ("handler", "local_devices", "in_shardings", "input_indices") + __slots__ = ("handler", "in_shardings", "in_layouts", "local_devices", + "input_indices") - def __init__(self, in_shardings, local_devices=None, input_indices=None): - self.handler = partial(shard_args, in_shardings) - self.local_devices = local_devices + def __init__(self, in_shardings, in_layouts, local_devices=None, + input_indices=None): + self.handler = partial(shard_args, in_shardings, in_layouts) self.in_shardings = in_shardings + self.in_layouts = in_layouts + self.local_devices = local_devices self.input_indices = input_indices def __call__(self, input_buffers): @@ -1122,8 +1162,9 @@ def __call__(self, input_buffers): def __str__(self): return ("InputsHandler(\n" - f"local_devices={self.local_devices},\n" f"in_shardings={self.in_shardings},\n" + f"in_layouts={self.in_layouts},\n" + f"local_devices={self.local_devices},\n" f"input_indices={self.input_indices})") @@ -1388,7 +1429,7 @@ def _hlo_shard(aval, axis_env, x, in_axis): return x elif isinstance(aval, core.ShapedArray): if dtypes.issubdtype(aval.dtype, dtypes.extended): - aval = aval.dtype._rules.physical_element_aval(aval.dtype) + aval = core.physical_element_aval(aval.dtype) dims = list(aval.shape) zero = mlir.ir_constant(np.zeros((), dtype=np.uint32)) idxs = [zero] * len(dims) @@ -1828,7 +1869,7 @@ def _raise_warnings_or_errors_for_jit_of_pmap( "does not preserve sharded data representations and instead collects " "input and output arrays onto a single device. " "Consider removing the outer jit unless you know what you're doing. " - "See https://github.com/google/jax/issues/2926.") + "See https://github.com/jax-ml/jax/issues/2926.") if nreps > xb.device_count(backend): raise ValueError( @@ -1843,35 +1884,6 @@ def _raise_warnings_or_errors_for_jit_of_pmap( "extra data movement anyway, so maybe you don't want it after all).") -@lru_cache(maxsize=2048) -def _maybe_get_default_layout(arg_layout, jit_in_layout, sharding, aval - ) -> DeviceLocalLayout | None: - if is_unspecified_or_auto(sharding): - return None - # TODO(yashkatariya): Figure out how layouts work with extended dtypes. - if dtypes.issubdtype(aval.dtype, dtypes.extended): - return None - if not core.is_constant_shape(aval.shape): - return None - shard_shape = sharding.shard_shape(aval.shape) - d = sharding._device_assignment[0] - # If a backend doesn't implement `get_default_layout` return `None` to avoid - # cache misses. This can happen when you have `jit(f, in_shardings=s)`. On - # first call you pass it a sharded array with layout and on second call you - # pass a numpy array. The layouts should be the same to get cache hits. - try: - al = DeviceLocalLayout.from_pjrt_layout( - d.client.get_default_layout(aval.dtype, shard_shape, d)) - except: - return None - # argument does not have `.layout` property. ShapedArray, numpy array, etc - # are some examples. - if arg_layout is None: - return al if jit_in_layout is None else arg_layout # arg_layout is None - # If arg has a `.layout` property, then return device_local_layout as is. - return arg_layout.device_local_layout - - @weakref_lru_cache def _cached_lowering_to_hlo(closed_jaxpr, api_name, fun_name, backend, semantic_in_shardings, semantic_out_shardings, @@ -1912,7 +1924,7 @@ def _cached_lowering_to_hlo(closed_jaxpr, api_name, fun_name, backend, out_mlir_shardings = map(_to_logical_sharding, global_out_avals, out_shardings) replicated_args = [False] * len(global_in_avals) axis_ctx = sharding_impls.ShardingContext(num_devices, device_assignment, - mesh_shape=mesh_shape_tuple) + mesh_shape_tuple) num_partitions = num_devices else: # This path is triggered for `jit(pmap)` cases. @@ -1989,7 +2001,7 @@ def are_all_shardings_default_mem_kind(da_object, shardings): except: return True for i in shardings: - if is_unspecified_or_auto(i): + if is_unspecified_or_auto(i) or i.memory_kind is None: continue if i.memory_kind != default_mem_kind: return False @@ -2172,8 +2184,6 @@ def lower_sharding_computation( devices_from_context) platforms = lowering_platforms or (backend.platform,) - # TODO(yashkatariya): Enable this when offload APIs are stable. - # transfer_mem_kind_in_jaxpr = list(jaxpr_transfer_mem_kinds(jaxpr)) committed = bool( devices_from_context or @@ -2184,34 +2194,39 @@ def lower_sharding_computation( da_object = _create_da_object(tuple(device_assignment)) + transfer_mem_kind_in_jaxpr = list(jaxpr_transfer_mem_kinds(jaxpr)) all_default_mem_kind = are_all_shardings_default_mem_kind( da_object, it.chain(in_shardings, out_shardings, - [js for js, _ in unique_intermediate_shardings])) + [js for js, _ in unique_intermediate_shardings], + transfer_mem_kind_in_jaxpr)) # pytype: disable=wrong-arg-types - # TODO(yashkatariya): Remove this when XLA can propagate memory kinds or when - # JAX puts memory kinds in the types of jaxpr. - if not all_default_mem_kind: + if all_default_mem_kind: + propagated_out_mem_kinds = (None,) * len(global_out_avals) + else: propagated_out_mem_kinds = get_out_memory_kinds_via_propagation( closed_jaxpr, in_shardings) - else: - propagated_out_mem_kinds = (None,) * len(global_out_avals) # 2. Build up the HLO semantic_in_shardings = SemanticallyEqualShardings( in_shardings, global_in_avals) # type: ignore semantic_out_shardings = SemanticallyEqualShardings( out_shardings, global_out_avals) # type: ignore + prim_requires_devices = dispatch.jaxpr_has_prim_requiring_devices(jaxpr) - # TODO(yashkatariya): Initialize with context_mesh here? mesh_shape_tuple = None - for sharding in it.chain( - in_shardings, out_shardings, - [js for js, _ in unique_intermediate_shardings]): - if isinstance(sharding, sharding_impls.NamedSharding): - mesh_shape_tuple = sharding.mesh.shape_tuple - break + if config.use_shardy_partitioner.value or prim_requires_devices: + for sharding in it.chain(in_shardings, out_shardings, + [js for js, _ in unique_intermediate_shardings]): + if isinstance(sharding, (sharding_impls.NamedSharding, sharding_impls.AUTO)): + if (mesh_shape_tuple is not None and + mesh_shape_tuple != sharding.mesh.shape_tuple): + raise ValueError( + "mesh should be the same across the entire program. Got mesh" + f" shape for one sharding {mesh_shape_tuple} and" + f" {sharding.mesh.shape_tuple} for another") + mesh_shape_tuple = sharding.mesh.shape_tuple (module, keepalive, host_callbacks, unordered_effects, ordered_effects, nreps, tuple_args, shape_poly_state) = _cached_lowering_to_hlo( @@ -2253,6 +2268,8 @@ def lower_sharding_computation( out_layouts=out_layouts, pmap_nreps=nreps, shape_poly_state=shape_poly_state, + # TODO(yashkatariya): Remove `all_default_mem_kind` after + # MemoryDescription works in OSS. all_default_mem_kind=all_default_mem_kind, all_args_info=all_args_info, pgle_profiler=pgle_profiler, @@ -2322,18 +2339,17 @@ def get_out_shardings_from_executable( ) -> Sequence[sharding_impls.GSPMDSharding] | None: from jax._src import pjit - if config.enable_memories.value: - if all_default_mem_kind: - omk = [None] * num_out_avals - else: - try: - omk = xla_executable.get_output_memory_kinds()[0] - if num_ordered_effects > 0: - omk = omk[num_ordered_effects:] - except: - omk = [None] * num_out_avals - else: + # TODO(yashkatariya): Remove `all_default_mem_kind` branch after + # MemoryDescription works in OSS. + if all_default_mem_kind: omk = [None] * num_out_avals + else: + try: + omk = xla_executable.get_output_memory_kinds()[0] + if num_ordered_effects > 0: + omk = omk[num_ordered_effects:] + except: + omk = [None] * num_out_avals assert len(omk) == num_out_avals, (len(omk), num_out_avals) @@ -2422,6 +2438,7 @@ def _register_out_sharding_handler( def _gspmd_to_named_sharding( out_s: sharding_impls.GSPMDSharding, orig_in_s: sharding_impls.NamedSharding) -> sharding_impls.NamedSharding: + assert isinstance(orig_in_s.mesh, mesh_lib.Mesh) return sharding_impls._gspmd_to_named_sharding_via_mesh(out_s, orig_in_s.mesh) _register_out_sharding_handler( @@ -2500,17 +2517,11 @@ def maybe_recover_user_shardings( def is_user_xla_layout_equal(ul: DeviceLocalLayout | AutoLayout, xl: DeviceLocalLayout) -> bool: - if isinstance(ul, DeviceLocalLayout) and ul._tiling is None: + if isinstance(ul, DeviceLocalLayout) and not ul._tiling: return ul.major_to_minor == xl.major_to_minor else: return ul == xl -def _check_user_xla_layout(ul, xl, what: str): - if not is_user_xla_layout_equal(ul, xl): - raise AssertionError( - f"Unexpected XLA layout override: (XLA) {xl} != {ul} " - f"(User {what} layout)") - def _get_layouts_from_executable( xla_executable, in_layouts, out_layouts, num_ordered_effects @@ -2526,19 +2537,23 @@ def _get_layouts_from_executable( out_layouts_xla = out_layouts_xla[num_ordered_effects:] new_in_layouts = [] - for x, i in safe_zip(in_layouts_xla, in_layouts): + for x, l in safe_zip(in_layouts_xla, in_layouts): x = DeviceLocalLayout.from_pjrt_layout(x) - if isinstance(i, DeviceLocalLayout): - _check_user_xla_layout(i, x, "input") + if isinstance(l, DeviceLocalLayout) and not is_user_xla_layout_equal(l, x): + raise AssertionError( + f"Unexpected XLA layout override: (XLA) {x} != {l} " + f"(User input layout)") # Always append the XLA layout because it has the full information # (tiling, etc) even if the user layout does not specify tiling. new_in_layouts.append(x) new_out_layouts = [] - for x, o in safe_zip(out_layouts_xla, out_layouts): + for x, l in safe_zip(out_layouts_xla, out_layouts): x = DeviceLocalLayout.from_pjrt_layout(x) - if isinstance(o, DeviceLocalLayout): - _check_user_xla_layout(o, x, "output") + if isinstance(l, DeviceLocalLayout) and not is_user_xla_layout_equal(l, x): + raise AssertionError( + f"Unexpected XLA layout override: (XLA) {x} != {l} " + f"(User output layout)") # Always append the XLA layout because it has the full information # (tiling, etc) even if the user layout does not specify tiling. new_out_layouts.append(x) @@ -2731,13 +2746,14 @@ class UnloadedMeshExecutable: kept_var_idx: set[int] mut: MutationData | None auto_spmd_lowering: bool - in_layouts: Sequence[DeviceLocalLayout | None] - out_layouts: Sequence[DeviceLocalLayout | None] + xla_in_layouts: Sequence[DeviceLocalLayout | None] + dispatch_in_layouts: Sequence[DeviceLocalLayout | None] + xla_out_layouts: Sequence[DeviceLocalLayout | None] all_args_info: AllArgsInfo | None pgle_profiler: profiler.PGLEProfiler | None def build_unsafe_call(self): - handle_args = InputsHandler(self.input_shardings) + handle_args = InputsHandler(self.input_shardings, self.dispatch_in_layouts) handle_outs = global_avals_to_results_handler( self.output_avals, self.output_shardings, self.committed) @@ -2753,8 +2769,8 @@ def load(self) -> MeshExecutable: self.input_avals, self.output_avals, self.input_shardings, self.output_shardings, self.auto_spmd_lowering, self.kept_var_idx, - self.in_layouts, self.out_layouts, - self.all_args_info, self) + self.xla_in_layouts, self.dispatch_in_layouts, + self.xla_out_layouts, self.all_args_info, self) @staticmethod def from_hlo(name: str, @@ -2837,8 +2853,18 @@ def from_hlo(name: str, in_shardings, out_shardings, committed, da = _get_metadata_jit_pmap( xla_executable.local_devices(), len(in_shardings), len(out_shardings)) - in_layouts, out_layouts = _get_layouts_from_executable( + # xla_in_layouts are all either None or DeviceLocalLayout. Even default + # layout are concrete layouts and they are used in `compiled.input_layouts` + # to return concrete layouts to users. + # `dispatch_in_layouts` replaces default layouts with `None` to simplify + # dispatch logic downstream. + xla_in_layouts, xla_out_layouts = _get_layouts_from_executable( xla_executable, in_layouts, out_layouts, len(ordered_effects)) + del in_layouts, out_layouts + dispatch_in_layouts = [ + None if is_default_layout(l, s, a) else l + for l, s, a, in safe_zip(xla_in_layouts, in_shardings, global_in_avals) + ] out_shardings = maybe_recover_user_shardings( in_shardings, out_shardings, global_in_avals, global_out_avals, @@ -2863,8 +2889,9 @@ def from_hlo(name: str, kept_var_idx=kept_var_idx, mut=mut, auto_spmd_lowering=auto_spmd_lowering, - in_layouts=in_layouts, - out_layouts=out_layouts, + xla_in_layouts=xla_in_layouts, + dispatch_in_layouts=dispatch_in_layouts, + xla_out_layouts=xla_out_layouts, all_args_info=all_args_info, pgle_profiler=pgle_profiler).load() @@ -2877,9 +2904,35 @@ class MeshExecutableFastpathData(NamedTuple): out_avals: Sequence[ShapedArray] out_committed: Sequence[bool] kept_var_bitvec: Iterable[bool] - # TODO(yashkatariya): Remove once minimum jaxlib version is 0.4.24 - arg_handler_devices: Sequence[xc.Device] - arg_handler_indices: Sequence[tuple[Index | None, ...]] + in_device_local_layouts: Sequence[DeviceLocalLayout | None] + + +@dataclasses.dataclass(frozen=True, kw_only=True) +class JitGlobalCppCacheKeys: + donate_argnums: tuple[int, ...] | None = None + donate_argnames: tuple[str, ...] | None = None + device: xc.Device | None = None + backend: str | None = None + in_shardings_treedef: PyTreeDef | None = None + in_shardings_leaves: tuple[Any, ...] | None = None + out_shardings_treedef: PyTreeDef | None = None + out_shardings_leaves: tuple[Any, ...] | None = None + in_layouts_treedef: PyTreeDef | None = None + in_layouts_leaves: tuple[Any, ...] | None = None + out_layouts_treedef: PyTreeDef | None = None + out_layouts_leaves: tuple[Any, ...] | None = None + use_resource_env: bool = False + + @functools.cached_property + def contains_explicit_attributes(self): + return (self.donate_argnums is not None or + self.donate_argnames is not None or + self.device is not None or + self.backend is not None or + any(not is_unspecified(i) for i in self.in_shardings_leaves) or + any(not is_unspecified(o) for o in self.out_shardings_leaves) or + any(i is not None for i in self.in_layouts_leaves) or + any(o is not None for o in self.out_layouts_leaves)) def reflatten_outputs_for_dispatch(out_tree, out_flat): @@ -2894,13 +2947,13 @@ class MeshExecutable(stages.XlaExecutable): __slots__ = [ "xla_executable", "_unsafe_call", "build_unsafe_call", "in_avals", "out_avals", "_in_shardings", "_out_shardings", "_auto_spmd_lowering", - "_kept_var_idx", "_in_layouts", "_out_layouts", "_all_args_info", - "_unloaded_executable", + "_kept_var_idx", "_xla_in_layouts", "_dispatch_in_layouts", + "_xla_out_layouts", "_all_args_info", "_unloaded_executable", ] def __init__(self, xla_executable, build_unsafe_call, in_avals, out_avals, in_shardings, out_shardings, auto_spmd_lowering, kept_var_idx, - in_layouts, out_layouts, + xla_in_layouts, dispatch_in_layouts, xla_out_layouts, all_args_info: AllArgsInfo | None = None, unloaded_executable=None): self.xla_executable = xla_executable @@ -2914,8 +2967,9 @@ def __init__(self, xla_executable, build_unsafe_call, in_avals, out_avals, self._out_shardings = out_shardings self._auto_spmd_lowering = auto_spmd_lowering self._kept_var_idx = kept_var_idx - self._in_layouts = in_layouts - self._out_layouts = out_layouts + self._xla_in_layouts = xla_in_layouts + self._dispatch_in_layouts = dispatch_in_layouts + self._xla_out_layouts = xla_out_layouts self._all_args_info = all_args_info self._unloaded_executable = unloaded_executable @@ -2943,9 +2997,8 @@ def call(self, *args): all_arg_avals = map(xla.abstractify, kept_args) check_arg_avals_for_call(ref_avals, all_arg_avals, debug_info) - # Check the GDA sharding and the input sharding. check_array_xla_sharding_layout_match( - args_after_dce, self._in_shardings, self._in_layouts, debug_info, + args_after_dce, self._in_shardings, self._xla_in_layouts, debug_info, self._kept_var_idx) return self.unsafe_call(*args) # pylint: disable=not-callable @@ -2957,11 +3010,11 @@ def output_shardings(self) -> Sequence[JSharding]: def input_layouts(self): return [Layout(l, s) - for l, s in safe_zip(self._in_layouts, self._in_shardings)] + for l, s in safe_zip(self._xla_in_layouts, self._in_shardings)] def output_layouts(self): return [Layout(l, s) - for l, s in safe_zip(self._out_layouts, self._out_shardings)] + for l, s in safe_zip(self._xla_out_layouts, self._out_shardings)] def create_cpp_call(self, no_kwargs, in_tree, out_tree): if not (isinstance(self.unsafe_call, ExecuteReplicated) and @@ -2990,15 +3043,22 @@ def aot_cache_miss(*args, **kwargs): fastpath_data = MeshExecutableFastpathData( self.xla_executable, out_tree_dispatch, in_shardings, self._out_shardings, out_avals, out_committed, kept_var_bitvec, - self.unsafe_call.in_handler.local_devices, - self.unsafe_call.in_handler.input_indices) + self._dispatch_in_layouts) else: fastpath_data = None return outs, fastpath_data, False # Do not remove cache entry - return xc._xla.pjit( - self.unsafe_call.name, None, aot_cache_miss, [], [], [], - tree_util.dispatch_registry, lambda x, s: shard_args([s], [x])[0]) + if xla_extension_version >= 286: + return xc._xla.pjit( + self.unsafe_call.name, None, aot_cache_miss, [], [], + JitGlobalCppCacheKeys(), tree_util.dispatch_registry, cc_shard_arg) + else: + return xc._xla.pjit( + self.unsafe_call.name, None, aot_cache_miss, [], [], [], + tree_util.dispatch_registry, cc_shard_arg) + +def cc_shard_arg(x, sharding, layout): + return shard_args([sharding], [layout], [x])[0] def check_arg_avals_for_call(ref_avals, arg_avals, diff --git a/jax/_src/lax/ann.py b/jax/_src/lax/ann.py index f2dbd8d4fa0e..0e037ec774b5 100644 --- a/jax/_src/lax/ann.py +++ b/jax/_src/lax/ann.py @@ -373,7 +373,7 @@ def _approx_top_k_jvp(primals, tangents, *, k, reduction_dimension, reduction_input_size_override, aggregate_to_topk) if type(tangent) is ad_util.Zero: - tangent_out = ad_util.Zero.from_value(val_out) + tangent_out = ad_util.Zero.from_primal_value(val_out) else: arg_shape = arg_out.shape rank = len(arg_shape) @@ -385,7 +385,7 @@ def _approx_top_k_jvp(primals, tangents, *, k, reduction_dimension, idx = tuple( arg_out if i == reduction_dimension else iotas[i] for i in range(rank)) tangent_out = tangent[idx] - return (val_out, arg_out), (tangent_out, ad_util.Zero.from_value(arg_out)) + return (val_out, arg_out), (tangent_out, ad_util.Zero.from_primal_value(arg_out)) approx_top_k_p = core.Primitive('approx_top_k') diff --git a/jax/_src/lax/control_flow/__init__.py b/jax/_src/lax/control_flow/__init__.py index 05dcade84999..5e6fa86f706e 100644 --- a/jax/_src/lax/control_flow/__init__.py +++ b/jax/_src/lax/control_flow/__init__.py @@ -33,4 +33,4 @@ _initial_style_jaxprs_with_common_consts, _check_tree_and_avals) # TODO(mattjj): fix dependent library which expects optimization_barrier_p here -from jax._src.ad_checkpoint import optimization_barrier_p +from jax._src.lax.lax import optimization_barrier_p diff --git a/jax/_src/lax/control_flow/conditionals.py b/jax/_src/lax/control_flow/conditionals.py index 8161606801c2..d3065d0f96d7 100644 --- a/jax/_src/lax/control_flow/conditionals.py +++ b/jax/_src/lax/control_flow/conditionals.py @@ -23,7 +23,6 @@ import operator from typing import Any, TypeVar -import jax from jax.tree_util import tree_flatten, tree_unflatten from jax._src import ad_util from jax._src import config @@ -275,12 +274,8 @@ def cond(pred, true_fun, false_fun, *operands): num_consts = len(consts) out_ = iter(out) - def _cast_to_array(x): - _copy = isinstance(x, np.bool_) - return jax.numpy.asarray(x, copy=_copy) - out = [ - next(out_) if fwd is None else _cast_to_array(ops[fwd - num_consts]) + next(out_) if fwd is None else lax.asarray(ops[fwd - num_consts]) for fwd in in_fwd ] assert next(out_, None) is None @@ -394,7 +389,7 @@ def _cond_batching_rule(spmd_axis_name, axis_size, axis_name, main_type, args, branch_outs = [] for i, jaxpr in enumerate(branches_batched): # Perform a select on the inputs for safety of reverse-mode autodiff; see - # https://github.com/google/jax/issues/1052 + # https://github.com/jax-ml/jax/issues/1052 predicate = lax.eq(index, lax._const(index, i)) ops_ = [_bcast_select(predicate, x, lax.stop_gradient(x)) for x in ops] branch_outs.append(core.jaxpr_as_fun(jaxpr)(*ops_)) @@ -439,7 +434,7 @@ def _cond_jvp(primals, tangents, branches): out = cond_p.bind(index, *ops, *ops_dot, branches=branches_jvp) out_primals, out_tangents = split_list(out, [len(out_nz)]) out_tangents_iter = iter(out_tangents) - out_tangents = [next(out_tangents_iter) if nz else ad_util.Zero.from_value(p) + out_tangents = [next(out_tangents_iter) if nz else ad_util.Zero.from_primal_value(p) for p, nz in zip(out_primals, out_nz)] return out_primals, out_tangents diff --git a/jax/_src/lax/control_flow/for_loop.py b/jax/_src/lax/control_flow/for_loop.py index 15249e531144..21b522b3d8bb 100644 --- a/jax/_src/lax/control_flow/for_loop.py +++ b/jax/_src/lax/control_flow/for_loop.py @@ -20,7 +20,6 @@ import operator from typing import Any, Generic, TypeVar -import jax.numpy as jnp from jax import lax from jax.api_util import flatten_fun_nokwargs from jax._src.interpreters import ad @@ -46,6 +45,7 @@ split_list, split_dict, weakref_lru_cache) from jax._src.lax.control_flow import loops from jax._src.lax.control_flow.common import _abstractify, _initial_style_jaxpr +import numpy as np ## JAX utilities @@ -132,7 +132,7 @@ def wrapped_body(i, refs): nsteps, = nsteps flat_state, state_tree = tree_flatten(init_state) state_avals = map(state_utils.val_to_ref_aval, flat_state) - idx_aval = core.ShapedArray((), dtypes.canonicalize_dtype(jnp.int64)) + idx_aval = core.ShapedArray((), dtypes.canonicalize_dtype(np.int64)) jaxpr, consts, out_tree = _trace_to_jaxpr_with_refs( body, state_tree, [idx_aval, *state_avals]) if out_tree != tree_structure(None): @@ -202,7 +202,7 @@ def _create_jaxpr(init): return jaxpr, out_tree jaxpr, out_tree = _create_jaxpr(init) _, ys_avals = tree_unflatten(out_tree, jaxpr.out_avals) - ys = tree_map(lambda aval: jnp.zeros([length, *aval.shape], aval.dtype), + ys = tree_map(lambda aval: lax.full([length, *aval.shape], 0, aval.dtype), ys_avals) def for_body(i, refs): carry_refs, xs_refs, ys_refs = refs @@ -251,7 +251,7 @@ def body(i, state): def _for_impl_unrolled(body, nsteps, unroll, *args): remainder = nsteps % unroll - i = jnp.astype(0, dtypes.canonicalize_dtype(jnp.int64)) + i = lax.full((), 0, dtypes.canonicalize_dtype(np.int64)) state = list(args) for _ in range(remainder): @@ -340,7 +340,7 @@ def _for_jvp(primals, tangents, *, jaxpr, nsteps, reverse, which_linear, # into outputs as well. We don't care about these in AD so we throw them out. out_primals, out_tangents = split_list(out_flat, [len(primals)]) out_tangents_iter = iter(out_tangents) - out_tangents = [next(out_tangents_iter) if nz else ad_util.Zero.from_value(p) + out_tangents = [next(out_tangents_iter) if nz else ad_util.Zero.from_primal_value(p) for p, nz in zip(out_primals, nonzero_tangents)] return out_primals, out_tangents ad.primitive_jvps[for_p] = _for_jvp @@ -748,7 +748,7 @@ def discharged_for_loop(nsteps, body, init_state, *, reverse: bool = False): """ flat_state, state_tree = tree_flatten(init_state) state_avals = map(state_utils.val_to_ref_aval, flat_state) - idx_aval = core.ShapedArray((), dtypes.canonicalize_dtype(jnp.int64)) + idx_aval = core.ShapedArray((), dtypes.canonicalize_dtype(np.int64)) jaxpr, consts, out_tree = _trace_to_jaxpr_with_refs( body, state_tree, [idx_aval, *state_avals]) if out_tree != tree_structure(None): @@ -756,7 +756,7 @@ def discharged_for_loop(nsteps, body, init_state, *, reverse: bool = False): discharged_jaxpr, discharged_consts = discharge_state(jaxpr, consts) def fori_body(i, carry): - i = jnp.astype(i, dtypes.canonicalize_dtype(jnp.int64)) + i = lax.convert_element_type(i, dtypes.canonicalize_dtype(np.int64)) if reverse: i = nsteps - i - 1 out_flat = core.eval_jaxpr(discharged_jaxpr, discharged_consts, diff --git a/jax/_src/lax/control_flow/loops.py b/jax/_src/lax/control_flow/loops.py index 24029b92873e..7a9596bf2c0d 100644 --- a/jax/_src/lax/control_flow/loops.py +++ b/jax/_src/lax/control_flow/loops.py @@ -22,7 +22,6 @@ from typing import Any, TypeVar import weakref -import jax from jax._src import ad_checkpoint from jax._src import ad_util from jax._src import api @@ -42,6 +41,7 @@ from jax._src.interpreters import mlir from jax._src.interpreters import partial_eval as pe from jax._src.interpreters import pxla +from jax._src import sharding_impls as sharding from jax._src.interpreters import xla from jax._src.lax import lax from jax._src.lax import slicing @@ -50,9 +50,9 @@ _abstractify, _avals_short, _initial_style_jaxpr, _initial_style_jaxpr_attrs, _make_closed_jaxpr_attrs, _prune_zeros, _typecheck_param) +from jax._src.lax.other import logaddexp from jax._src.lib.mlir import ir from jax._src.lib.mlir.dialects import hlo -from jax._src.numpy.ufuncs import logaddexp from jax._src.state import discharge as state_discharge from jax._src.traceback_util import api_boundary from jax._src.tree_util import equality_errors @@ -67,6 +67,7 @@ unzip2, weakref_lru_cache, ) +from jax._src import xla_bridge as xb from jax.tree_util import ( keystr, tree_flatten, @@ -85,6 +86,9 @@ ### Helper functions +def _stack(arrs: Sequence[Array], axis: int=0) -> Array: + return lax.concatenate([lax.expand_dims(arr, (axis,)) for arr in arrs], dimension=axis) + def _promote_weak_typed_inputs(in_vals, in_avals, out_avals): """Promote weakly-typed in_vals to be compatible with out_avals. @@ -224,7 +228,11 @@ def scan(f, init, xs, length=None): if not hasattr(x, 'shape')))) from err if length is not None: - length = int(length) + try: + length = int(length) + except core.ConcretizationTypeError as err: + msg = 'The `length` argument to `scan` expects a concrete `int` value.' + raise core.ConcretizationTypeError(length, msg) from None # type: ignore[arg-type] if not all(length == l for l in lengths): msg = ("scan got `length` argument of {} which disagrees with " "leading axis sizes {}.") @@ -250,7 +258,7 @@ def scan(f, init, xs, length=None): xs_slice = [slicing.index_in_dim(x, i, keepdims=False) for x in xs_flat] carry, y = f(carry, tree_unflatten(xs_tree, xs_slice)) ys.append(y) - stack = lambda *ys: jax.numpy.stack(ys) + stack = lambda *ys: _stack(ys) stacked_y = tree_map(stack, *maybe_reversed(ys)) return carry, stacked_y @@ -268,7 +276,8 @@ def _create_jaxpr(init): if len(out_tree_children) != 2: msg = "scan body output must be a pair, got {}." raise TypeError(msg.format(tree_unflatten(out_tree, jaxpr.out_avals))) - carry_avals_out = jaxpr.out_avals[:out_tree_children[0].num_leaves] + _, carry_avals_out, _ = split_list( + jaxpr.out_avals, [len(attrs_tracked), out_tree_children[0].num_leaves]) return (init_flat, carry_avals, carry_avals_out, init_tree, in_flat, jaxpr, consts, out_tree, out_tree_children, attrs_tracked) @@ -291,6 +300,10 @@ def _create_jaxpr(init): raise NotImplementedError( f'Effects not supported in `scan`: {disallowed_effects}') + unroll = core.concrete_or_error( + None, unroll, + "The `unroll` argument to `scan` expects a concrete `int` or `bool` " + "value.") if isinstance(unroll, bool): unroll = max(length, 1) if unroll else 1 if unroll < 1: @@ -440,11 +453,11 @@ def inner(n, carry, xs): carry, y = split_list(carry_y, [num_carry]) ys.append(y) ys = list(reversed(ys)) if reverse else ys - return carry, _map(jax.numpy.stack, zip(*ys)) + return carry, _map(_stack, zip(*ys)) if num_trips: i = lax._const(num_trips, 0) - _, carry, yss = jax.lax.while_loop(cond_fun, body_fun, (i, carry, yss)) + _, carry, yss = while_loop(cond_fun, body_fun, (i, carry, yss)) if unroll != 1: ys = [lax.reshape(ys, (num_trips * unroll, *ys.shape[2:])) for ys in yss] else: @@ -534,7 +547,7 @@ def _scan_jvp(primals, tangents, reverse, length, jaxpr, num_consts, num_carry, carry, carry_dot, ys, ys_dot = split_list(out_flat, [num_carry, len(init_dot), num_ys]) primals_out = carry + ys tangents_out_iter = iter(carry_dot + ys_dot) - tangents_out = [next(tangents_out_iter) if nz else ad_util.Zero.from_value(p) + tangents_out = [next(tangents_out_iter) if nz else ad_util.Zero.from_primal_value(p) for p, nz in zip(primals_out, nonzeros_out)] return primals_out, tangents_out @@ -685,9 +698,9 @@ def _scan_partial_eval(trace, *tracers, reverse, length, num_consts, num_carry, def _maybe_put(x): if isinstance(x, np.ndarray): aval = shaped_abstractify(x) - s = jax.sharding.SingleDeviceSharding(jax.local_devices(backend='cpu')[0]) + s = sharding.SingleDeviceSharding(xb.local_devices(backend='cpu')[0]) result_handler = pxla.global_aval_to_result_handler(aval, s, False) - return result_handler(pxla.shard_args([s], [x])) + return result_handler(pxla.shard_args([s], [None], [x])) else: return x @@ -702,7 +715,7 @@ def _scan_transpose(cts, *args, reverse, length, num_consts, if xs_lin != [True] * (len(xs_lin) - num_eres) + [False] * num_eres: raise NotImplementedError if not all(init_lin): - pass # TODO(mattjj): error check https://github.com/google/jax/issues/1963 + pass # TODO(mattjj): error check https://github.com/jax-ml/jax/issues/1963 consts, _, xs = split_list(args, [num_consts, num_carry]) ires, _ = split_list(consts, [num_ires]) @@ -1156,7 +1169,7 @@ def _scan_state_discharge_rule(in_avals, out_avals, *args, jaxpr, num_consts, if discharged_consts: raise NotImplementedError("Discharged jaxpr has consts. If you see this, " "please open an issue at " - "https://github.com/google/jax/issues") + "https://github.com/jax-ml/jax/issues") def wrapped(*wrapped_args): val_consts, ref_consts_in, carry_in, val_xs, ref_xs_in = split_list_checked(wrapped_args, [n_val_consts, n_ref_consts, n_carry, n_val_xs, n_ref_xs]) @@ -1505,7 +1518,7 @@ def _while_loop_jvp(primals, tangents, cond_nconsts, cond_jaxpr, body_nconsts, out_carry, out_carry_dot = split_list(out, [num_carry]) out_tangents_iter = iter(out_carry_dot) - out_tangents = [next(out_tangents_iter) if nz else ad_util.Zero.from_value(p) + out_tangents = [next(out_tangents_iter) if nz else ad_util.Zero.from_primal_value(p) for p, nz in zip(out_carry, nonzeros_out)] return out_carry, out_tangents @@ -1825,7 +1838,7 @@ def _while_discharge_rule(in_avals, out_avals, *args, cond_jaxpr, body_jaxpr, if body_jaxpr_consts: raise NotImplementedError("Body jaxpr has consts. If you see this error, " "please open an issue at " - "https://github.com/google/jax/issues") + "https://github.com/jax-ml/jax/issues") # body_jaxpr has the signature (*body_consts, *carry) -> carry. # Some of these body_consts are actually `Ref`s so when we discharge # them, they also turn into outputs, effectively turning those consts into @@ -2135,12 +2148,12 @@ def map(f, xs): """ if batch_size is not None: scan_xs, remainder_xs = _batch_and_remainder(xs, batch_size) - g = lambda _, x: ((), jax.vmap(f)(x)) + g = lambda _, x: ((), api.vmap(f)(x)) _, scan_ys = scan(g, (), scan_xs) - remainder_ys = jax.vmap(f)(remainder_xs) + remainder_ys = api.vmap(f)(remainder_xs) flatten = lambda x: x.reshape(-1, *x.shape[2:]) ys = tree_map( - lambda x, y: jax.numpy.concatenate([flatten(x), y], axis=0), scan_ys, remainder_ys, + lambda x, y: lax.concatenate([flatten(x), y], dimension=0), scan_ys, remainder_ys, ) else: g = lambda _, x: ((), f(x)) @@ -2158,7 +2171,7 @@ def _rng_bit_generator_batching_rule(batched_args, batch_dims, *, shape, dtype, key = keys[0] new_key, bits = lax.rng_bit_generator_p.bind(key, shape=(batch_size, *shape), dtype=dtype, algorithm=algorithm) - new_keys = jax.lax.dynamic_update_index_in_dim(keys, new_key, 0, axis=0) + new_keys = slicing.dynamic_update_index_in_dim(keys, new_key, 0, axis=0) return (new_keys, bits), (0, 0) batching.primitive_batchers[lax.rng_bit_generator_p] = _rng_bit_generator_batching_rule diff --git a/jax/_src/lax/control_flow/solves.py b/jax/_src/lax/control_flow/solves.py index 4d55907f6b37..4e0f5086b121 100644 --- a/jax/_src/lax/control_flow/solves.py +++ b/jax/_src/lax/control_flow/solves.py @@ -16,11 +16,12 @@ from functools import partial import operator -import jax from jax.tree_util import (tree_flatten, treedef_children, tree_leaves, tree_unflatten, treedef_tuple) from jax._src import ad_util +from jax._src import api from jax._src import core +from jax._src import custom_derivatives from jax._src import linear_util as lu from jax._src.core import raise_to_shaped from jax._src.interpreters import ad @@ -99,7 +100,7 @@ def custom_root(f, initial_guess, solve, tangent_solve, has_aux=False): _check_tree("solve", "initial_guess", solution_tree, in_tree, has_aux) def linearize_and_solve(x, b): - unchecked_zeros, f_jvp = jax.linearize(f, x) + unchecked_zeros, f_jvp = api.linearize(f, x) return tangent_solve(f_jvp, b) l_and_s_jaxpr, l_and_s_consts, out_tree = _initial_style_jaxpr( @@ -115,7 +116,7 @@ def linearize_and_solve(x, b): return tree_unflatten(solution_tree, solution_flat) -@partial(jax.custom_jvp, nondiff_argnums=(0, 1)) +@partial(custom_derivatives.custom_jvp, nondiff_argnums=(0, 1)) def _custom_root(const_lengths, jaxprs, *args): params, initial_guess = _split_root_args(args, const_lengths) solution = core.jaxpr_as_fun(jaxprs.solve)(*(params.solve + initial_guess)) @@ -169,7 +170,7 @@ def _split_linear_solve_args(args, const_lengths): def _transpose_one_output(linear_fun, primals): - transpose_fun = jax.linear_transpose(linear_fun, primals) + transpose_fun = api.linear_transpose(linear_fun, primals) def transposed_fun(x): (y,) = transpose_fun(x) return y @@ -315,7 +316,7 @@ def _tangent_linear_map(func, params, params_dot, *x): this function computes ``∂A @ x``. """ assert any(type(p) is not ad_util.Zero for p in params_dot) - zeros = _map(ad_util.Zero.from_value, x) + zeros = _map(ad_util.Zero.from_primal_value, x) _, out_tangent = ad.jvp(lu.wrap_init(func)).call_wrapped( params + list(x), params_dot + zeros) return out_tangent @@ -351,7 +352,7 @@ def _custom_linear_solve_jvp(primals, tangents, const_lengths, jaxprs): # split into x tangents and aux tangents (these become zero) dx_leaves, daux_leaves = split_list(x_dot, [num_x_leaves]) - daux_leaves = _map(ad_util.Zero.from_value, daux_leaves) + daux_leaves = _map(ad_util.Zero.from_primal_value, daux_leaves) x_dot = dx_leaves + daux_leaves diff --git a/jax/_src/lax/convolution.py b/jax/_src/lax/convolution.py index 2b2ad5bbb515..0e41fe5bb18f 100644 --- a/jax/_src/lax/convolution.py +++ b/jax/_src/lax/convolution.py @@ -114,7 +114,7 @@ def conv_general_dilated( - the input and output feature dimensions in rhs with the characters 'I' and 'O' respectively, and - spatial dimension correspondences between lhs, rhs, and the output using - any distinct characters. + any distinct characters. The examples below use 'W' and 'H'. For example, to indicate dimension numbers consistent with the ``conv`` function with two spatial dimensions, one could use ``('NCHW', 'OIHW', diff --git a/jax/_src/lax/fft.py b/jax/_src/lax/fft.py index a1cce3500df1..36553e512cd7 100644 --- a/jax/_src/lax/fft.py +++ b/jax/_src/lax/fft.py @@ -23,6 +23,7 @@ from jax import lax from jax._src import dispatch +from jax._src import dtypes from jax._src.api import jit, linear_transpose, ShapeDtypeStruct from jax._src.core import Primitive, is_constant_shape from jax._src.interpreters import ad @@ -30,7 +31,6 @@ from jax._src.interpreters import mlir from jax._src.lib.mlir.dialects import hlo from jax._src.lib import xla_client -from jax._src.numpy.util import promote_dtypes_complex, promote_dtypes_inexact __all__ = [ "fft", @@ -61,9 +61,9 @@ def fft(x, fft_type: xla_client.FftType | str, fft_lengths: Sequence[int]): if typ == xla_client.FftType.RFFT: if np.iscomplexobj(x): raise ValueError("only real valued inputs supported for rfft") - x, = promote_dtypes_inexact(x) + x = lax.convert_element_type(x, dtypes.to_inexact_dtype(dtypes.dtype(x))) else: - x, = promote_dtypes_complex(x) + x = lax.convert_element_type(x, dtypes.to_complex_dtype(dtypes.dtype(x))) if len(fft_lengths) == 0: # XLA FFT doesn't support 0-rank. return x @@ -157,7 +157,7 @@ def _irfft_transpose(t, fft_lengths): out = scale * lax.expand_dims(mask, range(x.ndim - 1)) * x assert out.dtype == _complex_dtype(t.dtype), (out.dtype, t.dtype) # Use JAX's convention for complex gradients - # https://github.com/google/jax/issues/6223#issuecomment-807740707 + # https://github.com/jax-ml/jax/issues/6223#issuecomment-807740707 return lax.conj(out) def _fft_transpose_rule(t, operand, fft_type, fft_lengths): diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py index c0ed971770b4..f51f0436b7a9 100644 --- a/jax/_src/lax/lax.py +++ b/jax/_src/lax/lax.py @@ -22,12 +22,11 @@ import itertools import math import operator -from typing import Any, ClassVar, TypeVar, Union, cast as type_cast, overload, TYPE_CHECKING +from typing import Any, NamedTuple, TypeVar, Union, cast as type_cast, overload import warnings import numpy as np -import jax from jax import tree_util from jax.sharding import Sharding from jax.tree_util import tree_map @@ -42,6 +41,7 @@ from jax._src import dtypes from jax._src import effects from jax._src import linear_util as lu +from jax._src import pjit from jax._src import pretty_printer as pp from jax._src import source_info_util from jax._src import state @@ -62,10 +62,11 @@ standard_multi_result_abstract_eval, standard_primitive) from jax._src import xla_bridge from jax._src.lib import xla_client +from jax._src.lib import version as jaxlib_version from jax._src.lib.mlir import ir from jax._src.lib.mlir.dialects import chlo from jax._src.lib.mlir.dialects import hlo -from jax._src.sharding_impls import PmapSharding +from jax._src.sharding_impls import PmapSharding, NamedSharding, PartitionSpec from jax._src.typing import Array, ArrayLike, DimSize, DuckTypedArray, DTypeLike, Shape from jax._src.util import (cache, safe_zip, safe_map, canonicalize_axis, split_list, NumpyComplexWarning) @@ -84,6 +85,10 @@ map, unsafe_map = safe_map, map zip, unsafe_zip = safe_zip, zip +def _matrix_transpose(x: Array) -> Array: + assert x.ndim >= 2 + return transpose(x, [*range(x.ndim - 2), x.ndim - 1, x.ndim - 2]) + def _clip_int_to_valid_range(val: DimSize, dtype, where: str) -> int: info = np.iinfo(dtype) val = core.concrete_dim_or_error(val, where) @@ -132,7 +137,7 @@ def asarray(x: ArrayLike) -> Array: if isinstance(x, Array): return x if isinstance(x, (np.ndarray, np.generic, bool, int, float, builtins.complex)): - return _convert_element_type(x, weak_type=dtypes.is_weakly_typed(x)) + return _convert_element_type(x, weak_type=dtypes.is_weakly_typed(x)) # type: ignore[unused-ignore,bad-return-type] else: raise TypeError(f"asarray: expected ArrayLike, got {x} of type {type(x)}.") @@ -515,7 +520,7 @@ def convert_element_type(operand: ArrayLike, Returns: An array with the same shape as `operand`, cast elementwise to `new_dtype`. """ - return _convert_element_type(operand, new_dtype, weak_type=False) + return _convert_element_type(operand, new_dtype, weak_type=False) # type: ignore[unused-ignore,bad-return-type] def _convert_element_type( operand: ArrayLike, @@ -525,17 +530,30 @@ def _convert_element_type( if hasattr(operand, '__jax_array__'): operand = operand.__jax_array__() - if (dtypes.issubdtype(new_dtype, dtypes.extended) or - dtypes.issubdtype(getattr(operand, 'dtype', None), dtypes.extended)): - return convert_element_type_p.bind( - operand, new_dtype=new_dtype, weak_type=bool(weak_type), - sharding=sharding) - - new_dtype = type_cast(DTypeLike | None, new_dtype) - # Don't canonicalize old_dtype because x64 context might cause # un-canonicalized operands to be passed in. old_dtype = dtypes.dtype(operand, canonicalize=False) + + if (isinstance(new_dtype, dtypes.ExtendedDType) or + isinstance(old_dtype, dtypes.ExtendedDType)): + if sharding is not None or weak_type: raise NotImplementedError + if new_dtype == old_dtype: return operand + if (isinstance(new_dtype, dtypes.ExtendedDType) and + isinstance(old_dtype, dtypes.ExtendedDType)): + old_rep_dtype = core.physical_element_aval(old_dtype).dtype + new_rep_dtype = core.physical_element_aval(new_dtype).dtype + raise ValueError( + "cannot directly convert between extended dtypes: from " + f"{dtype_to_string(old_dtype)} to {dtype_to_string(new_dtype)}. " + "Instead, convert to and from their representation dtypes, e.g.:\n" + f"{dtype_to_string(old_dtype)} -> {dtype_to_string(old_rep_dtype)} " + f"-> {dtype_to_string(new_rep_dtype)} -> {dtype_to_string(new_dtype)}") + if isinstance(new_dtype, dtypes.ExtendedDType): + return to_edtype_p.bind(operand, edtype=new_dtype) + return from_edtype_p.bind(operand, dtype=np.dtype(new_dtype)) + + new_dtype = type_cast(DTypeLike | None, new_dtype) + old_weak_type = dtypes.is_weakly_typed(operand) if new_dtype is None: new_dtype = old_dtype @@ -634,64 +652,42 @@ def concatenate(operands: Array | Sequence[ArrayLike], dimension: int) -> Array: _precision_strings: dict[Any, Precision] = {} -# TODO(b/333851820): pytype does not properly handle _missing_ in enums. -# We work around that by defining `Precision` as a normal class. -if TYPE_CHECKING: - - class Precision: - DEFAULT: ClassVar[Precision] - HIGH: ClassVar[Precision] - HIGHEST: ClassVar[Precision] - - def __new__(cls, value: Precision | int | str | None) -> Precision: - raise NotImplementedError - - @property - def name(self) -> str: - raise NotImplementedError - - @property - def value(self) -> int: - raise NotImplementedError - -else: - - class Precision(enum.Enum): - """Precision enum for lax matrix multiply related functions. - - The device-dependent `precision` argument to JAX functions generally - controls the tradeoff between speed and accuracy for array computations on - accelerator backends, (i.e. TPU and GPU). Has no impact on CPU backends. - This only has an effect on float32 computations, and does not affect the - input/output datatypes. Members are: - - DEFAULT: - Fastest mode, but least accurate. On TPU: performs float32 computations in - bfloat16. On GPU: uses tensorfloat32 if available (e.g. on A100 and H100 - GPUs), otherwise standard float32 (e.g. on V100 GPUs). Aliases: - ``'default'``, ``'fastest'``. - HIGH: - Slower but more accurate. On TPU: performs float32 computations in 3 - bfloat16 passes. On GPU: uses tensorfloat32 where available, otherwise - float32. Aliases: ``'high'``.. - HIGHEST: - Slowest but most accurate. On TPU: performs float32 computations in 6 - bfloat16. Aliases: ``'highest'``. On GPU: uses float32. - """ +class Precision(enum.Enum): + """Precision enum for lax matrix multiply related functions. + + The device-dependent `precision` argument to JAX functions generally + controls the tradeoff between speed and accuracy for array computations on + accelerator backends, (i.e. TPU and GPU). Has no impact on CPU backends. + This only has an effect on float32 computations, and does not affect the + input/output datatypes. Members are: + + DEFAULT: + Fastest mode, but least accurate. On TPU: performs float32 computations in + bfloat16. On GPU: uses tensorfloat32 if available (e.g. on A100 and H100 + GPUs), otherwise standard float32 (e.g. on V100 GPUs). Aliases: + ``'default'``, ``'fastest'``. + HIGH: + Slower but more accurate. On TPU: performs float32 computations in 3 + bfloat16 passes. On GPU: uses tensorfloat32 where available, otherwise + float32. Aliases: ``'high'``.. + HIGHEST: + Slowest but most accurate. On TPU: performs float32 computations in 6 + bfloat16. Aliases: ``'highest'``. On GPU: uses float32. + """ - DEFAULT = 0 - HIGH = 1 - HIGHEST = 2 + DEFAULT = 0 + HIGH = 1 + HIGHEST = 2 - @classmethod - def _missing_(cls, value: object) -> Precision | None: - return _precision_strings.get(value) + @classmethod + def _missing_(cls, value: object) -> Precision | None: + return _precision_strings.get(value) - def __repr__(self) -> str: - return f'{self.__class__.__name__}.{self.name}' + def __repr__(self) -> str: + return f'{self.__class__.__name__}.{self.name}' - def __str__(self) -> str: - return self.name + def __str__(self) -> str: + return self.name _precision_strings['highest'] = Precision.HIGHEST @@ -713,15 +709,204 @@ def __str__(self) -> str: None, ] + +class DotAlgorithm(NamedTuple): + """Specify the algorithm used for computing dot products. + + When used as input to :func:`~jax.lax.dot_general`, this data structure is + used for controlling the properties of the algorithm used for computing the + dot product. This API controls the precision used for the computation, and + allows users to access hardware-specific accelerations. + + Support for these algorithms is platform dependent, and using an unsupported + algorithm will raise a Python exception when the computation is compiled. The + algorithms that are known to be supported on at least some platforms are + listed in the :class:`~jax.lax.DotAlgorithm.Preset` enum, and these are a + good starting point for experimenting with this API. + + A "dot algorithm" is specified by the following parameters: + + * ``lhs_precision_type`` and ``rhs_precision_type``, the data types that the + LHS and RHS of the operation are rounded to. + * ``accumulation_type`` the data type used for accumulation. + * ``lhs_component_count``, ``rhs_component_count``, and + ``num_primitive_operations`` apply to algorithms that decompose the LHS + and/or RHS into multiple components and execute multiple operations on + those values, usually to emulate a higher precision. For algorithms with no + decomposition, these values should be set to ``1``. + * ``allow_imprecise_accumulation`` to specify if accumulation in lower + precision is permitted for some steps (e.g. + ``CUBLASLT_MATMUL_DESC_FAST_ACCUM``). + + The `StableHLO spec `_ for + the dot operation doesn't require that the precision types be the same as the + storage types for the inputs or outputs, but some plaforms may require that + these types match. Furthermore, the return type of + :func:`~jax.lax.dot_general` is always defined by the ``accumulation_type`` + parameter of the input algorithm, if specified. + + Examples: + + Accumulate two 16-bit floats using a 32-bit float accumulator: + + >>> algorithm = DotAlgorithm( + ... lhs_precision_type=np.float16, + ... rhs_precision_type=np.float16, + ... accumulation_type=np.float32, + ... ) + >>> lhs = jnp.array([1.0, 2.0, 3.0, 4.0], dtype=np.float16) + >>> rhs = jnp.array([1.0, 2.0, 3.0, 4.0], dtype=np.float16) + >>> dot(lhs, rhs, algorithm=algorithm) # doctest: +SKIP + array([ 1., 4., 9., 16.], dtype=float32) + + Or, equivalently, using a preset: + + >>> algorithm = DotAlgorithm.Preset.F16_F16_F32 + >>> dot(lhs, rhs, algorithm=algorithm) # doctest: +SKIP + array([ 1., 4., 9., 16.], dtype=float32) + """ + + lhs_precision_type: DTypeLike + rhs_precision_type: DTypeLike + accumulation_type: DTypeLike + lhs_component_count: int = 1 + rhs_component_count: int = 1 + num_primitive_operations: int = 1 + allow_imprecise_accumulation: bool = False + + def _convert_to_hlo_attr(self, lhs_dtype: DTypeLike, + rhs_dtype: DTypeLike) -> hlo.DotAlgorithm: + del lhs_dtype, rhs_dtype # unused + return hlo.DotAlgorithm.get( + mlir.dtype_to_ir_type(dtypes.dtype(self.lhs_precision_type)), + mlir.dtype_to_ir_type(dtypes.dtype(self.rhs_precision_type)), + mlir.dtype_to_ir_type(dtypes.dtype(self.accumulation_type)), + self.lhs_component_count, + self.rhs_component_count, + self.num_primitive_operations, + self.allow_imprecise_accumulation, + ) + + # mypy doesn't currently support nested classes in a NamedTuple definition. + class Preset(enum.Enum): # type: ignore[misc] + DEFAULT = 0 + ANY_F8_ANY_F8_F32 = 1 + ANY_F8_ANY_F8_F32_FAST_ACCUM = 2 + F16_F16_F16 = 3 + F16_F16_F32 = 4 + BF16_BF16_BF16 = 5 + BF16_BF16_F32 = 6 + BF16_BF16_F32_X3 = 7 + BF16_BF16_F32_X6 = 8 + TF32_TF32_F32 = 9 + TF32_TF32_F32_X3 = 10 + F32_F32_F32 = 11 + F64_F64_F64 = 12 + + def __repr__(self) -> str: + return f'{self.__class__.__name__}.{self.name}' + + def __str__(self) -> str: + return self.name + + @property + def accumulation_type(self) -> DTypeLike: + match self: + case DotAlgorithm.Preset.DEFAULT: + raise TypeError( + "The default dot algorithm does not have an accumulation type.") + case DotAlgorithm.Preset.F16_F16_F16: + return np.float16 + case DotAlgorithm.Preset.BF16_BF16_BF16: + return dtypes.bfloat16 + case DotAlgorithm.Preset.F64_F64_F64: + return np.float64 + case _: + return np.float32 + + def _convert_to_hlo_attr(self, lhs_dtype: DTypeLike, + rhs_dtype: DTypeLike) -> hlo.DotAlgorithm | None: + if self == DotAlgorithm.Preset.DEFAULT: + return None + + if self in (DotAlgorithm.Preset.ANY_F8_ANY_F8_F32, + DotAlgorithm.Preset.ANY_F8_ANY_F8_F32_FAST_ACCUM): + fp8_dtypes = (np.dtype(dtypes.float8_e4m3b11fnuz), + np.dtype(dtypes.float8_e4m3fn), + np.dtype(dtypes.float8_e4m3fnuz), + np.dtype(dtypes.float8_e5m2), + np.dtype(dtypes.float8_e5m2fnuz)) + if lhs_dtype not in fp8_dtypes or rhs_dtype not in fp8_dtypes: + raise ValueError( + f"The dot algorithm '{self}' requires both inputs to have float8 " + f"dtypes. Got {lhs_dtype} and {rhs_dtype} instead.") + lhs = mlir.dtype_to_ir_type(dtypes.dtype(lhs_dtype)) + rhs = mlir.dtype_to_ir_type(dtypes.dtype(rhs_dtype)) + acc = ir.F32Type.get() + return hlo.DotAlgorithm.get( + lhs, rhs, acc, 1, 1, 1, + self == DotAlgorithm.Preset.ANY_F8_ANY_F8_F32_FAST_ACCUM) + + else: + f16 = ir.F16Type.get() + f32 = ir.F32Type.get() + f64 = ir.F64Type.get() + bf16 = ir.BF16Type.get() + tf32 = ir.FloatTF32Type.get() + match self: + case DotAlgorithm.Preset.F16_F16_F16: + return hlo.DotAlgorithm.get(f16, f16, f16, 1, 1, 1, False) + case DotAlgorithm.Preset.F16_F16_F32: + return hlo.DotAlgorithm.get(f16, f16, f32, 1, 1, 1, False) + case DotAlgorithm.Preset.BF16_BF16_BF16: + return hlo.DotAlgorithm.get(bf16, bf16, bf16, 1, 1, 1, False) + case DotAlgorithm.Preset.BF16_BF16_F32: + return hlo.DotAlgorithm.get(bf16, bf16, f32, 1, 1, 1, False) + case DotAlgorithm.Preset.BF16_BF16_F32_X3: + return hlo.DotAlgorithm.get(bf16, bf16, f32, 1, 1, 3, False) + case DotAlgorithm.Preset.BF16_BF16_F32_X6: + return hlo.DotAlgorithm.get(bf16, bf16, f32, 1, 1, 6, False) + case DotAlgorithm.Preset.TF32_TF32_F32: + return hlo.DotAlgorithm.get(tf32, tf32, f32, 1, 1, 1, False) + case DotAlgorithm.Preset.TF32_TF32_F32_X3: + return hlo.DotAlgorithm.get(tf32, tf32, f32, 1, 1, 3, False) + case DotAlgorithm.Preset.F32_F32_F32: + return hlo.DotAlgorithm.get(f32, f32, f32, 1, 1, 1, False) + case DotAlgorithm.Preset.F64_F64_F64: + return hlo.DotAlgorithm.get(f64, f64, f64, 1, 1, 1, False) + case _: + raise NotImplementedError("unreachable") + + +DotAlgorithmLike = Union[ + DotAlgorithm, + DotAlgorithm.Preset, + str, + None, +] +_DotAlgorithmLike = Union[ + DotAlgorithm, + DotAlgorithm.Preset, + None, +] +DotTransposeAlgorithmLike = Union[ + DotAlgorithmLike, + tuple[DotAlgorithmLike, DotAlgorithmLike], +] +DotTransposeAlgorithm = tuple[_DotAlgorithmLike, _DotAlgorithmLike] + + def dot(lhs: Array, rhs: Array, precision: PrecisionLike = None, - preferred_element_type: DTypeLike | None = None) -> Array: + preferred_element_type: DTypeLike | None = None, + algorithm: DotAlgorithmLike = None, + transpose_algorithm: DotTransposeAlgorithmLike = None) -> Array: """Vector/vector, matrix/vector, and matrix/matrix multiplication. Wraps XLA's `Dot `_ operator. - For more general contraction, see the `dot_general` operator. + For more general contraction, see the :func:`jax.lax.dot_general` operator. Args: lhs: an array of dimension 1 or 2. @@ -733,6 +918,17 @@ def dot(lhs: Array, rhs: Array, precision: PrecisionLike = None, preferred_element_type: Optional. Either ``None``, which means the default accumulation type for the input types, or a datatype, indicating to accumulate results to and return a result with that datatype. + algorithm: Optional. Specify the algorithm used for accumulating the dot + product. See :class:`~jax.lax.DotAlgorithm` for more details. This argument + cannot be used with ``precision`` or ``preferred_element_type``. + transpose_algorithm: Optional. This allows specifying the algorithm used when + this operation is transposed, typically as part of reverse-mode automatic + differentiation. This argument can either be a single + :class:`~jax.lax.DotAlgorithm` or a tuple of two + :class:`~jax.lax.DotAlgorithm`s, in which case the two elements define the + algorithm for transposing the LHS and RHS, respectively. + ``transpose_algorithm`` must be explicitly specified when transposing a + dot product where a specific ``algorithm`` was used on the forward pass. Returns: An array containing the product. @@ -740,7 +936,9 @@ def dot(lhs: Array, rhs: Array, precision: PrecisionLike = None, if 1 <= lhs.ndim <= 2 and 1 <= rhs.ndim <= 2 and core.definitely_equal(lhs.shape[-1], rhs.shape[0]): return dot_general(lhs, rhs, (((lhs.ndim - 1,), (0,)), ((), ())), precision=precision, - preferred_element_type=preferred_element_type) + preferred_element_type=preferred_element_type, + algorithm=algorithm, + transpose_algorithm=transpose_algorithm) else: raise TypeError("Incompatible shapes for dot: got {} and {}.".format( lhs.shape, rhs.shape)) @@ -751,7 +949,9 @@ def dot(lhs: Array, rhs: Array, precision: PrecisionLike = None, def dot_general(lhs: ArrayLike, rhs: ArrayLike, dimension_numbers: DotDimensionNumbers, precision: PrecisionLike = None, - preferred_element_type: DTypeLike | None = None) -> Array: + preferred_element_type: DTypeLike | None = None, + algorithm: DotAlgorithmLike = None, + transpose_algorithm: DotTransposeAlgorithmLike = None) -> Array: """General dot product/contraction operator. Wraps XLA's `DotGeneral @@ -778,6 +978,17 @@ def dot_general(lhs: ArrayLike, rhs: ArrayLike, dimension_numbers: DotDimensionN preferred_element_type: Optional. Either ``None``, which means the default accumulation type for the input types, or a datatype, indicating to accumulate results to and return a result with that datatype. + algorithm: Optional. Specify the algorithm used for accumulating the dot + product. See :class:`~jax.lax.DotAlgorithm` for more details. This argument + cannot be used with ``precision`` or ``preferred_element_type``. + transpose_algorithm: Optional. This allows specifying the algorithm used when + this operation is transposed, typically as part of reverse-mode automatic + differentiation. This argument can either be a single + :class:`~jax.lax.DotAlgorithm` or a tuple of two + :class:`~jax.lax.DotAlgorithm`s, in which case the two elements define the + algorithm for transposing the LHS and RHS, respectively. + ``transpose_algorithm`` must be explicitly specified when transposing a + dot product where a specific ``algorithm`` was used on the forward pass. Returns: An array whose first dimensions are the (shared) batch dimensions, followed by @@ -795,7 +1006,9 @@ def dot_general(lhs: ArrayLike, rhs: ArrayLike, dimension_numbers: DotDimensionN return dot_general_p.bind(lhs, rhs, dimension_numbers=(cdims, bdims), precision=canonicalize_precision(precision), - preferred_element_type=preferred_element_type) + preferred_element_type=preferred_element_type, + algorithm=canonicalize_dot_algorithm(algorithm), + transpose_algorithm=canonicalize_dot_transpose_algorithm(transpose_algorithm)) def ragged_dot( @@ -1094,7 +1307,7 @@ def comp(x, y): if any(isinstance(c, core.Tracer) for c in consts): raise NotImplementedError( "Reduction computations can't close over Tracers. Please open an issue " - "at https://github.com/google/jax.") + "at https://github.com/jax-ml/jax.") return jaxpr, tuple(consts) @cache() @@ -1107,7 +1320,7 @@ def _variadic_reduction_jaxpr(computation, flat_avals, aval_tree): if any(isinstance(c, core.Tracer) for c in consts): raise NotImplementedError( "Reduction computations can't close over Tracers. Please open an issue " - "at https://github.com/google/jax.") + "at https://github.com/jax-ml/jax.") return core.ClosedJaxpr(jaxpr, consts), out_tree() def _get_monoid_reducer(monoid_op: Callable, @@ -1349,7 +1562,7 @@ def broadcasted_iota(dtype: DTypeLike, shape: Shape, dimension: int) -> Array: return iota_p.bind(*dynamic_shape, dtype=dtype, shape=tuple(static_shape), dimension=dimension) -def _eye(dtype: DTypeLike, shape: Shape, offset: DimSize) -> Array: +def _eye(dtype: DTypeLike, shape: Shape, offset: DimSize = 0) -> Array: """Like numpy.eye, create a 2D array with ones on a diagonal.""" offset = _clip_int_to_valid_range(offset, np.int32, "argument `offset` of jax.numpy.eye") @@ -1390,18 +1603,43 @@ def stop_gradient(x: T) -> T: argument `x` unchanged. However, ``stop_gradient`` prevents the flow of gradients during forward or reverse-mode automatic differentiation. If there are multiple nested gradient computations, ``stop_gradient`` stops gradients - for all of them. - - For example: - - >>> jax.grad(lambda x: x**2)(3.) - Array(6., dtype=float32, weak_type=True) - >>> jax.grad(lambda x: jax.lax.stop_gradient(x)**2)(3.) - Array(0., dtype=float32, weak_type=True) - >>> jax.grad(jax.grad(lambda x: x**2))(3.) - Array(2., dtype=float32, weak_type=True) - >>> jax.grad(jax.grad(lambda x: jax.lax.stop_gradient(x)**2))(3.) - Array(0., dtype=float32, weak_type=True) + for all of them. For some discussion of where this is useful, refer to + :ref:`stopping-gradients`. + + Args: + x: array or pytree of arrays + + Returns: + input value is returned unchanged, but within autodiff will be treated as + a constant. + + Examples: + Consider a simple function that returns the square of the input value: + + >>> def f1(x): + ... return x ** 2 + >>> x = jnp.float32(3.0) + >>> f1(x) + Array(9.0, dtype=float32) + >>> jax.grad(f1)(x) + Array(6.0, dtype=float32) + + The same function with ``stop_gradient`` around ``x`` will be equivalent + under normal evaluation, but return a zero gradient because ``x`` is + effectively treated as a constant: + + >>> def f2(x): + ... return jax.lax.stop_gradient(x) ** 2 + >>> f2(x) + Array(9.0, dtype=float32) + >>> jax.grad(f2)(x) + Array(0.0, dtype=float32) + + This is used in a number of places within the JAX codebase; for example + :func:`jax.nn.softmax` internally normalizes the input by its maximum + value, and this maximum value is wrapped in ``stop_gradient`` for + efficiency. Refer to :ref:`stopping-gradients` for more discussion of + the applicability of ``stop_gradient``. """ def stop(x): # only bind primitive on inexact dtypes, to avoid some staging @@ -1731,13 +1969,54 @@ def broadcasting_shape_rule(name, *avals): return tuple(result_shape) +def broadcasting_sharding_rule(name, *avals): + shapes = [aval.shape for aval in avals if aval.shape] + if not shapes: + return () + if len({len(shape) for shape in shapes}) != 1: + msg = '{}: arrays must have same number of dimensions, got {}.' + raise TypeError(msg.format(name, ', '.join(map(str, map(tuple, shapes))))) + + specs = [a.sharding.spec for a in avals if a.shape] + + mesh = None + for a in avals: + if a.shape: + mesh = a.sharding.mesh + if mesh is not None and mesh != a.sharding.mesh: + raise ValueError( + f'Mesh for all inputs should be equal. Got one mesh: {mesh} and' + f' another mesh: {a.sharding.mesh}') + assert mesh is not None + + result_specs = [] + for ss, ds in zip(zip(*specs), zip(*shapes)): + if all(s == ss[0] for s in ss[1:]): + # if all dimension shardings are same, the resulting dimension sharding is + # the same. + result_specs.append(ss[0]) + else: + non_trivial_s = [s for s, d in zip(ss, ds) + if not (core.definitely_equal(d, 1) and s is None)] + if not non_trivial_s: + result_specs.append(None) + elif all(non_trivial_s[0] == s for s in non_trivial_s[1:]): + result_specs.append(non_trivial_s[0]) + else: + raise TypeError(f'{name} got incompatible shardings for broadcasting: ' + f'{", ".join(map(str, map(tuple, specs)))}.') + return NamedSharding(mesh, PartitionSpec(*result_specs)) + + def naryop(result_dtype, accepted_dtypes, name, allow_extended_dtype=False, require_same_dtypes=False): dtype_rule = partial(naryop_dtype_rule, result_dtype, accepted_dtypes, name, allow_extended_dtype=allow_extended_dtype, require_same=require_same_dtypes) shape_rule = partial(broadcasting_shape_rule, name) - prim = standard_primitive(shape_rule, dtype_rule, name) + sharding_rule = partial(broadcasting_sharding_rule, name) + prim = standard_primitive(shape_rule, dtype_rule, name, + sharding_rule=sharding_rule) batching.defbroadcasting(prim) pe.def_trivial_padding(prim) return prim @@ -1794,6 +2073,23 @@ def broadcast_hlo( out.append(arg) return out +def multi_sharding_in_dim(ctx, ops, in_avals, out_aval): + out = [] + for op, in_aval in zip(ops, in_avals): + if in_aval.sharding == out_aval.sharding or in_aval.sharding is None: + out.append(op) + else: + # TODO(yashkatariya, dougalm): If `in_aval.sharding` contains + # CompilerShardingAxis, then specify `unspecified_dims` via + # `wrap_with_sharding_op`. + if config.use_shardy_partitioner.value: + sp = in_aval.sharding._to_sdy_sharding(in_aval.ndim) + else: + sp = in_aval.sharding._to_xla_hlo_sharding(in_aval.ndim).to_proto() + out.append(mlir.wrap_with_sharding_op(ctx, op, out_aval, sp)) + return out + + def _nary_lower_hlo(op: Callable, ctx, *args: ir.Value, explicit_type=False, **params) -> Sequence[ir.Value]: @@ -1804,13 +2100,22 @@ def _nary_lower_hlo(op: Callable, ctx, """ del params avals_in, (aval_out,) = ctx.avals_in, ctx.avals_out - broadcasted_args = mlir.multi_broadcast_in_dim( - ctx, args, avals_in, aval_out.shape) + args = mlir.multi_broadcast_in_dim(ctx, args, avals_in, aval_out.shape) # type: ignore + if config.sharding_in_types.value: + args = multi_sharding_in_dim(ctx, args, avals_in, aval_out) if explicit_type: - return [op(mlir.aval_to_ir_type(aval_out), *broadcasted_args)] + out = op(mlir.aval_to_ir_type(aval_out), *args) else: - return [op(*broadcasted_args)] + out = op(*args) + if config.sharding_in_types.value: + if config.use_shardy_partitioner.value: + out_sp = aval_out.sharding._to_sdy_sharding(aval_out.ndim) + else: + out_sp = aval_out.sharding._to_xla_hlo_sharding(aval_out.ndim).to_proto() + return [mlir.wrap_with_sharding_op(ctx, out, aval_out, out_sp)] + else: + return [out] _float = {np.floating} @@ -1965,7 +2270,15 @@ def _tan_impl(x): tan_p = standard_unop(_float | _complex, 'tan') ad.defjvp2(tan_p, lambda g, ans, x: mul(g, _const(x, 1) + square(ans))) -mlir.register_lowering(tan_p, partial(_nary_lower_hlo, chlo.tan)) +# TODO(b/368011034): Remove after jaxlib 0.4.34 release. In 0.4.33, this +# lowering is mostly supported, but it fails on export or with the PJRT plugin +# because those modes target an older StableHLO version, and the +# compatibility updates from https://github.com/openxla/xla/pull/16649 aren't +# included in the 0.4.33 release. +if jaxlib_version <= (0, 4, 33): + mlir.register_lowering(tan_p, partial(_nary_lower_hlo, chlo.tan)) +else: + mlir.register_lowering(tan_p, partial(_nary_lower_hlo, hlo.tan)) def asin_impl(x): if dtypes.issubdtype(_dtype(x), np.complexfloating): @@ -2135,7 +2448,11 @@ def _pow_dtype_rule(x, y): def _pow_jvp_lhs(g, ans, x, y): y_dtype = dtypes.dtype(y) - x, y = jax._src.numpy.util.promote_dtypes_numeric(x, y) # TODO replace this + result_dtype = dtypes.result_type(x, y) + if result_dtype == bool: + result_dtype = 'int32' + x = convert_element_type(x, result_dtype) + y = convert_element_type(y, result_dtype) if dtypes.issubdtype(y_dtype, np.integer): if x.shape != y.shape: shape = broadcast_shapes(x.shape, y.shape) @@ -2247,7 +2564,7 @@ def _add_jvp(primals, tangents): xdot, ydot = tangents primal_out = add(x, y) if type(xdot) is type(ydot) is ad_util.Zero: - return primal_out, ad_util.Zero.from_value(primal_out) + return primal_out, ad_util.Zero.from_primal_value(primal_out) if type(xdot) is ad_util.Zero: return primal_out, _maybe_broadcast(primal_out.shape, ydot) elif type(ydot) is ad_util.Zero: @@ -2278,7 +2595,7 @@ def _sub_jvp(primals, tangents): xdot, ydot = tangents primal_out = sub(x, y) if type(xdot) is type(ydot) is ad_util.Zero: - return primal_out, ad_util.Zero.from_value(primal_out) + return primal_out, ad_util.Zero.from_primal_value(primal_out) if type(xdot) is ad_util.Zero: return primal_out, _maybe_broadcast(primal_out.shape, neg(ydot)) elif type(ydot) is ad_util.Zero: @@ -2467,16 +2784,12 @@ def _convert_element_type_shape_rule(operand, *, new_dtype, weak_type, sharding): return operand.shape +def _convert_element_type_sharding_rule(operand, *, new_dtype, weak_type, + sharding): + return sharding + def _convert_element_type_dtype_rule(operand, *, new_dtype, weak_type, sharding): - if (operand.dtype != new_dtype and - ((dtypes.issubdtype(operand.dtype, dtypes.extended) and - not operand.dtype._rules.convert_from(operand.dtype, new_dtype)) or - (dtypes.issubdtype(new_dtype, dtypes.extended) and - not new_dtype._rules.convert_to(operand.dtype, new_dtype)))): - raise ValueError( - f"Cannot convert_element_type from {dtype_to_string(operand.dtype)} " - f"to {dtype_to_string(new_dtype)}") return new_dtype def _convert_element_type_weak_type_rule(operand, *, new_dtype, weak_type, @@ -2496,13 +2809,13 @@ def _convert_element_type_transpose_rule(ct, operand, *, new_dtype, weak_type, return [convert_element_type_p.bind( ct, new_dtype=old_dtype, weak_type=old_weak_type, sharding=sharding)] -def _convert_element_type_jvp_rule(tangent, operand , *, new_dtype, weak_type, - sharding): - if core.primal_dtype_to_tangent_dtype(new_dtype) == dtypes.float0: - tangent_aval = core.raise_to_shaped(core.get_aval(tangent)) - return ad_util.Zero(tangent_aval.update(dtype=dtypes.float0, weak_type=False)) +def _convert_element_type_jvp_rule(tangent, primal_result, operand, *, + new_dtype, weak_type, sharding): + new_tangent_dtype = core.primal_dtype_to_tangent_dtype(new_dtype) + if new_tangent_dtype == dtypes.float0: + return ad_util.Zero.from_primal_value(primal_result) else: - return convert_element_type_p.bind(tangent, new_dtype=new_dtype, + return convert_element_type_p.bind(tangent, new_dtype=new_tangent_dtype, weak_type=weak_type, sharding=sharding) def _convert_elt_type_folding_rule(consts, eqn): @@ -2553,15 +2866,16 @@ def _convert_element_type_bind(operand, *, new_dtype, weak_type, sharding): new_dtype=new_dtype, weak_type=weak_type, sharding=sharding) if sharding is not None: - operand = jax.lax.with_sharding_constraint(operand, sharding) + operand = pjit.with_sharding_constraint(operand, sharding) return operand convert_element_type_p.def_custom_bind(_convert_element_type_bind) convert_element_type_p.def_impl(partial(dispatch.apply_primitive, convert_element_type_p)) convert_element_type_p.def_abstract_eval( partial(standard_abstract_eval, convert_element_type_p, _convert_element_type_shape_rule, _convert_element_type_dtype_rule, - _convert_element_type_weak_type_rule)) -ad.defjvp(convert_element_type_p, _convert_element_type_jvp_rule) + _convert_element_type_weak_type_rule, + _convert_element_type_sharding_rule)) +ad.defjvp2(convert_element_type_p, _convert_element_type_jvp_rule) ad.primitive_transposes[convert_element_type_p] = _convert_element_type_transpose_rule batching.defvectorized(convert_element_type_p) pe.const_fold_rules[convert_element_type_p] = _convert_elt_type_folding_rule @@ -2584,6 +2898,91 @@ def _convert_element_type_lower(ctx, operand, *, new_dtype, weak_type, mlir.register_lowering(convert_element_type_p, _convert_element_type_lower) +def _to_edtype_abstract_eval(x, *, edtype): + assert (isinstance(edtype, dtypes.ExtendedDType) and + not isinstance(x.dtype, dtypes.ExtendedDType)) + # For backward compatibility, if the edtype rules have a `convert_to` method, + # use that rather than looking for an `allow_conversion: bool` attribute. + if convert_to := getattr(edtype._rules, 'convert_to', None): + allow_conversion = convert_to(x.dtype, edtype) + else: + allow_conversion = edtype._rules.allow_conversion + if not allow_conversion: + raise ValueError( + f"Cannot convert_element_type from {dtype_to_string(x.dtype)} " + f"to {dtype_to_string(edtype)}") + rep_aval = core.physical_element_aval(edtype) + if x.dtype != rep_aval.dtype: + raise ValueError( + "can only convert to extended dtype from its representation dtype, " + f"but tried to convert from {dtype_to_string(x.dtype)} to " + f"{dtype_to_string(edtype)} which doesn't match the representation type " + f"{dtype_to_string(rep_aval.dtype)}.") + if x.ndim < rep_aval.ndim: + raise ValueError( + "can only convert to extended dtype from an array of its " + f"representation type, but the extended dtype {dtype_to_string(edtype)}" + f" has a representation shape {rep_aval.shape} (rank {rep_aval.ndim}) " + f"while the given representation array has shape {x.shape} (rank " + f"{x.ndim} < {rep_aval.ndim}).") + n = x.ndim - rep_aval.ndim + shape_prefix, shape_suffix = x.shape[:n], x.shape[n:] + if shape_suffix != rep_aval.shape: + raise ValueError( + "can only convert to extended dtype from an array of its " + f"representation type, but the extended dtype {dtype_to_string(edtype)}" + f" has a representation shape {rep_aval.shape} while the given " + f"representation array has shape {x.shape}, so the shape suffix " + f"does not match: given {shape_suffix} but required {rep_aval.shape}.") + return core.raise_to_shaped(x).update(shape=shape_prefix, dtype=edtype) + +to_edtype_p = Primitive('to_edtype') +to_edtype_p.def_impl(partial(dispatch.apply_primitive, to_edtype_p)) +to_edtype_p.def_abstract_eval(_to_edtype_abstract_eval) +ad.defjvp(to_edtype_p, + lambda t, x, edtype: + convert_element_type(t, core.primal_dtype_to_tangent_dtype(edtype))) +ad.primitive_transposes[to_edtype_p] = \ + lambda ct, x, edtype: [from_edtype_p.bind(ct, dtype=x.aval.dtype)] # type: ignore +batching.defvectorized(to_edtype_p) +mlir.register_lowering(to_edtype_p, lambda _, x, **__: [x]) + + +def _from_edtype_abstract_eval(x, *, dtype): + assert (isinstance(x.dtype, dtypes.ExtendedDType) and + not isinstance(dtype, dtypes.ExtendedDType)) + if convert_from := getattr(x.dtype._rules, 'convert_from', None): + allow_conversion = convert_from(x.dtype, dtype) + else: + allow_conversion = x.dtype._rules.allow_conversion + if not allow_conversion: + raise ValueError( + f"Cannot convert_element_type from {dtype_to_string(x.dtype)} " + f"to {dtype_to_string(dtype)}") + rep_aval = core.physical_element_aval(x.dtype) + if rep_aval.dtype != dtype: + raise ValueError( + "can only convert from extended dtype to its representation dtype, " + f"but tried to convert from {dtype_to_string(x.dtype)} to " + f"{dtype_to_string(dtype)} which doesn't match the representation type " + f"{dtype_to_string(rep_aval.dtype)}.") + if all(isinstance(d, int) for d in x.shape): + return core.ShapedArray(shape=(*x.shape, *rep_aval.shape), dtype=dtype) + else: + raise NotImplementedError + +from_edtype_p = Primitive('from_edtype') +from_edtype_p.def_impl(partial(dispatch.apply_primitive, from_edtype_p)) +from_edtype_p.def_abstract_eval(_from_edtype_abstract_eval) +ad.defjvp(from_edtype_p, + lambda t, x, dtype: + convert_element_type(t, core.primal_dtype_to_tangent_dtype(dtype))) +ad.primitive_transposes[from_edtype_p] = \ + lambda ct, x, dtype: [to_edtype_p.bind(ct, edtype=x.dtype)] +batching.defvectorized(from_edtype_p) +mlir.register_lowering(from_edtype_p, lambda _, x, **__: [x]) + + def _bitcast_convert_type_shape_rule(operand, *, new_dtype): old_dtype = dtypes.canonicalize_dtype(operand.dtype) new_dtype = dtypes.canonicalize_dtype(new_dtype) @@ -2656,7 +3055,9 @@ def _validate_preferred_element_type(input_dtype, preferred_element_type): def _dot_general_shape_rule(lhs, rhs, *, dimension_numbers, precision, - preferred_element_type: DTypeLike | None): + preferred_element_type: DTypeLike | None, + algorithm: _DotAlgorithmLike = None, + transpose_algorithm: DotTransposeAlgorithm | None = None): (lhs_contracting, rhs_contracting), (lhs_batch, rhs_batch) = dimension_numbers if not all(np.all(np.greater_equal(d, 0)) and np.all(np.less(d, lhs.ndim)) for d in (lhs_contracting, lhs_batch)): @@ -2732,7 +3133,10 @@ def tuple_delete(tup, idx): def _dot_general_dtype_rule(lhs, rhs, *, dimension_numbers, precision, - preferred_element_type: DTypeLike | None): + preferred_element_type: DTypeLike | None, + algorithm: _DotAlgorithmLike = None, + transpose_algorithm: DotTransposeAlgorithm | None = None): + del dimension_numbers, precision # unused # We're mostly matching XLA's logic here, namely in shape_inference.cc and # primitive_util.h's HigherPrecisionType, e.g. # https://github.com/openxla/xla/blob/ea3a841768d0dcf192e5820c9b25c34c73f2226a/xla/primitive_util.h#L329 @@ -2754,6 +3158,21 @@ def type_properties(dt): f"lax.dot_general argument type error: {lhs.dtype}, {rhs.dtype}") result_dtype = lhs.dtype + if transpose_algorithm is not None and algorithm is None: + raise ValueError( + "When the algorithm argument to dot_general is None, the " + "transpose_algorithm argument is unused and must also be None.") + + if algorithm is not None and algorithm != DotAlgorithm.Preset.DEFAULT: + if preferred_element_type is not None: + raise ValueError( + "The preferred_element_type and algorithm arguments to dot_general " + "cannot both be specified.") + + # This is used to ensure that the output type is equal to the accumulation + # type whenever an algorithm is specified. + preferred_element_type = algorithm.accumulation_type + return _maybe_upcast(result_dtype, preferred_element_type) def _bit_width(d): @@ -2777,6 +3196,8 @@ def _maybe_upcast(result_dtype, preferred_element_type): def _dot_general_transpose_lhs(g, x, y, *, dimension_numbers, precision, preferred_element_type: DTypeLike | None, + algorithm: _DotAlgorithmLike = None, + transpose_algorithm: DotTransposeAlgorithm | None = None, swap_ans=False): (x_contract, y_contract), (x_batch, y_batch) = dimension_numbers x_ndim = x.aval.ndim @@ -2789,20 +3210,35 @@ def _dot_general_transpose_lhs(g, x, y, *, dimension_numbers, precision, dims = ((ans_y, y_kept), (ans_batch, y_batch)) x_contract_sorted_by_y = list(np.take(x_contract, np.argsort(y_contract))) out_axes = np.argsort(list(x_batch) + x_kept + x_contract_sorted_by_y) + if algorithm is not None: + if transpose_algorithm is None or transpose_algorithm[0] is None: + raise ValueError( + "When a dot_general algorithm is specified on the forward pass, " + "transpose_algorithm must be specified for the backward pass.") + lhs_alg, rhs_alg = transpose_algorithm + transpose_algorithm = (algorithm, rhs_alg) + algorithm = lhs_alg x_bar = transpose(dot_general(g, y, dims, precision=precision, - preferred_element_type=preferred_element_type), + preferred_element_type=preferred_element_type, + algorithm=algorithm, + transpose_algorithm=transpose_algorithm), tuple(out_axes)) if x_bar.dtype != x.aval.dtype: x_bar = _convert_element_type(x_bar, x.aval.dtype, x.aval.weak_type) return x_bar def _dot_general_transpose_rhs(g, x, y, *, dimension_numbers, precision, - preferred_element_type: DTypeLike | None): + preferred_element_type: DTypeLike | None, + algorithm: _DotAlgorithmLike = None, + transpose_algorithm: DotTransposeAlgorithm | None = None): (x_contract, y_contract), (x_batch, y_batch) = dimension_numbers swapped_dimension_numbers = ((y_contract, x_contract), (y_batch, x_batch)) + transpose_algorithm = None if transpose_algorithm is None else ( + transpose_algorithm[1], transpose_algorithm[0]) y_bar = _dot_general_transpose_lhs( g, y, x, dimension_numbers=swapped_dimension_numbers, precision=precision, - preferred_element_type=preferred_element_type, + preferred_element_type=preferred_element_type, algorithm=algorithm, + transpose_algorithm=transpose_algorithm, swap_ans=True) if y_bar.dtype != y.aval.dtype: y_bar = _convert_element_type(y_bar, y.aval.dtype, y.aval.weak_type) @@ -2810,7 +3246,9 @@ def _dot_general_transpose_rhs(g, x, y, *, dimension_numbers, precision, def _dot_general_batch_rule(batched_args, batch_dims, *, dimension_numbers, precision, - preferred_element_type: DTypeLike | None): + preferred_element_type: DTypeLike | None, + algorithm: _DotAlgorithmLike = None, + transpose_algorithm: DotTransposeAlgorithm | None = None): lhs, rhs = batched_args lbd, rbd = batch_dims (lhs_contract, rhs_contract), (lhs_batch, rhs_batch) = dimension_numbers @@ -2836,7 +3274,9 @@ def _dot_general_batch_rule(batched_args, batch_dims, *, dimension_numbers, rhs_shape = np.shape(rhs) batched_out = dot_general(lhs, rhs, new_dimension_numbers, precision=precision, - preferred_element_type=preferred_element_type) + preferred_element_type=preferred_element_type, + algorithm=algorithm, + transpose_algorithm=transpose_algorithm) result_batch_dim = batching.shape_as_bdim( result_stack_dim, _dot_general_shape_computation(lhs_shape, rhs_shape, new_dimension_numbers)) @@ -2933,9 +3373,19 @@ def precision_attr(precision: Precision) -> ir.ArrayAttr: [hlo.PrecisionAttr.get(str(p)) for p in full_precision]) +def dot_algorithm_attr(algorithm: _DotAlgorithmLike, lhs_dtype: DTypeLike, + rhs_dtype: DTypeLike) -> hlo.DotAlgorithm | None: + if algorithm is None: + return None + return algorithm._convert_to_hlo_attr(lhs_dtype, rhs_dtype) + + def _dot_general_lower(ctx, lhs, rhs, *, dimension_numbers, precision, preferred_element_type: np.dtype | None, + algorithm: _DotAlgorithmLike = None, + transpose_algorithm: DotTransposeAlgorithm | None = None, platform: str = "default"): + del transpose_algorithm # unused def _is_fp8_mixed_precision_matmul(_lhs_dtypes, _rhs_dtypes): fp8_dtypes = (dtypes.float8_e4m3fn, dtypes.float8_e5m2, dtypes.float8_e5m2fnuz, dtypes.float8_e4m3fnuz) @@ -2976,13 +3426,30 @@ def _is_fp8_mixed_precision_matmul(_lhs_dtypes, _rhs_dtypes): rhs_batching_dimensions=list(rhs_batch), lhs_contracting_dimensions=list(lhs_contracting), rhs_contracting_dimensions=list(rhs_contracting)) + + if algorithm is not None and precision not in { + None, Precision.DEFAULT, (Precision.DEFAULT, Precision.DEFAULT)}: + raise ValueError( + "The dot_general precision must be None or DEFAULT when an algorithm " + "is specified.") + if jaxlib_version <= (0, 4, 33): + if algorithm is not None: + raise ValueError( + "The dot_general algorithm parameter is only supported for jaxlib " + "versions larger than 0.4.33.") + algorithm_kwargs = {} + else: + algorithm_kwargs = {"algorithm": dot_algorithm_attr(algorithm, lhs_dtype, + rhs_dtype)} return [ hlo.dot_general( mlir.aval_to_ir_type(aval_out), lhs, rhs, dot_dnums, - precision_config=precision_attr(precision)) + precision_config=precision_attr(precision), + **algorithm_kwargs, + ) ] mlir.register_lowering(dot_general_p, _dot_general_lower) @@ -3007,11 +3474,13 @@ def _ragged_dot_shape_rule(lhs: Array, rhs: Array, group_sizes: Array, **_) -> S _RAGGED_DOT_DOT_DIMENSION_NUMBERS: DotDimensionNumbers = (([2, 0], [1, 0]), ([], [])) def _ragged_dot_dtype_rule(lhs: Array, rhs: Array, group_sizes: Array, - precision, preferred_element_type: DTypeLike | None, **_) -> np.dtype: + precision, preferred_element_type: DTypeLike | None, **_) -> np.dtype: if not dtypes.issubdtype(group_sizes.dtype, np.integer): raise TypeError("ragged_dot requires that group_sizes.dtype is subtype of np.integer.") # defer the output dtype to dot_general, which is part of the _ragged_dot_impl. - return _dot_general_dtype_rule(lhs, rhs, dimension_numbers=_RAGGED_DOT_DOT_DIMENSION_NUMBERS, precision=precision, preferred_element_type=preferred_element_type) + return _dot_general_dtype_rule(lhs, rhs, dimension_numbers=_RAGGED_DOT_DOT_DIMENSION_NUMBERS, + precision=precision, preferred_element_type=preferred_element_type, + algorithm=None, transpose_algorithm=None) def _ragged_dot_jvp_rule( @@ -3043,7 +3512,7 @@ def _ragged_dot_jvp_rule( preferred_element_type=preferred_element_type, ) if type(dx) is not ad_util.Zero - else jax.numpy.zeros_like(primal_out) + else _zeros(primal_out) ) dy_out = ( ragged_dot( @@ -3054,13 +3523,30 @@ def _ragged_dot_jvp_rule( preferred_element_type=preferred_element_type, ) if type(dy) is not ad_util.Zero - else jax.numpy.zeros_like(primal_out) + else _zeros(primal_out) ) tangent_out = dx_out + dy_out return primal_out, tangent_out +def _ragged_to_dense(x, y, group_sizes): + from jax._src.lax import control_flow # avoid circular imports + shape = (y.shape[0], x.shape[0], x.shape[1]) + x = broadcast_in_dim(x, shape, [1, 2]) + iota = broadcasted_iota(group_sizes.dtype, shape, 1) + group_ends = control_flow.cumsum(group_sizes) + group_starts = concatenate( + [_zeros(group_sizes)[:1], group_ends[:-1]], + dimension=0, + ) + group_ends = broadcast_in_dim(group_ends, shape, (0,)) + group_starts = broadcast_in_dim(group_starts, shape, (0,)) + mask = bitwise_and(group_starts <= iota, iota < group_ends) + x = select(mask, x, _zeros(x)) + return x + + def _ragged_dot_transpose_rule( ct, *operands, precision, preferred_element_type, group_offset ): @@ -3068,28 +3554,10 @@ def _ragged_dot_transpose_rule( if group_offset is not None: raise NotImplementedError('Unimplemented group_offset support.') - def ragged_to_dense(x, group_sizes): - group_count = group_sizes.shape[0] - shape = (group_count, x.shape[0], x.shape[1]) - x_broadcasted = jax.lax.broadcast_in_dim(x, shape, (1, 2)) - iota = jax.lax.broadcasted_iota(group_sizes.dtype, shape, 1) - group_ends = jax.lax.cumsum(group_sizes) - group_starts = concatenate( - [ - np.zeros_like([group_ends[0]], dtype=group_sizes.dtype), - group_ends[:-1], - ], - 0, - ) - group_ends = jax.lax.broadcast_in_dim(group_ends, shape, (0,)) - group_starts = jax.lax.broadcast_in_dim(group_starts, shape, (0,)) - mask = (group_starts <= iota) & (iota < group_ends) - return jax.numpy.where(mask, x_broadcasted, 0) - if ad.is_undefined_primal(y): grad_x = None else: - y_t = jax.numpy.matrix_transpose(y) + y_t = _matrix_transpose(y) grad_x = ragged_dot( ct, y_t, @@ -3101,10 +3569,11 @@ def ragged_to_dense(x, group_sizes): if ad.is_undefined_primal(x): grad_y = None else: - x_dense = ragged_to_dense(x, gs) - ct_dense = ragged_to_dense(ct, gs) + y = y.aval if ad.is_undefined_primal(y) else y + x_dense = _ragged_to_dense(x, y, group_sizes=gs) + ct_dense = _ragged_to_dense(ct, y, group_sizes=gs) dimension_numbers = (([1], [1]), ([0], [0])) - grad_y = jax.lax.dot_general( + grad_y = dot_general( x_dense, ct_dense, dimension_numbers, @@ -3131,17 +3600,7 @@ def _ragged_dot_impl( ) -> Array: if group_offset is not None: raise NotImplementedError("Unimplemented group_offset support.") - shape = (rhs.shape[0], lhs.shape[0], lhs.shape[1]) - lhs = broadcast_in_dim(lhs, shape, [1, 2]) - iota = broadcasted_iota(group_sizes.dtype, shape, 1) - group_ends = jax.lax.cumsum(group_sizes) - group_starts = concatenate( - [_zeros(group_sizes)[:1], group_ends[:-1]], dimension=0, - ) - group_ends = broadcast_in_dim(group_ends, shape, (0,)) - group_starts = broadcast_in_dim(group_starts, shape, (0,)) - mask = bitwise_and(group_starts <= iota, iota < group_ends) - lhs = select(mask, lhs, _zeros(lhs)) + lhs = _ragged_to_dense(lhs, rhs, group_sizes=group_sizes) return dot_general( lhs, rhs, @@ -3307,7 +3766,7 @@ def _broadcast_in_dim_jvp_rule(primals, tangents, *, shape, broadcast_dimensions y = broadcast_in_dim_p.bind(operand, *dyn_shape, shape=shape, broadcast_dimensions=broadcast_dimensions) if type(operand_dot) is ad_util.Zero: - y_dot = ad_util.Zero.from_value(y) + y_dot = ad_util.Zero.from_primal_value(y) else: y_dot = broadcast_in_dim_p.bind(operand_dot, *dyn_shape, shape=shape, broadcast_dimensions=broadcast_dimensions) @@ -3784,8 +4243,7 @@ def _transpose_batch_rule(batched_args, batch_dims, *, permutation): def _transpose_lower(ctx, x, *, permutation): aval_out, = ctx.avals_out if dtypes.issubdtype(aval_out.dtype, dtypes.extended): - elt_shape = aval_out.dtype._rules.physical_element_aval( - aval_out.dtype).shape + elt_shape = core.physical_element_aval(aval_out.dtype).shape trailing_dims = [aval_out.ndim + i for i in range(len(elt_shape))] permutation = [*permutation, *trailing_dims] return [hlo.transpose(x, mlir.dense_int_array(permutation))] @@ -4343,7 +4801,7 @@ def _canonicalize_float_for_sort(x): # and NaNs in the output. result = select(eq(x, _zero(x)), _zeros(x), x) - with jax.debug_nans(False): + with config.debug_nans(False): result = select(_isnan(x), full_like(result, np.nan), result) return result @@ -4477,7 +4935,7 @@ def _top_k_jvp(primals, tangents, *, k): tangent, = tangents primals_out = top_k(operand, k) if type(tangent) is ad_util.Zero: - tangent_out = ad_util.Zero.from_value(primals_out[0]) + tangent_out = ad_util.Zero.from_primal_value(primals_out[0]) else: _, k_idxs = primals_out idx_shape = k_idxs.shape @@ -4496,7 +4954,7 @@ def _top_k_jvp(primals, tangents, *, k): collapsed_slice_dims=tuple(range(rank)), start_index_map=tuple(range(rank))) tangent_out = slicing.gather(tangent, gather_indices, dnums, slice_sizes) - return primals_out, (tangent_out, ad_util.Zero.from_value(primals_out[1])) + return primals_out, (tangent_out, ad_util.Zero.from_primal_value(primals_out[1])) def _top_k_batch_rule(batched_args, batch_dims, *, k): operand, = batched_args @@ -4532,7 +4990,7 @@ def _top_k_lower(ctx, operand, k): def _stop_gradient_jvp_rule(primals, tangents): # if we don't call stop_gradient here, we'd only peel off one autodiff tracer x, = primals - return stop_gradient(x), ad_util.Zero.from_value(x) + return stop_gradient(x), ad_util.Zero.from_primal_value(x) def _stop_gradient_batch_rule(batched_args, batch_dims): x, = batched_args @@ -4854,11 +5312,11 @@ def _copy_impl_pmap_sharding(sharded_dim, *args, **kwargs): return tree_util.tree_unflatten(p.out_tree(), out_flat) -# TODO(https://github.com/google/jax/issues/13552): Look into making this a +# TODO(https://github.com/jax-ml/jax/issues/13552): Look into making this a # method on jax.Array so that we can bypass the XLA compilation here. def _copy_impl(prim, *args, **kwargs): a, = args - if isinstance(a, jax.Array) and isinstance(a.sharding, PmapSharding): + if isinstance(a, Array) and isinstance(a.sharding, PmapSharding): sharded_dim = _which_dim_sharded(a.sharding) if sharded_dim is None: return dispatch.apply_primitive(prim, *args, **kwargs) @@ -5027,13 +5485,15 @@ def padtype_to_pads(in_shape, window_shape, window_strides, padding): for d in (out_shape - 1) * window_strides + window_shape - in_shape) if padding == PaddingType.SAME: - return [ + pads = [ (pad_size // 2, pad_size - pad_size // 2) for pad_size in pad_sizes ] else: - return [ + pads = [ (pad_size - pad_size // 2, pad_size // 2) for pad_size in pad_sizes ] + # Avoids verbose numpy scalars in jaxprs. + return [p.item() if isinstance(p, np.generic) else p for p in pads] elif padding == PaddingType.VALID: return [(0, 0)] * len(in_shape) else: @@ -5217,6 +5677,29 @@ def canonicalize_precision(precision: PrecisionLike) -> tuple[Precision, Precisi "a lax.Precision value or a tuple of two lax.Precision values or " f"strings; got {precision}.") +def canonicalize_dot_algorithm(algorithm: DotAlgorithmLike) -> _DotAlgorithmLike: + if isinstance(algorithm, str): + algorithm = DotAlgorithm.Preset[algorithm] + if algorithm is None or algorithm == DotAlgorithm.Preset.DEFAULT: + return None + return algorithm + +def canonicalize_dot_transpose_algorithm( + algorithm: DotTransposeAlgorithmLike) -> DotTransposeAlgorithm | None: + if algorithm is None: + return None + elif isinstance(algorithm, DotAlgorithm): + return (algorithm, algorithm) + elif isinstance(algorithm, tuple): + if len(algorithm) != 2: + raise ValueError( + "The transpose_algorithm argument must be a single value or a tuple " + f"of two values; got {algorithm}.") + return (canonicalize_dot_algorithm(algorithm[0]), + canonicalize_dot_algorithm(algorithm[1])) + algorithm = canonicalize_dot_algorithm(algorithm) + return (algorithm, algorithm) + def _balanced_eq(x, z, y): return div(select(_eq_meet(x, z), _ones(z), _zeros(z)), select(_eq_meet(y, z), _twos(z), _ones(z))) @@ -5260,6 +5743,8 @@ def _empty_lower(ctx, *, dtype): class BIntRules: + allow_conversion: bool = True + @staticmethod def physical_element_aval(dtype) -> core.ShapedArray: return core.ShapedArray((), np.dtype('int32')) @@ -5286,13 +5771,61 @@ def handler(bufs): return core.DArray(aval, phys_handler(bufs)) return handler - @staticmethod - def convert_from(bint_dtype, other_dtype) -> bool: - return other_dtype in (np.dtype('int32'), np.dtype('int64')) - @staticmethod - def convert_to(other_dtype, bint_dtype) -> bool: - return other_dtype in (np.dtype('int32'), np.dtype('int64')) +core.bint._rules = BIntRules -core.bint._rules = BIntRules +def optimization_barrier(operand, /): + """Prevents the compiler from moving operations across the barrier. + + Optimization barriers have a number of possible uses: + + * An optimization barrier ensures that all inputs are evaluated before any + operators that depend on the barrier's outputs. This can be used to enforce + a particular order of operations. + * An optimization barrier prevents common subexpression elimination. This is + used by JAX to implement rematerialization. + * Optimization barriers prevent compiler fusions. That is, operations before + the barrier may not be fused into the same kernel as operations after the + barrier by the compiler. + + JAX does not define derivative or batching rules for an optimization barrier. + + Optimization barriers have no effect outside a compiled function. + + Args: + operand: a pytree of JAX values. + + Returns: + A pytree of JAX values, with the same structure and contents as ``operand``. + + Examples: + Prevents common-subexpression elimination between the two calls to `sin`: + + >>> def f(x): + ... return jax.lax.optimization_barrier(jax.lax.sin(x)) + jax.lax.sin(x) + >>> jax.jit(f)(0.) + Array(0., dtype=float32, weak_type=True) + """ + flat_args, treedef = tree_util.tree_flatten(operand) + return tree_util.tree_unflatten( + treedef, optimization_barrier_p.bind(*flat_args)) + + +def _optimization_barrier_abstract_eval(*args): + return args + +def _optimization_barrier_lowering_rule(ctx, *args): + barrier_types = map(mlir.aval_to_ir_type, ctx.avals_in) + flat_args = mlir.flatten_ir_values(args) + barrier_op = hlo.OptimizationBarrierOp(flat_args) + return mlir.unflatten_ir_values_like_types(barrier_op.results, barrier_types) + + +optimization_barrier_p = core.Primitive('optimization_barrier') +optimization_barrier_p.multiple_results = True +optimization_barrier_p.def_impl( + partial(dispatch.apply_primitive, optimization_barrier_p)) +optimization_barrier_p.def_abstract_eval(_optimization_barrier_abstract_eval) +mlir.register_lowering(optimization_barrier_p, + _optimization_barrier_lowering_rule) diff --git a/jax/_src/lax/linalg.py b/jax/_src/lax/linalg.py index 3bd7e37e54ca..ef6a5a11a56e 100644 --- a/jax/_src/lax/linalg.py +++ b/jax/_src/lax/linalg.py @@ -22,15 +22,18 @@ import numpy as np -import jax from jax import lax from jax._src import ad_util from jax._src import api +from jax._src import config +from jax._src import core from jax._src import dispatch from jax._src import dtypes +from jax._src import util from jax._src.core import ( Primitive, ShapedArray, raise_to_shaped, is_constant_dim, is_constant_shape) +from jax._src.extend import ffi from jax._src.interpreters import ad from jax._src.interpreters import batching from jax._src.interpreters import mlir @@ -41,25 +44,50 @@ from jax._src.lax.lax import ( standard_primitive, standard_unop, naryop_dtype_rule, _float, _complex, _input_dtype) -from jax._src.lib import gpu_linalg from jax._src.lib import gpu_solver from jax._src.lib import gpu_sparse from jax._src.lib import lapack from jax._src.lib import version as jaxlib_version -from jax._src.lib import xla_client from jax._src.lib.mlir import ir from jax._src.lib.mlir.dialects import chlo from jax._src.lib.mlir.dialects import hlo -from jax._src.numpy import lax_numpy as jnp -from jax._src.numpy import reductions -from jax._src.numpy import ufuncs -from jax._src.numpy.vectorize import vectorize from jax._src.typing import Array, ArrayLike -xops = xla_client.ops +# The following import is unused but needed to register the custom_call targets +# in the gpu_linalg module. +from jax._src.lib import gpu_linalg # noqa: F401 TFun = TypeVar('TFun', bound=Callable[..., Any]) +def _broadcasted_iotas(*sizes): + ones = (1,) * (len(sizes) - 1) + shapes = (util.tuple_insert(ones, i, s) for i, s in enumerate(sizes)) + return [lax.broadcasted_iota('int32', shape, i) for i, shape in enumerate(shapes)] + +def _tril(m: Array, k:int = 0) -> Array: + *_, N, M = m.shape + mask = lax_internal._tri(bool, (N, M), k) + return lax.select(lax.broadcast(mask, m.shape[:-2]), m, lax.zeros_like_array(m)) + +def _triu(m: Array, k:int = 0) -> Array: + *_, N, M = m.shape + mask = lax_internal._tri(bool, (N, M), k - 1) + return lax.select(lax.broadcast(mask, m.shape[:-2]), lax.zeros_like_array(m), m) + +def _construct_diagonal(s: Array) -> Array: + """Construct a (batched) diagonal matrix""" + i = lax.iota('int32', s.shape[-1]) + return lax.full((*s.shape, s.shape[-1]), 0, s.dtype).at[..., i, i].set(s) + +def _extract_diagonal(s: Array) -> Array: + """Extract the diagonal from a batched matrix""" + i = lax.iota('int32', min(s.shape[-2], s.shape[-1])) + return s[..., i, i] + +def _broadcast_to(x: Array, shape: tuple[int, ...]) -> Array: + assert x.ndim <= len(shape) + return lax.broadcast_in_dim(x, shape, range(len(shape) - x.ndim, len(shape))) + # traceables def cholesky(x: Array, *, symmetrize_input: bool = True) -> Array: @@ -89,7 +117,7 @@ def cholesky(x: Array, *, symmetrize_input: bool = True) -> Array: """ if symmetrize_input: x = symmetrize(x) - return jnp.tril(cholesky_p.bind(x)) + return _tril(cholesky_p.bind(x)) def eig(x: ArrayLike, *, compute_left_eigenvectors: bool = True, @@ -184,6 +212,20 @@ def cholesky_update(r_matrix: ArrayLike, w_vector: ArrayLike) -> Array: return cholesky_update_p.bind(r_matrix, w_vector) +def symmetric_product( + a_matrix: ArrayLike, c_matrix: ArrayLike, + alpha: float = 1., beta: float = 0., + symmetrize_output=False): + """Computes C = alpha * A @ A.T + beta * C (where C is symmetric).""" + result = symmetric_product_p.bind(a_matrix, c_matrix, alpha=alpha, beta=beta) + if symmetrize_output: + upper_half = lax.transpose( + _tril(result, k=-1), + (*range(result.ndim - 2), result.ndim - 1, result.ndim - 2)) + result = _tril(result, k=0) + upper_half + return result + + def lu_pivots_to_permutation(pivots: ArrayLike, permutation_size: int) -> Array: """Converts the pivots (row swaps) returned by LU to a permutation. @@ -198,7 +240,7 @@ def lu_pivots_to_permutation(pivots: ArrayLike, permutation_size: int) -> Array: An int32 array of shape (..., permutation_size). """ permutation = lu_pivots_to_permutation_p.bind( - pivots, permutation_size=int(permutation_size)) + pivots, permutation_size=permutation_size) return permutation @@ -366,10 +408,10 @@ def triangular_solve(a: ArrayLike, b: ArrayLike, *, Returns: A batch of matrices the same shape and dtype as ``b``. """ - conjugate_a = conjugate_a and jnp.issubdtype(lax.dtype(a), jnp.complexfloating) - singleton = jnp.ndim(b) == jnp.ndim(a) - 1 + conjugate_a = conjugate_a and dtypes.issubdtype(lax.dtype(a), np.complexfloating) + singleton = np.ndim(b) == np.ndim(a) - 1 if singleton: - b = jnp.expand_dims(b, -1 if left_side else -2) + b = lax.expand_dims(b, (-1 if left_side else -2,)) out = triangular_solve_p.bind( a, b, left_side=left_side, lower=lower, transpose_a=transpose_a, conjugate_a=conjugate_a, unit_diagonal=unit_diagonal) @@ -379,9 +421,17 @@ def triangular_solve(a: ArrayLike, b: ArrayLike, *, # utilities -@partial(vectorize, signature='(n,m),(m)->(n)') -def _matvec_multiply(a: Array, b: Array) -> Array: - return lax.dot(a, b, precision=lax.Precision.HIGHEST) +def _broadcasted_matvec(a: Array, b: Array) -> Array: + # This is a broadcasted dot_general with signature (...,n,m),(...,m)->(...,n) + assert a.ndim >= 2 + assert b.ndim >= 1 + batch_shape = lax.broadcast_shapes(a.shape[:-2], b.shape[:-1]) + n_batch = len(batch_shape) + a = _broadcast_to(a, (*batch_shape, *a.shape[-2:])) + b = _broadcast_to(b, (*batch_shape, b.shape[-1])) + + dimension_numbers = (([a.ndim - 1], [b.ndim - 1]), (list(range(n_batch)), list(range(n_batch)))) + return lax.dot_general(a, b, dimension_numbers=dimension_numbers, precision=lax.Precision.HIGHEST) def _check_solve_shapes(a: Array, b: Array): if not (a.ndim >= 2 and b.ndim in [a.ndim, a.ndim - 1] and @@ -397,14 +447,14 @@ def _solve(a: Array, b: Array) -> Array: # custom_linear_solve. out_shape = tuple(d_a if d_b == 1 else d_b for d_a, d_b in zip(a.shape[:-1] + (1,), b.shape)) - b = jnp.broadcast_to(b, out_shape) + b = lax.broadcast_in_dim(b, out_shape, range(b.ndim)) # With custom_linear_solve, we can reuse the same factorization when # computing sensitivities. This is considerably faster. lu_, _, permutation = lu(lax.stop_gradient(a)) custom_solve = partial( lax.custom_linear_solve, - lambda x: _matvec_multiply(a, x), + lambda x: _broadcasted_matvec(a, x), solve=lambda _, x: lu_solve(lu_, permutation, x, trans=0), transpose_solve=lambda _, x: lu_solve(lu_, permutation, x, trans=1)) if a.ndim == b.ndim + 1: @@ -414,8 +464,10 @@ def _solve(a: Array, b: Array) -> Array: # b.shape == [..., m, k] return api.vmap(custom_solve, b.ndim - 1, max(a.ndim, b.ndim) - 1)(b) -def _T(x: Array) -> Array: return jnp.swapaxes(x, -1, -2) -def _H(x: Array) -> Array: return ufuncs.conj(_T(x)) +def _T(x: Array) -> Array: + return lax.transpose(x, (*range(x.ndim - 2), x.ndim - 1, x.ndim - 2)) +def _H(x: Array) -> Array: + return _T(x).conj() def symmetrize(x: Array) -> Array: return (x + _H(x)) / 2 # primitives @@ -428,13 +480,13 @@ def symmetrize(x: Array) -> Array: return (x + _H(x)) / 2 def _cholesky_jvp_rule(primals, tangents): x, = primals sigma_dot, = tangents - L = jnp.tril(cholesky_p.bind(x)) + L = _tril(cholesky_p.bind(x)) # Forward-mode rule from https://arxiv.org/pdf/1602.07527.pdf def phi(X): - l = jnp.tril(X) + l = _tril(X) return l / lax.expand_dims( - lax_internal._const(X, 1) + jnp.eye(X.shape[-1], dtype=X.dtype), + lax_internal._const(X, 1) + lax_internal._eye(X.dtype, (X.shape[-1], X.shape[-1])), range(l.ndim - 2)) tmp = triangular_solve(L, sigma_dot, left_side=False, transpose_a=True, @@ -464,11 +516,7 @@ def _cholesky_cpu_lowering(ctx, operand): out_aval, = ctx.avals_out batch_dims = operand_aval.shape[:-2] op_shape_vals = mlir.eval_dynamic_shape_as_ivals(ctx, operand_aval.shape) - # TODO(b/344892332): Remove the check after the compatibility period. - if jaxlib_version < (0, 4, 31): - ctx_arg = () - else: - ctx_arg = (ctx,) + ctx_arg = (ctx,) result, info = lapack.potrf_hlo(*ctx_arg, operand_aval.dtype, operand, lower=True, a_shape_vals=op_shape_vals) @@ -504,42 +552,33 @@ def _cholesky_update_abstract_eval(r_matrix, w_vector): r_matrix.shape, w_vector.shape)) return ShapedArray(r_matrix.shape, r_matrix.dtype) -def _cholesky_update_cuda_lowering_rule(ctx, r_matrix, w_vector): - r_matrix_aval, _ = ctx.avals_in - try: - [platform] = ctx.module_context.platforms - except ValueError: - raise ValueError( - "Can only lower cholesky_update on a single platform." - ) from None - if platform != "cuda": - raise NotImplementedError( - "Can only lower fast cholesky_update on CUDA." - ) - return gpu_linalg.cuda_cholesky_update( - r_matrix, w_vector, r_matrix_aval.dtype) +def _cholesky_update_gpu_lowering_rule(target_name_prefix, ctx, r_matrix, w_vector): + rule = ffi.ffi_lowering(f"{target_name_prefix}_cholesky_update_ffi", + operand_output_aliases={0: 0, 1: 1}) + sub_ctx = ctx.replace(avals_out=ctx.avals_in) + return rule(sub_ctx, r_matrix, w_vector)[:1] def _cholesky_update_jax_fn(R, z): def _drotg(x, y): """Get coefs for Givens rotation in a numerically stable way.""" def _drotg_nonzero(x, y): - abs_x = jax.numpy.abs(x) - abs_y = jax.numpy.abs(y) - denominator = jnp.where(abs_x > abs_y, abs_x, abs_y) + abs_x = abs(x) + abs_y = abs(y) + denominator = lax.select(abs_x > abs_y, abs_x, abs_y) x /= denominator y /= denominator - rh = 1 / jax.numpy.sqrt(x ** 2 + y ** 2) + rh = 1 / lax.sqrt(x ** 2 + y ** 2) return x * rh, -y * rh one_and_zero = ( - jnp.array(1., dtype=x.dtype), - jnp.array(0., dtype=x.dtype), + np.array(1., dtype=x.dtype), + np.array(0., dtype=x.dtype), ) - return jax.lax.cond(y == 0, lambda x, y: one_and_zero, _drotg_nonzero, x, y) + return lax.cond(y == 0, lambda x, y: one_and_zero, _drotg_nonzero, x, y) def _drot( - first_vector: jax.Array, second_vector: jax.Array, - c_coef: float, s_coef: float) -> tuple[jax.Array, jax.Array]: + first_vector: Array, second_vector: Array, + c_coef: float, s_coef: float) -> tuple[Array, Array]: return ( c_coef * first_vector - s_coef * second_vector, c_coef * second_vector + s_coef * first_vector) @@ -550,18 +589,81 @@ def _drot( R = R.at[k, :].set(row_k) return R + cholesky_update_p = Primitive('cholesky_update') cholesky_update_p.multiple_results = False cholesky_update_p.def_abstract_eval(_cholesky_update_abstract_eval) cholesky_update_p.def_impl(partial(dispatch.apply_primitive, cholesky_update_p)) mlir.register_lowering( - cholesky_update_p, _cholesky_update_cuda_lowering_rule, platform='cuda') - + cholesky_update_p, partial(_cholesky_update_gpu_lowering_rule, "cu"), + platform='cuda') mlir.register_lowering( cholesky_update_p, mlir.lower_fun(_cholesky_update_jax_fn, multiple_results=False)) +# symmetric_update + +def _symmetric_product_abstract_eval(a, c, *, alpha, beta): + a_dtype = dtypes.canonicalize_dtype(a.dtype) + c_dtype = dtypes.canonicalize_dtype(c.dtype) + if not (a_dtype == c_dtype and a_dtype in (np.float32, np.float64)): + raise NotImplementedError( + "Symmetric update is only implemented for float32 and float64.") + if not (a.ndim >= 2 and c.ndim >= 2 + and a.shape[-2] == c.shape[-1] + and c.shape[-1] == c.shape[-2]): + raise ValueError( + "Symmetric update takes (maybe batched) matrices of matching shapes. " + "Got shapes {}, {} instead".format(a.shape, c.shape)) + return ShapedArray(c.shape, c.dtype) + + +def _symmetric_product_batching_rule(batched_args, batch_dims, *, alpha, beta): + a_tensor, c_tensor = batched_args + a_bd, c_bd = batch_dims + a_tensor = batching.moveaxis(a_tensor, a_bd, 0) + c_tensor = batching.moveaxis(c_tensor, c_bd, 0) + return ( + symmetric_product_p.bind(a_tensor, c_tensor, alpha=alpha, beta=beta), 0) + +symmetric_product_p = Primitive('symmetric_update') +symmetric_product_p.multiple_results = False +symmetric_product_p.def_abstract_eval(_symmetric_product_abstract_eval) +symmetric_product_p.def_impl( + partial(dispatch.apply_primitive, symmetric_product_p)) +batching.primitive_batchers[ + symmetric_product_p] = _symmetric_product_batching_rule + + +def _symmetric_product_gpu_lowering( + platform, ctx, a_tensor, c_tensor, alpha, beta): + a_aval, c_aval = ctx.avals_in[:2] + dtype = a_aval.dtype + alpha_aval = beta_aval = ShapedArray((), dtype) + + alpha_array = mlir.full_like_aval(ctx, alpha, alpha_aval) + beta_array = mlir.full_like_aval(ctx, beta, beta_aval) + + rule = ffi.ffi_lowering(f"{platform}solver_syrk_ffi", + operand_output_aliases={1: 0}) + ctx = ctx.replace(avals_in=[a_aval, c_aval, alpha_aval, beta_aval]) + return rule(ctx, a_tensor, c_tensor, alpha_array, beta_array, transpose=False) + + +def _symmetric_product_jax_fn(a, c, *, alpha, beta): + a_T = lax.transpose(a, (*range(a.ndim - 2), a.ndim - 1, a.ndim - 2)) + return alpha * lax.batch_matmul( + a, a_T, precision=lax.Precision.HIGHEST) + beta * c + + +mlir.register_lowering( + symmetric_product_p, + partial(_symmetric_product_gpu_lowering, 'cu'), platform='cuda') +mlir.register_lowering( + symmetric_product_p, + mlir.lower_fun(_symmetric_product_jax_fn, multiple_results=False)) + # Asymmetric eigendecomposition def eig_impl(operand, *, compute_left_eigenvectors, compute_right_eigenvectors): @@ -607,7 +709,8 @@ def _eig_cpu_lowering(ctx, operand, *, compute_left_eigenvectors, out_aval = ctx.avals_out[0] batch_dims = operand_aval.shape[:-2] op_shape_vals = mlir.eval_dynamic_shape_as_ivals(ctx, operand_aval.shape) - w, vl, vr, info = lapack.geev_hlo(operand_aval.dtype, operand, + ctx_args = (ctx,) + w, vl, vr, info = lapack.geev_hlo(*ctx_args, operand_aval.dtype, operand, input_shape_vals=op_shape_vals, jobvl=compute_left_eigenvectors, jobvr=compute_right_eigenvectors) @@ -665,13 +768,13 @@ def eig_jvp_rule(primals, tangents, *, compute_left_eigenvectors, raise NotImplementedError( 'The derivatives of eigenvectors are not implemented, only ' 'eigenvalues. See ' - 'https://github.com/google/jax/issues/2748 for discussion.') + 'https://github.com/jax-ml/jax/issues/2748 for discussion.') # Formula for derivative of eigenvalues w.r.t. a is eqn 4.60 in # https://arxiv.org/abs/1701.00392 a, = primals da, = tangents l, v = eig(a, compute_left_eigenvectors=False) - return [l], [reductions.sum(_solve(v, da.astype(v.dtype)) * _T(v), -1)] + return [l], [(_solve(v, da.astype(v.dtype)) * _T(v)).sum(-1)] eig_p = Primitive('eig') eig_p.multiple_results = True @@ -793,7 +896,8 @@ def _eigh_abstract_eval(operand, *, lower, sort_eigenvalues, subset_by_index): def _eigh_cpu_gpu_lowering( - syevd_impl, ctx, operand, *, lower, sort_eigenvalues, subset_by_index + syevd_impl, ctx, operand, *, lower, sort_eigenvalues, subset_by_index, + platform=None ): del sort_eigenvalues # The CPU/GPU implementations always sort. operand_aval, = ctx.avals_in @@ -813,7 +917,11 @@ def _eigh_cpu_gpu_lowering( raise NotImplementedError("subset_by_index not implemented for CPU and GPU") op_shape_vals = mlir.eval_dynamic_shape_as_ivals(ctx, operand_aval.shape) - v, w, info = syevd_impl(operand_aval.dtype, operand, + cpu_args = [] + if platform == "cpu": + ctx_args = (ctx,) + cpu_args.extend(ctx_args) + v, w, info = syevd_impl(*cpu_args, operand_aval.dtype, operand, a_shape_vals=op_shape_vals, lower=lower) zeros = mlir.full_like_aval(ctx, 0, ShapedArray(batch_dims, np.dtype(np.int32))) @@ -859,16 +967,17 @@ def eigh_qdwh(x): # We should only look at elements from the lower/upper triangle. Reflects # that triangle into the other triangle to form a Hermitian matrix. if lower: - mask = jnp.tri(n, k=0, dtype=bool) + mask = lax_internal._tri(bool, (n, n), 0) else: - mask = ufuncs.logical_not(jnp.tri(n, k=-1, dtype=bool)) - if dtypes.issubdtype(x.dtype, jnp.complexfloating): + mask = lax.bitwise_not(lax_internal._tri(bool, (n, n), -1)) + if dtypes.issubdtype(x.dtype, np.complexfloating): re = lax.select(mask, lax.real(x), _T(lax.real(x))) if lower: - im_mask = jnp.tri(n, k=-1, dtype=bool) + im_mask = lax_internal._tri(bool, (n, n), -1) else: - im_mask = ufuncs.logical_not(jnp.tri(n, k=0, dtype=bool)) - im = lax.select(im_mask, lax.imag(x), jnp.zeros_like(lax.imag(x))) + im_mask = lax.bitwise_not(lax_internal._tri(bool, (n, n), 0)) + im = lax.imag(x) + im = lax.select(im_mask, im, lax.full_like(im, 0)) im = lax.select(mask, im, -_T(im)) x = lax.complex(re, im) else: @@ -913,15 +1022,15 @@ def _eigh_jvp_rule( # for complex numbers we need eigenvalues to be full dtype of v, a: w = w_real.astype(a.dtype) - eye_n = jnp.eye(n, dtype=a.dtype) + eye_n = lax_internal._eye(a.dtype, (n, n)) # carefully build reciprocal delta-eigenvalue matrix, avoiding NaNs. - Fmat = ufuncs.reciprocal(eye_n + w[..., jnp.newaxis, :] - w[..., jnp.newaxis]) - eye_n + Fmat = lax.integer_pow(eye_n + w[..., np.newaxis, :] - w[..., np.newaxis], -1) - eye_n # eigh impl doesn't support batch dims, but future-proof the grad. dot = partial(lax.dot if a.ndim == 2 else lax.batch_matmul, precision=lax.Precision.HIGHEST) vdag_adot_v = dot(dot(_H(v), a_dot), v) - dv = dot(v, ufuncs.multiply(Fmat, vdag_adot_v)) - dw = ufuncs.real(jnp.diagonal(vdag_adot_v, axis1=-2, axis2=-1)) + dv = dot(v, Fmat * vdag_adot_v) + dw = _extract_diagonal(vdag_adot_v.real) return (v, w_real), (dv, dw) @@ -947,15 +1056,17 @@ def _eigh_batching_rule( batching.primitive_batchers[eigh_p] = _eigh_batching_rule mlir.register_lowering( - eigh_p, partial(_eigh_cpu_gpu_lowering, lapack.syevd_hlo), + eigh_p, partial(_eigh_cpu_gpu_lowering, lapack.syevd_hlo, platform='cpu'), platform='cpu') if gpu_solver is not None: mlir.register_lowering( - eigh_p, partial(_eigh_cpu_gpu_lowering, gpu_solver.cuda_syevd), + eigh_p, partial(_eigh_cpu_gpu_lowering, gpu_solver.cuda_syevd, + platform='cuda'), platform='cuda') mlir.register_lowering( - eigh_p, partial(_eigh_cpu_gpu_lowering, gpu_solver.rocm_syevd), + eigh_p, partial(_eigh_cpu_gpu_lowering, gpu_solver.rocm_syevd, + platform='rocm'), platform='rocm') mlir.register_lowering( @@ -993,10 +1104,10 @@ def _triangular_solve_jvp_rule_a( unit_diagonal): m, n = b.shape[-2:] k = 1 if unit_diagonal else 0 - g_a = jnp.tril(g_a, k=-k) if lower else jnp.triu(g_a, k=k) + g_a = _tril(g_a, k=-k) if lower else _triu(g_a, k=k) g_a = lax.neg(g_a) - g_a = jnp.swapaxes(g_a, -1, -2) if transpose_a else g_a - g_a = ufuncs.conj(g_a) if conjugate_a else g_a + g_a = _T(g_a) if transpose_a else g_a + g_a = g_a.conj() if conjugate_a else g_a dot = partial(lax.dot if g_a.ndim == 2 else lax.batch_matmul, precision=lax.Precision.HIGHEST) @@ -1131,11 +1242,11 @@ def _lu_pivots_body_fn(i, permutation_and_swaps): permutation, swaps = permutation_and_swaps batch_dims = swaps.shape[:-1] j = swaps[..., i] - iotas = jnp.ix_(*(lax.iota(jnp.int32, b) for b in batch_dims)) + iotas = _broadcasted_iotas(*batch_dims) x = permutation[..., i] - y = permutation[iotas + (j,)] + y = permutation[(*iotas, j)] permutation = permutation.at[..., i].set(y) - return permutation.at[iotas + (j,)].set(x), swaps + return permutation.at[(*iotas, j)].set(x), swaps def _generic_lu_pivots_to_permutation(swaps, permutation_size): @@ -1155,12 +1266,13 @@ def _generic_lu_pivots_to_permutation(swaps, permutation_size): k = swaps.shape[-1] m = permutation_size - permutation = lax.broadcasted_iota(jnp.int32, batch_dims + (m,), + permutation = lax.broadcasted_iota(np.int32, batch_dims + (m,), len(batch_dims)) - if m == 0: + if m == 0 or k == 0: return permutation - result, _ = lax.fori_loop(np.array(0, np.int32), np.array(k, np.int32), - _lu_pivots_body_fn, (permutation, swaps)) + upper = np.array(k, np.int32) if is_constant_dim(k) else k + result, _ = lax.fori_loop(np.array(0, np.int32), upper, _lu_pivots_body_fn, + (permutation, swaps)) return result @@ -1171,18 +1283,14 @@ def _lu_pivots_to_permutation_abstract_eval(pivots, *, permutation_size): raise ValueError( 'Argument to lu_pivots_to_permutation must have rank >= 1 and dtype ' 'int32. Got shape={} and dtype={}'.format(pivots.shape, pivots.dtype)) - - if permutation_size < pivots.shape[-1]: + pivots_size = pivots.shape[-1] + if not permutation_size >= pivots_size: raise ValueError( 'Output permutation size {} has to exceed the trailing dimension of ' - 'the pivots. Got shape {}'.format(permutation_size, pivots.shape)) - - batch_dims = pivots.shape[:-1] - permutations = pivots.update(shape=batch_dims + (permutation_size,)) + 'the pivots. Got pivots size {}'.format(permutation_size, pivots_size)) + return pivots.update(shape=(*pivots.shape[:-1], permutation_size)) else: - permutations = pivots - - return permutations + return pivots def _lu_pivots_to_permutation_batching_rule(batched_args, batch_dims, *, @@ -1193,9 +1301,15 @@ def _lu_pivots_to_permutation_batching_rule(batched_args, batch_dims, *, return lu_pivots_to_permutation_p.bind( x, permutation_size=permutation_size), 0 -def _lu_pivots_to_permutation_gpu_lowering(lowering, ctx, pivots, *, +def _lu_pivots_to_permutation_gpu_lowering(platform, ctx, pivots, *, permutation_size): - return lowering(pivots, permutation_size=permutation_size) + rule = ffi.ffi_lowering(f"{platform}_lu_pivots_to_permutation") + # TODO(b/358275922): remove unused once jaxlib v0.4.32 is the minimum version. + if ctx.is_forward_compat() or jaxlib_version < (0, 4, 32): + kwargs = dict(permutation_size=np.int32(permutation_size)) + else: + kwargs = {} + return rule(ctx, pivots, **kwargs) lu_pivots_to_permutation_p = Primitive('lu_pivots_to_permutation') @@ -1211,13 +1325,11 @@ def _lu_pivots_to_permutation_gpu_lowering(lowering, ctx, pivots, *, mlir.lower_fun(_generic_lu_pivots_to_permutation, multiple_results=False)) mlir.register_lowering( lu_pivots_to_permutation_p, - partial(_lu_pivots_to_permutation_gpu_lowering, - gpu_linalg.cuda_lu_pivots_to_permutation), + partial(_lu_pivots_to_permutation_gpu_lowering, "cu"), platform='cuda') mlir.register_lowering( lu_pivots_to_permutation_p, - partial(_lu_pivots_to_permutation_gpu_lowering, - gpu_linalg.hip_lu_pivots_to_permutation), + partial(_lu_pivots_to_permutation_gpu_lowering, "hip"), platform='rocm') # LU decomposition @@ -1231,30 +1343,32 @@ def _lu_unblocked(a): m, n = a.shape def body(k, state): pivot, perm, a = state - m_idx = jnp.arange(m) - n_idx = jnp.arange(n) + m_idx = lax.iota('int32', m) + n_idx = lax.iota('int32', n) - if jnp.issubdtype(a.dtype, jnp.complexfloating): + if dtypes.issubdtype(a.dtype, np.complexfloating): t = a[:, k] - magnitude = ufuncs.abs(ufuncs.real(t)) + ufuncs.abs(ufuncs.imag(t)) + magnitude = abs(t.real) + abs(t.imag) else: - magnitude = ufuncs.abs(a[:, k]) - i = jnp.argmax(jnp.where(m_idx >= k, magnitude, -jnp.inf)) - pivot = pivot.at[k].set(i.astype(pivot.dtype)) + magnitude = abs(a[:, k]) + i = lax.argmax(lax.select(m_idx >= k, magnitude, lax.full_like(magnitude, -np.inf)), + axis=0, index_dtype=pivot.dtype) + pivot = pivot.at[k].set(i) a = a.at[[k, i],].set(a[[i, k],]) perm = perm.at[[i, k],].set(perm[[k, i],]) # a[k+1:, k] /= a[k, k], adapted for loop-invariant shapes x = a[k, k] - a = a.at[:, k].set(jnp.where((m_idx > k) & (x != 0), a[:, k] / x, a[:, k])) + a = a.at[:, k].set(lax.select((m_idx > k) & (x != 0), a[:, k] / x, a[:, k])) # a[k+1:, k+1:] -= jnp.outer(a[k+1:, k], a[k, k+1:]) - a = a - jnp.where((m_idx[:, None] > k) & (n_idx[None, :] > k), - jnp.outer(a[:, k], a[k, :]), jnp.array(0, dtype=a.dtype)) + a_outer = a[:, k, None] * a[k, None] + a = a - lax.select((m_idx[:, None] > k) & (n_idx[None, :] > k), + a_outer, lax_internal._zeros(a_outer)) return pivot, perm, a - pivot = jnp.zeros((min(m, n),), dtype=jnp.int32) - perm = jnp.arange(m, dtype=jnp.int32) + pivot = lax.full((min(m, n),), 0, dtype=np.int32) + perm = lax.iota('int32', m) if m == 0 and n == 0: # If the array is empty, the loop body never executes but tracing it to a # jaxpr fails because the indexing cannot succeed. @@ -1266,8 +1380,8 @@ def _lu_blocked(a, block_size=128): """Blocked LU decomposition, as an unrolled loop.""" m, n = a.shape r = min(m, n) - pivot = jnp.zeros((r,), dtype=jnp.int32) - perm = jnp.arange(m, dtype=jnp.int32) + pivot = lax.full((r,), 0, dtype=np.int32) + perm = lax.iota('int32', m) for k in range(0, r, block_size): b = min(r - k, block_size) block_pivot, block_perm, lu_block = _lu_unblocked(a[k:, k:k+b]) @@ -1307,8 +1421,9 @@ def _lu_abstract_eval(operand): batch_dims = operand.shape[:-2] m = operand.shape[-2] n = operand.shape[-1] - pivot = operand.update(shape=batch_dims + (min(m, n),), dtype=jnp.int32) - perm = operand.update(shape=batch_dims + (m,), dtype=jnp.int32) + pivot = operand.update(shape=batch_dims + (core.min_dim(m, n),), + dtype=np.int32) + perm = operand.update(shape=batch_dims + (m,), dtype=np.int32) else: pivot = operand perm = operand @@ -1319,14 +1434,14 @@ def _lu_jvp_rule(primals, tangents): a_dot, = tangents lu, pivots, permutation = lu_p.bind(a) - a_shape = jnp.shape(a) + a_shape = np.shape(a) m, n = a_shape[-2:] dtype = lax.dtype(a) k = min(m, n) batch_dims = a_shape[:-2] - iotas = jnp.ix_(*(lax.iota(jnp.int32, b) for b in batch_dims + (1,))) - x = a_dot[iotas[:-1] + (permutation, slice(None))] + iotas = _broadcasted_iotas(*batch_dims, 1) + x = a_dot[(*iotas[:-1], permutation, slice(None))] # Differentiation of Matrix Functionals Using Triangular Factorization # F. R. De Hoog, R. S. Anderssen, and M. A. Lukas @@ -1341,14 +1456,13 @@ def _lu_jvp_rule(primals, tangents): l_padding = [(0, 0, 0)] * ndims l_padding[-1] = (0, m - k, 0) zero = lax_internal._const(lu, 0) - l = lax.pad(jnp.tril(lu[..., :, :k], -1), zero, l_padding) - l = l + lax.expand_dims(jnp.eye(m, m, dtype=dtype), range(l.ndim - 2)) - - u_eye = lax.pad(jnp.eye(n - k, n - k, dtype=dtype), zero, + l = lax.pad(_tril(lu[..., :, :k], -1), zero, l_padding) + l = l + lax.expand_dims(lax_internal._eye(dtype, (m, m)), range(l.ndim - 2)) + u_eye = lax.pad(lax_internal._eye(dtype, (n - k, n - k)), zero, ((k, 0, 0), (k, 0, 0))) u_padding = [(0, 0, 0)] * ndims u_padding[-2] = (0, n - k, 0) - u = (lax.pad(jnp.triu(lu[..., :k, :]), zero, u_padding) + + u = (lax.pad(_triu(lu[..., :k, :]), zero, u_padding) + lax.expand_dims(u_eye, range(lu.ndim - 2))) la = triangular_solve(l, x, left_side=True, transpose_a=False, lower=True, @@ -1356,11 +1470,12 @@ def _lu_jvp_rule(primals, tangents): lau = triangular_solve(u, la, left_side=False, transpose_a=False, lower=False) - l_dot = jnp.matmul(l, jnp.tril(lau, -1), precision=lax.Precision.HIGHEST) - u_dot = jnp.matmul(jnp.triu(lau), u, precision=lax.Precision.HIGHEST) + with config.default_matmul_precision("highest"): + l_dot = l @ _tril(lau, -1) + u_dot = _triu(lau) @ u lu_dot = l_dot + u_dot - return (lu, pivots, permutation), (lu_dot, ad_util.Zero.from_value(pivots), - ad_util.Zero.from_value(permutation)) + return (lu, pivots, permutation), (lu_dot, ad_util.Zero.from_primal_value(pivots), + ad_util.Zero.from_primal_value(permutation)) def _lu_batching_rule(batched_args, batch_dims): @@ -1369,39 +1484,51 @@ def _lu_batching_rule(batched_args, batch_dims): x = batching.moveaxis(x, bd, 0) return lu_p.bind(x), (0, 0, 0) -def _lu_cpu_gpu_lowering(getrf_impl, ctx, operand, *, - platform: str): +def _lu_cpu_gpu_lowering(getrf_impl, ctx, operand, *, platform: str, + target_name_prefix: str): operand_aval, = ctx.avals_in - # It should be possible to support fully-dynamic shapes, but since - # the last two dimensions (m, n) are used in more involved ways, we only - # support dynamic dimensions for the batch size for now. - if not is_constant_shape(operand_aval.shape[-2:]): - raise NotImplementedError( - "Shape polymorphism for native lowering for lu on CPU and GPU is " - f"implemented only for the batch dimensions: {operand_aval.shape}") - - # TODO(b/357034884): Remove once jaxlib 0.4.32 is the minimum version. - ctx_arg = (ctx,) if jaxlib_version >= (0, 4, 32) else () - out_aval, pivot_aval, perm_aval = ctx.avals_out batch_dims = operand_aval.shape[:-2] + info_aval = ShapedArray(batch_dims, np.dtype(np.int32)) m = operand_aval.shape[-2] - if platform in ["cuda", "rocm"]: - # TODO(necula): remove the platform kwarg when we implement GPU support. - if not is_constant_shape(operand_aval.shape): + + # TODO(b/357034884): Remove version gate on the forward compat flag after the + # 3 week compatibility window. + if ctx.is_forward_compat(): + if not is_constant_shape(operand_aval.shape[-2:]): raise NotImplementedError( - "Shape polymorphism for native serialization for lu on GPU is not " - f"implemented; b/261671778; {operand_aval.shape}") - lu, pivot, info = getrf_impl(*ctx_arg, operand_aval.dtype, operand) + "Shape polymorphism for native lowering for lu on CPU and GPU is " + f"implemented only for the batch dimensions: {operand_aval.shape}") + if platform in ["cuda", "rocm"]: + if not is_constant_shape(operand_aval.shape): + raise NotImplementedError( + "Shape polymorphism for native serialization for lu on GPU is not " + f"implemented; b/261671778; {operand_aval.shape}") + lu, pivot, info = getrf_impl(operand_aval.dtype, operand) + else: + op_shape_vals = mlir.eval_dynamic_shape_as_ivals(ctx, operand_aval.shape) + lu, pivot, info = getrf_impl( + operand_aval.dtype, operand, a_shape_vals=op_shape_vals) else: - op_shape_vals = mlir.eval_dynamic_shape_as_ivals(ctx, operand_aval.shape) - # TODO(b/344892332): Remove the conditional after the compatibility period. - lu, pivot, info = getrf_impl( - *ctx_arg, operand_aval.dtype, operand, a_shape_vals=op_shape_vals) + if target_name_prefix == "cpu": + target_name = lapack.prepare_lapack_call("getrf_ffi", operand_aval.dtype) + else: + target_name = f"{target_name_prefix}solver_getrf_ffi" + # We manually construct the layouts because the input and output are + # expected to be in Fortran order. + nb = len(batch_dims) + layout = (nb, nb + 1) + tuple(range(nb - 1, -1, -1)) + result_layouts = [layout, tuple(range(nb, -1, -1)), + tuple(range(nb - 1, -1, -1))] + rule = ffi.ffi_lowering(target_name, operand_layouts=[layout], + result_layouts=result_layouts, + operand_output_aliases={0: 0}) + sub_ctx = ctx.replace(avals_out=[out_aval, pivot_aval, info_aval]) + lu, pivot, info = rule(sub_ctx, operand) + # Subtract 1 from the pivot to get 0-based indices. pivot = hlo.subtract(pivot, mlir.full_like_aval(ctx, 1, pivot_aval)) - ok = mlir.compare_hlo( - info, mlir.full_like_aval(ctx, 0, ShapedArray(batch_dims, np.dtype(np.int32))), + ok = mlir.compare_hlo(info, mlir.full_like_aval(ctx, 0, info_aval), "GE", "SIGNED") select_lu_aval = ShapedArray(batch_dims + (1, 1), np.dtype(np.bool_)) lu = _broadcasting_select_hlo( @@ -1446,25 +1573,24 @@ def _lu_tpu_lowering_rule(ctx, operand): mlir.register_lowering(lu_p, partial(_lu_cpu_gpu_lowering, lapack.getrf_hlo, - platform='cpu'), + platform='cpu', target_name_prefix="cpu"), platform='cpu') mlir.register_lowering( lu_p, partial(_lu_cpu_gpu_lowering, gpu_solver.cuda_getrf, - platform='cuda'), + platform='cuda', target_name_prefix="cu"), platform='cuda') mlir.register_lowering( lu_p, partial(_lu_cpu_gpu_lowering, gpu_solver.rocm_getrf, - platform='rocm'), + platform='rocm', target_name_prefix="hip"), platform='rocm') mlir.register_lowering(lu_p, _lu_tpu_lowering_rule, platform='tpu') -@partial(vectorize, excluded={3}, signature='(n,n),(n),(n,k)->(n,k)') def _lu_solve_core(lu: Array, permutation: Array, b: Array, trans: int) -> Array: m = lu.shape[0] - x = jnp.reshape(b, (m, math.prod(b.shape[1:]))) + x = lax.reshape(b, (m, math.prod(b.shape[1:]))) if trans == 0: x = x[permutation, :] x = triangular_solve(lu, x, left_side=True, lower=True, unit_diagonal=True) @@ -1475,7 +1601,8 @@ def _lu_solve_core(lu: Array, permutation: Array, b: Array, trans: int) -> Array conjugate_a=conj) x = triangular_solve(lu, x, left_side=True, lower=True, unit_diagonal=True, transpose_a=True, conjugate_a=conj) - x = x[jnp.argsort(permutation), :] + _, ind = lax.sort_key_val(permutation, lax.iota('int32', len(permutation))) + x = x[ind, :] else: raise ValueError(f"'trans' value must be 0, 1, or 2, got {trans}") return lax.reshape(x, b.shape) @@ -1499,7 +1626,7 @@ def _lu_solve(lu: Array, permutation: Array, b: Array, trans: int) -> Array: "number of dimensions, last axis of LU decomposition " "matrix (shape {}) and b array (shape {}) must match" .format(lu.shape, b.shape)) - b = b[..., jnp.newaxis] + b = b[..., np.newaxis] else: if b.shape[-2] != lu.shape[-1]: raise ValueError("When LU decomposition matrix and b different " @@ -1507,7 +1634,15 @@ def _lu_solve(lu: Array, permutation: Array, b: Array, trans: int) -> Array: "matrix (shape {}) and second to last axis of b array " "(shape {}) must match" .format(lu.shape, b.shape)) - x = _lu_solve_core(lu, permutation, b, trans) + + batch_shape = lax.broadcast_shapes(lu.shape[:-2], permutation.shape[:-1], b.shape[:-2]) + lu = _broadcast_to(lu, (*batch_shape, *lu.shape[-2:])) + permutation = _broadcast_to(permutation, (*batch_shape, permutation.shape[-1])) + b = _broadcast_to(b, (*batch_shape, *b.shape[-2:])) + fn = _lu_solve_core + for _ in batch_shape: + fn = api.vmap(fn, in_axes=(0, 0, 0, None)) + x = fn(lu, permutation, b, trans) return x[..., 0] if rhs_vector else x @@ -1603,8 +1738,18 @@ def _geqrf_cpu_gpu_lowering(geqrf_impl, batched_geqrf_impl, ctx, a, *, a_out, taus, info_geqrf = geqrf_impl(a_aval.dtype, a) else: a_shape_vals = mlir.eval_dynamic_shape_as_ivals(ctx, a_aval.shape) - a_out, taus, info_geqrf = geqrf_impl(a_aval.dtype, a, - a_shape_vals=a_shape_vals) + ctx_args = ( + (ctx,) if platform == "cpu" else () + ) + a_out, taus, *maybe_info_geqrf = geqrf_impl( + *ctx_args, a_aval.dtype, a, a_shape_vals=a_shape_vals + ) + if not ctx.is_forward_compat(): + # Skip the info parameter verification for the FFI kernel. + return a_out, taus + # TODO(b/344892332): This parameter will no longer be needed after + # the forward compatibility period + info_geqrf = maybe_info_geqrf[0] zeros = mlir.full_like_aval(ctx, 0, ShapedArray(batch_dims, np.dtype(np.int32))) ok = mlir.compare_hlo(info_geqrf, zeros, "EQ", "SIGNED") select_ok_a_aval = ShapedArray(batch_dims + [1, 1], np.dtype(np.bool_)) @@ -1716,11 +1861,20 @@ def _householder_product_cpu_gpu_lowering(orgqr_impl, ctx, a, taus, *, f"on GPU is not implemented; b/261671778; {a_aval.shape}") a, info_orgqr = orgqr_impl(a_aval.dtype, a, taus) else: + ctx_args = ( + (ctx,) if platform == "cpu" else () + ) a_shape_vals = mlir.eval_dynamic_shape_as_ivals(ctx, a_aval.shape) tau_shape_vals = mlir.eval_dynamic_shape_as_ivals(ctx, taus_aval.shape) - a, info_orgqr = orgqr_impl(a_aval.dtype, a, taus, - a_shape_vals=a_shape_vals, - tau_shape_vals=tau_shape_vals) + a, *maybe_info_orgqr = orgqr_impl(*ctx_args, a_aval.dtype, a, taus, + a_shape_vals=a_shape_vals, + tau_shape_vals=tau_shape_vals) + if not ctx.is_forward_compat(): + # Skip the info parameter verification for the FFI kernel. + return [a] + # TODO(b/344892332): This parameter will no longer be needed after + # the forward compatibility period + info_orgqr = maybe_info_orgqr[0] zeros = mlir.full_like_aval(ctx, 0, ShapedArray(batch_dims, np.dtype(np.int32))) ok = mlir.compare_hlo(info_orgqr, zeros, "EQ", "SIGNED") select_a_aval = ShapedArray(batch_dims + [1, 1], np.dtype(np.bool_)) @@ -1780,14 +1934,14 @@ def qr_jvp_rule(primals, tangents, *, full_matrices): raise NotImplementedError( "Unimplemented case of QR decomposition derivative") dx_rinv = triangular_solve(r, dx) # Right side solve by default - qt_dx_rinv = jnp.matmul(_H(q), dx_rinv) - qt_dx_rinv_lower = jnp.tril(qt_dx_rinv, -1) + qt_dx_rinv = _H(q) @ dx_rinv + qt_dx_rinv_lower = _tril(qt_dx_rinv, -1) do = qt_dx_rinv_lower - _H(qt_dx_rinv_lower) # This is skew-symmetric # The following correction is necessary for complex inputs - I = lax.expand_dims(jnp.eye(n, dtype=do.dtype), range(qt_dx_rinv.ndim - 2)) + I = lax.expand_dims(lax_internal._eye(do.dtype, (n, n)), range(qt_dx_rinv.ndim - 2)) do = do + I * (qt_dx_rinv - qt_dx_rinv.real.astype(qt_dx_rinv.dtype)) - dq = jnp.matmul(q, do - qt_dx_rinv) + dx_rinv - dr = jnp.matmul(qt_dx_rinv - do, r) + dq = q @ (do - qt_dx_rinv) + dx_rinv + dr = (qt_dx_rinv - do) @ r return (q, r), (dq, dr) def _qr_batching_rule(batched_args, batch_dims, *, full_matrices): @@ -1800,8 +1954,10 @@ def _qr_lowering(a, *, full_matrices): *batch_dims, m, n = a.shape if m == 0 or n == 0: k = m if full_matrices else min(m, n) - q = jnp.broadcast_to(jnp.eye(m, k, dtype=a.dtype), (*batch_dims, m, k)) - r = jnp.empty((*batch_dims, k, n), dtype=a.dtype) + q = lax.broadcast_in_dim(lax_internal._eye(a.dtype, (m, k)), + (*batch_dims, m, k), + (len(batch_dims), len(batch_dims) + 1)) + r = lax.full((*batch_dims, k, n), 0, dtype=a.dtype) return q, r r, taus = geqrf(a) @@ -1814,7 +1970,7 @@ def _qr_lowering(a, *, full_matrices): else: q = householder_product(r, taus) r = r[..., :n, :n] - r = jnp.triu(r) + r = _triu(r) return q, r @@ -1865,7 +2021,7 @@ def _svd_abstract_eval(operand, *, full_matrices, compute_uv, subset_by_index): raise NotImplementedError -@jax.default_matmul_precision("float32") +@config.default_matmul_precision("float32") def _svd_jvp_rule( primals, tangents, *, full_matrices, compute_uv, subset_by_index ): @@ -1883,13 +2039,13 @@ def _svd_jvp_rule( Ut, V = _H(U), _H(Vt) s_dim = s[..., None, :] dS = Ut @ dA @ V - ds = ufuncs.real(jnp.diagonal(dS, 0, -2, -1)) + ds = _extract_diagonal(dS.real) if not compute_uv: return (s,), (ds,) s_diffs = (s_dim + _T(s_dim)) * (s_dim - _T(s_dim)) - s_diffs_zeros = jnp.eye(s.shape[-1], dtype=s.dtype) # jnp.ones((), dtype=A.dtype) * (s_diffs == 0.) # is 1. where s_diffs is 0. and is 0. everywhere else + s_diffs_zeros = lax_internal._eye(s.dtype, (s.shape[-1], s.shape[-1])) # jnp.ones((), dtype=A.dtype) * (s_diffs == 0.) # is 1. where s_diffs is 0. and is 0. everywhere else s_diffs_zeros = lax.expand_dims(s_diffs_zeros, range(s_diffs.ndim - 2)) F = 1 / (s_diffs + s_diffs_zeros) - s_diffs_zeros dSS = s_dim.astype(A.dtype) * dS # dS.dot(jnp.diag(s)) @@ -1897,7 +2053,7 @@ def _svd_jvp_rule( s_zeros = (s == 0).astype(s.dtype) s_inv = 1 / (s + s_zeros) - s_zeros - s_inv_mat = jnp.vectorize(jnp.diag, signature='(k)->(k,k)')(s_inv) + s_inv_mat = _construct_diagonal(s_inv) dUdV_diag = .5 * (dS - _H(dS)) * s_inv_mat.astype(A.dtype) dU = U @ (F.astype(A.dtype) * (dSS + _H(dSS)) + dUdV_diag) dV = V @ (F.astype(A.dtype) * (SdS + _H(SdS))) @@ -1916,15 +2072,17 @@ def _svd_jvp_rule( def _empty_svd(a, *, full_matrices, compute_uv): batch_shape = a.shape[:-2] m, n = a.shape[-2:] - s = jnp.empty(batch_shape + (0,), dtype=lax_internal._complex_basetype(a.dtype)) + s = lax.full(batch_shape + (0,), 0, dtype=lax_internal._complex_basetype(a.dtype)) if not compute_uv: return (s,) if full_matrices: size = max(m, n) - u = jnp.broadcast_to(jnp.eye(size, dtype=a.dtype), batch_shape + (size, size)) + u = lax.broadcast_in_dim(lax_internal._eye(a.dtype, (size, size)), + (*batch_shape, size, size), + (len(batch_shape), len(batch_shape) + 1)) else: - u = jnp.empty(batch_shape + (m, n), dtype=a.dtype) - v = jnp.empty(batch_shape + (0, 0), dtype=a.dtype) + u = lax.full(batch_shape + (m, n), 0, dtype=a.dtype) + v = lax.full(batch_shape + (0, 0), 0, dtype=a.dtype) if m < n: u, v = v, u return s, u, v @@ -1973,7 +2131,8 @@ def _svd_cpu_gpu_lowering( compute_uv=compute_uv) else: a_shape_vals = mlir.eval_dynamic_shape_as_ivals(ctx, operand_aval.shape) - s, u, vt, info = gesvd_impl(operand_aval.dtype, operand, + ctx_args = (ctx,) + s, u, vt, info = gesvd_impl(*ctx_args, operand_aval.dtype, operand, full_matrices=full_matrices, compute_uv=compute_uv, a_shape_vals=a_shape_vals) @@ -2162,24 +2321,25 @@ def _tridiagonal_solve_batching_rule( def _tridiagonal_solve_jax(dl, d, du, b, **kw): """Pure JAX implementation of `tridiagonal_solve`.""" def prepend_zero(x): - return jnp.append( - jnp.zeros((1,) + x.shape[1:], dtype=x.dtype), - x[:-1], axis=0) + return lax.concatenate( + [lax.full((1,) + x.shape[1:], 0, dtype=x.dtype), x[:-1]], dimension=0) fwd1 = lambda tu_, x: x[1] / (x[0] - x[2] * tu_) def fwd2(b_, x): - return (x[0] - x[3][jnp.newaxis, ...] * b_) / ( - x[1] - x[3] * x[2])[jnp.newaxis, ...] + return (x[0] - x[3][np.newaxis, ...] * b_) / ( + x[1] - x[3] * x[2])[np.newaxis, ...] - bwd1 = lambda x_, x: x[0] - x[1][jnp.newaxis, ...] * x_ + bwd1 = lambda x_, x: x[0] - x[1][np.newaxis, ...] * x_ double = lambda f, args: (f(*args), f(*args)) # Move relevant dimensions to the front for the scan. - dl = jnp.moveaxis(dl, -1, 0) - d = jnp.moveaxis(d, -1, 0) - du = jnp.moveaxis(du, -1, 0) - b = jnp.moveaxis(b, -1, 0) - b = jnp.moveaxis(b, -1, 0) + moveaxis_fwd = lambda x: lax.transpose(x, (x.ndim - 1, *range(x.ndim - 1))) + moveaxis_bwd = lambda x: lax.transpose(x, (*range(1, x.ndim), 0)) + dl = moveaxis_fwd(dl) + d = moveaxis_fwd(d) + du = moveaxis_fwd(du) + b = moveaxis_fwd(b) + b = moveaxis_fwd(b) # Forward pass. _, tu_ = lax.scan(lambda tu_, x: double(fwd1, (tu_, x)), @@ -2199,8 +2359,8 @@ def fwd2(b_, x): unroll=32) result = x_[::-1] - result = jnp.moveaxis(result, 0, -1) - result = jnp.moveaxis(result, 0, -1) + result = moveaxis_bwd(result) + result = moveaxis_bwd(result) return result @@ -2388,7 +2548,7 @@ def hessenberg(a: ArrayLike) -> tuple[Array, Array]: return hessenberg_p.bind(a) def _hessenberg_abstract_eval(a): - if a.dtype not in (jnp.float32, jnp.float64, jnp.complex64, jnp.complex128): + if a.dtype not in (np.float32, np.float64, np.complex64, np.complex128): raise TypeError("hessenberg requires a.dtype to be float32, float64, " f"complex64, or complex128, got {a.dtype}.") if a.ndim < 2: @@ -2464,19 +2624,20 @@ def tridiagonal(a: ArrayLike, *, lower=True first superdiagonal. ``taus`` contains the scalar factors of the elementary Householder reflectors. """ - arr, d, e, taus, info = tridiagonal_p.bind(jnp.asarray(a), lower=lower) - nan = arr.dtype.type(jnp.nan) - if jnp.issubdtype(arr.dtype, np.complexfloating): - nan = nan + arr.dtype.type(jnp.nan * 1j) - arr = jnp.where((info == 0)[..., None, None], arr, nan) - real_type = jnp.finfo(arr.dtype).dtype.type - d = jnp.where((info == 0)[..., None], d, real_type(jnp.nan)) - e = jnp.where((info == 0)[..., None], e, real_type(jnp.nan)) - taus = jnp.where((info == 0)[..., None], taus, nan) + arr, d, e, taus, info = tridiagonal_p.bind(lax_internal.asarray(a), lower=lower) + def nans_like(arr): + if dtypes.issubdtype(arr.dtype, np.complexfloating): + return lax.full_like(arr, np.nan + 1j * np.nan) + return lax.full_like(arr, np.nan) + mask = lambda x: lax.broadcast_in_dim(info == 0, x.shape, range(info.ndim)) + arr = lax.select(mask(arr), arr, nans_like(arr)) + d = lax.select(mask(d), d, nans_like(d)) + e = lax.select(mask(e), e, nans_like(e)) + taus = lax.select(mask(taus), taus, nans_like(taus)) return arr, d, e, taus def _tridiagonal_abstract_eval(a, *, lower): - if a.dtype not in (jnp.float32, jnp.float64, jnp.complex64, jnp.complex128): + if a.dtype not in (np.float32, np.float64, np.complex64, np.complex128): raise TypeError("tridiagonal requires a.dtype to be float32, float64, " f"complex64, or complex128, got {a.dtype}.") if a.ndim < 2: @@ -2488,7 +2649,7 @@ def _tridiagonal_abstract_eval(a, *, lower): if a.shape[-1] == 0: raise TypeError("tridiagonal requires the last two dimensions of a to be " f"non-zero, got a.shape of {a.shape}.") - real_dtype = jnp.finfo(a.dtype).dtype + real_dtype = dtypes.finfo(a.dtype).dtype return [ a, ShapedArray(a.shape[:-2] + (a.shape[-1],), real_dtype), @@ -2528,7 +2689,7 @@ def _tridiagonal_cpu_gpu_hlo(sytrd_impl, ctx, a, *, lower): # Utilities def _nan_like_hlo(ctx: mlir.LoweringRuleContext, aval) -> ir.Value: - if jnp.issubdtype(aval.dtype, np.complexfloating): + if dtypes.issubdtype(aval.dtype, np.complexfloating): return mlir.full_like_aval(ctx, np.nan + np.nan * 1j, aval) else: return mlir.full_like_aval(ctx, np.nan, aval) diff --git a/jax/_src/lax/other.py b/jax/_src/lax/other.py index 7bdfabb92df8..67f274e829ff 100644 --- a/jax/_src/lax/other.py +++ b/jax/_src/lax/other.py @@ -18,15 +18,18 @@ import math from typing import Any -import jax -from jax._src.numpy import lax_numpy as jnp +from jax._src.custom_derivatives import custom_jvp +from jax._src import dtypes from jax._src.lax import lax from jax._src.lax import convolution +from jax._src import util +from jax._src.typing import Array, ArrayLike +import numpy as np DType = Any def conv_general_dilated_patches( - lhs: jax.typing.ArrayLike, + lhs: ArrayLike, filter_shape: Sequence[int], window_strides: Sequence[int], padding: str | Sequence[tuple[int, int]], @@ -35,7 +38,7 @@ def conv_general_dilated_patches( dimension_numbers: convolution.ConvGeneralDilatedDimensionNumbers | None = None, precision: lax.Precision | None = None, preferred_element_type: DType | None = None, -) -> jax.Array: +) -> Array: """Extract patches subject to the receptive field of `conv_general_dilated`. Runs the input through a convolution with given parameters. The kernel of the @@ -88,7 +91,7 @@ def conv_general_dilated_patches( (`np.prod(filter_shape) * lhs.shape[lhs_spec.index('C')]`). """ - lhs_array = jnp.asarray(lhs) + lhs_array = lax.asarray(lhs) filter_shape = tuple(filter_shape) dimension_numbers = convolution.conv_dimension_numbers( lhs_array.shape, (1, 1) + filter_shape, dimension_numbers) @@ -99,11 +102,10 @@ def conv_general_dilated_patches( n_channels = lhs_array.shape[lhs_spec[1]] # Move separate `lhs` spatial locations into separate `rhs` channels. - rhs = jnp.eye(spatial_size, dtype=lhs_array.dtype).reshape(filter_shape * 2) - - rhs = rhs.reshape((spatial_size, 1) + filter_shape) - rhs = jnp.tile(rhs, (n_channels,) + (1,) * (rhs.ndim - 1)) - rhs = jnp.moveaxis(rhs, (0, 1), (rhs_spec[0], rhs_spec[1])) + rhs = lax._eye(lhs_array.dtype, shape=(spatial_size, spatial_size)) + rhs = lax.broadcast_in_dim(rhs, (n_channels, spatial_size, spatial_size), (1, 2)) + rhs = lax.reshape(rhs, (n_channels * spatial_size, 1, *filter_shape)) + rhs = util.moveaxis(rhs, (0, 1), (rhs_spec[0], rhs_spec[1])) out = convolution.conv_general_dilated( lhs=lhs_array, @@ -122,8 +124,8 @@ def conv_general_dilated_patches( def conv_general_dilated_local( - lhs: jax.typing.ArrayLike, - rhs: jax.typing.ArrayLike, + lhs: ArrayLike, + rhs: ArrayLike, window_strides: Sequence[int], padding: str | Sequence[tuple[int, int]], filter_shape: Sequence[int], @@ -131,7 +133,7 @@ def conv_general_dilated_local( rhs_dilation: Sequence[int] | None = None, dimension_numbers: convolution.ConvGeneralDilatedDimensionNumbers | None = None, precision: lax.PrecisionLike = None -) -> jax.Array: +) -> Array: """General n-dimensional unshared convolution operator with optional dilation. Also known as locally connected layer, the operation is equivalent to @@ -185,7 +187,7 @@ def conv_general_dilated_local( - the input and output feature dimensions in rhs with the characters 'I' and 'O' respectively, and - spatial dimension correspondences between `lhs`, `rhs`, and the output using - any distinct characters. + any distinct characters. The examples below use 'W' and 'H'. For example, to indicate dimension numbers consistent with the `conv` function with two spatial dimensions, one could use `('NCHW', 'OIHW', 'NCHW')`. As @@ -200,7 +202,7 @@ def conv_general_dilated_local( If `dimension_numbers` is `None`, the default is `('NCHW', 'OIHW', 'NCHW')` (for a 2D convolution). """ - lhs_array = jnp.asarray(lhs) + lhs_array = lax.asarray(lhs) c_precision = lax.canonicalize_precision(precision) lhs_precision = ( @@ -234,5 +236,52 @@ def conv_general_dilated_local( dn = ((lhs_c_dims, rhs_c_dims), (lhs_b_dims, rhs_b_dims)) out = lax.dot_general(patches, rhs, dimension_numbers=dn, precision=precision) - out = jnp.moveaxis(out, (-2, -1), (out_spec[0], out_spec[1])) + out = util.moveaxis(out, (-2, -1), (out_spec[0], out_spec[1])) return out + + +def _wrap_between(x, _a): + """Wraps `x` between `[-a, a]`.""" + a = lax._const(x, _a) + two_a = lax._const(x, 2 * _a) + zero = lax._const(x, 0) + rem = lax.rem(lax.add(x, a), two_a) + rem = lax.select(lax.lt(rem, zero), lax.add(rem, two_a), rem) + return lax.sub(rem, a) + + +def _replace_inf(x: Array) -> Array: + re_x = lax.real(x) if dtypes.issubdtype(x.dtype, np.complexfloating) else x + inf = lax._const(re_x, float('inf')) + return lax.select(lax.eq(re_x, inf), lax._zeros(x), x) + + +@custom_jvp +def logaddexp(x1: ArrayLike, x2: ArrayLike, /) -> Array: + """Compute log(exp(x1) + exp(x2)) avoiding overflow.""" + x1_arr = lax.asarray(x1) + x2_arr = lax.asarray(x2) + assert x1_arr.dtype == x2_arr.dtype + + amax = lax.max(x1_arr, x2_arr) + if dtypes.isdtype(x1_arr.dtype, "real floating"): + delta = lax.sub(x1_arr, x2_arr) + return lax.select(lax._isnan(delta), + lax.add(x1_arr, x2_arr), # NaNs or infinities of the same sign. + lax.add(amax, lax.log1p(lax.exp(lax.neg(lax.abs(delta)))))) + elif dtypes.isdtype(x1_arr.dtype, "complex floating"): + delta = lax.sub(lax.add(x1, x2), lax.mul(amax, lax._const(amax, 2))) + out = lax.add(amax, lax.log1p(lax.exp(delta))) + return lax.complex(lax.real(out), _wrap_between(lax.imag(out), np.pi)) + else: + raise ValueError(f"logaddexp requires floating-point or complex inputs; got {x1_arr.dtype}") + + +@logaddexp.defjvp +def _logaddexp_jvp(primals, tangents): + x1, x2 = primals + t1, t2 = tangents + primal_out = logaddexp(x1, x2) + tangent_out = lax.add(lax.mul(t1, lax.exp(lax.sub(_replace_inf(x1), _replace_inf(primal_out)))), + lax.mul(t2, lax.exp(lax.sub(_replace_inf(x2), _replace_inf(primal_out))))) + return primal_out, tangent_out diff --git a/jax/_src/lax/parallel.py b/jax/_src/lax/parallel.py index 4faa0bdd390b..c9a07072ddc7 100644 --- a/jax/_src/lax/parallel.py +++ b/jax/_src/lax/parallel.py @@ -35,7 +35,6 @@ from jax._src.lax import slicing from jax._src.lib.mlir import ir from jax._src.lib.mlir.dialects import hlo -from jax._src.numpy import lax_numpy from jax._src.util import (canonicalize_axis, moveaxis, safe_map, safe_zip, unzip2) import numpy as np @@ -231,7 +230,10 @@ def pargmax(x, axis_name): def _axis_index_of_val(x, val, axis_name): idx = axis_index(axis_name) - validx = lax_numpy.where(val == x, idx, dtypes.iinfo(dtypes.dtype(idx)).max) + mask = (val == x) + validx = lax.select(mask, + lax.full(mask.shape, idx), + lax.full(mask.shape, dtypes.iinfo(dtypes.dtype(idx)).max, dtype=idx.dtype)) return pmin(validx, axis_name) def _validate_reduce_axis_index_groups(axis_index_groups): @@ -779,7 +781,7 @@ def _ppermute_batcher(axis_size, frame_name, _, vals_in, dims_in, axis_name, per perm_indices = np.zeros(axis_size, dtype=int) for src, dst in perm: perm_indices[dst] = src - return lax_numpy.take(v, perm_indices, d), d + return v.take(perm_indices, d), d def _collective_batcher(prim, args, dims, **params): return prim.bind(*args, **params), dims if prim.multiple_results else dims[0] @@ -795,7 +797,7 @@ def _collective_batcher(prim, args, dims, **params): def _pbroadcast_transpose_rule(t, x, source, axis_name): is_source = axis_index(axis_name) == source tsum = psum(t, axis_name) - return [lax_numpy.where(is_source, tsum, lax_numpy.zeros_like(t))] + return [lax.select(is_source, lax.full_like(t, tsum), lax.full_like(t, 0))] def _pbroadcast_batcher(axis_size, frame_name, _, vals_in, dims_in, axis_name, source): (v,), (d,) = vals_in, dims_in @@ -810,7 +812,7 @@ def _pbroadcast_batcher(axis_size, frame_name, _, vals_in, dims_in, axis_name, s return pbroadcast_p.bind(v, source=source, axis_name=remaining_axes), d if d is batching.not_mapped: return v, d - return lax_numpy.take(v, [source] * axis_size, d), d + return v.take([source] * axis_size, d), d def _pbroadcast_lowering(ctx, x, *, axis_name, source): replica_groups = _replica_groups(ctx.module_context.axis_env, axis_name, None) diff --git a/jax/_src/lax/slicing.py b/jax/_src/lax/slicing.py index 206d52ba5ebd..60dfa0e1b3d2 100644 --- a/jax/_src/lax/slicing.py +++ b/jax/_src/lax/slicing.py @@ -24,9 +24,8 @@ import numpy as np -import jax - from jax._src import ad_util +from jax._src import api from jax._src import config from jax._src import core from jax._src import dispatch @@ -1363,7 +1362,7 @@ def _dynamic_update_slice_jvp(primals, tangents): g_operand, g_update = tangents[:2] val_out = dynamic_update_slice_p.bind(operand, update, *start_indices) if type(g_operand) is ad_util.Zero and type(g_update) is ad_util.Zero: - tangent_out = ad_util.Zero.from_value(val_out) + tangent_out = ad_util.Zero.from_primal_value(val_out) else: g_operand = ad.instantiate_zeros(g_operand) g_update = ad.instantiate_zeros(g_update) @@ -1401,7 +1400,7 @@ def _dynamic_update_slice_batching_rule(batched_args, batch_dims): inserted_window_dims=(), scatter_dims_to_operand_dims=dims) index, index_bdim = _batch_dynamic_slice_indices(start_idx, start_idx_bd) - return jax.vmap( + return api.vmap( partial(scatter, dimension_numbers=dnums, indices_are_sorted=True, unique_indices=True, mode=GatherScatterMode.CLIP), @@ -1784,7 +1783,7 @@ def _gather_lower_opaque(ctx, operand, indices, *, indices_are_sorted, mode, fill_value) -> ir.Value: aval_x, aval_indices = ctx.avals_in aval_y, = ctx.avals_out - elt_shape = aval_x.dtype._rules.physical_element_aval(aval_x.dtype).shape + elt_shape = core.physical_element_aval(aval_x.dtype).shape trailing_offset_dims = [aval_y.ndim + i for i in range(len(elt_shape))] dimension_numbers = dimension_numbers._replace( offset_dims=(*dimension_numbers.offset_dims, *trailing_offset_dims)) @@ -2001,7 +2000,7 @@ def _scatter_add_jvp(primals, tangents, *, update_jaxpr, update_consts, indices_are_sorted=indices_are_sorted, unique_indices=unique_indices, mode=mode) if type(g_operand) is ad_util.Zero and type(g_updates) is ad_util.Zero: - tangent_out = ad_util.Zero.from_value(val_out) + tangent_out = ad_util.Zero.from_primal_value(val_out) else: g_operand = ad.instantiate_zeros(g_operand) g_updates = ad.instantiate_zeros(g_updates) @@ -2181,7 +2180,7 @@ def _scatter_extremal_jvp(scatter_op, primals, tangents, update_jaxpr, unique_indices=unique_indices, mode=mode) if type(g_operand) is ad_util.Zero and type(g_updates) is ad_util.Zero: - tangent_out = ad_util.Zero.from_value(val_out) + tangent_out = ad_util.Zero.from_primal_value(val_out) else: g_operand = ad.instantiate_zeros(g_operand) g_updates = ad.instantiate_zeros(g_updates) @@ -2295,7 +2294,7 @@ def _scatter_jvp(primals, tangents, *, update_jaxpr, update_consts, update_consts=update_consts, dimension_numbers=dnums, indices_are_sorted=indices_are_sorted, unique_indices=unique_indices, mode=mode) - return val_out, ad_util.Zero.from_value(val_out) + return val_out, ad_util.Zero.from_primal_value(val_out) g_operand = ad.instantiate_zeros(g_operand) g_updates = ad.instantiate_zeros(g_updates) @@ -2385,7 +2384,7 @@ def _scatter_transpose_rule(t, operand, indices, updates, *, update_jaxpr, update_consts, dimension_numbers, indices_are_sorted, unique_indices, mode): if not unique_indices: - raise NotImplementedError("scatter transpose is only implemented where" + raise NotImplementedError("scatter transpose is only implemented where " "unique_indices=True") assert not ad.is_undefined_primal(indices) if ad.is_undefined_primal(updates): @@ -2437,7 +2436,7 @@ def _scatter_lower_opaque(ctx, operand, indices, updates, *, unique_indices, indices_are_sorted, mode): aval_x, aval_indices, aval_updates = ctx.avals_in aval_y, = ctx.avals_out - elt_shape = aval_x.dtype._rules.physical_element_aval(aval_x.dtype).shape + elt_shape = core.physical_element_aval(aval_x.dtype).shape trailing_window_dims = [aval_updates.ndim + i for i in range(len(elt_shape))] dimension_numbers = dimension_numbers._replace( update_window_dims=(*dimension_numbers.update_window_dims, diff --git a/jax/_src/lax/utils.py b/jax/_src/lax/utils.py index 01301db1a9a0..5e3e9bcd8df2 100644 --- a/jax/_src/lax/utils.py +++ b/jax/_src/lax/utils.py @@ -20,6 +20,7 @@ from jax._src import core from jax._src import dispatch +from jax._src import config from jax._src import dtypes from jax._src.util import safe_zip from jax._src.lib import xla_client @@ -37,19 +38,19 @@ def _argnum_weak_type(*argnums): return lambda *args, **_: all(args[i].weak_type for i in argnums) def standard_primitive(shape_rule, dtype_rule, name, - weak_type_rule=None): + weak_type_rule=None, sharding_rule=None): weak_type_rule = weak_type_rule or _standard_weak_type_rule prim = core.Primitive(name) prim.def_impl(partial(dispatch.apply_primitive, prim)) prim.def_abstract_eval( partial(standard_abstract_eval, prim, shape_rule, dtype_rule, - weak_type_rule)) + weak_type_rule, sharding_rule)) return prim def _get_array_abstraction_level(a): return a.array_abstraction_level def standard_abstract_eval(prim, shape_rule, dtype_rule, weak_type_rule, - *avals, **kwargs): + sharding_rule, *avals, **kwargs): assert all(isinstance(aval, core.UnshapedArray) for aval in avals), avals assert not prim.multiple_results weak_type = weak_type_rule(*avals, **kwargs) @@ -58,8 +59,11 @@ def standard_abstract_eval(prim, shape_rule, dtype_rule, weak_type_rule, out = prim.impl(*[x.val for x in avals], **kwargs) return core.ConcreteArray(out.dtype, out, weak_type=weak_type) elif least_specialized is core.ShapedArray: + out_sharding = (sharding_rule(*avals, **kwargs) + if config.sharding_in_types.value else None) return core.ShapedArray(shape_rule(*avals, **kwargs), - dtype_rule(*avals, **kwargs), weak_type=weak_type) + dtype_rule(*avals, **kwargs), weak_type=weak_type, + sharding=out_sharding) elif least_specialized is core.DShapedArray: shape = shape_rule(*avals, **kwargs) ty = (core.ShapedArray if all(type(d) is int for d in shape) diff --git a/jax/_src/lax/windowed_reductions.py b/jax/_src/lax/windowed_reductions.py index 5d6eddad0e4d..089a77de2949 100644 --- a/jax/_src/lax/windowed_reductions.py +++ b/jax/_src/lax/windowed_reductions.py @@ -30,9 +30,9 @@ from jax._src.lax import convolution from jax._src.lax import lax from jax._src.lax import slicing +from jax._src.lax.other import logaddexp from jax._src.lib.mlir import ir from jax._src.lib.mlir.dialects import hlo -from jax._src.numpy.ufuncs import logaddexp from jax._src.typing import Array import numpy as np from jax._src.core import ClosedJaxpr @@ -707,7 +707,7 @@ def _select_and_scatter_add_jvp( padding) del g_operand if type(g_source) is ad_util.Zero: - tangent_out = ad_util.Zero.from_value(val_out) + tangent_out = ad_util.Zero.from_primal_value(val_out) else: tangent_out = _select_and_scatter_add( g_source, operand, select_prim, window_dimensions, @@ -952,7 +952,7 @@ def _select_and_gather_add_jvp( padding, base_dilation, window_dilation) del g_operand if type(g_source) is ad_util.Zero: - tangent_out = ad_util.Zero.from_value(val_out) + tangent_out = ad_util.Zero.from_primal_value(val_out) else: tangent_out = _select_and_gather_add( g_source, operand, select_prim, window_dimensions, diff --git a/jax/_src/layout.py b/jax/_src/layout.py index 84708555041f..64bbd3268b16 100644 --- a/jax/_src/layout.py +++ b/jax/_src/layout.py @@ -69,7 +69,7 @@ def __eq__(self, other): self._tiling == other._tiling and self._sub_byte_element_size_in_bits == other._sub_byte_element_size_in_bits) - def _to_xla_layout(self, dtype) -> str: + def _to_xla_layout(self, dtype) -> xc.Layout: if self._tiling is None: xla_layout = xc.Layout(self.major_to_minor[::-1]) else: @@ -81,7 +81,7 @@ def _to_xla_layout(self, dtype) -> str: sub_byte_size = 0 xla_layout = xc.Layout(self.major_to_minor[::-1], self._tiling, sub_byte_size) - return str(xla_layout) + return xla_layout def check_compatible_aval(self, aval_shape: Shape): if len(self.major_to_minor) != len(aval_shape): diff --git a/jax/_src/lazy_loader.py b/jax/_src/lazy_loader.py index cf6e68e49c81..14822bff3eff 100644 --- a/jax/_src/lazy_loader.py +++ b/jax/_src/lazy_loader.py @@ -16,6 +16,7 @@ from collections.abc import Callable, Sequence import importlib +import sys from typing import Any @@ -26,17 +27,27 @@ def attach(package_name: str, submodules: Sequence[str]) -> tuple[ ]: """Lazily loads submodules of a package. - Example use: - ``` - __getattr__, __dir__, __all__ = lazy_loader.attach(__name__, ["sub1", "sub2"]) - ``` + Returns: + A tuple of ``__getattr__``, ``__dir__`` function and ``__all__`` -- + a list of available global names, which can be used to replace the + corresponding definitions in the package. + + Raises: + RuntimeError: If the ``__name__`` of the caller cannot be determined. """ + owner_name = sys._getframe(1).f_globals.get("__name__") + if owner_name is None: + raise RuntimeError("Cannot determine the ``__name__`` of the caller.") - __all__: list[str] = list(submodules) + __all__ = list(submodules) def __getattr__(name: str) -> Any: if name in submodules: - return importlib.import_module(f"{package_name}.{name}") + value = importlib.import_module(f"{package_name}.{name}") + # Update module-level globals to avoid calling ``__getattr__`` again + # for this ``name``. + setattr(sys.modules[owner_name], name, value) + return value raise AttributeError(f"module '{package_name}' has no attribute '{name}") def __dir__() -> list[str]: diff --git a/jax/_src/lib/BUILD b/jax/_src/lib/BUILD index 09cc3a81c2c2..7068c0ef6732 100644 --- a/jax/_src/lib/BUILD +++ b/jax/_src/lib/BUILD @@ -22,7 +22,7 @@ load( package( default_applicable_licenses = [], - default_visibility = ["//:__subpackages__"], + default_visibility = ["//jax:internal"], ) py_library_providing_imports_info( diff --git a/jax/_src/lib/__init__.py b/jax/_src/lib/__init__.py index b2bcc53a53f8..e8fcb433438a 100644 --- a/jax/_src/lib/__init__.py +++ b/jax/_src/lib/__init__.py @@ -27,7 +27,7 @@ except ModuleNotFoundError as err: raise ModuleNotFoundError( 'jax requires jaxlib to be installed. See ' - 'https://github.com/google/jax#installation for installation instructions.' + 'https://github.com/jax-ml/jax#installation for installation instructions.' ) from err import jax.version @@ -92,7 +92,7 @@ def _parse_version(v: str) -> tuple[int, ...]: jax_jit = xla_client._xla.jax_jit pmap_lib = xla_client._xla.pmap_lib -# XLA garbage collection: see https://github.com/google/jax/issues/14882 +# XLA garbage collection: see https://github.com/jax-ml/jax/issues/14882 def _xla_gc_callback(*args): xla_client._xla.collect_garbage() gc.callbacks.append(_xla_gc_callback) diff --git a/jax/_src/lib/mlir/dialects/__init__.py b/jax/_src/lib/mlir/dialects/__init__.py index 01dc7e2725b5..a9bae8821db5 100644 --- a/jax/_src/lib/mlir/dialects/__init__.py +++ b/jax/_src/lib/mlir/dialects/__init__.py @@ -13,35 +13,49 @@ # limitations under the License. # ruff: noqa: F401 -from typing import Any -import jaxlib.mlir.dialects.arith as arith -import jaxlib.mlir.dialects.builtin as builtin -import jaxlib.mlir.dialects.chlo as chlo -import jaxlib.mlir.dialects.func as func -import jaxlib.mlir.dialects.math as math -import jaxlib.mlir.dialects.memref as memref -import jaxlib.mlir.dialects.mhlo as mhlo -import jaxlib.mlir.dialects.scf as scf +from typing import Any, TYPE_CHECKING + +if TYPE_CHECKING: + from jaxlib.mlir.dialects import arith as arith + from jaxlib.mlir.dialects import builtin as builtin + from jaxlib.mlir.dialects import chlo as chlo + from jaxlib.mlir.dialects import func as func + from jaxlib.mlir.dialects import gpu as gpu + from jaxlib.mlir.dialects import llvm as llvm + from jaxlib.mlir.dialects import math as math + from jaxlib.mlir.dialects import memref as memref + from jaxlib.mlir.dialects import mhlo as mhlo + from jaxlib.mlir.dialects import nvgpu as nvgpu + from jaxlib.mlir.dialects import nvvm as nvvm + from jaxlib.mlir.dialects import scf as scf + from jaxlib.mlir.dialects import sparse_tensor as sparse_tensor + from jaxlib.mlir.dialects import vector as vector +else: + from jax._src import lazy_loader as _lazy + __getattr__, __dir__, __all__ = _lazy.attach("jaxlib.mlir.dialects", [ + "arith", + "builtin", + "chlo", + "func", + "gpu", + "llvm", + "math", + "memref", + "mhlo", + "nvgpu", + "nvvm", + "scf", + "sparse_tensor", + "vector", + ]) + del _lazy + # TODO(bartchr): Once JAX is released with SDY, remove the try/except. try: - import jaxlib.mlir.dialects.sdy as sdy + from jaxlib.mlir.dialects import sdy as sdy except ImportError: sdy: Any = None # type: ignore[no-redef] -import jaxlib.mlir.dialects.sparse_tensor as sparse_tensor -import jaxlib.mlir.dialects.vector as vector -try: - # pytype: disable=import-error - import jaxlib.mlir.dialects.gpu as gpu - import jaxlib.mlir.dialects.nvgpu as nvgpu - import jaxlib.mlir.dialects.nvvm as nvvm - import jaxlib.mlir.dialects.llvm as llvm - # pytype: enable=import-error -except ImportError: - pass - -from jax._src import lib - # Alias that is set up to abstract away the transition from MHLO to StableHLO. -import jaxlib.mlir.dialects.stablehlo as hlo +from jaxlib.mlir.dialects import stablehlo as hlo diff --git a/jax/_src/mesh.py b/jax/_src/mesh.py index 39663f7d711a..08c8bfcb3a29 100644 --- a/jax/_src/mesh.py +++ b/jax/_src/mesh.py @@ -217,6 +217,8 @@ def __setattr__(self, name, value): super().__setattr__(name, value) def __enter__(self): + if jax_config.disallow_mesh_context_manager.value: + raise RuntimeError("Mesh context manager is disabled.") new_env = thread_resources.stack[-1].with_mesh(self) thread_resources.stack.append(new_env) thread_resources.env = new_env @@ -247,11 +249,11 @@ def shape_tuple(self): @property def size(self): - return math.prod(self.shape.values()) + return math.prod(self.shape.values()) if self.devices.ndim else 0 @property def empty(self): - return self.devices.ndim == 0 + return self.size == 0 @functools.cached_property def is_multi_process(self): @@ -308,6 +310,10 @@ def local_devices(self): return [d for d in self.devices.flat if d.process_index == d.client.process_index()] + @functools.cached_property + def abstract_mesh(self): + return AbstractMesh(self.shape_tuple) + EMPTY_ENV = ResourceEnv(Mesh(np.empty((), dtype=object), ())) @@ -318,3 +324,91 @@ def __init__(self): self.env = self.stack[-1] thread_resources = _ThreadResourcesLocalState() + + +class AbstractMesh: + """AbstractMesh contains only axis names and axis sizes. + + It does not contain concrete devices compared to `jax.sharding.Mesh`. You + should use this as an input to the sharding passed to with_sharding_constraint + and mesh passed to shard_map to avoid tracing and lowering cache misses when + your mesh shape and names stay the same but the devices change. + See the description of https://github.com/jax-ml/jax/pull/23022 for more + details. + """ + + def __init__(self, shape_tuple: tuple[tuple[str, int], ...]): + self.shape_tuple = shape_tuple + if self.shape_tuple: + self._axis_names, self._axis_sizes = list(zip(*self.shape_tuple)) + else: + self._axis_names, self._axis_sizes = (), () + + def __hash__(self): + return hash(self.shape_tuple) + + def __eq__(self, other): + if not isinstance(other, AbstractMesh): + return False + if id(self) == id(other): + return True + return self.shape_tuple == other.shape_tuple + + def __repr__(self): + return f"AbstractMesh({self.shape_tuple})" + + @property + def axis_names(self): + return self._axis_names + + @functools.cached_property + def size(self): + return math.prod(self._axis_sizes) if self._axis_sizes else 0 + + @functools.cached_property + def shape(self): + return collections.OrderedDict(self.shape_tuple) + + @property + def _is_jax_device_mesh(self): + return False + + @property + def _internal_device_list(self): + return None + + @property + def empty(self): + return self.size == 0 + + @property + def devices(self): + _raise_value_error("devices") + + @property + def device_ids(self): + _raise_value_error("device_ids") + + @property + def is_multi_process(self): + _raise_value_error("is_multi_process") + + @property + def local_devices(self): + _raise_value_error("local_devices") + + @property + def local_mesh(self): + _raise_value_error("local_mesh") + + def __enter__(self): + raise RuntimeError("AbstractMesh is not a context manager") + + def __exit__(self, exc_type, exc_value, traceback): + raise RuntimeError("AbstractMesh is not a context manager") + + +# Create this indirection because pytype fails to recognize a property if a +# property raises an exception unconditionally. Remove this once that is fixed. +def _raise_value_error(name): + raise ValueError(f"AbstractMesh does not implement {name}") diff --git a/jax/_src/mesh_utils.py b/jax/_src/mesh_utils.py new file mode 100644 index 000000000000..da3b54058fbb --- /dev/null +++ b/jax/_src/mesh_utils.py @@ -0,0 +1,818 @@ +# Copyright 2021 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Utils for building a device mesh.""" + +from __future__ import annotations + +import collections +from collections.abc import Callable, Generator, MutableMapping, Sequence +import itertools +import logging +import math +from typing import Any + +from jax._src import xla_bridge as xb +import numpy as np + +logger = logging.getLogger(__name__) + +_TPU_V2 = 'TPU v2' +_TPU_V3 = 'TPU v3' +_TPU_V4 = 'TPU v4' +_TPU_V5_LITE = "TPU v5 lite" + +# Maps physical topology -> mesh shape -> transpose to use for jekbradbury's +# famous contiguous mesh trick. +# +# The trick only works for certain topologies and mesh shapes. Trivial dims of +# size 1 can be added to the shapes listed, and they are also supported. +_TRANSPOSE_TRICKS: dict[ + tuple[int, ...], dict[tuple[int, ...], tuple[int, ...]] +] = { + (2, 2, 1): { + (2, 2): (0, 1, 2), + }, + (2, 2, 4): { + (4, 4): (0, 1, 2), + }, + (4, 4, 4): { + (16, 4): (0, 2, 1), + }, + (4, 8, 8): { + (64, 4): (0, 2, 1), + (4, 64): (0, 2, 1), + }, + (8, 8, 8): { + (64, 8): (0, 2, 1), + }, + (8, 16, 16): { + (256, 8): (0, 2, 1), + (8, 256): (0, 2, 1), + }, +} + +# Physical ordering of core IDs in a tray that creates a ring +_TRAY_RING_ORDER = (0, 1, 2, 3, 6, 7, 4, 5) +_TRAY_2x2_RING_ORDER = (0, 1, 3, 2) +_TRAY_4x4_RING_ORDER = (0, 1, 2, 3, 7, 6, 5, 9, 10, 11, 15, 14, 13, 12, 8, 4) +_V5E_TRAY_RING_ORDER = (0, 1, 2, 3, 7, 6, 5, 4) + +def _tpu_v2_v3_create_device_mesh( + mesh_shape: Sequence[int], + devices: Sequence[Any], + **unused_kwargs, +) -> np.ndarray: + if len(devices) == 8: + logger.info( + 'Reordering mesh to physical ring order on single-tray TPU v2/v3.' + ) + device_mesh = np.asarray(devices) + device_mesh = device_mesh[np.array(_TRAY_RING_ORDER)] + device_mesh = device_mesh.reshape(mesh_shape) + return device_mesh + elif mesh_shape[-1] == 8: + device_mesh = np.asarray(devices).reshape(mesh_shape) + logger.info( + 'Reordering mesh to physical ring order on each TPU v2/v3 tray.' + ) + perm = np.array(_TRAY_RING_ORDER) + device_mesh = device_mesh[..., perm] + return device_mesh + else: + # TODO(skye): implement 2D mesh_shape logic here: + # https://github.com/tensorflow/lingvo/blob/0df40cf604dfcd14e28f7087d73687a0bd2fe5c6/lingvo/core/gshard_utils.py#L187 + # (possibly replaces above mesh_shape[-1] == 8 case) + return np.asarray(devices).reshape(mesh_shape) + + +def _v5e_create_device_mesh( + mesh_shape: Sequence[int], devices: Sequence[Any], **unused_kwargs +) -> np.ndarray | None: + """Creates rotated pincer device assignment for selected topologies. + + Args: + mesh_shape: Logical mesh shape used by the model. + devices: TPU devices. + **unused_kwargs: ... + + Returns: + None or reordered devices reshaped as `mesh_shape`. + """ + max_x, max_y, max_z = max(getattr(d, "coords", (0, 0, 0)) for d in devices) + bound_x, bound_y, bound_z = max_x + 1, max_y + 1, max_z + 1 + # Our ring re-ordering makes sense only if the passed-in devices are + # sequential, which may not always be the case. reversed() changes z-minor to + # x-minor. + sequential_devices = sorted( + devices, + key=lambda d: tuple(reversed(getattr(d, "coords", (0, 0, 0))))) + + if bound_x == bound_y == 2 and bound_z == 1 and len(devices) == 4: + device_mesh = np.asarray(sequential_devices) + device_mesh = device_mesh[np.array(_TRAY_2x2_RING_ORDER)] + device_mesh = device_mesh.reshape(mesh_shape) + return device_mesh + + if len(devices) == 8: + device_mesh = np.asarray(sequential_devices) + device_mesh = device_mesh[np.array(_V5E_TRAY_RING_ORDER)] + device_mesh = device_mesh.reshape(mesh_shape) + return device_mesh + + if bound_x == bound_y == 4 and bound_z == 1 and len(devices) == 16: # v5e4x4 + # Only uses ring order if the whole mesh is a replica group. + if max(mesh_shape) == len(devices): + device_mesh = np.asarray(sequential_devices) + device_mesh = device_mesh[np.array(_TRAY_4x4_RING_ORDER)] + device_mesh = device_mesh.reshape(mesh_shape) + return device_mesh + + return None + + +# Registers functions to create device mesh for specific device kinds. Takes +# precedence over the more general logic in create_device_mesh(). Handler may +# return None; in that case, it will fall back to using the default logic. +device_kind_handler_dict: dict[ + str, + Callable[..., np.ndarray | None], +] = { + _TPU_V2: _tpu_v2_v3_create_device_mesh, + _TPU_V3: _tpu_v2_v3_create_device_mesh, + _TPU_V5_LITE: _v5e_create_device_mesh, +} + + +def _create_device_mesh_for_nd_torus( + physical_mesh: np.ndarray, + mesh_shape: Sequence[int], + *, + allow_split_physical_axes: bool = False, +) -> tuple[np.ndarray, np.ndarray]: + """Assigns logical parallelism axes to physical axes of an N-D torus network. + + Given logical parallelism axes with sizes in `mesh_shape` and devices in an + N-dimensional torus network represented by `physical_mesh`, maps each logical + axis to one or more physical axes. Prefer to map more-performance-sensitive + logical axes to larger numbers of physical axes to maximize the bandwidth + available to them. Also prefer to assign logical axes to multiple physical + axes of the same size (e.g., a 2D square) rather than multiple physical axes + of different sizes when possible. + + If allow_split_physical_axes = False (default), this routine will error out + instead of splitting a physical axis over more than one logical axis (which + would reduce total usable bandwidth). + + Let's use a concrete example to explain the concepts and considerations. + + As an example, suppose the logical mesh is [data, model], for data and model + parallelism respectively. Also suppose that data parallelism is less + performance sensitive than model parallelism. Consider a 3D TPU pod slice of + shape 4x4x16, represented by a physical mesh of shape (4, 4, 16). + + A TPU pod slice has equal bandwidth along all axes with wraparound links, but + a 2D plane of size 4x4 may have faster XLA collective implementations than a + non-square plane or a 1D subgroup. If the mesh_shape is [16, 16], we may want + the more performance sensitive `model` axis to be mapped to the 4x4 XY plane. + + Args: + physical_mesh: a np.ndarray of devices in the shape of the N-D torus + physical topology. + mesh_shape: shape of the logical mesh (size of the various logical + parallelism axes), with axes ordered by increasing network intensity. + allow_split_physical_axes: If True, we would split physical axes if + necessary to fit the desired mesh shape. + + Returns: + An np.ndarray of devices in the shape of the logical mesh (mesh_shape), with + each logical parallelism axis mapped to one or more physical mesh axes. + The axis assignment matrix, which is a 2-d array mapping from + (physical_axis, logical_axis) to the size assigned, with the invariant + np.prod(assignment, axis=1) = physical_mesh_shape, and + np.prod(assignment, axis=0) = mesh_shape. + """ + # Remaining physical axes to be assigned to logical axes. + assignable_physical_mesh = list(physical_mesh.shape) + # Map each logical axis to a subset of physical axes. + assignment: list[tuple[int, ...]] = [() for _ in mesh_shape] + + # Assign logical axes from highest network intensity to lowest. + # `mesh_shape` is assumed to ordered by lowest network intensity first, so + # reverse it first. + for logical_axis_index, logical_axis_size in reversed( + list(enumerate(mesh_shape)) + ): + # Preferentially map to more physical axes first for higher bandwidth. + for num_axes in range(3, 0, -1): + # Try assign to any subset of size num_axes. Generate all candidates. + indices_and_axes = itertools.combinations( + enumerate(assignable_physical_mesh), num_axes + ) + for elem in indices_and_axes: + c_indices, c_axes = zip(*elem) + # TODO(zhangqiaorjc): Due to limitations in XLA, 2D collectives only + # implemented for square 2D plane. Mapping a physical axis to two + # logical axes might be slower for non-square 2D plane, e.g., map 32 to + # 4x8 or a single axis. If XLA 2D collectives support non-square plane + # soon, we can continue to preferentially map to 2D plane in general, + # otherwise, we should treat non-square 2D plane and 1D submesh equally. + if np.prod(c_axes) == logical_axis_size: + assignment[logical_axis_index] = c_indices + # Zero the assigned physical axes. + assignable_physical_mesh = [ + 0 if i in c_indices else v + for i, v in enumerate(assignable_physical_mesh) + ] + break + if assignment[logical_axis_index]: + # We already found an assignment from one candidate above. + break + else: + # If the num_axes for loop did not break, i.e. none of the candidates work + # goto here with this while-else construct. + if logical_axis_size > 1: + if not allow_split_physical_axes: + # Although this is now implemented, there are downstream tasks + # counting on this being a NotImplementedError. + raise NotImplementedError( + 'Failed to find assignment for logical_axis_index' + f' {logical_axis_index} of size {logical_axis_size} with' + f' remaining assignable mesh {assignable_physical_mesh}. The size' + ' of each axis in your logical mesh must be equal to the product' + ' of some subset of the physical mesh axis sizes. E.g. logical' + ' mesh (4, 16) is compatible with physical mesh 4x4x4 since 4=4' + ' and 16=4x4. If you want to split physical axes, set ' + ' allow_split_physical_axes to True.' + ) + else: + # We will try finding an assignment, even if that means splitting the + # physical axes, which requires a more sophisticated implementation. + return _create_device_mesh_for_nd_torus_splitting_axes( + physical_mesh, mesh_shape + ) + + # Flatten the assignment, e.g., [(), (2,), (0, 1)] -> (2, 0, 1). + transpose: list[int] = [] + assignment_array = np.ones( + [len(physical_mesh.shape), len(mesh_shape)], dtype=np.int64 + ) + for i, x in enumerate(assignment): + for y in x: + physical_mesh_axis = int(y) + assignment_array[physical_mesh_axis, i] = physical_mesh.shape[ + physical_mesh_axis + ] + transpose.append(physical_mesh_axis) + return ( + physical_mesh.transpose(transpose).reshape(mesh_shape), + assignment_array, + ) + + +def _create_device_mesh_for_nd_torus_splitting_axes( + physical_mesh: np.ndarray, + mesh_shape: Sequence[int], +) -> tuple[np.ndarray, np.ndarray]: + """Assigns logical parallelism axes to physical axes of an N-D torus network. + + This implementation allows creating meshes that requires splitting physical + axes, and thus one could produce logical mesh of any shape, as long as the + number of devices matches, e.g., + + - Creating 2x2x4 from 4x4; + + - Creating 2x2x16 from 8x8; + + Args: + physical_mesh: a np.ndarray of devices in the shape of the N-D torus + physical topology. + mesh_shape: shape of the logical mesh (size of the various logical + parallelism axes), with axes ordered by increasing network intensity. + + Returns: + An np.ndarray of devices in the shape of the logical mesh (mesh_shape), with + each logical parallelism axis mapped to one or more physical mesh axes. + The axis assignment matrix, which is a 2-d array mapping from + (physical_axis, logical_axis) to the size assigned, with the invariant + np.prod(assignment, axis=1) = physical_mesh_shape, and + np.prod(assignment, axis=0) = mesh_shape. + """ + if np.prod(physical_mesh.shape) != np.prod(mesh_shape): + raise ValueError( + 'The number of devices in physical mesh' + f' {physical_mesh.shape} does not match the number of devices' + f' in logical mesh {mesh_shape}.' + ) + + physical_mesh_shape = physical_mesh.shape + logical_mesh_shape = tuple(mesh_shape) + + # (Partial) assignment map as an 2-d array [p_axis, l_axis] -> size. + assignment = np.ones( + [len(physical_mesh_shape), len(logical_mesh_shape)], dtype=np.int64 + ) + + # Process logical axes from highest network intensity to lowest. + # `mesh_shape` is assumed to ordered by lowest network intensity first, so + # reverse it. + for logical_axis, logical_axis_size in reversed( + list(enumerate(logical_mesh_shape)) + ): + # Go over all the possible assignment for the logical axis, including the + # one that splits multiple physical axes. + best_logical_axis_assignment = None + for logical_axis_assignment in _enumerate_feasible_logical_axis_assignments( + physical_mesh_shape, assignment, logical_axis_size + ): + # TODO(rosun): Instead of using heuristics, replace this with a proper + # scoring function reflecting the underlying hardware properties. + if ( + best_logical_axis_assignment is None + or _prefer_first_logical_axis_assignment( + logical_axis_assignment, + best_logical_axis_assignment, + physical_mesh_shape=physical_mesh_shape, + assignment=assignment, + ) + ): + best_logical_axis_assignment = logical_axis_assignment + assignment[:, logical_axis] = best_logical_axis_assignment + + # Read out the assignment. + logical_mesh = _generate_logical_mesh( + physical_mesh, logical_mesh_shape, assignment + ) + + return logical_mesh, assignment + + +def _get_prime_factors(x: int) -> list[int]: + """Returns a sorted list of prime factors for the given number.""" + assert x > 0 + factors = [] + for p in range(2, math.isqrt(x) + 2): + while x % p == 0: + factors.append(p) + x //= p + if x == 1: + return factors + else: + return [x] # x is a prime number. + + +def _enumerate_feasible_logical_axis_assignments( + physical_mesh_shape: Sequence[int], + assignment: np.ndarray, + logical_axis_size: int, +) -> Generator[np.ndarray, None, None]: + """Yields feasible assignments for a single logical axis. + + For a physical mesh of shape [x_1, ..., x_n], and the product of all previous + assignments on each physical axes [y_1, ..., y_n], this function yields all + possible assignments for the axis as 1-d arrays [z_1, ..., z_n], so that: + + - prod(z_1, ..., z_n) = logical_axis_size + + - x_i % (z_i * y_i) = 0 + + Args: + physical_mesh_shape: Physical mesh shape. + assignment: Existing assignment matrix. + logical_axis_size: Size of the logical axis to assign. + + Yields: + All valid assignments for the logical axis. Each assignment is represented + as an integer array of length len(physical_mesh_shape). + """ + logical_axis_factors: MutableMapping[int, int] = collections.defaultdict(int) + for factor in _get_prime_factors(logical_axis_size): + logical_axis_factors[factor] += 1 + + available_physical_mesh_shape = np.array(physical_mesh_shape) // np.prod( + assignment, axis=-1 + ) + + # To enable efficient enumerations, we first index physical axes by their + # prime factors. Since we know the prime factorization of the logical axis + # size, we could simply enumerate by picking the correct count for each + # prime factor. + physical_axes_by_factor: MutableMapping[int, list[int]] = ( + collections.defaultdict(list) + ) + for physical_axis, physical_axis_size in enumerate( + available_physical_mesh_shape + ): + for factor in _get_prime_factors(physical_axis_size): + if factor not in logical_axis_factors: + continue + physical_axes_by_factor[factor].append(physical_axis) + + factors = [] + assignments_by_factor = [] + for factor, multiplicity in logical_axis_factors.items(): + factors.append(factor) + assignments_by_factor.append( + set( + itertools.combinations( + physical_axes_by_factor[factor], multiplicity + ) + ) + ) + + for axis_assignment in itertools.product(*assignments_by_factor): + result = np.ones([len(physical_mesh_shape)], dtype=np.int64) + for factor_index, per_factor_assignment in enumerate(axis_assignment): + for physical_axis in per_factor_assignment: + result[physical_axis] *= factors[factor_index] + yield result + + +def _prefer_first_logical_axis_assignment( + x: np.ndarray, + y: np.ndarray, + *, + physical_mesh_shape: Sequence[int], + assignment: np.ndarray, +) -> bool: + """Returns True if the first axis assignment is preferred over the second. + + For now, this is implemented with some very simple heuristics. However, + it is possible to introduce e.g., a value function here based on a more + precise model of the underlying hardware. + + TODO(rosun): Use a proxy of network capacity to select the partitions. + + Args: + x: Logical axis assignment as [len(physical_mesh_shape)] array. + y: Logical axis assignment as [len(physical_mesh_shape)] array. + physical_mesh_shape: Physical mesh shape. + assignment: Assignment matrix. + + Returns: + True if x is preferred over y. + """ + # Prefer occupying complete physical axes. I don't have a good reason for + # this, except that it is compatible with the existing behavior. + # + # E.g., on 4 x 4 x 8, [4, 4, -] will be preferred over [4, -, 4], and then + # over [2, 2, 4]. + x_whole_axis_size = np.prod( + [s for i, s in enumerate(x) if s == physical_mesh_shape[i]] + ) + y_whole_axis_size = np.prod( + [s for i, s in enumerate(y) if s == physical_mesh_shape[i]] + ) + + if x_whole_axis_size != y_whole_axis_size: + return x_whole_axis_size > y_whole_axis_size + + # Prefer occupying more whole physical axes for better bandwidth. + # + # This is consistent with existing logic, i.e., 2 x 2 is preferred over 4. + x_num_whole_axes = len( + [1 for i, s in enumerate(x) if s == physical_mesh_shape[i] and s > 1] + ) + y_num_whole_axes = len( + [1 for i, s in enumerate(y) if s == physical_mesh_shape[i] and s > 1] + ) + + if x_num_whole_axes != y_num_whole_axes: + return x_num_whole_axes > y_num_whole_axes + + # Prefer taking physical axes that are not taken by logical axes of higher + # network intensity. E.g., for a 4 x 4 x 4, suppose that the previous + # assignments are 1 x 2 x 4, and we want to place a new logical axis of size + # 2, we will go for [2, 1, 1] instead of [1, 2, 1], as the latter choice will + # tap into bandwidth already taken by the higher intensity axis. + assigned_physical_mesh_shape = np.prod(assignment, axis=-1) + + x_non_overlapping_axis_size = np.prod( + [s for i, s in enumerate(x) if assigned_physical_mesh_shape[i] > 1] + ) + y_non_overlapping_axis_size = np.prod( + [s for i, s in enumerate(y) if assigned_physical_mesh_shape[i] > 1] + ) + + if x_non_overlapping_axis_size != y_non_overlapping_axis_size: + return x_non_overlapping_axis_size > y_non_overlapping_axis_size + + # Otherwise sort by reverse lexical graphical order, to be consistent with + # existing behavior. + return tuple(x) > tuple(y) + + +def _generate_logical_mesh( + physical_mesh: np.ndarray, + logical_mesh_shape: Sequence[int], + assignment: np.ndarray, +) -> np.ndarray: + """Compute the logical mesh from assignment map. + + Args: + physical_mesh: Physical device mesh. + logical_mesh_shape: Logical mesh shape. + assignment: 2-d assignment matrix shape [physical_dims, logical_dims]. + + Returns: + Logical mesh reshaped from physical mesh. + """ + physical_indices = np.broadcast_to( + np.expand_dims( + np.arange(len(physical_mesh.shape), dtype=np.int64), axis=-1 + ), + assignment.shape, + ).reshape([-1]) + + logical_indices = np.broadcast_to( + np.expand_dims( + np.arange(len(logical_mesh_shape), dtype=np.int64), axis=0 + ), + assignment.shape, + ).reshape([-1]) + + # Axes of logical mesh is ordered by (physical_axis, logical_axis). + # + # Note that we sort for each physical_axis the logical_axis, so that higher + # intensity logical axes are replicated at inner (minor) dimensions. + # + # E.g., if a dimension size is 12 = 3x4, where 3 is higher intensity and 4 + # is lower, we want to reshape so that it becomes 12 = 4x3. Imagine in the + # 1-d case, this will allow more connections between the higher intensity + # axes. + logical_mesh = np.reshape(physical_mesh, assignment.reshape([-1])) + + # We will then group by l_axis as this is what is expected from output. + _, _, transpose_axes = zip( + *sorted( + zip(logical_indices, physical_indices, range(len(logical_indices))) + ) + ) + logical_mesh = np.transpose(logical_mesh, transpose_axes) + + # Reshape to add the trivial dimensions back. + logical_mesh = np.reshape(logical_mesh, logical_mesh_shape) + + return logical_mesh + + +def _bounds_from_last_device(last_device) -> Sequence[int]: + """Gets the bound from the given last device.""" + # Must be passed the device at the highest-coordinate corner of the + # relevant mesh, which is a requirement we know is satisfied by the last + # device in jax.devices(). + assert hasattr(last_device, 'coords'), 'Only TPU supported' + x, y, z = last_device.coords + return x + 1, y + 1, z + 1, last_device.core_on_chip + 1 + + +def _get_physical_tpu_mesh(jax_devices: Sequence[Any]) -> np.ndarray: + r"""Rearrange TPU devices in a slice into a physical mesh. + + Args: + jax_devices: A list of JAX devices in a TPU slice in process-tiled z, y, x, + core order, e.g. from jax.devices(). The coordinates of these devices + should constitute a cuboid with no holes; e.g., the coordinates can be + {(1, 0, 0), (1, 0, 1), (1, 1, 0), (1, 1, 1)} (a 1x2x2 cuboid); passing + only 3 of these devices would result in a "hole" in that cuboid, which is + an error. As in our example, the cuboid is not required to include the + point (0, 0, 0). + + Returns: + A np.ndarray of JAX devices with shape [global_x, global_y, global_z]. On + v2 and v3, global_z is instead cores_per_chip (i.e., 2). + """ + device_kind = jax_devices[0].device_kind + device_coords = [d.coords for d in jax_devices] + coord_size = len(device_coords[0]) + # Position-wise max and min coordinates: + max_coords = tuple( + max(dc[i] for dc in device_coords) for i in range(coord_size) + ) + min_coords = tuple( + min(dc[i] for dc in device_coords) for i in range(coord_size) + ) + dims = tuple(h - l + 1 for (h, l) in zip(max_coords, min_coords)) + + max_cores_per_chip = max(d.core_on_chip for d in jax_devices) + min_cores_per_chip = min(d.core_on_chip for d in jax_devices) + cores_per_chip = max_cores_per_chip - min_cores_per_chip + 1 + + assert len(dims) == 3, dims + assert ( + len(jax_devices) == np.prod(dims) * cores_per_chip + ), f'{jax_devices=} {dims=} {cores_per_chip=}' + + if device_kind in (_TPU_V2, _TPU_V3): + out = np.empty(dims[:2] + (cores_per_chip,), dtype=object) + for d in jax_devices: + coords = d.coords + assert coords[2] == 0, d + out[ + coords[0] - min_coords[0], + coords[1] - min_coords[1], + d.core_on_chip - min_cores_per_chip, + ] = d + else: + out = np.empty(dims, dtype=object) + for d in jax_devices: + coords = d.coords + if d.core_on_chip != 0: + raise AssertionError( + 'Creating meshes for TPU >v3 requires one device per chip' + f' ("megacore" mode). Got device id {d.core_on_chip} for a device' + f' of kind {device_kind}: {d}.' + ) + out[ + coords[0] - min_coords[0], + coords[1] - min_coords[1], + coords[2] - min_coords[2], + ] = d + + # Check there is no "hole" in the mesh we constructed. + if (out == None).any(): # pylint: disable=singleton-comparison + raise AssertionError( + 'Constructed mesh contains a "hole"; probable cause: coordinates ' + f'of jax_devices are not a contiguous cuboid: {jax_devices}' + ) + return out + + +# jekbradbury's famous trick for creating contiguous submeshes (where available) +def _transpose_trick( + physical_mesh: np.ndarray, mesh_shape: Sequence[int] +) -> np.ndarray: + mesh_shape = tuple(mesh_shape) + topology = physical_mesh.shape + if topology not in _TRANSPOSE_TRICKS: + raise ValueError( + 'create_device_mesh cannot create contiguous submeshes for ' + f'physical mesh topology {topology}' + ) + + mesh_shape_no_trivial_dims: tuple[int, ...] = () + for dim_size in mesh_shape: + if dim_size != 1: + mesh_shape_no_trivial_dims += (dim_size,) + + if mesh_shape_no_trivial_dims not in _TRANSPOSE_TRICKS[topology]: + raise ValueError( + 'create_device_mesh cannot create contiguous submeshes for ' + f'mesh_shape {mesh_shape} and physical mesh topology {topology}. ' + f'Available mesh_shapes: {list(_TRANSPOSE_TRICKS[topology].keys())}' + ) + + return physical_mesh.transpose( + *_TRANSPOSE_TRICKS[topology][mesh_shape_no_trivial_dims] + ) + + +def create_device_mesh( + mesh_shape: Sequence[int], + devices: Sequence[Any] | None = None, + *, + contiguous_submeshes: bool = False, + allow_split_physical_axes: bool = False, +) -> np.ndarray: + """Creates a performant device mesh for jax.sharding.Mesh. + + Args: + mesh_shape: shape of logical mesh, ordered by increasing network-intensity + e.g. [replica, data, mdl] where mdl has the most network communication + requirements. + devices: optionally, the devices to construct a mesh for. Defaults to + jax.devices(). + contiguous_submeshes: if True, this function will attempt to create a mesh + where each process's local devices form a contiguous submesh. A ValueError + will be raised if this function can't produce a suitable mesh. This + setting was sometimes necessary before the introduction of jax.Array to + ensure non-ragged local arrays; if using jax.Arrays, it's better to keep + this set to False. + allow_split_physical_axes: If True, we will split physical axes if necessary + to produce the desired device mesh. + + Raises: + ValueError: if the number of devices doesn't equal the product of + `mesh_shape`. + + Returns: + A np.ndarray of JAX devices with mesh_shape as its shape that can be fed + into jax.sharding.Mesh with good collective performance. + """ + if devices is None: + devices = xb.devices() + if np.prod(mesh_shape) != len(devices): + raise ValueError( + f'Number of devices {len(devices)} must equal the product ' + f'of mesh_shape {mesh_shape}' + ) + last_device = devices[-1] + + handler = device_kind_handler_dict.get(last_device.device_kind, None) + if handler is not None: + result = handler( + mesh_shape, devices, contiguous_submeshes=contiguous_submeshes + ) + if result is not None: + return result + + if last_device.platform == 'tpu': + physical_mesh = _get_physical_tpu_mesh(devices) + if contiguous_submeshes: + physical_mesh = _transpose_trick(physical_mesh, mesh_shape) + device_mesh, _ = _create_device_mesh_for_nd_torus( + physical_mesh, + mesh_shape, + allow_split_physical_axes=allow_split_physical_axes, + ) + return device_mesh + else: + device_mesh = np.asarray(devices).reshape(mesh_shape) + return device_mesh + + +def create_hybrid_device_mesh( + mesh_shape: Sequence[int], + dcn_mesh_shape: Sequence[int], + devices: Sequence[Any] | None = None, + *, + process_is_granule: bool = False, + should_sort_granules_by_key: bool = True, + allow_split_physical_axes: bool = False, +) -> np.ndarray: + """Creates a device mesh for hybrid (e.g., ICI and DCN) parallelism. + + Args: + mesh_shape: shape of the logical mesh for the faster/inner network, ordered + by increasing network intensity, e.g. [replica, data, mdl] where mdl has + the most network communication requirements. + dcn_mesh_shape: shape of the logical mesh for the slower/outer network, in + the same order as mesh_shape. + devices: optionally, the devices to construct a mesh for. Defaults to + jax.devices(). + process_is_granule: if True, this function will treat processes as the units + of the slower/outer network. Otherwise it will look for slice_index + attributes on devices and use slices as the units. Enabling this is meant + as a fallback for platforms that don't set slice_index. + should_sort_granules_by_key: Whether device granules should be sorted by the + granule key, either slice or process index, depending on + process_is_granule. + allow_split_physical_axes: If True, we will split physical axes if necessary + to produce the desired device mesh. + + Raises: + ValueError: if the number of slices to which the `devices` belong doesn't + equal the product of `dcn_mesh_shape`, or if the number of devices + belonging to any single slice does not equal the product of `mesh_shape`. + + Returns: + A np.ndarray of JAX devices with mesh_shape * dcn_mesh_shape as its shape + that can be fed into jax.sharding.Mesh for hybrid parallelism. + """ + if devices is None: + devices = xb.devices() + attr = 'process_index' if process_is_granule else 'slice_index' + if not hasattr(devices[0], attr): + raise ValueError( + f'Device {devices[0]} does not have attribute {attr}. See' + ' `process_is_granule` option.' + ) + granule_dict = collections.defaultdict(list) + for dev in devices: + granule_dict[getattr(dev, attr)].append(dev) + granules = ( + [granule_dict[key] for key in sorted(granule_dict.keys())] + if should_sort_granules_by_key + else granule_dict.values() + ) + if np.prod(dcn_mesh_shape) != len(granules): + raise ValueError( + f'Number of slices {len(granules)} must equal the product of ' + f'dcn_mesh_shape {dcn_mesh_shape}' + ) + per_granule_meshes = [ + create_device_mesh( + mesh_shape, + granule, + allow_split_physical_axes=allow_split_physical_axes, + ) + for granule in granules + ] + # TODO(jekbradbury): handle non-uniform DCN topologies + granule_mesh = np.arange(len(granules)).reshape(dcn_mesh_shape) + blocks = np.vectorize(lambda i: per_granule_meshes[i], otypes=[object])( + granule_mesh + ) + device_mesh = np.block(blocks.tolist()) + return device_mesh diff --git a/jax/_src/nn/functions.py b/jax/_src/nn/functions.py index 32d543a27966..c81d51ea054b 100644 --- a/jax/_src/nn/functions.py +++ b/jax/_src/nn/functions.py @@ -35,7 +35,7 @@ from jax._src.cudnn.fused_attention_stablehlo import ( dot_product_attention as cudnn_dot_product_attention, MaskType) from jax._src.numpy import util as numpy_util -from jax._src.typing import Array, ArrayLike +from jax._src.typing import Array, ArrayLike, DType from jax._src.ops.special import logsumexp as _logsumexp @@ -430,8 +430,8 @@ def gelu(x: ArrayLike, approximate: bool = True) -> Array: If ``approximate=False``, computes the element-wise function: .. math:: - \mathrm{gelu}(x) = \frac{x}{2} \left(1 + \mathrm{erf} \left( - \frac{x}{\sqrt{2}} \right) \right) + \mathrm{gelu}(x) = \frac{x}{2} \left(\mathrm{erfc} \left( + \frac{-x}{\sqrt{2}} \right) \right) If ``approximate=True``, uses the approximate formulation of GELU: @@ -443,7 +443,7 @@ def gelu(x: ArrayLike, approximate: bool = True) -> Array: `_, section 2. Args: - x : input array + x: input array approximate: whether to use the approximate or exact formulation. """ [x_arr] = numpy_util.promote_args_inexact("gelu", x) @@ -453,8 +453,10 @@ def gelu(x: ArrayLike, approximate: bool = True) -> Array: cdf = 0.5 * (1.0 + jnp.tanh(sqrt_2_over_pi * (x_arr + 0.044715 * (x_arr ** 3)))) return x_arr * cdf else: - sqrt_2 = np.sqrt(2).astype(x_arr.dtype) - return jnp.array(x_arr * (lax.erf(x_arr / sqrt_2) + 1) / 2, dtype=x_arr.dtype) + sqrt_half = np.sqrt(0.5).astype(x_arr.dtype) + return jnp.array( + 0.5 * x_arr * (lax.erfc(-x_arr * sqrt_half)), dtype=x_arr.dtype + ) @partial(jax.jit, static_argnames=("axis",)) def glu(x: ArrayLike, axis: int = -1) -> Array: @@ -541,7 +543,7 @@ def log_softmax(x: ArrayLike, # TODO(phawkins): this jit was found to change numerics in a test. Debug this. -#@partial(jax.jit, static_argnames=("axis",)) +# @partial(jax.jit, static_argnames=("axis",)) def softmax(x: ArrayLike, axis: int | tuple[int, ...] | None = -1, where: ArrayLike | None = None, @@ -781,13 +783,65 @@ def _get_large_negative(dtype): dtype_max = jnp.finfo(dtype).max return jnp.asarray(-0.7 * dtype_max, dtype=dtype) -def _get_causal_mask(T, S, dtype): - pred = jnp.tril(jnp.ones((T, S), dtype=jnp.bool_)) - mask = jnp.where(pred, jnp.asarray(0.0, dtype), _get_large_negative(dtype)) - return mask +def _get_causal_mask(T, S): + mask = jnp.tril(jnp.ones((T, S), dtype=jnp.bool_)) + return mask[None, None, :, :] + +def _get_window_mask(T: int, S: int, local_window_size: tuple[int, int]): + query_pos = jnp.array(range(T)) + key_pos = jnp.array(range(S)) + left_window, right_window = local_window_size + left_mask = query_pos[..., None] <= key_pos[..., None, :] + left_window + right_mask = query_pos[..., None] >= key_pos[..., None, :] - right_window + return jnp.logical_and(right_mask, left_mask)[None, None, :, :] + +def _get_padding_mask_logits(T, S, q_seqlen, kv_seqlen): + q_mask = True + kv_mask = True + if q_seqlen is not None: + q_indices = jnp.arange(0, T)[None, :, None] + q_mask = q_indices < q_seqlen[:, None, None] + if kv_seqlen is not None: + kv_indices = jnp.arange(0, S)[None, None, :] + kv_mask = kv_indices < kv_seqlen[:, None, None] + mask = jnp.logical_and(q_mask, kv_mask) + return mask[:, None, :, :] + +def _get_padding_mask_encoded(T, q_seqlen): + q_indices = jnp.arange(0, T)[None, :] + mask = q_indices < q_seqlen[:, None] + return mask[:, :, None, None] + +def _apply_masks(logits, mask, is_causal, q_seqlen, kv_seqlen, + local_window_size): + if mask is None and not is_causal and q_seqlen is None and kv_seqlen is None: + return logits + + combined_mask = jnp.ones_like(logits, dtype=jnp.bool_) + if mask is not None: + assert mask.dtype == jnp.bool_ + combined_mask = jnp.logical_and(combined_mask, mask) + + T, S = logits.shape[2], logits.shape[3] + + if is_causal: + mask = _get_causal_mask(T, S) + combined_mask = jnp.logical_and(combined_mask, mask) + + if local_window_size is not None: + mask = _get_window_mask(T, S, local_window_size) + combined_mask = jnp.logical_and(combined_mask, mask) + + if q_seqlen is not None or kv_seqlen is not None: + mask = _get_padding_mask_logits(T, S, q_seqlen, kv_seqlen) + combined_mask = jnp.logical_and(combined_mask, mask) + + large_negative_number = _get_large_negative(logits.dtype) + padded_logits = jnp.where(combined_mask, logits, large_negative_number) + return padded_logits def _dot_product_attention_core(query, key, value, bias, mask, is_causal, - scale): + scale, q_seqlen, kv_seqlen, local_window_size): logits_dtype = jnp.promote_types(query.dtype, jnp.float32) logits = jnp.einsum('BTNH,BSNH->BNTS', query, key, preferred_element_type=logits_dtype) @@ -797,24 +851,17 @@ def _dot_product_attention_core(query, key, value, bias, mask, is_causal, if bias is not None: logits = (logits + bias).astype(logits.dtype) - if mask is not None: - assert mask.dtype == jnp.bool_ - large_negative_number = _get_large_negative(logits.dtype) - padded_logits = jnp.where(mask, logits, large_negative_number) - else: - padded_logits = logits - - if is_causal: - T, S = query.shape[1], key.shape[1] - mask = jnp.broadcast_to(_get_causal_mask(T, S, logits.dtype), - padded_logits.shape) - padded_logits = padded_logits + mask + padded_logits = _apply_masks(logits, mask, is_causal, q_seqlen, kv_seqlen, + local_window_size) # Softmax and it is always carried out in fp32. padded_logits = padded_logits.astype(jnp.float32) probs = jax.nn.softmax(padded_logits, axis=-1).astype(key.dtype) encoded = jnp.einsum('BNTS,BSNH->BTNH', probs, value) + if q_seqlen is not None and kv_seqlen is not None: + mask = _get_padding_mask_encoded(encoded.shape[1], q_seqlen) + encoded *= mask.astype(encoded.dtype) return encoded def _dot_product_attention_xla( @@ -824,7 +871,10 @@ def _dot_product_attention_xla( bias: Array | None, mask: Array | None, is_causal: bool, - scale: float): + scale: float, + q_seqlen: Array | None, + kv_seqlen: Array | None, + local_window_size: tuple[int, int] | None): B, T, N, H = query.shape _, S, K, _ = key.shape @@ -842,10 +892,13 @@ def _reshape_to_grouped(t): return t bias = _reshape_to_grouped(bias) mask = _reshape_to_grouped(mask) - vmapped_fn = jax.vmap(_dot_product_attention_core, - in_axes=(3, None, None, 2, 2, None, None), - out_axes=3) - encoded = vmapped_fn(query, key, value, bias, mask, is_causal, scale) + vmapped_fn = jax.vmap( + _dot_product_attention_core, + in_axes=(3, None, None, 2, 2, None, None, None, None, None), + out_axes=3, + ) + encoded = vmapped_fn(query, key, value, bias, mask, is_causal, scale, + q_seqlen, kv_seqlen, local_window_size) encoded = jnp.reshape(encoded, (B, T, N, H)) return encoded @@ -853,11 +906,14 @@ def dot_product_attention( query: ArrayLike, key: ArrayLike, value: ArrayLike, - *, bias: ArrayLike | None = None, mask: ArrayLike | None = None, + *, scale: float | None = None, is_causal: bool = False, + query_seq_lengths: ArrayLike | None = None, + key_value_seq_lengths: ArrayLike | None = None, + local_window_size: int | tuple[int, int] | None = None, implementation: Literal['xla', 'cudnn'] | None = None) -> Array: r"""Scaled dot product attention function. @@ -882,20 +938,20 @@ def dot_product_attention( G = number of groups, which equals to N // K Args: - query: query array; shape :code:`(BTNH)` - key: key array: shape :code:`(BSKH)`. When `K` equals `N`, multi-headed - attention (MHA: https://arxiv.org/abs/1706.03762) is performed. Otherwise, - grouped query attention (GQA: https://arxiv.org/abs/2305.13245) is performed - if `N` is a multiple of `K`, and multi-query attention (MQA: - https://arxiv.org/abs/1911.02150) is performed if `K == 1` (a special case - of GQA). + query: query array; shape :code:`(BTNH|TNH)` + key: key array: shape :code:`(BSKH|SKH)`. When `K` equals `N`, multi-headed + attention (MHA https://arxiv.org/abs/1706.03762) is performed. Otherwise, + grouped query attention (GQA https://arxiv.org/abs/2305.13245) is + performed if `N` is a multiple of `K`, and multi-query attention (MQA + https://arxiv.org/abs/1911.02150) is performed if `K == 1` (a special case + of GQA). value: value array, should have the same shape as the `key` array. bias: optional, bias array to be added to logits; The shape must be 4D and - be broadcastable to :code:`(BNTS)`. + be broadcastable to :code:`(BNTS|NTS)`. mask: optional, mask array used to filter out logits. It is a boolean mask where `True` indicates the element should take part in attention. For an additive mask, users should pass it to `bias`. The shape must be 4D and be - broadcastable to :code:`(BNTS)`. + broadcastable to :code:`(BNTS|NTS)`. scale: scale for the logits. If None, the scale will be set to 1 divided by the square root of query's head dimension (i.e. H). is_causal: If true, causal attention will be applied. Note, some @@ -903,6 +959,16 @@ def dot_product_attention( logits to mask out the non-causal parts of the attention matrix, but other implementations like `cudnn` will avoid computing the non-causal regions, providing speedups. + query_seq_lengths: `int32` array of sequence lengths for query; shape + :code:`(B)` + key_value_seq_lengths: `int32` array of sequence lengths for key and value; + shape :code:`(B)` + local_window_size: Window sizes to make self attention to attend to each + token's local window. If set, this specifies the (left_window_size, + right_window_size) for each token. E.g., if local_window_size == (3, 2) + and the sequence is [0, 1, 2, 3, 4, 5, c, 7, 8, 9], token `c` can attend + to [3, 4, 5, c, 7, 8]. If a single int is given, it will be intepreted as + a symmetric window (window_size, window_size). implementation: A string to control which implementation backend to use. Supported strings are `xla`, `cudnn` (cuDNN flash attention). It defaults to `None`, which will automatically select the best available backend. @@ -912,51 +978,105 @@ def dot_product_attention( Returns: An array of the attention output with the same shape as :code:`query`. """ - def _check_has_shape(t: Array, shape: Sequence[int], name: str) -> None: + output_shape = jnp.asarray(query).shape + def _ensure_4d(t): + t = jnp.asarray(t) + dims_to_add = 4 - t.ndim + if dims_to_add > 0: + return jnp.expand_dims(t, axis=tuple(range(dims_to_add))) + return t + + query_arr = _ensure_4d(query) + key_arr = _ensure_4d(key) + value_arr = _ensure_4d(value) + bias = _ensure_4d(bias) if bias is not None else None + mask = _ensure_4d(mask) if mask is not None else None + if query_seq_lengths is not None: + query_seq_lengths = jnp.asarray(query_seq_lengths) + if key_value_seq_lengths is not None: + key_value_seq_lengths = jnp.asarray(key_value_seq_lengths) + if isinstance(local_window_size, int): + local_window_size = (local_window_size, local_window_size) + + def _check_shape_and_dtype(t: Array | None, shape: Sequence[int], + dtype: DType | None, name: str) -> None: + if t is None: + return if t.ndim != len(shape): raise ValueError(f"{name} ndim should be {len(shape)}, but got {t.ndim}") + if dtype is not None and t.dtype != dtype: + raise ValueError(f"{name} dtype should be {dtype}, but got {t.dtype}") for i in range(t.ndim): if shape[i] != -1 and t.shape[i] != shape[i]: raise ValueError(f"{name} shape should be {shape}: but got {t.shape}") - query = jnp.asarray(query) - key = jnp.asarray(key) - value = jnp.asarray(value) - bias = bias if bias is None else jnp.asarray(bias) - mask = mask if mask is None else jnp.asarray(mask) - - B, S, K, H = key.shape - _check_has_shape(value, [B, S, K, H], 'value') - _check_has_shape(query, [B, -1, -1, H], 'query') - if query.shape[-2] % K != 0: - raise ValueError(f"The number of query heads must to a multiple of " - f"key/value heads, but got {query.shape[-2]} vs {K}") - if not (query.dtype == key.dtype == value.dtype): - raise ValueError(f"query/key/value should have the same shape, but got " - f"{query.shape} vs {key.shape} vs {value.shape}.") - if mask is not None and mask.dtype != jnp.bool_ and mask.ndim != 4: - raise ValueError(f"Mask must be a 4D boolean tensor, but got " - f"rank={mask.ndim}, dtype={mask.dtype}.") - if bias is not None and bias.ndim != 4: - raise ValueError(f"Bias must be a 4D tensor, but got rank={bias.ndim}.") + B, S, K, H = key_arr.shape + _check_shape_and_dtype(value_arr, [B, S, K, H], key_arr.dtype, 'value') + _check_shape_and_dtype(query_arr, [B, -1, -1, H], key_arr.dtype, 'query') + _check_shape_and_dtype(mask, [-1] * 4, jnp.bool_, 'mask') + _check_shape_and_dtype(bias, [-1] * 4, None, 'bias') + _check_shape_and_dtype(query_seq_lengths, [B], jnp.int32, + 'query_seq_lengths') + _check_shape_and_dtype(key_value_seq_lengths, [B], jnp.int32, + 'key_value_seq_lengths') + if query_arr.shape[-2] % K != 0: + raise ValueError(f"The number of query heads must be a multiple of " + f"key/value heads, but got {query_arr.shape[-2]} vs {K}") scale_val = (1.0 / np.sqrt(H)) if scale is None else scale match implementation: case 'xla': - return _dot_product_attention_xla( - query, key, value, bias, mask, is_causal=is_causal, scale=scale_val, + out = _dot_product_attention_xla( + query_arr, key_arr, value_arr, bias, mask, is_causal=is_causal, + scale=scale_val, q_seqlen=query_seq_lengths, + kv_seqlen=key_value_seq_lengths, + local_window_size=local_window_size, ) case 'cudnn': - mask_type = MaskType.CAUSAL if is_causal else MaskType.NO_MASK - return cudnn_dot_product_attention( - query, key, value, bias, mask, scale=scale_val, mask_type=mask_type + use_padding = ( + query_seq_lengths is not None or key_value_seq_lengths is not None + ) + if use_padding: + if query_seq_lengths is None: + T = query_arr.shape[1] + query_seq_lengths = jnp.full((B,), T, dtype=jnp.int32) + if key_value_seq_lengths is None: + key_value_seq_lengths = jnp.full((B,), S, dtype=jnp.int32) + + mask_type = MaskType.NO_MASK + if use_padding and is_causal: + mask_type = MaskType.PADDING_CAUSAL + elif is_causal: + mask_type = MaskType.CAUSAL + elif use_padding: + mask_type = MaskType.PADDING + # CuDNN supports only the left window with an exclusive boundary when + # causal mask is enabled. + sliding_window = None + if local_window_size is not None: + l_window, r_window = local_window_size + if r_window == 0 or mask_type == MaskType.CAUSAL: + sliding_window = l_window + 1 + else: + raise ValueError(f"cuDNN doesn't support right window: {r_window} " + "when causal mask is not used.") + + out = cudnn_dot_product_attention( + query_arr, key_arr, value_arr, bias, mask, query_seq_lengths, + key_value_seq_lengths, scale=scale_val, mask_type=mask_type, + sliding_window_length=sliding_window, ) case None: # TODO(kaixih@nvidia) Defaults to XLA for now. Will automatically select # best backend. - return _dot_product_attention_xla( - query, key, value, bias, mask, is_causal=is_causal, scale=scale_val, + out = _dot_product_attention_xla( + query_arr, key_arr, value_arr, bias, mask, is_causal=is_causal, + scale=scale_val, q_seqlen=query_seq_lengths, + kv_seqlen=key_value_seq_lengths, + local_window_size=local_window_size, ) case _: raise ValueError(f"Unsupported implementation option: {implementation}") + + return jnp.reshape(out, output_shape) diff --git a/jax/_src/nn/initializers.py b/jax/_src/nn/initializers.py index 7d228e4beef4..eb1bb1609bbf 100644 --- a/jax/_src/nn/initializers.py +++ b/jax/_src/nn/initializers.py @@ -184,7 +184,7 @@ def truncated_normal(stddev: RealNumeric = 1e-2, >>> import jax, jax.numpy as jnp >>> initializer = jax.nn.initializers.truncated_normal(5.0) - >>> initializer(jax.random.PRNGKey(42), (2, 3), jnp.float32) # doctest: +SKIP + >>> initializer(jax.random.key(42), (2, 3), jnp.float32) # doctest: +SKIP Array([[ 2.9047365, 5.2338114, 5.29852 ], [-3.836303 , -4.192359 , 0.6022964]], dtype=float32) """ diff --git a/jax/_src/numpy/array_methods.py b/jax/_src/numpy/array_methods.py index 87635be37c84..95d681cad8e5 100644 --- a/jax/_src/numpy/array_methods.py +++ b/jax/_src/numpy/array_methods.py @@ -26,7 +26,7 @@ import abc from functools import partial, wraps import math -from typing import Any +from typing import Any, Sequence import numpy as np import jax @@ -43,9 +43,8 @@ from jax._src.numpy import lax_numpy from jax._src.numpy import reductions from jax._src.numpy import ufuncs -from jax._src.numpy import util from jax._src.ops import scatter -from jax._src.typing import Array, ArrayLike, DimSize, DTypeLike, Shape +from jax._src.typing import Array, ArrayLike, DimSize, DTypeLike, Shape, StaticScalar from jax._src.util import safe_zip, safe_map map, unsafe_map = safe_map, map @@ -59,7 +58,56 @@ # functions, which can themselves handle instances from any of these classes. -def _astype(arr: ArrayLike, dtype: DTypeLike, copy: bool = False, device: xc.Device | Sharding | None = None) -> Array: +def _all(self: Array, axis: reductions.Axis = None, out: None = None, + keepdims: bool = False, *, where: ArrayLike | None = None) -> Array: + """Test whether all array elements along a given axis evaluate to True. + + Refer to :func:`jax.numpy.all` for the full documentation. + """ + return reductions.all(self, axis=axis, out=out, keepdims=keepdims, where=where) + +def _any(self: Array, axis: reductions.Axis = None, out: None = None, + keepdims: bool = False, *, where: ArrayLike | None = None) -> Array: + """Test whether any array elements along a given axis evaluate to True. + + Refer to :func:`jax.numpy.any` for the full documentation. + """ + return reductions.any(self, axis=axis, out=out, keepdims=keepdims, where=where) + +def _argmax(self: Array, axis: int | None = None, out: None = None, + keepdims: bool | None = None) -> Array: + """Return the index of the maximum value. + + Refer to :func:`jax.numpy.argmax` for the full documentation. + """ + return lax_numpy.argmax(self, axis=axis, out=out, keepdims=keepdims) + +def _argmin(self: Array, axis: int | None = None, out: None = None, + keepdims: bool | None = None) -> Array: + """Return the index of the minimum value. + + Refer to :func:`jax.numpy.argmin` for the full documentation. + """ + return lax_numpy.argmin(self, axis=axis, out=out, keepdims=keepdims) + +def _argpartition(self: Array, kth: int, axis: int = -1) -> Array: + """Return the indices that partially sort the array. + + Refer to :func:`jax.numpy.argpartition` for the full documentation. + """ + return lax_numpy.argpartition(self, kth=kth, axis=axis) + +def _argsort(self: Array, axis: int | None = -1, *, kind: None = None, order: None = None, + stable: bool = True, descending: bool = False) -> Array: + """Return the indices that sort the array. + + Refer to :func:`jax.numpy.argsort` for the full documentation. + """ + return lax_numpy.argsort(self, axis=axis, kind=kind, order=order, + stable=stable, descending=descending) + +def _astype(self: Array, dtype: DTypeLike | None, copy: bool = False, + device: xc.Device | Sharding | None = None) -> Array: """Copy the array and cast to a specified dtype. This is implemented via :func:`jax.lax.convert_element_type`, which may @@ -67,42 +115,298 @@ def _astype(arr: ArrayLike, dtype: DTypeLike, copy: bool = False, device: xc.Dev some cases. In particular, the details of float-to-int and int-to-float casts are implementation dependent. """ - return lax_numpy.astype(arr, dtype, copy=copy, device=device) + return lax_numpy.astype(self, dtype, copy=copy, device=device) -def _to_device(arr: ArrayLike, device: xc.Device | Sharding, *, - stream: int | Any | None = None): - if stream is not None: - raise NotImplementedError("stream argument of array.to_device()") - return api.device_put(arr, device) +def _choose(self: Array, choices: Sequence[ArrayLike], out: None = None, mode: str = 'raise') -> Array: + """Construct an array choosing from elements of multiple arrays. + Refer to :func:`jax.numpy.choose` for the full documentation. + """ + return lax_numpy.choose(self, choices=choices, out=out, mode=mode) -def _nbytes(arr: ArrayLike) -> int: - """Total bytes consumed by the elements of the array.""" - return np.size(arr) * dtypes.dtype(arr, canonicalize=True).itemsize +def _clip(self: Array, min: ArrayLike | None = None, max: ArrayLike | None = None) -> Array: + """Return an array whose values are limited to a specified range. + + Refer to :func:`jax.numpy.clip` for full documentation. + """ + return lax_numpy.clip(self, min=min, max=max) + +def _compress(self: Array, condition: ArrayLike, + axis: int | None = None, *, out: None = None, + size: int | None = None, fill_value: ArrayLike = 0) -> Array: + """Return selected slices of this array along given axis. + + Refer to :func:`jax.numpy.compress` for full documentation. + """ + return lax_numpy.compress(condition, self, axis=axis, out=out, + size=size, fill_value=fill_value) + +def _conj(self: Array) -> Array: + """Return the complex conjugate of the array. + + Refer to :func:`jax.numpy.conj` for the full documentation. + """ + return ufuncs.conj(self) + +def _conjugate(self: Array) -> Array: + """Return the complex conjugate of the array. + Refer to :func:`jax.numpy.conjugate` for the full documentation. + """ + return ufuncs.conjugate(self) + +def _copy(self: Array) -> Array: + """Return a copy of the array. + + Refer to :func:`jax.numpy.copy` for the full documentation. + """ + return lax_numpy.copy(self) -def _item(a: Array, *args) -> bool | int | float | complex: +def _cumprod(self: Array, axis: reductions.Axis = None, dtype: DTypeLike | None = None, + out: None = None) -> Array: + """Return the cumulative product of the array. + + Refer to :func:`jax.numpy.cumprod` for the full documentation. + """ + return reductions.cumprod(self, axis=axis, dtype=dtype, out=out) + +def _cumsum(self: Array, axis: reductions.Axis = None, dtype: DTypeLike | None = None, + out: None = None) -> Array: + """Return the cumulative sum of the array. + + Refer to :func:`jax.numpy.cumsum` for the full documentation. + """ + return reductions.cumsum(self, axis=axis, dtype=dtype, out=out) + +def _diagonal(self: Array, offset: int = 0, axis1: int = 0, axis2: int = 1) -> Array: + """Return the specified diagonal from the array. + + Refer to :func:`jax.numpy.diagonal` for the full documentation. + """ + return lax_numpy.diagonal(self, offset=offset, axis1=axis1, axis2=axis2) + +def _dot(self: Array, b: ArrayLike, *, precision: lax_internal.PrecisionLike = None, + preferred_element_type: DTypeLike | None = None) -> Array: + """Compute the dot product of two arrays. + + Refer to :func:`jax.numpy.dot` for the full documentation. + """ + return lax_numpy.dot(self, b, precision=precision, preferred_element_type=preferred_element_type) + +def _flatten(self: Array, order: str = "C") -> Array: + """Flatten array into a 1-dimensional shape. + + Refer to :func:`jax.numpy.ravel` for the full documentation. + """ + return lax_numpy.ravel(self, order=order) + +def _imag_property(self: Array) -> Array: + """Return the imaginary part of the array.""" + return ufuncs.imag(self) + +def _item(self: Array, *args: int) -> bool | int | float | complex: """Copy an element of an array to a standard Python scalar and return it.""" - arr = core.concrete_or_error(np.asarray, a, context="This occurred in the item() method of jax.Array") - if dtypes.issubdtype(a.dtype, dtypes.extended): - raise TypeError(f"No Python scalar type for {a.dtype=}") + arr = core.concrete_or_error(np.asarray, self, context="This occurred in the item() method of jax.Array") + if dtypes.issubdtype(self.dtype, dtypes.extended): + raise TypeError(f"No Python scalar type for {arr.dtype=}") return arr.item(*args) -def _itemsize(arr: ArrayLike) -> int: +def _itemsize_property(self: Array) -> int: """Length of one array element in bytes.""" - return dtypes.dtype(arr, canonicalize=True).itemsize + return dtypes.dtype(self, canonicalize=True).itemsize +def _matrix_transpose_property(self: Array): + """Compute the (batched) matrix transpose. -def _clip(number: ArrayLike, - min: ArrayLike | None = None, max: ArrayLike | None = None) -> Array: - """Return an array whose values are limited to a specified range. + Refer to :func:`jax.numpy.matrix_transpose` for details. + """ + return lax_numpy.matrix_transpose(self) - Refer to :func:`jax.numpy.clip` for full documentation.""" - return lax_numpy.clip(number, min=min, max=max) +def _max(self: Array, axis: reductions.Axis = None, out: None = None, + keepdims: bool = False, initial: ArrayLike | None = None, + where: ArrayLike | None = None) -> Array: + """Return the maximum of array elements along a given axis. + + Refer to :func:`jax.numpy.max` for the full documentation. + """ + return reductions.max(self, axis=axis, out=out, keepdims=keepdims, + initial=initial, where=where) -def _transpose(a: Array, *args: Any) -> Array: - """Returns a view of the array with axes transposed. +def _mean(self: Array, axis: reductions.Axis = None, dtype: DTypeLike | None = None, + out: None = None, keepdims: bool = False, *, + where: ArrayLike | None = None) -> Array: + """Return the mean of array elements along a given axis. + + Refer to :func:`jax.numpy.mean` for the full documentation. + """ + return reductions.mean(self, axis=axis, dtype=dtype, out=out, + keepdims=keepdims, where=where) + +def _min(self: Array, axis: reductions.Axis = None, out: None = None, + keepdims: bool = False, initial: ArrayLike | None = None, + where: ArrayLike | None = None) -> Array: + """Return the minimum of array elements along a given axis. + + Refer to :func:`jax.numpy.min` for the full documentation. + """ + return reductions.min(self, axis=axis, out=out, keepdims=keepdims, + initial=initial, where=where) + +def _nbytes_property(self: Array) -> int: + """Total bytes consumed by the elements of the array.""" + return np.size(self) * dtypes.dtype(self, canonicalize=True).itemsize + +def _nonzero(self: Array, *, fill_value: None | ArrayLike | tuple[ArrayLike, ...] = None, + size: int | None = None) -> tuple[Array, ...]: + """Return indices of nonzero elements of an array. + + Refer to :func:`jax.numpy.nonzero` for the full documentation. + """ + return lax_numpy.nonzero(self, size=size, fill_value=fill_value) + +def _prod(self: Array, axis: reductions.Axis = None, dtype: DTypeLike | None = None, + out: None = None, keepdims: bool = False, + initial: ArrayLike | None = None, where: ArrayLike | None = None, + promote_integers: bool = True) -> Array: + """Return product of the array elements over a given axis. + + Refer to :func:`jax.numpy.prod` for the full documentation. + """ + return reductions.prod(self, axis=axis, dtype=dtype, out=out, keepdims=keepdims, + initial=initial, where=where, promote_integers=promote_integers) + +def _ptp(self: Array, axis: reductions.Axis = None, out: None = None, + keepdims: bool = False) -> Array: + """Return the peak-to-peak range along a given axis. + + Refer to :func:`jax.numpy.ptp` for the full documentation. + """ + return reductions.ptp(self, axis=axis, out=out, keepdims=keepdims) + +def _real_property(self: Array) -> Array: + """Return the real part of the array.""" + return ufuncs.real(self) + +def _repeat(self: Array, repeats: ArrayLike, axis: int | None = None, *, + total_repeat_length: int | None = None) -> Array: + """Construct an array from repeated elements. + + Refer to :func:`jax.numpy.repeat` for the full documentation. + """ + return lax_numpy.repeat(self, repeats=repeats, axis=axis, total_repeat_length=total_repeat_length) + +def _reshape(self: Array, *args: Any, order: str = "C") -> Array: + """Returns an array containing the same data with a new shape. + + Refer to :func:`jax.numpy.reshape` for full documentation. + """ + __tracebackhide__ = True + newshape = _compute_newshape(self, args[0] if len(args) == 1 else args) + if order == "C": + return lax.reshape(self, newshape, None) + elif order == "F": + dims = list(range(self.ndim)[::-1]) + return lax.reshape(self, newshape[::-1], dims).T + elif order == "A": + raise NotImplementedError("np.reshape order=A is not implemented.") + else: + raise ValueError(f"Unexpected value for 'order' argument: {order}.") + +def _round(self: Array, decimals: int = 0, out: None = None) -> Array: + """Round array elements to a given decimal. + + Refer to :func:`jax.numpy.round` for full documentation. + """ + return lax_numpy.round(self, decimals=decimals, out=out) + +def _searchsorted(self: Array, v: ArrayLike, side: str = 'left', + sorter: ArrayLike | None = None, *, method: str = 'scan') -> Array: + """Perform a binary search within a sorted array. + + Refer to :func:`jax.numpy.searchsorted` for full documentation.""" + return lax_numpy.searchsorted(self, v, side=side, sorter=sorter, method=method) + +def _sort(self: Array, axis: int | None = -1, *, kind: None = None, + order: None = None, stable: bool = True, descending: bool = False) -> Array: + """Return a sorted copy of an array. + + Refer to :func:`jax.numpy.sort` for full documentation. + """ + return lax_numpy.sort(self, axis=axis, kind=kind, order=order, + stable=stable, descending=descending) + +def _squeeze(self: Array, axis: reductions.Axis = None) -> Array: + """Remove one or more length-1 axes from array. + + Refer to :func:`jax.numpy.squeeze` for full documentation. + """ + return lax_numpy.squeeze(self, axis=axis) + +def _std(self: Array, axis: reductions.Axis = None, dtype: DTypeLike | None = None, + out: None = None, ddof: int = 0, keepdims: bool = False, *, + where: ArrayLike | None = None, correction: int | float | None = None) -> Array: + """Compute the standard deviation along a given axis. + + Refer to :func:`jax.numpy.std` for full documentation. + """ + return reductions.std(self, axis=axis, dtype=dtype, out=out, ddof=ddof, keepdims=keepdims, + where=where, correction=correction) + +def _sum(self: Array, axis: reductions.Axis = None, dtype: DTypeLike | None = None, + out: None = None, keepdims: bool = False, initial: ArrayLike | None = None, + where: ArrayLike | None = None, promote_integers: bool = True) -> Array: + """Sum of the elements of the array over a given axis. + + Refer to :func:`jax.numpy.sum` for full documentation. + """ + return reductions.sum(self, axis=axis, dtype=dtype, out=out, keepdims=keepdims, + where=where, promote_integers=promote_integers) + +def _swapaxes(self: Array, axis1: int, axis2: int) -> Array: + """Swap two axes of an array. + + Refer to :func:`jax.numpy.swapaxes` for full documentation. + """ + return lax_numpy.swapaxes(self, axis1=axis1, axis2=axis2) + + +def _take(self: Array, indices: ArrayLike, axis: int | None = None, out: None = None, + mode: str | None = None, unique_indices: bool = False, indices_are_sorted: bool = False, + fill_value: StaticScalar | None = None) -> Array: + """Take elements from an array. + + Refer to :func:`jax.numpy.take` for full documentation. + """ + return lax_numpy.take(self, indices, axis=axis, out=out, mode=mode, unique_indices=unique_indices, + indices_are_sorted=indices_are_sorted, fill_value=fill_value) + +def _to_device(self: Array, device: xc.Device | Sharding, *, + stream: int | Any | None = None): + """Return a copy of the array on the specified device + + Args: + device: :class:`~jax.Device` or :class:`~jax.sharding.Sharding` + to which the created array will be committed. + stream: not implemented, passing a non-None value will lead to an error. + Returns: + copy of array placed on the specified device or devices. + """ + if stream is not None: + raise NotImplementedError("stream argument of array.to_device()") + return api.device_put(self, device) + + +def _trace(self: Array, offset: int | ArrayLike = 0, axis1: int = 0, axis2: int = 1, + dtype: DTypeLike | None = None, out: None = None) -> Array: + """Return the sum along the diagonal. + + Refer to :func:`jax.numpy.trace` for full documentation. + """ + return lax_numpy.trace(self, offset=offset, axis1=axis1, axis2=axis2, dtype=dtype, out=out) + +def _transpose(self: Array, *args: Any) -> Array: + """Returns a copy of the array with axes transposed. Refer to :func:`jax.numpy.transpose` for full documentation. """ @@ -112,10 +416,27 @@ def _transpose(a: Array, *args: Any) -> Array: axis = args[0] if args[0] is None else _ensure_index_tuple(args[0]) else: axis = _ensure_index_tuple(args) - return lax_numpy.transpose(a, axis) + return lax_numpy.transpose(self, axis) + +def _transpose_property(self: Array): + """Compute the all-axis array transpose. + + Refer to :func:`jax.numpy.transpose` for details. + """ + return lax_numpy.transpose(self) + +def _var(self: Array, axis: reductions.Axis = None, dtype: DTypeLike | None = None, + out: None = None, ddof: int = 0, keepdims: bool = False, *, + where: ArrayLike | None = None, correction: int | float | None = None) -> Array: + """Compute the variance along a given axis. + Refer to :func:`jax.numpy.var` for full documentation. + """ + return reductions.var(self, axis=axis, dtype=dtype, out=out, ddof=ddof, + keepdims=keepdims, where=where, correction=correction) -def _compute_newshape(a: ArrayLike, newshape: DimSize | Shape) -> Shape: + +def _compute_newshape(arr: Array, newshape: DimSize | Shape) -> Shape: """Fixes a -1 value in newshape, if present.""" orig_newshape = newshape # for error messages try: @@ -130,43 +451,24 @@ def _compute_newshape(a: ArrayLike, newshape: DimSize | Shape) -> Shape: if neg1s: i, = neg1s other_sizes = (*newshape[:i], *newshape[i+1:]) - if (all(isinstance(d, int) for d in (*np.shape(a), *other_sizes)) and - np.size(a) % math.prod(other_sizes) != 0): - raise TypeError(f"cannot reshape array of shape {np.shape(a)} (size {np.size(a)}) " + if (all(isinstance(d, int) for d in (*arr.shape, *other_sizes)) and + arr.size % math.prod(other_sizes) != 0): + raise TypeError(f"cannot reshape array of shape {arr.shape} (size {arr.size}) " f"into shape {orig_newshape} because the product of " f"specified axis sizes ({math.prod(other_sizes)}) does " - f"not evenly divide {np.size(a)}") - sz = core.cancel_divide_tracers(np.shape(a), other_sizes) + f"not evenly divide {arr.size}") + sz = core.cancel_divide_tracers(arr.shape, other_sizes) if sz is not None: return (*newshape[:i], sz, *newshape[i+1:]) else: - if (all(isinstance(d, int) for d in (*np.shape(a), *newshape)) and - np.size(a) != math.prod(newshape)): - raise TypeError(f"cannot reshape array of shape {np.shape(a)} (size {np.size(a)}) " + if (all(isinstance(d, int) for d in (*arr.shape, *newshape)) and + arr.size != math.prod(newshape)): + raise TypeError(f"cannot reshape array of shape {arr.shape} (size {arr.size}) " f"into shape {orig_newshape} (size {math.prod(newshape)})") - return tuple(-core.divide_shape_sizes(np.shape(a), newshape) + return tuple(-core.divide_shape_sizes(arr.shape, newshape) if core.definitely_equal(d, -1) else d for d in newshape) - -def _reshape(a: Array, *args: Any, order: str = "C") -> Array: - """Returns an array containing the same data with a new shape. - - Refer to :func:`jax.numpy.reshape` for full documentation. - """ - __tracebackhide__ = True - newshape = _compute_newshape(a, args[0] if len(args) == 1 else args) - if order == "C": - return lax.reshape(a, newshape, None) - elif order == "F": - dims = list(range(a.ndim)[::-1]) - return lax.reshape(a, newshape[::-1], dims).T - elif order == "A": - raise NotImplementedError("np.reshape order=A is not implemented.") - else: - raise ValueError(f"Unexpected value for 'order' argument: {order}.") - - -def _view(arr: Array, dtype: DTypeLike | None = None, type: None = None) -> Array: +def _view(self: Array, dtype: DTypeLike | None = None, type: None = None) -> Array: """Return a bitwise copy of the array, viewed as a new dtype. This is fuller-featured wrapper around :func:`jax.lax.bitcast_convert_type`. @@ -187,72 +489,70 @@ def _view(arr: Array, dtype: DTypeLike | None = None, type: None = None) -> Arra should only contain 0 or 1 bytes. Otherwise, results may be unpredictable or may change depending on how the result is used. - This conversion is guaranteed and safe: - >>> jnp.array([1, 0, 1], dtype=jnp.int8).view(jnp.bool_) - Array([ True, False, True], dtype=bool) + This conversion is guaranteed and safe:: + + >>> jnp.array([1, 0, 1], dtype=jnp.int8).view(jnp.bool_) + Array([ True, False, True], dtype=bool) However, there are no guarantees about the results of any expression involving a view such as this: `jnp.array([1, 2, 3], dtype=jnp.int8).view(jnp.bool_)`. In particular, the results may change between JAX releases and depending on the platform. To safely convert such an array to a boolean array, compare it - with `0`: + with `0`:: - >>> jnp.array([1, 2, 0], dtype=jnp.int8) != 0 - Array([ True, True, False], dtype=bool) + >>> jnp.array([1, 2, 0], dtype=jnp.int8) != 0 + Array([ True, True, False], dtype=bool) """ if type is not None: raise NotImplementedError("`type` argument of array.view() is not supported.") - util.check_arraylike("view", arr) - arr = lax_numpy.asarray(arr) - dtypes.check_user_dtype_supported(dtype, "view") dtype = dtypes.canonicalize_dtype(dtype) - if arr.ndim == 0: - if arr.dtype.itemsize != dtype.itemsize: + if self.ndim == 0: + if self.dtype.itemsize != dtype.itemsize: raise ValueError("view() of a 0d array is only supported if the itemsize is unchanged.") - return _view(lax.expand_dims(arr, (0,)), dtype).squeeze() + return _view(lax.expand_dims(self, (0,)), dtype).squeeze() - if (arr.shape[-1] * arr.dtype.itemsize) % dtype.itemsize != 0: + if (self.shape[-1] * self.dtype.itemsize) % dtype.itemsize != 0: raise ValueError("When changing to a larger dtype, its size must be a divisor " "of the total size in bytes of the last axis of the array.") - if arr.dtype == dtype: - return arr + if self.dtype == dtype: + return self # lax.bitcast_convert_type does not support bool or complex; in these cases we # cast to a compatible type and recursively call _view for simplicity. - if arr.dtype == bool: - return _view(arr.astype('uint8'), dtype) + if self.dtype == bool: + return _view(self.astype('uint8'), dtype) - if lax_numpy.issubdtype(arr.dtype, np.complexfloating): - new_shape = (*arr.shape[:-1], arr.shape[-1] * 2) - new_dtype = lax_numpy.finfo(arr.dtype).dtype - arr = (lax_numpy.zeros(new_shape, new_dtype) - .at[..., 0::2].set(arr.real) - .at[..., 1::2].set(arr.imag)) - return _view(arr, dtype) + if lax_numpy.issubdtype(self.dtype, np.complexfloating): + new_shape = (*self.shape[:-1], self.shape[-1] * 2) + new_dtype = lax_numpy.finfo(self.dtype).dtype + self = (lax_numpy.zeros(new_shape, new_dtype) + .at[..., 0::2].set(self.real) + .at[..., 1::2].set(self.imag)) + return _view(self, dtype) if dtype == bool: - return _view(arr, np.uint8).astype(bool) + return _view(self, np.uint8).astype(bool) if lax_numpy.issubdtype(dtype, np.complexfloating): - out = _view(arr, lax_numpy.finfo(dtype).dtype).astype(dtype) + out = _view(self, lax_numpy.finfo(dtype).dtype).astype(dtype) return out[..., 0::2] + 1j * out[..., 1::2] # lax.bitcast_convert_type adds or subtracts dimensions depending on the # relative bitwidths of the dtypes; we account for that with reshapes. - if arr.dtype.itemsize < dtype.itemsize: - factor = dtype.itemsize // arr.dtype.itemsize - arr = arr.reshape(*arr.shape[:-1], arr.shape[-1] // factor, factor) - return lax.bitcast_convert_type(arr, dtype) + if self.dtype.itemsize < dtype.itemsize: + factor = dtype.itemsize // self.dtype.itemsize + out = self.reshape(*self.shape[:-1], self.shape[-1] // factor, factor) + return lax.bitcast_convert_type(out, dtype) - if arr.dtype.itemsize > dtype.itemsize: - out = lax.bitcast_convert_type(arr, dtype) + if self.dtype.itemsize > dtype.itemsize: + out = lax.bitcast_convert_type(self, dtype) return out.reshape(*out.shape[:-2], out.shape[-2] * out.shape[-1]) - return lax.bitcast_convert_type(arr, dtype) + return lax.bitcast_convert_type(self, dtype) def _notimplemented_flat(self): @@ -291,9 +591,6 @@ def _operator_round(number: ArrayLike, ndigits: int | None = None) -> Array: # If `ndigits` is None, for a builtin float round(7.5) returns an integer. return out.astype(int) if ndigits is None else out -def _copy(self: Array) -> Array: - return self.copy() - def _deepcopy(self: Array, memo: Any) -> Array: del memo # unused return self.copy() @@ -311,19 +608,9 @@ def __array_module__(self, types): return NotImplemented -def _compress_method(a: ArrayLike, condition: ArrayLike, - axis: int | None = None, *, out: None = None, - size: int | None = None, fill_value: ArrayLike = 0) -> Array: - """Return selected slices of this array along given axis. - - Refer to :func:`jax.numpy.compress` for full documentation.""" - return lax_numpy.compress(condition, a, axis=axis, out=out, - size=size, fill_value=fill_value) - - @core.stash_axis_env() @partial(jax.jit, static_argnums=(1,2,3)) -def _multi_slice(arr: ArrayLike, +def _multi_slice(self: Array, start_indices: tuple[tuple[int, ...]], limit_indices: tuple[tuple[int, ...]], removed_dims: tuple[tuple[int, ...]]) -> list[Array]: @@ -334,13 +621,13 @@ def _multi_slice(arr: ArrayLike, """ results: list[Array] = [] for starts, limits, removed in zip(start_indices, limit_indices, removed_dims): - sliced = lax.slice(arr, starts, limits) + sliced = lax.slice(self, starts, limits) if removed: sliced = lax.squeeze(sliced, removed) results.append(sliced) return results -# The next two functions are related to iter(device_array), implemented here to +# The next two functions are related to iter(array), implemented here to # avoid circular imports. @jax.jit def _unstack(x: Array) -> list[Array]: @@ -622,15 +909,15 @@ def max(self, values, *, indices_are_sorted=False, unique_indices=False, "setitem": _unimplemented_setitem, "copy": _copy, "deepcopy": _deepcopy, - "neg": ufuncs.negative, - "pos": ufuncs.positive, + "neg": lambda self: ufuncs.negative(self), + "pos": lambda self: ufuncs.positive(self), "eq": _defer_to_unrecognized_arg("==", ufuncs.equal), "ne": _defer_to_unrecognized_arg("!=", ufuncs.not_equal), "lt": _defer_to_unrecognized_arg("<", ufuncs.less), "le": _defer_to_unrecognized_arg("<=", ufuncs.less_equal), "gt": _defer_to_unrecognized_arg(">", ufuncs.greater), "ge": _defer_to_unrecognized_arg(">=", ufuncs.greater_equal), - "abs": ufuncs.abs, + "abs": lambda self: ufuncs.abs(self), "add": _defer_to_unrecognized_arg("+", ufuncs.add), "radd": _defer_to_unrecognized_arg("+", ufuncs.add, swap=True), "sub": _defer_to_unrecognized_arg("-", ufuncs.subtract), @@ -657,7 +944,7 @@ def max(self, values, *, indices_are_sorted=False, unique_indices=False, "ror": _defer_to_unrecognized_arg("|", ufuncs.bitwise_or, swap=True), "xor": _defer_to_unrecognized_arg("^", ufuncs.bitwise_xor), "rxor": _defer_to_unrecognized_arg("^", ufuncs.bitwise_xor, swap=True), - "invert": ufuncs.bitwise_not, + "invert": lambda self: ufuncs.bitwise_not(self), "lshift": _defer_to_unrecognized_arg("<<", ufuncs.left_shift), "rshift": _defer_to_unrecognized_arg(">>", ufuncs.right_shift), "rlshift": _defer_to_unrecognized_arg("<<", ufuncs.left_shift, swap=True), @@ -667,46 +954,46 @@ def max(self, values, *, indices_are_sorted=False, unique_indices=False, _array_methods = { "__array_namespace__": array_api_metadata.__array_namespace__, - "all": reductions.all, - "any": reductions.any, - "argmax": lax_numpy.argmax, - "argmin": lax_numpy.argmin, - "argpartition": lax_numpy.argpartition, - "argsort": lax_numpy.argsort, + "all": _all, + "any": _any, + "argmax": _argmax, + "argmin": _argmin, + "argpartition": _argpartition, + "argsort": _argsort, "astype": _astype, - "choose": lax_numpy.choose, + "choose": _choose, "clip": _clip, - "conj": ufuncs.conj, - "conjugate": ufuncs.conjugate, - "compress": _compress_method, - "copy": lax_numpy.copy, - "cumprod": reductions.cumprod, - "cumsum": reductions.cumsum, - "diagonal": lax_numpy.diagonal, - "dot": lax_numpy.dot, - "flatten": lax_numpy.ravel, + "compress": _compress, + "conj": _conj, + "conjugate": _conjugate, + "copy": _copy, + "cumprod": _cumprod, + "cumsum": _cumsum, + "diagonal": _diagonal, + "dot": _dot, + "flatten": _flatten, "item": _item, - "max": reductions.max, - "mean": reductions.mean, - "min": reductions.min, - "nonzero": lax_numpy.nonzero, - "prod": reductions.prod, - "ptp": reductions.ptp, - "ravel": lax_numpy.ravel, - "repeat": lax_numpy.repeat, + "max": _max, + "mean": _mean, + "min": _min, + "nonzero": _nonzero, + "prod": _prod, + "ptp": _ptp, + "ravel": _flatten, + "repeat": _repeat, "reshape": _reshape, - "round": lax_numpy.round, - "searchsorted": lax_numpy.searchsorted, - "sort": lax_numpy.sort, - "squeeze": lax_numpy.squeeze, - "std": reductions.std, - "sum": reductions.sum, - "swapaxes": lax_numpy.swapaxes, - "take": lax_numpy.take, + "round": _round, + "searchsorted": _searchsorted, + "sort": _sort, + "squeeze": _squeeze, + "std": _std, + "sum": _sum, + "swapaxes": _swapaxes, + "take": _take, "to_device": _to_device, - "trace": lax_numpy.trace, + "trace": _trace, "transpose": _transpose, - "var": reductions.var, + "var": _var, "view": _view, # Methods exposed in order to avoid circular imports @@ -721,12 +1008,12 @@ def max(self, values, *, indices_are_sorted=False, unique_indices=False, _array_properties = { "flat": _notimplemented_flat, - "T": lax_numpy.transpose, - "mT": lax_numpy.matrix_transpose, - "real": ufuncs.real, - "imag": ufuncs.imag, - "nbytes": _nbytes, - "itemsize": _itemsize, + "T": _transpose_property, + "mT": _matrix_transpose_property, + "real": _real_property, + "imag": _imag_property, + "nbytes": _nbytes_property, + "itemsize": _itemsize_property, "at": _IndexUpdateHelper, } @@ -772,14 +1059,14 @@ def _set_tracer_aval_forwarding(tracer, exclude=()): if prop_name not in exclude: setattr(tracer, prop_name, _forward_property_to_aval(prop_name)) -def _set_array_base_attributes(device_array, include=None, exclude=None): +def _set_array_base_attributes(array_impl, include=None, exclude=None): # Forward operators, methods, and properties on Array to lax_numpy # functions (with no Tracers involved; this forwarding is direct) def maybe_setattr(attr_name, target): if exclude is not None and attr_name in exclude: return if not include or attr_name in include: - setattr(device_array, attr_name, target) + setattr(array_impl, attr_name, target) for operator_name, function in _array_operators.items(): maybe_setattr(f"__{operator_name}__", function) @@ -789,10 +1076,10 @@ def maybe_setattr(attr_name, target): maybe_setattr(prop_name, property(prop)) for name, func in _impl_only_array_methods.items(): - setattr(device_array, name, func) + setattr(array_impl, name, func) -def _set_array_attributes(device_array): - setattr(device_array, "__array_module__", __array_module__) +def _set_array_attributes(array_impl): + setattr(array_impl, "__array_module__", __array_module__) def _make_abstract_method(name, func): @abc.abstractmethod diff --git a/jax/_src/numpy/fft.py b/jax/_src/numpy/fft.py index b65f1ee589cc..8b914680fea3 100644 --- a/jax/_src/numpy/fft.py +++ b/jax/_src/numpy/fft.py @@ -252,17 +252,171 @@ def ifftn(a: ArrayLike, s: Shape | None = None, return _fft_core('ifftn', xla_client.FftType.IFFT, a, s, axes, norm) -@implements(np.fft.rfftn) def rfftn(a: ArrayLike, s: Shape | None = None, axes: Sequence[int] | None = None, norm: str | None = None) -> Array: + """Compute a multidimensional discrete Fourier transform of a real-valued array. + + JAX implementation of :func:`numpy.fft.rfftn`. + + Args: + a: real-valued input array. + s: optional sequence of integers. Controls the effective size of the input + along each specified axis. If not specified, it will default to the + dimension of input along ``axes``. + axes: optional sequence of integers, default=None. Specifies the axes along + which the transform is computed. If not specified, the transform is computed + along the last ``len(s)`` axes. If neither ``axes`` nor ``s`` is specified, + the transform is computed along all the axes. + norm: string, default="backward". The normalization mode. "backward", "ortho" + and "forward" are supported. + + Returns: + An array containing the multidimensional discrete Fourier transform of ``a`` + having size specified in ``s`` along the axes ``axes`` except along the axis + ``axes[-1]``. The size of the output along the axis ``axes[-1]`` is + ``s[-1]//2+1``. + + See also: + - :func:`jax.numpy.fft.rfft`: Computes a one-dimensional discrete Fourier + transform of real-valued array. + - :func:`jax.numpy.fft.rfft2`: Computes a two-dimensional discrete Fourier + transform of real-valued array. + - :func:`jax.numpy.fft.irfftn`: Computes a real-valued multidimensional inverse + discrete Fourier transform. + + Examples: + >>> x = jnp.array([[[1, 3, 5], + ... [2, 4, 6]], + ... [[7, 9, 11], + ... [8, 10, 12]]]) + >>> with jnp.printoptions(precision=2, suppress=True): + ... jnp.fft.rfftn(x) + Array([[[ 78.+0.j , -12.+6.93j], + [ -6.+0.j , 0.+0.j ]], + + [[-36.+0.j , 0.+0.j ], + [ 0.+0.j , 0.+0.j ]]], dtype=complex64) + + When ``s=[3, 3, 4]``, size of the transform along ``axes (-3, -2)`` will + be (3, 3), and along ``axis -1`` will be ``4//2+1 = 3`` and size along + other axes will be the same as that of input. + + >>> with jnp.printoptions(precision=2, suppress=True): + ... jnp.fft.rfftn(x, s=[3, 3, 4]) + Array([[[ 78. +0.j , -16. -26.j , 26. +0.j ], + [ 15. -36.37j, -16.12 +1.93j, 5. -12.12j], + [ 15. +36.37j, 8.12-11.93j, 5. +12.12j]], + + [[ -7.5 -49.36j, -20.45 +9.43j, -2.5 -16.45j], + [-25.5 -7.79j, -0.6 +11.96j, -8.5 -2.6j ], + [ 19.5 -12.99j, -8.33 -6.5j , 6.5 -4.33j]], + + [[ -7.5 +49.36j, 12.45 -4.43j, -2.5 +16.45j], + [ 19.5 +12.99j, 0.33 -6.5j , 6.5 +4.33j], + [-25.5 +7.79j, 4.6 +5.04j, -8.5 +2.6j ]]], dtype=complex64) + + When ``s=[3, 5]`` and ``axes=(0, 1)``, size of the transform along ``axis 0`` + will be ``3``, along ``axis 1`` will be ``5//2+1 = 3`` and dimension along + other axes will be same as that of input. + + >>> with jnp.printoptions(precision=2, suppress=True): + ... jnp.fft.rfftn(x, s=[3, 5], axes=[0, 1]) + Array([[[ 18. +0.j , 26. +0.j , 34. +0.j ], + [ 11.09 -9.51j, 16.33-13.31j, 21.56-17.12j], + [ -0.09 -5.88j, 0.67 -8.23j, 1.44-10.58j]], + + [[ -4.5 -12.99j, -2.5 -16.45j, -0.5 -19.92j], + [ -9.71 -6.3j , -10.05 -9.52j, -10.38-12.74j], + [ -4.95 +0.72j, -5.78 -0.2j , -6.61 -1.12j]], + + [[ -4.5 +12.99j, -2.5 +16.45j, -0.5 +19.92j], + [ 3.47+10.11j, 6.43+11.42j, 9.38+12.74j], + [ 3.19 +1.63j, 4.4 +1.38j, 5.61 +1.12j]]], dtype=complex64) + + For 1-D input: + + >>> x1 = jnp.array([1, 2, 3, 4]) + >>> jnp.fft.rfftn(x1) + Array([10.+0.j, -2.+2.j, -2.+0.j], dtype=complex64) + """ return _fft_core('rfftn', xla_client.FftType.RFFT, a, s, axes, norm) -@implements(np.fft.irfftn) def irfftn(a: ArrayLike, s: Shape | None = None, axes: Sequence[int] | None = None, norm: str | None = None) -> Array: + """Compute a real-valued multidimensional inverse discrete Fourier transform. + + JAX implementation of :func:`numpy.fft.irfftn`. + + Args: + a: input array. + s: optional sequence of integers. Specifies the size of the output in each + specified axis. If not specified, the dimension of output along axis + ``axes[-1]`` is ``2*(m-1)``, ``m`` is the size of input along axis ``axes[-1]`` + and the dimension along other axes will be the same as that of input. + axes: optional sequence of integers, default=None. Specifies the axes along + which the transform is computed. If not specified, the transform is computed + along the last ``len(s)`` axes. If neither ``axes`` nor ``s`` is specified, + the transform is computed along all the axes. + norm: string, default="backward". The normalization mode. "backward", "ortho" + and "forward" are supported. + + Returns: + A real-valued array containing the multidimensional inverse discrete Fourier + transform of ``a`` with size ``s`` along specified ``axes``, and the same as + the input along other axes. + + See also: + - :func:`jax.numpy.fft.rfftn`: Computes a multidimensional discrete Fourier + transform of a real-valued array. + - :func:`jax.numpy.fft.irfft`: Computes a real-valued one-dimensional inverse + discrete Fourier transform. + - :func:`jax.numpy.fft.irfft2`: Computes a real-valued two-dimensional inverse + discrete Fourier transform. + + Examples: + ``jnp.fft.irfftn`` computes the transform along all the axes by default. + + >>> x = jnp.array([[[1, 3, 5], + ... [2, 4, 6]], + ... [[7, 9, 11], + ... [8, 10, 12]]]) + >>> jnp.fft.irfftn(x) + Array([[[ 6.5, -1. , 0. , -1. ], + [-0.5, 0. , 0. , 0. ]], + + [[-3. , 0. , 0. , 0. ], + [ 0. , 0. , 0. , 0. ]]], dtype=float32) + + When ``s=[3, 4]``, size of the transform along ``axes (-2, -1)`` will be + ``(3, 4)`` and size along other axes will be the same as that of input. + + >>> with jnp.printoptions(precision=2, suppress=True): + ... jnp.fft.irfftn(x, s=[3, 4]) + Array([[[ 2.33, -0.67, 0. , -0.67], + [ 0.33, -0.74, 0. , 0.41], + [ 0.33, 0.41, 0. , -0.74]], + + [[ 6.33, -0.67, 0. , -0.67], + [ 1.33, -1.61, 0. , 1.28], + [ 1.33, 1.28, 0. , -1.61]]], dtype=float32) + + When ``s=[3]`` and ``axes=[0]``, size of the transform along ``axes 0`` will + be ``3`` and dimension along other axes will be same as that of input. + + >>> with jnp.printoptions(precision=2, suppress=True): + ... jnp.fft.irfftn(x, s=[3], axes=[0]) + Array([[[ 5., 7., 9.], + [ 6., 8., 10.]], + + [[-2., -2., -2.], + [-2., -2., -2.]], + + [[-2., -2., -2.], + [-2., -2., -2.]]], dtype=float32) + """ return _fft_core('irfftn', xla_client.FftType.IRFFT, a, s, axes, norm) @@ -465,12 +619,12 @@ def rfft(a: ArrayLike, n: int | None = None, def irfft(a: ArrayLike, n: int | None = None, axis: int = -1, norm: str | None = None) -> Array: - r"""Compute a one-dimensional inverse discrete Fourier transform for real input. + """Compute a real-valued one-dimensional inverse discrete Fourier transform. JAX implementation of :func:`numpy.fft.irfft`. Args: - a: real-valued input array. + a: input array. n: int. Specifies the dimension of the result along ``axis``. If not specified, ``n = 2*(m-1)``, where ``m`` is the dimension of ``a`` along ``axis``. axis: int, default=-1. Specifies the axis along which the transform is computed. @@ -479,8 +633,8 @@ def irfft(a: ArrayLike, n: int | None = None, supported. Returns: - An array containing the one-dimensional inverse discrete Fourier transform - of ``a``, with a dimension of ``n`` along ``axis``. + A real-valued array containing the one-dimensional inverse discrete Fourier + transform of ``a``, with a dimension of ``n`` along ``axis``. See also: - :func:`jax.numpy.fft.ifft`: Computes a one-dimensional inverse discrete @@ -826,15 +980,157 @@ def ifft2(a: ArrayLike, s: Shape | None = None, axes: Sequence[int] = (-2,-1), return _fft_core_2d('ifft2', xla_client.FftType.IFFT, a, s=s, axes=axes, norm=norm) -@implements(np.fft.rfft2) + def rfft2(a: ArrayLike, s: Shape | None = None, axes: Sequence[int] = (-2,-1), norm: str | None = None) -> Array: + """Compute a two-dimensional discrete Fourier transform of a real-valued array. + + JAX implementation of :func:`numpy.fft.rfft2`. + + Args: + a: real-valued input array. Must have ``a.ndim >= 2``. + s: optional length-2 sequence of integers. Specifies the effective size of the + output along each specified axis. If not specified, it will default to the + dimension of input along ``axes``. + axes: optional length-2 sequence of integers, default=(-2,-1). Specifies the + axes along which the transform is computed. + norm: string, default="backward". The normalization mode. "backward", "ortho" + and "forward" are supported. + + Returns: + An array containing the two-dimensional discrete Fourier transform of ``a``. + The size of the output along the axis ``axes[1]`` is ``(s[1]/2)+1``, if ``s[1]`` + is even and ``(s[1]+1)/2``, if ``s[1]`` is odd. The size of the output along + the axis ``axes[0]`` is ``s[0]``. + + See also: + - :func:`jax.numpy.fft.rfft`: Computes a one-dimensional discrete Fourier + transform of real-valued array. + - :func:`jax.numpy.fft.rfftn`: Computes a multidimensional discrete Fourier + transform of real-valued array. + - :func:`jax.numpy.fft.irfft2`: Computes a real-valued two-dimensional inverse + discrete Fourier transform. + + Examples: + ``jnp.fft.rfft2`` computes the transform along the last two axes by default. + + >>> x = jnp.array([[[1, 3, 5], + ... [2, 4, 6]], + ... [[7, 9, 11], + ... [8, 10, 12]]]) + >>> with jnp.printoptions(precision=2, suppress=True): + ... jnp.fft.rfft2(x) + Array([[[21.+0.j , -6.+3.46j], + [-3.+0.j , 0.+0.j ]], + + [[57.+0.j , -6.+3.46j], + [-3.+0.j , 0.+0.j ]]], dtype=complex64) + + When ``s=[2, 4]``, dimension of the transform along ``axis -2`` will be + ``2``, along ``axis -1`` will be ``(4/2)+1) = 3`` and dimension along other + axes will be the same as that of input. + + >>> with jnp.printoptions(precision=2, suppress=True): + ... jnp.fft.rfft2(x, s=[2, 4]) + Array([[[21. +0.j, -8. -7.j, 7. +0.j], + [-3. +0.j, 0. +1.j, -1. +0.j]], + + [[57. +0.j, -8.-19.j, 19. +0.j], + [-3. +0.j, 0. +1.j, -1. +0.j]]], dtype=complex64) + + When ``s=[3, 5]`` and ``axes=(0, 1)``, shape of the transform along ``axis 0`` + will be ``3``, along ``axis 1`` will be ``(5+1)/2 = 3`` and dimension along + other axes will be same as that of input. + + >>> with jnp.printoptions(precision=2, suppress=True): + ... jnp.fft.rfft2(x, s=[3, 5], axes=(0, 1)) + Array([[[ 18. +0.j , 26. +0.j , 34. +0.j ], + [ 11.09 -9.51j, 16.33-13.31j, 21.56-17.12j], + [ -0.09 -5.88j, 0.67 -8.23j, 1.44-10.58j]], + + [[ -4.5 -12.99j, -2.5 -16.45j, -0.5 -19.92j], + [ -9.71 -6.3j , -10.05 -9.52j, -10.38-12.74j], + [ -4.95 +0.72j, -5.78 -0.2j , -6.61 -1.12j]], + + [[ -4.5 +12.99j, -2.5 +16.45j, -0.5 +19.92j], + [ 3.47+10.11j, 6.43+11.42j, 9.38+12.74j], + [ 3.19 +1.63j, 4.4 +1.38j, 5.61 +1.12j]]], dtype=complex64) + """ return _fft_core_2d('rfft2', xla_client.FftType.RFFT, a, s=s, axes=axes, norm=norm) -@implements(np.fft.irfft2) + def irfft2(a: ArrayLike, s: Shape | None = None, axes: Sequence[int] = (-2,-1), norm: str | None = None) -> Array: + """Compute a real-valued two-dimensional inverse discrete Fourier transform. + + JAX implementation of :func:`numpy.fft.irfft2`. + + Args: + a: input array. Must have ``a.ndim >= 2``. + s: optional length-2 sequence of integers. Specifies the size of the output + in each specified axis. If not specified, the dimension of output along + axis ``axes[1]`` is ``2*(m-1)``, ``m`` is the size of input along axis + ``axes[1]`` and the dimension along other axes will be the same as that of + input. + axes: optional length-2 sequence of integers, default=(-2,-1). Specifies the + axes along which the transform is computed. + norm: string, default="backward". The normalization mode. "backward", "ortho" + and "forward" are supported. + + Returns: + A real-valued array containing the two-dimensional inverse discrete Fourier + transform of ``a``. + + See also: + - :func:`jax.numpy.fft.rfft2`: Computes a two-dimensional discrete Fourier + transform of a real-valued array. + - :func:`jax.numpy.fft.irfft`: Computes a real-valued one-dimensional inverse + discrete Fourier transform. + - :func:`jax.numpy.fft.irfftn`: Computes a real-valued multidimensional inverse + discrete Fourier transform. + + Examples: + ``jnp.fft.irfft2`` computes the transform along the last two axes by default. + + >>> x = jnp.array([[[1, 3, 5], + ... [2, 4, 6]], + ... [[7, 9, 11], + ... [8, 10, 12]]]) + >>> jnp.fft.irfft2(x) + Array([[[ 3.5, -1. , 0. , -1. ], + [-0.5, 0. , 0. , 0. ]], + + [[ 9.5, -1. , 0. , -1. ], + [-0.5, 0. , 0. , 0. ]]], dtype=float32) + + When ``s=[3, 3]``, dimension of the transform along ``axes (-2, -1)`` will be + ``(3, 3)`` and dimension along other axes will be the same as that of input. + + >>> with jnp.printoptions(precision=2, suppress=True): + ... jnp.fft.irfft2(x, s=[3, 3]) + Array([[[ 1.89, -0.44, -0.44], + [ 0.22, -0.78, 0.56], + [ 0.22, 0.56, -0.78]], + + [[ 5.89, -0.44, -0.44], + [ 1.22, -1.78, 1.56], + [ 1.22, 1.56, -1.78]]], dtype=float32) + + When ``s=[2, 3]`` and ``axes=(0, 1)``, shape of the transform along + ``axes (0, 1)`` will be ``(2, 3)`` and dimension along other axes will be + same as that of input. + + >>> with jnp.printoptions(precision=2, suppress=True): + ... jnp.fft.irfft2(x, s=[2, 3], axes=(0, 1)) + Array([[[ 4.67, 6.67, 8.67], + [-0.33, -0.33, -0.33], + [-0.33, -0.33, -0.33]], + + [[-3. , -3. , -3. ], + [ 0. , 0. , 0. ], + [ 0. , 0. , 0. ]]], dtype=float32) + """ return _fft_core_2d('irfft2', xla_client.FftType.IRFFT, a, s=s, axes=axes, norm=norm) @@ -844,7 +1140,7 @@ def fftfreq(n: int, d: ArrayLike = 1.0, *, dtype: DTypeLike | None = None, """Return sample frequencies for the discrete Fourier transform. JAX implementation of :func:`numpy.fft.fftfreq`. Returns frequencies appropriate - for use with the outputs of :func:`~jax.numpy.fft` and :func:`~jax.numpy.ifft`. + for use with the outputs of :func:`~jax.numpy.fft.fft` and :func:`~jax.numpy.fft.ifft`. Args: n: length of the FFT window @@ -858,8 +1154,8 @@ def fftfreq(n: int, d: ArrayLike = 1.0, *, dtype: DTypeLike | None = None, Array of sample frequencies, length ``n``. See also: - - :func:`jax.numpy.fft.rfftfreq`: frequencies for use with :func:`~jax.numpy.rfft` - and :func:`~jax.numpy.irfft`. + - :func:`jax.numpy.fft.rfftfreq`: frequencies for use with + :func:`~jax.numpy.fft.rfft` and :func:`~jax.numpy.fft.irfft`. """ dtype = dtype or dtypes.canonicalize_dtype(jnp.float_) if isinstance(n, (list, tuple)): @@ -895,7 +1191,8 @@ def rfftfreq(n: int, d: ArrayLike = 1.0, *, dtype: DTypeLike | None = None, """Return sample frequencies for the discrete Fourier transform. JAX implementation of :func:`numpy.fft.fftfreq`. Returns frequencies appropriate - for use with the outputs of :func:`~jax.numpy.rfft` and :func:`~jax.numpy.irfft`. + for use with the outputs of :func:`~jax.numpy.fft.rfft` and + :func:`~jax.numpy.fft.irfft`. Args: n: length of the FFT window @@ -909,8 +1206,8 @@ def rfftfreq(n: int, d: ArrayLike = 1.0, *, dtype: DTypeLike | None = None, Array of sample frequencies, length ``n // 2 + 1``. See also: - - :func:`jax.numpy.fft.rfftfreq`: frequencies for use with :func:`~jax.numpy.fft` - and :func:`~jax.numpy.ifft`. + - :func:`jax.numpy.fft.rfftfreq`: frequencies for use with + :func:`~jax.numpy.fft.fft` and :func:`~jax.numpy.fft.ifft`. """ dtype = dtype or dtypes.canonicalize_dtype(jnp.float_) if isinstance(n, (list, tuple)): diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index 3af51e30585d..387b3b2a51a7 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -35,7 +35,6 @@ import types from typing import (overload, Any, Literal, NamedTuple, Protocol, TypeVar, Union) -from textwrap import dedent as _dedent import warnings import numpy as np @@ -99,16 +98,6 @@ def canonicalize_shape(shape: Any, context: str="") -> core.Shape: else: return core.canonicalize_shape(shape, context) -# Common docstring additions: - -_PRECISION_DOC = """\ -In addition to the original NumPy arguments listed below, also supports -``precision`` for extra control over matrix-multiplication precision -on supported devices. ``precision`` may be set to ``None``, which means -default precision for the backend, a :class:`~jax.lax.Precision` enum value -(``Precision.DEFAULT``, ``Precision.HIGH`` or ``Precision.HIGHEST``) or a tuple -of two :class:`~jax.lax.Precision` enums indicating separate precision for each argument. -""" # Some objects below rewrite their __module__ attribute to this name. _PUBLIC_MODULE_NAME = "jax.numpy" @@ -349,14 +338,103 @@ def load(*args: Any, **kwargs: Any) -> Array: ### implementations of numpy functions in terms of lax -@util.implements(np.fmin, module='numpy') @jit def fmin(x1: ArrayLike, x2: ArrayLike) -> Array: + """Return element-wise minimum of the input arrays. + + JAX implemtentation of :func:`numpy.fmin`. + + Args: + x1: input array or scalar. + x2: input array or scalar. x1 and x2 must either have same shape or be + broadcast compatible. + + Returns: + An array containing the element-wise minimum of x1 and x2. + + Note: + For each pair of elements, ``jnp.fmin`` returns: + - the smaller of the two if both elements are finite numbers. + - finite number if one element is ``nan``. + - ``-inf`` if one element is ``-inf`` and the other is finite or ``nan``. + - ``inf`` if one element is ``inf`` and the other is ``nan``. + - ``nan`` if both elements are ``nan``. + + Examples: + >>> jnp.fmin(2, 3) + Array(2, dtype=int32, weak_type=True) + >>> jnp.fmin(2, jnp.array([1, 4, 2, -1])) + Array([ 1, 2, 2, -1], dtype=int32) + + >>> x1 = jnp.array([1, 3, 2]) + >>> x2 = jnp.array([2, 1, 4]) + >>> jnp.fmin(x1, x2) + Array([1, 1, 2], dtype=int32) + + >>> x3 = jnp.array([1, 5, 3]) + >>> x4 = jnp.array([[2, 3, 1], + ... [5, 6, 7]]) + >>> jnp.fmin(x3, x4) + Array([[1, 3, 1], + [1, 5, 3]], dtype=int32) + + >>> nan = jnp.nan + >>> x5 = jnp.array([jnp.inf, 5, nan]) + >>> x6 = jnp.array([[2, 3, nan], + ... [nan, 6, 7]]) + >>> jnp.fmin(x5, x6) + Array([[ 2., 3., nan], + [inf, 5., 7.]], dtype=float32) + """ return where(ufuncs.less(x1, x2) | ufuncs.isnan(x2), x1, x2) -@util.implements(np.fmax, module='numpy') + @jit def fmax(x1: ArrayLike, x2: ArrayLike) -> Array: + """Return element-wise maximum of the input arrays. + + JAX implementation of :func:`numpy.fmax`. + + Args: + x1: input array or scalar + x2: input array or scalar. x1 and x1 must either have same shape or be + broadcast compatible. + + Returns: + An array containing the element-wise maximum of x1 and x2. + + Note: + For each pair of elements, ``jnp.fmax`` returns: + - the larger of the two if both elements are finite numbers. + - finite number if one element is ``nan``. + - ``nan`` if both elements are ``nan``. + - ``inf`` if one element is ``inf`` and the other is finite or ``nan``. + - ``-inf`` if one element is ``-inf`` and the other is ``nan``. + + Examples: + >>> jnp.fmax(3, 7) + Array(7, dtype=int32, weak_type=True) + >>> jnp.fmax(5, jnp.array([1, 7, 9, 4])) + Array([5, 7, 9, 5], dtype=int32) + + >>> x1 = jnp.array([1, 3, 7, 8]) + >>> x2 = jnp.array([-1, 4, 6, 9]) + >>> jnp.fmax(x1, x2) + Array([1, 4, 7, 9], dtype=int32) + + >>> x3 = jnp.array([[2, 3, 5, 10], + ... [11, 9, 7, 5]]) + >>> jnp.fmax(x1, x3) + Array([[ 2, 3, 7, 10], + [11, 9, 7, 8]], dtype=int32) + + >>> x4 = jnp.array([jnp.inf, 6, -jnp.inf, nan]) + >>> x5 = jnp.array([[3, 5, 7, nan], + ... [nan, 9, nan, -1]]) + >>> jnp.fmax(x4, x5) + Array([[ inf, 6., 7., nan], + [ inf, 9., -inf, -1.]], dtype=float32) + """ return where(ufuncs.greater(x1, x2) | ufuncs.isnan(x2), x1, x2) @util.implements(np.issubdtype) @@ -376,22 +454,42 @@ def result_type(*args: Any) -> DType: return dtypes.result_type(*args) -@util.implements(np.trunc, module='numpy') @jit def trunc(x: ArrayLike) -> Array: + """Round input to the nearest integer towards zero. + + JAX implementation of :func:`numpy.trunc`. + + Args: + x: input array or scalar. + + Returns: + An array with same shape and dtype as ``x`` containing the rounded values. + + See also: + - :func:`jax.numpy.fix`: Rounds the input to the nearest integer towards zero. + - :func:`jax.numpy.ceil`: Rounds the input up to the nearest integer. + - :func:`jax.numpy.floor`: Rounds the input down to the nearest integer. + + Examples: + >>> key = jax.random.key(42) + >>> x = jax.random.uniform(key, (3, 3), minval=-10, maxval=10) + >>> with jnp.printoptions(precision=2, suppress=True): + ... print(x) + [[ 2.88 -3.55 -6.13] + [ 7.73 4.49 -6.16] + [-3.1 -4.95 2.64]] + >>> jnp.trunc(x) + Array([[ 2., -3., -6.], + [ 7., 4., -6.], + [-3., -4., 2.]], dtype=float32) + """ util.check_arraylike('trunc', x) if dtypes.isdtype(dtypes.dtype(x), ('integral', 'bool')): return lax_internal.asarray(x) return where(lax.lt(x, _lax_const(x, 0)), ufuncs.ceil(x), ufuncs.floor(x)) -_CONV_PREFERRED_ELEMENT_TYPE_DESCRIPTION = """ -preferred_element_type : dtype, optional - If specified, accumulate results and return a result of the given data type. - If not specified, the function instead follows the numpy convention of always - accumulating results and returning an inexact dtype. -""" - @partial(jit, static_argnames=['mode', 'op', 'precision', 'preferred_element_type']) def _conv(x: Array, y: Array, mode: str, op: str, precision: PrecisionLike, preferred_element_type: DTypeLike | None = None) -> Array: @@ -720,11 +818,6 @@ def histogramdd(sample: ArrayLike, bins: ArrayLike | list[ArrayLike] = 10, return hist, bin_edges_by_dim -_ARRAY_VIEW_DOC = """ -The JAX version of this function may in some cases return a copy rather than a -view of the input. -""" - def transpose(a: ArrayLike, axes: Sequence[int] | None = None) -> Array: """Return a transposed version of an N-dimensional array. @@ -803,8 +896,31 @@ def transpose(a: ArrayLike, axes: Sequence[int] | None = None) -> Array: return lax.transpose(a, axes_) -@util.implements(getattr(np, "permute_dims", None)) def permute_dims(a: ArrayLike, /, axes: tuple[int, ...]) -> Array: + """Permute the axes/dimensions of an array. + + JAX implementation of :func:`array_api.permute_dims`. + + Args: + a: input array + axes: tuple of integers in range ``[0, a.ndim)`` specifying the + axes permutation. + + Returns: + a copy of ``a`` with axes permuted. + + See also: + - :func:`jax.numpy.transpose` + - :func:`jax.numpy.matrix_transpose` + + Examples: + >>> a = jnp.array([[1, 2, 3], + ... [4, 5, 6]]) + >>> jnp.permute_dims(a, (1, 0)) + Array([[1, 4], + [2, 5], + [3, 6]], dtype=int32) + """ util.check_arraylike("permute_dims", a) return lax.transpose(a, axes) @@ -862,9 +978,65 @@ def matrix_transpose(x: ArrayLike, /) -> Array: return lax.transpose(x, axes) -@util.implements(np.rot90, lax_description=_ARRAY_VIEW_DOC) @partial(jit, static_argnames=('k', 'axes')) def rot90(m: ArrayLike, k: int = 1, axes: tuple[int, int] = (0, 1)) -> Array: + """Rotate an array by 90 degrees counterclockwise in the plane specified by axes. + + JAX implementation of :func:`numpy.rot90`. + + Args: + m: input array. Must have ``m.ndim >= 2``. + k: int, optional, default=1. Specifies the number of times the array is rotated. + For negative values of ``k``, the array is rotated in clockwise direction. + axes: tuple of 2 integers, optional, default= (0, 1). The axes define the plane + in which the array is rotated. Both the axes must be different. + + Returns: + An array containing the copy of the input, ``m`` rotated by 90 degrees. + + See also: + - :func:`jax.numpy.flip`: reverse the order along the given axis + - :func:`jax.numpy.fliplr`: reverse the order along axis 1 (left/right) + - :func:`jax.numpy.flipud`: reverse the order along axis 0 (up/down) + + Examples: + >>> m = jnp.array([[1, 2, 3], + ... [4, 5, 6]]) + >>> jnp.rot90(m) + Array([[3, 6], + [2, 5], + [1, 4]], dtype=int32) + >>> jnp.rot90(m, k=2) + Array([[6, 5, 4], + [3, 2, 1]], dtype=int32) + + ``jnp.rot90(m, k=1, axes=(1, 0))`` is equivalent to + ``jnp.rot90(m, k=-1, axes(0,1))``. + + >>> jnp.rot90(m, axes=(1, 0)) + Array([[4, 1], + [5, 2], + [6, 3]], dtype=int32) + >>> jnp.rot90(m, k=-1, axes=(0, 1)) + Array([[4, 1], + [5, 2], + [6, 3]], dtype=int32) + + when input array has ``ndim>2``: + + >>> m1 = jnp.array([[[1, 2, 3], + ... [4, 5, 6]], + ... [[7, 8, 9], + ... [10, 11, 12]]]) + >>> jnp.rot90(m1, k=1, axes=(2, 1)) + Array([[[ 4, 1], + [ 5, 2], + [ 6, 3]], + + [[10, 7], + [11, 8], + [12, 9]]], dtype=int32) + """ util.check_arraylike("rot90", m) if np.ndim(m) < 2: raise ValueError("rot90 requires its first argument to have ndim at least " @@ -1105,11 +1277,68 @@ def angle(z: ArrayLike, deg: bool = False) -> Array: return ufuncs.degrees(result) if deg else result -@util.implements(np.diff) @partial(jit, static_argnames=('n', 'axis')) def diff(a: ArrayLike, n: int = 1, axis: int = -1, prepend: ArrayLike | None = None, append: ArrayLike | None = None) -> Array: + """Calculate n-th order difference between array elements along a given axis. + + JAX implementation of :func:`numpy.diff`. + + The first order difference is computed by ``a[i+1] - a[i]``, and the n-th order + difference is computed ``n`` times recursively. + + Args: + a: input array. Must have ``a.ndim >= 1``. + n: int, optional, default=1. Order of the difference. Specifies the number + of times the difference is computed. If n=0, no difference is computed and + input is returned as is. + axis: int, optional, default=-1. Specifies the axis along which the difference + is computed. The difference is computed along ``axis -1`` by default. + prepend: scalar or array, optional, defualt=None. Specifies the values to be + prepended along ``axis`` before computing the difference. + append: scalar or array, optional, defualt=None. Specifies the values to be + appended along ``axis`` before computing the difference. + + Returns: + An array containing the n-th order difference between the elements of ``a``. + + See also: + - :func:`jax.numpy.ediff1d`: Computes the differences between consecutive + elements of an array. + - :func:`jax.numpy.cumsum`: Computes the cumulative sum of the elements of + the array along a given axis. + - :func:`jax.numpy.gradient`: Computes the gradient of an N-dimensional array. + + Examples: + ``jnp.diff`` computes the first order difference along ``axis``, by default. + + >>> a = jnp.array([[1, 5, 2, 9], + ... [3, 8, 7, 4]]) + >>> jnp.diff(a) + Array([[ 4, -3, 7], + [ 5, -1, -3]], dtype=int32) + + When ``n = 2``, second order difference is computed along ``axis``. + + >>> jnp.diff(a, n=2) + Array([[-7, 10], + [-6, -2]], dtype=int32) + + When ``prepend = 2``, it is prepended to ``a`` along ``axis`` before computing + the difference. + + >>> jnp.diff(a, prepend=2) + Array([[-1, 4, -3, 7], + [ 1, 5, -1, -3]], dtype=int32) + + When ``append = jnp.array([[3],[1]])``, it is appended to ``a`` along ``axis`` + before computing the difference. + + >>> jnp.diff(a, append=jnp.array([[3],[1]])) + Array([[ 4, -3, 7, -6], + [ 5, -1, -3, -3]], dtype=int32) + """ util.check_arraylike("diff", a) arr = asarray(a) n = core.concrete_or_error(operator.index, n, "'n' argument of jnp.diff") @@ -1159,16 +1388,58 @@ def diff(a: ArrayLike, n: int = 1, axis: int = -1, return arr -_EDIFF1D_DOC = """\ -Unlike NumPy's implementation of ediff1d, :py:func:`jax.numpy.ediff1d` will not -issue an error if casting ``to_end`` or ``to_begin`` to the type of ``ary`` -loses precision. -""" -@util.implements(np.ediff1d, lax_description=_EDIFF1D_DOC) @jit def ediff1d(ary: ArrayLike, to_end: ArrayLike | None = None, to_begin: ArrayLike | None = None) -> Array: + """Compute the differences of the elements of the flattened array. + + JAX implementation of :func:`numpy.ediff1d`. + + Args: + ary: input array or scalar. + to_end: scalar or array, optional, default=None. Specifies the numbers to + append to the resulting array. + to_begin: scalar or array, optional, default=None. Specifies the numbers to + prepend to the resulting array. + + Returns: + An array containing the differences between the elements of the input array. + + Note: + Unlike NumPy's implementation of ediff1d, :py:func:`jax.numpy.ediff1d` will + not issue an error if casting ``to_end`` or ``to_begin`` to the type of + ``ary`` loses precision. + + See also: + - :func:`jax.numpy.diff`: Computes the n-th order difference between elements + of the array along a given axis. + - :func:`jax.numpy.cumsum`: Computes the cumulative sum of the elements of + the array along a given axis. + - :func:`jax.numpy.gradient`: Computes the gradient of an N-dimensional array. + + Examples: + >>> a = jnp.array([2, 3, 5, 9, 1, 4]) + >>> jnp.ediff1d(a) + Array([ 1, 2, 4, -8, 3], dtype=int32) + >>> jnp.ediff1d(a, to_begin=-10) + Array([-10, 1, 2, 4, -8, 3], dtype=int32) + >>> jnp.ediff1d(a, to_end=jnp.array([20, 30])) + Array([ 1, 2, 4, -8, 3, 20, 30], dtype=int32) + >>> jnp.ediff1d(a, to_begin=-10, to_end=jnp.array([20, 30])) + Array([-10, 1, 2, 4, -8, 3, 20, 30], dtype=int32) + + For array with ``ndim > 1``, the differences are computed after flattening + the input array. + + >>> a1 = jnp.array([[2, -1, 4, 7], + ... [3, 5, -6, 9]]) + >>> jnp.ediff1d(a1) + Array([ -3, 5, 3, -4, 2, -11, 15], dtype=int32) + >>> a2 = jnp.array([2, -1, 4, 7, 3, 5, -6, 9]) + >>> jnp.ediff1d(a2) + Array([ -3, 5, 3, -4, 2, -11, 15], dtype=int32) + """ util.check_arraylike("ediff1d", ary) arr = ravel(ary) result = lax.sub(arr[1:], arr[:-1]) @@ -1306,6 +1577,8 @@ def reshape( JAX does not support ``order="A"``. copy: unused by JAX; JAX always returns a copy, though under JIT the compiler may optimize such copies away. + newshape: deprecated alias of the ``shape`` argument. Will result in a + :class:`DeprecationWarning` if used. Returns: reshaped copy of input array with the specified shape. @@ -1370,11 +1643,10 @@ def reshape( "jnp.reshape received both `shape` and `newshape` arguments. Note that " "using `newshape` is deprecated, please only use `shape` instead." ) - warnings.warn( - "The newshape argument of jax.numpy.reshape is deprecated and setting it " - "will soon raise an error. To avoid an error in the future, and to " - "suppress this warning, please use the shape argument instead.", - DeprecationWarning, stacklevel=2) + deprecations.warn( + "jax-numpy-reshape-newshape", + ("The newshape argument of jax.numpy.reshape is deprecated. " + "Please use the shape argument instead."), stacklevel=2) shape = newshape del newshape elif shape is None: @@ -1595,9 +1867,40 @@ def unravel_index(indices: ArrayLike, shape: Shape) -> tuple[Array, ...]: return tuple(where(oob_pos, s - 1, where(oob_neg, 0, i)) for s, i in safe_zip(shape, out_indices)) -@util.implements(np.resize) + @partial(jit, static_argnames=('new_shape',)) def resize(a: ArrayLike, new_shape: Shape) -> Array: + """Return a new array with specified shape. + + JAX implementation of :func:`numpy.resize`. + + Args: + a: input array or scalar. + new_shape: int or tuple of ints. Specifies the shape of the resized array. + + Returns: + A resized array with specified shape. The elements of ``a`` are repeated in + the resized array, if the resized array is larger than the original aray. + + See also: + - :func:`jax.numpy.reshape`: Returns a reshaped copy of an array. + - :func:`jax.numpy.repeat`: Constructs an array from repeated elements. + + Examples: + >>> x = jnp.array([1, 2, 3, 4, 5, 6, 7, 8, 9]) + >>> jnp.resize(x, (3, 3)) + Array([[1, 2, 3], + [4, 5, 6], + [7, 8, 9]], dtype=int32) + >>> jnp.resize(x, (3, 4)) + Array([[1, 2, 3, 4], + [5, 6, 7, 8], + [9, 1, 2, 3]], dtype=int32) + >>> jnp.resize(4, (3, 2)) + Array([[4, 4], + [4, 4], + [4, 4]], dtype=int32, weak_type=True) + """ util.check_arraylike("resize", a) new_shape = _ensure_index_tuple(new_shape) @@ -2012,15 +2315,58 @@ def _interp(x: ArrayLike, xp: ArrayLike, fp: ArrayLike, return f -@util.implements(np.interp, - lax_description=_dedent(""" - In addition to constant interpolation supported by NumPy, jnp.interp also - supports left='extrapolate' and right='extrapolate' to indicate linear - extrapolation instead.""")) def interp(x: ArrayLike, xp: ArrayLike, fp: ArrayLike, left: ArrayLike | str | None = None, right: ArrayLike | str | None = None, period: ArrayLike | None = None) -> Array: + """One-dimensional linear interpolation. + + JAX implementation of :func:`numpy.interp`. + + Args: + x: N-dimensional array of x coordinates at which to evaluate the interpolation. + xp: one-dimensional sorted array of points to be interpolated. + fp: array of shape ``xp.shape`` containing the function values associated with ``xp``. + left: specify how to handle points ``x < xp[0]``. Default is to return ``fp[0]``. + If ``left`` is a scalar value, it will return this value. if ``left`` is the string + ``"extrapolate"``, then the value will be determined by linear extrapolation. + ``left`` is ignored if ``period`` is specified. + right: specify how to handle points ``x > xp[-1]``. Default is to return ``fp[-1]``. + If ``right`` is a scalar value, it will return this value. if ``right`` is the string + ``"extrapolate"``, then the value will be determined by linear extrapolation. + ``right`` is ignored if ``period`` is specified. + period: optionally specify the period for the *x* coordinates, for e.g. interpolation + in angular space. + + Returns: + an array of shape ``x.shape`` containing the interpolated function at values ``x``. + + Examples: + >>> xp = jnp.arange(10) + >>> fp = 2 * xp + >>> x = jnp.array([0.5, 2.0, 3.5]) + >>> interp(x, xp, fp) + Array([1., 4., 7.], dtype=float32) + + Unless otherwise specified, extrapolation will be constant: + + >>> x = jnp.array([-10., 10.]) + >>> interp(x, xp, fp) + Array([ 0., 18.], dtype=float32) + + Use ``"extrapolate"`` mode for linear extrapolation: + + >>> interp(x, xp, fp, left='extrapolate', right='extrapolate') + Array([-20., 20.], dtype=float32) + + For periodic interpolation, specify the ``period``: + + >>> xp = jnp.array([0, jnp.pi / 2, jnp.pi, 3 * jnp.pi / 2]) + >>> fp = jnp.sin(xp) + >>> x = 2 * jnp.pi # note: not in input array + >>> jnp.interp(x, xp, fp, period=2 * jnp.pi) + Array(0., dtype=float32) + """ static_argnames = [] if isinstance(left, str) or left is None: static_argnames.append('left') @@ -2079,7 +2425,7 @@ def where(condition, x=None, y=None, /, *, size=None, fill_value=None): Returns: An array of dtype ``jnp.result_type(x, y)`` with values drawn from ``x`` where ``condition`` is True, and from ``y`` where condition is ``False``. If ``x`` and ``y`` are ``None``, the - function behaves differently; see `:func:`jax.numpy.nonzero` for a description of the return + function behaves differently; see :func:`jax.numpy.nonzero` for a description of the return type. See Also: @@ -2126,12 +2472,60 @@ def where(condition, x=None, y=None, /, *, size=None, fill_value=None): return util._where(condition, x, y) -@util.implements(np.select) def select( condlist: Sequence[ArrayLike], choicelist: Sequence[ArrayLike], default: ArrayLike = 0, ) -> Array: + """Select values based on a series of conditions. + + JAX implementation of :func:`numpy.select`, implemented in terms + of :func:`jax.lax.select_n` + + Args: + condlist: sequence of array-like conditions. All entries must be mutually + broadcast-compatible. + choicelist: sequence of array-like values to choose. Must have the same length + as ``condlist``, and all entries must be broadcast-compatible with entries + of ``condlist``. + default: value to return when every condition is False (default: 0). + + Returns: + Array of selected values from ``choicelist`` corresponding to the first + ``True`` entry in ``condlist`` at each location. + + See also: + - :func:`jax.numpy.where`: select between two values based on a single condition. + - :func:`jax.lax.select_n`: select between *N* values based on an index. + + Examples: + >>> condlist = [ + ... jnp.array([False, True, False, False]), + ... jnp.array([True, False, False, False]), + ... jnp.array([False, True, True, False]), + ... ] + >>> choicelist = [ + ... jnp.array([1, 2, 3, 4]), + ... jnp.array([10, 20, 30, 40]), + ... jnp.array([100, 200, 300, 400]), + ... ] + >>> jnp.select(condlist, choicelist, default=0) + Array([ 10, 2, 300, 0], dtype=int32) + + This is logically equivalent to the following nested ``where`` statement: + + >>> default = 0 + >>> jnp.where(condlist[0], + ... choicelist[0], + ... jnp.where(condlist[1], + ... choicelist[1], + ... jnp.where(condlist[2], + ... choicelist[2], + ... default))) + Array([ 10, 2, 300, 0], dtype=int32) + + However, for efficiency it is implemented in terms of :func:`jax.lax.select_n`. + """ if len(condlist) != len(choicelist): msg = "condlist must have length equal to choicelist ({} vs {})" raise ValueError(msg.format(len(condlist), len(choicelist))) @@ -2235,25 +2629,119 @@ def broadcast_shapes(*shapes: Sequence[int]) -> tuple[int, ...]: ... def broadcast_shapes(*shapes: Sequence[int | core.Tracer] ) -> tuple[int | core.Tracer, ...]: ... -@util.implements(getattr(np, "broadcast_shapes", None)) def broadcast_shapes(*shapes): + """Broadcast input shapes to a common output shape. + + JAX implementation of :func:`numpy.broadcast_shapes`. JAX uses NumPy-style + broadcasting rules, which you can read more about at `NumPy broadcasting`_. + + Args: + shapes: 0 or more shapes specified as sequences of integers + + Returns: + The broadcasted shape as a tuple of integers. + + See Also: + - :func:`jax.numpy.broadcast_arrays`: broadcast arrays to a common shape. + - :func:`jax.numpy.broadcast_to`: broadcast an array to a specified shape. + + Examples: + Some compatible shapes: + + >>> jnp.broadcast_shapes((1,), (4,)) + (4,) + >>> jnp.broadcast_shapes((3, 1), (4,)) + (3, 4) + >>> jnp.broadcast_shapes((3, 1), (1, 4), (5, 1, 1)) + (5, 3, 4) + + Incompatible shapes: + + >>> jnp.broadcast_shapes((3, 1), (4, 1)) # doctest: +IGNORE_EXCEPTION_DETAIL + Traceback (most recent call last): + ValueError: Incompatible shapes for broadcasting: shapes=[(3, 1), (4, 1)] + + .. _NumPy broadcasting: https://numpy.org/doc/stable/user/basics.broadcasting.html + """ if not shapes: return () shapes = [(shape,) if np.ndim(shape) == 0 else tuple(shape) for shape in shapes] return lax.broadcast_shapes(*shapes) -@util.implements(np.broadcast_arrays, lax_description="""\ -The JAX version does not necessarily return a view of the input. -""") def broadcast_arrays(*args: ArrayLike) -> list[Array]: + """Broadcast arrays to a common shape. + + JAX implementation of :func:`numpy.broadcast_arrays`. JAX uses NumPy-style + broadcasting rules, which you can read more about at `NumPy broadcasting`_. + + Args: + args: zero or more array-like objects to be broadcasted. + + Returns: + a list of arrays containing broadcasted copies of the inputs. + + See also: + - :func:`jax.numpy.broadcast_shapes`: broadcast input shapes to a common shape. + - :func:`jax.numpy.broadcast_to`: broadcast an array to a specified shape. + + Examples: + + >>> x = jnp.arange(3) + >>> y = jnp.int32(1) + >>> jnp.broadcast_arrays(x, y) + [Array([0, 1, 2], dtype=int32), Array([1, 1, 1], dtype=int32)] + + >>> x = jnp.array([[1, 2, 3]]) + >>> y = jnp.array([[10], + ... [20]]) + >>> x2, y2 = jnp.broadcast_arrays(x, y) + >>> x2 + Array([[1, 2, 3], + [1, 2, 3]], dtype=int32) + >>> y2 + Array([[10, 10, 10], + [20, 20, 20]], dtype=int32) + + .. _NumPy broadcasting: https://numpy.org/doc/stable/user/basics.broadcasting.html + """ return util._broadcast_arrays(*args) -@util.implements(np.broadcast_to, lax_description="""\ -The JAX version does not necessarily return a view of the input. -""") def broadcast_to(array: ArrayLike, shape: DimSize | Shape) -> Array: + """Broadcast an array to a specified shape. + + JAX implementation of :func:`numpy.broadcast_to`. JAX uses NumPy-style + broadcasting rules, which you can read more about at `NumPy broadcasting`_. + + Args: + array: array to be broadcast. + shape: shape to which the array will be broadcast. + + Returns: + a copy of array broadcast to the specified shape. + + See also: + - :func:`jax.numpy.broadcast_arrays`: broadcast arrays to a common shape. + - :func:`jax.numpy.broadcast_shapes`: broadcast input shapes to a common shape. + + Examples: + >>> x = jnp.int32(1) + >>> jnp.broadcast_to(x, (1, 4)) + Array([[1, 1, 1, 1]], dtype=int32) + + >>> x = jnp.array([1, 2, 3]) + >>> jnp.broadcast_to(x, (2, 3)) + Array([[1, 2, 3], + [1, 2, 3]], dtype=int32) + + >>> x = jnp.array([[2], [4]]) + >>> jnp.broadcast_to(x, (2, 4)) + Array([[2, 2, 2, 2], + [4, 4, 4, 4]], dtype=int32) + + .. _NumPy broadcasting: https://numpy.org/doc/stable/user/basics.broadcasting.html + """ return util._broadcast_to(array, shape) @@ -2295,89 +2783,266 @@ def _split(op: str, ary: ArrayLike, return [lax.slice(ary, _subval(starts, axis, start), _subval(ends, axis, end)) for start, end in zip(split_indices[:-1], split_indices[1:])] -@util.implements(np.split, lax_description=_ARRAY_VIEW_DOC) + def split(ary: ArrayLike, indices_or_sections: int | Sequence[int] | ArrayLike, axis: int = 0) -> list[Array]: - return _split("split", ary, indices_or_sections, axis=axis) + """Split an array into sub-arrays. -def _split_on_axis(op: str, axis: int) -> Callable[[ArrayLike, int | ArrayLike], list[Array]]: - @util.implements(getattr(np, op), update_doc=False) - def f(ary: ArrayLike, indices_or_sections: int | Sequence[int] | ArrayLike) -> list[Array]: - # for 1-D array, hsplit becomes vsplit - nonlocal axis - util.check_arraylike(op, ary) - a = asarray(ary) - if axis == 1 and len(a.shape) == 1: - axis = 0 - return _split(op, ary, indices_or_sections, axis=axis) - return f + JAX implementation of :func:`numpy.split`. -vsplit = _split_on_axis("vsplit", axis=0) -hsplit = _split_on_axis("hsplit", axis=1) -dsplit = _split_on_axis("dsplit", axis=2) + Args: + ary: N-dimensional array-like object to split + indices_or_sections: either a single integer or a sequence of indices. -@util.implements(np.array_split) -def array_split(ary: ArrayLike, indices_or_sections: int | Sequence[int] | ArrayLike, - axis: int = 0) -> list[Array]: - return _split("array_split", ary, indices_or_sections, axis=axis) + - if ``indices_or_sections`` is an integer *N*, then *N* must evenly divide + ``ary.shape[axis]`` and ``ary`` will be divided into *N* equally-sized + chunks along ``axis``. + - if ``indices_or_sections`` is a sequence of integers, then these integers + specify the boundary between unevenly-sized chunks along ``axis``; see + examples below. + axis: the axis along which to split; defaults to 0. -@jit -def clip( - arr: ArrayLike | None = None, - /, - min: ArrayLike | None = None, - max: ArrayLike | None = None, - *, - a: ArrayLike | DeprecatedArg = DeprecatedArg(), - a_min: ArrayLike | None | DeprecatedArg = DeprecatedArg(), - a_max: ArrayLike | None | DeprecatedArg = DeprecatedArg() -) -> Array: - """Clip array values to a specified range. + Returns: + A list of arrays. If ``indices_or_sections`` is an integer *N*, then the list is + of length *N*. If ``indices_or_sections`` is a sequence *seq*, then the list is + is of length *len(seq) + 1*. - JAX implementation of :func:`numpy.clip`. + Examples: + Splitting a 1-dimensional array: - Args: - arr: N-dimensional array to be clipped. - min: optional minimum value of the clipped range; if ``None`` (default) then - result will not be clipped to any minimum value. If specified, it should be - broadcast-compatible with ``arr`` and ``max``. - max: optional maximum value of the clipped range; if ``None`` (default) then - result will not be clipped to any maximum value. If specified, it should be - broadcast-compatible with ``arr`` and ``min``. - a: deprecated alias of the ``arr`` argument. Will result in a - :class:`DeprecationWarning` if used. - a_min: deprecated alias of the ``min`` argument. Will result in a - :class:`DeprecationWarning` if used. - a_max: deprecated alias of the ``max`` argument. Will result in a - :class:`DeprecationWarning` if used. + >>> x = jnp.array([1, 2, 3, 4, 5, 6, 7, 8, 9]) - Returns: - An array containing values from ``arr``, with values smaller than ``min`` set - to ``min``, and values larger than ``max`` set to ``max``. + Split into three equal sections: + + >>> chunks = jnp.split(x, 3) + >>> print(*chunks) + [1 2 3] [4 5 6] [7 8 9] + + Split into sections by index: + + >>> chunks = jnp.split(x, [2, 7]) # [x[0:2], x[2:7], x[7:]] + >>> print(*chunks) + [1 2] [3 4 5 6 7] [8 9] + + Splitting a two-dimensional array along axis 1: + + >>> x = jnp.array([[1, 2, 3, 4], + ... [5, 6, 7, 8]]) + >>> x1, x2 = jnp.split(x, 2, axis=1) + >>> print(x1) + [[1 2] + [5 6]] + >>> print(x2) + [[3 4] + [7 8]] See also: - - :func:`jax.numpy.minimum`: Compute the element-wise minimum value of two arrays. - - :func:`jax.numpy.maximum`: Compute the element-wise maximum value of two arrays. + - :func:`jax.numpy.array_split`: like ``split``, but allows ``indices_or_sections`` + to be an integer that does not evenly divide the size of the array. + - :func:`jax.numpy.vsplit`: split vertically, i.e. along axis=0 + - :func:`jax.numpy.hsplit`: split horizontally, i.e. along axis=1 + - :func:`jax.numpy.dsplit`: split depth-wise, i.e. along axis=2 + """ + return _split("split", ary, indices_or_sections, axis=axis) + + +def vsplit(ary: ArrayLike, indices_or_sections: int | Sequence[int] | ArrayLike) -> list[Array]: + """Split an array into sub-arrays vertically. + + JAX implementation of :func:`numpy.vsplit`. + + Refer to the documentation of :func:`jax.numpy.split` for details; ``vsplit`` is + equivalent to ``split`` with ``axis=0``. Examples: - >>> arr = jnp.array([0, 1, 2, 3, 4, 5, 6, 7]) - >>> jnp.clip(arr, 2, 5) - Array([2, 2, 2, 3, 4, 5, 5, 5], dtype=int32) - """ - # TODO(micky774): deprecated 2024-4-2, remove after deprecation expires. - arr = a if not isinstance(a, DeprecatedArg) else arr - if arr is None: - raise ValueError("No input was provided to the clip function.") - min = a_min if not isinstance(a_min, DeprecatedArg) else min - max = a_max if not isinstance(a_max, DeprecatedArg) else max - if any(not isinstance(t, DeprecatedArg) for t in (a, a_min, a_max)): - warnings.warn( - "Passing arguments 'a', 'a_min' or 'a_max' to jax.numpy.clip is " - "deprecated. Please use 'arr', 'min' or 'max' respectively instead.", - DeprecationWarning, - stacklevel=2, - ) + 1D array: + + >>> x = jnp.array([1, 2, 3, 4, 5, 6]) + >>> x1, x2 = jnp.vsplit(x, 2) + >>> print(x1, x2) + [1 2 3] [4 5 6] + + 2D array: + + >>> x = jnp.array([[1, 2, 3, 4], + ... [5, 6, 7, 8]]) + >>> x1, x2 = jnp.vsplit(x, 2) + >>> print(x1, x2) + [[1 2 3 4]] [[5 6 7 8]] + + See also: + - :func:`jax.numpy.split`: split an array along any axis. + - :func:`jax.numpy.hsplit`: split horizontally, i.e. along axis=1 + - :func:`jax.numpy.dsplit`: split depth-wise, i.e. along axis=2 + - :func:`jax.numpy.array_split`: like ``split``, but allows ``indices_or_sections`` + to be an integer that does not evenly divide the size of the array. + """ + return _split("vsplit", ary, indices_or_sections, axis=0) + + +def hsplit(ary: ArrayLike, indices_or_sections: int | Sequence[int] | ArrayLike) -> list[Array]: + """Split an array into sub-arrays horizontally. + + JAX implementation of :func:`numpy.hsplit`. + + Refer to the documentation of :func:`jax.numpy.split` for details. ``hsplit`` is + equivalent to ``split`` with ``axis=1``, or ``axis=0`` for one-dimensional arrays. + + Examples: + 1D array: + + >>> x = jnp.array([1, 2, 3, 4, 5, 6]) + >>> x1, x2 = jnp.hsplit(x, 2) + >>> print(x1, x2) + [1 2 3] [4 5 6] + + 2D array: + + >>> x = jnp.array([[1, 2, 3, 4], + ... [5, 6, 7, 8]]) + >>> x1, x2 = jnp.hsplit(x, 2) + >>> print(x1) + [[1 2] + [5 6]] + >>> print(x2) + [[3 4] + [7 8]] + + See also: + - :func:`jax.numpy.split`: split an array along any axis. + - :func:`jax.numpy.vsplit`: split vertically, i.e. along axis=0 + - :func:`jax.numpy.dsplit`: split depth-wise, i.e. along axis=2 + - :func:`jax.numpy.array_split`: like ``split``, but allows ``indices_or_sections`` + to be an integer that does not evenly divide the size of the array. + """ + util.check_arraylike("hsplit", ary) + a = asarray(ary) + return _split("hsplit", a, indices_or_sections, axis=0 if a.ndim == 1 else 1) + + +def dsplit(ary: ArrayLike, indices_or_sections: int | Sequence[int] | ArrayLike) -> list[Array]: + """Split an array into sub-arrays depth-wise. + + JAX implementation of :func:`numpy.dsplit`. + + Refer to the documentation of :func:`jax.numpy.split` for details. ``dsplit`` is + equivalent to ``split`` with ``axis=2``. + + Examples: + + >>> x = jnp.arange(12).reshape(3, 1, 4) + >>> print(x) + [[[ 0 1 2 3]] + + [[ 4 5 6 7]] + + [[ 8 9 10 11]]] + >>> x1, x2 = jnp.dsplit(x, 2) + >>> print(x1) + [[[0 1]] + + [[4 5]] + + [[8 9]]] + >>> print(x2) + [[[ 2 3]] + + [[ 6 7]] + + [[10 11]]] + + See also: + - :func:`jax.numpy.split`: split an array along any axis. + - :func:`jax.numpy.vsplit`: split vertically, i.e. along axis=0 + - :func:`jax.numpy.hsplit`: split horizontally, i.e. along axis=1 + - :func:`jax.numpy.array_split`: like ``split``, but allows ``indices_or_sections`` + to be an integer that does not evenly divide the size of the array. + """ + return _split("dsplit", ary, indices_or_sections, axis=2) + + +def array_split(ary: ArrayLike, indices_or_sections: int | Sequence[int] | ArrayLike, + axis: int = 0) -> list[Array]: + """Split an array into sub-arrays. + + JAX implementation of :func:`numpy.array_split`. + + Refer to the documentation of :func:`jax.numpy.split` for details; ``array_split`` + is equivalent to ``split``, but allows integer ``indices_or_sections`` which does + not evenly divide the split axis. + + Examples: + >>> x = jnp.array([1, 2, 3, 4, 5, 6, 7, 8, 9]) + >>> chunks = jnp.array_split(x, 4) + >>> print(*chunks) + [1 2 3] [4 5] [6 7] [8 9] + + See also: + - :func:`jax.numpy.split`: split an array along any axis. + - :func:`jax.numpy.vsplit`: split vertically, i.e. along axis=0 + - :func:`jax.numpy.hsplit`: split horizontally, i.e. along axis=1 + - :func:`jax.numpy.dsplit`: split depth-wise, i.e. along axis=2 + """ + return _split("array_split", ary, indices_or_sections, axis=axis) + + +@jit +def clip( + arr: ArrayLike | None = None, + /, + min: ArrayLike | None = None, + max: ArrayLike | None = None, + *, + a: ArrayLike | DeprecatedArg = DeprecatedArg(), + a_min: ArrayLike | None | DeprecatedArg = DeprecatedArg(), + a_max: ArrayLike | None | DeprecatedArg = DeprecatedArg() +) -> Array: + """Clip array values to a specified range. + + JAX implementation of :func:`numpy.clip`. + + Args: + arr: N-dimensional array to be clipped. + min: optional minimum value of the clipped range; if ``None`` (default) then + result will not be clipped to any minimum value. If specified, it should be + broadcast-compatible with ``arr`` and ``max``. + max: optional maximum value of the clipped range; if ``None`` (default) then + result will not be clipped to any maximum value. If specified, it should be + broadcast-compatible with ``arr`` and ``min``. + a: deprecated alias of the ``arr`` argument. Will result in a + :class:`DeprecationWarning` if used. + a_min: deprecated alias of the ``min`` argument. Will result in a + :class:`DeprecationWarning` if used. + a_max: deprecated alias of the ``max`` argument. Will result in a + :class:`DeprecationWarning` if used. + + Returns: + An array containing values from ``arr``, with values smaller than ``min`` set + to ``min``, and values larger than ``max`` set to ``max``. + + See also: + - :func:`jax.numpy.minimum`: Compute the element-wise minimum value of two arrays. + - :func:`jax.numpy.maximum`: Compute the element-wise maximum value of two arrays. + + Examples: + >>> arr = jnp.array([0, 1, 2, 3, 4, 5, 6, 7]) + >>> jnp.clip(arr, 2, 5) + Array([2, 2, 2, 3, 4, 5, 5, 5], dtype=int32) + """ + # TODO(micky774): deprecated 2024-4-2, remove after deprecation expires. + arr = a if not isinstance(a, DeprecatedArg) else arr + if arr is None: + raise ValueError("No input was provided to the clip function.") + min = a_min if not isinstance(a_min, DeprecatedArg) else min + max = a_max if not isinstance(a_max, DeprecatedArg) else max + if any(not isinstance(t, DeprecatedArg) for t in (a, a_min, a_max)): + deprecations.warn( + "jax-numpy-clip-args", + ("Passing arguments 'a', 'a_min' or 'a_max' to jax.numpy.clip is " + "deprecated. Please use 'arr', 'min' or 'max' respectively instead."), + stacklevel=2, + ) util.check_arraylike("clip", arr) if any(jax.numpy.iscomplexobj(t) for t in (arr, min, max)): @@ -2392,9 +3057,47 @@ def clip( arr = ufuncs.minimum(max, arr) return asarray(arr) -@util.implements(np.around, skip_params=['out']) + @partial(jit, static_argnames=('decimals',)) def round(a: ArrayLike, decimals: int = 0, out: None = None) -> Array: + """Round input evenly to the given number of decimals. + + JAX implementation of :func:`numpy.round`. + + Args: + a: input array or scalar. + decimals: int, default=0. Number of decimal points to which the input needs + to be rounded. It must be specified statically. Not implemented for + ``decimals < 0``. + out: Unused by JAX. + + Returns: + An array containing the rounded values to the specified ``decimals`` with + same shape and dtype as ``a``. + + Note: + ``jnp.round`` rounds to the nearest even integer for the values exactly halfway + between rounded decimal values. + + See also: + - :func:`jax.numpy.floor`: Rounds the input to the nearest integer downwards. + - :func:`jax.numpy.ceil`: Rounds the input to the nearest integer upwards. + - :func:`jax.numpy.fix` and :func:numpy.trunc`: Rounds the input to the + nearest integer towards zero. + + Examples: + >>> x = jnp.array([1.532, 3.267, 6.149]) + >>> jnp.round(x) + Array([2., 3., 6.], dtype=float32) + >>> jnp.round(x, decimals=2) + Array([1.53, 3.27, 6.15], dtype=float32) + + For values exactly halfway between rounded values: + + >>> x1 = jnp.array([10.5, 21.5, 12.5, 31.5]) + >>> jnp.round(x1) + Array([10., 22., 12., 32.], dtype=float32) + """ util.check_arraylike("round", a) decimals = core.concrete_or_error(operator.index, decimals, "'decimals' argument of jnp.round") if out is not None: @@ -2424,13 +3127,45 @@ def _round_float(x: ArrayLike) -> Array: return lax.complex(_round_float(lax.real(a)), _round_float(lax.imag(a))) else: return _round_float(a) -around = round -round_ = round -@util.implements(np.fix, skip_params=['out']) +@partial(jit, static_argnames=('decimals',)) +def around(a: ArrayLike, decimals: int = 0, out: None = None) -> Array: + """Alias of :func:`jax.numpy.round`""" + return round(a, decimals, out) + + @jit def fix(x: ArrayLike, out: None = None) -> Array: + """Round input to the nearest integer towards zero. + + JAX implementation of :func:`numpy.fix`. + + Args: + x: input array. + out: unused by JAX. + + Returns: + An array with same shape and dtype as ``x`` containing the rounded values. + + See also: + - :func:`jax.numpy.trunc`: Rounds the input to nearest integer towards zero. + - :func:`jax.numpy.ceil`: Rounds the input up to the nearest integer. + - :func:`jax.numpy.floor`: Rounds the input down to the nearest integer. + + Examples: + >>> key = jax.random.key(0) + >>> x = jax.random.uniform(key, (3, 3), minval=-5, maxval=5) + >>> with jnp.printoptions(precision=2, suppress=True): + ... print(x) + [[-1.45 1.04 -0.72] + [-2.69 1.74 -0.6 ] + [-2.49 -2.23 2.68]] + >>> jnp.fix(x) + Array([[-1., 1., -0.], + [-2., 1., -0.], + [-2., -2., 2.]], dtype=float32) + """ util.check_arraylike("fix", x) if out is not None: raise NotImplementedError("The 'out' argument to jnp.fix is not supported.") @@ -3015,13 +3750,123 @@ def _pad(array: ArrayLike, pad_width: PadValueLike[int], mode: str, "not implemented modes") -@util.implements(np.pad, lax_description="""\ -Unlike numpy, JAX "function" mode's argument (which is another function) should return -the modified array. This is because Jax arrays are immutable. -(In numpy, "function" mode's argument should modify a rank 1 array in-place.) -""") def pad(array: ArrayLike, pad_width: PadValueLike[int | Array | np.ndarray], mode: str | Callable[..., Any] = "constant", **kwargs) -> Array: + """Add padding to an array. + + JAX implementation of :func:`numpy.pad`. + + Args: + array: array to pad. + pad_width: specify the pad width for each dimension of an array. Padding widths + may be separately specified for *before* and *after* the array. Options are: + + - ``int`` or ``(int,)``: pad each array dimension with the same number of values + both before and after. + - ``(before, after)``: pad each array with ``before`` elements before, and ``after`` + elements after + - ``((before_1, after_1), (before_2, after_2), ... (before_N, after_N))``: specify + distinct ``before`` and ``after`` values for each array dimension. + + mode: a string or callable. Supported pad modes are: + + - ``'constant'`` (default): pad with a constant value, which defaults to zero. + - ``'empty'``: pad with empty values (i.e. zero) + - ``'edge'``: pad with the edge values of the array. + - ``'wrap'``: pad by wrapping the array. + - ``'linear_ramp'``: pad with a linear ramp to specified ``end_values``. + - ``'maximum'``: pad with the maximum value. + - ``'mean'``: pad with the mean value. + - ``'median'``: pad with the median value. + - ``'minimum'``: pad with the minimum value. + - ``'reflect'``: pad by reflection. + - ``'symmetric'``: pad by symmetric reflection. + - ````: a callable function. See Notes below. + + constant_values: referenced for ``mode = 'constant'``. Specify the constant value + to pad with. + stat_length: referenced for ``mode in ['maximum', 'mean', 'median', 'minimum']``. + An integer or tuple specifying the number of edge values to use when calculating + the statistic. + end_values: referenced for ``mode = 'linear_ramp'``. Specify the end values to + ramp the padding values to. + reflect_type: referenced for ``mode in ['reflect', 'symmetric']``. Specify whether + to use even or odd reflection. + + Returns: + A padded copy of ``array``. + + Notes: + When ``mode`` is callable, it should have the following signature:: + + def pad_func(row: Array, pad_width: tuple[int, int], + iaxis: int, kwargs: dict) -> Array: + ... + + Here ``row`` is a 1D slice of the padded array along axis ``iaxis``, with the pad + values filled with zeros. ``pad_width`` is a tuple specifying the ``(before, after)`` + padding sizes, and ``kwargs`` are any additional keyword arguments passed to the + :func:`jax.numpy.pad` function. + + Note that while in NumPy, the function should modify ``row`` in-place, in JAX the + function should return the modified ``row``. In JAX, the custom padding function + will be mapped across the padded axis using the :func:`jax.vmap` transformation. + + See also: + - :func:`jax.numpy.resize`: resize an array + - :func:`jax.numpy.tile`: create a larger array by tiling a smaller array. + - :func:`jax.numpy.repeat`: create a larger array by repeating values of a smaller array. + + Examples: + + Pad a 1-dimensional array with zeros: + + >>> x = jnp.array([10, 20, 30, 40]) + >>> jnp.pad(x, 2) + Array([ 0, 0, 10, 20, 30, 40, 0, 0], dtype=int32) + >>> jnp.pad(x, (2, 4)) + Array([ 0, 0, 10, 20, 30, 40, 0, 0, 0, 0], dtype=int32) + + Pad a 1-dimensional array with specified values: + + >>> jnp.pad(x, 2, constant_values=99) + Array([99, 99, 10, 20, 30, 40, 99, 99], dtype=int32) + + Pad a 1-dimensional array with the mean array value: + + >>> jnp.pad(x, 2, mode='mean') + Array([25, 25, 10, 20, 30, 40, 25, 25], dtype=int32) + + Pad a 1-dimensional array with reflected values: + + >>> jnp.pad(x, 2, mode='reflect') + Array([30, 20, 10, 20, 30, 40, 30, 20], dtype=int32) + + Pad a 2-dimensional array with different paddings in each dimension: + + >>> x = jnp.array([[1, 2, 3], + ... [4, 5, 6]]) + >>> jnp.pad(x, ((1, 2), (3, 0))) + Array([[0, 0, 0, 0, 0, 0], + [0, 0, 0, 1, 2, 3], + [0, 0, 0, 4, 5, 6], + [0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0]], dtype=int32) + + Pad a 1-dimensional array with a custom padding function: + + >>> def custom_pad(row, pad_width, iaxis, kwargs): + ... # row represents a 1D slice of the zero-padded array. + ... before, after = pad_width + ... before_value = kwargs.get('before_value', 0) + ... after_value = kwargs.get('after_value', 0) + ... row = row.at[:before].set(before_value) + ... return row.at[len(row) - after:].set(after_value) + >>> x = jnp.array([2, 3, 4]) + >>> jnp.pad(x, 2, custom_pad, before_value=-10, after_value=10) + Array([-10, -10, 2, 3, 4, 10, 10], dtype=int32) + """ + util.check_arraylike("pad", array) pad_width = _broadcast_to_pairs(pad_width, ndim(array), "pad_width") if pad_width and not all(core.is_dim(p[0]) and core.is_dim(p[1]) @@ -3061,9 +3906,53 @@ def pad(array: ArrayLike, pad_width: PadValueLike[int | Array | np.ndarray], ### Array-creation functions -@util.implements(np.stack, skip_params=['out']) def stack(arrays: np.ndarray | Array | Sequence[ArrayLike], axis: int = 0, out: None = None, dtype: DTypeLike | None = None) -> Array: + """Join arrays along a new axis. + + JAX implementation of :func:`numpy.stack`. + + Args: + arrays: a sequence of arrays to stack; each must have the same shape. If a + single array is given it will be treated equivalently to + `arrays = unstack(arrays)`, but the implementation will avoid explicit + unstacking. + axis: specify the axis along which to stack. + out: unused by JAX + dtype: optional dtype of the resulting array. If not specified, the dtype + will be determined via type promotion rules described in :ref:`type-promotion`. + + Returns: + the stacked result. + + See also: + - :func:`jax.numpy.unstack`: inverse of ``stack``. + - :func:`jax.numpy.concatenate`: concatenation along existing axes. + - :func:`jax.numpy.vstack`: stack vertically, i.e. along axis 0. + - :func:`jax.numpy.hstack`: stack horizontally, i.e. along axis 1. + - :func:`jax.numpy.dstack`: stack depth-wise, i.e. along axis 2. + - :func:`jax.numpy.column_stack`: stack columns. + + Examples: + >>> x = jnp.array([1, 2, 3]) + >>> y = jnp.array([4, 5, 6]) + >>> jnp.stack([x, y]) + Array([[1, 2, 3], + [4, 5, 6]], dtype=int32) + >>> jnp.stack([x, y], axis=1) + Array([[1, 4], + [2, 5], + [3, 6]], dtype=int32) + + :func:`~jax.numpy.unstack` performs the inverse operation: + + >>> arr = jnp.stack([x, y], axis=1) + >>> x, y = jnp.unstack(arr, axis=1) + >>> x + Array([1, 2, 3], dtype=int32) + >>> y + Array([4, 5, 6], dtype=int32) + """ if not len(arrays): raise ValueError("Need at least one array to stack.") if out is not None: @@ -3082,9 +3971,38 @@ def stack(arrays: np.ndarray | Array | Sequence[ArrayLike], new_arrays.append(expand_dims(a, axis)) return concatenate(new_arrays, axis=axis, dtype=dtype) -@util.implements(getattr(np, 'unstack', None)) + @partial(jit, static_argnames="axis") def unstack(x: ArrayLike, /, *, axis: int = 0) -> tuple[Array, ...]: + """Unstack an array along an axis. + + JAX implementation of :func:`array_api.unstack`. + + Args: + x: array to unstack. Must have ``x.ndim >= 1``. + axis: integer axis along which to unstack. Must satisfy + ``-x.ndim <= axis < x.ndim``. + + Returns: + tuple of unstacked arrays. + + See also: + - :func:`jax.numpy.stack`: inverse of ``unstack`` + - :func:`jax.numpy.split`: split array into batches along an axis. + + Examples: + >>> arr = jnp.array([[1, 2, 3], + ... [4, 5, 6]]) + >>> arrs = jnp.unstack(arr) + >>> print(*arrs) + [1 2 3] [4 5 6] + + :func:`~jax.numpy.stack` provides the inverse of this: + + >>> jnp.stack(arrs) + Array([[1, 2, 3], + [4, 5, 6]], dtype=int32) + """ util.check_arraylike("unstack", x) x = asarray(x) if x.ndim == 0: @@ -3126,9 +4044,46 @@ def _concatenate_array(arr: ArrayLike, axis: int | None, dimensions = [*range(1, axis + 1), 0, *range(axis + 1, arr.ndim)] return lax.reshape(arr, shape, dimensions) -@util.implements(np.concatenate) + def concatenate(arrays: np.ndarray | Array | Sequence[ArrayLike], axis: int | None = 0, dtype: DTypeLike | None = None) -> Array: + """Join arrays along an existing axis. + + JAX implementation of :func:`numpy.concatenate`. + + Args: + arrays: a sequence of arrays to concatenate; each must have the same shape + except along the specified axis. If a single array is given it will be + treated equivalently to `arrays = unstack(arrays)`, but the implementation + will avoid explicit unstacking. + axis: specify the axis along which to concatenate. + dtype: optional dtype of the resulting array. If not specified, the dtype + will be determined via type promotion rules described in :ref:`type-promotion`. + + Returns: + the concatenated result. + + See also: + - :func:`jax.lax.concatenate`: XLA concatenation API. + - :func:`jax.numpy.concat`: Array API version of this function. + - :func:`jax.numpy.stack`: concatenate arrays along a new axis. + + Examples: + One-dimensional concatenation: + + >>> x = jnp.arange(3) + >>> y = jnp.zeros(3, dtype=int) + >>> jnp.concatenate([x, y]) + Array([0, 1, 2, 0, 0, 0], dtype=int32) + + Two-dimensional concatenation: + + >>> x = jnp.ones((2, 3)) + >>> y = jnp.zeros((2, 1)) + >>> jnp.concatenate([x, y], axis=1) + Array([[1., 1., 1., 0.], + [1., 1., 1., 0.]], dtype=float32) + """ if isinstance(arrays, (np.ndarray, Array)): return _concatenate_array(arrays, axis, dtype=dtype) util.check_arraylike("concatenate", *arrays) @@ -3145,7 +4100,7 @@ def concatenate(arrays: np.ndarray | Array | Sequence[ArrayLike], arrays_out = [asarray(arr, dtype=dtype) for arr in arrays] # lax.concatenate can be slow to compile for wide concatenations, so form a # tree of concatenations as a workaround especially for op-by-op mode. - # (https://github.com/google/jax/issues/653). + # (https://github.com/jax-ml/jax/issues/653). k = 16 while len(arrays_out) > 1: arrays_out = [lax.concatenate(arrays_out[i:i+k], axis) @@ -3153,15 +4108,96 @@ def concatenate(arrays: np.ndarray | Array | Sequence[ArrayLike], return arrays_out[0] -@util.implements(getattr(np, "concat", None)) def concat(arrays: Sequence[ArrayLike], /, *, axis: int | None = 0) -> Array: + """Join arrays along an existing axis. + + JAX implementation of :func:`array_api.concat`. + + Args: + arrays: a sequence of arrays to concatenate; each must have the same shape + except along the specified axis. If a single array is given it will be + treated equivalently to `arrays = unstack(arrays)`, but the implementation + will avoid explicit unstacking. + axis: specify the axis along which to concatenate. + + Returns: + the concatenated result. + + See also: + - :func:`jax.lax.concatenate`: XLA concatenation API. + - :func:`jax.numpy.concatenate`: NumPy version of this function. + - :func:`jax.numpy.stack`: concatenate arrays along a new axis. + + Examples: + One-dimensional concatenation: + + >>> x = jnp.arange(3) + >>> y = jnp.zeros(3, dtype=int) + >>> jnp.concat([x, y]) + Array([0, 1, 2, 0, 0, 0], dtype=int32) + + Two-dimensional concatenation: + + >>> x = jnp.ones((2, 3)) + >>> y = jnp.zeros((2, 1)) + >>> jnp.concat([x, y], axis=1) + Array([[1., 1., 1., 0.], + [1., 1., 1., 0.]], dtype=float32) + """ util.check_arraylike("concat", *arrays) return jax.numpy.concatenate(arrays, axis=axis) -@util.implements(np.vstack) def vstack(tup: np.ndarray | Array | Sequence[ArrayLike], dtype: DTypeLike | None = None) -> Array: + """Vertically stack arrays. + + JAX implementation of :func:`numpy.vstack`. + + For arrays of two or more dimensions, this is equivalent to + :func:`jax.numpy.concatenate` with ``axis=0``. + + Args: + tup: a sequence of arrays to stack; each must have the same shape along all + but the first axis. If a single array is given it will be treated + equivalently to `tup = unstack(tup)`, but the implementation will avoid + explicit unstacking. + dtype: optional dtype of the resulting array. If not specified, the dtype + will be determined via type promotion rules described in :ref:`type-promotion`. + + Returns: + the stacked result. + + See also: + - :func:`jax.numpy.stack`: stack along arbitrary axes + - :func:`jax.numpy.concatenate`: concatenation along existing axes. + - :func:`jax.numpy.hstack`: stack horizontally, i.e. along axis 1. + - :func:`jax.numpy.dstack`: stack depth-wise, i.e. along axis 2. + + Examples: + Scalar values: + + >>> jnp.vstack([1, 2, 3]) + Array([[1], + [2], + [3]], dtype=int32, weak_type=True) + + 1D arrays: + + >>> x = jnp.arange(4) + >>> y = jnp.ones(4) + >>> jnp.vstack([x, y]) + Array([[0., 1., 2., 3.], + [1., 1., 1., 1.]], dtype=float32) + + 2D arrays: + + >>> x = x.reshape(1, 4) + >>> y = y.reshape(1, 4) + >>> jnp.vstack([x, y]) + Array([[0., 1., 2., 3.], + [1., 1., 1., 1.]], dtype=float32) + """ arrs: Array | list[Array] if isinstance(tup, (np.ndarray, Array)): arrs = jax.vmap(atleast_2d)(tup) @@ -3172,9 +4208,54 @@ def vstack(tup: np.ndarray | Array | Sequence[ArrayLike], return concatenate(arrs, axis=0, dtype=dtype) -@util.implements(np.hstack) def hstack(tup: np.ndarray | Array | Sequence[ArrayLike], dtype: DTypeLike | None = None) -> Array: + """Horizontally stack arrays. + + JAX implementation of :func:`numpy.hstack`. + + For arrays of one or more dimensions, this is equivalent to + :func:`jax.numpy.concatenate` with ``axis=1``. + + Args: + tup: a sequence of arrays to stack; each must have the same shape along all + but the second axis. Input arrays will be promoted to at least rank 1. + If a single array is given it will be treated equivalently to + `tup = unstack(tup)`, but the implementation will avoid explicit unstacking. + dtype: optional dtype of the resulting array. If not specified, the dtype + will be determined via type promotion rules described in :ref:`type-promotion`. + + Returns: + the stacked result. + + See also: + - :func:`jax.numpy.stack`: stack along arbitrary axes + - :func:`jax.numpy.concatenate`: concatenation along existing axes. + - :func:`jax.numpy.vstack`: stack vertically, i.e. along axis 0. + - :func:`jax.numpy.dstack`: stack depth-wise, i.e. along axis 2. + + Examples: + Scalar values: + + >>> jnp.hstack([1, 2, 3]) + Array([1, 2, 3], dtype=int32, weak_type=True) + + 1D arrays: + + >>> x = jnp.arange(3) + >>> y = jnp.ones(3) + >>> jnp.hstack([x, y]) + Array([0., 1., 2., 1., 1., 1.], dtype=float32) + + 2D arrays: + + >>> x = x.reshape(3, 1) + >>> y = y.reshape(3, 1) + >>> jnp.hstack([x, y]) + Array([[0., 1.], + [1., 1.], + [2., 1.]], dtype=float32) + """ arrs: Array | list[Array] if isinstance(tup, (np.ndarray, Array)): arrs = jax.vmap(atleast_1d)(tup) @@ -3187,9 +4268,56 @@ def hstack(tup: np.ndarray | Array | Sequence[ArrayLike], return concatenate(arrs, axis=0 if arr0_ndim == 1 else 1, dtype=dtype) -@util.implements(np.dstack) def dstack(tup: np.ndarray | Array | Sequence[ArrayLike], dtype: DTypeLike | None = None) -> Array: + """Stack arrays depth-wise. + + JAX implementation of :func:`numpy.dstack`. + + For arrays of three or more dimensions, this is equivalent to + :func:`jax.numpy.concatenate` with ``axis=2``. + + Args: + tup: a sequence of arrays to stack; each must have the same shape along all + but the third axis. Input arrays will be promoted to at least rank 3. If a + single array is given it will be treated equivalently to `tup = unstack(tup)`, + but the implementation will avoid explicit unstacking. + dtype: optional dtype of the resulting array. If not specified, the dtype + will be determined via type promotion rules described in :ref:`type-promotion`. + + Returns: + the stacked result. + + See also: + - :func:`jax.numpy.stack`: stack along arbitrary axes + - :func:`jax.numpy.concatenate`: concatenation along existing axes. + - :func:`jax.numpy.vstack`: stack vertically, i.e. along axis 0. + - :func:`jax.numpy.hstack`: stack horizontally, i.e. along axis 1. + + Examples: + Scalar values: + + >>> jnp.dstack([1, 2, 3]) + Array([[[1, 2, 3]]], dtype=int32, weak_type=True) + + 1D arrays: + + >>> x = jnp.arange(3) + >>> y = jnp.ones(3) + >>> jnp.dstack([x, y]) + Array([[[0., 1.], + [1., 1.], + [2., 1.]]], dtype=float32) + + 2D arrays: + + >>> x = x.reshape(1, 3) + >>> y = y.reshape(1, 3) + >>> jnp.dstack([x, y]) + Array([[[0., 1.], + [1., 1.], + [2., 1.]]], dtype=float32) + """ arrs: Array | list[Array] if isinstance(tup, (np.ndarray, Array)): arrs = jax.vmap(atleast_3d)(tup) @@ -3200,8 +4328,56 @@ def dstack(tup: np.ndarray | Array | Sequence[ArrayLike], return concatenate(arrs, axis=2, dtype=dtype) -@util.implements(np.column_stack) def column_stack(tup: np.ndarray | Array | Sequence[ArrayLike]) -> Array: + """Stack arrays column-wise. + + JAX implementation of :func:`numpy.column_stack`. + + For arrays of two or more dimensions, this is equivalent to + :func:`jax.numpy.concatenate` with ``axis=1``. + + Args: + tup: a sequence of arrays to stack; each must have the same leading dimension. + Input arrays will be promoted to at least rank 2. If a single array is given + it will be treated equivalently to `tup = unstack(tup)`, but the implementation + will avoid explicit unstacking. + dtype: optional dtype of the resulting array. If not specified, the dtype + will be determined via type promotion rules described in :ref:`type-promotion`. + + Returns: + the stacked result. + + See also: + - :func:`jax.numpy.stack`: stack along arbitrary axes + - :func:`jax.numpy.concatenate`: concatenation along existing axes. + - :func:`jax.numpy.vstack`: stack vertically, i.e. along axis 0. + - :func:`jax.numpy.hstack`: stack horizontally, i.e. along axis 1. + - :func:`jax.numpy.hstack`: stack depth=wise, i.e. along axis 2. + + Examples: + Scalar values: + + >>> jnp.column_stack([1, 2, 3]) + Array([[1, 2, 3]], dtype=int32, weak_type=True) + + 1D arrays: + + >>> x = jnp.arange(3) + >>> y = jnp.ones(3) + >>> jnp.column_stack([x, y]) + Array([[0., 1.], + [1., 1.], + [2., 1.]], dtype=float32) + + 2D arrays: + + >>> x = x.reshape(3, 1) + >>> y = y.reshape(3, 1) + >>> jnp.column_stack([x, y]) + Array([[0., 1.], + [1., 1.], + [2., 1.]], dtype=float32) + """ arrs: Array | list[Array] | np.ndarray if isinstance(tup, (np.ndarray, Array)): arrs = jax.vmap(lambda x: atleast_2d(x).T)(tup) if tup.ndim < 3 else tup @@ -3209,7 +4385,7 @@ def column_stack(tup: np.ndarray | Array | Sequence[ArrayLike]) -> Array: # TODO(jakevdp): Non-array input deprecated 2023-09-22; change to error. util.check_arraylike("column_stack", *tup, emit_warning=True) arrs = [atleast_2d(arr).T if arr.ndim < 2 else arr for arr in map(asarray, tup)] - return concatenate(arrs, 1) + return concatenate(arrs, axis=1) @util.implements(np.choose, skip_params=['out']) @@ -3275,10 +4451,44 @@ def atleast_1d(x: ArrayLike, /) -> Array: @overload def atleast_1d(x: ArrayLike, y: ArrayLike, /, *arys: ArrayLike) -> list[Array]: ... -@util.implements(np.atleast_1d, update_doc=False, lax_description=_ARRAY_VIEW_DOC) @jit def atleast_1d(*arys: ArrayLike) -> Array | list[Array]: - # TODO(jakevdp): Non-array input deprecated 2023-09-22; change to error. + """Convert inputs to arrays with at least 1 dimension. + + JAX implementation of :func:`numpy.atleast_1d`. + + Args: + zero or more arraylike arguments. + + Returns: + an array or list of arrays corresponding to the input values. Arrays + of shape ``()`` are converted to shape ``(1,)``, and arrays with other + shapes are returned unchanged. + + See also: + - :func:`jax.numpy.asarray` + - :func:`jax.numpy.atleast_2d` + - :func:`jax.numpy.atleast_3d` + + Examples: + Scalar arguments are converted to 1D, length-1 arrays: + + >>> x = jnp.float32(1.0) + >>> jnp.atleast_1d(x) + Array([1.], dtype=float32) + + Higher dimensional inputs are returned unchanged: + + >>> y = jnp.arange(4) + >>> jnp.atleast_1d(y) + Array([0, 1, 2, 3], dtype=int32) + + Multiple arguments can be passed to the function at once, in which + case a list of results is returned: + + >>> jnp.atleast_1d(x, y) + [Array([1.], dtype=float32), Array([0, 1, 2, 3], dtype=int32)] + """ util.check_arraylike("atleast_1d", *arys, emit_warning=True) if len(arys) == 1: return array(arys[0], copy=False, ndmin=1) @@ -3295,9 +4505,52 @@ def atleast_2d(x: ArrayLike, /) -> Array: @overload def atleast_2d(x: ArrayLike, y: ArrayLike, /, *arys: ArrayLike) -> list[Array]: ... -@util.implements(np.atleast_2d, update_doc=False, lax_description=_ARRAY_VIEW_DOC) @jit def atleast_2d(*arys: ArrayLike) -> Array | list[Array]: + """Convert inputs to arrays with at least 2 dimensions. + + JAX implementation of :func:`numpy.atleast_2d`. + + Args: + zero or more arraylike arguments. + + Returns: + an array or list of arrays corresponding to the input values. Arrays + of shape ``()`` are converted to shape ``(1, 1)``, 1D arrays of shape + ``(N,)`` are converted to shape ``(1, N)``, and arrays of all other + shapes are returned unchanged. + + See also: + - :func:`jax.numpy.asarray` + - :func:`jax.numpy.atleast_1d` + - :func:`jax.numpy.atleast_3d` + + Examples: + Scalar arguments are converted to 2D, size-1 arrays: + + >>> x = jnp.float32(1.0) + >>> jnp.atleast_2d(x) + Array([[1.]], dtype=float32) + + One-dimensional arguments have a unit dimension prepended to the shape: + + >>> y = jnp.arange(4) + >>> jnp.atleast_2d(y) + Array([[0, 1, 2, 3]], dtype=int32) + + Higher dimensional inputs are returned unchanged: + + >>> z = jnp.ones((2, 3)) + >>> jnp.atleast_2d(z) + Array([[1., 1., 1.], + [1., 1., 1.]], dtype=float32) + + Multiple arguments can be passed to the function at once, in which + case a list of results is returned: + + >>> jnp.atleast_2d(x, y) + [Array([[1.]], dtype=float32), Array([[0, 1, 2, 3]], dtype=int32)] + """ # TODO(jakevdp): Non-array input deprecated 2023-09-22; change to error. util.check_arraylike("atleast_2d", *arys, emit_warning=True) if len(arys) == 1: @@ -3315,9 +4568,58 @@ def atleast_3d(x: ArrayLike, /) -> Array: @overload def atleast_3d(x: ArrayLike, y: ArrayLike, /, *arys: ArrayLike) -> list[Array]: ... -@util.implements(np.atleast_3d, update_doc=False, lax_description=_ARRAY_VIEW_DOC) @jit def atleast_3d(*arys: ArrayLike) -> Array | list[Array]: + """Convert inputs to arrays with at least 3 dimensions. + + JAX implementation of :func:`numpy.atleast_3d`. + + Args: + zero or more arraylike arguments. + + Returns: + an array or list of arrays corresponding to the input values. Arrays + of shape ``()`` are converted to shape ``(1, 1, 1)``, 1D arrays of + shape ``(N,)`` are converted to shape ``(1, N, 1)``, 2D arrays of + shape ``(M, N)`` are converted to shape ``(M, N, 1)``, and arrays + of all other shapes are returned unchanged. + + See also: + - :func:`jax.numpy.asarray` + - :func:`jax.numpy.atleast_1d` + - :func:`jax.numpy.atleast_2d` + + Examples: + Scalar arguments are converted to 3D, size-1 arrays: + + >>> x = jnp.float32(1.0) + >>> jnp.atleast_3d(x) + Array([[[1.]]], dtype=float32) + + 1D arrays have a unit dimension prepended and appended: + + >>> y = jnp.arange(4) + >>> jnp.atleast_3d(y).shape + (1, 4, 1) + + 2D arrays have a unit dimension appended: + + >>> z = jnp.ones((2, 3)) + >>> jnp.atleast_3d(z).shape + (2, 3, 1) + + Multiple arguments can be passed to the function at once, in which + case a list of results is returned: + + >>> x3, y3 = jnp.atleast_3d(x, y) + >>> print(x3) + [[[1.]]] + >>> print(y3) + [[[0] + [1] + [2] + [3]]] + """ # TODO(jakevdp): Non-array input deprecated 2023-09-22; change to error. util.check_arraylike("atleast_3d", *arys, emit_warning=True) if len(arys) == 1: @@ -3342,22 +4644,73 @@ def _supports_buffer_protocol(obj): return True -_ARRAY_DOC = """ -This function will create arrays on JAX's default device. For control of the -device placement of data, see :func:`jax.device_put`. More information is -available in the JAX FAQ at :ref:`faq-data-placement` (full FAQ at -https://jax.readthedocs.io/en/latest/faq.html). -""" - -deprecations.register("jax-numpy-array-none") - -@util.implements(np.array, lax_description=_ARRAY_DOC, extra_params=""" -device: (optional) :class:`~jax.Device` or :class:`~jax.sharding.Sharding` - to which the created array will be committed. -""") def array(object: Any, dtype: DTypeLike | None = None, copy: bool = True, order: str | None = "K", ndmin: int = 0, *, device: xc.Device | Sharding | None = None) -> Array: + """Convert an object to a JAX array. + + JAX implementation of :func:`numpy.array`. + + Args: + object: an object that is convertible to an array. This includes JAX + arrays, NumPy arrays, Python scalars, Python collections like lists + and tuples, objects with an ``__array__`` method, and objects + supporting the Python buffer protocol. + dtype: optionally specify the dtype of the output array. If not + specified it will be inferred from the input. + copy: specify whether to force a copy of the input. Default: True. + order: not implemented in JAX + ndmin: integer specifying the minimum number of dimensions in the + output array. + device: optional :class:`~jax.Device` or :class:`~jax.sharding.Sharding` + to which the created array will be committed. + + Returns: + A JAX array constructed from the input. + + See also: + - :func:`jax.numpy.asarray`: like `array`, but by default only copies + when necessary. + - :func:`jax.numpy.from_dlpack`: construct a JAX array from an object + that implements the dlpack interface. + - :func:`jax.numpy.frombuffer`: construct a JAX array from an object + that implements the buffer interface. + + Examples: + Constructing JAX arrays from Python scalars: + + >>> jnp.array(True) + Array(True, dtype=bool) + >>> jnp.array(42) + Array(42, dtype=int32, weak_type=True) + >>> jnp.array(3.5) + Array(3.5, dtype=float32, weak_type=True) + >>> jnp.array(1 + 1j) + Array(1.+1.j, dtype=complex64, weak_type=True) + + Constructing JAX arrays from Python collections: + + >>> jnp.array([1, 2, 3]) # list of ints -> 1D array + Array([1, 2, 3], dtype=int32) + >>> jnp.array([(1, 2, 3), (4, 5, 6)]) # list of tuples of ints -> 2D array + Array([[1, 2, 3], + [4, 5, 6]], dtype=int32) + >>> jnp.array(range(5)) + Array([0, 1, 2, 3, 4], dtype=int32) + + Constructing JAX arrays from NumPy arrays: + + >>> jnp.array(np.linspace(0, 2, 5)) + Array([0. , 0.5, 1. , 1.5, 2. ], dtype=float32) + + Constructing a JAX array via the Python buffer interface, using Python's + built-in :mod:`array` module. + + >>> from array import array + >>> pybuffer = array('i', [2, 3, 5, 7]) + >>> jnp.array(pybuffer) + Array([2, 3, 5, 7], dtype=int32) + """ if order is not None and order != "K": raise NotImplementedError("Only implemented for order='K'") @@ -3439,7 +4792,7 @@ def array(object: Any, dtype: DTypeLike | None = None, copy: bool = True, if all(not isinstance(leaf, Array) for leaf in leaves): # TODO(jakevdp): falling back to numpy here fails to overflow for lists # containing large integers; see discussion in - # https://github.com/google/jax/pull/6047. More correct would be to call + # https://github.com/jax-ml/jax/pull/6047. More correct would be to call # coerce_to_array on each leaf, but this may have performance implications. out = np.asarray(object, dtype=dtype) elif isinstance(object, Array): @@ -3479,8 +4832,6 @@ def _convert_to_array_if_dtype_fails(x: ArrayLike) -> ArrayLike: return x -deprecations.register("jax-numpy-astype-complex-to-real") - def astype(x: ArrayLike, dtype: DTypeLike | None, /, *, copy: bool = False, device: xc.Device | Sharding | None = None) -> Array: @@ -3546,13 +4897,72 @@ def astype(x: ArrayLike, dtype: DTypeLike | None, return _array_copy(result) if copy else result -@util.implements(np.asarray, lax_description=_ARRAY_DOC, extra_params=""" -device: (optional) :class:`~jax.Device` or :class:`~jax.sharding.Sharding` - to which the created array will be committed. -""") def asarray(a: Any, dtype: DTypeLike | None = None, order: str | None = None, *, copy: bool | None = None, device: xc.Device | Sharding | None = None) -> Array: + """Convert an object to a JAX array. + + JAX implementation of :func:`numpy.asarray`. + + Args: + a: an object that is convertible to an array. This includes JAX + arrays, NumPy arrays, Python scalars, Python collections like lists + and tuples, objects with an ``__array__`` method, and objects + supporting the Python buffer protocol. + dtype: optionally specify the dtype of the output array. If not + specified it will be inferred from the input. + order: not implemented in JAX + copy: optional boolean specifying the copy mode. If True, then always + return a copy. If False, then error if a copy is necessary. Default is + None, which will only copy when necessary. + device: optional :class:`~jax.Device` or :class:`~jax.sharding.Sharding` + to which the created array will be committed. + + Returns: + A JAX array constructed from the input. + + See also: + - :func:`jax.numpy.array`: like `asarray`, but defaults to `copy=True`. + - :func:`jax.numpy.from_dlpack`: construct a JAX array from an object + that implements the dlpack interface. + - :func:`jax.numpy.frombuffer`: construct a JAX array from an object + that implements the buffer interface. + + Examples: + Constructing JAX arrays from Python scalars: + + >>> jnp.asarray(True) + Array(True, dtype=bool) + >>> jnp.asarray(42) + Array(42, dtype=int32, weak_type=True) + >>> jnp.asarray(3.5) + Array(3.5, dtype=float32, weak_type=True) + >>> jnp.asarray(1 + 1j) + Array(1.+1.j, dtype=complex64, weak_type=True) + + Constructing JAX arrays from Python collections: + + >>> jnp.asarray([1, 2, 3]) # list of ints -> 1D array + Array([1, 2, 3], dtype=int32) + >>> jnp.asarray([(1, 2, 3), (4, 5, 6)]) # list of tuples of ints -> 2D array + Array([[1, 2, 3], + [4, 5, 6]], dtype=int32) + >>> jnp.asarray(range(5)) + Array([0, 1, 2, 3, 4], dtype=int32) + + Constructing JAX arrays from NumPy arrays: + + >>> jnp.asarray(np.linspace(0, 2, 5)) + Array([0. , 0.5, 1. , 1.5, 2. ], dtype=float32) + + Constructing a JAX array via the Python buffer interface, using Python's + built-in :mod:`array` module. + + >>> from array import array + >>> pybuffer = array('i', [2, 3, 5, 7]) + >>> jnp.asarray(pybuffer) + Array([2, 3, 5, 7], dtype=int32) + """ # For copy=False, the array API specifies that we raise a ValueError if the input supports # the buffer protocol but a copy is required. Since array() supports the buffer protocol # via numpy, this is only the case when the default device is not 'cpu' @@ -3568,8 +4978,50 @@ def asarray(a: Any, dtype: DTypeLike | None = None, order: str | None = None, return array(a, dtype=dtype, copy=bool(copy), order=order, device=device) -@util.implements(np.copy, lax_description=_ARRAY_DOC) def copy(a: ArrayLike, order: str | None = None) -> Array: + """Return a copy of the array. + + JAX implementation of :func:`numpy.copy`. + + Args: + a: arraylike object to copy + order: not implemented in JAX + + Returns: + a copy of the input array ``a``. + + See Also: + - :func:`jax.numpy.array`: create an array with or without a copy. + - :meth:`jax.Array.copy`: same function accessed as an array method. + + Examples: + Since JAX arrays are immutable, in most cases explicit array copies + are not necessary. One exception is when using a function with donated + arguments (see the ``donate_argnums`` argument to :func:`jax.jit`). + + >>> f = jax.jit(lambda x: 2 * x, donate_argnums=0) + >>> x = jnp.arange(4) + >>> y = f(x) + >>> print(y) + [0 2 4 6] + + Because we marked ``x`` as being donated, the original array is no longer + available: + + >>> print(x) # doctest: +IGNORE_EXCEPTION_DETAIL + Traceback (most recent call last): + RuntimeError: Array has been deleted with shape=int32[4]. + + In situations like this, an explicit copy will let you keep access to the + original buffer: + + >>> x = jnp.arange(4) + >>> y = f(x.copy()) + >>> print(y) + [0 2 4 6] + >>> print(x) + [0 1 2 3] + """ util.check_arraylike("copy", a) return array(a, copy=True, order=order) @@ -4007,9 +5459,50 @@ def array_equiv(a1: ArrayLike, a2: ArrayLike) -> Array: # General np.from* style functions mostly delegate to numpy. -@util.implements(np.frombuffer) def frombuffer(buffer: bytes | Any, dtype: DTypeLike = float, count: int = -1, offset: int = 0) -> Array: + r"""Convert a buffer into a 1-D JAX array. + + JAX implementation of :func:`numpy.frombuffer`. + + Args: + buffer: an object containing the data. It must be either a bytes object with + a length that is an integer multiple of the dtype element size, or + it must be an object exporting the `Python buffer interface`_. + dtype: optional. Desired data type for the array. Default is ``float64``. + This specifes the dtype used to parse the buffer, but note that after parsing, + 64-bit values will be cast to 32-bit JAX arrays if the ``jax_enable_x64`` + flag is set to ``False``. + count: optional integer specifying the number of items to read from the buffer. + If -1 (default), all items from the buffer are read. + offset: optional integer specifying the number of bytes to skip at the beginning + of the buffer. Default is 0. + + Returns: + A 1-D JAX array representing the interpreted data from the buffer. + + See also: + - :func:`jax.numpy.fromstring`: convert a string of text into 1-D JAX array. + + Examples: + Using a bytes buffer: + + >>> buf = b"\x00\x01\x02\x03\x04" + >>> jnp.frombuffer(buf, dtype=jnp.uint8) + Array([0, 1, 2, 3, 4], dtype=uint8) + >>> jnp.frombuffer(buf, dtype=jnp.uint8, offset=1) + Array([1, 2, 3, 4], dtype=uint8) + + Constructing a JAX array via the Python buffer interface, using Python's + built-in :mod:`array` module. + + >>> from array import array + >>> pybuffer = array('i', [0, 1, 2, 3, 4]) + >>> jnp.frombuffer(pybuffer, dtype=jnp.int32) + Array([0, 1, 2, 3, 4], dtype=int32) + + .. _Python buffer interface: https://docs.python.org/3/c-api/buffer.html + """ return asarray(np.frombuffer(buffer=buffer, dtype=dtype, count=count, offset=offset)) @@ -4118,8 +5611,31 @@ def fromfunction(function: Callable[..., Array], shape: Any, return function(*(arange(s, dtype=dtype) for s in shape), **kwargs) -@util.implements(np.fromstring) def fromstring(string: str, dtype: DTypeLike = float, count: int = -1, *, sep: str) -> Array: + """Convert a string of text into 1-D JAX array. + + JAX implementation of :func:`numpy.fromstring`. + + Args: + string: input string containing the data. + dtype: optional. Desired data type for the array. Default is ``float``. + count: optional integer specifying the number of items to read from the string. + If -1 (default), all items are read. + sep: the string used to separate values in the input string. + + Returns: + A 1-D JAX array containing the parsed data from the input string. + + See also: + - :func:`jax.numpy.frombuffer`: construct a JAX array from an object + that implements the buffer interface. + + Examples: + >>> jnp.fromstring("1 2 3", dtype=int, sep=" ") + Array([1, 2, 3], dtype=int32) + >>> jnp.fromstring("0.1, 0.2, 0.3", dtype=float, count=2, sep=",") + Array([0.1, 0.2], dtype=float32) + """ return asarray(np.fromstring(string=string, dtype=dtype, count=count, sep=sep)) @@ -4402,15 +5918,69 @@ def linspace(start: ArrayLike, stop: ArrayLike, num: int = 50, dtype: DTypeLike | None = None, axis: int = 0, *, device: xc.Device | Sharding | None = None) -> Array | tuple[Array, Array]: ... -@util.implements(np.linspace, extra_params=""" -device: (optional) :class:`~jax.Device` or :class:`~jax.sharding.Sharding` - to which the created array will be committed. -""") def linspace(start: ArrayLike, stop: ArrayLike, num: int = 50, endpoint: bool = True, retstep: bool = False, dtype: DTypeLike | None = None, axis: int = 0, *, device: xc.Device | Sharding | None = None) -> Array | tuple[Array, Array]: + """Return evenly-spaced numbers within an interval. + + JAX implementation of :func:`numpy.linspace`. + + Args: + start: scalar or array of starting values. + stop: scalar or array of stop values. + num: number of values to generate. Default: 50. + endpoint: if True (default) then include the ``stop`` value in the result. + If False, then exclude the ``stop`` value. + retstep: If True, then return a ``(result, step)`` tuple, where ``step`` is the + interval between adjacent values in ``result``. + axis: integer axis along which to generate the linspace. Defaults to zero. + device: optional :class:`~jax.Device` or :class:`~jax.sharding.Sharding` + to which the created array will be committed. + + Returns: + An array ``values``, or a tuple ``(values, step)`` if ``retstep`` is True, where: + + - ``values`` is an array of evenly-spaced values from ``start`` to ``stop`` + - ``step`` is the interval between adjacent values. + + See also: + - :func:`jax.numpy.arange`: Generate ``N`` evenly-spaced values given a starting + point and a step + - :func:`jax.numpy.logspace`: Generate logarithmically-spaced values. + - :func:`jax.numpy.geomspace`: Generate geometrically-spaced values. + + Examples: + List of 5 values between 0 and 10: + + >>> jnp.linspace(0, 10, 5) + Array([ 0. , 2.5, 5. , 7.5, 10. ], dtype=float32) + + List of 8 values between 0 and 10, excluding the endpoint: + + >>> jnp.linspace(0, 10, 8, endpoint=False) + Array([0. , 1.25, 2.5 , 3.75, 5. , 6.25, 7.5 , 8.75], dtype=float32) + + List of values and the step size between them + + >>> vals, step = jnp.linspace(0, 10, 9, retstep=True) + >>> vals + Array([ 0. , 1.25, 2.5 , 3.75, 5. , 6.25, 7.5 , 8.75, 10. ], dtype=float32) + >>> step + Array(1.25, dtype=float32) + + Multi-dimensional linspace: + + >>> start = jnp.array([0, 5]) + >>> stop = jnp.array([5, 10]) + >>> jnp.linspace(start, stop, 5) + Array([[ 0. , 5. ], + [ 1.25, 6.25], + [ 2.5 , 7.5 ], + [ 3.75, 8.75], + [ 5. , 10. ]], dtype=float32) + """ num = core.concrete_dim_or_error(num, "'num' argument of jnp.linspace") axis = core.concrete_or_error(operator.index, axis, "'axis' argument of jnp.linspace") return _linspace(start, stop, num, endpoint, retstep, dtype, axis, device=device) @@ -4473,10 +6043,69 @@ def _linspace(start: ArrayLike, stop: ArrayLike, num: int = 50, return (result, delta) if retstep else result -@util.implements(np.logspace) def logspace(start: ArrayLike, stop: ArrayLike, num: int = 50, endpoint: bool = True, base: ArrayLike = 10.0, dtype: DTypeLike | None = None, axis: int = 0) -> Array: + """Generate logarithmically-spaced values. + + JAX implementation of :func:`numpy.logspace`. + + Args: + start: scalar or array. Used to specify the start value. The start value is + ``base ** start``. + stop: scalar or array. Used to specify the stop value. The end value is + ``base ** stop``. + num: int, optional, default=50. Number of values to generate. + endpoint: bool, optional, default=True. If True, then include the ``stop`` value + in the result. If False, then exclude the ``stop`` value. + base: scalar or array, optional, default=10. Specifies the base of the logarithm. + dtype: optional. Specifies the dtype of the output. + axis: int, optional, default=0. Axis along which to generate the logspace. + + Returns: + An array of logarithm. + + See also: + - :func:`jax.numpy.arange`: Generate ``N`` evenly-spaced values given a starting + point and a step value. + - :func:`jax.numpy.linspace`: Generate evenly-spaced values. + - :func:`jax.numpy.geomspace`: Generate geometrically-spaced values. + + Examples: + List 5 logarithmically spaced values between 1 (``10 ** 0``) and 100 + (``10 ** 2``): + + >>> with jnp.printoptions(precision=3, suppress=True): + ... jnp.logspace(0, 2, 5) + Array([ 1. , 3.162, 10. , 31.623, 100. ], dtype=float32) + + List 5 logarithmically-spaced values between 1(``10 ** 0``) and 100 + (``10 ** 2``), excluding endpoint: + + >>> with jnp.printoptions(precision=3, suppress=True): + ... jnp.logspace(0, 2, 5, endpoint=False) + Array([ 1. , 2.512, 6.31 , 15.849, 39.811], dtype=float32) + + List 7 logarithmically-spaced values between 1 (``2 ** 0``) and 4 (``2 ** 2``) + with base 2: + + >>> with jnp.printoptions(precision=3, suppress=True): + ... jnp.logspace(0, 2, 7, base=2) + Array([1. , 1.26 , 1.587, 2. , 2.52 , 3.175, 4. ], dtype=float32) + + Multi-dimensional logspace: + + >>> start = jnp.array([0, 5]) + >>> stop = jnp.array([5, 0]) + >>> base = jnp.array([2, 3]) + >>> with jnp.printoptions(precision=3, suppress=True): + ... jnp.logspace(start, stop, 5, base=base) + Array([[ 1. , 243. ], + [ 2.378, 61.547], + [ 5.657, 15.588], + [ 13.454, 3.948], + [ 32. , 1. ]], dtype=float32) + """ num = core.concrete_or_error(operator.index, num, "'num' argument of jnp.logspace") axis = core.concrete_or_error(operator.index, axis, "'axis' argument of jnp.logspace") return _logspace(start, stop, num, endpoint, base, dtype, axis) @@ -4499,9 +6128,54 @@ def _logspace(start: ArrayLike, stop: ArrayLike, num: int = 50, return lax.convert_element_type(ufuncs.power(base, lin), dtype) -@util.implements(np.geomspace) def geomspace(start: ArrayLike, stop: ArrayLike, num: int = 50, endpoint: bool = True, dtype: DTypeLike | None = None, axis: int = 0) -> Array: + """Generate geometrically-spaced values. + + JAX implementation of :func:`numpy.geomspace`. + + Args: + start: scalar or array. Specifies the starting values. + stop: scalar or array. Specifies the stop values. + num: int, optional, default=50. Number of values to generate. + endpoint: bool, optional, default=True. If True, then include the ``stop`` value + in the result. If False, then exclude the ``stop`` value. + dtype: optional. Specifies the dtype of the output. + axis: int, optional, default=0. Axis along which to generate the geomspace. + + Returns: + An array containing the geometrically-spaced values. + + See also: + - :func:`jax.numpy.arange`: Generate ``N`` evenly-spaced values given a starting + point and a step value. + - :func:`jax.numpy.linspace`: Generate evenly-spaced values. + - :func:`jax.numpy.logspace`: Generate logarithmically-spaced values. + + Examples: + List 5 geometrically-spaced values between 1 and 16: + + >>> with jnp.printoptions(precision=3, suppress=True): + ... jnp.geomspace(1, 16, 5) + Array([ 1., 2., 4., 8., 16.], dtype=float32) + + List 4 geomtrically-spaced values between 1 and 16, with ``endpoint=False``: + + >>> with jnp.printoptions(precision=3, suppress=True): + ... jnp.geomspace(1, 16, 4, endpoint=False) + Array([1., 2., 4., 8.], dtype=float32) + + Multi-dimensional geomspace: + + >>> start = jnp.array([1, 1000]) + >>> stop = jnp.array([27, 1]) + >>> with jnp.printoptions(precision=3, suppress=True): + ... jnp.geomspace(start, stop, 4) + Array([[ 1., 1000.], + [ 3., 100.], + [ 9., 10.], + [ 27., 1.]], dtype=float32) + """ num = core.concrete_or_error(operator.index, num, "'num' argument of jnp.geomspace") axis = core.concrete_or_error(operator.index, axis, "'axis' argument of jnp.geomspace") return _geomspace(start, stop, num, endpoint, dtype, axis) @@ -4528,9 +6202,67 @@ def _geomspace(start: ArrayLike, stop: ArrayLike, num: int = 50, endpoint: bool return lax.convert_element_type(res, dtype) -@util.implements(np.meshgrid, lax_description=_ARRAY_VIEW_DOC) def meshgrid(*xi: ArrayLike, copy: bool = True, sparse: bool = False, indexing: str = 'xy') -> list[Array]: + """Construct N-dimensional grid arrays from N 1-dimensional vectors. + + JAX implementation of :func:`numpy.meshgrid`. + + Args: + xi: N arrays to convert to a grid. + copy: whether to copy the input arrays. JAX supports only ``copy=True``, + though under JIT compilation the compiler may opt to avoid copies. + sparse: if False (default), then each returned arrays will be of shape + ``[len(x1), len(x2), ..., len(xN)]``. If False, then returned arrays + will be of shape ``[1, 1, ..., len(xi), ..., 1, 1]``. + indexing: options are ``'xy'`` for cartesian indexing (default) or ``'ij'`` + for matrix indexing. + + Returns: + A length-N list of grid arrays. + + See also: + - :obj:`jax.numpy.mgrid`: create a meshgrid using indexing syntax. + - :obj:`jax.numpy.ogrid`: create an open meshgrid using indexing syntax. + + Examples: + For the following examples, we'll use these 1D arrays as inputs: + + >>> x = jnp.array([1, 2]) + >>> y = jnp.array([10, 20, 30]) + + 2D cartesian mesh grid: + + >>> x_grid, y_grid = jnp.meshgrid(x, y) + >>> print(x_grid) + [[1 2] + [1 2] + [1 2]] + >>> print(y_grid) + [[10 10] + [20 20] + [30 30]] + + 2D sparse cartesian mesh grid: + + >>> x_grid, y_grid = jnp.meshgrid(x, y, sparse=True) + >>> print(x_grid) + [[1 2]] + >>> print(y_grid) + [[10] + [20] + [30]] + + 2D matrix-index mesh grid: + + >>> x_grid, y_grid = jnp.meshgrid(x, y, indexing='ij') + >>> print(x_grid) + [[1 1 1] + [2 2 2]] + >>> print(y_grid) + [[10 20 30] + [10 20 30]] + """ util.check_arraylike("meshgrid", *xi) args = [asarray(x) for x in xi] if not copy: @@ -4550,19 +6282,53 @@ def meshgrid(*xi: ArrayLike, copy: bool = True, sparse: bool = False, return output -@custom_jvp -@util.implements(np.i0) @jit def i0(x: ArrayLike) -> Array: + r"""Calculate modified Bessel function of first kind, zeroth order. + + JAX implementation of :func:`numpy.i0`. + + Modified Bessel function of first kind, zeroth order is defined by: + + .. math:: + + \mathrm{i0}(x) = I_0(x) = \sum_{k=0}^{\infty} \frac{(x^2/4)^k}{(k!)^2} + + Args: + x: scalar or array. Specifies the argument of Bessel function. Complex inputs + are not supported. + + Returns: + An array containing the corresponding vlaues of the modified Bessel function + of ``x``. + + See also: + - :func:`jax.scipy.special.i0`: Calculates the modified Bessel function of + zeroth order. + - :func:`jax.scipy.special.i1`: Calculates the modified Bessel function of + first order. + - :func:`jax.scipy.special.i0e`: Calculates the exponentially scaled modified + Bessel function of zeroth order. + + Examples: + >>> x = jnp.array([-2, -1, 0, 1, 2]) + >>> jnp.i0(x) + Array([2.2795851, 1.266066 , 1.0000001, 1.266066 , 2.2795851], dtype=float32) + """ x_arr, = util.promote_args_inexact("i0", x) if not issubdtype(x_arr.dtype, np.floating): raise ValueError(f"Unsupported input type to jax.numpy.i0: {_dtype(x)}") - x_arr = lax.abs(x_arr) - return lax.mul(lax.exp(x_arr), lax.bessel_i0e(x_arr)) + return _i0(x_arr) + + +@custom_jvp +def _i0(x): + abs_x = lax.abs(x) + return lax.mul(lax.exp(abs_x), lax.bessel_i0e(abs_x)) -@i0.defjvp +@_i0.defjvp def _i0_jvp(primals, tangents): - primal_out, tangent_out = jax.jvp(i0.fun, primals, tangents) + primal_out, tangent_out = jax.jvp(_i0.fun, primals, tangents) return primal_out, where(primals[0] == 0, 0.0, tangent_out) def ix_(*args: ArrayLike) -> tuple[Array, ...]: @@ -4996,13 +6762,60 @@ def triu(m: ArrayLike, k: int = 0) -> Array: return lax.select(lax.broadcast(mask, m_shape[:-2]), zeros_like(m), m) -@util.implements(np.trace, skip_params=['out']) @partial(jit, static_argnames=('axis1', 'axis2', 'dtype')) def trace(a: ArrayLike, offset: int | ArrayLike = 0, axis1: int = 0, axis2: int = 1, dtype: DTypeLike | None = None, out: None = None) -> Array: + """Calculate sum of the diagonal of input along the given axes. + + JAX implementation of :func:`numpy.trace`. + + Args: + a: input array. Must have ``a.ndim >= 2``. + offset: optional, int, default=0. Diagonal offset from the main diagonal. + Can be positive or negative. + axis1: optional, default=0. The first axis along which to take the sum of + diagonal. Must be a static integer value. + axis2: optional, default=1. The second axis along which to take the sum of + diagonal. Must be a static integer value. + dtype: optional. The dtype of the output array. Should be provided as static + argument in JIT compilation. + out: Not used by JAX. + + Returns: + An array of dimension x.ndim-2 containing the sum of the diagonal elements + along axes (axis1, axis2) + + See also: + - :func:`jax.numpy.diag`: Returns the specified diagonal or constructs a diagonal + array + - :func:`jax.numpy.diagonal`: Returns the specified diagonal of an array. + - :func:`jax.numpy.diagflat`: Returns a 2-D array with the flattened input array + laid out on the diagonal. + + Examples: + >>> x = jnp.arange(1, 9).reshape(2, 2, 2) + >>> x + Array([[[1, 2], + [3, 4]], + + [[5, 6], + [7, 8]]], dtype=int32) + >>> jnp.trace(x) + Array([ 8, 10], dtype=int32) + >>> jnp.trace(x, offset=1) + Array([3, 4], dtype=int32) + >>> jnp.trace(x, axis1=1, axis2=2) + Array([ 5, 13], dtype=int32) + >>> jnp.trace(x, offset=1, axis1=1, axis2=2) + Array([2, 6], dtype=int32) + """ util.check_arraylike("trace", a) if out is not None: raise NotImplementedError("The 'out' argument to jnp.trace is not supported.") + + if _canonicalize_axis(axis1, ndim(a)) == _canonicalize_axis(axis2, ndim(a)): + raise ValueError(f"axis1 and axis2 can not be same. axis1={axis1} and axis2={axis2}") + dtypes.check_user_dtype_supported(dtype, "trace") a_shape = shape(a) @@ -5426,10 +7239,42 @@ def diag_indices_from(arr: ArrayLike) -> tuple[Array, ...]: return diag_indices(s[0], ndim=nd) -@util.implements(np.diagonal, lax_description=_ARRAY_VIEW_DOC) @partial(jit, static_argnames=('offset', 'axis1', 'axis2')) def diagonal(a: ArrayLike, offset: int = 0, axis1: int = 0, axis2: int = 1) -> Array: + """Returns the specified diagonal of an array. + + JAX implementation of :func:`numpy.diagonal`. + + The JAX version always returns a copy of the input, although if this is used + within a JIT compilation, the compiler may avoid the copy. + + Args: + a: Input array. Must be at least 2-dimensional. + offset: optional, default=0. Diagonal offset from the main diagonal. + Must be a static integer value. Can be positive or negative. + axis1: optional, default=0. The first axis along which to take the diagonal. + axis2: optional, default=1. The second axis along which to take the diagonal. + + Returns: + A 1D array for 2D input, and in general a N-1 dimensional array + for N-dimensional input. + + See also: + - :func:`jax.numpy.diag` + - :func:`jax.numpy.diagflat` + + Examples: + >>> x = jnp.array([[1, 2, 3], + ... [4, 5, 6], + ... [7, 8, 9]]) + >>> jnp.diagonal(x) + Array([1, 5, 9], dtype=int32) + >>> jnp.diagonal(x, offset=1) + Array([2, 6], dtype=int32) + >>> jnp.diagonal(x, offset=-1) + Array([4, 8], dtype=int32) + """ util.check_arraylike("diagonal", a) a_shape = shape(a) if ndim(a) < 2: @@ -5445,8 +7290,53 @@ def diagonal(a: ArrayLike, offset: int = 0, axis1: int = 0, return a[..., i, j] if offset >= 0 else a[..., j, i] -@util.implements(np.diag, lax_description=_ARRAY_VIEW_DOC) def diag(v: ArrayLike, k: int = 0) -> Array: + """Returns the specified diagonal or constructs a diagonal array. + + JAX implementation of :func:`numpy.diag`. + + The JAX version always returns a copy of the input, although if this is used + within a JIT compilation, the compiler may avoid the copy. + + Args: + v: Input array. Can be a 1-D array to create a diagonal matrix or a + 2-D array to extract a diagonal. + k: optional, default=0. Diagonal offset. Positive values place the diagonal + above the main diagonal, negative values place it below the main diagonal. + + Returns: + If `v` is a 2-D array, a 1-D array containing the diagonal elements. + If `v` is a 1-D array, a 2-D array with the input elements placed along the + specified diagonal. + + See also: + - :func:`jax.numpy.diagflat` + - :func:`jax.numpy.diagonal` + + Examples: + Creating a diagonal matrix from a 1-D array: + + >>> jnp.diag(jnp.array([1, 2, 3])) + Array([[1, 0, 0], + [0, 2, 0], + [0, 0, 3]], dtype=int32) + + Specifying a diagonal offset: + + >>> jnp.diag(jnp.array([1, 2, 3]), k=1) + Array([[0, 1, 0, 0], + [0, 0, 2, 0], + [0, 0, 0, 3], + [0, 0, 0, 0]], dtype=int32) + + Extracting a diagonal from a 2-D array: + + >>> x = jnp.array([[1, 2, 3], + ... [4, 5, 6], + ... [7, 8, 9]]) + >>> jnp.diag(x) + Array([1, 5, 9], dtype=int32) + """ return _diag(v, operator.index(k)) @partial(jit, static_argnames=('k',)) @@ -5463,14 +7353,46 @@ def _diag(v, k): else: raise ValueError("diag input must be 1d or 2d") -_SCALAR_VALUE_DOC = """\ -This differs from np.diagflat for some scalar values of v, -jax always returns a two-dimensional array, whereas numpy may -return a scalar depending on the type of v. -""" - -@util.implements(np.diagflat, lax_description=_SCALAR_VALUE_DOC) def diagflat(v: ArrayLike, k: int = 0) -> Array: + """Return a 2-D array with the flattened input array laid out on the diagonal. + + JAX implementation of :func:`numpy.diagflat`. + + This differs from `np.diagflat` for some scalar values of `v`. JAX always returns + a two-dimensional array, whereas NumPy may return a scalar depending on the type + of `v`. + + Args: + v: Input array. Can be N-dimensional but is flattened to 1D. + k: optional, default=0. Diagonal offset. Positive values place the diagonal + above the main diagonal, negative values place it below the main diagonal. + + Returns: + A 2D array with the input elements placed along the diagonal with the + specified offset (k). The remaining entries are filled with zeros. + + See also: + - :func:`jax.numpy.diag` + - :func:`jax.numpy.diagonal` + + Examples: + >>> jnp.diagflat(jnp.array([1, 2, 3])) + Array([[1, 0, 0], + [0, 2, 0], + [0, 0, 3]], dtype=int32) + >>> jnp.diagflat(jnp.array([1, 2, 3]), k=1) + Array([[0, 1, 0, 0], + [0, 0, 2, 0], + [0, 0, 0, 3], + [0, 0, 0, 0]], dtype=int32) + >>> a = jnp.array([[1, 2], + ... [3, 4]]) + >>> jnp.diagflat(a) + Array([[1, 0, 0, 0], + [0, 2, 0, 0], + [0, 0, 3, 0], + [0, 0, 0, 4]], dtype=int32) + """ util.check_arraylike("diagflat", v) v_ravel = ravel(v) v_length = len(v_ravel) @@ -5486,16 +7408,48 @@ def diagflat(v: ArrayLike, k: int = 0) -> Array: return res -@util.implements(np.trim_zeros) -def trim_zeros(filt, trim='fb'): - filt = core.concrete_or_error(asarray, filt, - "Error arose in the `filt` argument of trim_zeros()") - nz = (filt == 0) +def trim_zeros(filt: ArrayLike, trim: str ='fb') -> Array: + """Trim leading and/or trailing zeros of the input array. + + JAX implementation of :func:`numpy.trim_zeros`. + + Args: + filt: input array. Must have ``filt.ndim == 1``. + trim: string, optional, default = ``fb``. Specifies from which end the input + is trimmed. + + - ``f`` - trims only the leading zeros. + - ``b`` - trims only the trailing zeros. + - ``fb`` - trims both leading and trailing zeros. + + Returns: + An array containig the trimmed input with same dtype as ``filt``. + + Examples: + >>> x = jnp.array([0, 0, 2, 0, 1, 4, 3, 0, 0, 0]) + >>> jnp.trim_zeros(x) + Array([2, 0, 1, 4, 3], dtype=int32) + """ + # Non-array inputs are deprecated 2024-09-11 + util.check_arraylike("trim_zeros", filt, emit_warning=True) + core.concrete_or_error(None, filt, + "Error arose in the `filt` argument of trim_zeros()") + filt_arr = jax.numpy.asarray(filt) + del filt + if filt_arr.ndim != 1: + # Added on 2024-09-11 + if deprecations.is_accelerated("jax-numpy-trimzeros-not-1d-array"): + raise TypeError(f"'filt' must be 1-D array, but received {filt_arr.ndim}-D array.") + warnings.warn( + "Passing arrays with ndim != 1 to jnp.trim_zeros() is deprecated. Currently, it " + "works with Arrays having ndim != 1. In the future this will result in an error.", + DeprecationWarning, stacklevel=2) + nz = (filt_arr == 0) if reductions.all(nz): - return empty(0, _dtype(filt)) - start = argmin(nz) if 'f' in trim.lower() else 0 - end = argmin(nz[::-1]) if 'b' in trim.lower() else 0 - return filt[start:len(filt) - end] + return empty(0, filt_arr.dtype) + start: Array | int = argmin(nz) if 'f' in trim.lower() else 0 + end: Array | int = argmin(nz[::-1]) if 'b' in trim.lower() else 0 + return filt_arr[start:len(filt_arr) - end] def trim_zeros_tol(filt, tol, trim='fb'): @@ -5845,20 +7799,17 @@ def dot(a: ArrayLike, b: ArrayLike, *, batch_dims = ((), ()) a_ndim, b_ndim = ndim(a), ndim(b) if a_ndim == 0 or b_ndim == 0: - # TODO(jakevdp): lower this case to dot_general as well? - # Currently, doing so causes issues in remat tests due to #16805 - if preferred_element_type is not None: - a = a.astype(preferred_element_type) - b = b.astype(preferred_element_type) - result = lax.mul(a, b) + contract_dims: tuple[tuple[int, ...], tuple[int, ...]] = ((), ()) else: if b_ndim == 1: contract_dims = ((a_ndim - 1,), (0,)) else: contract_dims = ((a_ndim - 1,), (b_ndim - 2,)) - result = lax.dot_general(a, b, dimension_numbers=(contract_dims, batch_dims), - precision=precision, preferred_element_type=preferred_element_type) - return lax_internal._convert_element_type(result, preferred_element_type, output_weak_type) + result = lax.dot_general(a, b, dimension_numbers=(contract_dims, batch_dims), + precision=precision, + preferred_element_type=preferred_element_type) + return lax_internal._convert_element_type(result, preferred_element_type, + output_weak_type) @partial(jit, static_argnames=('precision', 'preferred_element_type'), inline=True) @@ -5871,7 +7822,7 @@ def matmul(a: ArrayLike, b: ArrayLike, *, JAX implementation of :func:`numpy.matmul`. Args: - a: first input array, of shape ``(..., N)``. + a: first input array, of shape ``(N,)`` or ``(..., K, N)``. b: second input array. Must have shape ``(N,)`` or ``(..., N, M)``. In the multi-dimensional case, leading dimensions must be broadcast-compatible with the leading dimensions of ``a``. @@ -6772,9 +8723,33 @@ def inner( preferred_element_type=preferred_element_type) -@util.implements(np.outer, skip_params=['out']) @partial(jit, inline=True) def outer(a: ArrayLike, b: ArrayLike, out: None = None) -> Array: + """Compute the outer product of two arrays. + + JAX implementation of :func:`numpy.outer`. + + Args: + a: first input array, if not 1D it will be flattened. + b: second input array, if not 1D it will be flattened. + out: unsupported by JAX. + + Returns: + The outer product of the inputs ``a`` and ``b``. Returned array + will be of shape ``(a.size, b.size)``. + + See also: + - :func:`jax.numpy.inner`: compute the inner product of two arrays. + - :func:`jax.numpy.einsum`: Einstein summation. + + Examples: + >>> a = jnp.array([1, 2, 3]) + >>> b = jnp.array([4, 5, 6]) + >>> jnp.outer(a, b) + Array([[ 4, 5, 6], + [ 8, 10, 12], + [12, 15, 18]], dtype=int32) + """ if out is not None: raise NotImplementedError("The 'out' argument to jnp.outer is not supported.") util.check_arraylike("outer", a, b) @@ -6810,9 +8785,39 @@ def cross(a, b, axisa: int = -1, axisb: int = -1, axisc: int = -1, return moveaxis(c, 0, axisc) -@util.implements(np.kron) @jit def kron(a: ArrayLike, b: ArrayLike) -> Array: + """Compute the Kronecker product of two input arrays. + + JAX implementation of :func:`numpy.kron`. + + The Kronecker product is an operation on two matrices of arbitrary size that + produces a block matrix. Each element of the first matrix ``a`` is multiplied by + the entire second matrix ``b``. If ``a`` has shape (m, n) and ``b`` + has shape (p, q), the resulting matrix will have shape (m * p, n * q). + + Args: + a: first input array with any shape. + b: second input array with any shape. + + Returns: + A new array representing the Kronecker product of the inputs ``a`` and ``b``. + The shape of the output is the element-wise product of the input shapes. + + See also: + - :func:`jax.numpy.outer`: compute the outer product of two arrays. + + Examples: + >>> a = jnp.array([[1, 2], + ... [3, 4]]) + >>> b = jnp.array([[5, 6], + ... [7, 8]]) + >>> jnp.kron(a, b) + Array([[ 5, 6, 10, 12], + [ 7, 8, 14, 16], + [15, 18, 20, 24], + [21, 24, 28, 32]], dtype=int32) + """ util.check_arraylike("kron", a, b) a, b = util.promote_dtypes(a, b) if ndim(a) < ndim(b): @@ -6825,11 +8830,51 @@ def kron(a: ArrayLike, b: ArrayLike) -> Array: return reshape(lax.mul(a_reshaped, b_reshaped), out_shape) -@util.implements(np.vander) @partial(jit, static_argnames=('N', 'increasing')) def vander( x: ArrayLike, N: int | None = None, increasing: bool = False ) -> Array: + """Generate a Vandermonde matrix. + + JAX implementation of :func:`numpy.vander`. + + Args: + x: input array. Must have ``x.ndim == 1``. + N: int, optional, default=None. Specifies the number of the columns the + output matrix. If not specified, ``N = len(x)``. + increasing: bool, optional, default=False. Specifies the order of the powers + of the columns. If ``True``, the powers increase from left to right, + :math:`[x^0, x^1, ..., x^{(N-1)}]`. By default, the powers decrease from left to + right :math:`[x^{(N-1)}, ..., x^1, x^0]`. + + Returns: + An array of shape ``[len(x), N]`` containing the generated Vandermonde matrix. + + Examples: + >>> x = jnp.array([1, 2, 3, 4]) + >>> jnp.vander(x) + Array([[ 1, 1, 1, 1], + [ 8, 4, 2, 1], + [27, 9, 3, 1], + [64, 16, 4, 1]], dtype=int32) + + If ``N = 2``, generates a Vandermonde matrix with ``2`` columns. + + >>> jnp.vander(x, N=2) + Array([[1, 1], + [2, 1], + [3, 1], + [4, 1]], dtype=int32) + + Generates the Vandermonde matrix in increaing order of powers, when + ``increasing=True``. + + >>> jnp.vander(x, increasing=True) + Array([[ 1, 1, 1, 1], + [ 1, 2, 4, 8], + [ 1, 3, 9, 27], + [ 1, 4, 16, 64]], dtype=int32) + """ util.check_arraylike("vander", x) x = asarray(x) if x.ndim != 1: @@ -6913,9 +8958,41 @@ def argwhere( return result.reshape(result.shape[0], ndim(a)) -@util.implements(np.argmax, skip_params=['out']) def argmax(a: ArrayLike, axis: int | None = None, out: None = None, keepdims: bool | None = None) -> Array: + """Return the index of the maximum value of an array. + + JAX implementation of :func:`numpy.argmax`. + + Args: + a: input array + axis: optional integer specifying the axis along which to find the maximum + value. If ``axis`` is not specified, ``a`` will be flattened. + out: unused by JAX + keepdims: if True, then return an array with the same number of dimensions + as ``a``. + + Returns: + an array containing the index of the maximum value along the specified axis. + + See also: + - :func:`jax.numpy.argmin`: return the index of the minimum value. + - :func:`jax.numpy.nanargmax`: compute ``argmax`` while ignoring NaN values. + + Examples: + >>> x = jnp.array([1, 3, 5, 4, 2]) + >>> jnp.argmax(x) + Array(2, dtype=int32) + + >>> x = jnp.array([[1, 3, 2], + ... [5, 4, 1]]) + >>> jnp.argmax(x, axis=1) + Array([1, 0], dtype=int32) + + >>> jnp.argmax(x, axis=1, keepdims=True) + Array([[1], + [0]], dtype=int32) + """ util.check_arraylike("argmax", a) if out is not None: raise NotImplementedError("The 'out' argument to jnp.argmax is not supported.") @@ -6935,9 +9012,42 @@ def _argmax(a: Array, axis: int | None = None, keepdims: bool = False) -> Array: result = lax.argmax(a, _canonicalize_axis(axis, a.ndim), dtypes.canonicalize_dtype(int_)) return expand_dims(result, dims) if keepdims else result -@util.implements(np.argmin, skip_params=['out']) + def argmin(a: ArrayLike, axis: int | None = None, out: None = None, keepdims: bool | None = None) -> Array: + """Return the index of the minimum value of an array. + + JAX implementation of :func:`numpy.argmax`. + + Args: + a: input array + axis: optional integer specifying the axis along which to find the maximum + value. If ``axis`` is not specified, ``a`` will be flattened. + out: unused by JAX + keepdims: if True, then return an array with the same number of dimensions + as ``a``. + + Returns: + an array containing the index of the maximum value along the specified axis. + + See also: + - :func:`jax.numpy.argmax`: return the index of the maximum value. + - :func:`jax.numpy.nanargmin`: compute ``argmin`` while ignoring NaN values. + + Examples: + >>> x = jnp.array([1, 3, 5, 4, 2]) + >>> jnp.argmin(x) + Array(0, dtype=int32) + + >>> x = jnp.array([[1, 3, 2], + ... [5, 4, 1]]) + >>> jnp.argmin(x, axis=1) + Array([0, 2], dtype=int32) + + >>> jnp.argmin(x, axis=1, keepdims=True) + Array([[0], + [2]], dtype=int32) + """ util.check_arraylike("argmin", a) if out is not None: raise NotImplementedError("The 'out' argument to jnp.argmin is not supported.") @@ -6958,19 +9068,57 @@ def _argmin(a: Array, axis: int | None = None, keepdims: bool = False) -> Array: return expand_dims(result, dims) if keepdims else result -_NANARG_DOC = """\ -Warning: jax.numpy.arg{} returns -1 for all-NaN slices and does not raise -an error. -""" - - -@util.implements(np.nanargmax, lax_description=_NANARG_DOC.format("max"), skip_params=['out']) def nanargmax( a: ArrayLike, axis: int | None = None, out: None = None, keepdims: bool | None = None, ) -> Array: + """Return the index of the maximum value of an array, ignoring NaNs. + + JAX implementation of :func:`numpy.nanargmax`. + + Args: + a: input array + axis: optional integer specifying the axis along which to find the maximum + value. If ``axis`` is not specified, ``a`` will be flattened. + out: unused by JAX + keepdims: if True, then return an array with the same number of dimensions + as ``a``. + + Returns: + an array containing the index of the maximum value along the specified axis. + + Note: + In the case of an axis with all-NaN values, the returned index will be -1. + This differs from the behavior of :func:`numpy.nanargmax`, which raises an error. + + See also: + - :func:`jax.numpy.argmax`: return the index of the maximum value. + - :func:`jax.numpy.nanargmin`: compute ``argmin`` while ignoring NaN values. + + Examples: + >>> x = jnp.array([1, 3, 5, 4, jnp.nan]) + + Using a standard :func:`~jax.numpy.argmax` leads to potentially unexpected results: + + >>> jnp.argmax(x) + Array(4, dtype=int32) + + Using ``nanargmax`` returns the index of the maximum non-NaN value. + + >>> jnp.nanargmax(x) + Array(2, dtype=int32) + + >>> x = jnp.array([[1, 3, jnp.nan], + ... [5, 4, jnp.nan]]) + >>> jnp.nanargmax(x, axis=1) + Array([1, 0], dtype=int32) + + >>> jnp.nanargmax(x, axis=1, keepdims=True) + Array([[1], + [0]], dtype=int32) + """ if out is not None: raise NotImplementedError("The 'out' argument to jnp.nanargmax is not supported.") return _nanargmax(a, None if axis is None else operator.index(axis), keepdims=bool(keepdims)) @@ -6987,13 +9135,50 @@ def _nanargmax(a, axis: int | None = None, keepdims: bool = False): return where(reductions.all(nan_mask, axis=axis, keepdims=keepdims), -1, res) -@util.implements(np.nanargmin, lax_description=_NANARG_DOC.format("min"), skip_params=['out']) def nanargmin( a: ArrayLike, axis: int | None = None, out: None = None, keepdims: bool | None = None, ) -> Array: + + """Return the index of the minimum value of an array, ignoring NaNs. + + JAX implementation of :func:`numpy.nanargmin`. + + Args: + a: input array + axis: optional integer specifying the axis along which to find the maximum + value. If ``axis`` is not specified, ``a`` will be flattened. + out: unused by JAX + keepdims: if True, then return an array with the same number of dimensions + as ``a``. + + Returns: + an array containing the index of the minimum value along the specified axis. + + Note: + In the case of an axis with all-NaN values, the returned index will be -1. + This differs from the behavior of :func:`numpy.nanargmin`, which raises an error. + + See also: + - :func:`jax.numpy.argmin`: return the index of the minimum value. + - :func:`jax.numpy.nanargmax`: compute ``argmax`` while ignoring NaN values. + + Examples: + >>> x = jnp.array([jnp.nan, 3, 5, 4, 2]) + >>> jnp.nanargmin(x) + Array(4, dtype=int32) + + >>> x = jnp.array([[1, 3, jnp.nan], + ... [5, 4, jnp.nan]]) + >>> jnp.nanargmin(x, axis=1) + Array([0, 1], dtype=int32) + + >>> jnp.nanargmin(x, axis=1, keepdims=True) + Array([[0], + [1]], dtype=int32) + """ if out is not None: raise NotImplementedError("The 'out' argument to jnp.nanargmin is not supported.") return _nanargmin(a, None if axis is None else operator.index(axis), keepdims=bool(keepdims)) @@ -7073,11 +9258,40 @@ def sort( return lax.rev(result, dimensions=[dimension]) if descending else result -@util.implements(np.sort_complex) @jit def sort_complex(a: ArrayLike) -> Array: + """Return a sorted copy of complex array. + + JAX implementation of :func:`numpy.sort_complex`. + + Complex numbers are sorted lexicographically, meaning by their real part + first, and then by their imaginary part if real parts are equal. + + Args: + a: input array. If dtype is not complex, the array will be upcast to complex. + + Returns: + A sorted array of the same shape and complex dtype as the input. If ``a`` + is multi-dimensional, it is sorted along the last axis. + + See also: + - :func:`jax.numpy.sort`: Return a sorted copy of an array. + + Examples: + >>> a = jnp.array([1+2j, 2+4j, 3-1j, 2+3j]) + >>> jnp.sort_complex(a) + Array([1.+2.j, 2.+3.j, 2.+4.j, 3.-1.j], dtype=complex64) + + Multi-dimensional arrays are sorted along the last axis: + + >>> a = jnp.array([[5, 3, 4], + ... [6, 9, 2]]) + >>> jnp.sort_complex(a) + Array([[3.+0.j, 4.+0.j, 5.+0.j], + [2.+0.j, 6.+0.j, 9.+0.j]], dtype=complex64) + """ util.check_arraylike("sort_complex", a) - a = lax.sort(asarray(a), dimension=0) + a = lax.sort(asarray(a)) return lax.convert_element_type(a, dtypes.to_complex_dtype(a.dtype)) @util.implements(np.lexsort) @@ -7352,9 +9566,45 @@ def _roll_static(a: Array, shift: Sequence[int], axis: Sequence[int]) -> Array: dimension=ax) return a -@util.implements(np.roll) def roll(a: ArrayLike, shift: ArrayLike | Sequence[int], axis: int | Sequence[int] | None = None) -> Array: + """Roll the elements of an array along a specified axis. + + JAX implementation of :func:`numpy.roll`. + + Args: + a: input array. + shift: the number of positions to shift the specified axis. If an integer, + all axes are shifted by the same amount. If a tuple, the shift for each + axis is specified individually. + axis: the axis or axes to roll. If ``None``, the array is flattened, shifted, + and then reshaped to its original shape. + + Returns: + A copy of ``a`` with elements rolled along the specified axis or axes. + + See also: + - :func:`jax.numpy.rollaxis`: roll the specified axis to a given position. + + Examples: + >>> a = jnp.array([0, 1, 2, 3, 4, 5]) + >>> jnp.roll(a, 2) + Array([4, 5, 0, 1, 2, 3], dtype=int32) + + Roll elements along a specific axis: + + >>> a = jnp.array([[ 0, 1, 2, 3], + ... [ 4, 5, 6, 7], + ... [ 8, 9, 10, 11]]) + >>> jnp.roll(a, 1, axis=0) + Array([[ 8, 9, 10, 11], + [ 0, 1, 2, 3], + [ 4, 5, 6, 7]], dtype=int32) + >>> jnp.roll(a, [2, 3], axis=[0, 1]) + Array([[ 5, 6, 7, 4], + [ 9, 10, 11, 8], + [ 1, 2, 3, 0]], dtype=int32) + """ util.check_arraylike("roll", a) arr = asarray(a) if axis is None: @@ -8290,11 +10540,11 @@ def _eliminate_deprecated_list_indexing(idx): if any(_should_unpack_list_index(i) for i in idx): msg = ("Using a non-tuple sequence for multidimensional indexing is not allowed; " "use `arr[tuple(seq)]` instead of `arr[seq]`. " - "See https://github.com/google/jax/issues/4564 for more information.") + "See https://github.com/jax-ml/jax/issues/4564 for more information.") else: msg = ("Using a non-tuple sequence for multidimensional indexing is not allowed; " "use `arr[array(seq)]` instead of `arr[seq]`. " - "See https://github.com/google/jax/issues/4564 for more information.") + "See https://github.com/jax-ml/jax/issues/4564 for more information.") raise TypeError(msg) else: idx = (idx,) @@ -8340,7 +10590,10 @@ def _expand_bool_indices(idx, shape): i_shape = _shape(i) start = len(out) + ellipsis_offset - newaxis_offset expected_shape = shape[start: start + _ndim(i)] - if i_shape != expected_shape: + if len(i_shape) != len(expected_shape): + raise IndexError(f"too many boolean indices at index {dim_number}: got mask of shape " + f"{i_shape}, but only {len(expected_shape)} dimensions remain.") + if not all(s1 in (0, s2) for s1, s2 in zip(i_shape, expected_shape)): raise IndexError("boolean index did not match shape of indexed array in index " f"{dim_number}: got {i_shape}, expected {expected_shape}") out.extend(np.where(i)) @@ -8465,8 +10718,28 @@ def clamp_index(i: DimSize, which: str): return start, step, slice_size -@util.implements(np.blackman) def blackman(M: int) -> Array: + """Return a Blackman window of size M. + + JAX implementation of :func:`numpy.blackman`. + + Args: + M: The window size. + + Returns: + An array of size M containing the Blackman window. + + Examples: + >>> with jnp.printoptions(precision=2, suppress=True): + ... print(jnp.blackman(4)) + [-0. 0.63 0.63 -0. ] + + See also: + - :func:`jax.numpy.bartlett`: return a Bartlett window of size M. + - :func:`jax.numpy.hamming`: return a Hamming window of size M. + - :func:`jax.numpy.hanning`: return a Hanning window of size M. + - :func:`jax.numpy.kaiser`: return a Kaiser window of size M. + """ M = core.concrete_or_error(int, M, "M argument of jnp.blackman") dtype = dtypes.canonicalize_dtype(float_) if M <= 1: @@ -8475,8 +10748,28 @@ def blackman(M: int) -> Array: return 0.42 - 0.5 * ufuncs.cos(2 * pi * n / (M - 1)) + 0.08 * ufuncs.cos(4 * pi * n / (M - 1)) -@util.implements(np.bartlett) def bartlett(M: int) -> Array: + """Return a Bartlett window of size M. + + JAX implementation of :func:`numpy.bartlett`. + + Args: + M: The window size. + + Returns: + An array of size M containing the Bartlett window. + + Examples: + >>> with jnp.printoptions(precision=2, suppress=True): + ... print(jnp.bartlett(4)) + [0. 0.67 0.67 0. ] + + See also: + - :func:`jax.numpy.blackman`: return a Blackman window of size M. + - :func:`jax.numpy.hamming`: return a Hamming window of size M. + - :func:`jax.numpy.hanning`: return a Hanning window of size M. + - :func:`jax.numpy.kaiser`: return a Kaiser window of size M. + """ M = core.concrete_or_error(int, M, "M argument of jnp.bartlett") dtype = dtypes.canonicalize_dtype(float_) if M <= 1: @@ -8485,8 +10778,28 @@ def bartlett(M: int) -> Array: return 1 - ufuncs.abs(2 * n + 1 - M) / (M - 1) -@util.implements(np.hamming) def hamming(M: int) -> Array: + """Return a Hamming window of size M. + + JAX implementation of :func:`numpy.hamming`. + + Args: + M: The window size. + + Returns: + An array of size M containing the Hamming window. + + Examples: + >>> with jnp.printoptions(precision=2, suppress=True): + ... print(jnp.hamming(4)) + [0.08 0.77 0.77 0.08] + + See also: + - :func:`jax.numpy.bartlett`: return a Bartlett window of size M. + - :func:`jax.numpy.blackman`: return a Blackman window of size M. + - :func:`jax.numpy.hanning`: return a Hanning window of size M. + - :func:`jax.numpy.kaiser`: return a Kaiser window of size M. + """ M = core.concrete_or_error(int, M, "M argument of jnp.hamming") dtype = dtypes.canonicalize_dtype(float_) if M <= 1: @@ -8495,8 +10808,28 @@ def hamming(M: int) -> Array: return 0.54 - 0.46 * ufuncs.cos(2 * pi * n / (M - 1)) -@util.implements(np.hanning) def hanning(M: int) -> Array: + """Return a Hanning window of size M. + + JAX implementation of :func:`numpy.hanning`. + + Args: + M: The window size. + + Returns: + An array of size M containing the Hanning window. + + Examples: + >>> with jnp.printoptions(precision=2, suppress=True): + ... print(jnp.hanning(4)) + [0. 0.75 0.75 0. ] + + See also: + - :func:`jax.numpy.bartlett`: return a Bartlett window of size M. + - :func:`jax.numpy.blackman`: return a Blackman window of size M. + - :func:`jax.numpy.hamming`: return a Hamming window of size M. + - :func:`jax.numpy.kaiser`: return a Kaiser window of size M. + """ M = core.concrete_or_error(int, M, "M argument of jnp.hanning") dtype = dtypes.canonicalize_dtype(float_) if M <= 1: @@ -8505,8 +10838,29 @@ def hanning(M: int) -> Array: return 0.5 * (1 - ufuncs.cos(2 * pi * n / (M - 1))) -@util.implements(np.kaiser) def kaiser(M: int, beta: ArrayLike) -> Array: + """Return a Kaiser window of size M. + + JAX implementation of :func:`numpy.kaiser`. + + Args: + M: The window size. + beta: The Kaiser window parameter. + + Returns: + An array of size M containing the Kaiser window. + + Examples: + >>> with jnp.printoptions(precision=2, suppress=True): + ... print(jnp.kaiser(4, 1.5)) + [0.61 0.95 0.95 0.61] + + See also: + - :func:`jax.numpy.bartlett`: return a Bartlett window of size M. + - :func:`jax.numpy.blackman`: return a Blackman window of size M. + - :func:`jax.numpy.hamming`: return a Hamming window of size M. + - :func:`jax.numpy.hanning`: return a Hanning window of size M. + """ M = core.concrete_or_error(int, M, "M argument of jnp.kaiser") dtype = dtypes.canonicalize_dtype(float_) if M <= 1: @@ -8526,9 +10880,43 @@ def _gcd_body_fn(xs: tuple[Array, Array]) -> tuple[Array, Array]: where(x2 != 0, lax.rem(x1, x2), _lax_const(x2, 0))) return (where(x1 < x2, x2, x1), where(x1 < x2, x1, x2)) -@util.implements(np.gcd, module='numpy') @jit def gcd(x1: ArrayLike, x2: ArrayLike) -> Array: + """Compute the greatest common divisor of two arrays. + + JAX implementation of :func:`numpy.gcd`. + + Args: + x1: First input array. The elements must have integer dtype. + x2: Second input array. The elements must have integer dtype. + + Returns: + An array containing the greatest common divisors of the corresponding + elements from the absolute values of `x1` and `x2`. + + See also: + - :func:`jax.numpy.lcm`: compute the least common multiple of two arrays. + + Examples: + Scalar inputs: + + >>> jnp.gcd(12, 18) + Array(6, dtype=int32, weak_type=True) + + Array inputs: + + >>> x1 = jnp.array([12, 18, 24]) + >>> x2 = jnp.array([5, 10, 15]) + >>> jnp.gcd(x1, x2) + Array([1, 2, 3], dtype=int32) + + Broadcasting: + + >>> x1 = jnp.array([12]) + >>> x2 = jnp.array([6, 9, 12]) + >>> jnp.gcd(x1, x2) + Array([ 6, 3, 12], dtype=int32) + """ util.check_arraylike("gcd", x1, x2) x1, x2 = util.promote_dtypes(x1, x2) if not issubdtype(_dtype(x1), integer): @@ -8538,9 +10926,43 @@ def gcd(x1: ArrayLike, x2: ArrayLike) -> Array: return gcd -@util.implements(np.lcm, module='numpy') @jit def lcm(x1: ArrayLike, x2: ArrayLike) -> Array: + """Compute the least common multiple of two arrays. + + JAX implementation of :func:`numpy.lcm`. + + Args: + x1: First input array. The elements must have integer dtype. + x2: Second input array. The elements must have integer dtype. + + Returns: + An array containing the least common multiple of the corresponding + elements from the absolute values of `x1` and `x2`. + + See also: + - :func:`jax.numpy.gcd`: compute the greatest common divisor of two arrays. + + Examples: + Scalar inputs: + + >>> jnp.lcm(12, 18) + Array(36, dtype=int32, weak_type=True) + + Array inputs: + + >>> x1 = jnp.array([12, 18, 24]) + >>> x2 = jnp.array([5, 10, 15]) + >>> jnp.lcm(x1, x2) + Array([ 60, 90, 120], dtype=int32) + + Broadcasting: + + >>> x1 = jnp.array([12]) + >>> x2 = jnp.array([6, 9, 12]) + >>> jnp.lcm(x1, x2) + Array([12, 36, 12], dtype=int32) + """ util.check_arraylike("lcm", x1, x2) x1, x2 = util.promote_dtypes(x1, x2) x1, x2 = ufuncs.abs(x1), ufuncs.abs(x2) @@ -8813,7 +11235,7 @@ def body_fun(state, _): def _searchsorted_via_sort(sorted_arr: Array, query: Array, side: str, dtype: type) -> Array: working_dtype = int32 if sorted_arr.size + query.size < np.iinfo(np.int32).max else int64 def _rank(x): - idx = lax.iota(working_dtype, len(x)) + idx = lax.iota(working_dtype, x.shape[0]) return zeros_like(idx).at[argsort(x)].set(idx) query_flat = query.ravel() if side == 'left': @@ -8906,8 +11328,8 @@ def searchsorted(a: ArrayLike, v: ArrayLike, side: str = 'left', a, v = util.promote_dtypes(a, v) if sorter is not None: a = a[sorter] - dtype = int32 if len(a) <= np.iinfo(np.int32).max else int64 - if len(a) == 0: + dtype = int32 if a.shape[0] <= np.iinfo(np.int32).max else int64 + if a.shape[0] == 0: return zeros_like(v, dtype=dtype) impl = { 'scan': partial(_searchsorted_via_scan, False), @@ -8917,9 +11339,46 @@ def searchsorted(a: ArrayLike, v: ArrayLike, side: str = 'left', }[method] return impl(asarray(a), asarray(v), side, dtype) # type: ignore -@util.implements(np.digitize) -@partial(jit, static_argnames=('right',)) -def digitize(x: ArrayLike, bins: ArrayLike, right: bool = False) -> Array: + +@partial(jit, static_argnames=('right', 'method')) +def digitize(x: ArrayLike, bins: ArrayLike, right: bool = False, + *, method: str | None = None) -> Array: + """Convert an array to bin indices. + + JAX implementation of :func:`numpy.digitize`. + + Args: + x: array of values to digitize. + bins: 1D array of bin edges. Must be monotonically increasing or decreasing. + right: if true, the intervals include the right bin edges. If false (default) + the intervals include the left bin edges. + method: optional method argument to be passed to :func:`~jax.numpy.searchsorted`. + See that function for available options. + + Returns: + An integer array of the same shape as ``x`` indicating the bin number that + the values are in. + + See also: + - :func:`jax.numpy.searchsorted`: find insertion indices for values in a + sorted array. + - :func:`jax.numpy.histogram`: compute frequency of array values within + specified bins. + + Examples: + >>> x = jnp.array([1.0, 2.0, 2.5, 1.5, 3.0, 3.5]) + >>> bins = jnp.array([1, 2, 3]) + >>> jnp.digitize(x, bins) + Array([1, 2, 2, 1, 3, 3], dtype=int32) + >>> jnp.digitize(x, bins, right=True) + Array([0, 1, 2, 1, 2, 3], dtype=int32) + + ``digitize`` supports reverse-ordered bins as well: + + >>> bins = jnp.array([3, 2, 1]) + >>> jnp.digitize(x, bins) + Array([2, 1, 1, 2, 0, 0], dtype=int32) + """ util.check_arraylike("digitize", x, bins) right = core.concrete_or_error(bool, right, "right argument of jnp.digitize()") bins_arr = asarray(bins) @@ -8928,22 +11387,83 @@ def digitize(x: ArrayLike, bins: ArrayLike, right: bool = False) -> Array: if bins_arr.shape[0] == 0: return zeros_like(x, dtype=int32) side = 'right' if not right else 'left' + kwds: dict[str, str] = {} if method is None else {'method': method} return where( bins_arr[-1] >= bins_arr[0], - searchsorted(bins_arr, x, side=side), - len(bins_arr) - searchsorted(bins_arr[::-1], x, side=side) + searchsorted(bins_arr, x, side=side, **kwds), + bins_arr.shape[0] - searchsorted(bins_arr[::-1], x, side=side, **kwds) ) -_PIECEWISE_DOC = """\ -Unlike `np.piecewise`, :py:func:`jax.numpy.piecewise` requires functions in -`funclist` to be traceable by JAX, as it is implemented via :func:`jax.lax.switch`. -See the :func:`jax.lax.switch` documentation for more information. -""" -@util.implements(np.piecewise, lax_description=_PIECEWISE_DOC) def piecewise(x: ArrayLike, condlist: Array | Sequence[ArrayLike], funclist: list[ArrayLike | Callable[..., Array]], *args, **kw) -> Array: + """Evaluate a function defined piecewise across the domain. + + JAX implementation of :func:`numpy.piecewise`, in terms of :func:`jax.lax.switch`. + + Note: + Unlike :func:`numpy.piecewise`, :func:`jax.numpy.piecewise` requires functions + in ``funclist`` to be traceable by JAX, as it is implemented via + :func:`jax.lax.switch`. + + Args: + x: array of input values. + condlist: boolean array or sequence of boolean arrays corresponding to the + functions in ``funclist``. If a sequence of arrays, the length of each + array must match the length of ``x`` + funclist: list of arrays or functions; must either be the same length as + ``condlist``, or have length ``len(condlist) + 1``, in which case the + last entry is the default applied when none of the conditions are True. + Alternatively, entries of ``funclist`` may be numerical values, in which + case they indicate a constant function. + args, kwargs: additional arguments are passed to each function in + ``funclist``. + + Returns: + An array which is the result of evaluating the functions on ``x`` at + the specified conditions. + + See also: + - :func:`jax.lax.switch`: choose between *N* functions based on an index. + - :func:`jax.lax.cond`: choose between two functions based on a boolean condition. + - :func:`jax.numpy.where`: choose between two results based on a boolean mask. + - :func:`jax.lax.select`: choose between two results based on a boolean mask. + - :func:`jax.lax.select_n`: choose between *N* results based on a boolean mask. + + Examples: + Here's an example of a function which is zero for negative values, and linear + for positive values: + + >>> x = jnp.array([-4, -3, -2, -1, 0, 1, 2, 3, 4]) + + >>> condlist = [x < 0, x >= 0] + >>> funclist = [lambda x: 0 * x, lambda x: x] + >>> jnp.piecewise(x, condlist, funclist) + Array([0, 0, 0, 0, 0, 1, 2, 3, 4], dtype=int32) + + ``funclist`` can also contain a simple scalar value for constant functions: + + >>> condlist = [x < 0, x >= 0] + >>> funclist = [0, lambda x: x] + >>> jnp.piecewise(x, condlist, funclist) + Array([0, 0, 0, 0, 0, 1, 2, 3, 4], dtype=int32) + + You can specify a default value by appending an extra condition to ``funclist``: + + >>> condlist = [x < -1, x > 1] + >>> funclist = [lambda x: 1 + x, lambda x: x - 1, 0] + >>> jnp.piecewise(x, condlist, funclist) + Array([-3, -2, -1, 0, 0, 0, 1, 2, 3], dtype=int32) + + ``condlist`` may also be a simple array of scalar conditions, in which case + the associated function applies to the whole range + + >>> condlist = jnp.array([False, True, False]) + >>> funclist = [lambda x: x * 0, lambda x: x * 10, lambda x: x * 100] + >>> jnp.piecewise(x, condlist, funclist) + Array([-40, -30, -20, -10, 0, 10, 20, 30, 40], dtype=int32) + """ util.check_arraylike("piecewise", x) nc, nf = len(condlist), len(funclist) if nf == nc + 1: diff --git a/jax/_src/numpy/linalg.py b/jax/_src/numpy/linalg.py index 729dc81adb90..79b47d9090af 100644 --- a/jax/_src/numpy/linalg.py +++ b/jax/_src/numpy/linalg.py @@ -28,6 +28,7 @@ from jax import jit, custom_jvp from jax import lax +from jax._src import deprecations from jax._src.lax import lax as lax_internal from jax._src.lax.lax import PrecisionLike from jax._src.lax import linalg as lax_linalg @@ -408,6 +409,8 @@ def matrix_rank( smaller than `rtol * largest_singular_value` are considered to be zero. If ``rtol`` is None (the default), a reasonable default is chosen based the floating point precision of the input. + tol: deprecated alias of the ``rtol`` argument. Will result in a + :class:`DeprecationWarning` if used. Returns: array of shape ``a.shape[-2]`` giving the matrix rank. @@ -433,11 +436,11 @@ def matrix_rank( if not isinstance(tol, DeprecatedArg): rtol = tol del tol - warnings.warn( - "The tol argument for linalg.matrix_rank is deprecated using it will soon raise " - "an error. To prepare for future releases, and suppress this warning, " - "please use rtol instead.", - DeprecationWarning, stacklevel=2 + deprecations.warn( + "jax-numpy-linalg-matrix_rank-tol", + ("The tol argument for linalg.matrix_rank is deprecated. " + "Please use rtol instead."), + stacklevel=2 ) M, = promote_dtypes_inexact(jnp.asarray(M)) if M.ndim < 2: @@ -498,7 +501,7 @@ def slogdet(a: ArrayLike, *, method: str | None = None) -> SlogdetResult: """ Compute the sign and (natural) logarithm of the determinant of an array. - JAX implementation of :func:`numpy.linalg.slotdet`. + JAX implementation of :func:`numpy.linalg.slogdet`. Args: a: array of shape ``(..., M, M)`` for which to compute the sign and log determinant. @@ -659,6 +662,19 @@ def _det_3x3(a: Array) -> Array: @custom_jvp +def _det(a): + sign, logdet = slogdet(a) + return sign * ufuncs.exp(logdet).astype(sign.dtype) + + +@_det.defjvp +def _det_jvp(primals, tangents): + x, = primals + g, = tangents + y, z = _cofactor_solve(x, g) + return y, jnp.trace(z, axis1=-1, axis2=-2) + + @jit def det(a: ArrayLike) -> Array: """ @@ -689,21 +705,12 @@ def det(a: ArrayLike) -> Array: elif len(a_shape) >= 2 and a_shape[-1] == 3 and a_shape[-2] == 3: return _det_3x3(a) elif len(a_shape) >= 2 and a_shape[-1] == a_shape[-2]: - sign, logdet = slogdet(a) - return sign * ufuncs.exp(logdet).astype(sign.dtype) + return _det(a) else: msg = "Argument to _det() must have shape [..., n, n], got {}" raise ValueError(msg.format(a_shape)) -@det.defjvp -def _det_jvp(primals, tangents): - x, = primals - g, = tangents - y, z = _cofactor_solve(x, g) - return y, jnp.trace(z, axis1=-1, axis2=-2) - - def eig(a: ArrayLike) -> tuple[Array, Array]: """ Compute the eigenvalues and eigenvectors of a square array. @@ -891,6 +898,8 @@ def pinv(a: ArrayLike, rtol: ArrayLike | None = None, determined based on the floating point precision of the dtype. hermitian: if True, then the input is assumed to be Hermitian, and a more efficient algorithm is used (default: False) + rcond: deprecated alias of the ``rtol`` argument. Will result in a + :class:`DeprecationWarning` if used. Returns: An array of shape ``(..., N, M)`` containing the pseudo-inverse of ``a``. @@ -921,11 +930,11 @@ def pinv(a: ArrayLike, rtol: ArrayLike | None = None, if not isinstance(rcond, DeprecatedArg): rtol = rcond del rcond - warnings.warn( - "The rcond argument for linalg.pinv is deprecated using it will soon " - "raise an error. To prepare for future releases, and suppress this " - "warning, please use rtol instead.", - DeprecationWarning, stacklevel=2 + deprecations.warn( + "jax-numpy-linalg-pinv-rcond", + ("The rcond argument for linalg.pinv is deprecated. " + "Please use rtol instead."), + stacklevel=2 ) return _pinv(a, rtol, hermitian) diff --git a/jax/_src/numpy/polynomial.py b/jax/_src/numpy/polynomial.py index ca4e3ebaf6a2..cce8bb8e6f7f 100644 --- a/jax/_src/numpy/polynomial.py +++ b/jax/_src/numpy/polynomial.py @@ -317,7 +317,7 @@ def poly(seq_of_zeros: ArrayLike) -> Array: - :func:`jax.numpy.roots`: Computes the roots of a polynomial for given coefficients. - Example: + Examples: Scalar inputs: @@ -407,7 +407,7 @@ def polyval(p: ArrayLike, x: ArrayLike, *, unroll: int = 16) -> Array: - :func:`jax.numpy.roots`: Computes the roots of a polynomial for given coefficients. - Example: + Examples: >>> p = jnp.array([2, 5, 1]) >>> jnp.polyval(p, 3) Array(34., dtype=float32) @@ -455,7 +455,7 @@ def polyadd(a1: ArrayLike, a2: ArrayLike) -> Array: - :func:`jax.numpy.polydiv`: Computes the quotient and remainder of polynomial division. - Example: + Examples: >>> x1 = jnp.array([2, 3]) >>> x2 = jnp.array([5, 4, 1]) >>> jnp.polyadd(x1, x2) @@ -637,7 +637,7 @@ def polymul(a1: ArrayLike, a2: ArrayLike, *, trim_leading_zeros: bool = False) - - :func:`jax.numpy.polydiv`: Computes the quotient and remainder of polynomial division. - Example: + Examples: >>> x1 = np.array([2, 1, 0]) >>> x2 = np.array([0, 5, 0, 3]) >>> np.polymul(x1, x2) @@ -702,7 +702,7 @@ def polydiv(u: ArrayLike, v: ArrayLike, *, trim_leading_zeros: bool = False) -> - :func:`jax.numpy.polysub`: Computes the difference of two polynomials. - :func:`jax.numpy.polymul`: Computes the product of two polynomials. - Example: + Examples: >>> x1 = jnp.array([5, 7, 9]) >>> x2 = jnp.array([4, 1]) >>> np.polydiv(x1, x2) @@ -755,7 +755,7 @@ def polysub(a1: ArrayLike, a2: ArrayLike) -> Array: - :func:`jax.numpy.polydiv`: Computes the quotient and remainder of polynomial division. - Example: + Examples: >>> x1 = jnp.array([2, 3]) >>> x2 = jnp.array([5, 4, 1]) >>> jnp.polysub(x1, x2) diff --git a/jax/_src/numpy/reductions.py b/jax/_src/numpy/reductions.py index fa8899325879..3436b00cfce1 100644 --- a/jax/_src/numpy/reductions.py +++ b/jax/_src/numpy/reductions.py @@ -28,8 +28,8 @@ from jax import lax from jax._src import api from jax._src import core +from jax._src import deprecations from jax._src import dtypes -from jax._src.numpy import ufuncs from jax._src.numpy.util import ( _broadcast_to, check_arraylike, _complex_elem_type, promote_dtypes_inexact, promote_dtypes_numeric, _where, implements) @@ -765,7 +765,8 @@ def _mean(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, else: normalizer = core.dimension_as_value(_axis_size(a, axis)) else: - normalizer = sum(_broadcast_to(where, np.shape(a)), axis, dtype=dtype, keepdims=keepdims) + normalizer = sum(_broadcast_to(where, np.shape(a)), axis, + dtype=computation_dtype, keepdims=keepdims) return lax.div( sum(a, axis, dtype=computation_dtype, keepdims=keepdims, where=where), @@ -967,7 +968,7 @@ def _var_promote_types(a_dtype: DTypeLike, dtype: DTypeLike | None) -> tuple[DTy msg = ("jax.numpy.var does not yet support real dtype parameters when " "computing the variance of an array of complex values. The " "semantics of numpy.var seem unclear in this case. Please comment " - "on https://github.com/google/jax/issues/2283 if this behavior is " + "on https://github.com/jax-ml/jax/issues/2283 if this behavior is " "important to you.") raise ValueError(msg) computation_dtype = dtype @@ -1786,73 +1787,257 @@ def __call__(self, a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None, out: None = None) -> Array: ... -# TODO(jakevdp): should we change these semantics to match those of numpy? -CUML_REDUCTION_LAX_DESCRIPTION = """ -Unlike the numpy counterpart, when ``dtype`` is not specified the output dtype will always -match the dtype of the input. -""" - -def _make_cumulative_reduction(np_reduction: Any, reduction: Callable[..., Array], - fill_nan: bool = False, fill_value: ArrayLike = 0, - promote_integers: bool = False) -> CumulativeReduction: - @implements(np_reduction, skip_params=['out'], - lax_description=CUML_REDUCTION_LAX_DESCRIPTION) - def cumulative_reduction(a: ArrayLike, axis: Axis = None, - dtype: DTypeLike | None = None, out: None = None) -> Array: - return _cumulative_reduction(a, _ensure_optional_axes(axis), dtype, out) - - @partial(api.jit, static_argnames=('axis', 'dtype')) - def _cumulative_reduction(a: ArrayLike, axis: Axis = None, - dtype: DTypeLike | None = None, out: None = None) -> Array: - check_arraylike(np_reduction.__name__, a) - if out is not None: - raise NotImplementedError(f"The 'out' argument to jnp.{np_reduction.__name__} " - f"is not supported.") - dtypes.check_user_dtype_supported(dtype, np_reduction.__name__) - - if axis is None or _isscalar(a): - a = lax.reshape(a, (np.size(a),)) - if axis is None: - axis = 0 +def _cumulative_reduction( + name: str, reduction: Callable[..., Array], + a: ArrayLike, axis: int | None, dtype: DTypeLike | None, out: None, + fill_nan: bool = False, fill_value: ArrayLike = 0, + promote_integers: bool = False) -> Array: + """Helper function for implementing cumulative reductions.""" + check_arraylike(name, a) + if out is not None: + raise NotImplementedError(f"The 'out' argument to jnp.{name} is not supported") + dtypes.check_user_dtype_supported(dtype, name) + + if axis is None or _isscalar(a): + a = lax.reshape(a, (np.size(a),)) + if axis is None: + axis = 0 + + a_shape = list(np.shape(a)) + num_dims = len(a_shape) + axis = _canonicalize_axis(axis, num_dims) + + if fill_nan: + a = _where(lax_internal._isnan(a), _lax_const(a, fill_value), a) + + a_type: DType = dtypes.dtype(a) + result_type: DTypeLike = dtypes.dtype(dtype or a) + if dtype is None and promote_integers or dtypes.issubdtype(result_type, np.bool_): + result_type = _promote_integer_dtype(result_type) + result_type = dtypes.canonicalize_dtype(result_type) + + if a_type != np.bool_ and dtype == np.bool_: + a = lax_internal.asarray(a).astype(np.bool_) + + a = lax.convert_element_type(a, result_type) + result = reduction(a, axis) + + # We downcast to boolean because we accumulate in integer types + if dtype is not None and dtypes.issubdtype(dtype, np.bool_): + result = lax.convert_element_type(result, np.bool_) + return result + + +@partial(api.jit, static_argnames=('axis', 'dtype')) +def cumsum(a: ArrayLike, axis: int | None = None, + dtype: DTypeLike | None = None, out: None = None) -> Array: + """Cumulative sum of elements along an axis. + + JAX implementation of :func:`numpy.cumsum`. + + Args: + a: N-dimensional array to be accumulated. + axis: integer axis along which to accumulate. If None (default), then + array will be flattened and accumulated along the flattened axis. + dtype: optionally specify the dtype of the output. If not specified, + then the output dtype will match the input dtype. + out: unused by JAX + + Returns: + An array containing the accumulated sum along the given axis. + + See also: + - :func:`jax.numpy.cumulative_sum`: cumulative sum via the array API standard. + - :meth:`jax.numpy.add.accumulate`: cumulative sum via ufunc methods. + - :func:`jax.numpy.nancumsum`: cumulative sum ignoring NaN values. + - :func:`jax.numpy.sum`: sum along axis + + Examples: + >>> x = jnp.array([[1, 2, 3], + ... [4, 5, 6]]) + >>> jnp.cumsum(x) # flattened cumulative sum + Array([ 1, 3, 6, 10, 15, 21], dtype=int32) + >>> jnp.cumsum(x, axis=1) # cumulative sum along axis 1 + Array([[ 1, 3, 6], + [ 4, 9, 15]], dtype=int32) + """ + return _cumulative_reduction("cumsum", lax.cumsum, a, axis, dtype, out) + + +@partial(api.jit, static_argnames=('axis', 'dtype')) +def cumprod(a: ArrayLike, axis: int | None = None, + dtype: DTypeLike | None = None, out: None = None) -> Array: + """Cumulative product of elements along an axis. + + JAX implementation of :func:`numpy.cumprod`. + + Args: + a: N-dimensional array to be accumulated. + axis: integer axis along which to accumulate. If None (default), then + array will be flattened and accumulated along the flattened axis. + dtype: optionally specify the dtype of the output. If not specified, + then the output dtype will match the input dtype. + out: unused by JAX + + Returns: + An array containing the accumulated product along the given axis. + + See also: + - :meth:`jax.numpy.multiply.accumulate`: cumulative product via ufunc methods. + - :func:`jax.numpy.nancumprod`: cumulative product ignoring NaN values. + - :func:`jax.numpy.prod`: product along axis + + Examples: + >>> x = jnp.array([[1, 2, 3], + ... [4, 5, 6]]) + >>> jnp.cumprod(x) # flattened cumulative product + Array([ 1, 2, 6, 24, 120, 720], dtype=int32) + >>> jnp.cumprod(x, axis=1) # cumulative product along axis 1 + Array([[ 1, 2, 6], + [ 4, 20, 120]], dtype=int32) + """ + return _cumulative_reduction("cumprod", lax.cumprod, a, axis, dtype, out) + + +@partial(api.jit, static_argnames=('axis', 'dtype')) +def nancumsum(a: ArrayLike, axis: int | None = None, + dtype: DTypeLike | None = None, out: None = None) -> Array: + """Cumulative sum of elements along an axis, ignoring NaN values. + + JAX implementation of :func:`numpy.nancumsum`. + + Args: + a: N-dimensional array to be accumulated. + axis: integer axis along which to accumulate. If None (default), then + array will be flattened and accumulated along the flattened axis. + dtype: optionally specify the dtype of the output. If not specified, + then the output dtype will match the input dtype. + out: unused by JAX + + Returns: + An array containing the accumulated sum along the given axis. + + See also: + - :func:`jax.numpy.cumsum`: cumulative sum without ignoring NaN values. + - :func:`jax.numpy.cumulative_sum`: cumulative sum via the array API standard. + - :meth:`jax.numpy.add.accumulate`: cumulative sum via ufunc methods. + - :func:`jax.numpy.sum`: sum along axis + + Examples: + >>> x = jnp.array([[1., 2., jnp.nan], + ... [4., jnp.nan, 6.]]) - a_shape = list(np.shape(a)) - num_dims = len(a_shape) - axis = _canonicalize_axis(axis, num_dims) + The standard cumulative sum will propagate NaN values: - if fill_nan: - a = _where(lax_internal._isnan(a), _lax_const(a, fill_value), a) + >>> jnp.cumsum(x) + Array([ 1., 3., nan, nan, nan, nan], dtype=float32) - result_type: DTypeLike = dtypes.dtype(dtype or a) - if dtype is None and promote_integers or dtypes.issubdtype(result_type, np.bool_): - result_type = _promote_integer_dtype(result_type) - result_type = dtypes.canonicalize_dtype(result_type) + :func:`~jax.numpy.nancumsum` will ignore NaN values, effectively replacing + them with zeros: - a = lax.convert_element_type(a, result_type) - result = reduction(a, axis) + >>> jnp.nancumsum(x) + Array([ 1., 3., 3., 7., 7., 13.], dtype=float32) - # We downcast to boolean because we accumulate in integer types - if dtypes.issubdtype(dtype, np.bool_): - result = lax.convert_element_type(result, np.bool_) - return result + Cumulative sum along axis 1: - return cumulative_reduction + >>> jnp.nancumsum(x, axis=1) + Array([[ 1., 3., 3.], + [ 4., 4., 10.]], dtype=float32) + """ + return _cumulative_reduction("nancumsum", lax.cumsum, a, axis, dtype, out, + fill_nan=True, fill_value=0) + + +@partial(api.jit, static_argnames=('axis', 'dtype')) +def nancumprod(a: ArrayLike, axis: int | None = None, + dtype: DTypeLike | None = None, out: None = None) -> Array: + """Cumulative product of elements along an axis, ignoring NaN values. + JAX implementation of :func:`numpy.nancumprod`. + + Args: + a: N-dimensional array to be accumulated. + axis: integer axis along which to accumulate. If None (default), then + array will be flattened and accumulated along the flattened axis. + dtype: optionally specify the dtype of the output. If not specified, + then the output dtype will match the input dtype. + out: unused by JAX + + Returns: + An array containing the accumulated product along the given axis. + + See also: + - :func:`jax.numpy.cumprod`: cumulative product without ignoring NaN values. + - :meth:`jax.numpy.multiply.accumulate`: cumulative product via ufunc methods. + - :func:`jax.numpy.prod`: product along axis + + Examples: + >>> x = jnp.array([[1., 2., jnp.nan], + ... [4., jnp.nan, 6.]]) + + The standard cumulative product will propagate NaN values: + + >>> jnp.cumprod(x) + Array([ 1., 2., nan, nan, nan, nan], dtype=float32) + + :func:`~jax.numpy.nancumprod` will ignore NaN values, effectively replacing + them with ones: + + >>> jnp.nancumprod(x) + Array([ 1., 2., 2., 8., 8., 48.], dtype=float32) + + Cumulative product along axis 1: + + >>> jnp.nancumprod(x, axis=1) + Array([[ 1., 2., 2.], + [ 4., 4., 24.]], dtype=float32) + """ + return _cumulative_reduction("nancumprod", lax.cumprod, a, axis, dtype, out, + fill_nan=True, fill_value=1) + + +@partial(api.jit, static_argnames=('axis', 'dtype')) +def _cumsum_with_promotion(a: ArrayLike, axis: int | None = None, + dtype: DTypeLike | None = None, out: None = None) -> Array: + """Utility function to compute cumsum with integer promotion.""" + return _cumulative_reduction("_cumsum_with_promotion", lax.cumsum, + a, axis, dtype, out, promote_integers=True) -cumsum = _make_cumulative_reduction(np.cumsum, lax.cumsum, fill_nan=False) -cumprod = _make_cumulative_reduction(np.cumprod, lax.cumprod, fill_nan=False) -nancumsum = _make_cumulative_reduction(np.nancumsum, lax.cumsum, - fill_nan=True, fill_value=0) -nancumprod = _make_cumulative_reduction(np.nancumprod, lax.cumprod, - fill_nan=True, fill_value=1) -_cumsum_with_promotion = _make_cumulative_reduction( - np.cumsum, lax.cumsum, fill_nan=False, promote_integers=True -) -@implements(getattr(np, 'cumulative_sum', None)) def cumulative_sum( x: ArrayLike, /, *, axis: int | None = None, dtype: DTypeLike | None = None, include_initial: bool = False) -> Array: + """Cumulative sum along the axis of an array. + + JAX implementation of :func:`numpy.cumulative_sum`. + + Args: + x: N-dimensional array + axis: integer axis along which to accumulate. If ``x`` is one-dimensional, + this argument is optional. + dtype: optional dtype of the output. + include_initial: if True, then include the initial value in the cumulative + sum. Default is False. + + Returns: + An array containing the accumulated values. + + See Also: + - :func:`jax.numpy.cumsum`: alternative API for cumulative sum. + - :func:`jax.numpy.nancumsum`: cumulative sum while ignoring NaN values. + - :func:`jax.numpy.add.accumulate`: cumulative sum via the ufunc API. + + Examples: + >>> x = jnp.array([[1, 2, 3], + ... [4, 5, 6]]) + >>> jnp.cumulative_sum(x, axis=1) + Array([[ 1, 3, 6], + [ 4, 9, 15]], dtype=int32) + >>> jnp.cumulative_sum(x, axis=1, include_initial=True) + Array([[ 0, 1, 3, 6], + [ 0, 4, 9, 15]], dtype=int32) + """ check_arraylike("cumulative_sum", x) x = lax_internal.asarray(x) if x.ndim == 0: @@ -1882,36 +2067,116 @@ def cumulative_sum( # Quantiles # TODO(jakevdp): interpolation argument deprecated 2024-05-16 -@implements(np.quantile, skip_params=['out', 'overwrite_input']) @partial(api.jit, static_argnames=('axis', 'overwrite_input', 'interpolation', 'keepdims', 'method')) def quantile(a: ArrayLike, q: ArrayLike, axis: int | tuple[int, ...] | None = None, out: None = None, overwrite_input: bool = False, method: str = "linear", keepdims: bool = False, *, interpolation: DeprecatedArg | str = DeprecatedArg()) -> Array: + """Compute the quantile of the data along the specified axis. + + JAX implementation of :func:`numpy.quantile`. + + Args: + a: N-dimensional array input. + q: scalar or 1-dimensional array specifying the desired quantiles. ``q`` + should contain floating-point values between ``0.0`` and ``1.0``. + axis: optional axis or tuple of axes along which to compute the quantile + out: not implemented by JAX; will error if not None + overwrite_input: not implemented by JAX; will error if not False + method: specify the interpolation method to use. Options are one of + ``["linear", "lower", "higher", "midpoint", "nearest"]``. + default is ``linear``. + keepdims: if True, then the returned array will have the same number of + dimensions as the input. Default is False. + interpolation: deprecated alias of the ``method`` argument. Will result + in a :class:`DeprecationWarning` if used. + + Returns: + An array containing the specified quantiles along the specified axes. + + See also: + - :func:`jax.numpy.nanquantile`: compute the quantile while ignoring NaNs + - :func:`jax.numpy.percentile`: compute the percentile (0-100) + + Examples: + Computing the median and quartiles of an array, with linear interpolation: + + >>> x = jnp.arange(10) + >>> q = jnp.array([0.25, 0.5, 0.75]) + >>> jnp.quantile(x, q) + Array([2.25, 4.5 , 6.75], dtype=float32) + + Computing the quartiles using nearest-value interpolation: + + >>> jnp.quantile(x, q, method='nearest') + Array([2., 4., 7.], dtype=float32) + """ check_arraylike("quantile", a, q) if overwrite_input or out is not None: - msg = ("jax.numpy.quantile does not support overwrite_input=True or " - "out != None") - raise ValueError(msg) + raise ValueError("jax.numpy.quantile does not support overwrite_input=True " + "or out != None") if not isinstance(interpolation, DeprecatedArg): - warnings.warn("The interpolation= argument to 'quantile' is deprecated. " - "Use 'method=' instead.", DeprecationWarning, stacklevel=2) + deprecations.warn( + "jax-numpy-quantile-interpolation", + ("The interpolation= argument to 'quantile' is deprecated. " + "Use 'method=' instead."), stacklevel=2) method = interpolation return _quantile(lax_internal.asarray(a), lax_internal.asarray(q), axis, method, keepdims, False) # TODO(jakevdp): interpolation argument deprecated 2024-05-16 -@implements(np.nanquantile, skip_params=['out', 'overwrite_input']) @partial(api.jit, static_argnames=('axis', 'overwrite_input', 'interpolation', 'keepdims', 'method')) def nanquantile(a: ArrayLike, q: ArrayLike, axis: int | tuple[int, ...] | None = None, out: None = None, overwrite_input: bool = False, method: str = "linear", keepdims: bool = False, *, interpolation: DeprecatedArg | str = DeprecatedArg()) -> Array: + """Compute the quantile of the data along the specified axis, ignoring NaNs. + + JAX implementation of :func:`numpy.nanquantile`. + + Args: + a: N-dimensional array input. + q: scalar or 1-dimensional array specifying the desired quantiles. ``q`` + should contain floating-point values between ``0.0`` and ``1.0``. + axis: optional axis or tuple of axes along which to compute the quantile + out: not implemented by JAX; will error if not None + overwrite_input: not implemented by JAX; will error if not False + method: specify the interpolation method to use. Options are one of + ``["linear", "lower", "higher", "midpoint", "nearest"]``. + default is ``linear``. + keepdims: if True, then the returned array will have the same number of + dimensions as the input. Default is False. + interpolation: deprecated alias of the ``method`` argument. Will result + in a :class:`DeprecationWarning` if used. + + Returns: + An array containing the specified quantiles along the specified axes. + + See also: + - :func:`jax.numpy.quantile`: compute the quantile without ignoring nans + - :func:`jax.numpy.nanpercentile`: compute the percentile (0-100) + + Examples: + Computing the median and quartiles of a 1D array: + + >>> x = jnp.array([0, 1, 2, jnp.nan, 3, 4, 5, 6]) + >>> q = jnp.array([0.25, 0.5, 0.75]) + + Because of the NaN value, :func:`jax.numpy.quantile` returns all NaNs, + while :func:`~jax.numpy.nanquantile` ignores them: + + >>> jnp.quantile(x, q) + Array([nan, nan, nan], dtype=float32) + >>> jnp.nanquantile(x, q) + Array([1.5, 3. , 4.5], dtype=float32) + """ check_arraylike("nanquantile", a, q) if overwrite_input or out is not None: msg = ("jax.numpy.nanquantile does not support overwrite_input=True or " "out != None") raise ValueError(msg) if not isinstance(interpolation, DeprecatedArg): - warnings.warn("The interpolation= argument to 'nanquantile' is deprecated. " - "Use 'method=' instead.", DeprecationWarning, stacklevel=2) + deprecations.warn( + "jax-numpy-quantile-interpolation", + ("The interpolation= argument to 'nanquantile' is deprecated. " + "Use 'method=' instead."), stacklevel=2) method = interpolation return _quantile(lax_internal.asarray(a), lax_internal.asarray(q), axis, method, keepdims, True) @@ -1957,9 +2222,9 @@ def _quantile(a: Array, q: Array, axis: int | tuple[int, ...] | None, a_shape = a.shape if squash_nans: - a = _where(ufuncs.isnan(a), np.nan, a) # Ensure nans are positive so they sort to the end. + a = _where(lax_internal._isnan(a), np.nan, a) # Ensure nans are positive so they sort to the end. a = lax.sort(a, dimension=axis) - counts = sum(ufuncs.logical_not(ufuncs.isnan(a)), axis=axis, dtype=q.dtype, keepdims=keepdims) + counts = sum(lax_internal.bitwise_not(lax_internal._isnan(a)), axis=axis, dtype=q.dtype, keepdims=keepdims) shape_after_reduction = counts.shape q = lax.expand_dims( q, tuple(range(q_ndim, len(shape_after_reduction) + q_ndim))) @@ -1985,7 +2250,7 @@ def _quantile(a: Array, q: Array, axis: int | tuple[int, ...] | None, index[axis] = high high_value = a[tuple(index)] else: - a = _where(any(ufuncs.isnan(a), axis=axis, keepdims=True), np.nan, a) + a = _where(any(lax_internal._isnan(a), axis=axis, keepdims=True), np.nan, a) a = lax.sort(a, dimension=axis) n = lax.convert_element_type(a_shape[axis], lax_internal._dtype(q)) q = lax.mul(q, n - 1) @@ -2038,33 +2303,116 @@ def _quantile(a: Array, q: Array, axis: int | tuple[int, ...] | None, return lax.convert_element_type(result, a.dtype) # TODO(jakevdp): interpolation argument deprecated 2024-05-16 -@implements(np.percentile, skip_params=['out', 'overwrite_input']) @partial(api.jit, static_argnames=('axis', 'overwrite_input', 'interpolation', 'keepdims', 'method')) def percentile(a: ArrayLike, q: ArrayLike, axis: int | tuple[int, ...] | None = None, out: None = None, overwrite_input: bool = False, method: str = "linear", keepdims: bool = False, *, interpolation: str | DeprecatedArg = DeprecatedArg()) -> Array: + """Compute the percentile of the data along the specified axis. + + JAX implementation of :func:`numpy.percentile`. + + Args: + a: N-dimensional array input. + q: scalar or 1-dimensional array specifying the desired quantiles. ``q`` + should contain integer or floating point values between ``0`` and ``100``. + axis: optional axis or tuple of axes along which to compute the quantile + out: not implemented by JAX; will error if not None + overwrite_input: not implemented by JAX; will error if not False + method: specify the interpolation method to use. Options are one of + ``["linear", "lower", "higher", "midpoint", "nearest"]``. + default is ``linear``. + keepdims: if True, then the returned array will have the same number of + dimensions as the input. Default is False. + interpolation: deprecated alias of the ``method`` argument. Will result + in a :class:`DeprecationWarning` if used. + + Returns: + An array containing the specified percentiles along the specified axes. + + See also: + - :func:`jax.numpy.quantile`: compute the quantile (0.0-1.0) + - :func:`jax.numpy.nanpercentile`: compute the percentile while ignoring NaNs + + Examples: + Computing the median and quartiles of a 1D array: + + >>> x = jnp.array([0, 1, 2, 3, 4, 5, 6]) + >>> q = jnp.array([25, 50, 75]) + >>> jnp.percentile(x, q) + Array([1.5, 3. , 4.5], dtype=float32) + + Computing the same percentiles with nearest rather than linear interpolation: + + >>> jnp.percentile(x, q, method='nearest') + Array([1., 3., 4.], dtype=float32) + """ check_arraylike("percentile", a, q) q, = promote_dtypes_inexact(q) if not isinstance(interpolation, DeprecatedArg): - warnings.warn("The interpolation= argument to 'percentile' is deprecated. " - "Use 'method=' instead.", DeprecationWarning, stacklevel=2) + deprecations.warn( + "jax-numpy-quantile-interpolation", + ("The interpolation= argument to 'percentile' is deprecated. " + "Use 'method=' instead."), stacklevel=2) method = interpolation return quantile(a, q / 100, axis=axis, out=out, overwrite_input=overwrite_input, method=method, keepdims=keepdims) # TODO(jakevdp): interpolation argument deprecated 2024-05-16 -@implements(np.nanpercentile, skip_params=['out', 'overwrite_input']) @partial(api.jit, static_argnames=('axis', 'overwrite_input', 'interpolation', 'keepdims', 'method')) def nanpercentile(a: ArrayLike, q: ArrayLike, axis: int | tuple[int, ...] | None = None, out: None = None, overwrite_input: bool = False, method: str = "linear", keepdims: bool = False, *, interpolation: str | DeprecatedArg = DeprecatedArg()) -> Array: + """Compute the percentile of the data along the specified axis, ignoring NaN values. + + JAX implementation of :func:`numpy.nanpercentile`. + + Args: + a: N-dimensional array input. + q: scalar or 1-dimensional array specifying the desired quantiles. ``q`` + should contain integer or floating point values between ``0`` and ``100``. + axis: optional axis or tuple of axes along which to compute the quantile + out: not implemented by JAX; will error if not None + overwrite_input: not implemented by JAX; will error if not False + method: specify the interpolation method to use. Options are one of + ``["linear", "lower", "higher", "midpoint", "nearest"]``. + default is ``linear``. + keepdims: if True, then the returned array will have the same number of + dimensions as the input. Default is False. + interpolation: deprecated alias of the ``method`` argument. Will result + in a :class:`DeprecationWarning` if used. + + Returns: + An array containing the specified percentiles along the specified axes. + + See also: + - :func:`jax.numpy.nanquantile`: compute the nan-aware quantile (0.0-1.0) + - :func:`jax.numpy.percentile`: compute the percentile without special + handling of NaNs. + + Examples: + Computing the median and quartiles of a 1D array: + + >>> x = jnp.array([0, 1, 2, jnp.nan, 3, 4, 5, 6]) + >>> q = jnp.array([25, 50, 75]) + + Because of the NaN value, :func:`jax.numpy.percentile` returns all NaNs, + while :func:`~jax.numpy.nanpercentile` ignores them: + + >>> jnp.percentile(x, q) + Array([nan, nan, nan], dtype=float32) + >>> jnp.nanpercentile(x, q) + Array([1.5, 3. , 4.5], dtype=float32) + """ check_arraylike("nanpercentile", a, q) - q = ufuncs.true_divide(q, 100.0) + q, = promote_dtypes_inexact(q) + q = q / 100 if not isinstance(interpolation, DeprecatedArg): - warnings.warn("The interpolation= argument to 'nanpercentile' is deprecated. " - "Use 'method=' instead.", DeprecationWarning, stacklevel=2) + deprecations.warn( + "jax-numpy-quantile-interpolation", + ("The interpolation= argument to 'nanpercentile' is deprecated. " + "Use 'method=' instead."), stacklevel=2) method = interpolation return nanquantile(a, q, axis=axis, out=out, overwrite_input=overwrite_input, method=method, keepdims=keepdims) diff --git a/jax/_src/numpy/setops.py b/jax/_src/numpy/setops.py index 2e953c67abd6..6491a7617d8d 100644 --- a/jax/_src/numpy/setops.py +++ b/jax/_src/numpy/setops.py @@ -21,6 +21,7 @@ import numpy as np +import jax from jax import jit from jax import lax @@ -33,7 +34,7 @@ sort, where, zeros) from jax._src.numpy.reductions import any, cumsum from jax._src.numpy.ufuncs import isnan -from jax._src.numpy.util import check_arraylike +from jax._src.numpy.util import check_arraylike, promote_dtypes from jax._src.util import canonicalize_axis from jax._src.typing import Array, ArrayLike @@ -41,24 +42,50 @@ _lax_const = lax_internal._const -@partial(jit, static_argnames=('invert',)) -def _in1d(ar1: ArrayLike, ar2: ArrayLike, invert: bool) -> Array: +@partial(jit, static_argnames=('assume_unique', 'invert', 'method')) +def _in1d(ar1: ArrayLike, ar2: ArrayLike, invert: bool, + method='auto', assume_unique=False) -> Array: check_arraylike("in1d", ar1, ar2) - ar1_flat = ravel(ar1) - ar2_flat = ravel(ar2) - # Note: an algorithm based on searchsorted has better scaling, but in practice - # is very slow on accelerators because it relies on lax control flow. If XLA - # ever supports binary search natively, we should switch to this: - # ar2_flat = jnp.sort(ar2_flat) - # ind = jnp.searchsorted(ar2_flat, ar1_flat) - # if invert: - # return ar1_flat != ar2_flat[ind] - # else: - # return ar1_flat == ar2_flat[ind] - if invert: - return (ar1_flat[:, None] != ar2_flat[None, :]).all(-1) + arr1, arr2 = promote_dtypes(ar1, ar2) + arr1, arr2 = arr1.ravel(), arr2.ravel() + if arr1.size == 0 or arr2.size == 0: + return (ones if invert else zeros)(arr1.shape, dtype=bool) + if method in ['auto', 'compare_all']: + if invert: + return (arr1[:, None] != arr2[None, :]).all(-1) + else: + return (arr1[:, None] == arr2[None, :]).any(-1) + elif method == 'binary_search': + arr2 = lax.sort(arr2) + ind = jax.numpy.searchsorted(arr2, arr1) + if invert: + return arr1 != arr2[ind] + else: + return arr1 == arr2[ind] + elif method == 'sort': + if assume_unique: + ind_out: slice | Array = slice(None) + else: + arr1, ind_out = unique(arr1, size=len(arr1), return_inverse=True, fill_value=arr2.max()) + aux, ind = lax.sort_key_val(concatenate([arr1, arr2]), arange(arr1.size + arr2.size)) + if invert: + return ones(arr1.shape, bool).at[ind[:-1]].set(aux[1:] != aux[:-1], mode='drop')[ind_out] + else: + return zeros(arr1.shape, bool).at[ind[:-1]].set(aux[1:] == aux[:-1], mode='drop')[ind_out] else: - return (ar1_flat[:, None] == ar2_flat[None, :]).any(-1) + raise ValueError(f"{method=} is not implemented; options are " + "'compare_all', 'binary_search', 'sort', and 'auto'") + + +def _concat_unique(arr1: Array, arr2: Array) -> tuple[Array, Array]: + """Utility to concatenate the unique values from two arrays.""" + arr1, arr2 = ravel(arr1), ravel(arr2) + arr1, num_unique1 = _unique(arr1, axis=0, size=arr1.size, return_true_size=True) + arr2, num_unique2 = _unique(arr2, axis=0, size=arr2.size, return_true_size=True) + arr = zeros(arr1.size + arr2.size, dtype=dtypes.result_type(arr1, arr2)) + arr = lax.dynamic_update_slice(arr, arr1, (0,)) + arr = lax.dynamic_update_slice(arr, arr2, (num_unique1,)) + return arr, num_unique1 + num_unique2 def setdiff1d(ar1: ArrayLike, ar2: ArrayLike, assume_unique: bool = False, @@ -68,10 +95,9 @@ def setdiff1d(ar1: ArrayLike, ar2: ArrayLike, assume_unique: bool = False, JAX implementation of :func:`numpy.setdiff1d`. Because the size of the output of ``setdiff1d`` is data-dependent, the function - semantics are not typically compatible with :func:`~jax.jit` and other JAX - transformations. The JAX version adds the optional ``size`` argument which - must be specified statically for ``jnp.setdiff1d`` to be used in such contexts. - transformations. + is not typically compatible with :func:`~jax.jit` and other JAX transformations. + The JAX version adds the optional ``size`` argument which must be specified statically + for ``jnp.setdiff1d`` to be used in such contexts. Args: ar1: first array of elements to be differenced. @@ -108,7 +134,7 @@ def setdiff1d(ar1: ArrayLike, ar2: ArrayLike, assume_unique: bool = False, Traceback (most recent call last): ... ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected: traced array with shape int32[4]. - The error occurred while tracing the function setdiff1d at /Users/vanderplas/github/google/jax/jax/_src/numpy/setops.py:64 for jit. This concrete value was not available in Python because it depends on the value of the argument ar1. + The error occurred while tracing the function setdiff1d at /Users/vanderplas/github/jax-ml/jax/jax/_src/numpy/setops.py:64 for jit. This concrete value was not available in Python because it depends on the value of the argument ar1. In order to ensure statically-known output shapes, you can pass a static ``size`` argument: @@ -138,7 +164,7 @@ def setdiff1d(ar1: ArrayLike, ar2: ArrayLike, assume_unique: bool = False, return full_like(arr1, fill_value, shape=size or 0) if not assume_unique: arr1 = cast(Array, unique(arr1, size=size and arr1.size)) - mask = _in1d(arr1, ar2, invert=True) + mask = _in1d(arr1, ar2, invert=True, assume_unique=assume_unique) if size is None: return arr1[mask] else: @@ -156,10 +182,9 @@ def union1d(ar1: ArrayLike, ar2: ArrayLike, JAX implementation of :func:`numpy.union1d`. Because the size of the output of ``union1d`` is data-dependent, the function - semantics are not typically compatible with :func:`~jax.jit` and other JAX - transformations. The JAX version adds the optional ``size`` argument which - must be specified statically for ``jnp.union1d`` to be used in such contexts. - transformations. + is not typically compatible with :func:`~jax.jit` and other JAX transformations. + The JAX version adds the optional ``size`` argument which must be specified + statically for ``jnp.union1d`` to be used in such contexts. Args: ar1: first array of elements to be unioned. @@ -192,7 +217,7 @@ def union1d(ar1: ArrayLike, ar2: ArrayLike, Traceback (most recent call last): ... ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected: traced array with shape int32[4]. - The error occurred while tracing the function union1d at /Users/vanderplas/github/google/jax/jax/_src/numpy/setops.py:101 for jit. This concrete value was not available in Python because it depends on the value of the argument ar1. + The error occurred while tracing the function union1d at /Users/vanderplas/github/jax-ml/jax/jax/_src/numpy/setops.py:101 for jit. This concrete value was not available in Python because it depends on the value of the argument ar1. In order to ensure statically-known output shapes, you can pass a static ``size`` argument: @@ -222,7 +247,39 @@ def union1d(ar1: ArrayLike, ar2: ArrayLike, return cast(Array, out) -def setxor1d(ar1: ArrayLike, ar2: ArrayLike, assume_unique: bool = False) -> Array: +@partial(jit, static_argnames=['assume_unique', 'size']) +def _setxor1d_size(arr1: Array, arr2: Array, fill_value: ArrayLike | None, *, + assume_unique: bool, size: int, ) -> Array: + # Ensured by caller + assert arr1.ndim == arr2.ndim == 1 + assert arr1.dtype == arr2.dtype + + if assume_unique: + arr = concatenate([arr1, arr2]) + aux = sort(concatenate([arr1, arr2])) + flag = concatenate((bool(aux.size), aux[1:] != aux[:-1], True), axis=None) + else: + arr, num_unique = _concat_unique(arr1, arr2) + mask = arange(arr.size + 1) < num_unique + 1 + _, aux = lax.sort([~mask[1:], arr], is_stable=True, num_keys=2) + flag = mask & concatenate((bool(aux.size), aux[1:] != aux[:-1], False), + axis=None).at[num_unique].set(True) + aux_mask = flag[1:] & flag[:-1] + num_results = aux_mask.sum() + if aux.size: + indices = nonzero(aux_mask, size=size, fill_value=len(aux))[0] + vals = aux.at[indices].get(mode='fill', fill_value=0) + else: + vals = zeros(size, aux.dtype) + if fill_value is None: + vals = where(arange(len(vals)) < num_results, vals, vals.max()) + return where(arange(len(vals)) < num_results, vals, vals.min()) + else: + return where(arange(len(vals)) < num_results, vals, fill_value) + + +def setxor1d(ar1: ArrayLike, ar2: ArrayLike, assume_unique: bool = False, *, + size: int | None = None, fill_value: ArrayLike | None = None) -> Array: """Compute the set-wise xor of elements in two arrays. JAX implementation of :func:`numpy.setxor1d`. @@ -236,6 +293,12 @@ def setxor1d(ar1: ArrayLike, ar2: ArrayLike, assume_unique: bool = False) -> Arr assume_unique: if True, assume the input arrays contain unique values. This allows a more efficient implementation, but if ``assume_unique`` is True and the input arrays contain duplicates, the behavior is undefined. default: False. + size: if specified, return only the first ``size`` sorted elements. If there are fewer + elements than ``size`` indicates, the return value will be padded with ``fill_value``, + and returned indices will be padded with an out-of-bound index. + fill_value: when ``size`` is specified and there are fewer than the indicated number of + elements, fill the remaining entries ``fill_value``. Defaults to the smallest value + in the xor result. Returns: An array of values that are found in exactly one of the input arrays. @@ -252,50 +315,119 @@ def setxor1d(ar1: ArrayLike, ar2: ArrayLike, assume_unique: bool = False) -> Arr Array([1, 2, 5, 6], dtype=int32) """ check_arraylike("setxor1d", ar1, ar2) - ar1 = core.concrete_or_error(None, ar1, "The error arose in setxor1d()") - ar2 = core.concrete_or_error(None, ar2, "The error arose in setxor1d()") + arr1, arr2 = promote_dtypes(ravel(ar1), ravel(ar2)) + del ar1, ar2 - ar1 = ravel(ar1) - ar2 = ravel(ar2) + if size is not None: + return _setxor1d_size(arr1, arr2, fill_value=fill_value, + assume_unique=assume_unique, size=size) if not assume_unique: - ar1 = unique(ar1) - ar2 = unique(ar2) - - aux = concatenate((ar1, ar2)) + arr1 = unique(arr1) + arr2 = unique(arr2) + aux = concatenate((arr1, arr2)) if aux.size == 0: return aux - aux = sort(aux) - flag = concatenate((array([True]), aux[1:] != aux[:-1], array([True]))) + flag = concatenate((True, aux[1:] != aux[:-1], True), axis=None) return aux[flag[1:] & flag[:-1]] @partial(jit, static_argnames=['return_indices']) -def _intersect1d_sorted_mask(ar1: ArrayLike, ar2: ArrayLike, return_indices: bool = False) -> tuple[Array, ...]: - # JIT-compatible helper function for intersect1d - ar = concatenate((ar1, ar2)) +def _intersect1d_sorted_mask(arr1: Array, arr2: Array, + return_indices: bool) -> tuple[Array, Array, Array | None]: + """JIT-compatible helper function for intersect1d""" + assert arr1.ndim == arr2.ndim == 1 + arr = concatenate((arr1, arr2)) if return_indices: - iota = lax.broadcasted_iota(np.int64, np.shape(ar), dimension=0) - aux, indices = lax.sort_key_val(ar, iota) + iota = lax.broadcasted_iota(np.int64, np.shape(arr), dimension=0) + aux, indices = lax.sort_key_val(arr, iota) else: - aux = sort(ar) - + aux = sort(arr) + indices = None mask = aux[1:] == aux[:-1] + return aux, mask, indices + + +@partial(jit, static_argnames=['fill_value', 'assume_unique', 'size', 'return_indices']) +def _intersect1d_size(arr1: Array, arr2: Array, fill_value: ArrayLike | None, assume_unique: bool, + size: int, return_indices: bool) -> Array | tuple[Array, Array, Array]: + """Jit-compatible helper function for intersect1d with size specified.""" + # Ensured by caller + assert arr1.ndim == arr2.ndim == 1 + assert arr1.dtype == arr2.dtype + + # First step: we concatenate the unique values of arr1 and arr2. + # The resulting values are: + # num_unique1/num_unique2: number of unique values in arr1/arr2 + # aux[:num_unique1 + num_unique2] contains the sorted concatenated + # unique values drawn from arr1 and arr2. + # aux_sorted_indices: indices mapping aux to concatenation of arr1 and arr2 + # ind1[:num_unique1], ind2[:num_unique2]: indices of sorted unique + # values in arr1/arr2 + # mask: boolean mask of relevant values in aux & aux_sorted_indices + if assume_unique: + ind1, num_unique1 = arange(arr1.size), asarray(arr1.size) + ind2, num_unique2 = arange(arr2.size), asarray(arr2.size) + arr = concatenate([arr1, arr2]) + aux, aux_sort_indices = lax.sort([arr, arange(arr.size)], is_stable=True, num_keys=1) + mask = ones(arr.size, dtype=bool) + else: + arr1, ind1, num_unique1 = _unique(arr1, 0, size=arr1.size, return_index=True, return_true_size=True, fill_value=0) + arr2, ind2, num_unique2 = _unique(arr2, 0, size=arr2.size, return_index=True, return_true_size=True, fill_value=0) + arr = zeros(arr1.size + arr2.size, dtype=dtypes.result_type(arr1, arr2)) + arr = lax.dynamic_update_slice(arr, arr1, (0,)) + arr = lax.dynamic_update_slice(arr, arr2, (num_unique1,)) + mask = arange(arr.size) < num_unique1 + num_unique2 + _, aux, aux_sort_indices = lax.sort([~mask, arr, arange(arr.size)], is_stable=True, num_keys=2) + + # Second step: extract the intersection values from aux + # Since we've sorted the unique entries in arr1 and arr2, any place where + # adjacent entries are equal is a value of the intersection. + # relevant results here: + # num_results: number of values in the intersection of arr1 and arr2 + # vals: array where vals[:num_results] contains the intersection of arr1 and arr2, + # and vals[num_results:] contains the appropriate fill_value. + aux_mask = (aux[1:] == aux[:-1]) & mask[1:] + num_results = aux_mask.sum() + if aux.size: + val_indices = nonzero(aux_mask, size=size, fill_value=aux.size)[0] + vals = aux.at[val_indices].get(mode='fill', fill_value=0) + else: + vals = zeros(size, aux.dtype) + if fill_value is None: + vals = where(arange(len(vals)) < num_results, vals, vals.max()) + vals = where(arange(len(vals)) < num_results, vals, vals.min()) + else: + vals = where(arange(len(vals)) < num_results, vals, fill_value) + + # Third step: extract the indices of the intersection values. + # This requires essentially unwinding aux_sort_indices and ind1/ind2 to find + # the appropriate list of indices from the original arrays. if return_indices: - return aux, mask, indices + arr1_indices = aux_sort_indices.at[val_indices].get(mode='fill', fill_value=arr1.size) + arr1_indices = where(arange(len(arr1_indices)) < num_results, arr1_indices, arr1.size) + arr2_indices = aux_sort_indices.at[val_indices + 1].get(mode='fill', fill_value=arr2.size) - num_unique1 + arr2_indices = where(arange(len(arr2_indices)) < num_results, arr2_indices, arr2.size) + if not assume_unique: + arr1_indices = ind1.at[arr1_indices].get(mode='fill', fill_value=ind1.size) + arr2_indices = ind2.at[arr2_indices].get(mode='fill', fill_value=ind2.size) + return vals, arr1_indices, arr2_indices else: - return aux, mask + return vals def intersect1d(ar1: ArrayLike, ar2: ArrayLike, assume_unique: bool = False, - return_indices: bool = False) -> Array | tuple[Array, Array, Array]: + return_indices: bool = False, *, size: int | None = None, + fill_value: ArrayLike | None = None) -> Array | tuple[Array, Array, Array]: """Compute the set intersection of two 1D arrays. JAX implementation of :func:`numpy.intersect1d`. - Because the size of the output of ``intersect1d`` is data-dependent, the function is not - compatible with JIT or other JAX transformations. + Because the size of the output of ``intersect1d`` is data-dependent, the function + is not typically compatible with :func:`~jax.jit` and other JAX transformations. + The JAX version adds the optional ``size`` argument which must be specified + statically for ``jnp.intersect1d`` to be used in such contexts. Args: ar1: first array of values to intersect. @@ -305,6 +437,12 @@ def intersect1d(ar1: ArrayLike, ar2: ArrayLike, assume_unique: bool = False, arrays contain duplicates, the behavior is undefined. default: False. return_indices: If True, return arrays of indices specifying where the intersected values first appear in the input arrays. + size: if specified, return only the first ``size`` sorted elements. If there are fewer + elements than ``size`` indicates, the return value will be padded with ``fill_value``, + and returned indices will be padded with an out-of-bound index. + fill_value: when ``size`` is specified and there are fewer than the indicated number of + elements, fill the remaining entries ``fill_value``. Defaults to the smallest value + in the intersection. Returns: An array ``intersection``, or if ``return_indices=True``, a tuple of arrays @@ -353,41 +491,42 @@ def intersect1d(ar1: ArrayLike, ar2: ArrayLike, assume_unique: bool = False, Array(True, dtype=bool) """ check_arraylike("intersect1d", ar1, ar2) - ar1 = core.concrete_or_error(None, ar1, "The error arose in intersect1d()") - ar2 = core.concrete_or_error(None, ar2, "The error arose in intersect1d()") + arr1, arr2 = promote_dtypes(ar1, ar2) + del ar1, ar2 + arr1 = ravel(arr1) + arr2 = ravel(arr2) + + if size is not None: + return _intersect1d_size(arr1, arr2, return_indices=return_indices, + size=size, fill_value=fill_value, assume_unique=assume_unique) if not assume_unique: if return_indices: - ar1, ind1 = unique(ar1, return_index=True) - ar2, ind2 = unique(ar2, return_index=True) + arr1, ind1 = unique(arr1, return_index=True) + arr2, ind2 = unique(arr2, return_index=True) else: - ar1 = unique(ar1) - ar2 = unique(ar2) - else: - ar1 = ravel(ar1) - ar2 = ravel(ar2) + arr1 = unique(arr1) + arr2 = unique(arr2) - if return_indices: - aux, mask, aux_sort_indices = _intersect1d_sorted_mask(ar1, ar2, return_indices) - else: - aux, mask = _intersect1d_sorted_mask(ar1, ar2, return_indices) + aux, mask, aux_sort_indices = _intersect1d_sorted_mask(arr1, arr2, return_indices) int1d = aux[:-1][mask] if return_indices: - ar1_indices = aux_sort_indices[:-1][mask] - ar2_indices = aux_sort_indices[1:][mask] - np.size(ar1) + assert aux_sort_indices is not None + arr1_indices = aux_sort_indices[:-1][mask] + arr2_indices = aux_sort_indices[1:][mask] - np.size(arr1) if not assume_unique: - ar1_indices = ind1[ar1_indices] - ar2_indices = ind2[ar2_indices] - - return int1d, ar1_indices, ar2_indices + arr1_indices = ind1[arr1_indices] + arr2_indices = ind2[arr2_indices] + return int1d, arr1_indices, arr2_indices else: return int1d def isin(element: ArrayLike, test_elements: ArrayLike, - assume_unique: bool = False, invert: bool = False) -> Array: + assume_unique: bool = False, invert: bool = False, *, + method='auto') -> Array: """Determine whether elements in ``element`` appear in ``test_elements``. JAX implementation of :func:`numpy.isin`. @@ -397,7 +536,11 @@ def isin(element: ArrayLike, test_elements: ArrayLike, test_elements: N-dimensional array of test values to check for the presence of each element. invert: If True, return ``~isin(element, test_elements)``. Default is False. - assume_unique: unused by JAX + assume_unique: if true, input arrays are assumed to be unique, which can + lead to more efficient computation. If the input arrays are not unique + and assume_unique is set to True, the results are undefined. + method: string specifying the method used to compute the result. Supported + options are 'compare_all', 'binary_search', 'sort', and 'auto' (default). Returns: A boolean array of shape ``element.shape`` that specifies whether each element @@ -409,9 +552,9 @@ def isin(element: ArrayLike, test_elements: ArrayLike, >>> jnp.isin(elements, test_elements) Array([ True, False, True, False], dtype=bool) """ - del assume_unique # unused check_arraylike("isin", element, test_elements) - result = _in1d(element, test_elements, invert=invert) + result = _in1d(element, test_elements, invert=invert, + method=method, assume_unique=assume_unique) return result.reshape(np.shape(element)) @@ -517,9 +660,9 @@ def unique(ar: ArrayLike, return_index: bool = False, return_inverse: bool = Fal JAX implementation of :func:`numpy.unique`. Because the size of the output of ``unique`` is data-dependent, the function - semantics are not typically compatible with :func:`~jax.jit` and other JAX - transformations. The JAX version adds the optional ``size`` argument which - must be specified statically for ``jnp.unique`` to be used in such contexts. + is not typically compatible with :func:`~jax.jit` and other JAX transformations. + The JAX version adds the optional ``size`` argument which must be specified + statically for ``jnp.unique`` to be used in such contexts. Args: ar: N-dimensional array from which unique values will be extracted. @@ -729,9 +872,9 @@ def unique_all(x: ArrayLike, /, *, size: int | None = None, and `equal_nan` set to True. Because the size of the output of ``unique_all`` is data-dependent, the function - semantics are not typically compatible with :func:`~jax.jit` and other JAX - transformations. The JAX version adds the optional ``size`` argument which - must be specified statically for ``jnp.unique`` to be used in such contexts. + is not typically compatible with :func:`~jax.jit` and other JAX transformations. + The JAX version adds the optional ``size`` argument which must be specified + statically for ``jnp.unique`` to be used in such contexts. Args: x: N-dimensional array from which unique values will be extracted. @@ -810,9 +953,9 @@ def unique_counts(x: ArrayLike, /, *, size: int | None = None, :func:`jax.numpy.unique` with `return_counts` and `equal_nan` set to True. Because the size of the output of ``unique_counts`` is data-dependent, the function - semantics are not typically compatible with :func:`~jax.jit` and other JAX - transformations. The JAX version adds the optional ``size`` argument which - must be specified statically for ``jnp.unique`` to be used in such contexts. + is not typically compatible with :func:`~jax.jit` and other JAX transformations. + The JAX version adds the optional ``size`` argument which must be specified + statically for ``jnp.unique`` to be used in such contexts. Args: x: N-dimensional array from which unique values will be extracted. @@ -870,9 +1013,9 @@ def unique_inverse(x: ArrayLike, /, *, size: int | None = None, :func:`jax.numpy.unique` with `return_inverse` and `equal_nan` set to True. Because the size of the output of ``unique_inverse`` is data-dependent, the function - semantics are not typically compatible with :func:`~jax.jit` and other JAX - transformations. The JAX version adds the optional ``size`` argument which - must be specified statically for ``jnp.unique`` to be used in such contexts. + is not typically compatible with :func:`~jax.jit` and other JAX transformations. + The JAX version adds the optional ``size`` argument which must be specified + statically for ``jnp.unique`` to be used in such contexts. Args: x: N-dimensional array from which unique values will be extracted. @@ -935,9 +1078,9 @@ def unique_values(x: ArrayLike, /, *, size: int | None = None, :func:`jax.numpy.unique` with `equal_nan` set to True. Because the size of the output of ``unique_values`` is data-dependent, the function - semantics are not typically compatible with :func:`~jax.jit` and other JAX - transformations. The JAX version adds the optional ``size`` argument which - must be specified statically for ``jnp.unique`` to be used in such contexts. + is not typically compatible with :func:`~jax.jit` and other JAX transformations. + The JAX version adds the optional ``size`` argument which must be specified statically + for ``jnp.unique`` to be used in such contexts. Args: x: N-dimensional array from which unique values will be extracted. diff --git a/jax/_src/numpy/ufunc_api.py b/jax/_src/numpy/ufunc_api.py index 2e114193af13..3473e8a7468a 100644 --- a/jax/_src/numpy/ufunc_api.py +++ b/jax/_src/numpy/ufunc_api.py @@ -25,13 +25,11 @@ import jax from jax._src.typing import Array, ArrayLike, DTypeLike from jax._src.lax import lax as lax_internal -from jax._src.numpy import reductions -from jax._src.numpy.lax_numpy import _eliminate_deprecated_list_indexing, append, take +import jax._src.numpy.lax_numpy as jnp from jax._src.numpy.reductions import _moveaxis -from jax._src.numpy.util import implements, check_arraylike, _broadcast_to, _where +from jax._src.numpy.util import check_arraylike, _broadcast_to, _where from jax._src.numpy.vectorize import vectorize from jax._src.util import canonicalize_axis, set_module -from jax._src import pjit import numpy as np @@ -42,81 +40,126 @@ """ -def get_if_single_primitive(fun: Callable[..., Any], *args: Any) -> jax.core.Primitive | None: - """ - If fun(*args) lowers to a single primitive with inputs and outputs matching - function inputs and outputs, return that primitive. Otherwise return None. - """ - try: - jaxpr = jax.make_jaxpr(fun)(*args) - except: - return None - while len(jaxpr.eqns) == 1: - eqn = jaxpr.eqns[0] - if (eqn.invars, eqn.outvars) != (jaxpr.jaxpr.invars, jaxpr.jaxpr.outvars): - return None - elif (eqn.primitive == pjit.pjit_p and - all(pjit.is_unspecified(sharding) for sharding in - (*eqn.params['in_shardings'], *eqn.params['out_shardings']))): - jaxpr = jaxpr.eqns[0].params['jaxpr'] - else: - return jaxpr.eqns[0].primitive - return None +@set_module('jax.numpy') +class ufunc: + """Universal functions which operation element-by-element on arrays. + JAX implementation of :class:`numpy.ufunc`. -_primitive_reducers: dict[jax.core.Primitive, Callable[..., Any]] = { - lax_internal.add_p: reductions.sum, - lax_internal.mul_p: reductions.prod, -} + This is a class for JAX-backed implementations of NumPy's ufunc APIs. + Most users will never need to instantiate :class:`ufunc`, but rather + will use the pre-defined ufuncs in :mod:`jax.numpy`. + For constructing your own ufuncs, see :func:`jax.numpy.frompyfunc`. -_primitive_accumulators: dict[jax.core.Primitive, Callable[..., Any]] = { - lax_internal.add_p: reductions.cumsum, - lax_internal.mul_p: reductions.cumprod, -} + Examples: + Universal functions are functions that apply element-wise to broadcasted + arrays, but they also come with a number of extra attributes and methods. + As an example, consider the function :obj:`jax.numpy.add`. The object + acts as a function that applies addition to broadcasted arrays in an + element-wise manner: -@set_module('jax.numpy') -class ufunc: - """Functions that operate element-by-element on whole arrays. + >>> x = jnp.array([1, 2, 3, 4, 5]) + >>> jnp.add(x, 1) + Array([2, 3, 4, 5, 6], dtype=int32) + + Each :class:`ufunc` object includes a number of attributes that describe + its behavior: + + >>> jnp.add.nin # number of inputs + 2 + >>> jnp.add.nout # number of outputs + 1 + >>> jnp.add.identity # identity value, or None if no identity exists + 0 + + Binary ufuncs like :obj:`jax.numpy.add` include number of methods to + apply the function to arrays in different manners. + + The :meth:`~ufunc.outer` method applies the function to the + pair-wise outer-product of the input array values: + + >>> jnp.add.outer(x, x) + Array([[ 2, 3, 4, 5, 6], + [ 3, 4, 5, 6, 7], + [ 4, 5, 6, 7, 8], + [ 5, 6, 7, 8, 9], + [ 6, 7, 8, 9, 10]], dtype=int32) - This is a class for LAX-backed implementations of numpy ufuncs. + The :meth:`ufunc.reduce` method perfoms a reduction over the array. + For example, :meth:`jnp.add.reduce` is equivalent to ``jnp.sum``: + + >>> jnp.add.reduce(x) + Array(15, dtype=int32) + + The :meth:`ufunc.accumulate` method performs a cumulative reduction + over the array. For example, :meth:`jnp.add.accumulate` is equivalent + to :func:`jax.numpy.cumulative_sum`: + + >>> jnp.add.accumulate(x) + Array([ 1, 3, 6, 10, 15], dtype=int32) + + The :meth:`ufunc.at` method applies the function at particular indices in the + array; for ``jnp.add`` the computation is similar to :func:`jax.lax.scatter_add`: + + >>> jnp.add.at(x, 0, 100, inplace=False) + Array([101, 2, 3, 4, 5], dtype=int32) + + And the :meth:`ufunc.reduceat` method performs a number of ``reduce`` + operations bewteen specified indices of an array; for ``jnp.add`` the + operation is similar to :func:`jax.ops.segment_sum`: + + >>> jnp.add.reduceat(x, jnp.array([0, 2])) + Array([ 3, 12], dtype=int32) + + In this case, the first element is ``x[0:2].sum()``, and the second element + is ``x[2:].sum()``. """ def __init__(self, func: Callable[..., Any], /, nin: int, nout: int, *, name: str | None = None, nargs: int | None = None, - identity: Any = None, update_doc=False): + identity: Any = None, + call: Callable[..., Any] | None = None, + reduce: Callable[..., Any] | None = None, + accumulate: Callable[..., Any] | None = None, + at: Callable[..., Any] | None = None, + reduceat: Callable[..., Any] | None = None, + ): + self.__doc__ = func.__doc__ + self.__name__ = name or func.__name__ # We want ufunc instances to work properly when marked as static, # and for this reason it's important that their properties not be # mutated. We prevent this by storing them in a dunder attribute, # and accessing them via read-only properties. - if update_doc: - self.__doc__ = func.__doc__ - self.__name__ = name or func.__name__ self.__static_props = { 'func': func, - 'call': vectorize(func), 'nin': operator.index(nin), 'nout': operator.index(nout), 'nargs': operator.index(nargs or nin), - 'identity': identity + 'identity': identity, + 'call': call, + 'reduce': reduce, + 'accumulate': accumulate, + 'at': at, + 'reduceat': reduceat, } _func = property(lambda self: self.__static_props['func']) - _call = property(lambda self: self.__static_props['call']) nin = property(lambda self: self.__static_props['nin']) nout = property(lambda self: self.__static_props['nout']) nargs = property(lambda self: self.__static_props['nargs']) identity = property(lambda self: self.__static_props['identity']) def __hash__(self) -> int: - # Do not include _call, because it is computed from _func. + # In both __hash__ and __eq__, we do not consider call, reduce, etc. + # because they are considered implementation details rather than + # necessary parts of object identity. return hash((self._func, self.__name__, self.identity, self.nin, self.nout, self.nargs)) def __eq__(self, other: Any) -> bool: - # Do not include _call, because it is computed from _func. return isinstance(other, ufunc) and ( (self._func, self.__name__, self.identity, self.nin, self.nout, self.nargs) == (other._func, other.__name__, other.identity, other.nin, other.nout, other.nargs)) @@ -124,20 +167,71 @@ def __eq__(self, other: Any) -> bool: def __repr__(self) -> str: return f"" - def __call__(self, *args: ArrayLike, - out: None = None, where: None = None, - **kwargs: Any) -> Any: + def __call__(self, *args: ArrayLike, out: None = None, where: None = None) -> Any: + check_arraylike(self.__name__, *args) if out is not None: raise NotImplementedError(f"out argument of {self}") if where is not None: raise NotImplementedError(f"where argument of {self}") - return self._call(*args, **kwargs) + call = self.__static_props['call'] or self._call_vectorized + return call(*args) + + @partial(jax.jit, static_argnames=['self']) + def _call_vectorized(self, *args): + return vectorize(self._func)(*args) - @implements(np.ufunc.reduce, module="numpy.ufunc") @partial(jax.jit, static_argnames=['self', 'axis', 'dtype', 'out', 'keepdims']) - def reduce(self, a: ArrayLike, axis: int = 0, dtype: DTypeLike | None = None, + def reduce(self, a: ArrayLike, axis: int = 0, + dtype: DTypeLike | None = None, out: None = None, keepdims: bool = False, initial: ArrayLike | None = None, where: ArrayLike | None = None) -> Array: + """Reduction operation derived from a binary function. + + JAX implementation of :meth:`numpy.ufunc.reduce`. + + Args: + a: Input array. + axis: integer specifying the axis over which to reduce. default=0 + dtype: optionally specify the type of the output array. + out: Unused by JAX + keepdims: If True, reduced axes are left in the result with size 1. + If False (default) then reduced axes are squeezed out. + initial: int or array, Default=None. Initial value for the reduction. + where: boolean mask, default=None. The elements to be used in the sum. Array + should be broadcast compatible to the input. + + Returns: + array containing the result of the reduction operation. + + Examples: + Consider the following array: + + >>> x = jnp.array([[1, 2, 3], + ... [4, 5, 6]]) + + :meth:`jax.numpy.add.reduce` is equivalent to :func:`jax.numpy.sum` + along ``axis=0``: + + >>> jnp.add.reduce(x) + Array([5, 7, 9], dtype=int32) + >>> x.sum(0) + Array([5, 7, 9], dtype=int32) + + Similarly, :meth:`jax.numpy.logical_and.reduce` is equivalent to + :func:`jax.numpy.all`: + + >>> jnp.logical_and.reduce(x > 2) + Array([False, False, True], dtype=bool) + >>> jnp.all(x > 2, axis=0) + Array([False, False, True], dtype=bool) + + Some reductions do not correspond to any built-in aggregation function; + for example here is the reduction of :func:`jax.numpy.bitwise_or` along + the first axis of ``x``: + + >>> jnp.bitwise_or.reduce(x, axis=1) + Array([3, 7], dtype=int32) + """ check_arraylike(f"{self.__name__}.reduce", a) if self.nin != 2: raise ValueError("reduce only supported for binary ufuncs") @@ -154,14 +248,10 @@ def reduce(self, a: ArrayLike, axis: int = 0, dtype: DTypeLike | None = None, "so to use a where mask one has to specify 'initial'.") if lax_internal._dtype(where) != bool: raise ValueError(f"where argument must have dtype=bool; got dtype={lax_internal._dtype(where)}") - primitive = get_if_single_primitive(self._call, *(self.nin * [lax_internal._one(a)])) - if primitive is None: - reducer = self._reduce_via_scan - else: - reducer = _primitive_reducers.get(primitive, self._reduce_via_scan) - return reducer(a, axis=axis, dtype=dtype, keepdims=keepdims, initial=initial, where=where) + reduce = self.__static_props['reduce'] or self._reduce_via_scan + return reduce(a, axis=axis, dtype=dtype, keepdims=keepdims, initial=initial, where=where) - def _reduce_via_scan(self, arr: ArrayLike, axis: int = 0, dtype: DTypeLike | None = None, + def _reduce_via_scan(self, arr: ArrayLike, axis: int | None = 0, dtype: DTypeLike | None = None, keepdims: bool = False, initial: ArrayLike | None = None, where: ArrayLike | None = None) -> Array: assert self.nin == 2 and self.nout == 1 @@ -202,9 +292,9 @@ def _reduce_via_scan(self, arr: ArrayLike, axis: int = 0, dtype: DTypeLike | Non def body_fun(i, val): if where is None: - return self._call(val, arr[i].astype(dtype)) + return self(val, arr[i].astype(dtype)) else: - return _where(where[i], self._call(val, arr[i].astype(dtype)), val) + return _where(where[i], self(val, arr[i].astype(dtype)), val) start_value: ArrayLike if initial is None: @@ -221,22 +311,63 @@ def body_fun(i, val): result = result.reshape(final_shape) return result - @implements(np.ufunc.accumulate, module="numpy.ufunc") @partial(jax.jit, static_argnames=['self', 'axis', 'dtype']) def accumulate(self, a: ArrayLike, axis: int = 0, dtype: DTypeLike | None = None, out: None = None) -> Array: + """Accumulate operation derived from binary ufunc. + + JAX implementation of :func:`numpy.ufunc.accumulate`. + + Args: + a: N-dimensional array over which to accumulate. + axis: integer axis over which accumulation will be performed (default = 0) + dtype: optionally specify the type of the output array. + out: Unused by JAX + + Returns: + An array containing the accumulated result. + + Examples: + Consider the following array: + + >>> x = jnp.array([[1, 2, 3], + ... [4, 5, 6]]) + + :meth:`jax.numpy.add.accumulate` is equivalent to + :func:`jax.numpy.cumsum` along the specified axis: + >>> jnp.add.accumulate(x, axis=1) + Array([[ 1, 3, 6], + [ 4, 9, 15]], dtype=int32) + >>> jnp.cumsum(x, axis=1) + Array([[ 1, 3, 6], + [ 4, 9, 15]], dtype=int32) + + Similarly, :meth:`jax.numpy.multiply.accumulate` is equivalent to + :func:`jax.numpy.cumprod` along the specified axis: + + >>> jnp.multiply.accumulate(x, axis=1) + Array([[ 1, 2, 6], + [ 4, 20, 120]], dtype=int32) + >>> jnp.cumprod(x, axis=1) + Array([[ 1, 2, 6], + [ 4, 20, 120]], dtype=int32) + + For other binary ufuncs, the accumulation is an operation not available + via standard APIs. For example, :meth:`jax.numpy.bitwise_or.accumulate` + is essentially a bitwise cumulative ``any``: + + >>> jnp.bitwise_or.accumulate(x, axis=1) + Array([[1, 3, 3], + [4, 5, 7]], dtype=int32) + """ if self.nin != 2: raise ValueError("accumulate only supported for binary ufuncs") if self.nout != 1: raise ValueError("accumulate only supported for functions returning a single value") if out is not None: raise NotImplementedError(f"out argument of {self.__name__}.accumulate()") - primitive = get_if_single_primitive(self._call, *(self.nin * [lax_internal._one(a)])) - if primitive is None: - accumulator = self._accumulate_via_scan - else: - accumulator = _primitive_accumulators.get(primitive, self._accumulate_via_scan) - return accumulator(a, axis=axis, dtype=dtype) + accumulate = self.__static_props['accumulate'] or self._accumulate_via_scan + return accumulate(a, axis=axis, dtype=dtype) def _accumulate_via_scan(self, arr: ArrayLike, axis: int = 0, dtype: DTypeLike | None = None) -> Array: @@ -254,21 +385,54 @@ def _accumulate_via_scan(self, arr: ArrayLike, axis: int = 0, arr = _moveaxis(arr, axis, 0) def scan_fun(carry, _): i, x = carry - y = _where(i == 0, arr[0].astype(dtype), self._call(x.astype(dtype), arr[i].astype(dtype))) + y = _where(i == 0, arr[0].astype(dtype), self(x.astype(dtype), arr[i].astype(dtype))) return (i + 1, y), y _, result = jax.lax.scan(scan_fun, (0, arr[0].astype(dtype)), None, length=arr.shape[0]) return _moveaxis(result, 0, axis) - @implements(np.ufunc.at, module="numpy.ufunc") @partial(jax.jit, static_argnums=[0], static_argnames=['inplace']) def at(self, a: ArrayLike, indices: Any, b: ArrayLike | None = None, /, *, inplace: bool = True) -> Array: + """Update elements of an array via the specified unary or binary ufunc. + + JAX implementation of :func:`numpy.ufunc.at`. + + Note: + :meth:`numpy.ufunc.at` mutates arrays in-place. JAX arrays are immutable, + so :meth:`jax.numpy.ufunc.at` cannot replicate these semantics. Instead, JAX + will return the updated value, but requires explicitly passing ``inplace=False`` + as a reminder of this difference. + + Args: + a: N-dimensional array to update + indices: index, slice, or tuple of indices and slices. + b: array of values for binary ufunc updates. + inplace: must be set to False to indicate that an updated copy will be returned. + + Returns: + an updated copy of the input array. + + Examples: + + Add numbers to specified indices: + + >>> x = jnp.ones(10, dtype=int) + >>> indices = jnp.array([2, 5, 7]) + >>> values = jnp.array([10, 20, 30]) + >>> jnp.add.at(x, indices, values, inplace=False) + Array([ 1, 1, 11, 1, 1, 21, 1, 31, 1, 1], dtype=int32) + + This is roughly equivalent to JAX's :meth:`jax.numpy.ndarray.at` method + called this way: + + >>> x.at[indices].add(values) + Array([ 1, 1, 11, 1, 1, 21, 1, 31, 1, 1], dtype=int32) + """ if inplace: raise NotImplementedError(_AT_INPLACE_WARNING) - if b is None: - return self._at_via_scan(a, indices) - else: - return self._at_via_scan(a, indices, b) + + at = self.__static_props['at'] or self._at_via_scan + return at(a, indices) if b is None else at(a, indices, b) def _at_via_scan(self, a: ArrayLike, indices: Any, *args: Any) -> Array: assert len(args) in {0, 1} @@ -276,14 +440,14 @@ def _at_via_scan(self, a: ArrayLike, indices: Any, *args: Any) -> Array: dtype = jax.eval_shape(self._func, lax_internal._one(a), *(lax_internal._one(arg) for arg in args)).dtype a = lax_internal.asarray(a).astype(dtype) args = tuple(lax_internal.asarray(arg).astype(dtype) for arg in args) - indices = _eliminate_deprecated_list_indexing(indices) + indices = jnp._eliminate_deprecated_list_indexing(indices) if not indices: return a shapes = [np.shape(i) for i in indices if not isinstance(i, slice)] shape = shapes and jax.lax.broadcast_shapes(*shapes) if not shape: - return a.at[indices].set(self._call(a.at[indices].get(), *args)) + return a.at[indices].set(self(a.at[indices].get(), *args)) if args: arg = _broadcast_to(args[0], (*shape, *args[0].shape[len(shape):])) @@ -293,28 +457,65 @@ def _at_via_scan(self, a: ArrayLike, indices: Any, *args: Any) -> Array: def scan_fun(carry, x): i, a = carry idx = tuple(ind if isinstance(ind, slice) else ind[i] for ind in indices) - a = a.at[idx].set(self._call(a.at[idx].get(), *(arg[i] for arg in args))) + a = a.at[idx].set(self(a.at[idx].get(), *(arg[i] for arg in args))) return (i + 1, a), x carry, _ = jax.lax.scan(scan_fun, (0, a), None, len(indices[0])) return carry[1] - @implements(np.ufunc.reduceat, module="numpy.ufunc") @partial(jax.jit, static_argnames=['self', 'axis', 'dtype']) def reduceat(self, a: ArrayLike, indices: Any, axis: int = 0, dtype: DTypeLike | None = None, out: None = None) -> Array: + """Reduce an array between specified indices via a binary ufunc. + + JAX implementation of :meth:`numpy.ufunc.reduceat` + + Args: + a: N-dimensional array to reduce + indices: a 1-dimensional array of increasing integer values which encodes + segments of the array to be reduced. + axis: integer specifying the axis along which to reduce: default=0. + dtype: optionally specify the dtype of the output array. + out: unused by JAX + Returns: + An array containing the reduced values. + + Examples: + The ``reduce`` method lets you efficiently compute reduction operations + over array segments. For example: + + >>> x = jnp.array([1, 2, 3, 4, 5, 6, 7, 8]) + >>> indices = jnp.array([0, 2, 5]) + >>> jnp.add.reduce(x, indices) + Array([ 3, 12, 21], dtype=int32) + + This is more-or-less equivalent to the following: + + >>> jnp.array([x[0:2].sum(), x[2:5].sum(), x[5:].sum()]) + Array([ 3, 12, 21], dtype=int32) + + For some binary ufuncs, JAX provides similar APIs within :mod:`jax.ops`. + For example, :meth:`jax.add.reduceat` is similar to :func:`jax.ops.segment_sum`, + although in this case the segments are defined via an array of segment ids: + + >>> segments = jnp.array([0, 0, 1, 1, 1, 2, 2, 2]) + >>> jax.ops.segment_sum(x, segments) + Array([ 3, 12, 21], dtype=int32) + """ if self.nin != 2: raise ValueError("reduceat only supported for binary ufuncs") if self.nout != 1: raise ValueError("reduceat only supported for functions returning a single value") if out is not None: raise NotImplementedError(f"out argument of {self.__name__}.reduceat()") - return self._reduceat_via_scan(a, indices, axis=axis, dtype=dtype) + + reduceat = self.__static_props['reduceat'] or self._reduceat_via_scan + return reduceat(a, indices, axis=axis, dtype=dtype) def _reduceat_via_scan(self, a: ArrayLike, indices: Any, axis: int = 0, dtype: DTypeLike | None = None) -> Array: check_arraylike(f"{self.__name__}.reduceat", a, indices) a = lax_internal.asarray(a) - idx_tuple = _eliminate_deprecated_list_indexing(indices) + idx_tuple = jnp._eliminate_deprecated_list_indexing(indices) assert len(idx_tuple) == 1 indices = idx_tuple[0] if a.ndim == 0: @@ -326,27 +527,62 @@ def _reduceat_via_scan(self, a: ArrayLike, indices: Any, axis: int = 0, if axis is None or isinstance(axis, (tuple, list)): raise ValueError("reduceat requires a single integer axis.") axis = canonicalize_axis(axis, a.ndim) - out = take(a, indices, axis=axis) - ind = jax.lax.expand_dims(append(indices, a.shape[axis]), + out = jnp.take(a, indices, axis=axis) + ind = jax.lax.expand_dims(jnp.append(indices, a.shape[axis]), list(np.delete(np.arange(out.ndim), axis))) ind_start = jax.lax.slice_in_dim(ind, 0, ind.shape[axis] - 1, axis=axis) ind_end = jax.lax.slice_in_dim(ind, 1, ind.shape[axis], axis=axis) def loop_body(i, out): return _where((i > ind_start) & (i < ind_end), - self._call(out, take(a, jax.lax.expand_dims(i, (0,)), axis=axis)), + self(out, jnp.take(a, jax.lax.expand_dims(i, (0,)), axis=axis)), out) return jax.lax.fori_loop(0, a.shape[axis], loop_body, out) - @implements(np.ufunc.outer, module="numpy.ufunc") @partial(jax.jit, static_argnums=[0]) - def outer(self, A: ArrayLike, B: ArrayLike, /, **kwargs) -> Array: + def outer(self, A: ArrayLike, B: ArrayLike, /) -> Array: + """Apply the function to all pairs of values in ``A`` and ``B``. + + JAX implementation of :meth:`numpy.ufunc.outer`. + + Args: + A: N-dimensional array + B: N-dimensional array + + Returns: + An array of shape `tuple(*A.shape, *B.shape)` + + Examples: + A times-table for integers 1...10 created via + :meth:`jax.numpy.multiply.outer`: + + >>> x = jnp.arange(1, 11) + >>> print(jnp.multiply.outer(x, x)) + [[ 1 2 3 4 5 6 7 8 9 10] + [ 2 4 6 8 10 12 14 16 18 20] + [ 3 6 9 12 15 18 21 24 27 30] + [ 4 8 12 16 20 24 28 32 36 40] + [ 5 10 15 20 25 30 35 40 45 50] + [ 6 12 18 24 30 36 42 48 54 60] + [ 7 14 21 28 35 42 49 56 63 70] + [ 8 16 24 32 40 48 56 64 72 80] + [ 9 18 27 36 45 54 63 72 81 90] + [ 10 20 30 40 50 60 70 80 90 100]] + + For input arrays with ``N`` and ``M`` dimensions respectively, the output + will have dimesion ``N + M``: + + >>> x = jnp.ones((1, 3, 5)) + >>> y = jnp.ones((2, 4)) + >>> jnp.add.outer(x, y).shape + (1, 3, 5, 2, 4) + """ if self.nin != 2: raise ValueError("outer only supported for binary ufuncs") if self.nout != 1: raise ValueError("outer only supported for functions returning a single value") check_arraylike(f"{self.__name__}.outer", A, B) _ravel = lambda A: jax.lax.reshape(A, (np.size(A),)) - result = jax.vmap(jax.vmap(partial(self._call, **kwargs), (None, 0)), (0, None))(_ravel(A), _ravel(B)) + result = jax.vmap(jax.vmap(self, (None, 0)), (0, None))(_ravel(A), _ravel(B)) return result.reshape(*np.shape(A), *np.shape(B)) @@ -363,4 +599,4 @@ def frompyfunc(func: Callable[..., Any], /, nin: int, nout: int, Returns: wrapped : jax.numpy.ufunc wrapper of func. """ - return ufunc(func, nin, nout, identity=identity, update_doc=True) + return ufunc(func, nin, nout, identity=identity) diff --git a/jax/_src/numpy/ufuncs.py b/jax/_src/numpy/ufuncs.py index 531d0bec813f..dc265b8e87e1 100644 --- a/jax/_src/numpy/ufuncs.py +++ b/jax/_src/numpy/ufuncs.py @@ -21,6 +21,7 @@ from collections.abc import Callable from functools import partial import operator +from typing import Any import numpy as np @@ -29,11 +30,14 @@ from jax._src.api import jit from jax._src.custom_derivatives import custom_jvp from jax._src.lax import lax -from jax._src.typing import Array, ArrayLike +from jax._src.lax import other as lax_other +from jax._src.typing import Array, ArrayLike, DTypeLike from jax._src.numpy.util import ( check_arraylike, promote_args, promote_args_inexact, promote_args_numeric, promote_dtypes_inexact, promote_dtypes_numeric, promote_shapes, _where, implements, check_no_float0s) +from jax._src.numpy.ufunc_api import ufunc +from jax._src.numpy import reductions _lax_const = lax._const @@ -52,9 +56,48 @@ def _replace_inf(x: ArrayLike) -> Array: def _to_bool(x: Array) -> Array: return x if x.dtype == bool else lax.ne(x, _lax_const(x, 0)) -@implements(np.fabs, module='numpy') + @partial(jit, inline=True) def fabs(x: ArrayLike, /) -> Array: + """Compute the element-wise absolute values of the real-valued input. + + JAX implementation of :obj:`numpy.fabs`. + + Args: + x: input array or scalar. Must not have a complex dtype. + + Returns: + An array with same shape as ``x`` and dtype float, containing the element-wise + absolute values. + + See also: + - :func:`jax.numpy.absolute`: Computes the absolute values of the input including + complex dtypes. + - :func:`jax.numpy.abs`: Computes the absolute values of the input including + complex dtypes. + + Examples: + For integer inputs: + + >>> x = jnp.array([-5, -9, 1, 10, 15]) + >>> jnp.fabs(x) + Array([ 5., 9., 1., 10., 15.], dtype=float32) + + For float type inputs: + + >>> x1 = jnp.array([-1.342, 5.649, 3.927]) + >>> jnp.fabs(x1) + Array([1.342, 5.649, 3.927], dtype=float32) + + For boolean inputs: + + >>> x2 = jnp.array([True, False]) + >>> jnp.fabs(x2) + Array([1., 0.], dtype=float32) + """ + check_arraylike('fabs', x) + if dtypes.issubdtype(dtypes.dtype(x), np.complexfloating): + raise TypeError("ufunc 'fabs' does not support complex dtypes") return lax.abs(*promote_args_inexact('fabs', x)) @implements(getattr(np, 'bitwise_invert', np.invert), module='numpy') @@ -72,95 +115,706 @@ def bitwise_not(x: ArrayLike, /) -> Array: def invert(x: ArrayLike, /) -> Array: return lax.bitwise_not(*promote_args('invert', x)) -@implements(np.negative, module='numpy') + @partial(jit, inline=True) -def negative(x: ArrayLike, /) -> Array: +def _negative(x: ArrayLike, /) -> Array: + """Return element-wise negative values of the input. + + JAX implementation of :obj:`numpy.negative`. + + Args: + x: input array or scalar. + + Returns: + An array with same shape and dtype as ``x`` containing ``-x``. + + See also: + - :func:`jax.numpy.positive`: Returns element-wise positive values of the input. + - :func:`jax.numpy.sign`: Returns element-wise indication of sign of the input. + + Note: + ``jnp.negative``, when applied over ``unsigned integer``, produces the result + of their two's complement negation, which typically results in unexpected + large positive values due to integer underflow. + + Examples: + For real-valued inputs: + + >>> x = jnp.array([0., -3., 7]) + >>> jnp.negative(x) + Array([-0., 3., -7.], dtype=float32) + + For complex inputs: + + >>> x1 = jnp.array([1-2j, -3+4j, 5-6j]) + >>> jnp.negative(x1) + Array([-1.+2.j, 3.-4.j, -5.+6.j], dtype=complex64) + + For unit32: + + >>> x2 = jnp.array([5, 0, -7]).astype(jnp.uint32) + >>> x2 + Array([ 5, 0, 4294967289], dtype=uint32) + >>> jnp.negative(x2) + Array([4294967291, 0, 7], dtype=uint32) + """ return lax.neg(*promote_args('negative', x)) -@implements(np.positive, module='numpy') + @partial(jit, inline=True) def positive(x: ArrayLike, /) -> Array: + """Return element-wise positive values of the input. + + JAX implementation of :obj:`numpy.positive`. + + Args: + x: input array or scalar + + Returns: + An array of same shape and dtype as ``x`` containing ``+x``. + + Note: + ``jnp.positive`` is equivalent to ``x.copy()`` and is defined only for the + types that support arithmetic operations. + + See also: + - :func:`jax.numpy.negative`: Returns element-wise negative values of the input. + - :func:`jax.numpy.sign`: Returns element-wise indication of sign of the input. + + Examples: + For real-valued inputs: + + >>> x = jnp.array([-5, 4, 7., -9.5]) + >>> jnp.positive(x) + Array([-5. , 4. , 7. , -9.5], dtype=float32) + >>> x.copy() + Array([-5. , 4. , 7. , -9.5], dtype=float32) + + For complex inputs: + + >>> x1 = jnp.array([1-2j, -3+4j, 5-6j]) + >>> jnp.positive(x1) + Array([ 1.-2.j, -3.+4.j, 5.-6.j], dtype=complex64) + >>> x1.copy() + Array([ 1.-2.j, -3.+4.j, 5.-6.j], dtype=complex64) + + For uint32: + + >>> x2 = jnp.array([6, 0, -4]).astype(jnp.uint32) + >>> x2 + Array([ 6, 0, 4294967292], dtype=uint32) + >>> jnp.positive(x2) + Array([ 6, 0, 4294967292], dtype=uint32) + """ return lax.asarray(*promote_args('positive', x)) -@implements(np.sign, module='numpy') + @partial(jit, inline=True) def sign(x: ArrayLike, /) -> Array: + r"""Return an element-wise indication of sign of the input. + + JAX implementation of :obj:`numpy.sign`. + + The sign of ``x`` for real-valued input is: + + .. math:: + \mathrm{sign}(x) = \begin{cases} + 1, & x > 0\\ + 0, & x = 0\\ + -1, & x < 0 + \end{cases} + + For complex valued input, ``jnp.sign`` returns a unit vector repesenting the + phase. For generalized case, the sign of ``x`` is given by: + + .. math:: + \mathrm{sign}(x) = \begin{cases} + \frac{x}{abs(x)}, & x \ne 0\\ + 0, & x = 0 + \end{cases} + + Args: + x: input array or scalar. + + Returns: + An array with same shape and dtype as ``x`` containing the sign indication. + + See also: + - :func:`jax.numpy.positive`: Returns element-wise positive values of the input. + - :func:`jax.numpy.negative`: Returns element-wise negative values of the input. + + Examples: + For Real-valued inputs: + + >>> x = jnp.array([0., -3., 7.]) + >>> jnp.sign(x) + Array([ 0., -1., 1.], dtype=float32) + + For complex-inputs: + + >>> x1 = jnp.array([1, 3+4j, 5j]) + >>> jnp.sign(x1) + Array([1. +0.j , 0.6+0.8j, 0. +1.j ], dtype=complex64) + """ return lax.sign(*promote_args('sign', x)) -@implements(np.floor, module='numpy') + @partial(jit, inline=True) def floor(x: ArrayLike, /) -> Array: + """Round input to the nearest integer downwards. + + JAX implementation of :obj:`numpy.floor`. + + Args: + x: input array or scalar. Must not have complex dtype. + + Returns: + An array with same shape and dtype as ``x`` containing the values rounded to + the nearest integer that is less than or equal to the value itself. + + See also: + - :func:`jax.numpy.fix`: Rounds the input to the nearest interger towards zero. + - :func:`jax.numpy.trunc`: Rounds the input to the nearest interger towards + zero. + - :func:`jax.numpy.ceil`: Rounds the input up to the nearest integer. + + Examples: + >>> key = jax.random.key(42) + >>> x = jax.random.uniform(key, (3, 3), minval=-5, maxval=5) + >>> with jnp.printoptions(precision=2, suppress=True): + ... print(x) + [[ 1.44 -1.77 -3.07] + [ 3.86 2.25 -3.08] + [-1.55 -2.48 1.32]] + >>> jnp.floor(x) + Array([[ 1., -2., -4.], + [ 3., 2., -4.], + [-2., -3., 1.]], dtype=float32) + """ check_arraylike('floor', x) if dtypes.isdtype(dtypes.dtype(x), ('integral', 'bool')): return lax.asarray(x) return lax.floor(*promote_args_inexact('floor', x)) -@implements(np.ceil, module='numpy') + @partial(jit, inline=True) def ceil(x: ArrayLike, /) -> Array: + """Round input to the nearest integer upwards. + + JAX implementation of :obj:`numpy.ceil`. + + Args: + x: input array or scalar. Must not have complex dtype. + + Returns: + An array with same shape and dtype as ``x`` containing the values rounded to + the nearest integer that is greater than or equal to the value itself. + + See also: + - :func:`jax.numpy.fix`: Rounds the input to the nearest interger towards zero. + - :func:`jax.numpy.trunc`: Rounds the input to the nearest interger towards + zero. + - :func:`jax.numpy.floor`: Rounds the input down to the nearest integer. + + Examples: + >>> key = jax.random.key(1) + >>> x = jax.random.uniform(key, (3, 3), minval=-5, maxval=5) + >>> with jnp.printoptions(precision=2, suppress=True): + ... print(x) + [[ 2.55 -1.87 -3.76] + [ 0.48 3.85 -1.94] + [ 3.2 4.56 -1.43]] + >>> jnp.ceil(x) + Array([[ 3., -1., -3.], + [ 1., 4., -1.], + [ 4., 5., -1.]], dtype=float32) + """ check_arraylike('ceil', x) if dtypes.isdtype(dtypes.dtype(x), ('integral', 'bool')): return lax.asarray(x) return lax.ceil(*promote_args_inexact('ceil', x)) -@implements(np.exp, module='numpy') + @partial(jit, inline=True) def exp(x: ArrayLike, /) -> Array: + """Calculate element-wise exponential of the input. + + JAX implementation of :obj:`numpy.exp`. + + Args: + x: input array or scalar + + Returns: + An array containing the exponential of each element in ``x``, promotes to + inexact dtype. + + See also: + - :func:`jax.numpy.log`: Calculates element-wise logarithm of the input. + - :func:`jax.numpy.expm1`: Calculates :math:`e^x-1` of each element of the + input. + - :func:`jax.numpy.exp2`: Calculates base-2 exponential of each element of + the input. + + Examples: + ``jnp.exp`` follows the properties of exponential such as :math:`e^{(a+b)} + = e^a * e^b`. + + >>> x1 = jnp.array([2, 4, 3, 1]) + >>> x2 = jnp.array([1, 3, 2, 3]) + >>> with jnp.printoptions(precision=2, suppress=True): + ... print(jnp.exp(x1+x2)) + [ 20.09 1096.63 148.41 54.6 ] + >>> with jnp.printoptions(precision=2, suppress=True): + ... print(jnp.exp(x1)*jnp.exp(x2)) + [ 20.09 1096.63 148.41 54.6 ] + + This property holds for complex input also: + + >>> jnp.allclose(jnp.exp(3-4j), jnp.exp(3)*jnp.exp(-4j)) + Array(True, dtype=bool) + """ return lax.exp(*promote_args_inexact('exp', x)) -@implements(np.log, module='numpy') + @partial(jit, inline=True) def log(x: ArrayLike, /) -> Array: + """Calculate element-wise natural logarithm of the input. + + JAX implementation of :obj:`numpy.log`. + + Args: + x: input array or scalar. + + Returns: + An array containing the logarithm of each element in ``x``, promotes to inexact + dtype. + + See also: + - :func:`jax.numpy.exp`: Calculates element-wise exponential of the input. + - :func:`jax.numpy.log2`: Calculates base-2 logarithm of each element of input. + - :func:`jax.numpy.log1p`: Calculates element-wise logarithm of one plus input. + + Examples: + ``jnp.log`` and ``jnp.exp`` are inverse functions of each other. Applying + ``jnp.log`` on the result of ``jnp.exp(x)`` yields the original input ``x``. + + >>> x = jnp.array([2, 3, 4, 5]) + >>> jnp.log(jnp.exp(x)) + Array([2., 3., 4., 5.], dtype=float32) + + Using ``jnp.log`` we can demonstrate well-known properties of logarithms, such + as :math:`log(a*b) = log(a)+log(b)`. + + >>> x1 = jnp.array([2, 1, 3, 1]) + >>> x2 = jnp.array([1, 3, 2, 4]) + >>> jnp.allclose(jnp.log(x1*x2), jnp.log(x1)+jnp.log(x2)) + Array(True, dtype=bool) + """ return lax.log(*promote_args_inexact('log', x)) -@implements(np.expm1, module='numpy') + @partial(jit, inline=True) def expm1(x: ArrayLike, /) -> Array: + """Calculate ``exp(x)-1`` of each element of the input. + + JAX implementation of :obj:`numpy.expm1`. + + Args: + x: input array or scalar. + + Returns: + An array containing ``exp(x)-1`` of each element in ``x``, promotes to inexact + dtype. + + Note: + ``jnp.expm1`` has much higher precision than the naive computation of + ``exp(x)-1`` for small values of ``x``. + + See also: + - :func:`jax.numpy.log1p`: Calculates element-wise logarithm of one plus input. + - :func:`jax.numpy.exp`: Calculates element-wise exponential of the input. + - :func:`jax.numpy.exp2`: Calculates base-2 exponential of each element of + the input. + + Examples: + >>> x = jnp.array([2, -4, 3, -1]) + >>> with jnp.printoptions(precision=2, suppress=True): + ... print(jnp.expm1(x)) + [ 6.39 -0.98 19.09 -0.63] + >>> with jnp.printoptions(precision=2, suppress=True): + ... print(jnp.exp(x)-1) + [ 6.39 -0.98 19.09 -0.63] + + For values very close to 0, ``jnp.expm1(x)`` is much more accurate than + ``jnp.exp(x)-1``: + + >>> x1 = jnp.array([1e-4, 1e-6, 2e-10]) + >>> jnp.expm1(x1) + Array([1.0000500e-04, 1.0000005e-06, 2.0000000e-10], dtype=float32) + >>> jnp.exp(x1)-1 + Array([1.00016594e-04, 9.53674316e-07, 0.00000000e+00], dtype=float32) + """ return lax.expm1(*promote_args_inexact('expm1', x)) -@implements(np.log1p, module='numpy') + @partial(jit, inline=True) def log1p(x: ArrayLike, /) -> Array: + """Calculates element-wise logarithm of one plus input, ``log(x+1)``. + + JAX implementation of :obj:`numpy.log1p`. + + Args: + x: input array or scalar. + + Returns: + An array containing the logarithm of one plus of each element in ``x``, + promotes to inexact dtype. + + Note: + ``jnp.log1p`` is more accurate than when using the naive computation of + ``log(x+1)`` for small values of ``x``. + + See also: + - :func:`jax.numpy.expm1`: Calculates :math:`e^x-1` of each element of the + input. + - :func:`jax.numpy.log2`: Calculates base-2 logarithm of each element of input. + - :func:`jax.numpy.log`: Calculates element-wise logarithm of the input. + + Examples: + >>> x = jnp.array([2, 5, 9, 4]) + >>> jnp.allclose(jnp.log1p(x), jnp.log(x+1)) + Array(True, dtype=bool) + + For values very close to 0, ``jnp.log1p(x)`` is more accurate than + ``jnp.log(x+1)``: + + >>> x1 = jnp.array([1e-4, 1e-6, 2e-10]) + >>> jnp.expm1(jnp.log1p(x1)) # doctest: +SKIP + Array([1.00000005e-04, 9.99999997e-07, 2.00000003e-10], dtype=float32) + >>> jnp.expm1(jnp.log(x1+1)) # doctest: +SKIP + Array([1.000166e-04, 9.536743e-07, 0.000000e+00], dtype=float32) + """ return lax.log1p(*promote_args_inexact('log1p', x)) -@implements(np.sin, module='numpy') + @partial(jit, inline=True) def sin(x: ArrayLike, /) -> Array: + """Compute a trigonometric sine of each element of input. + + JAX implementation of :obj:`numpy.sin`. + + Args: + x: array or scalar. Angle in radians. + + Returns: + An array containing the sine of each element in ``x``, promotes to inexact + dtype. + + See also: + - :func:`jax.numpy.cos`: Computes a trigonometric cosine of each element of + input. + - :func:`jax.numpy.tan`: Computes a trigonometric tangent of each element of + input. + - :func:`jax.numpy.arcsin` and :func:`jax.numpy.asin`: Computes the inverse of + trigonometric sine of each element of input. + + Examples: + >>> pi = jnp.pi + >>> x = jnp.array([pi/4, pi/2, 3*pi/4, pi]) + >>> with jnp.printoptions(precision=3, suppress=True): + ... print(jnp.sin(x)) + [ 0.707 1. 0.707 -0. ] + """ return lax.sin(*promote_args_inexact('sin', x)) -@implements(np.cos, module='numpy') + @partial(jit, inline=True) def cos(x: ArrayLike, /) -> Array: + """Compute a trigonometric cosine of each element of input. + + JAX implementation of :obj:`numpy.cos`. + + Args: + x: scalar or array. Angle in radians. + + Returns: + An array containing the cosine of each element in ``x``, promotes to inexact + dtype. + + See also: + - :func:`jax.numpy.sin`: Computes a trigonometric sine of each element of input. + - :func:`jax.numpy.tan`: Computes a trigonometric tangent of each element of + input. + - :func:`jax.numpy.arccos` and :func:`jax.numpy.acos`: Computes the inverse of + trigonometric cosine of each element of input. + + Examples: + >>> pi = jnp.pi + >>> x = jnp.array([pi/4, pi/2, 3*pi/4, 5*pi/6]) + >>> with jnp.printoptions(precision=3, suppress=True): + ... print(jnp.cos(x)) + [ 0.707 -0. -0.707 -0.866] + """ return lax.cos(*promote_args_inexact('cos', x)) -@implements(np.tan, module='numpy') + @partial(jit, inline=True) def tan(x: ArrayLike, /) -> Array: + """Compute a trigonometric tangent of each element of input. + + JAX implementation of :obj:`numpy.tan`. + + Args: + x: scalar or array. Angle in radians. + + Returns: + An array containing the tangent of each element in ``x``, promotes to inexact + dtype. + + See also: + - :func:`jax.numpy.sin`: Computes a trigonometric sine of each element of input. + - :func:`jax.numpy.cos`: Computes a trigonometric cosine of each element of + input. + - :func:`jax.numpy.arctan` and :func:`jax.numpy.atan`: Computes the inverse of + trigonometric tangent of each element of input. + + Examples: + >>> pi = jnp.pi + >>> x = jnp.array([0, pi/6, pi/4, 3*pi/4, 5*pi/6]) + >>> with jnp.printoptions(precision=3, suppress=True): + ... print(jnp.tan(x)) + [ 0. 0.577 1. -1. -0.577] + """ return lax.tan(*promote_args_inexact('tan', x)) -@implements(np.arcsin, module='numpy') + @partial(jit, inline=True) def arcsin(x: ArrayLike, /) -> Array: + r"""Compute element-wise inverse of trigonometric sine of input. + + JAX implementation of :obj:`numpy.arcsin`. + + Args: + x: input array or scalar. + + Returns: + An array containing the inverse trigonometric sine of each element of ``x`` + in radians in the range ``[-pi/2, pi/2]``, promoting to inexact dtype. + + Note: + - ``jnp.arcsin`` returns ``nan`` when ``x`` is real-valued and not in the closed + interval ``[-1, 1]``. + - ``jnp.arcsin`` follows the branch cut convention of :func:`numpy.arcsin` for + complex inputs. + + See also: + - :func:`jax.numpy.sin`: Computes a trigonometric sine of each element of input. + - :func:`jax.numpy.arccos` and :func:`jax.numpy.acos`: Computes the inverse of + trigonometric cosine of each element of input. + - :func:`jax.numpy.arctan` and :func:`jax.numpy.atan`: Computes the inverse of + trigonometric tangent of each element of input. + + Examples: + >>> x = jnp.array([-2, -1, -0.5, 0, 0.5, 1, 2]) + >>> with jnp.printoptions(precision=3, suppress=True): + ... jnp.arcsin(x) + Array([ nan, -1.571, -0.524, 0. , 0.524, 1.571, nan], dtype=float32) + + For complex-valued inputs: + + >>> with jnp.printoptions(precision=3, suppress=True): + ... jnp.arcsin(3+4j) + Array(0.634+2.306j, dtype=complex64, weak_type=True) + """ return lax.asin(*promote_args_inexact('arcsin', x)) -@implements(np.arccos, module='numpy') + @partial(jit, inline=True) def arccos(x: ArrayLike, /) -> Array: + """Compute element-wise inverse of trigonometric cosine of input. + + JAX implementation of :obj:`numpy.arccos`. + + Args: + x: input array or scalar. + + Returns: + An array containing the inverse trigonometric cosine of each element of ``x`` + in radians in the range ``[0, pi]``, promoting to inexact dtype. + + Note: + - ``jnp.arccos`` returns ``nan`` when ``x`` is real-valued and not in the closed + interval ``[-1, 1]``. + - ``jnp.arccos`` follows the branch cut convention of :func:`numpy.arccos` for + complex inputs. + + See also: + - :func:`jax.numpy.cos`: Computes a trigonometric cosine of each element of + input. + - :func:`jax.numpy.arcsin` and :func:`jax.numpy.asin`: Computes the inverse of + trigonometric sine of each element of input. + - :func:`jax.numpy.arctan` and :func:`jax.numpy.atan`: Computes the inverse of + trigonometric tangent of each element of input. + + Examples: + >>> x = jnp.array([-2, -1, -0.5, 0, 0.5, 1, 2]) + >>> with jnp.printoptions(precision=3, suppress=True): + ... jnp.arccos(x) + Array([ nan, 3.142, 2.094, 1.571, 1.047, 0. , nan], dtype=float32) + + For complex inputs: + + >>> with jnp.printoptions(precision=3, suppress=True): + ... jnp.arccos(4-1j) + Array(0.252+2.097j, dtype=complex64, weak_type=True) + """ return lax.acos(*promote_args_inexact('arccos', x)) -@implements(np.arctan, module='numpy') + @partial(jit, inline=True) def arctan(x: ArrayLike, /) -> Array: + """Compute element-wise inverse of trigonometric tangent of input. + + JAX implement of :obj:`numpy.arctan`. + + Args: + x: input array or scalar. + + Returns: + An array containing the inverse trigonometric tangent of each element ``x`` + in radians in the range ``[-pi/2, pi/2]``, promoting to inexact dtype. + + Note: + ``jnp.arctan`` follows the branch cut convention of :func:`numpy.arctan` for + complex inputs. + + See also: + - :func:`jax.numpy.tan`: Computes a trigonometric tangent of each element of + input. + - :func:`jax.numpy.arcsin` and :func:`jax.numpy.asin`: Computes the inverse of + trigonometric sine of each element of input. + - :func:`jax.numpy.arccos` and :func:`jax.numpy.atan`: Computes the inverse of + trigonometric cosine of each element of input. + + Examples: + >>> x = jnp.array([-jnp.inf, -20, -1, 0, 1, 20, jnp.inf]) + >>> with jnp.printoptions(precision=3, suppress=True): + ... jnp.arctan(x) + Array([-1.571, -1.521, -0.785, 0. , 0.785, 1.521, 1.571], dtype=float32) + + For complex-valued inputs: + + >>> with jnp.printoptions(precision=3, suppress=True): + ... jnp.arctan(2+7j) + Array(1.532+0.133j, dtype=complex64, weak_type=True) + """ return lax.atan(*promote_args_inexact('arctan', x)) -@implements(np.sinh, module='numpy') + @partial(jit, inline=True) def sinh(x: ArrayLike, /) -> Array: + r"""Calculate element-wise hyperbolic sine of input. + + JAX implementation of :obj:`numpy.sinh`. + + The hyperbolic sine is defined by: + + .. math:: + + sinh(x) = \frac{e^x - e^{-x}}{2} + + Args: + x: input array or scalar. + + Returns: + An array containing the hyperbolic sine of each element of ``x``, promoting + to inexact dtype. + + Note: + ``jnp.sinh`` is equivalent to computing ``-1j * jnp.sin(1j * x)``. + + See also: + - :func:`jax.numpy.cosh`: Computes the element-wise hyperbolic cosine of the + input. + - :func:`jax.numpy.tanh`: Computes the element-wise hyperbolic tangent of the + input. + - :func:`jax.numpy.arcsinh`: Computes the element-wise inverse of hyperbolic + sine of the input. + + Examples: + >>> x = jnp.array([[-2, 3, 5], + ... [0, -1, 4]]) + >>> with jnp.printoptions(precision=3, suppress=True): + ... jnp.sinh(x) + Array([[-3.627, 10.018, 74.203], + [ 0. , -1.175, 27.29 ]], dtype=float32) + >>> with jnp.printoptions(precision=3, suppress=True): + ... -1j * jnp.sin(1j * x) + Array([[-3.627+0.j, 10.018-0.j, 74.203-0.j], + [ 0. -0.j, -1.175+0.j, 27.29 -0.j]], dtype=complex64, weak_type=True) + + For complex-valued input: + + >>> with jnp.printoptions(precision=3, suppress=True): + ... jnp.sinh(3-2j) + Array(-4.169-9.154j, dtype=complex64, weak_type=True) + >>> with jnp.printoptions(precision=3, suppress=True): + ... -1j * jnp.sin(1j * (3-2j)) + Array(-4.169-9.154j, dtype=complex64, weak_type=True) + """ return lax.sinh(*promote_args_inexact('sinh', x)) -@implements(np.cosh, module='numpy') + @partial(jit, inline=True) def cosh(x: ArrayLike, /) -> Array: + r"""Calculate element-wise hyperbolic cosine of input. + + JAX implementation of :obj:`numpy.cosh`. + + The hyperbolic cosine is defined by: + + .. math:: + + cosh(x) = \frac{e^x + e^{-x}}{2} + + Args: + x: input array or scalar. + + Returns: + An array containing the hyperbolic cosine of each element of ``x``, promoting + to inexact dtype. + + Note: + ``jnp.cosh`` is equivalent to computing ``jnp.cos(1j * x)``. + + See also: + - :func:`jax.numpy.sinh`: Computes the element-wise hyperbolic sine of the input. + - :func:`jax.numpy.tanh`: Computes the element-wise hyperbolic tangent of the + input. + - :func:`jax.numpy.arccosh`: Computes the element-wise inverse of hyperbolic + cosine of the input. + + Examples: + >>> x = jnp.array([[3, -1, 0], + ... [4, 7, -5]]) + >>> with jnp.printoptions(precision=3, suppress=True): + ... jnp.cosh(x) + Array([[ 10.068, 1.543, 1. ], + [ 27.308, 548.317, 74.21 ]], dtype=float32) + >>> with jnp.printoptions(precision=3, suppress=True): + ... jnp.cos(1j * x) + Array([[ 10.068+0.j, 1.543+0.j, 1. +0.j], + [ 27.308+0.j, 548.317+0.j, 74.21 +0.j]], dtype=complex64, weak_type=True) + + For complex-valued input: + + >>> with jnp.printoptions(precision=3, suppress=True): + ... jnp.cosh(5+1j) + Array(40.096+62.44j, dtype=complex64, weak_type=True) + >>> with jnp.printoptions(precision=3, suppress=True): + ... jnp.cos(1j * (5+1j)) + Array(40.096+62.44j, dtype=complex64, weak_type=True) + """ return lax.cosh(*promote_args_inexact('cosh', x)) @implements(np.arcsinh, module='numpy') @@ -178,9 +832,57 @@ def arccosh(x: ArrayLike, /) -> Array: result = _where(real(result) < 0, lax.neg(result), result) return result -@implements(np.tanh, module='numpy') + @partial(jit, inline=True) def tanh(x: ArrayLike, /) -> Array: + r"""Calculate element-wise hyperbolic tangent of input. + + JAX implementation of :obj:`numpy.tanh`. + + The hyperbolic tangent is defined by: + + .. math:: + + tanh(x) = \frac{sinh(x)}{cosh(x)} = \frac{e^x - e^{-x}}{e^x + e^{-x}} + + Args: + x: input array or scalar. + + Returns: + An array containing the hyperbolic tangent of each element of ``x``, promoting + to inexact dtype. + + Note: + ``jnp.tanh`` is equivalent to computing ``-1j * jnp.tan(1j * x)``. + + See also: + - :func:`jax.numpy.sinh`: Computes the element-wise hyperbolic sine of the input. + - :func:`jax.numpy.cosh`: Computes the element-wise hyperbolic cosine of the + input. + - :func:`jax.numpy.arctanh`: Computes the element-wise inverse of hyperbolic + tangent of the input. + + Examples: + >>> x = jnp.array([[-1, 0, 1], + ... [3, -2, 5]]) + >>> with jnp.printoptions(precision=3, suppress=True): + ... jnp.tanh(x) + Array([[-0.762, 0. , 0.762], + [ 0.995, -0.964, 1. ]], dtype=float32) + >>> with jnp.printoptions(precision=3, suppress=True): + ... -1j * jnp.tan(1j * x) + Array([[-0.762+0.j, 0. -0.j, 0.762-0.j], + [ 0.995-0.j, -0.964+0.j, 1. -0.j]], dtype=complex64, weak_type=True) + + For complex-valued input: + + >>> with jnp.printoptions(precision=3, suppress=True): + ... jnp.tanh(2-5j) + Array(1.031+0.021j, dtype=complex64, weak_type=True) + >>> with jnp.printoptions(precision=3, suppress=True): + ... -1j * jnp.tan(1j * (2-5j)) + Array(1.031+0.021j, dtype=complex64, weak_type=True) + """ return lax.tanh(*promote_args_inexact('tanh', x)) @implements(np.arctanh, module='numpy') @@ -188,9 +890,36 @@ def tanh(x: ArrayLike, /) -> Array: def arctanh(x: ArrayLike, /) -> Array: return lax.atanh(*promote_args_inexact('arctanh', x)) -@implements(np.sqrt, module='numpy') + @partial(jit, inline=True) def sqrt(x: ArrayLike, /) -> Array: + """Calculates element-wise non-negative square root of the input array. + + JAX implementation of :obj:`numpy.sqrt`. + + Args: + x: input array or scalar. + + Returns: + An array containing the non-negative square root of the elements of ``x``. + + Note: + - For real-valued negative inputs, ``jnp.sqrt`` produces a ``nan`` output. + - For complex-valued negative inputs, ``jnp.sqrt`` produces a ``complex`` output. + + See also: + - :func:`jax.numpy.square`: Calculates the element-wise square of the input. + - :func:`jax.numpy.power`: Calculates the element-wise base ``x1`` exponential + of ``x2``. + + Examples: + >>> x = jnp.array([-8-6j, 1j, 4]) + >>> with jnp.printoptions(precision=3, suppress=True): + ... jnp.sqrt(x) + Array([1. -3.j , 0.707+0.707j, 2. +0.j ], dtype=complex64) + >>> jnp.sqrt(-1) + Array(nan, dtype=float32, weak_type=True) + """ return lax.sqrt(*promote_args_inexact('sqrt', x)) @implements(np.cbrt, module='numpy') @@ -198,31 +927,151 @@ def sqrt(x: ArrayLike, /) -> Array: def cbrt(x: ArrayLike, /) -> Array: return lax.cbrt(*promote_args_inexact('cbrt', x)) -@implements(np.add, module='numpy') @partial(jit, inline=True) -def add(x: ArrayLike, y: ArrayLike, /) -> Array: +def _add(x: ArrayLike, y: ArrayLike, /) -> Array: + """Add two arrays element-wise. + + JAX implementation of :obj:`numpy.add`. This is a universal function, + and supports the additional APIs described at :class:`jax.numpy.ufunc`. + This function provides the implementation of the ``+`` operator for + JAX arrays. + + Args: + x, y: arrays to add. Must be broadcastable to a common shape. + + Returns: + Array containing the result of the element-wise addition. + + Examples: + Calling ``add`` explicitly: + + >>> x = jnp.arange(4) + >>> jnp.add(x, 10) + Array([10, 11, 12, 13], dtype=int32) + + Calling ``add`` via the ``+`` operator: + + >>> x + 10 + Array([10, 11, 12, 13], dtype=int32) + """ x, y = promote_args("add", x, y) return lax.add(x, y) if x.dtype != bool else lax.bitwise_or(x, y) -@implements(np.multiply, module='numpy') @partial(jit, inline=True) -def multiply(x: ArrayLike, y: ArrayLike, /) -> Array: +def _multiply(x: ArrayLike, y: ArrayLike, /) -> Array: + """Multiply two arrays element-wise. + + JAX implementation of :obj:`numpy.multiply`. This is a universal function, + and supports the additional APIs described at :class:`jax.numpy.ufunc`. + This function provides the implementation of the ``*`` operator for + JAX arrays. + + Args: + x, y: arrays to multiply. Must be broadcastable to a common shape. + + Returns: + Array containing the result of the element-wise multiplication. + + Examples: + Calling ``multiply`` explicitly: + + >>> x = jnp.arange(4) + >>> jnp.multiply(x, 10) + Array([ 0, 10, 20, 30], dtype=int32) + + Calling ``multiply`` via the ``*`` operator: + + >>> x * 10 + Array([ 0, 10, 20, 30], dtype=int32) + """ x, y = promote_args("multiply", x, y) return lax.mul(x, y) if x.dtype != bool else lax.bitwise_and(x, y) -@implements(np.bitwise_and, module='numpy') @partial(jit, inline=True) -def bitwise_and(x: ArrayLike, y: ArrayLike, /) -> Array: +def _bitwise_and(x: ArrayLike, y: ArrayLike, /) -> Array: + """Compute the bitwise AND operation elementwise. + + JAX implementation of :obj:`numpy.bitwise_and`. This is a universal function, + and supports the additional APIs described at :class:`jax.numpy.ufunc`. + This function provides the implementation of the ``&`` operator for + JAX arrays. + + Args: + x, y: integer or boolean arrays. Must be broadcastable to a common shape. + + Returns: + Array containing the result of the element-wise bitwise AND. + + Examples: + Calling ``bitwise_and`` explicitly: + + >>> x = jnp.arange(4) + >>> jnp.bitwise_and(x, 1) + Array([0, 1, 0, 1], dtype=int32) + + Calling ``bitwise_and`` via the ``&`` operator: + + >>> x & 1 + Array([0, 1, 0, 1], dtype=int32) + """ return lax.bitwise_and(*promote_args("bitwise_and", x, y)) -@implements(np.bitwise_or, module='numpy') @partial(jit, inline=True) -def bitwise_or(x: ArrayLike, y: ArrayLike, /) -> Array: +def _bitwise_or(x: ArrayLike, y: ArrayLike, /) -> Array: + """Compute the bitwise OR operation elementwise. + + JAX implementation of :obj:`numpy.bitwise_or`. This is a universal function, + and supports the additional APIs described at :class:`jax.numpy.ufunc`. + This function provides the implementation of the ``|`` operator for + JAX arrays. + + Args: + x, y: integer or boolean arrays. Must be broadcastable to a common shape. + + Returns: + Array containing the result of the element-wise bitwise OR. + + Examples: + Calling ``bitwise_or`` explicitly: + + >>> x = jnp.arange(4) + >>> jnp.bitwise_or(x, 1) + Array([1, 1, 3, 3], dtype=int32) + + Calling ``bitwise_or`` via the ``|`` operator: + + >>> x | 1 + Array([1, 1, 3, 3], dtype=int32) + """ return lax.bitwise_or(*promote_args("bitwise_or", x, y)) -@implements(np.bitwise_xor, module='numpy') @partial(jit, inline=True) -def bitwise_xor(x: ArrayLike, y: ArrayLike, /) -> Array: +def _bitwise_xor(x: ArrayLike, y: ArrayLike, /) -> Array: + """Compute the bitwise XOR operation elementwise. + + JAX implementation of :obj:`numpy.bitwise_xor`. This is a universal function, + and supports the additional APIs described at :class:`jax.numpy.ufunc`. + This function provides the implementation of the ``^`` operator for + JAX arrays. + + Args: + x, y: integer or boolean arrays. Must be broadcastable to a common shape. + + Returns: + Array containing the result of the element-wise bitwise XOR. + + Examples: + Calling ``bitwise_xor`` explicitly: + + >>> x = jnp.arange(4) + >>> jnp.bitwise_xor(x, 1) + Array([1, 0, 3, 2], dtype=int32) + + Calling ``bitwise_xor`` via the ``^`` operator: + + >>> x ^ 1 + Array([1, 0, 3, 2], dtype=int32) + """ return lax.bitwise_xor(*promote_args("bitwise_xor", x, y)) @implements(np.left_shift, module='numpy') @@ -245,84 +1094,463 @@ def equal(x: ArrayLike, y: ArrayLike, /) -> Array: def not_equal(x: ArrayLike, y: ArrayLike, /) -> Array: return lax.ne(*promote_args("not_equal", x, y)) -@implements(np.subtract, module='numpy') -@partial(jit, inline=True) -def subtract(x: ArrayLike, y: ArrayLike, /) -> Array: - return lax.sub(*promote_args("subtract", x, y)) +@implements(np.subtract, module='numpy') +@partial(jit, inline=True) +def subtract(x: ArrayLike, y: ArrayLike, /) -> Array: + return lax.sub(*promote_args("subtract", x, y)) + +@implements(np.arctan2, module='numpy') +@partial(jit, inline=True) +def arctan2(x1: ArrayLike, x2: ArrayLike, /) -> Array: + return lax.atan2(*promote_args_inexact("arctan2", x1, x2)) + + +@partial(jit, inline=True) +def minimum(x: ArrayLike, y: ArrayLike, /) -> Array: + """Return element-wise minimum of the input arrays. + + JAX implementation of :obj:`numpy.minimum`. + + Args: + x: input array or scalar. + y: input array or scalar. Both ``x`` and ``y`` should either have same shape + or be broadcast compatible. + + Returns: + An array containing the element-wise minimum of ``x`` and ``y``. + + Note: + For each pair of elements, ``jnp.minimum`` returns: + - smaller of the two if both elements are finite numbers. + - ``nan`` if one element is ``nan``. + + See also: + - :func:`jax.numpy.maximum`: Returns element-wise maximum of the input arrays. + - :func:`jax.numpy.fmin`: Returns element-wise minimum of the input arrays, + ignoring NaNs. + - :func:`jax.numpy.amin`: Returns the minimum of array elements along a given + axis. + - :func:`jax.numpy.nanmin`: Returns the minimum of the array elements along + a given axis, ignoring NaNs. + + Examples: + Inputs with ``x.shape == y.shape``: + + >>> x = jnp.array([2, 3, 5, 1]) + >>> y = jnp.array([-3, 6, -4, 7]) + >>> jnp.minimum(x, y) + Array([-3, 3, -4, 1], dtype=int32) + + Inputs having broadcast compatibility: + + >>> x1 = jnp.array([[1, 5, 2], + ... [-3, 4, 7]]) + >>> y1 = jnp.array([-2, 3, 6]) + >>> jnp.minimum(x1, y1) + Array([[-2, 3, 2], + [-3, 3, 6]], dtype=int32) + + Inputs with ``nan``: + + >>> nan = jnp.nan + >>> x2 = jnp.array([[2.5, nan, -2], + ... [nan, 5, 6], + ... [-4, 3, 7]]) + >>> y2 = jnp.array([1, nan, 5]) + >>> jnp.minimum(x2, y2) + Array([[ 1., nan, -2.], + [nan, nan, 5.], + [-4., nan, 5.]], dtype=float32) + """ + return lax.min(*promote_args("minimum", x, y)) + + +@partial(jit, inline=True) +def maximum(x: ArrayLike, y: ArrayLike, /) -> Array: + """Return element-wise maximum of the input arrays. + + JAX implementation of :obj:`numpy.maximum`. + + Args: + x: input array or scalar. + y: input array or scalar. Both ``x`` and ``y`` should either have same shape + or be broadcast compatible. + + Returns: + An array containing the element-wise maximum of ``x`` and ``y``. + + Note: + For each pair of elements, ``jnp.maximum`` returns: + - larger of the two if both elements are finite numbers. + - ``nan`` if one element is ``nan``. + + See also: + - :func:`jax.numpy.minimum`: Returns element-wise minimum of the input + arrays. + - :func:`jax.numpy.fmax`: Returns element-wise maximum of the input arrays, + ignoring NaNs. + - :func:`jax.numpy.amax`: Retruns the maximum of array elements along a given + axis. + - :func:`jax.numpy.nanmax`: Returns the maximum of the array elements along + a given axis, ignoring NaNs. + + Examples: + Inputs with ``x.shape == y.shape``: + + >>> x = jnp.array([1, -5, 3, 2]) + >>> y = jnp.array([-2, 4, 7, -6]) + >>> jnp.maximum(x, y) + Array([1, 4, 7, 2], dtype=int32) + + Inputs with broadcast compatibility: + + >>> x1 = jnp.array([[-2, 5, 7, 4], + ... [1, -6, 3, 8]]) + >>> y1 = jnp.array([-5, 3, 6, 9]) + >>> jnp.maximum(x1, y1) + Array([[-2, 5, 7, 9], + [ 1, 3, 6, 9]], dtype=int32) + + Inputs having ``nan``: + + >>> nan = jnp.nan + >>> x2 = jnp.array([nan, -3, 9]) + >>> y2 = jnp.array([[4, -2, nan], + ... [-3, -5, 10]]) + >>> jnp.maximum(x2, y2) + Array([[nan, -2., nan], + [nan, -3., 10.]], dtype=float32) + """ + return lax.max(*promote_args("maximum", x, y)) + + +@partial(jit, inline=True) +def float_power(x: ArrayLike, y: ArrayLike, /) -> Array: + """Calculate element-wise base ``x`` exponential of ``y``. + + JAX implementation of :obj:`numpy.float_power`. + + Args: + x: scalar or array. Specifies the bases. + y: scalar or array. Specifies the exponents. ``x`` and ``y`` should either + have same shape or be broadcast compatible. + + Returns: + An array containing the base ``x`` exponentials of ``y``, promoting to the + inexact dtype. + + See also: + - :func:`jax.numpy.exp`: Calculates element-wise exponential of the input. + - :func:`jax.numpy.exp2`: Calculates base-2 exponential of each element of + the input. + + Examples: + Inputs with same shape: + + >>> x = jnp.array([3, 1, -5]) + >>> y = jnp.array([2, 4, -1]) + >>> jnp.float_power(x, y) + Array([ 9. , 1. , -0.2], dtype=float32) + + Inputs with broacast compatibility: + + >>> x1 = jnp.array([[2, -4, 1], + ... [-1, 2, 3]]) + >>> y1 = jnp.array([-2, 1, 4]) + >>> jnp.float_power(x1, y1) + Array([[ 0.25, -4. , 1. ], + [ 1. , 2. , 81. ]], dtype=float32) + + ``jnp.float_power`` produces ``nan`` for negative values raised to a non-integer + values. + + >>> jnp.float_power(-3, 1.7) + Array(nan, dtype=float32, weak_type=True) + """ + return lax.pow(*promote_args_inexact("float_power", x, y)) + + +@partial(jit, inline=True) +def nextafter(x: ArrayLike, y: ArrayLike, /) -> Array: + """Return element-wise next floating point value after ``x`` towards ``y``. + + JAX implementation of :obj:`numpy.nextafter`. + + Args: + x: scalar or array. Specifies the value after which the next number is found. + y: scalar or array. Specifies the direction towards which the next number is + found. ``x`` and ``y`` should either have same shape or be broadcast + compatible. + + Returns: + An array containing the next representable number of ``x`` in the direction + of ``y``. + + Examples: + >>> jnp.nextafter(2, 1) # doctest: +SKIP + Array(1.9999999, dtype=float32, weak_type=True) + >>> x = jnp.array([3, -2, 1]) + >>> y = jnp.array([2, -1, 2]) + >>> jnp.nextafter(x, y) # doctest: +SKIP + Array([ 2.9999998, -1.9999999, 1.0000001], dtype=float32) + """ + return lax.nextafter(*promote_args_inexact("nextafter", x, y)) + +# Logical ops +@partial(jit, inline=True) +def _logical_and(x: ArrayLike, y: ArrayLike, /) -> Array: + """Compute the logical AND operation elementwise. + + JAX implementation of :obj:`numpy.logical_and`. This is a universal function, + and supports the additional APIs described at :class:`jax.numpy.ufunc`. + + Args: + x, y: input arrays. Must be broadcastable to a common shape. + + Returns: + Array containing the result of the element-wise logical AND. + + Examples: + >>> x = jnp.arange(4) + >>> jnp.logical_and(x, 1) + Array([False, True, True, True], dtype=bool) + """ + return lax.bitwise_and(*map(_to_bool, promote_args("logical_and", x, y))) + +@partial(jit, inline=True) +def _logical_or(x: ArrayLike, y: ArrayLike, /) -> Array: + """Compute the logical OR operation elementwise. + + JAX implementation of :obj:`numpy.logical_or`. This is a universal function, + and supports the additional APIs described at :class:`jax.numpy.ufunc`. + + Args: + x, y: input arrays. Must be broadcastable to a common shape. + + Returns: + Array containing the result of the element-wise logical OR. + + Examples: + >>> x = jnp.arange(4) + >>> jnp.logical_or(x, 1) + Array([ True, True, True, True], dtype=bool) + """ + return lax.bitwise_or(*map(_to_bool, promote_args("logical_or", x, y))) + +@partial(jit, inline=True) +def _logical_xor(x: ArrayLike, y: ArrayLike, /) -> Array: + """Compute the logical XOR operation elementwise. + + JAX implementation of :obj:`numpy.logical_xor`. This is a universal function, + and supports the additional APIs described at :class:`jax.numpy.ufunc`. + + Args: + x, y: input arrays. Must be broadcastable to a common shape. + + Returns: + Array containing the result of the element-wise logical XOR. + + Examples: + >>> x = jnp.arange(4) + >>> jnp.logical_xor(x, 1) + Array([ True, False, False, False], dtype=bool) + """ + return lax.bitwise_xor(*map(_to_bool, promote_args("logical_xor", x, y))) + +@implements(np.logical_not, module='numpy') +@partial(jit, inline=True) +def logical_not(x: ArrayLike, /) -> Array: + return lax.bitwise_not(*map(_to_bool, promote_args("logical_not", x))) + +# Comparison ops +def _complex_comparison(lax_op: Callable[[ArrayLike, ArrayLike], Array], + x: Array, y: Array): + if dtypes.issubdtype(x.dtype, np.complexfloating): + return lax.select(lax.eq(x.real, y.real), + lax_op(x.imag, y.imag), + lax_op(x.real, y.real)) + return lax_op(x, y) + +@partial(jit, inline=True) +def greater_equal(x: ArrayLike, y: ArrayLike, /) -> Array: + """Return element-wise truth value of ``x >= y``. + + JAX implementation of :obj:`numpy.greater_equal`. + + Args: + x: input array or scalar. + y: input array or scalar. ``x`` and ``y`` must either have same shape or be + broadcast compatible. + + Returns: + An array containing boolean values. ``True`` if the elements of ``x >= y``, + and ``False`` otherwise. + + See also: + - :func:`jax.numpy.less_equal`: Returns element-wise truth value of ``x <= y``. + - :func:`jax.numpy.greater`: Returns element-wise truth value of ``x > y``. + - :func:`jax.numpy.less`: Returns element-wise truth value of ``x < y``. + + Examples: + Scalar inputs: + + >>> jnp.greater_equal(4, 7) + Array(False, dtype=bool, weak_type=True) + + Inputs with same shape: + + >>> x = jnp.array([2, 5, -1]) + >>> y = jnp.array([-6, 4, 3]) + >>> jnp.greater_equal(x, y) + Array([ True, True, False], dtype=bool) + + Inputs with broadcast compatibility: + + >>> x1 = jnp.array([[3, -1, 4], + ... [5, 9, -6]]) + >>> y1 = jnp.array([-1, 4, 2]) + >>> jnp.greater_equal(x1, y1) + Array([[ True, False, True], + [ True, True, False]], dtype=bool) + """ + return _complex_comparison(lax.ge, *promote_args("greater_equal", x, y)) + + +@partial(jit, inline=True) +def greater(x: ArrayLike, y: ArrayLike, /) -> Array: + """Return element-wise truth value of ``x > y``. + + JAX implementation of :obj:`numpy.greater`. + + Args: + x: input array or scalar. + y: input array or scalar. ``x`` and ``y`` must either have same shape or be + broadcast compatible. + + Returns: + An array containing boolean values. ``True`` if the elements of ``x > y``, + and ``False`` otherwise. + + See also: + - :func:`jax.numpy.less`: Returns element-wise truth value of ``x < y``. + - :func:`jax.numpy.greater_equal`: Returns element-wise truth value of + ``x >= y``. + - :func:`jax.numpy.less_equal`: Returns element-wise truth value of ``x <= y``. + + Examples: + Scalar inputs: + + >>> jnp.greater(5, 2) + Array(True, dtype=bool, weak_type=True) + + Inputs with same shape: + + >>> x = jnp.array([5, 9, -2]) + >>> y = jnp.array([4, -1, 6]) + >>> jnp.greater(x, y) + Array([ True, True, False], dtype=bool) + + Inputs with broadcast compatibility: + + >>> x1 = jnp.array([[5, -6, 7], + ... [-2, 5, 9]]) + >>> y1 = jnp.array([-4, 3, 10]) + >>> jnp.greater(x1, y1) + Array([[ True, False, False], + [ True, True, False]], dtype=bool) + """ + return _complex_comparison(lax.gt, *promote_args("greater", x, y)) + + +@partial(jit, inline=True) +def less_equal(x: ArrayLike, y: ArrayLike, /) -> Array: + """Return element-wise truth value of ``x <= y``. + + JAX implementation of :obj:`numpy.less_equal`. + + Args: + x: input array or scalar. + y: input array or scalar. ``x`` and ``y`` must have either same shape or be + broadcast compatible. + + Returns: + An array containing the boolean values. ``True`` if the elements of ``x <= y``, + and ``False`` otherwise. + + See also: + - :func:`jax.numpy.greater_equal`: Returns element-wise truth value of + ``x >= y``. + - :func:`jax.numpy.greater`: Returns element-wise truth value of ``x > y``. + - :func:`jax.numpy.less`: Returns element-wise truth value of ``x < y``. + + Examples: + Scalar inputs: + + >>> jnp.less_equal(6, -2) + Array(False, dtype=bool, weak_type=True) + + Inputs with same shape: + + >>> x = jnp.array([-4, 1, 7]) + >>> y = jnp.array([2, -3, 8]) + >>> jnp.less_equal(x, y) + Array([ True, False, True], dtype=bool) -@implements(np.arctan2, module='numpy') -@partial(jit, inline=True) -def arctan2(x: ArrayLike, y: ArrayLike, /) -> Array: - return lax.atan2(*promote_args_inexact("arctan2", x, y)) + Inputs with broadcast compatibility: -@implements(np.minimum, module='numpy') -@partial(jit, inline=True) -def minimum(x: ArrayLike, y: ArrayLike, /) -> Array: - return lax.min(*promote_args("minimum", x, y)) + >>> x1 = jnp.array([2, -5, 9]) + >>> y1 = jnp.array([[1, -6, 5], + ... [-2, 4, -6]]) + >>> jnp.less_equal(x1, y1) + Array([[False, False, False], + [False, True, False]], dtype=bool) + """ + return _complex_comparison(lax.le, *promote_args("less_equal", x, y)) -@implements(np.maximum, module='numpy') -@partial(jit, inline=True) -def maximum(x: ArrayLike, y: ArrayLike, /) -> Array: - return lax.max(*promote_args("maximum", x, y)) -@implements(np.float_power, module='numpy') @partial(jit, inline=True) -def float_power(x: ArrayLike, y: ArrayLike, /) -> Array: - return lax.pow(*promote_args_inexact("float_power", x, y)) +def less(x: ArrayLike, y: ArrayLike, /) -> Array: + """Return element-wise truth value of ``x < y``. -@implements(np.nextafter, module='numpy') -@partial(jit, inline=True) -def nextafter(x: ArrayLike, y: ArrayLike, /) -> Array: - return lax.nextafter(*promote_args_inexact("nextafter", x, y)) + JAX implementation of :obj:`numpy.less`. -# Logical ops -@implements(np.logical_and, module='numpy') -@partial(jit, inline=True) -def logical_and(x: ArrayLike, y: ArrayLike, /) -> Array: - return lax.bitwise_and(*map(_to_bool, promote_args("logical_and", x, y))) + Args: + x: input array or scalar. + y: input array or scalar. ``x`` and ``y`` must either have same shape or be + broadcast compatible. -@implements(np.logical_or, module='numpy') -@partial(jit, inline=True) -def logical_or(x: ArrayLike, y: ArrayLike, /) -> Array: - return lax.bitwise_or(*map(_to_bool, promote_args("logical_or", x, y))) + Returns: + An array containing boolean values. ``True`` if the elements of ``x < y``, + and ``False`` otherwise. -@implements(np.logical_xor, module='numpy') -@partial(jit, inline=True) -def logical_xor(x: ArrayLike, y: ArrayLike, /) -> Array: - return lax.bitwise_xor(*map(_to_bool, promote_args("logical_xor", x, y))) + See also: + - :func:`jax.numpy.greater`: Returns element-wise truth value of ``x > y``. + - :func:`jax.numpy.greater_equal`: Returns element-wise truth value of + ``x >= y``. + - :func:`jax.numpy.less_equal`: Returns element-wise truth value of ``x <= y``. -@implements(np.logical_not, module='numpy') -@partial(jit, inline=True) -def logical_not(x: ArrayLike, /) -> Array: - return lax.bitwise_not(*map(_to_bool, promote_args("logical_not", x))) + Examples: + Scalar inputs: -# Comparison ops -def _complex_comparison(lax_op: Callable[[ArrayLike, ArrayLike], Array], - x: Array, y: Array): - if dtypes.issubdtype(x.dtype, np.complexfloating): - return lax.select(lax.eq(x.real, y.real), - lax_op(x.imag, y.imag), - lax_op(x.real, y.real)) - return lax_op(x, y) + >>> jnp.less(3, 7) + Array(True, dtype=bool, weak_type=True) -@implements(np.greater_equal, module='numpy') -@partial(jit, inline=True) -def greater_equal(x: ArrayLike, y: ArrayLike, /) -> Array: - return _complex_comparison(lax.ge, *promote_args("greater_equal", x, y)) + Inputs with same shape: -@implements(np.greater, module='numpy') -@partial(jit, inline=True) -def greater(x: ArrayLike, y: ArrayLike, /) -> Array: - return _complex_comparison(lax.gt, *promote_args("greater", x, y)) + >>> x = jnp.array([5, 9, -3]) + >>> y = jnp.array([1, 6, 4]) + >>> jnp.less(x, y) + Array([False, False, True], dtype=bool) -@implements(np.less_equal, module='numpy') -@partial(jit, inline=True) -def less_equal(x: ArrayLike, y: ArrayLike, /) -> Array: - return _complex_comparison(lax.le, *promote_args("less_equal", x, y)) + Inputs with broadcast compatibility: -@implements(np.less, module='numpy') -@partial(jit, inline=True) -def less(x: ArrayLike, y: ArrayLike, /) -> Array: + >>> x1 = jnp.array([[2, -4, 6, -8], + ... [-1, 5, -3, 7]]) + >>> y1 = jnp.array([0, 3, -5, 9]) + >>> jnp.less(x1, y1) + Array([[False, True, False, True], + [ True, False, False, True]], dtype=bool) + """ return _complex_comparison(lax.lt, *promote_args("less", x, y)) # Array API aliases @@ -357,16 +1585,16 @@ def atanh(x: ArrayLike, /) -> Array: return arctanh(*promote_args('atanh', x)) @partial(jit, inline=True) -def atan2(x: ArrayLike, y: ArrayLike, /) -> Array: +def atan2(x1: ArrayLike, x2: ArrayLike, /) -> Array: """Alias of :func:`jax.numpy.arctan2`""" - return arctan2(*promote_args('atan2', x, y)) + return arctan2(*promote_args('atan2', x1, x2)) @jit def bitwise_count(x: ArrayLike, /) -> Array: r"""Counts the number of 1 bits in the binary representation of the absolute value of each element of ``x``. - LAX-backend implementation of :func:`numpy.bitwise_count`. + JAX implementation of :obj:`numpy.bitwise_count`. Args: x: Input array, only accepts integer subtypes @@ -400,7 +1628,7 @@ def bitwise_count(x: ArrayLike, /) -> Array: def right_shift(x1: ArrayLike, x2: ArrayLike, /) -> Array: r"""Right shift the bits of ``x1`` to the amount specified in ``x2``. - LAX-backend implementation of :func:`numpy.right_shift`. + JAX implementation of :obj:`numpy.right_shift`. Args: x1: Input array, only accepts unsigned integer subtypes @@ -446,20 +1674,18 @@ def right_shift(x1: ArrayLike, x2: ArrayLike, /) -> Array: np.issubdtype(x1.dtype, np.unsignedinteger) else lax.shift_right_arithmetic return lax_fn(x1, x2) -@implements(getattr(np, "bitwise_right_shift", np.right_shift), module='numpy') + @partial(jit, inline=True) def bitwise_right_shift(x1: ArrayLike, x2: ArrayLike, /) -> Array: - x1, x2 = promote_args_numeric("bitwise_right_shift", x1, x2) - lax_fn = lax.shift_right_logical if \ - np.issubdtype(x1.dtype, np.unsignedinteger) else lax.shift_right_arithmetic - return lax_fn(x1, x2) + """Alias of :func:`jax.numpy.right_shift`.""" + return right_shift(x1, x2) @partial(jit, inline=True) def absolute(x: ArrayLike, /) -> Array: r"""Calculate the absolute value element-wise. - LAX-backend implementation of :func:`numpy.absolute`. + JAX implementation of :obj:`numpy.absolute`. This is the same function as :func:`jax.numpy.abs`. @@ -500,7 +1726,7 @@ def abs(x: ArrayLike, /) -> Array: def rint(x: ArrayLike, /) -> Array: """Rounds the elements of x to the nearest integer - LAX-backend implementation of :func:`numpy.rint`. + JAX implementation of :obj:`numpy.rint`. Args: x: Input array @@ -513,7 +1739,7 @@ def rint(x: ArrayLike, /) -> Array: If an element of x is exactly half way, e.g. ``0.5`` or ``1.5``, rint will round to the nearest even integer. - Example: + Examples: >>> x1 = jnp.array([5, 4, 7]) >>> jnp.rint(x1) Array([5., 4., 7.], dtype=float32) @@ -539,7 +1765,7 @@ def rint(x: ArrayLike, /) -> Array: def copysign(x1: ArrayLike, x2: ArrayLike, /) -> Array: """Copies the sign of each element in ``x2`` to the corresponding element in ``x1``. - LAX-backend implementation of :func:`numpy.copysign`. + JAX implementation of :obj:`numpy.copysign`. Args: x1: Input array @@ -574,20 +1800,54 @@ def copysign(x1: ArrayLike, x2: ArrayLike, /) -> Array: return _where(signbit(x2).astype(bool), -lax.abs(x1), lax.abs(x1)) -@implements(np.true_divide, module='numpy') @partial(jit, inline=True) def true_divide(x1: ArrayLike, x2: ArrayLike, /) -> Array: + """Calculates the division of x1 by x2 element-wise + + JAX implementation of :func:`numpy.true_divide`. + + Args: + x1: Input array, the dividend + x2: Input array, the divisor + + Returns: + An array containing the elementwise quotients, will always use + floating point division. + + Examples: + >>> x1 = jnp.array([3, 4, 5]) + >>> x2 = 2 + >>> jnp.true_divide(x1, x2) + Array([1.5, 2. , 2.5], dtype=float32) + + >>> x1 = 24 + >>> x2 = jnp.array([3, 4, 6j]) + >>> jnp.true_divide(x1, x2) + Array([8.+0.j, 6.+0.j, 0.-4.j], dtype=complex64) + + >>> x1 = jnp.array([1j, 9+5j, -4+2j]) + >>> x2 = 3j + >>> jnp.true_divide(x1, x2) + Array([0.33333334+0.j , 1.6666666 -3.j , + 0.6666667 +1.3333334j], dtype=complex64) + + See Also: + :func:`jax.numpy.floor_divide` for integer division + """ x1, x2 = promote_args_inexact("true_divide", x1, x2) return lax.div(x1, x2) -divide = true_divide + +def divide(x1: ArrayLike, x2: ArrayLike, /) -> Array: + """Alias of :func:`jax.numpy.true_divide`.""" + return true_divide(x1, x2) @jit def floor_divide(x1: ArrayLike, x2: ArrayLike, /) -> Array: """Calculates the floor division of x1 by x2 element-wise - LAX-backend implementation of :func:`numpy.floor_divide`. + JAX implementation of :obj:`numpy.floor_divide`. Args: x1: Input array, the dividend @@ -598,6 +1858,14 @@ def floor_divide(x1: ArrayLike, x2: ArrayLike, /) -> Array: to the nearest integer towards negative infinity. This is equivalent to ``x1 // x2`` in Python. + Note: + ``x1 // x2`` is equivalent to ``jnp.floor_divide(x1, x2)`` for arrays ``x1`` + and ``x2`` + + See Also: + :func:`jax.numpy.divide` and :func:`jax.numpy.true_divide` for floating point + division. + Examples: >>> x1 = jnp.array([10, 20, 30]) >>> x2 = jnp.array([3, 4, 7]) @@ -613,12 +1881,6 @@ def floor_divide(x1: ArrayLike, x2: ArrayLike, /) -> Array: >>> x2 = jnp.array([2.0, 2.5, 3.0], dtype=jnp.float32) >>> jnp.floor_divide(x1, x2) Array([3., 2., 2.], dtype=float32) - - Note: - ``x1 // x2`` is equivalent to ``jnp.floor_divide(x1, x2)`` for arrays ``x1`` and ``x2`` - - See Also: - :func:`jnp.divide` and :func:`jnp.true_divide` for floating point division """ x1, x2 = promote_args_numeric("floor_divide", x1, x2) dtype = dtypes.dtype(x1) @@ -639,7 +1901,7 @@ def floor_divide(x1: ArrayLike, x2: ArrayLike, /) -> Array: def divmod(x1: ArrayLike, x2: ArrayLike, /) -> tuple[Array, Array]: """Calculates the integer quotient and remainder of x1 by x2 element-wise - LAX-backend implementation of :func:`numpy.divmod`. + JAX implementation of :obj:`numpy.divmod`. Args: x1: Input array, the dividend @@ -648,6 +1910,10 @@ def divmod(x1: ArrayLike, x2: ArrayLike, /) -> tuple[Array, Array]: Returns: A tuple of arrays ``(x1 // x2, x1 % x2)``. + See Also: + - :func:`jax.numpy.floor_divide`: floor division function + - :func:`jax.numpy.remainder`: remainder function + Examples: >>> x1 = jnp.array([10, 20, 30]) >>> x2 = jnp.array([3, 4, 7]) @@ -665,10 +1931,6 @@ def divmod(x1: ArrayLike, x2: ArrayLike, /) -> tuple[Array, Array]: >>> jnp.divmod(x1, x2) (Array([3., 2., 1.], dtype=float32), Array([0.30000007, 1. , 2.9 ], dtype=float32)) - - See Also: - - :func:`jax.numpy.floor_divide`: floor division function - - :func:`jax.numpy.remainder`: remainder function """ x1, x2 = promote_args_numeric("divmod", x1, x2) if dtypes.issubdtype(dtypes.dtype(x1), np.integer): @@ -689,8 +1951,61 @@ def _float_divmod(x1: ArrayLike, x2: ArrayLike) -> tuple[Array, Array]: return lax.round(div), mod -@implements(np.power, module='numpy') def power(x1: ArrayLike, x2: ArrayLike, /) -> Array: + """Calculate element-wise base ``x1`` exponential of ``x2``. + + JAX implementation of :obj:`numpy.power`. + + Args: + x1: scalar or array. Specifies the bases. + x2: scalar or array. Specifies the exponent. ``x1`` and ``x2`` should either + have same shape or be broadcast compatible. + + Returns: + An array containing the base ``x1`` exponentials of ``x2`` with same dtype + as input. + + Note: + - When ``x2`` is a concrete integer scalar, ``jnp.power`` lowers to + :func:`jax.lax.integer_pow`. + - When ``x2`` is a traced scalar or an array, ``jnp.power`` lowers to + :func:`jax.lax.pow`. + - ``jnp.power`` raises a ``TypeError`` for integer type raised to negative + integer power. + - ``jnp.power`` returns ``nan`` for negative value raised to the power of + non-integer values. + + See also: + - :func:`jax.lax.pow`: Computes element-wise power, :math:`x^y`. + - :func:`jax.lax.integer_pow`: Computes element-wise power :math:`x^y`, where + :math:`y` is a fixed integer. + - :func:`jax.numpy.float_power`: Computes the first array raised to the power + of second array, element-wise, by promoting to the inexact dtype. + - :func:`jax.numpy.pow`: Computes the first array raised to the power of second + array, element-wise. + + Examples: + Inputs with scalar integers: + + >>> jnp.power(4, 3) + Array(64, dtype=int32, weak_type=True) + + Inputs with same shape: + + >>> x1 = jnp.array([2, 4, 5]) + >>> x2 = jnp.array([3, 0.5, 2]) + >>> jnp.power(x1, x2) + Array([ 8., 2., 25.], dtype=float32) + + Inputs with broadcast compatibility: + + >>> x3 = jnp.array([-2, 3, 1]) + >>> x4 = jnp.array([[4, 1, 6], + ... [1.3, 3, 5]]) + >>> jnp.power(x3, x4) + Array([[16., 3., 1.], + [nan, 27., 1.]], dtype=float32) + """ check_arraylike("power", x1, x2) check_no_float0s("power", x1, x2) @@ -720,8 +2035,9 @@ def power(x1: ArrayLike, x2: ArrayLike, /) -> Array: # Handle cases #2 and #3 under a jit: return _power(x1, x2) -# Array API alias -pow = power +def pow(x1: ArrayLike, x2: ArrayLike, /) -> Array: + """Alias of :func:`jax.numpy.power`""" + return power(x1, x2) @partial(jit, inline=True) def _power(x1: ArrayLike, x2: ArrayLike) -> Array: @@ -758,21 +2074,30 @@ def _pow_int_int(x1, x2): return acc -@custom_jvp -@implements(np.logaddexp, module='numpy') @jit def logaddexp(x1: ArrayLike, x2: ArrayLike, /) -> Array: + """Compute ``log(exp(x1) + exp(x2))`` avoiding overflow. + + JAX implementation of :obj:`numpy.logaddexp` + + Args: + x1: input array + x2: input array + + Returns: + array containing the result. + + Examples: + + >>> x1 = jnp.array([1, 2, 3]) + >>> x2 = jnp.array([4, 5, 6]) + >>> result1 = jnp.logaddexp(x1, x2) + >>> result2 = jnp.log(jnp.exp(x1) + jnp.exp(x2)) + >>> print(jnp.allclose(result1, result2)) + True + """ x1, x2 = promote_args_inexact("logaddexp", x1, x2) - amax = lax.max(x1, x2) - if dtypes.issubdtype(x1.dtype, np.floating): - delta = lax.sub(x1, x2) - return lax.select(lax._isnan(delta), - lax.add(x1, x2), # NaNs or infinities of the same sign. - lax.add(amax, lax.log1p(lax.exp(lax.neg(lax.abs(delta)))))) - else: - delta = lax.sub(lax.add(x1, x2), lax.mul(amax, _constant_like(amax, 2))) - out = lax.add(amax, lax.log1p(lax.exp(delta))) - return lax.complex(lax.real(out), _wrap_between(lax.imag(out), np.pi)) + return lax_other.logaddexp(x1, x2) def _wrap_between(x, _a): @@ -785,22 +2110,39 @@ def _wrap_between(x, _a): return lax.sub(rem, a) -@logaddexp.defjvp -def _logaddexp_jvp(primals, tangents): - x1, x2 = primals - t1, t2 = tangents - x1, x2, t1, t2 = promote_args_inexact("logaddexp_jvp", x1, x2, t1, t2) - primal_out = logaddexp(x1, x2) - tangent_out = lax.add(lax.mul(t1, exp(lax.sub(_replace_inf(x1), _replace_inf(primal_out)))), - lax.mul(t2, exp(lax.sub(_replace_inf(x2), _replace_inf(primal_out))))) - return primal_out, tangent_out - - -@custom_jvp -@implements(np.logaddexp2, module='numpy') @jit def logaddexp2(x1: ArrayLike, x2: ArrayLike, /) -> Array: + """Logarithm of the sum of exponentials of inputs in base-2 avoiding overflow. + + JAX implementation of :obj:`numpy.logaddexp2`. + + Args: + x1: input array or scalar. + x2: input array or scalar. ``x1`` and ``x2`` should either have same shape or + be broadcast compatible. + + Returns: + An array containing the result, :math:`log_2(2^{x1}+2^{x2})`, element-wise. + + See also: + - :func:`jax.numpy.logaddexp`: Computes ``log(exp(x1) + exp(x2))``, element-wise. + - :func:`jax.numpy.log2`: Calculates the base-2 logarithm of ``x`` element-wise. + + Examples: + >>> x1 = jnp.array([[3, -1, 4], + ... [8, 5, -2]]) + >>> x2 = jnp.array([2, 3, -5]) + >>> result1 = jnp.logaddexp2(x1, x2) + >>> result2 = jnp.log2(jnp.exp2(x1) + jnp.exp2(x2)) + >>> jnp.allclose(result1, result2) + Array(True, dtype=bool) + """ x1, x2 = promote_args_inexact("logaddexp2", x1, x2) + return _logaddexp2(x1, x2) + + +@custom_jvp +def _logaddexp2(x1, x2): amax = lax.max(x1, x2) if dtypes.issubdtype(x1.dtype, np.floating): delta = lax.sub(x1, x2) @@ -814,7 +2156,7 @@ def logaddexp2(x1: ArrayLike, x2: ArrayLike, /) -> Array: return lax.complex(lax.real(out), _wrap_between(lax.imag(out), np.pi / np.log(2))) -@logaddexp2.defjvp +@_logaddexp2.defjvp def _logaddexp2_jvp(primals, tangents): x1, x2 = primals t1, t2 = tangents @@ -827,9 +2169,9 @@ def _logaddexp2_jvp(primals, tangents): @partial(jit, inline=True) def log2(x: ArrayLike, /) -> Array: - """Calculates the base-2 logarithm of x element-wise + """Calculates the base-2 logarithm of ``x`` element-wise. - LAX-backend implementation of :func:`numpy.log2`. + JAX implementation of :obj:`numpy.log2`. Args: x: Input array @@ -851,7 +2193,7 @@ def log2(x: ArrayLike, /) -> Array: def log10(x: ArrayLike, /) -> Array: """Calculates the base-10 logarithm of x element-wise - LAX-backend implementation of :func:`numpy.log10`. + JAX implementation of :obj:`numpy.log10`. Args: x: Input array @@ -870,9 +2212,36 @@ def log10(x: ArrayLike, /) -> Array: return lax.div(lax.log(x), lax.log(_constant_like(x, 10))) -@implements(np.exp2, module='numpy') @partial(jit, inline=True) def exp2(x: ArrayLike, /) -> Array: + """Calculate element-wise base-2 exponential of input. + + JAX implementation of :obj:`numpy.exp2`. + + Args: + x: input array or scalar + + Returns: + An array containing the base-2 exponential of each element in ``x``, promotes + to inexact dtype. + + See also: + - :func:`jax.numpy.log2`: Calculates base-2 logarithm of each element of input. + - :func:`jax.numpy.exp`: Calculates exponential of each element of the input. + - :func:`jax.numpy.expm1`: Calculates :math:`e^x-1` of each element of the + input. + + Examples: + ``jnp.exp2`` follows the properties of the exponential such as :math:`2^{a+b} + = 2^a * 2^b`. + + >>> x1 = jnp.array([2, -4, 3, -1]) + >>> x2 = jnp.array([-1, 3, -2, 3]) + >>> jnp.exp2(x1+x2) + Array([2. , 0.5, 2. , 4. ], dtype=float32) + >>> jnp.exp2(x1)*jnp.exp2(x2) + Array([2. , 0.5, 2. , 4. ], dtype=float32) + """ x, = promote_args_inexact("exp2", x) return lax.exp2(x) @@ -981,9 +2350,42 @@ def frexp(x: ArrayLike, /) -> tuple[Array, Array]: return _where(cond, x, x1), lax.convert_element_type(x2, np.int32) -@implements(np.remainder, module='numpy') @jit def remainder(x1: ArrayLike, x2: ArrayLike, /) -> Array: + """Returns element-wise remainder of the division. + + JAX implementation of :obj:`numpy.remainder`. + + Args: + x1: scalar or array. Specifies the dividend. + x2: scalar or array. Specifies the divisor. ``x1`` and ``x2`` should either + have same shape or be broadcast compatible. + + Returns: + An array containing the remainder of element-wise division of ``x1`` by + ``x2`` with same sign as the elements of ``x2``. + + Note: + The result of ``jnp.remainder`` is equivalent to ``x1 - x2 * jnp.floor(x1 / x2)``. + + See also: + - :func:`jax.numpy.mod`: Returns the element-wise remainder of the division. + - :func:`jax.numpy.fmod`: Calculates the element-wise floating-point modulo + operation. + - :func:`jax.numpy.divmod`: Calculates the integer quotient and remainder of + ``x1`` by ``x2``, element-wise. + + Examples: + >>> x1 = jnp.array([[3, -1, 4], + ... [8, 5, -2]]) + >>> x2 = jnp.array([2, 3, -5]) + >>> jnp.remainder(x1, x2) + Array([[ 1, 2, -1], + [ 0, 2, -2]], dtype=int32) + >>> x1 - x2 * jnp.floor(x1 / x2) + Array([[ 1., 2., -1.], + [ 0., 2., -2.]], dtype=float32) + """ x1, x2 = promote_args_numeric("remainder", x1, x2) zero = _constant_like(x1, 0) if dtypes.issubdtype(x2.dtype, np.integer): @@ -993,68 +2395,302 @@ def remainder(x1: ArrayLike, x2: ArrayLike, /) -> Array: do_plus = lax.bitwise_and( lax.ne(lax.lt(trunc_mod, zero), lax.lt(x2, zero)), trunc_mod_not_zero) return lax.select(do_plus, lax.add(trunc_mod, x2), trunc_mod) -mod = implements(np.mod, module='numpy')(remainder) -@implements(np.fmod, module='numpy') +def mod(x1: ArrayLike, x2: ArrayLike, /) -> Array: + """Alias of :func:`jax.numpy.remainder`""" + return remainder(x1, x2) + + @jit def fmod(x1: ArrayLike, x2: ArrayLike, /) -> Array: + """Calculate element-wise floating-point modulo operation. + + JAX implementation of :obj:`numpy.fmod`. + + Args: + x1: scalar or array. Specifies the dividend. + x2: scalar or array. Specifies the divisor. ``x1`` and ``x2`` should either + have same shape or be broadcast compatible. + + Returns: + An array containing the result of the element-wise floating-point modulo + operation of ``x1`` and ``x2`` with same sign as the elements of ``x1``. + + Note: + The result of ``jnp.fmod`` is equivalent to ``x1 - x2 * jnp.fix(x1 / x2)``. + + See also: + - :func:`jax.numpy.mod` and :func:`jax.numpy.remainder`: Returns the element-wise + remainder of the division. + - :func:`jax.numpy.divmod`: Calculates the integer quotient and remainder of + ``x1`` by ``x2``, element-wise. + + Examples: + >>> x1 = jnp.array([[3, -1, 4], + ... [8, 5, -2]]) + >>> x2 = jnp.array([2, 3, -5]) + >>> jnp.fmod(x1, x2) + Array([[ 1, -1, 4], + [ 0, 2, -2]], dtype=int32) + >>> x1 - x2 * jnp.fix(x1 / x2) + Array([[ 1., -1., 4.], + [ 0., 2., -2.]], dtype=float32) + """ check_arraylike("fmod", x1, x2) if dtypes.issubdtype(dtypes.result_type(x1, x2), np.integer): x2 = _where(x2 == 0, lax._ones(x2), x2) return lax.rem(*promote_args_numeric("fmod", x1, x2)) -@implements(np.square, module='numpy') @partial(jit, inline=True) def square(x: ArrayLike, /) -> Array: + """Calculate element-wise square of the input array. + + JAX implementation of :obj:`numpy.square`. + + Args: + x: input array or scalar. + + Returns: + An array containing the square of the elements of ``x``. + + Note: + ``jnp.square`` is equivalent to computing ``jnp.power(x, 2)``. + + See also: + - :func:`jax.numpy.sqrt`: Calculates the element-wise non-negative square root + of the input array. + - :func:`jax.numpy.power`: Calculates the element-wise base ``x1`` exponential + of ``x2``. + - :func:`jax.lax.integer_pow`: Computes element-wise power :math:`x^y`, where + :math:`y` is a fixed integer. + - :func:`jax.numpy.float_power`: Computes the first array raised to the power + of second array, element-wise, by promoting to the inexact dtype. + + Examples: + >>> x = jnp.array([3, -2, 5.3, 1]) + >>> jnp.square(x) + Array([ 9. , 4. , 28.090002, 1. ], dtype=float32) + >>> jnp.power(x, 2) + Array([ 9. , 4. , 28.090002, 1. ], dtype=float32) + + For integer inputs: + + >>> x1 = jnp.array([2, 4, 5, 6]) + >>> jnp.square(x1) + Array([ 4, 16, 25, 36], dtype=int32) + + For complex-valued inputs: + + >>> x2 = jnp.array([1-3j, -1j, 2]) + >>> jnp.square(x2) + Array([-8.-6.j, -1.+0.j, 4.+0.j], dtype=complex64) + """ check_arraylike("square", x) x, = promote_dtypes_numeric(x) return lax.integer_pow(x, 2) -@implements(np.deg2rad, module='numpy') @partial(jit, inline=True) def deg2rad(x: ArrayLike, /) -> Array: + r"""Convert angles from degrees to radians. + + JAX implementation of :obj:`numpy.deg2rad`. + + The angle in degrees is converted to radians by: + + .. math:: + + deg2rad(x) = x * \frac{pi}{180} + + Args: + x: scalar or array. Specifies the angle in degrees. + + Returns: + An array containing the angles in radians. + + See also: + - :func:`jax.numpy.rad2deg` and :func:`jax.numpy.degrees`: Converts the angles + from radians to degrees. + - :func:`jax.numpy.radians`: Alias of ``deg2rad``. + + Examples: + >>> x = jnp.array([60, 90, 120, 180]) + >>> jnp.deg2rad(x) + Array([1.0471976, 1.5707964, 2.0943952, 3.1415927], dtype=float32) + >>> x * jnp.pi / 180 + Array([1.0471976, 1.5707964, 2.0943952, 3.1415927], dtype=float32, weak_type=True) + """ x, = promote_args_inexact("deg2rad", x) return lax.mul(x, _lax_const(x, np.pi / 180)) -@implements(np.rad2deg, module='numpy') @partial(jit, inline=True) def rad2deg(x: ArrayLike, /) -> Array: + r"""Convert angles from radians to degrees. + + JAX implementation of :obj:`numpy.rad2deg`. + + The angle in radians is converted to degrees by: + + .. math:: + + rad2deg(x) = x * \frac{180}{pi} + + Args: + x: scalar or array. Specifies the angle in radians. + + Returns: + An array containing the angles in degrees. + + See also: + - :func:`jax.numpy.deg2rad` and :func:`jax.numpy.radians`: Converts the angles + from degrees to radians. + - :func:`jax.numpy.degrees`: Alias of ``rad2deg``. + + Examples: + >>> pi = jnp.pi + >>> x = jnp.array([pi/4, pi/2, 2*pi/3]) + >>> jnp.rad2deg(x) + Array([ 45. , 90. , 120.00001], dtype=float32) + >>> x * 180 / pi + Array([ 45., 90., 120.], dtype=float32) + """ x, = promote_args_inexact("rad2deg", x) return lax.mul(x, _lax_const(x, 180 / np.pi)) -degrees = rad2deg -radians = deg2rad +def degrees(x: ArrayLike, /) -> Array: + """Alias of :func:`jax.numpy.rad2deg`""" + return rad2deg(x) + +def radians(x: ArrayLike, /) -> Array: + """Alias of :func:`jax.numpy.deg2rad`""" + return deg2rad(x) -@implements(np.conjugate, module='numpy') @partial(jit, inline=True) def conjugate(x: ArrayLike, /) -> Array: + """Return element-wise complex-conjugate of the input. + + JAX implementation of :obj:`numpy.conjugate`. + + Args: + x: inpuat array or scalar. + + Returns: + An array containing the complex-conjugate of ``x``. + + See also: + - :func:`jax.numpy.real`: Returns the element-wise real part of the complex + argument. + - :func:`jax.numpy.imag`: Returns the element-wise imaginary part of the + complex argument. + + Examples: + >>> jnp.conjugate(3) + Array(3, dtype=int32, weak_type=True) + >>> x = jnp.array([2-1j, 3+5j, 7]) + >>> jnp.conjugate(x) + Array([2.+1.j, 3.-5.j, 7.-0.j], dtype=complex64) + """ check_arraylike("conjugate", x) return lax.conj(x) if np.iscomplexobj(x) else lax.asarray(x) -conj = conjugate -@implements(np.imag) +def conj(x: ArrayLike, /) -> Array: + """Alias of :func:`jax.numpy.conjugate`""" + return conjugate(x) + + @partial(jit, inline=True) def imag(val: ArrayLike, /) -> Array: + """Return element-wise imaginary of part of the complex argument. + + JAX implementation of :obj:`numpy.imag`. + + Args: + val: input array or scalar. + + Returns: + An array containing the imaginary part of the elements of ``val``. + + See also: + - :func:`jax.numpy.conjugate` and :func:`jax.numpy.conj`: Returns the element-wise + complex-conjugate of the input. + - :func:`jax.numpy.real`: Returns the element-wise real part of the complex + argument. + + Examples: + >>> jnp.imag(4) + Array(0, dtype=int32, weak_type=True) + >>> jnp.imag(5j) + Array(5., dtype=float32, weak_type=True) + >>> x = jnp.array([2+3j, 5-1j, -3]) + >>> jnp.imag(x) + Array([ 3., -1., 0.], dtype=float32) + """ check_arraylike("imag", val) return lax.imag(val) if np.iscomplexobj(val) else lax.full_like(val, 0) -@implements(np.real) @partial(jit, inline=True) def real(val: ArrayLike, /) -> Array: + """Return element-wise real part of the complex argument. + + JAX implementation of :obj:`numpy.real`. + + Args: + val: input array or scalar. + + Returns: + An array containing the real part of the elements of ``val``. + + See also: + - :func:`jax.numpy.conjugate` and :func:`jax.numpy.conj`: Returns the element-wise + complex-conjugate of the input. + - :func:`jax.numpy.imag`: Returns the element-wise imaginary part of the + complex argument. + + Examples: + >>> jnp.real(5) + Array(5, dtype=int32, weak_type=True) + >>> jnp.real(2j) + Array(0., dtype=float32, weak_type=True) + >>> x = jnp.array([3-2j, 4+7j, -2j]) + >>> jnp.real(x) + Array([ 3., 4., -0.], dtype=float32) + """ check_arraylike("real", val) return lax.real(val) if np.iscomplexobj(val) else lax.asarray(val) -@implements(np.modf, module='numpy', skip_params=['out']) + @jit def modf(x: ArrayLike, /, out=None) -> tuple[Array, Array]: + """Return element-wise fractional and integral parts of the input array. + + JAX implementation of :obj:`numpy.modf`. + + Args: + x: input array or scalar. + out: Not used by JAX. + + Returns: + An array containing the fractional and integral parts of the elements of ``x``, + promoting dtypes inexact. + + See also: + - :func:`jax.numpy.divmod`: Calculates the integer quotient and remainder of + ``x1`` by ``x2`` element-wise. + + Examples: + >>> jnp.modf(4.8) + (Array(0.8000002, dtype=float32, weak_type=True), Array(4., dtype=float32, weak_type=True)) + >>> x = jnp.array([-3.4, -5.7, 0.6, 1.5, 2.3]) + >>> jnp.modf(x) + (Array([-0.4000001 , -0.6999998 , 0.6 , 0.5 , 0.29999995], dtype=float32), Array([-3., -5., 0., 1., 2.], dtype=float32)) + """ check_arraylike("modf", x) x, = promote_dtypes_inexact(x) if out is not None: @@ -1173,7 +2809,7 @@ def sinc(x: ArrayLike, /) -> Array: def _sinc_maclaurin(k, x): # compute the kth derivative of x -> sin(x)/x evaluated at zero (since we # compute the monomial term in the jvp rule) - # TODO(mattjj): see https://github.com/google/jax/issues/10750 + # TODO(mattjj): see https://github.com/jax-ml/jax/issues/10750 if k % 2: return x * 0 else: @@ -1183,3 +2819,53 @@ def _sinc_maclaurin(k, x): def _sinc_maclaurin_jvp(k, primals, tangents): (x,), (t,) = primals, tangents return _sinc_maclaurin(k, x), _sinc_maclaurin(k + 1, x) * t + + +def _logical_and_reduce(a: ArrayLike, axis: int = 0, dtype: DTypeLike | None = None, + out: None = None, keepdims: bool = False, initial: ArrayLike | None = None, + where: ArrayLike | None = None): + if initial is not None: + raise ValueError("initial argument not supported in jnp.logical_and.reduce()") + result = reductions.all(a, axis=axis, out=out, keepdims=keepdims, where=where) + return result if dtype is None else result.astype(dtype) + + +def _logical_or_reduce(a: ArrayLike, axis: int = 0, dtype: DTypeLike | None = None, + out: None = None, keepdims: bool = False, initial: ArrayLike | None = None, + where: ArrayLike | None = None): + if initial is not None: + raise ValueError("initial argument not supported in jnp.logical_or.reduce()") + result = reductions.any(a, axis=axis, out=out, keepdims=keepdims, where=where) + return result if dtype is None else result.astype(dtype) + +def _add_at(a: Array, indices: Any, b: ArrayLike): + if a.dtype == bool: + a = a.astype('int32') + b = lax.convert_element_type(b, bool).astype('int32') + return a.at[indices].add(b).astype(bool) + return a.at[indices].add(b) + +def _multiply_at(a: Array, indices: Any, b: ArrayLike): + if a.dtype == bool: + a = a.astype('int32') + b = lax.convert_element_type(b, bool).astype('int32') + return a.at[indices].mul(b).astype(bool) + else: + return a.at[indices].mul(b) + +# Generate ufunc interfaces for several common binary functions. +# We start with binary ufuncs that have well-defined identities.' +# TODO(jakevdp): wrap more ufuncs. Possibly define a decorator for convenience? +# TODO(jakevdp): optimize some implementations. +# - define add.at/multiply.at in terms of scatter_add/scatter_mul +# - define add.reduceat/multiply.reduceat in terms of segment_sum/segment_prod +# - define all monoidal reductions in terms of lax.reduce +add = ufunc(_add, name="add", nin=2, nout=1, identity=0, call=_add, reduce=reductions.sum, accumulate=reductions.cumsum, at=_add_at) +multiply = ufunc(_multiply, name="multiply", nin=2, nout=1, identity=1, call=_multiply, reduce=reductions.prod, accumulate=reductions.cumprod, at=_multiply_at) +bitwise_and = ufunc(_bitwise_and, name="bitwise_and", nin=2, nout=1, identity=-1, call=_bitwise_and) +bitwise_or = ufunc(_bitwise_or, name="bitwise_or", nin=2, nout=1, identity=0, call=_bitwise_or) +bitwise_xor = ufunc(_bitwise_xor, name="bitwise_xor", nin=2, nout=1, identity=0, call=_bitwise_xor) +logical_and = ufunc(_logical_and, name="logical_and", nin=2, nout=1, identity=True, call=_logical_and, reduce=_logical_and_reduce) +logical_or = ufunc(_logical_or, name="logical_or", nin=2, nout=1, identity=False, call=_logical_or, reduce=_logical_or_reduce) +logical_xor = ufunc(_logical_xor, name="logical_xor", nin=2, nout=1, identity=False, call=_logical_xor) +negative = ufunc(_negative, name="negative", nin=1, nout=1, call=_negative) diff --git a/jax/_src/numpy/util.py b/jax/_src/numpy/util.py index 09ff99cb40a1..9c9bc5d389e1 100644 --- a/jax/_src/numpy/util.py +++ b/jax/_src/numpy/util.py @@ -111,19 +111,11 @@ def _parse_parameters(body: str) -> dict[str, str]: return {p.partition(' : ')[0].partition(', ')[0]: p for p in parameters} -def _parse_extra_params(extra_params: str) -> dict[str, str]: - """Parse the extra parameters passed to implements()""" - parameters = _parameter_break.split(extra_params.strip('\n')) - return {p.partition(' : ')[0].partition(', ')[0]: p for p in parameters} - - def implements( original_fun: Callable[..., Any] | None, update_doc: bool = True, - lax_description: str = "", sections: Sequence[str] = ('Parameters', 'Returns', 'References'), skip_params: Sequence[str] = (), - extra_params: str | None = None, module: str | None = None, ) -> Callable[[_T], _T]: """Decorator for JAX functions which implement a specified NumPy function. @@ -139,15 +131,10 @@ def implements( update_doc: whether to transform the numpy docstring to remove references of parameters that are supported by the numpy version but not the JAX version. If False, include the numpy docstring verbatim. - lax_description: a string description that will be added to the beginning of - the docstring. sections: a list of sections to include in the docstring. The default is ["Parameters", "Returns", "References"] skip_params: a list of strings containing names of parameters accepted by the function that should be skipped in the parameter list. - extra_params: an optional string containing additional parameter descriptions. - When ``update_doc=True``, these will be added to the list of parameter - descriptions in the updated doc. module: an optional string specifying the module from which the original function is imported. This is useful for objects such as ufuncs, where the module cannot be determined from the original function itself. @@ -156,8 +143,6 @@ def decorator(wrapped_fun): wrapped_fun.__np_wrapped__ = original_fun # Allows this pattern: @implements(getattr(np, 'new_function', None)) if original_fun is None: - if lax_description: - wrapped_fun.__doc__ = lax_description return wrapped_fun docstr = getattr(original_fun, "__doc__", None) name = getattr(original_fun, "__name__", getattr(wrapped_fun, "__name__", str(wrapped_fun))) @@ -176,8 +161,6 @@ def decorator(wrapped_fun): code = getattr(getattr(wrapped_fun, "__wrapped__", wrapped_fun), "__code__", None) # Remove unrecognized parameter descriptions. parameters = _parse_parameters(parsed.sections['Parameters']) - if extra_params: - parameters.update(_parse_extra_params(extra_params)) parameters = {p: desc for p, desc in parameters.items() if (code is None or p in code.co_varnames) and p not in skip_params} @@ -193,8 +176,6 @@ def decorator(wrapped_fun): docstr = parsed.summary.strip() + "\n" if parsed.summary else "" docstr += f"\nLAX-backend implementation of :func:`{name}`.\n" - if lax_description: - docstr += "\n" + lax_description.strip() + "\n" docstr += "\n*Original docstring below.*\n" # We remove signatures from the docstrings, because they redundant at best and diff --git a/jax/_src/numpy/vectorize.py b/jax/_src/numpy/vectorize.py index dc368367e14e..e7a0e2142327 100644 --- a/jax/_src/numpy/vectorize.py +++ b/jax/_src/numpy/vectorize.py @@ -215,48 +215,49 @@ def vectorize(pyfunc, *, excluded=frozenset(), signature=None): Returns: Vectorized version of the given function. - Here are a few examples of how one could write vectorized linear algebra - routines using :func:`vectorize`: - - >>> from functools import partial - - >>> @partial(jnp.vectorize, signature='(k),(k)->(k)') - ... def cross_product(a, b): - ... assert a.shape == b.shape and a.ndim == b.ndim == 1 - ... return jnp.array([a[1] * b[2] - a[2] * b[1], - ... a[2] * b[0] - a[0] * b[2], - ... a[0] * b[1] - a[1] * b[0]]) - - >>> @partial(jnp.vectorize, signature='(n,m),(m)->(n)') - ... def matrix_vector_product(matrix, vector): - ... assert matrix.ndim == 2 and matrix.shape[1:] == vector.shape - ... return matrix @ vector - - These functions are only written to handle 1D or 2D arrays (the ``assert`` - statements will never be violated), but with vectorize they support - arbitrary dimensional inputs with NumPy style broadcasting, e.g., - - >>> cross_product(jnp.ones(3), jnp.ones(3)).shape - (3,) - >>> cross_product(jnp.ones((2, 3)), jnp.ones(3)).shape - (2, 3) - >>> cross_product(jnp.ones((1, 2, 3)), jnp.ones((2, 1, 3))).shape - (2, 2, 3) - >>> matrix_vector_product(jnp.ones(3), jnp.ones(3)) # doctest: +IGNORE_EXCEPTION_DETAIL - Traceback (most recent call last): - ValueError: input with shape (3,) does not have enough dimensions for all - core dimensions ('n', 'k') on vectorized function with excluded=frozenset() - and signature='(n,k),(k)->(k)' - >>> matrix_vector_product(jnp.ones((2, 3)), jnp.ones(3)).shape - (2,) - >>> matrix_vector_product(jnp.ones((2, 3)), jnp.ones((4, 3))).shape - (4, 2) - - Note that this has different semantics than `jnp.matmul`: - - >>> jnp.matmul(jnp.ones((2, 3)), jnp.ones((4, 3))) # doctest: +IGNORE_EXCEPTION_DETAIL - Traceback (most recent call last): - TypeError: dot_general requires contracting dimensions to have the same shape, got [3] and [4]. + Examples: + Here are a few examples of how one could write vectorized linear algebra + routines using :func:`vectorize`: + + >>> from functools import partial + + >>> @partial(jnp.vectorize, signature='(k),(k)->(k)') + ... def cross_product(a, b): + ... assert a.shape == b.shape and a.ndim == b.ndim == 1 + ... return jnp.array([a[1] * b[2] - a[2] * b[1], + ... a[2] * b[0] - a[0] * b[2], + ... a[0] * b[1] - a[1] * b[0]]) + + >>> @partial(jnp.vectorize, signature='(n,m),(m)->(n)') + ... def matrix_vector_product(matrix, vector): + ... assert matrix.ndim == 2 and matrix.shape[1:] == vector.shape + ... return matrix @ vector + + These functions are only written to handle 1D or 2D arrays (the ``assert`` + statements will never be violated), but with vectorize they support + arbitrary dimensional inputs with NumPy style broadcasting, e.g., + + >>> cross_product(jnp.ones(3), jnp.ones(3)).shape + (3,) + >>> cross_product(jnp.ones((2, 3)), jnp.ones(3)).shape + (2, 3) + >>> cross_product(jnp.ones((1, 2, 3)), jnp.ones((2, 1, 3))).shape + (2, 2, 3) + >>> matrix_vector_product(jnp.ones(3), jnp.ones(3)) # doctest: +IGNORE_EXCEPTION_DETAIL + Traceback (most recent call last): + ValueError: input with shape (3,) does not have enough dimensions for all + core dimensions ('n', 'k') on vectorized function with excluded=frozenset() + and signature='(n,k),(k)->(k)' + >>> matrix_vector_product(jnp.ones((2, 3)), jnp.ones(3)).shape + (2,) + >>> matrix_vector_product(jnp.ones((2, 3)), jnp.ones((4, 3))).shape + (4, 2) + + Note that this has different semantics than `jnp.matmul`: + + >>> jnp.matmul(jnp.ones((2, 3)), jnp.ones((4, 3))) # doctest: +IGNORE_EXCEPTION_DETAIL + Traceback (most recent call last): + TypeError: dot_general requires contracting dimensions to have the same shape, got [3] and [4]. """ if any(not isinstance(exclude, (str, int)) for exclude in excluded): raise TypeError("jax.numpy.vectorize can only exclude integer or string arguments, " diff --git a/jax/_src/pallas/BUILD b/jax/_src/pallas/BUILD index c0fa02131bc8..4ff7062ac1e8 100644 --- a/jax/_src/pallas/BUILD +++ b/jax/_src/pallas/BUILD @@ -21,7 +21,7 @@ load( package( default_applicable_licenses = [], default_visibility = [ - "//:__subpackages__", + "//jax:internal", ], ) diff --git a/jax/_src/pallas/core.py b/jax/_src/pallas/core.py index e99510f8d499..e817369a50c5 100644 --- a/jax/_src/pallas/core.py +++ b/jax/_src/pallas/core.py @@ -23,7 +23,7 @@ import functools import itertools import threading -from typing import Any, Hashable, Union +from typing import Any, ClassVar, Hashable, Protocol, Union, runtime_checkable import warnings import jax @@ -31,6 +31,7 @@ from jax._src import config from jax._src import core as jax_core from jax._src import deprecations +from jax._src import dtypes from jax._src import linear_util as lu from jax._src import mesh as mesh_lib from jax._src import state @@ -59,6 +60,22 @@ def __repr__(self): GridMappingGrid = tuple[int | DynamicGridDim, ...] OriginStr = str # The origin of a block spec, e.g. input[2]["field"] +# Datatype for semaphore values in interpret mode. +# For now, we choose a relatively uncommon datatype (i16) so it is more easily +# identifiable in kernels. +# TODO(justinfu): Handle semaphores with a custom extended dtype. +SEMAPHORE_INTERPRET_DTYPE = jnp.int16 +SEMAPHORE_MAX_VALUE = jnp.iinfo(SEMAPHORE_INTERPRET_DTYPE).max + + +@runtime_checkable +class CompilerParams(Protocol): + """Base class for compiler parameters.""" + PLATFORM: ClassVar[str] + + # Subclasses must be dataclasses. + __dataclass_fields__: ClassVar[dict[str, dataclasses.Field[Any]]] + @dataclasses.dataclass(frozen=True) class NameAndSrcInfo: @@ -98,21 +115,112 @@ def from_pallas_call(pallas_call_name: str | None, " ".join(src_info_parts[1:])) -# Pytrees of jax.ShapeDtypeStruct -ShapeDtypeStructTree = tuple[jax.ShapeDtypeStruct, ...] - split_list = util.split_list map, unsafe_map = util.safe_map, map zip, unsafe_zip = util.safe_zip, zip +class ShapedArrayWithMemorySpace(jax_core.ShapedArray): + __slots__ = ["memory_space"] + + def __init__(self, shape, dtype, weak_type=False, sharding=None, + memory_space=None): + super().__init__(shape, dtype, weak_type=weak_type, sharding=sharding) + self.memory_space = memory_space + + def __eq__(self, other): + return super().__eq__(other) and self.memory_space == other.memory_space + + def __hash__(self): + return hash(( + self.shape, + self.dtype, + self.weak_type, + getattr(self, "sharding", None), + self.memory_space, + )) + + def at_least_vspace(self): + """Vector space method needed for AD.""" + raise NotImplementedError + + def join(self, other): + raise NotImplementedError + + def str_short(self, short_dtypes=False): + dt_str = \ + dtypes.short_dtype_name(self.dtype) if short_dtypes else self.dtype.name + dt_str = dt_str.replace("void", "float0") + shapestr = ",".join(map(str, self.shape)) + if hasattr(self, "sharding"): + sharding_str = f"{dt_str}[{shapestr}]({self.sharding})" + else: + sharding_str = "" + memoryspace_str = ( + "" if self.memory_space is None else f"<{self.memory_space}>" + ) + return f"{dt_str}{memoryspace_str}[{shapestr}]{sharding_str}" + + def update( + self, + shape=None, + dtype=None, + weak_type=None, + sharding=None, + memory_space=None, + ): + if shape is None: + shape = self.shape + if dtype is None: + dtype = self.dtype + if weak_type is None: + weak_type = self.weak_type + if sharding is None: + sharding = getattr(self, "sharding", None) + if memory_space is None: + memory_space = self.memory_space + return ShapedArrayWithMemorySpace( + shape, dtype, weak_type, sharding=sharding, memory_space=memory_space + ) +mlir.ir_type_handlers[ShapedArrayWithMemorySpace] = mlir._array_ir_types + + +@dataclasses.dataclass(frozen=True) +class MemoryRef: + """Like jax.ShapeDtypeStruct but with memory spaces.""" + shape: tuple[int, ...] + dtype: jnp.dtype + # TODO(b/368122763): Unify memory space types across backends + memory_space: Any + + def get_array_aval(self) -> jax_core.ShapedArray: + dtype = self.dtype + if not isinstance(dtype, (jnp.dtype, dtypes.ExtendedDType)): + dtype = jnp.dtype(dtype) + return ShapedArrayWithMemorySpace( + self.shape, dtype, memory_space=self.memory_space + ) + + def get_ref_aval(self) -> AbstractMemoryRef: + # TODO(sharadmv): Clean this up. ShapedArrayWithMemorySpace fails when we + # try to apply JAX ops to it. + return AbstractMemoryRef( + jax_core.ShapedArray(self.shape, self.dtype), self.memory_space) + + class AbstractMemoryRef(state.AbstractRef): __slots__ = ["inner_aval", "memory_space"] - def __init__(self, inner_aval: jax_core.AbstractValue, - memory_space: Any): - assert isinstance(inner_aval, jax_core.ShapedArray) + inner_aval: jax_core.ShapedArray + + def __init__(self, inner_aval: jax_core.ShapedArray, memory_space: Any): + if isinstance(inner_aval, ShapedArrayWithMemorySpace): + if inner_aval.memory_space is not None: + assert inner_aval.memory_space == memory_space, ( + f"Mismatched memory spaces: {inner_aval.memory_space=}," + f" {memory_space=}" + ) self.inner_aval = inner_aval self.memory_space = memory_space @@ -129,9 +237,9 @@ def update(self, inner_aval=None, memory_space=None): memory_space = self.memory_space if memory_space is None else memory_space return AbstractMemoryRef(inner_aval, memory_space) - def at_least_vspace(self): + def to_tangent_aval(self): return AbstractMemoryRef( - self.inner_aval.at_least_vspace(), self.memory_space) + self.inner_aval.to_tangent_aval(), self.memory_space) def __eq__(self, other): return (type(self) is type(other) and self.inner_aval == other.inner_aval @@ -142,11 +250,12 @@ def __hash__(self): class MemorySpace(enum.Enum): - """ Logical, device-agnostic memory spaces. + """Logical, device-agnostic memory spaces. Each memory space will be translated to a device-specific memory type during lowering. """ + ANY = "any" # Unrestricted memory space (usually HBM) ERROR = "error" # Memory space for checkify errors. INDEX = "index" # Memory space for scalar prefetch arguments. @@ -167,9 +276,7 @@ class PallasGridContext: mapped_dims: tuple[int, ...] def size(self, axis: int) -> int | DynamicGridDim: - valid_grid = tuple( - s for i, s in enumerate(self.grid) if i not in self.mapped_dims - ) + valid_grid = tuple(self.grid) try: size = valid_grid[axis] except IndexError as e: @@ -182,6 +289,8 @@ def size(self, axis: int) -> int | DynamicGridDim: @dataclasses.dataclass class PallasTracingEnv(threading.local): grid_context: PallasGridContext | None = None + grid_env_stack: list[GridEnv] = dataclasses.field(default_factory=list) + is_interpret_mode: bool = False _pallas_tracing_env = PallasTracingEnv() @@ -202,22 +311,35 @@ class GridAxis: # Stores the kernel execution position and the size along grid axes. GridEnv = Sequence[GridAxis] -_grid_env_stack: list[GridEnv] = [] - - @contextlib.contextmanager def grid_env(env: GridEnv) -> Iterator[None]: - _grid_env_stack.append(env) + _pallas_tracing_env.grid_env_stack.append(env) try: yield finally: - _grid_env_stack.pop() + _pallas_tracing_env.grid_env_stack.pop() def current_grid_env() -> GridEnv | None: - if not _grid_env_stack: + if not _pallas_tracing_env.grid_env_stack: return None - return _grid_env_stack[-1] + return _pallas_tracing_env.grid_env_stack[-1] + + +@contextlib.contextmanager +def interpret_mode_env(interpret_mode: bool) -> Iterator[None]: + prev_interpret = _pallas_tracing_env.is_interpret_mode + if interpret_mode: + _pallas_tracing_env.is_interpret_mode = True + try: + yield + finally: + if interpret_mode: + _pallas_tracing_env.is_interpret_mode = prev_interpret + +def is_interpret_mode() -> bool: + """Returns whether the kernel is executing in interpret mode.""" + return _pallas_tracing_env.is_interpret_mode class Mapped: @@ -287,6 +409,105 @@ def __init__( self.memory_space = memory_space self.indexing_mode = indexing_mode + def to_block_mapping( + self, + origin: OriginStr, + array_aval: jax_core.ShapedArray, + *, + # Inputs for the index_map + index_map_avals: Sequence[jax_core.AbstractValue], + index_map_tree: tree_util.PyTreeDef, + grid: GridMappingGrid, + mapped_dims: tuple[int, ...], + ) -> BlockMapping: + if self.index_map is None: + index_map_func = lambda *args: (0,) * len(array_aval.shape) + else: + index_map_func = self.index_map + if self.block_shape is None: + block_shape = array_aval.shape + else: + block_shape = self.block_shape + if len(array_aval.shape) != len(block_shape): + raise ValueError( + f"Block shape for {origin} (= {block_shape}) " + "must have the same number of dimensions as the " + f"array shape {array_aval.shape}." + ) + + unmapped_block_shape = tuple(s for s in block_shape if s is not None) + block_array_aval = array_aval.update(shape=unmapped_block_shape) + if isinstance(array_aval, jax_core.DShapedArray): + # Get the "max" shape for the ragged array. + block_array_aval = jax_core.ShapedArray( + block_array_aval.shape, + block_array_aval.dtype, + block_array_aval.weak_type, + ) + block_aval = AbstractMemoryRef(block_array_aval, self.memory_space) + + if not jax_core.is_constant_shape(block_aval.shape): + raise ValueError( + "shape polymorphism for Pallas does not support " + "dynamically-shaped blocks. " + f"Block spec for {origin} has block_shape: {block_aval.shape}" + ) + + flat_index_map_fun, index_map_out_tree_thunk = api_util.flatten_fun( + lu.wrap_init(index_map_func), index_map_tree + ) + debug = pe.debug_info( + index_map_func, + index_map_tree, + index_map_out_tree_thunk, + False, + "pallas_call index_map", + ) + index_map_src_info = NameAndSrcInfo.from_pallas_call( + None, debug.func_src_info + ) + with tracing_grid_env(grid, mapped_dims): + jaxpr, out_avals, consts, () = pe.trace_to_jaxpr_dynamic( + flat_index_map_fun, index_map_avals, debug_info=debug + ) + mapped_block_shape = tuple(mapped if s is None else s for s in block_shape) + if len(out_avals) != len(block_shape): + raise ValueError( + f"Index map function {index_map_src_info} for " + f"{origin} must return " + f"{len(block_shape)} values to match {block_shape=}. " + f"Currently returning {len(out_avals)} values." + ) + for i, ov in enumerate(out_avals): + if ov.shape or ov.dtype not in [jnp.int32, jnp.int64]: + raise ValueError( + f"Index map function {index_map_src_info} for " + f"{origin} must return integer scalars. Output[{i}] has type " + f"{ov}." + ) + + if consts: + raise ValueError( + f"Index map function {index_map_src_info} for " + f"{origin} must not capture constants: {consts}" + ) + + array_aval_shape = _max_shape_from_aval(array_aval) + + mapping = BlockMapping( + block_shape=mapped_block_shape, + block_aval=block_aval, + index_map_jaxpr=jax_core.ClosedJaxpr(jaxpr, consts), + index_map_src_info=index_map_src_info, + indexing_mode=self.indexing_mode, + array_shape_dtype=jax.ShapeDtypeStruct( + array_aval_shape, array_aval.dtype + ), + origin=origin, + ) + mapping.check_invariants() + return mapping + class NoBlockSpec: def __repr__(self): @@ -298,6 +519,14 @@ def __repr__(self): # BlockSpecTree = Sequence[BlockSpec | NoBlockSpec, ...] | NoBlockSpec BlockSpecTree = Any + +class MemoryRefTransform(Protocol): + """Transforms a memory reference on load or store.""" + + def __call__(self, block_aval: AbstractMemoryRef) -> AbstractMemoryRef: + raise NotImplementedError("Abstract evaluation not implemented.") + + @dataclasses.dataclass(frozen=True) class BlockMapping: """An internal canonicalized version of BlockSpec. @@ -311,6 +540,7 @@ class BlockMapping: indexing_mode: IndexingMode array_shape_dtype: jax.ShapeDtypeStruct # The whole array origin: OriginStr + transforms: Sequence[MemoryRefTransform] = () def check_invariants(self) -> None: if not config.enable_checks.value: return @@ -323,7 +553,10 @@ def check_invariants(self) -> None: ) assert not self.index_map_jaxpr.consts - assert len(self.block_shape) == len(self.index_map_jaxpr.out_avals) + assert len(self.block_shape) == len(self.index_map_jaxpr.out_avals), ( + self.block_shape, + self.index_map_jaxpr.out_avals, + ) assert all(ov.shape == () and (ov.dtype == jnp.int32 or ov.dtype == jnp.int64) for ov in self.index_map_jaxpr.out_avals), ( @@ -334,6 +567,14 @@ def replace(self, **kwargs): new_self.check_invariants() return new_self + @property + def ref_aval(self) -> AbstractMemoryRef: + """Returns the abstract value of the Ref after transformations.""" + block_aval = self.block_aval + for transform in self.transforms: + block_aval = transform(block_aval) + return block_aval + def compute_start_indices_interpret(self, loop_idx, *args): discharged_jaxpr, discharged_consts = state_discharge.discharge_state( self.index_map_jaxpr.jaxpr, self.index_map_jaxpr.consts @@ -407,6 +648,8 @@ class GridMapping: num_inputs: int num_outputs: int num_scratch_operands: int + get_grid_indices: Callable | None = None + local_grid_env: Callable | None = None def check_invariants(self) -> None: if not config.enable_checks.value: return @@ -427,8 +670,8 @@ def check_invariants(self) -> None: assert len(index_map_args) >= len(self.grid) for i in range(len(self.grid)): index_map_arg = index_map_args[i] - assert index_map_arg.shape == () - assert index_map_arg.dtype == jnp.int32 + assert index_map_arg.shape == (), f"index_map_arg: {index_map_arg}" + assert index_map_arg.dtype == jnp.int32, f"index_map_arg: {index_map_arg}" assert len(self.vmapped_dims) <= len(self.grid) for i in self.vmapped_dims: @@ -439,8 +682,11 @@ def check_invariants(self) -> None: for bm in self.block_mappings: bm.check_invariants() - assert tuple(self.index_map_avals) == tuple(bm.index_map_jaxpr.in_avals), ( + assert tuple(self.index_map_avals) == tuple( + bm.index_map_jaxpr.in_avals + ), ( self.index_map_avals, + "|", bm.index_map_jaxpr.in_avals, ) @@ -506,9 +752,10 @@ def slice_scratch_ops(self): @property def in_shapes(self) -> Iterable[jax.ShapeDtypeStruct]: """The shapes of *index, *inputs.""" - index_shapes = (jax.ShapeDtypeStruct(ia.inner_aval.shape, - ia.inner_aval.dtype) - for ia in self.index_map_avals[len(self.grid):]) + index_shapes = ( + jax.ShapeDtypeStruct(ia.shape, ia.dtype) + for ia in self.index_map_avals[len(self.grid) :] + ) inputs_shapes = ( bm.array_shape_dtype for bm in self.block_mappings[:self.num_inputs]) @@ -532,6 +779,25 @@ def _is_valid_grid_dim(dim: int | jax.Array) -> bool: return True return jax_core.is_dim(dim) + +def _max_shape_from_aval(array_aval: jax_core.ShapedArray): + array_aval_shape = list(array_aval.shape) + for i, s in enumerate(array_aval.shape): + try: + aval = jax_core.get_aval(s) + if isinstance(aval, jax_core.DShapedArray): + array_aval_shape[i] = aval.dtype.bound + except OverflowError as e: + # Note - there are annoying cases where on 32 bit hardware, + # a flattened index space may overflow - for these cases, + # we just take the shape as is. + # In most places, this is totally sound to do. + # For ragged/jumble inputs, this will fail downstream. + return array_aval.shape + + return tuple(array_aval_shape) + + def _convert_block_spec_to_block_mapping( block_spec: BlockSpec, origin: OriginStr, @@ -545,77 +811,29 @@ def _convert_block_spec_to_block_mapping( ) -> BlockMapping: if block_spec is no_block_spec: block_spec = BlockSpec(None, None) - if block_spec.index_map is None: - index_map_func = lambda *args: (0,) * len(array_aval.shape) - else: - index_map_func = block_spec.index_map - if block_spec.block_shape is None: - block_shape = array_aval.shape - else: - block_shape = block_spec.block_shape - if len(array_aval.shape) != len(block_shape): - raise ValueError( - f"Block shape for {origin} (= {block_shape}) " - "must have the same number of dimensions as the " - f"array shape {array_aval.shape}.") - - unmapped_block_shape = tuple(s for s in block_shape if s is not None) - block_aval = AbstractMemoryRef(array_aval.update(shape=unmapped_block_shape), - block_spec.memory_space) - - if not jax_core.is_constant_shape(block_aval.shape): - raise ValueError( - "shape polymorphism for Pallas does not support " - "dynamically-shaped blocks. " - f"Block spec for {origin} has block_shape: {block_aval.shape}") - - flat_index_map_fun, index_map_out_tree_thunk = api_util.flatten_fun( - lu.wrap_init(index_map_func), index_map_tree) - debug = pe.debug_info(index_map_func, index_map_tree, index_map_out_tree_thunk, - False, "pallas_call index_map") - index_map_src_info = NameAndSrcInfo.from_pallas_call(None, - debug.func_src_info) - with tracing_grid_env(grid, mapped_dims): - jaxpr, out_avals, consts, () = pe.trace_to_jaxpr_dynamic(flat_index_map_fun, - index_map_avals, - debug_info=debug) - mapped_block_shape = tuple( - mapped if s is None else s for s in block_shape) - if len(out_avals) != len(block_shape): - raise ValueError( - f"Index map function {index_map_src_info} for " - f"{origin} must return " - f"{len(block_shape)} values to match {block_shape=}. " - f"Currently returning {len(out_avals)} values.") - for i, ov in enumerate(out_avals): - if ov.shape or ov.dtype not in [jnp.int32, jnp.int64]: - raise ValueError( - f"Index map function {index_map_src_info} for " - f"{origin} must return integer scalars. Output[{i}] has type " - f"{ov}.") + return block_spec.to_block_mapping( + origin, + array_aval, + index_map_avals=index_map_avals, + index_map_tree=index_map_tree, + grid=grid, + mapped_dims=mapped_dims, + ) +index_map_grid_aval = jax_core.ShapedArray((), jnp.int32) - if consts: - raise ValueError( - f"Index map function {index_map_src_info} for " - f"{origin} must not capture constants: {consts}") +class ScratchShape(Protocol): + def get_array_aval(self) -> jax_core.AbstractValue: + ... + def get_ref_aval(self) -> state.AbstractRef: + ... - mapping = BlockMapping( - block_shape=mapped_block_shape, - block_aval=block_aval, - index_map_jaxpr=jax_core.ClosedJaxpr(jaxpr, consts), - index_map_src_info=index_map_src_info, - indexing_mode=block_spec.indexing_mode, - array_shape_dtype=jax.ShapeDtypeStruct(array_aval.shape, array_aval.dtype), - origin=origin, - ) - mapping.check_invariants() - return mapping -index_map_grid_aval = jax_core.ShapedArray((), jnp.int32) +ScratchShapeTree = Sequence[Union[ScratchShape, "ScratchShapeTree"]] + -@dataclasses.dataclass(init=False) +@dataclasses.dataclass(init=False, kw_only=True) class GridSpec: """Encodes the grid parameters for :func:`jax.experimental.pallas.pallas_call`. @@ -628,12 +846,14 @@ class GridSpec: grid_names: tuple[Hashable, ...] | None in_specs: BlockSpecTree out_specs: BlockSpecTree + scratch_shapes: ScratchShapeTree = () def __init__( self, grid: Grid = (), in_specs: BlockSpecTree = no_block_spec, out_specs: BlockSpecTree = no_block_spec, + scratch_shapes: ScratchShapeTree = (), ): # Be more lenient for in/out_specs if isinstance(in_specs, list): @@ -645,6 +865,7 @@ def __init__( self.in_specs = in_specs self.out_specs = out_specs + self.scratch_shapes = tuple(scratch_shapes) grid_names = None if isinstance(grid, int): @@ -660,9 +881,6 @@ def __init__( self.grid = grid # type: ignore self.grid_names = grid_names - def _make_scratch_aval(self, obj: object) -> jax_core.AbstractValue: - assert False # Not needed in GridSpec - def _make_scalar_ref_aval(self, aval): assert False # Not needed in GridSpec @@ -707,12 +925,10 @@ def get_grid_mapping( else: num_flat_scalar_prefetch = 0 jaxpr_scalar_ref_avals = () - - scratch_shapes: tuple[Any, ...] = getattr(grid_spec, "scratch_shapes", ()) - if scratch_shapes: + if grid_spec.scratch_shapes: flat_scratch_shapes, scratch_tree = tree_util.tree_flatten( - scratch_shapes) - flat_scratch_avals = map(grid_spec._make_scratch_aval, flat_scratch_shapes) + grid_spec.scratch_shapes) + flat_scratch_avals = map(lambda s: s.get_ref_aval(), flat_scratch_shapes) num_flat_scratch_operands = len(flat_scratch_avals) jaxpr_scratch_avals = tree_util.tree_unflatten( scratch_tree, flat_scratch_avals) @@ -779,11 +995,11 @@ def get_grid_mapping( num_scratch_operands=num_flat_scratch_operands, ) grid_mapping.check_invariants() - in_ref_avals = [bm.block_aval for bm in in_block_mappings] + in_ref_avals = [bm.ref_aval for bm in in_block_mappings] jaxpr_in_ref_avals = tree_util.tree_unflatten(in_tree, in_ref_avals) jaxpr_in_avals = (*jaxpr_scalar_ref_avals, *jaxpr_in_ref_avals) - out_ref_avals = [bm.block_aval for bm in out_block_mappings] + out_ref_avals = [bm.ref_aval for bm in out_block_mappings] jaxpr_out_avals = tree_util.tree_unflatten(out_tree, out_ref_avals) if not isinstance(jaxpr_out_avals, (tuple, list)): jaxpr_out_avals = (jaxpr_out_avals,) diff --git a/jax/_src/pallas/mosaic/BUILD b/jax/_src/pallas/mosaic/BUILD index 57dad7793116..ae76a00a6c17 100644 --- a/jax/_src/pallas/mosaic/BUILD +++ b/jax/_src/pallas/mosaic/BUILD @@ -20,7 +20,7 @@ load("//jaxlib:jax.bzl", "py_deps") package( default_applicable_licenses = [], default_visibility = [ - "//:__subpackages__", + "//jax:internal", ], ) @@ -43,6 +43,16 @@ py_library( ], ) +py_library( + name = "error_handling", + srcs = ["error_handling.py"], + deps = [ + "//jax:compiler", + "//jax:traceback_util", + "//jax/_src/lib", + ], +) + py_library( name = "primitives", srcs = ["primitives.py"], @@ -71,10 +81,12 @@ py_library( srcs = ["lowering.py"], deps = [ ":core", + ":error_handling", ":primitives", "//jax", "//jax:ad_util", "//jax:core", + "//jax:dtypes", "//jax:mesh", "//jax:mlir", "//jax:mosaic", @@ -95,6 +107,7 @@ py_library( ":primitives", "//jax", "//jax:api_util", + "//jax:pallas", "//jax:util", "//jax/_src/pallas", ] + py_deps("numpy"), diff --git a/jax/_src/pallas/mosaic/core.py b/jax/_src/pallas/mosaic/core.py index 75e5101de142..4ff9d894da8f 100644 --- a/jax/_src/pallas/mosaic/core.py +++ b/jax/_src/pallas/mosaic/core.py @@ -19,14 +19,14 @@ import dataclasses import enum import functools -from typing import Any, Hashable +from typing import Any, ClassVar, Literal import jax from jax._src import core as jax_core from jax._src import dtypes from jax._src import util -import jax.numpy as jnp from jax._src.pallas import core as pallas_core +import jax.numpy as jnp import numpy as np map, unsafe_map = util.safe_map, map @@ -39,14 +39,47 @@ BlockSpecTree = pallas_core.BlockSpecTree GridMapping = pallas_core.GridMapping NoBlockSpec = pallas_core.NoBlockSpec +ScratchShapeTree = pallas_core.ScratchShapeTree AbstractMemoryRef = pallas_core.AbstractMemoryRef no_block_spec = pallas_core.no_block_spec _convert_block_spec_to_block_mapping = pallas_core._convert_block_spec_to_block_mapping split_list = util.split_list +@dataclasses.dataclass(frozen=True) +class TPUCompilerParams(pallas_core.CompilerParams): + """Mosaic TPU compiler parameters. + + Attributes: + dimension_semantics: A list of dimension semantics for each grid + dimension of the kernel. Either "parallel" for dimensions that can + execute in any order, or "arbitrary" for dimensions that must be + executed sequentially. + allow_input_fusion: A list of booleans indicating whether input fusion is + allowed for each argument. + vmem_limit_bytes: Overrides the default VMEM limit for a kernel. Note + that this must be used in conjunction with the + --xla_tpu_scoped_vmem_limit_kib=N flag with N*1kib > vmem_limit_bytes. + collective_id: Indicates which barrier semaphore to use for the kernel. + Note that using the same collective_id does not guarantee that + the same barrier semaphore will be allocated between kernels. + internal_scratch_in_bytes: The size of the internal scratch space used by + Mosaic. + flags: A dictionary of command line flags for the kernel. + serialization_format: The serialization format for the kernel body. + device_type: The device type to compile for. + """ + PLATFORM: ClassVar[str] = "mosaic" + dimension_semantics: Sequence[Literal["parallel", "arbitrary"]] | None = None + allow_input_fusion: Sequence[bool] | None = None + vmem_limit_bytes: int | None = None + collective_id: int | None = None + flags: dict[str, Any] | None = None + internal_scratch_in_bytes: int | None = None + serialization_format: int = 1 + device_type: str | None = None class TPUMemorySpace(enum.Enum): - ANY = "any" + ANY = "any" # TODO(b/368401328): Remove this and just use pl.ANY. VMEM = "vmem" SMEM = "smem" CMEM = "cmem" @@ -57,7 +90,7 @@ def __str__(self) -> str: def __call__(self, shape: tuple[int, ...], dtype: jnp.dtype): # A convenience function for constructing MemoryRef types. - return MemoryRef(shape, dtype, self) + return pallas_core.MemoryRef(shape, dtype, self) class semaphore_dtype(dtypes.extended): pass class semaphore(semaphore_dtype): pass @@ -67,7 +100,11 @@ class barrier_semaphore(semaphore_dtype): pass class AbstractSemaphoreTyRules: @staticmethod def pallas_interpret_element_aval(_) -> jax_core.ShapedArray: - return pallas_core.index_map_grid_aval + return jax_core.ShapedArray((), pallas_core.SEMAPHORE_INTERPRET_DTYPE) + + @staticmethod + def physical_element_aval(_) -> jax_core.ShapedArray: + return jax_core.ShapedArray((), jnp.int32) class AbstractSemaphoreTy(dtypes.ExtendedDType): name: str @@ -109,10 +146,15 @@ def __call__(self, shape: tuple[int, ...]): dtype = BarrierSemaphoreTy() else: dtype = SemaphoreTy() - return MemoryRef(shape, dtype, TPUMemorySpace.SEMAPHORE) + if pallas_core.is_interpret_mode(): + dtype = pallas_core.SEMAPHORE_INTERPRET_DTYPE + return pallas_core.MemoryRef(shape, dtype, TPUMemorySpace.SEMAPHORE) + + def get_array_aval(self) -> pallas_core.ShapedArrayWithMemorySpace: + return self(()).get_array_aval() - def get_aval(self) -> AbstractMemoryRef: - return self(()).get_aval() + def get_ref_aval(self) -> AbstractMemoryRef: + return self(()).get_ref_aval() @dataclasses.dataclass(frozen=True) class AbstractSemaphore(jax_core.AbstractValue): @@ -128,26 +170,9 @@ def join(self, other): jax_core.raise_to_shaped_mappings[AbstractSemaphore] = lambda aval, _: aval -@dataclasses.dataclass(frozen=True) -class MemoryRef: - """Like jax.ShapeDtypeStruct but with memory spaces.""" - shape: tuple[int, ...] - dtype: jnp.dtype - memory_space: TPUMemorySpace = TPUMemorySpace.ANY - - def get_aval(self) -> AbstractMemoryRef: - return AbstractMemoryRef( - jax_core.ShapedArray(self.shape, self.dtype), self.memory_space) - - -@dataclasses.dataclass(init=False, unsafe_hash=True) +@dataclasses.dataclass(init=False, kw_only=True, unsafe_hash=True) class PrefetchScalarGridSpec(pallas_core.GridSpec): - grid: TupleGrid - grid_names: tuple[Hashable, ...] | None num_scalar_prefetch: int - in_specs: pallas_core.BlockSpecTree - out_specs: pallas_core.BlockSpecTree - scratch_shapes: tuple[Any, ...] def __init__( self, @@ -155,9 +180,9 @@ def __init__( grid: Grid = (), in_specs: BlockSpecTree = no_block_spec, out_specs: BlockSpecTree = no_block_spec, - scratch_shapes: Any | Sequence[Any] = () + scratch_shapes: ScratchShapeTree = () ): - super().__init__(grid, in_specs, out_specs) + super().__init__(grid, in_specs, out_specs, scratch_shapes) self.num_scalar_prefetch = num_scalar_prefetch self.scratch_shapes = tuple(scratch_shapes) @@ -165,14 +190,6 @@ def _make_scalar_ref_aval(self, aval): return AbstractMemoryRef(jax_core.ShapedArray(aval.shape, aval.dtype), TPUMemorySpace.SMEM) - def _make_scratch_aval(self, obj: object) -> jax_core.AbstractValue: - if isinstance(obj, MemoryRef): - return obj.get_aval() - if isinstance(obj, SemaphoreType): - return obj.get_aval() - raise ValueError(f"No registered conversion for {type(obj)}. " - "Only VMEM and SemaphoreType are supported.") - @dataclasses.dataclass(frozen=True) class TensorCore: diff --git a/jax/_src/pallas/mosaic/error_handling.py b/jax/_src/pallas/mosaic/error_handling.py new file mode 100644 index 000000000000..f8231f5b24b6 --- /dev/null +++ b/jax/_src/pallas/mosaic/error_handling.py @@ -0,0 +1,158 @@ +# Copyright 2024 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Utilities for raising more informative exceptions from Pallas.""" +from collections import namedtuple +import re +import types +from jax._src import compiler +from jax._src import traceback_util +from jax._src.lib import xla_client +from jax._src.lib.mlir import ir + +# This is a simple ir.Location parsing regex that assumes the string is properly +# formatted coming from Mosaic. +# It will assume everything from the first to last parentheses +# in the string is part of the frame, and does not account for unbalanced +# parentheses. +LOCATION_PATTERN = re.compile( + r'(?Ploc\((?P\".*?\")(?P.*)\))' +) +FRAME_PATTERN = re.compile( + r'(?P\".*?\")\((?P\".*?\"):' + r'(?P[0-9]+):(?P[0-9]+)\)' +) +MLIR_ERR_PREFIX = ( + 'Pallas encountered an internal verification error.' + 'Please file a bug at https://github.com/jax-ml/jax/issues. ' + 'Error details: ' +) + +RawFrame = namedtuple('RawFrame', ['func_name', 'filename', 'lineno', 'colno']) + + +class MosaicError(Exception): + """Error thrown by Pallas when re-raising a Mosaic internal error.""" + + +class VerificationError(MosaicError): + """Error thrown by Pallas when re-raising a verification error.""" + + def __init__(self, message: str): + super().__init__(MLIR_ERR_PREFIX + message) + + +def _handle_xla_runtime_error( + base_err: xla_client.XlaRuntimeError, +) -> MosaicError | None: + """Reformats XLARuntimeError to include a Python traceback.""" + if 'Mosaic' not in str(base_err): + return None + try: + _, frames = parse_location_string(str(base_err)) + except ValueError: + # If no location string is found, skip handling and raise the original + # error. + return None + new_tb = traceback_from_raw_frames(frames) + err_msg = base_err.args[0] + err_msg = redact_locations(err_msg) + new_error = MosaicError(err_msg) + new_error.__traceback__ = traceback_util.filter_traceback(new_tb) + return new_error + + +compiler.register_xla_runtime_error_handler(_handle_xla_runtime_error) + + +def mlir_error_to_verification_error( + base_err: ir.MLIRError) -> VerificationError: + """Reformats MLIRError to include a Python traceback.""" + diagnostic = base_err.error_diagnostics[0] # pytype: disable=attribute-error + def _get_diagnostic_message(diagnostic) -> str: + current_msg = diagnostic.message + for d in diagnostic.notes: + current_msg += "\n " + _get_diagnostic_message(d) + return current_msg + + _, frames = parse_location_string(str(diagnostic.location.attr)) + new_tb = traceback_from_raw_frames(frames) + new_error = VerificationError(_get_diagnostic_message(diagnostic)) + new_error.__traceback__ = traceback_util.filter_traceback(new_tb) + return new_error + + +def redact_locations(err_msg: str) -> str: + """Removes location strings from an error message.""" + for mat in re.finditer(LOCATION_PATTERN, err_msg): + start, end = mat.span('location') + # Remove the entire line containing the location. + line_start = err_msg.rfind('\n', 0, end) + line_start = line_start if line_start >= 0 else start + line_end = err_msg.find('\n', start) + line_end = line_end if line_end >= 0 else end + return err_msg[:line_start] + err_msg[line_end+1:] + return err_msg + + +def parse_location_string(location_string: str) -> tuple[str, list[RawFrame]]: + """Parses a serialized MLIR location. + + Locations strings have the format: + `loc("location_name"())` + + Where is a nested callsite string representing the entire + call stack: + `callsite("fn_name"("filename":lineno:colno) at callsite(...))` + + Args: + location_string: A string serialization of an MLIR location. + + Returns: + A tuple (name, frames) where name is the name of the location and frames + is a list of RawFrame objects representing the Python call stack associated + with the location. + """ + frame_str = '' + loc_name = None + matches = list(re.finditer(LOCATION_PATTERN, location_string)) + if len(matches) > 1: + raise ValueError( + 'More than one location found in string: ', location_string) + for mat in matches: + loc_name = mat.group('eqn_str')[1:-1] + frame_str = mat.group('frames')[1:-1] + if loc_name is None: + raise ValueError(f'Could not find location in string {location_string}') + frames: list[RawFrame] = [] + for mat in re.finditer(FRAME_PATTERN, frame_str): + frames.append( + RawFrame( + mat.group('fun_name')[1:-1], + mat.group('filename')[1:-1], + int(mat.group('lineno')), + int(mat.group('colno')), + ) + ) + return loc_name, frames + + +def traceback_from_raw_frames(frames: list[RawFrame]) -> types.TracebackType: + """Constructs a traceback from a list of RawFrame objects.""" + xla_frames = [ + xla_client.Frame(frame.filename, frame.func_name, -1, frame.lineno + ) # type: ignore [call-arg] + for frame in frames + ] + return xla_client.Traceback.traceback_from_frames(xla_frames) diff --git a/jax/_src/pallas/mosaic/lowering.py b/jax/_src/pallas/mosaic/lowering.py index 082927677c73..46cbe8e4758b 100644 --- a/jax/_src/pallas/mosaic/lowering.py +++ b/jax/_src/pallas/mosaic/lowering.py @@ -40,7 +40,6 @@ from jax._src.interpreters import partial_eval as pe from jax._src.lax import lax as lax_internal from jax._src.lax.control_flow import for_loop -from jax._src.lib import version as jaxlib_version from jax._src.lib.mlir import ir from jax._src.lib.mlir.dialects import arith from jax._src.lib.mlir.dialects import func @@ -48,15 +47,19 @@ from jax._src.lib.mlir.dialects import memref from jax._src.lib.mlir.dialects import scf from jax._src.lib.mlir.dialects import vector -from jax._src.pallas import pallas_call from jax._src.pallas import core as pallas_core +from jax._src.pallas import pallas_call from jax._src.pallas import primitives from jax._src.pallas import utils as pallas_utils from jax._src.pallas.mosaic import core as tpu_core +from jax._src.pallas.mosaic import error_handling from jax._src.pallas.mosaic import primitives as tpu_primitives from jax._src.state import discharge as state_discharge from jax._src.state import indexing from jax._src.state import primitives as state_primitives +from jax._src.state.types import RefBitcaster +from jax._src.state.utils import dtype_bitwidth +from jax._src.typing import DTypeLike from jax._src.util import safe_map from jax._src.util import safe_zip from jax._src.util import split_list @@ -84,12 +87,6 @@ map, unsafe_map = safe_map, map # pylint: disable=redefined-builtin zip, unsafe_zip = safe_zip, zip # pylint: disable=redefined-builtin -UNSIGNED_TO_SIGNED = { - np.dtype('uint8'): np.dtype('int8'), - np.dtype('uint16'): np.dtype('int16'), - np.dtype('uint32'): np.dtype('int32'), - np.dtype('uint64'): np.dtype('int64'), -} @dataclasses.dataclass class MeshContext: @@ -141,15 +138,29 @@ class LoweringRuleContext: replace = dataclasses.replace -def _memory_space_to_tpu_memspace(memory_space: MemorySpace | None - ) -> ir.Attribute: - if memory_space is None: - memory_space = VMEM - elif memory_space == pallas_core.MemorySpace.ERROR: - memory_space = SMEM - elif memory_space == pallas_core.MemorySpace.INDEX: - memory_space = SMEM - return ir.Attribute.parse(f"#tpu.memory_space<{memory_space}>") +def _memory_space_to_tpu_memory_space(memory_space: MemorySpace | None + ) -> TPUMemorySpace: + match memory_space: + case None: + # We pick VMEM as the default one when no memory space is + # specified + return TPUMemorySpace.VMEM + case pallas_core.MemorySpace.ANY: + # Map the general ANY memory space to TPU ANY memory space + return TPUMemorySpace.ANY + case pallas_core.MemorySpace.ERROR | pallas_core.MemorySpace.INDEX: + return TPUMemorySpace.SMEM + case TPUMemorySpace(): + # Leave the memory space unchanged + return memory_space + case _: + raise ValueError("Invalid memory space: {memory_space}") + + +def _memory_space_to_mosaic_attribute(memory_space: MemorySpace | None + ) -> ir.Attribute: + tpu_memory_space = _memory_space_to_tpu_memory_space(memory_space) + return ir.Attribute.parse(f"#tpu.memory_space<{tpu_memory_space}>") def _dtype_to_ir_type(dtype: jnp.dtype, is_kernel_boundary: bool = False) -> ir.Type: @@ -185,7 +196,7 @@ def aval_to_ir_type(aval, sem_type = ir.Type.parse("!tpu.semaphore") else: raise ValueError(f"Cannot allocate {aval.sem_type}.") - memspace = _memory_space_to_tpu_memspace(TPUMemorySpace.SEMAPHORE) + memspace = _memory_space_to_mosaic_attribute(TPUMemorySpace.SEMAPHORE) return ir.MemRefType.get((), sem_type, memory_space=memspace) if dtypes.issubdtype(aval.dtype, dtypes.prng_key): shape = aval.dtype._impl.key_shape @@ -193,13 +204,13 @@ def aval_to_ir_type(aval, memory_space = TPUMemorySpace.SMEM if memory_space != TPUMemorySpace.SMEM: raise ValueError(f"PRNG keys must be stored in SMEM. Got {memory_space}") - memspace = _memory_space_to_tpu_memspace(memory_space) + memspace = _memory_space_to_mosaic_attribute(memory_space) return ir.MemRefType.get(shape, _dtype_to_ir_type(np.dtype(np.uint32)), memory_space=memspace) if isinstance(aval, state.AbstractRef): if shape is None: shape = aval.shape - memspace = _memory_space_to_tpu_memspace(memory_space) + memspace = _memory_space_to_mosaic_attribute(memory_space) return ir.MemRefType.get(shape, _dtype_to_ir_type(aval.dtype, is_kernel_boundary=True), memory_space=memspace) @@ -297,6 +308,7 @@ def __init__(self, jaxpr: jax_core.Jaxpr, grid_mapping: pallas_core.GridMapping, self.jaxpr = jaxpr self.block_mappings = grid_mapping.block_mappings self.mapped_dims = grid_mapping.vmapped_dims + # TODO(mvoz): Generalize to not need this user_grid = tuple( g for i, g in enumerate(self.grid) if i not in self.mapped_dims ) @@ -344,9 +356,19 @@ def __init__(self, jaxpr: jax_core.Jaxpr, grid_mapping: pallas_core.GridMapping, for _ in range(len(self.grid)) ]) self._prepare_mesh_info(mesh) - def _get_grid_indices(indices): - return indices - self.get_grid_indices = _get_grid_indices + + if grid_mapping.get_grid_indices is None: + + def _get_grid_indices(indices, maybe_include_mapped_dims: bool): + if maybe_include_mapped_dims: + return indices + return tuple( + idx for i, idx in enumerate(indices) if i not in self.mapped_dims + ) + + self.get_grid_indices = _get_grid_indices + else: + self.get_grid_indices = grid_mapping.get_grid_indices def _prepare_mesh_info(self, mesh: mesh_lib.Mesh | None): if not self.has_communication: @@ -368,7 +390,8 @@ def _prepare_mesh_info(self, mesh: mesh_lib.Mesh | None): mesh_strides = pallas_utils.strides_from_shape(tuple( mesh.shape[a] for a in axis_names )) - self.mesh_info = MeshInfo(mesh.device_ids.shape, axis_names, mesh_strides) + mesh_shape = tuple(mesh.shape.values()) + self.mesh_info = MeshInfo(mesh_shape, axis_names, mesh_strides) def maybe_compress_grid(self): # If we have many leading parallel dimensions, we should "compress" them @@ -415,24 +438,23 @@ class MeshInfo: axis_names: list[str] mesh_strides: tuple[int, ...] -def lower_jaxpr_to_module( + +def _check_block_mappings( + block_mappings: tuple[pallas_core.BlockMapping, ...], lowering_context: mlir.LoweringRuleContext, - ctx: ir.Context, - grid_mapping: pallas_core.GridMapping, - jaxpr: jax_core.Jaxpr, - *, - dimension_semantics: tuple[str | None, ...] | None, name_and_src_info: pallas_core.NameAndSrcInfo, - mesh: mesh_lib.Mesh | None = None, - for_verification: bool = False, -) -> tuple[Module, tuple[Any, ...]]: - for bm in grid_mapping.block_mappings: +) -> None: + del lowering_context # originally needed for forward compat + for bm in block_mappings: rank = len(bm.block_shape) # TODO(necula): add tests for SMEM blocks with trivial windowing # We support scalars too if (bm.block_aval.memory_space == tpu_core.TPUMemorySpace.SMEM and bm.has_trivial_window()): continue + if bm.block_aval.memory_space == tpu_core.TPUMemorySpace.SEMAPHORE: + continue + def err_details(): return (f"Block spec for {bm.origin} in pallas_call {name_and_src_info} " "has block shape " @@ -441,20 +463,10 @@ def err_details(): f"and index_map returning {bm.index_map_jaxpr.jaxpr.outvars}, in " f"memory space {bm.block_aval.memory_space}." "\nSee details at https://jax.readthedocs.io/en/latest/pallas/grid_blockspec.html#pallas-blockspec") - if lowering_context.is_forward_compat() or jaxlib_version < (0, 4, 32): - # TODO(b/356116061): Remove the old rank condition - if rank < 2: - raise ValueError( - "The Pallas TPU lowering currently supports only blocks of " - "rank >= 2 for blocks, except those in the SMEM memory space " - "having the same block shape as the array shape and a " - "trivial index_map (returning all 0s). " + err_details()) - else: - if rank < 1: - raise ValueError( - "The Pallas TPU lowering currently supports only blocks of " - "rank >= 1. " + err_details()) - + if rank < 1: + raise ValueError( + "The Pallas TPU lowering currently supports only blocks of " + "rank >= 1. " + err_details()) if (bm.block_aval.memory_space == tpu_core.TPUMemorySpace.ANY and not bm.has_trivial_window()): @@ -469,42 +481,42 @@ def err_details(): bs1, as1 = unmapped_bs[-2], bm.array_shape_dtype.shape[-2] else: bs1, as1 = 1, 1 - if lowering_context.is_forward_compat(): - # TODO(b/356116061): Remove the old divisibility condition - # With shape polymorphism block_shape is static, but the array shape may - # be symbolic. Write the divisibility comparisons to defer inequality - # comparisons on dimensions as much as possible. + + if rank >= 2: evenly_divisible = ( - (bs0 % 128 == 0 or (bs0 == as0 and as0 < 128)) and - (bs1 % 8 == 0 or (bs1 == as1 and as1 < 8)) + (bs0 == as0 or bs0 % 128 == 0) and + (bs1 == as1 or bs1 % 8 == 0) ) - if not evenly_divisible: - raise ValueError( - "The Pallas TPU lowering currently requires that the last two " - "dimensions of your block shape are divisible by 8 and 128 " - "respectively, if the respective dimensions of the overall array " - "are larger than the respective factors. If array dimensions are " - "smaller, the block should span the full array dimension. " - + err_details()) else: - if rank >= 2: - evenly_divisible = ( - (bs0 == as0 or bs0 % 128 == 0) and - (bs1 == as1 or bs1 % 8 == 0) - ) - else: - assert rank == 1 - # TODO(necula): test this for bool. What should it do? - tiling_size = 128 * (32 // lax_internal._bit_width(bm.array_shape_dtype.dtype)) - evenly_divisible = (bs0 == as0 or bs0 % tiling_size == 0) + assert rank == 1 + # TODO(necula): test this for bool. What should it do? + tiling_size = 128 * (32 // lax_internal._bit_width(bm.array_shape_dtype.dtype)) + evenly_divisible = (bs0 == as0 or bs0 % tiling_size == 0) if not evenly_divisible: raise ValueError( - "The Pallas TPU lowering currently requires that the last two " - "dimensions of your block shape are divisible by 8 and 128 " - "respectively, or be equal to the respective dimensions of the " - "overall array. " - + err_details()) + "The Pallas TPU lowering currently requires that the last two " + "dimensions of your block shape are divisible by 8 and 128 " + "respectively, or be equal to the respective dimensions of the " + "overall array. " + + err_details() + ) + + +def lower_jaxpr_to_module( + lowering_context: mlir.LoweringRuleContext, + ctx: ir.Context, + grid_mapping: pallas_core.GridMapping, + jaxpr: jax_core.Jaxpr, + *, + dimension_semantics: tuple[str | None, ...] | None, + name_and_src_info: pallas_core.NameAndSrcInfo, + mesh: mesh_lib.Mesh | None = None, + for_verification: bool = False, +) -> tuple[Module, tuple[Any, ...]]: + # Verify that we have legal block mappings to catch errors early. + _check_block_mappings(grid_mapping.block_mappings, lowering_context, + name_and_src_info) mosaic_grid_mapping = MosaicGridMapping( jaxpr, grid_mapping, dimension_semantics, mesh) @@ -526,7 +538,9 @@ def err_details(): for i, bm in enumerate(grid_mapping.block_mappings): func_name = f"transform_{i}" # ANY operands don't support windowing and require empty window_params. - if bm.block_aval.memory_space == tpu_core.TPUMemorySpace.ANY: + tpu_memory_space = _memory_space_to_tpu_memory_space( + bm.block_aval.memory_space) + if tpu_memory_space == tpu_core.TPUMemorySpace.ANY: # We checked above that the block does not require windowing. window_params.append(ir.DictAttr.get()) continue @@ -593,7 +607,9 @@ def lower_jaxpr_to_transform_func( ] def body_func(*args): grid_indices, scalar_prefetch = split_list(args, [num_grid]) - jaxpr_indices = mosaic_grid_mapping.get_grid_indices(grid_indices) + jaxpr_indices = mosaic_grid_mapping.get_grid_indices( + grid_indices, maybe_include_mapped_dims=True + ) arg_block_shapes = [ *[()] * len(jaxpr_indices), *mosaic_grid_mapping.scalar_prefetch_block_shapes, @@ -632,12 +648,8 @@ def body_func(*args): body = func.FuncOp.from_py_func(*arg_types, name=name)(body_func) try: body.func_op.verify() - except Exception as e: - raise LoweringException( - f"Body failed to verify: {body.func_op}.\nThis is an internal error." - " Please report a bug at:" - " https://github.com/google/jax/issues/new?assignees=sharadmv." - ) from e + except ir.MLIRError as e: + raise error_handling.mlir_error_to_verification_error(e) from e return body.func_op @@ -665,9 +677,9 @@ def lower_jaxpr_to_func( def body_func(*args): grid_indices, scalar_prefetch, operands_and_scratch = split_list( args, [num_grid, num_scalar_prefetch]) - grid_indices = mosaic_grid_mapping.get_grid_indices(grid_indices) - jaxpr_indices = tuple(idx for i, idx in enumerate(grid_indices) - if i not in mosaic_grid_mapping.mapped_dims) + jaxpr_indices = mosaic_grid_mapping.get_grid_indices( + grid_indices, maybe_include_mapped_dims=False + ) mesh_info = mosaic_grid_mapping.mesh_info if mesh_info is not None: mesh_context = MeshContext( @@ -694,12 +706,8 @@ def body_func(*args): body = func.FuncOp.from_py_func(*arg_types, name=name)(body_func) try: body.func_op.verify() - except Exception as e: - raise LoweringException( - f"Body failed to verify: {body.func_op}.\nThis is an internal error." - " Please report a bug at:" - " https://github.com/google/jax/issues/new?assignees=sharadmv." - ) from e + except ir.MLIRError as e: + raise error_handling.mlir_error_to_verification_error(e) from e return body.func_op @@ -781,9 +789,7 @@ def write_env(var: jax_core.Var, val): source_info = eqn.source_info.replace( name_stack=ctx.name_stack + eqn.source_info.name_stack ) - loc = mlir._source_info_to_location( - ctx, eqn.primitive, eqn.params, source_info - ) + loc = mlir._source_info_to_location(ctx, eqn.primitive, source_info) with source_info_util.user_context(eqn.source_info.traceback), loc: if eqn.primitive in lowering_rules: if eqn.primitive not in skip_mlir_conversions: @@ -827,7 +833,7 @@ def write_env(var: jax_core.Var, val): raise NotImplementedError( "Unimplemented primitive in Pallas TPU lowering: " f"{eqn.primitive.name}. " - "Please file an issue on https://github.com/google/jax/issues.") + "Please file an issue on https://github.com/jax-ml/jax/issues.") if eqn.primitive.multiple_results: map(write_env, eqn.outvars, ans) else: @@ -986,11 +992,12 @@ def _indexer_to_start_size_stride( ) -def _slice_memref(ref: ir.Value, ref_aval: state.AbstractRef, - indexer: NDIndexer, - ref_block_shape: tuple[int | pallas_core.Mapped, ...] - ) -> tuple[ir.Value, tuple[int | pallas_core.Mapped, ...], - tuple[int | pallas_core.Mapped, ...]]: +def _slice_memref( + ref: ir.Value, + indexer: NDIndexer, + ref_dtype: DTypeLike, + ref_block_shape: tuple[int | pallas_core.Mapped, ...], +) -> tuple[ir.Value, tuple[int | pallas_core.Mapped, ...]]: assert ref_block_shape is not None target_shape = indexer.get_indexer_shape() starts, sizes, strides, squeeze_dims, ref_block_shape = ( @@ -1007,26 +1014,79 @@ def _slice_memref(ref: ir.Value, ref_aval: state.AbstractRef, static_sizes = tuple(s if not isinstance(s, ir.Value) else ir_dynamic_size for s in sizes) target_ref_ty = ir.MemRefType.get( - static_sizes, _dtype_to_ir_type(ref_aval.dtype), - memory_space=ref.type.memory_space) + static_sizes, + _dtype_to_ir_type(ref_dtype), + memory_space=ref.type.memory_space, + ) out = tpu.MemRefSliceOp(target_ref_ty, ref, starts, dynamic_sizes).result if any(squeeze_dims): # We need to squeeze out some dimensions static_sizes = tuple(s if not isinstance(s, ir.Value) else ir_dynamic_size for s in target_shape) squeezed_ref_ty = ir.MemRefType.get( - static_sizes, _dtype_to_ir_type(ref_aval.dtype), - memory_space=ref.type.memory_space) + static_sizes, + _dtype_to_ir_type(ref_dtype), + memory_space=ref.type.memory_space, + ) out = tpu.MemRefSqueezeOp(squeezed_ref_ty, out).result return out, ref_block_shape -def _index_ref(ref, ref_aval, ref_block_shape, indexers): - for indexer in indexers: - ref, ref_block_shape = _slice_memref(ref, ref_aval, indexer, - ref_block_shape) +def _bitcast_memref( + ref: ir.Value, + bitcaster: RefBitcaster, + ref_dtype: DTypeLike, + ref_block_shape: tuple[int | pallas_core.Mapped, ...], +) -> tuple[ir.Value, DTypeLike, tuple[int | pallas_core.Mapped, ...]]: + src_bitwidth = dtype_bitwidth(ref_dtype) + dst_bitwidth = dtype_bitwidth(bitcaster.dtype) + if src_bitwidth != dst_bitwidth: + if len(ref_block_shape) < 2: + raise NotImplementedError( + "Bitcast 1D ref with bitwidth change is not supported." + ) + if ref_block_shape[-2] is pallas_core.mapped: + raise NotImplementedError( + "Bitcast a ref whose 2nd minormost dimension is squeezed when" + " bitwidth changes." + ) + new_ref_dtype = bitcaster.dtype + target_ref_ty = ir.MemRefType.get( + bitcaster.shape, + _dtype_to_ir_type(new_ref_dtype), + memory_space=ref.type.memory_space, + ) + new_ref_block_shape = list(ref_block_shape) + if ( + len(new_ref_block_shape) >= 2 + and new_ref_block_shape[-2] is not pallas_core.mapped + ): + new_ref_block_shape[-2] = ( + new_ref_block_shape[-2] * src_bitwidth // dst_bitwidth + ) + return ( + tpu.memref_bitcast(target_ref_ty, ref), + new_ref_dtype, + tuple(new_ref_block_shape), + ) + + +def _transform_ref(ref, ref_dtype, ref_block_shape, transforms): + for transform in transforms: + match transform: + case NDIndexer(): + ref, ref_block_shape = _slice_memref( + ref, transform, ref_dtype, ref_block_shape + ) + case RefBitcaster(): + ref, ref_dtype, ref_block_shape = _bitcast_memref( + ref, transform, ref_dtype, ref_block_shape + ) + case _: + raise NotImplementedError(f"Unsupported transform: {transform}") return ref, ref_block_shape + @dataclasses.dataclass(frozen=True) class KeyScalarBundle: """A container class for PRNG key data. @@ -1045,21 +1105,21 @@ class KeyScalarBundle: scalars: list[ir.OpResult] def _load_lowering_rule(ctx: LoweringRuleContext, *args_flat, args_tree, **_): - ref, indexers, mask, _ = args_tree.unflatten(args_flat) - ref_aval, indexers_avals, _, _ = args_tree.unflatten(ctx.avals_in) - (*slice_indexers, idx) = indexers + ref, transforms, mask, _ = args_tree.unflatten(args_flat) + ref_aval, transforms_avals, _, _ = args_tree.unflatten(ctx.avals_in) + (*prev_transforms, idx) = transforms # Select last aval, which is the one that will be used for the load. - (*_, idx_aval) = indexers_avals + (*_, idx_aval) = transforms_avals if mask is not None: raise NotImplementedError ref_block_shape, *_ = ctx.block_shapes - ref, ref_block_shape = _index_ref( - ref, ref_aval, ref_block_shape, slice_indexers) + ref, ref_block_shape = _transform_ref( + ref, ref_aval.dtype, ref_block_shape, prev_transforms + ) ref_type = ir.MemRefType(ref.type) is_smem_load = str(ref_type.memory_space) == "#tpu.memory_space" - ref_aval, *_ = ctx.avals_in (aval_out,) = ctx.avals_out if isinstance(aval_out.dtype, prng.KeyTy): if not is_smem_load: @@ -1086,7 +1146,14 @@ def _load_lowering_rule(ctx: LoweringRuleContext, *args_flat, args_tree, **_): raise ValueError("Can only load scalars from SMEM") return _maybe_cast_load_to_bool( aval_out, memref.LoadOp(ref, starts).result) - load_aval = jax_core.ShapedArray(sizes, dtype=ref_aval.dtype) + elif str(ref_type.memory_space) != "#tpu.memory_space": + extra = "" + if str(ref_type.memory_space) == "#tpu.memory_space": + extra = " ANY memory space can only be accessed using async_copy." + raise ValueError( + "Loads are only allowed on VMEM and SMEM references." + extra + ) + load_aval = jax_core.ShapedArray(sizes, dtype=aval_out.dtype) if need_stride: load_val = tpu.StridedLoadOp( aval_to_ir_type(load_aval, is_kernel_boundary=True), ref, starts, strides @@ -1094,12 +1161,12 @@ def _load_lowering_rule(ctx: LoweringRuleContext, *args_flat, args_tree, **_): else: load_val = vector.LoadOp( aval_to_ir_type(load_aval, is_kernel_boundary=True), ref, starts).result - load_val = _maybe_cast_load_to_bool(aval_out, load_val) - if load_aval == aval_out: - return load_val - vec_type = ir.VectorType.get(aval_out.shape, - _dtype_to_ir_type(aval_out.dtype)) - return vector.ShapeCastOp(vec_type, load_val).result + if load_aval != aval_out: + vec_type = ir.VectorType.get(aval_out.shape, + _dtype_to_ir_type(aval_out.dtype, + is_kernel_boundary=True)) + load_val = vector.ShapeCastOp(vec_type, load_val).result + return _maybe_cast_load_to_bool(aval_out, load_val) def _prng_key_load_lowering_rule(ctx: LoweringRuleContext, *args_flat, args_tree) -> KeyScalarBundle: """Lowering rule for loading PRNG keys from SMEM. @@ -1154,15 +1221,20 @@ def _maybe_cast_load_to_bool( if out_aval.dtype != jnp.bool_: return val load_scalar_type = _dtype_to_ir_type(BOOL_MEMREF_TYPE) - if not out_aval.shape: - # For scalars, truncate the value to a bool. - pred = _cmpi_lowering_types[lax.ne_p] - predicate = ir.IntegerAttr.get(ir.IntegerType.get_signless(64), pred) - const_zero = ir.IntegerAttr.get(load_scalar_type, 0) + pred = _cmpi_lowering_types[lax.ne_p] + predicate = ir.IntegerAttr.get(ir.IntegerType.get_signless(64), pred) + const_zero = ir.IntegerAttr.get(load_scalar_type, 0) + if out_aval.shape: # Vector case. + load_vector_type = aval_to_ir_type(out_aval, is_kernel_boundary=True) + vector_zeros = arith.ConstantOp( + load_vector_type, + ir.DenseElementsAttr.get_splat(load_vector_type, const_zero) + ) + return arith.CmpIOp(predicate, val, vector_zeros).result + else: # Scalar case. const_zero = arith.ConstantOp(load_scalar_type, const_zero) return arith.CmpIOp(predicate, val, const_zero).result - else: - raise NotImplementedError("Boolean vector loads are not supported.") + def _maybe_cast_store_to_memref_type( expected_aval, val: ir.Value) -> ir.Value: @@ -1176,20 +1248,23 @@ def _maybe_cast_store_to_memref_type( def _masked_swap_lowering_rule( ctx: LoweringRuleContext, *args_flat, args_tree, **_ ): - ref, indexers, val, mask = args_tree.unflatten(args_flat) - ref_aval, indexers_avals, val_aval, _ = args_tree.unflatten(ctx.avals_in) - (*slice_indexers, idx) = indexers - (*_, idx_aval) = indexers_avals + ref, transforms, val, mask = args_tree.unflatten(args_flat) + ref_aval, transforms_avals, val_aval, _ = args_tree.unflatten(ctx.avals_in) + (*prev_transforms, idx) = transforms + (*_, idx_aval) = transforms_avals if mask is not None: raise NotImplementedError ref_block_shape, *_ = ctx.block_shapes - ref, ref_block_shape = _index_ref( - ref, ref_aval, ref_block_shape, slice_indexers) + ref, ref_block_shape = _transform_ref( + ref, ref_aval.dtype, ref_block_shape, prev_transforms + ) ref_type = ir.MemRefType(ref.type) - is_smem_store = str(ref_type.memory_space) == "#tpu.memory_space" + memory_space = str(ref_type.memory_space) + is_smem_store = memory_space == "#tpu.memory_space" + is_vmem_store = memory_space == "#tpu.memory_space" (aval_out,) = ctx.avals_out if not isinstance(val, ir.Value): val = ir_constant(val, mlir_type=_dtype_to_ir_type(val_aval.dtype)) @@ -1208,6 +1283,7 @@ def _masked_swap_lowering_rule( cast_to_index=True, ) need_stride = not all((s is None or s == 1) for s in strides) + if is_smem_store: if val_aval.shape: raise ValueError("Can only store scalars to SMEM") @@ -1216,6 +1292,19 @@ def _masked_swap_lowering_rule( val = _maybe_cast_store_to_memref_type(val_aval, val) memref.StoreOp(val, ref, starts) return result + + if not is_vmem_store: + extra = "" + if memory_space == "#tpu.memory_space": + extra = " ANY memory space can only be accessed using async_copy." + raise ValueError( + "Loads and stores are only allowed on VMEM and SMEM references." + extra + ) + + # handling VMEM store below + if not val_aval.shape: + raise ValueError("Cannot store scalars to VMEM") + mem_slice_shape = list(aval_out.shape) for i, a in enumerate(idx_aval.indices): if not isinstance(a, primitives.Slice): @@ -1301,9 +1390,7 @@ def _proxy_fun(val, *, axes): kind, x, acc, - ir.ArrayAttr.get( - [ir.IntegerAttr.get(ir.IntegerType.get_signless(64), a) for a in axes] - ), + axes, ) return op.result return _lowering_rule @@ -1468,9 +1555,7 @@ def _dot_general_lowering_rule( ir.Attribute.parse("#vector.kind"), arith.MulFOp(x, y), acc, - ir.ArrayAttr.get( - [ir.IntegerAttr.get(ir.IntegerType.get_signless(64), 1)] - ), + [1] ) return vector.ShapeCastOp(out_type, red).result @@ -1523,6 +1608,12 @@ def _convert_helper(x, *, to_dtype): if jnp.issubdtype(to_dtype, jnp.floating) and to_dtype.itemsize < 4: x = x.astype(jnp.float32) return x.astype(to_dtype) + if jnp.issubdtype(from_dtype, jnp.unsignedinteger): + if from_dtype.itemsize < 4: + x = x.astype(jnp.uint32) + if jnp.issubdtype(to_dtype, jnp.floating) and to_dtype.itemsize < 4: + x = x.astype(jnp.float32) + return x.astype(to_dtype) if jnp.issubdtype(from_dtype, jnp.floating): if jnp.issubdtype(to_dtype, jnp.signedinteger): if from_dtype.itemsize < 4: @@ -1544,13 +1635,10 @@ def _convert_element_type_lowering_rule( del weak_type del sharding out_aval = ctx.avals_out[0] - old_dtype = ctx.avals_in[0].dtype + in_aval = ctx.avals_in[0] + old_dtype = in_aval.dtype out_type = aval_to_ir_type(out_aval) - # TODO(justinfu): Remove after mosaic supports unsigned types. - # This conversion makes mosaic interpret all unsigned types as signed types. - if np.issubdtype(new_dtype, jnp.unsignedinteger): - new_dtype = UNSIGNED_TO_SIGNED[new_dtype] if old_dtype == new_dtype: return x if jnp.issubdtype(old_dtype, jnp.floating) and jnp.issubdtype( @@ -1560,18 +1648,21 @@ def _convert_element_type_lowering_rule( return arith.ExtFOp(out_type, x).result elif old_dtype.itemsize > new_dtype.itemsize and old_dtype.itemsize == 4: return arith.TruncFOp(out_type, x).result - elif jnp.issubdtype(old_dtype, jnp.signedinteger) and jnp.issubdtype( - new_dtype, jnp.signedinteger + elif jnp.issubdtype(old_dtype, jnp.integer) and jnp.issubdtype( + new_dtype, jnp.integer ): if old_dtype.itemsize < new_dtype.itemsize and new_dtype.itemsize == 4: return arith.ExtSIOp(out_type, x).result elif old_dtype.itemsize > new_dtype.itemsize and old_dtype.itemsize == 4: return arith.TruncIOp(out_type, x).result + elif jnp.iinfo(old_dtype).bits == jnp.iinfo(new_dtype).bits: + # This case triggers when casting signed to unsigned or vice versa. + return x elif jnp.issubdtype(old_dtype, jnp.floating) and jnp.issubdtype( new_dtype, jnp.signedinteger ) and old_dtype.itemsize == new_dtype.itemsize == 4: return arith.FPToSIOp(out_type, x).result - elif jnp.issubdtype(old_dtype, jnp.signedinteger) and jnp.issubdtype( + elif jnp.issubdtype(old_dtype, jnp.integer) and jnp.issubdtype( new_dtype, jnp.floating ) and old_dtype.itemsize == new_dtype.itemsize == 4: return arith.SIToFPOp(out_type, x).result @@ -1590,8 +1681,16 @@ def _convert_element_type_lowering_rule( predicate = ir.IntegerAttr.get(ir.IntegerType.get_signless(64), pred) const_type = _dtype_to_ir_type(old_dtype) const_zero = ir.IntegerAttr.get(const_type, 0) - const_zero = arith.ConstantOp(const_type, const_zero) - return arith.CmpIOp(predicate, x, const_zero).result + if in_aval.shape: + in_type = aval_to_ir_type(in_aval, is_kernel_boundary=False) + vector_zeros = arith.ConstantOp( + in_type, + ir.DenseElementsAttr.get_splat(in_type, const_zero), + ) + return arith.CmpIOp(predicate, x, vector_zeros).result + return arith.CmpIOp( + predicate, x, arith.ConstantOp(const_type, const_zero) + ).result return lower_fun(functools.partial(_convert_helper, to_dtype=new_dtype), multiple_results=False)(ctx, x) @@ -1617,6 +1716,12 @@ def _squeeze_lowering_rule(ctx: LoweringRuleContext, x, dimensions): (aval_in,) = ctx.avals_in (aval_out,) = ctx.avals_out if not aval_out.shape: + if aval_out.dtype.itemsize != 4: + raise ValueError( + "Only arrays with 32-bit element types can be converted to scalars," + f" but got: {aval_out.dtype}. Try casting the input before squeezing" + " the scalar." + ) return vector.ExtractOp(x, [], [0] * len(aval_in.shape)).result return vector.ShapeCastOp(aval_to_ir_type(ctx.avals_out[0]), x).result @@ -1813,22 +1918,10 @@ def _neg_lowering_rule(ctx: LoweringRuleContext, x): skip_mlir_conversions.add(lax.neg_p) -def _sign_lowering_helper(x): - if jnp.issubdtype(x.dtype, jnp.unsignedinteger): - return (x != 0).astype(x.dtype) - - if jnp.issubdtype(x.dtype, jnp.integer): - return (x > 0).astype(x.dtype) - (x < 0).astype(x.dtype) - - if jnp.issubdtype(x.dtype, jnp.floating): - out = (x > 0.).astype(x.dtype) - (x < 0.).astype(x.dtype) - return jnp.where(jnp.isnan(x), jnp.nan, out) - - raise NotImplementedError - - def _sign_lowering_rule(ctx: LoweringRuleContext, x): - return lower_fun(_sign_lowering_helper, multiple_results=False)(ctx, x) + return lower_fun( + pallas_utils.sign_lowering_helper, multiple_results=False, + )(ctx, x) lowering_rules[lax.sign_p] = _sign_lowering_rule @@ -2160,7 +2253,7 @@ def _run_body(i, args): def _scan_lowering_rule( ctx: LoweringRuleContext, *args, - jaxpr: jax_core.Jaxpr, + jaxpr: jax_core.ClosedJaxpr, linear: tuple[bool, ...], length: int, reverse: bool, @@ -2241,7 +2334,7 @@ def _while_lowering_rule( body_jaxpr, ): # First try to lower via a simpler fori loop, which may optimize better. - fori_jaxpr, err = pallas_utils.pattern_match_while_to_fori_loop( + fori_jaxpr, _ = pallas_utils.pattern_match_while_to_fori_loop( cond_jaxpr, cond_nconsts, body_jaxpr, body_nconsts ) if fori_jaxpr is not None: @@ -2262,19 +2355,12 @@ def _while_lowering_rule( cond_const_block_shapes, body_const_block_shapes, carry_block_shapes = ( split_list(ctx.block_shapes, [cond_nconsts, body_nconsts]) ) - cond_const_types = [a.type for a in cond_consts] - body_const_types = [a.type for a in body_consts] carry_types = [a.type for a in carry] - all_types = [*cond_const_types, *body_const_types, *carry_types] - while_op = scf.WhileOp(all_types, args) + while_op = scf.WhileOp(carry_types, carry) - before_block = while_op.before.blocks.append(*all_types) - cond_consts_, _, carry_ = split_list( - before_block.arguments, - [cond_nconsts, body_nconsts], - ) - cond_args = [*cond_consts_, *carry_] + before_block = while_op.before.blocks.append(*carry_types) with ir.InsertionPoint.at_block_begin(before_block): + cond_args = [*cond_consts, *before_block.arguments] [cond] = jaxpr_subcomp( ctx.lowering_context.replace( block_shapes=[*cond_const_block_shapes, *carry_block_shapes] @@ -2284,30 +2370,19 @@ def _while_lowering_rule( ) scf.condition(cond, before_block.arguments) - after_block = while_op.after.blocks.append(*all_types) - cond_consts_, body_consts_, carry_ = split_list( - after_block.arguments, - [cond_nconsts, body_nconsts], - ) - all_args = [*cond_consts_, *body_consts_, *carry_] - cond_const_args, body_const_args, carry_args = split_list( - all_args, [cond_nconsts, body_nconsts] - ) + after_block = while_op.after.blocks.append(*carry_types) with ir.InsertionPoint.at_block_begin(after_block): + body_args = [*body_consts, *after_block.arguments] loop_out = jaxpr_subcomp( ctx.lowering_context.replace( block_shapes=[*body_const_block_shapes, *carry_block_shapes], ), body_jaxpr.jaxpr, - *body_const_args, - *carry_args, + *body_args, ) - all_handles = [*cond_const_args, *body_const_args, *loop_out] - if all_handles: - scf.yield_(all_handles) - - all_out = list(while_op.results_) - return all_out[cond_nconsts + body_nconsts :] + if loop_out: + scf.yield_(loop_out) + return list(while_op.results) lowering_rules[lax.while_p] = _while_lowering_rule @@ -2382,6 +2457,7 @@ def _debug_callback_lowering_rule(ctx: LoweringRuleContext, *args, **kwargs): def _program_id_lowering_rule(ctx: LoweringRuleContext, *, axis: int): + if ctx.lowering_context.user_grid_indices is None: raise ValueError( f"program id: {axis} was passed, but user did not provide a grid." @@ -2489,39 +2565,10 @@ def _shift_right_logical_lowering_rules(ctx: LoweringRuleContext, x, d): skip_mlir_conversions.add(lax.shift_right_logical_p) -# based on https://github.com/openxla/xla/blob/a7a09d56c3599123f8148bbf3e44c9ebc04624b9/xla/mlir_hlo/mhlo/transforms/chlo_legalize_to_hlo/chlo_legalize_to_hlo.cc#L644-L802 -def _erf_inv_32_helper(x): - k_degree = 9 - w_lt_5_constants = [ - 2.81022636e-08, 3.43273939e-07, -3.5233877e-06, - -4.39150654e-06, 0.00021858087, -0.00125372503, - -0.00417768164, 0.246640727, 1.50140941, - ] - w_gt_5_constants = [ - -0.000200214257, 0.000100950558, 0.00134934322, - -0.00367342844, 0.00573950773, -0.0076224613, - 0.00943887047, 1.00167406, 2.83297682, - ] - - w = -jnp.log1p(x * -x) - w_lt_5 = w < 5.0 - - w = jnp.where(w_lt_5, w - 2.5, jnp.sqrt(w) - 3.0) - - p = jnp.where(w_lt_5, w_lt_5_constants[0], w_gt_5_constants[0]) - for i in range(1, k_degree): - c = jnp.where(w_lt_5, w_lt_5_constants[i], w_gt_5_constants[i]) - p = c + p * w - - return jnp.where(jnp.abs(x) == 1.0, jnp.inf * x, p * x) - - def _erf_inv_lowering_rule(ctx: LoweringRuleContext, x): - (x_aval,) = ctx.avals_in - if x_aval.dtype == jnp.float32: - return lower_fun(_erf_inv_32_helper, multiple_results=False)(ctx, x) - else: - raise NotImplementedError + return lower_fun( + pallas_utils.erf_inv_lowering_helper, multiple_results=False, + )(ctx, x) lowering_rules[lax.erf_inv_p] = _erf_inv_lowering_rule @@ -2538,14 +2585,16 @@ def _bitcast_convert_type_lowering_rule( ctx: LoweringRuleContext, x, *, new_dtype): (in_aval, ) = ctx.avals_in (out_aval,) = ctx.avals_out - if in_aval.dtype.itemsize != new_dtype.itemsize: + old_bitwidth = pallas_utils.dtype_bitwidth(in_aval.dtype) + new_bitwidth = pallas_utils.dtype_bitwidth(new_dtype) + if old_bitwidth != new_bitwidth: raise NotImplementedError("Changing bitwidths not supported.") return tpu.BitcastOp(aval_to_ir_type(out_aval), x).result lowering_rules[lax.bitcast_convert_type_p] = _bitcast_convert_type_lowering_rule def _alloc_value(aval: jax_core.AbstractValue) -> ir.Value: if isinstance(aval, pallas_core.AbstractMemoryRef): - memspace = ir.Attribute.parse(f"#tpu.memory_space<{aval.memory_space}>") + memspace = _memory_space_to_mosaic_attribute(aval.memory_space) if jnp.issubdtype(aval.dtype, tpu_core.semaphore_dtype): assert aval.memory_space == TPUMemorySpace.SEMAPHORE memref_type = aval_to_ir_type(aval, memory_space=TPUMemorySpace.SEMAPHORE) @@ -2610,8 +2659,8 @@ def _semaphore_read_lowering_rule( args_tree, ): sem_aval, _ = tree_util.tree_unflatten(args_tree, ctx.avals_in) - sem, indexers = tree_util.tree_unflatten(args_tree, args) - sem, _ = _index_ref(sem, sem_aval, sem_aval.shape, indexers) + sem, transforms = tree_util.tree_unflatten(args_tree, args) + sem, _ = _transform_ref(sem, sem_aval.dtype, sem_aval.shape, transforms) return tpu.SemaphoreReadOp(sem).result @@ -2624,8 +2673,10 @@ def _semaphore_signal_lowering_rule( device_id_type: tpu_primitives.DeviceIdType, ): sem_aval, _, _, _, _ = tree_util.tree_unflatten(args_tree, ctx.avals_in) - sem, indexers, value, device_id, core_index = tree_util.tree_unflatten(args_tree, args) - sem, _ = _index_ref(sem, sem_aval, sem_aval.shape, indexers) + sem, transforms, value, device_id, core_index = tree_util.tree_unflatten( + args_tree, args + ) + sem, _ = _transform_ref(sem, sem_aval.dtype, sem_aval.shape, transforms) if device_id is not None: device_id = _device_id_to_logical(ctx, device_id, device_id_type) return tpu.SemaphoreSignalOp( @@ -2639,8 +2690,8 @@ def _semaphore_signal_lowering_rule( def _semaphore_wait_lowering_rule(ctx: LoweringRuleContext, *args, args_tree): sem_aval, _, _ = tree_util.tree_unflatten(args_tree, ctx.avals_in) - sem, indexers, value = tree_util.tree_unflatten(args_tree, args) - sem, _ = _index_ref(sem, sem_aval, sem_aval.shape, indexers) + sem, transforms, value = tree_util.tree_unflatten(args_tree, args) + sem, _ = _transform_ref(sem, sem_aval.dtype, sem_aval.shape, transforms) return tpu.SemaphoreWaitOp(sem, value).results lowering_rules[tpu_primitives.semaphore_wait_p] = _semaphore_wait_lowering_rule @@ -2648,13 +2699,13 @@ def _dma_start_lowering_rule(ctx: LoweringRuleContext, *args, tree, device_id_type: tpu_primitives.DeviceIdType): ( src_ref, - src_indexers, + src_transforms, dst_ref, - dst_indexers, + dst_transforms, sem, - sem_indexers, + sem_transforms, src_sem, - src_sem_indexers, + src_sem_transforms, device_id, ) = tree_util.tree_unflatten(tree, args) (src_ref_aval, _, dst_ref_aval, _, sem_aval, _, src_sem_aval, _, _) = ( @@ -2664,16 +2715,17 @@ def _dma_start_lowering_rule(ctx: LoweringRuleContext, *args, tree, raise NotImplementedError("DMAs with bool dtypes are not supported.") block_shapes = tree_util.tree_unflatten(tree, ctx.block_shapes) src_ref_block_shape, dst_ref_block_shape = block_shapes[0], block_shapes[2] - src_ref, _ = _index_ref( - src_ref, src_ref_aval, src_ref_block_shape, src_indexers + src_ref, _ = _transform_ref( + src_ref, src_ref_aval.dtype, src_ref_block_shape, src_transforms ) if src_sem is not None: - src_sem, _ = _index_ref( - src_sem, src_sem_aval, src_sem_aval.shape, src_sem_indexers) - dst_ref, _ = _index_ref( - dst_ref, dst_ref_aval, dst_ref_block_shape, dst_indexers + src_sem, _ = _transform_ref( + src_sem, src_sem_aval.dtype, src_sem_aval.shape, src_sem_transforms + ) + dst_ref, _ = _transform_ref( + dst_ref, dst_ref_aval.dtype, dst_ref_block_shape, dst_transforms ) - sem, _ = _index_ref(sem, sem_aval, sem_aval.shape, sem_indexers) + sem, _ = _transform_ref(sem, sem_aval.dtype, sem_aval.shape, sem_transforms) if device_id is not None: device_id = _device_id_to_logical(ctx, device_id, device_id_type) return tpu.EnqueueDMAOp(src_ref, dst_ref, sem, source_semaphore=src_sem, @@ -2684,14 +2736,12 @@ def _dma_start_lowering_rule(ctx: LoweringRuleContext, *args, tree, def _dma_wait_lowering_rule(ctx: LoweringRuleContext, *args, tree, device_id_type: tpu_primitives.DeviceIdType): del device_id_type - sem, sem_indexers, ref, indexers = tree_util.tree_unflatten(tree, args) + sem, sem_transforms, ref, transforms = tree_util.tree_unflatten(tree, args) sem_aval, _, ref_aval, _ = tree_util.tree_unflatten(tree, ctx.avals_in) block_shapes = tree_util.tree_unflatten(tree, ctx.block_shapes) ref_block_shape = block_shapes[2] - ref, _ = _index_ref( - ref, ref_aval, ref_block_shape, indexers - ) - sem, _ = _index_ref(sem, sem_aval, sem_aval.shape, sem_indexers) + ref, _ = _transform_ref(ref, ref_aval.dtype, ref_block_shape, transforms) + sem, _ = _transform_ref(sem, sem_aval.dtype, sem_aval.shape, sem_transforms) return tpu.WaitDMAOp(sem, ref).results lowering_rules[tpu_primitives.dma_wait_p] = _dma_wait_lowering_rule @@ -2735,6 +2785,9 @@ def _delay_rule(ctx: LoweringRuleContext, nanos: int): def _debug_print_rule( ctx: LoweringRuleContext, *args, fmt: str, has_placeholders: bool ): + if any(aval.shape for aval in ctx.avals_in): + raise NotImplementedError("Only scalar values are supported") + primitives.check_debug_print_format(fmt, *args) if has_placeholders: if not all( @@ -2886,9 +2939,10 @@ def body(*args): out = pallas_call.pallas_call( body, out_shape=in_avals, - in_specs=[pallas_core.BlockSpec(memory_space=tpu_core.TPUMemorySpace.ANY)] + in_specs=[pallas_core.BlockSpec(memory_space=pallas_core.MemorySpace.ANY)] * len(in_avals), - out_specs=[pallas_core.BlockSpec(memory_space=tpu_core.TPUMemorySpace.ANY)] + out_specs=[pallas_core.BlockSpec( + memory_space=pallas_core.MemorySpace.ANY)] * len(in_avals), input_output_aliases={i: i for i in range(len(in_avals))}, grid=((core_axis_name, num_cores),), diff --git a/jax/_src/pallas/mosaic/pallas_call_registration.py b/jax/_src/pallas/mosaic/pallas_call_registration.py index c6edddca035b..b09d36a9d3b2 100644 --- a/jax/_src/pallas/mosaic/pallas_call_registration.py +++ b/jax/_src/pallas/mosaic/pallas_call_registration.py @@ -30,19 +30,24 @@ from jax._src.interpreters import mlir from jax._src.lib.mlir import ir from jax._src.pallas import core +from jax._src.pallas.mosaic import core as tpu_core from jax._src.pallas.mosaic import lowering from jax._src.pallas.mosaic import verification +from jax._src import tpu_custom_call from jax.experimental import mosaic from jax.experimental.mosaic.dialects import tpu from jax.experimental.pallas import tpu as pltpu -def _maybe_cast_to_int(x: jax.Array | jax_core.ShapedArray): +def _maybe_cast_to_int(x: jax.Array | jax_core.AbstractValue): """Casts boolean values to integers. We perform this cast because Mosaic does not directly support bool values for Memrefs. Instead, we load bools as integers and cast them to bools after loading from a memref inside of the kernel. """ + assert isinstance( + x, (jax.Array, jax_core.ShapedArray, jax_core.DShapedArray) + ), type(x) if isinstance(x, jax.Array): if dtypes.issubdtype(x.dtype, jax.numpy.bool_): return x.astype(lowering.BOOL_MEMREF_TYPE) @@ -63,6 +68,41 @@ def _maybe_cast_to_int(x: jax.Array | jax_core.ShapedArray): ) +def _get_memory_space_from_aval( + out_aval: jax_core.AbstractValue, +) -> tpu_custom_call.MemorySpace | None: + if not isinstance(out_aval, jax_core.ShapedArray): + raise ValueError('Memory spaces not defined for non-ShapedArrays') + if not isinstance(out_aval, core.ShapedArrayWithMemorySpace): + # If we are passed a regular old ShapedArray, we don't constrain the + # memory space + return None + # If we are passed an aval with an explicit memory space tag, we use it + # to constrain the memory space. + match out_aval.memory_space: + case None: + return None + case tpu_core.TPUMemorySpace.ANY: + return None + case tpu_core.TPUMemorySpace.VMEM: + return tpu_custom_call.MemorySpace.VMEM + case tpu_core.TPUMemorySpace.SEMAPHORE: + return tpu_custom_call.MemorySpace.SEMAPHORE_MEM + return None + + +def _get_memory_spaces_from_avals( + out_avals: tuple[jax_core.AbstractValue, ...], +) -> tuple[tpu_custom_call.MemorySpace | None, ...] | None: + output_memory_spaces = None + if any( + isinstance(out_aval, core.ShapedArrayWithMemorySpace) + for out_aval in out_avals + ): + output_memory_spaces = tuple(map(_get_memory_space_from_aval, out_avals)) + return output_memory_spaces + + def pallas_call_tpu_lowering_rule( ctx: mlir.LoweringRuleContext, *in_nodes, @@ -74,22 +114,14 @@ def pallas_call_tpu_lowering_rule( interpret: bool, compiler_params: dict[str, Any], cost_estimate: core.CostEstimate | None, + out_avals: tuple[jax_core.AbstractValue, ...], ): """Lowers a pallas_call to a Mosaic TPU custom call.""" del interpret if debug: print(f"\nThe kernel jaxpr for pallas_call {name_and_src_info}:") print(jaxpr) - if "mosaic_params" in compiler_params: - # TODO(slebedev): Remove this branch after July 12th 2024. - warnings.warn( - "Passing Mosaic parameters via compiler_params=dict(mosaic_params=...)" - " is deprecated. Use compiler_params=dict(mosaic=...) instead.", - DeprecationWarning, - ) - assert "mosaic" not in compiler_params - mosaic_params = compiler_params["mosaic_params"] - elif "mosaic" in compiler_params: + if "mosaic" in compiler_params: mosaic_params = compiler_params["mosaic"] else: mosaic_params = {} @@ -138,9 +170,6 @@ def lower_module(for_verification: bool): (a[0] + num_dyn_bounds + num_extra_args, a[1]) for a in input_output_aliases ) - out_avals = [jax_core.ShapedArray(bm.array_shape_dtype.shape, - bm.array_shape_dtype.dtype) - for bm in grid_mapping.block_mappings_output] if promela_dump_path := _DUMP_PROMELA_TO.value: num_devices = 1 if mesh is None else mesh.devices.size @@ -183,7 +212,7 @@ def lower_module(for_verification: bool): def _maybe_cast_inputs(*args): args = [_maybe_cast_to_int(x) for x in args] return args - kernel_in_avals = [_maybe_cast_to_int(x) for x in ctx.avals_in] # type: ignore + kernel_in_avals = [_maybe_cast_to_int(x) for x in ctx.avals_in] kernel_out_avals = [_maybe_cast_to_int(x) for x in out_avals] cast_ctx = ctx.replace(avals_out=kernel_in_avals) in_nodes = mlir.lower_fun(_maybe_cast_inputs)(cast_ctx, *in_nodes) @@ -191,6 +220,7 @@ def _maybe_cast_inputs(*args): # Dynamic grid bounds have to go at the front. dynamic_grid_args, args = in_nodes[:num_dyn_bounds], in_nodes[num_dyn_bounds:] kernel_ctx = ctx.replace(avals_in=kernel_in_avals, avals_out=kernel_out_avals) + output_memory_spaces = _get_memory_spaces_from_avals(out_avals) if cost_estimate is not None: mosaic_cost_estimate = pltpu.CostEstimate( flops=cost_estimate.flops, @@ -217,6 +247,7 @@ def _maybe_cast_inputs(*args): device_type=mosaic_params.get("device_type"), internal_scratch_in_bytes=mosaic_params.get("internal_scratch_in_bytes"), collective_id=mosaic_params.get("collective_id", None), + output_memory_spaces=output_memory_spaces, ) _maybe_cast_to_bool = lambda x, aval: x.astype( jax.numpy.bool_) if aval.dtype == jax.numpy.bool_ else x diff --git a/jax/_src/pallas/mosaic/pipeline.py b/jax/_src/pallas/mosaic/pipeline.py index 7fde6665d394..1514f67a9e33 100644 --- a/jax/_src/pallas/mosaic/pipeline.py +++ b/jax/_src/pallas/mosaic/pipeline.py @@ -16,12 +16,13 @@ from __future__ import annotations from collections.abc import Sequence +from contextlib import contextmanager import dataclasses import enum import functools import itertools import operator -from typing import Union, Any +from typing import Any, Union import jax from jax import lax @@ -39,7 +40,7 @@ SMEM = tpu_core.TPUMemorySpace.SMEM VMEM = tpu_core.TPUMemorySpace.VMEM DMA = tpu_core.SemaphoreType.DMA -REF = tpu_core.MemoryRef +REF = pallas_core.MemoryRef SemaphoreType = tpu_core.SemaphoreType SemaphoreTuple = jax.Array ArrayRef = Union[REF, jax.Array] @@ -72,6 +73,7 @@ def add_leaves(i, x): return tree_util.tree_unflatten(treedef, broadcast_leaves) +@jax_util.cache(trace_context_in_key=False) def _get_tpu_generation() -> int: kind = jax.devices()[0].device_kind if kind.endswith(' lite'): @@ -157,6 +159,8 @@ def _grid_size(grid): def _get_indices(step, grid, offsets): """Get indices for a given step and grid.""" + # TODO(enriqueps): Implement using bitwise ops, avoid div/rem since they are + # expensive. extended_grid = grid + (1,) strides = tuple( itertools.accumulate(extended_grid[::-1], func=operator.mul))[::-1] @@ -174,6 +178,8 @@ class BufferType(enum.Enum): ACCUMULATOR = 3 INPUT_OUTPUT = 4 + MANUAL = 5 + @tree_util.register_pytree_node_class @dataclasses.dataclass(frozen=True) @@ -185,7 +191,7 @@ class BufferedRef: dtype: dtype for buffers. buffer_type: enum indicating whether this is an input, output, or in/out accumulator buffered reference. - vmem_ref: a double-buffer to hold a working buffer and a dirty buffer used + window_ref: a double-buffer to hold a working buffer and a dirty buffer used to copy into and out of. In the case of a BufferedRef targeting a VMEM reference, this simply points to the existing ref. accum_ref: accumulating buffer used by accumulator BufferedRefs. @@ -206,7 +212,7 @@ class BufferedRef: spec: pl.BlockSpec # static metadata dtype: Any # static metadata buffer_type: BufferType # static metadata - vmem_ref: REF | None + window_ref: REF | None accum_ref: REF | None current_slot: ArrayRef | None next_slot: ArrayRef | None @@ -214,14 +220,26 @@ class BufferedRef: sem_sends: SemaphoreTuple | None def tree_flatten(self): - return ((self.vmem_ref, self.accum_ref, self.current_slot, - self.next_slot, self.sem_recvs, self.sem_sends), - (self.spec, self.dtype, self.buffer_type)) + return ( + ( + self.window_ref, + self.accum_ref, + self.current_slot, + self.next_slot, + self.sem_recvs, + self.sem_sends, + ), + (self.spec, self.dtype, self.buffer_type), + ) @classmethod def tree_unflatten(cls, meta, data): return cls(*meta, *data) + @staticmethod + def buffer_types() -> type[BufferType]: + return BufferType + @classmethod def create(cls, spec, dtype, buffer_type) -> BufferedRef: """Create a BufferedRef. @@ -235,7 +253,7 @@ def create(cls, spec, dtype, buffer_type) -> BufferedRef: Returns: Initialized BufferedRef """ - block_shape = tuple([1 if x is None else x for x in spec.block_shape]) + block_shape = tuple(1 if x is None else x for x in spec.block_shape) if buffer_type is BufferType.ACCUMULATOR: accum_ref = VMEM(block_shape, dtype) else: @@ -248,7 +266,7 @@ def create(cls, spec, dtype, buffer_type) -> BufferedRef: spec=spec, dtype=dtype, buffer_type=buffer_type, - vmem_ref=None, # to be bound to existing ref by the pipeline routine + window_ref=None, # to be bound to existing ref by the pipeline routine accum_ref=accum_ref, current_slot=None, next_slot=None, @@ -256,11 +274,12 @@ def create(cls, spec, dtype, buffer_type) -> BufferedRef: sem_sends=None, ) else: + memory_space = SMEM if spec.memory_space == SMEM else VMEM return cls( spec=spec, dtype=dtype, buffer_type=buffer_type, - vmem_ref=VMEM((2,) + block_shape, dtype), + window_ref=memory_space((2,) + block_shape, dtype), accum_ref=accum_ref, current_slot=SMEM((1,), jnp.int32), next_slot=SMEM((1,), jnp.int32), @@ -307,11 +326,11 @@ def memory_space(self): @property def current_ref(self): buffer_slice = tuple( - [0 if x is None else slice(None) for x in self.block_shape]) + 0 if x is None else slice(None) for x in self.block_shape) if self.memory_space == VMEM: - return self.vmem_ref.at[buffer_slice] + return self.window_ref.at[buffer_slice] else: - return self.vmem_ref.at[(self.current_slot[0], *buffer_slice)] + return self.window_ref.at[(self.current_slot[0], *buffer_slice)] @property def is_input(self): @@ -337,16 +356,17 @@ def is_accumulator(self): def is_input_output(self): return self.buffer_type == BufferType.INPUT_OUTPUT - def bind_existing_ref(self, vmem_ref, indices): + def bind_existing_ref(self, window_ref, indices): """For handling VMEM references, the pipeline aliases the existing ref.""" if self.memory_space == VMEM: return dataclasses.replace( - self, vmem_ref=vmem_ref.at[self.compute_slice(indices)]) + self, window_ref=window_ref.at[self.compute_slice(indices)] + ) return self def compute_slice(self, grid_indices): """Compute DMA slice from grid indices.""" - block_shape = tuple([1 if x is None else x for x in self.block_shape]) + block_shape = tuple(1 if x is None else x for x in self.block_shape) indices = self.compute_index(*grid_indices) return jax.tree.map(_make_ds, indices, block_shape) @@ -428,8 +448,9 @@ def copy_in(self, src_ref, grid_indices): dst_slice = tuple(pl.ds(0, s.size) for s in src_slice) tpu_primitives.make_async_copy( src_ref.at[src_slice], - self.vmem_ref.at[next_slot].at[dst_slice], - self.sem_recvs.at[next_slot]).start() + self.window_ref.at[next_slot].at[dst_slice], + self.sem_recvs.at[next_slot], + ).start() def copy_out(self, dst_ref, grid_indices): """Starts copy of HBM dma slice from the current slot.""" @@ -440,9 +461,10 @@ def copy_out(self, dst_ref, grid_indices): dst_slice = self.get_dma_slice(dst_ref.shape, dst_ref.dtype, grid_indices) src_slice = tuple(pl.ds(0, s.size) for s in dst_slice) tpu_primitives.make_async_copy( - self.vmem_ref.at[slot].at[src_slice], + self.window_ref.at[slot].at[src_slice], dst_ref.at[dst_slice], - self.sem_sends.at[slot]).start() + self.sem_sends.at[slot], + ).start() def wait_in(self, src_ref, grid_indices): """Waits for input copy to finish.""" @@ -452,9 +474,12 @@ def wait_in(self, src_ref, grid_indices): dst_slice = tuple(pl.ds(0, s.size) for s in src_slice) current_slot = self.current_slot[0] tpu_primitives.make_async_copy( - src_ref.at[src_slice], # nb: doesn't matter - self.vmem_ref.at[current_slot].at[dst_slice], # only dst shape is important - self.sem_recvs.at[current_slot]).wait() + src_ref.at[src_slice], # nb: doesn't matter + self.window_ref.at[current_slot].at[ + dst_slice + ], # only dst shape is important + self.sem_recvs.at[current_slot], + ).wait() def wait_out(self, dst_ref, grid_indices): """Waits for output copy to finish.""" @@ -464,9 +489,10 @@ def wait_out(self, dst_ref, grid_indices): dst_slice = self.get_dma_slice(dst_ref.shape, dst_ref.dtype, grid_indices) src_slice = tuple(pl.ds(0, s.size) for s in dst_slice) tpu_primitives.make_async_copy( - self.vmem_ref.at[prev_slot].at[src_slice], # nb: doesn't matter - dst_ref.at[dst_slice], # only dst shape is important - self.sem_sends.at[prev_slot]).wait() + self.window_ref.at[prev_slot].at[src_slice], # nb: doesn't matter + dst_ref.at[dst_slice], # only dst shape is important + self.sem_sends.at[prev_slot], + ).wait() # Accumulator methods # @@ -486,7 +512,7 @@ def set_accumulator(self, init=False): def _init(): self.accum_ref[...] = jnp.zeros_like(self.accum_ref[...]) def _set(): - self.accum_ref[...] = self.current_ref[...].astype(self.accum_ref) + self.accum_ref[...] = self.current_ref[...].astype(self.accum_ref.dtype) lax.cond(init, _init, _set) def accumulate(self): @@ -494,14 +520,14 @@ def accumulate(self): assert self.is_accumulator if self.accum_ref is not None: accum_dtype = jnp.float32 - if self.vmem_ref.dtype == jnp.int32: + if self.window_ref.dtype == jnp.int32: accum_dtype = jnp.int32 # TODO(levskaya): we could generalize init and reduction functions, # could it ever be useful to support more generic monoids? self.current_ref[...] = ( - self.current_ref[...].astype(accum_dtype) + - self.accum_ref[...].astype(accum_dtype) - ).astype(self.vmem_ref.dtype) + self.current_ref[...].astype(accum_dtype) + + self.accum_ref[...].astype(accum_dtype) + ).astype(self.window_ref.dtype) # Helper to tree map over BufferedRefs as leaves. @@ -513,30 +539,34 @@ def accumulate(self): class Scheduler: """Sequences input and output copies and waits for a pipeline.""" - def __init__(self, - step: jax.Array, - grid: tuple[int | jax.Array, ...], - grid_offsets: tuple[int | jax.Array, ...], - first_cycle=None, - last_cycle=None, - init_accumulators=None, - ): + def __init__( + self, + step: jax.Array, + grid: tuple[int | jax.Array, ...], + grid_offsets: tuple[int | jax.Array, ...], + first_cycle=None, + last_cycle=None, + init_accumulators=None, + trace_scopes=True, + ): """Initializes scheduler. - Args: - step: inner step number. - grid: pallas grid for BufferedRefs. - grid_offsets: offsets for grid indices (used for megacore). - first_cycle: whether this is the first invocation of the pipeline. - last_cycle: whether this is the last invocation of the pipeline. - init_accumulators: do we zero-initialize accumulator state for this - invocation of the pipeline. + Args: + step: inner step number. + grid: pallas grid for BufferedRefs. + grid_offsets: offsets for grid indices (used for megacore). + first_cycle: whether this is the first invocation of the pipeline. + last_cycle: whether this is the last invocation of the pipeline. + init_accumulators: do we zero-initialize accumulator state for this + invocation of the pipeline. + trace_scopes: whether to use named_scope to trace blocks in the pipeline. """ self.step = step self.grid = grid self.first_cycle = first_cycle self.last_cycle = last_cycle self.init_accumulators = init_accumulators + self.trace_scopes = trace_scopes # Total number of linear steps. self.num_steps = _grid_size(grid) @@ -562,6 +592,14 @@ def __init__(self, self.next_step, grid, grid_offsets ) + @contextmanager + def _named_scope(self, name): + if self.trace_scopes: + with jax.named_scope(name): + yield + else: + yield + def grid_env(self): return pallas_core.grid_env( list(map(pallas_core.GridAxis, self.indices, self.grid))) @@ -589,7 +627,7 @@ def initialize(self, buffered_ref, src_ref, schedule=None): schedule = _default_schedule pred = schedule["prologue_copy_in"](self, buffered_ref, src_ref) - with jax.named_scope("ep_initialize"): + with self._named_scope("ep_initialize"): @pl.when(self.first_step_ever) def _init_slots(): buffered_ref.init_slots() @@ -608,7 +646,7 @@ def wait_in(self, buffered_ref, src_ref, schedule=None): schedule = _default_schedule pred = schedule["wait_in"](self, buffered_ref, src_ref) - @jax.named_scope("ep_wait_in") + @self._named_scope("ep_wait_in") def _wait(): if buffered_ref.is_input: buffered_ref.wait_in(src_ref, self.indices) @@ -616,7 +654,8 @@ def _wait(): # In most cases we won't be waiting when init_accumulators is True, # so this is usually just setting what we just copied. buffered_ref.set_accumulator(self.init_accumulators) - @jax.named_scope("ep_set_accum") + + @self._named_scope("ep_set_accum") def _no_wait(): if buffered_ref.is_accumulator: @@ -633,7 +672,7 @@ def copy_in(self, buffered_ref, src_ref, schedule=None): pred = schedule['copy_in'](self, buffered_ref, src_ref) @pl.when(pred) - @jax.named_scope("ep_copy_in") + @self._named_scope("ep_copy_in") def _send(): if buffered_ref.is_input: # We skip the last step because that's what prefetch is for. @@ -650,7 +689,7 @@ def prefetch(self, buffered_ref, src_ref, schedule=None): pred = schedule['prefetch'](self, buffered_ref, src_ref) @pl.when(pred) - @jax.named_scope("ep_prefetch") + @self._named_scope("ep_prefetch") def _send(): if buffered_ref.is_input: # Prefetch should only run on the last step. @@ -664,7 +703,7 @@ def wait_out(self, buffered_ref, dst_ref, schedule=None): pred = schedule['wait_out'](self, buffered_ref, dst_ref) @pl.when(pred) - @jax.named_scope("ep_wait_out") + @self._named_scope("ep_wait_out") def _wait(): if buffered_ref.is_output: buffered_ref.wait_out(dst_ref, self.prev_indices) @@ -677,13 +716,14 @@ def copy_out(self, buffered_ref, dst_ref, schedule=None): schedule = _default_schedule pred = schedule['copy_out'](self, buffered_ref, dst_ref) - @jax.named_scope("ep_copy_out") + @self._named_scope("ep_copy_out") def _copy_out_and_accumulate(): if buffered_ref.is_accumulator: buffered_ref.accumulate() if buffered_ref.is_output: buffered_ref.copy_out(dst_ref, self.indices) - @jax.named_scope("ep_accum") + + @self._named_scope("ep_accum") def _just_accumulate(): if buffered_ref.is_accumulator: # We accumulate on the last step because we will set the accumulator @@ -702,7 +742,7 @@ def finalize(self, buffered_ref, dst_ref, schedule=None): pred = schedule['epilogue_wait_out'](self, buffered_ref, dst_ref) @pl.when(pred) - @jax.named_scope("ep_finalize") + @self._named_scope("ep_finalize") def _end(): if buffered_ref.is_output: buffered_ref.swap_slots() # formally correct, not actually necessary. @@ -945,7 +985,8 @@ def emit_pipeline( out_specs=None, should_accumulate_out=False, core_axis: int | None = None, - dimension_semantics: tuple[GridDimensionSemantics, ...] | None = None + dimension_semantics: tuple[GridDimensionSemantics, ...] | None = None, + trace_scopes: bool = True, ): """Creates a function to emit a manual pallas pipeline. @@ -968,6 +1009,8 @@ def emit_pipeline( along the core axis. dimension_semantics: optional tuple of GridDimensionSemantics (e.g. PARALLEL or ARBITRARY). + trace_scopes: optional bool, indicates whether to annotate each region in + the pipeline using named_scope. """ if any(not isinstance(d, (int, jax.Array)) for d in grid): grid_types = tuple(type(d) for d in grid) @@ -997,6 +1040,7 @@ def pipeline( prefetch=None, postyeet=None, schedule=None, + body_prologue=None, ): """ Run the pipeline. @@ -1019,6 +1063,8 @@ def pipeline( Called during the outputs phase in the first inner step. schedule: manually specified pipeline schedules for brefs, None indicates default schedule. + body_prologue: For running code within the grid environment before the + body is run. Useful for updating manual refs. """ if scratches is None: scratches = () @@ -1062,7 +1108,9 @@ def loop_body(step, _): grid_offsets=grid_offsets, first_cycle=first_cycle, last_cycle=last_cycle, - init_accumulators=init_accumulators) + init_accumulators=init_accumulators, + trace_scopes=trace_scopes, + ) # prepare any local VMEM aliases brefs = map_brefs(scheduler.alias_local_refs, allocations, refs) @@ -1073,15 +1121,18 @@ def loop_body(step, _): map_brefs(scheduler.wait_in, brefs, refs, schedule) # prefetch inputs for the *next* invocation of this pipeline - with jax.named_scope("ep_prefetch"): + with scheduler._named_scope("ep_prefetch"): if prefetch is not None: lax.cond(step == num_steps - 1, lambda: prefetch(*brefs, scheduler), lambda: None) # run the kernel! + if body_prologue is not None: + with scheduler.grid_env(): + body_prologue() current_refs = map_brefs(lambda x: x.current_ref, brefs) - with jax.named_scope("ep_run_kernel"): + with scheduler._named_scope("ep_run_kernel"): with scheduler.grid_env(): body(*current_refs, *scratches) @@ -1089,7 +1140,7 @@ def loop_body(step, _): map_brefs(scheduler.copy_out, brefs, refs, schedule) map_brefs(scheduler.wait_out, brefs, refs, schedule) # handle writes for the *last* invocation of this pipeline's outputs - with jax.named_scope("ep_postyeet"): + with scheduler._named_scope("ep_postyeet"): if postyeet is not None: lax.cond(step == 0, lambda: postyeet(*brefs, scheduler), diff --git a/jax/_src/pallas/mosaic/primitives.py b/jax/_src/pallas/mosaic/primitives.py index 7ddd8fb9d8c5..aab214a2d700 100644 --- a/jax/_src/pallas/mosaic/primitives.py +++ b/jax/_src/pallas/mosaic/primitives.py @@ -26,12 +26,14 @@ from jax._src import state from jax._src import tree_util from jax._src import util -from jax._src.state import indexing -from jax._src.state import primitives as sp from jax._src.interpreters import mlir from jax._src.pallas import core as pl_core +from jax._src.pallas import utils as pallas_utils from jax._src.pallas.mosaic import core as tpu_core from jax._src.state import discharge as state_discharge +from jax._src.state import indexing +from jax._src.state import primitives as sp +from jax._src.state.types import Transform from jax._src.typing import DTypeLike import jax.numpy as jnp @@ -65,7 +67,9 @@ def bitcast(x, ty: DTypeLike): ty = dtypes.canonicalize_dtype(ty) if len(x.shape) < 2: raise ValueError("Not implemented: bitcast 1D") - if x.shape[-2] * x.dtype.itemsize % ty.itemsize: + src_bitwidth = pallas_utils.dtype_bitwidth(x.dtype) + dst_bitwidth = pallas_utils.dtype_bitwidth(ty) + if x.shape[-2] * src_bitwidth % dst_bitwidth: raise ValueError( "Not implemented: the 2nd minor dim can not be perfectly packed or" " unpacked" @@ -76,19 +80,23 @@ def bitcast(x, ty: DTypeLike): @bitcast_p.def_abstract_eval def _bitcast_abstract_eval(x, *, ty): shape = list(x.shape) - shape[-2] = shape[-2] * x.dtype.itemsize // ty.itemsize + src_bitwidth = pallas_utils.dtype_bitwidth(x.dtype) + dst_bitwidth = pallas_utils.dtype_bitwidth(ty) + shape[-2] = shape[-2] * src_bitwidth // dst_bitwidth return jax_core.ShapedArray(shape, ty) def _bitcast_lowering_rule(ctx: mlir.LoweringRuleContext, x, *, ty): def _bitcast(x): - if x.dtype.itemsize < ty.itemsize: + src_bitwidth = pallas_utils.dtype_bitwidth(x.dtype) + dst_bitwidth = pallas_utils.dtype_bitwidth(ty) + if src_bitwidth < dst_bitwidth: *leading, m, n = x.shape - packing = ty.itemsize // x.dtype.itemsize + packing = dst_bitwidth // src_bitwidth x = x.reshape(*leading, m // packing, packing, n) x = jnp.swapaxes(x, -1, -2) return jax.lax.bitcast_convert_type(x, ty) - if x.dtype.itemsize > ty.itemsize: + if src_bitwidth > dst_bitwidth: y = jax.lax.bitcast_convert_type(x, ty) *leading, m, n, packing = y.shape return jnp.swapaxes(y, -1, -2).reshape(*leading, m * packing, n) @@ -157,14 +165,21 @@ class DeviceIdType(enum.Enum): LOGICAL = "logical" -def check_sem_avals(sem_aval, sem_indexers_avals, name, allowed_semaphore_types=None): +def check_sem_avals( + sem_aval, sem_transforms_avals, name, allowed_semaphore_types=None +): if allowed_semaphore_types is None: - allowed_semaphore_types = {tpu_core.semaphore, tpu_core.barrier_semaphore} + allowed_semaphore_types = { + tpu_core.semaphore, + tpu_core.barrier_semaphore, + # For interpret mode. + pl_core.SEMAPHORE_INTERPRET_DTYPE, + } if not isinstance(sem_aval, state.AbstractRef): raise ValueError(f"Cannot {name} on a non-semaphore Ref: {sem_aval}") sem_shape = sem_aval.shape - if sem_indexers_avals: - sem_shape = sem_indexers_avals[-1].get_indexer_shape() + if sem_transforms_avals: + sem_shape = sem_transforms_avals[-1].get_indexer_shape() if sem_shape: raise ValueError(f"Cannot {name} on a non-()-shaped semaphore: {sem_shape}") sem_dtype = sem_aval.dtype @@ -174,7 +189,20 @@ def check_sem_avals(sem_aval, sem_indexers_avals, name, allowed_semaphore_types= ): raise ValueError( f"Must {name} semaphores of the following types:" - f" {allowed_semaphore_types}" + f" {allowed_semaphore_types}." + ) + + +def _transform_semaphore(ref_value, transforms, ref_aval): + """Helper function for indexing into a semaphore during state_discharge.""" + if ref_value.shape == ref_aval.shape: + return state_discharge.transform_array(ref_value, transforms) + elif len(ref_value.shape) == 0: + return ref_value + else: + raise ValueError( + f"Semaphore value shape {ref_value.shape} does not match aval shape" + f" {ref_aval.shape}" ) @@ -183,8 +211,8 @@ def check_sem_avals(sem_aval, sem_indexers_avals, name, allowed_semaphore_types= def semaphore_read(sem_or_view): - ref, indexers = _get_ref_and_indexers(sem_or_view) - args = [ref, indexers] + ref, transforms = _get_ref_and_transforms(sem_or_view) + args = [ref, transforms] flat_args, args_tree = tree_util.tree_flatten(args) return semaphore_read_p.bind(*flat_args, args_tree=args_tree) @@ -193,17 +221,33 @@ def _semaphore_read_abstract_eval( *avals, args_tree, ): - sem_aval, sem_indexers_avals = tree_util.tree_unflatten(args_tree, avals) + sem_aval, sem_transforms_avals = tree_util.tree_unflatten(args_tree, avals) check_sem_avals( sem_aval, - sem_indexers_avals, + sem_transforms_avals, "read", allowed_semaphore_types={ - tpu_core.dma_semaphore, tpu_core.semaphore, tpu_core.barrier_semaphore + tpu_core.dma_semaphore, + tpu_core.semaphore, + tpu_core.barrier_semaphore, + pl_core.SEMAPHORE_INTERPRET_DTYPE, }, ) return jax_core.ShapedArray((), jnp.dtype("int32")) +def _semaphore_read_discharge_rule(in_avals, + out_avals, + *flat_args, + args_tree): + del out_avals + [ref, transforms] = args_tree.unflatten(flat_args) + sem_value = _transform_semaphore(ref, transforms, in_avals[0]) + sem_value = sem_value.astype(jnp.int32) + return (None,) * len(in_avals), sem_value +state_discharge.register_discharge_rule(semaphore_read_p)( + _semaphore_read_discharge_rule +) + semaphore_signal_p = jax_core.Primitive('semaphore_signal') semaphore_signal_p.multiple_results = True @@ -217,9 +261,9 @@ def semaphore_signal( device_id_type: DeviceIdType = DeviceIdType.MESH, core_index: int | jax.Array | None = None, ): - ref, indexers = _get_ref_and_indexers(sem_or_view) + ref, transforms = _get_ref_and_transforms(sem_or_view) inc = jnp.asarray(inc, dtype=jnp.int32) - args = [ref, indexers, inc, device_id, core_index] + args = [ref, transforms, inc, device_id, core_index] flat_args, args_tree = tree_util.tree_flatten(args) semaphore_signal_p.bind( *flat_args, @@ -235,10 +279,14 @@ def _semaphore_signal_abstract_eval( device_id_type: DeviceIdType, ): del device_id_type - sem_aval, sem_indexers_avals, value_aval, device_id_avals, core_index_aval = ( - tree_util.tree_unflatten(args_tree, avals) - ) - check_sem_avals(sem_aval, sem_indexers_avals, "signal") + ( + sem_aval, + sem_transforms_avals, + value_aval, + device_id_avals, + core_index_aval, + ) = tree_util.tree_unflatten(args_tree, avals) + check_sem_avals(sem_aval, sem_transforms_avals, "signal") if value_aval.dtype != jnp.dtype("int32"): raise ValueError("Must signal an int32 value.") if device_id_avals is not None: @@ -257,16 +305,16 @@ def _semaphore_signal_pp_eqn(eqn: jax_core.JaxprEqn, tree = eqn.params["args_tree"] ( sem, - sem_indexers, + sem_transforms, value, device_ids, _, ) = tree_util.tree_unflatten(tree, invars) out = pp.concat([ - pp.text('semaphore_signal'), - pp.text(' '), - sp.pp_ref_indexers(context, sem, sem_indexers), - pp.text(' '), + pp.text("semaphore_signal"), + pp.text(" "), + sp.pp_ref_transforms(context, sem, sem_transforms), + pp.text(" "), pp.text(jax_core.pp_var(value, context)), ]) if device_ids is not None: @@ -281,20 +329,45 @@ def _semaphore_signal_pp_eqn(eqn: jax_core.JaxprEqn, return out jax_core.pp_eqn_rules[semaphore_signal_p] = _semaphore_signal_pp_eqn + +def _semaphore_signal_discharge_rule(in_avals, + out_avals, + *flat_args, + args_tree, + device_id_type): + del out_avals, device_id_type + [ref, transforms, inc, device_id, core_index] = args_tree.unflatten(flat_args) + if device_id is not None: + raise NotImplementedError("Remote signal not implemented.") + if core_index is not None: + raise NotImplementedError("Multiple core support not implemented.") + sem_value = _transform_semaphore(ref, transforms, in_avals[0]) + inc = inc.astype(pl_core.SEMAPHORE_INTERPRET_DTYPE) + _, new_sem_value = state_discharge.transform_swap_array( + ref, transforms, sem_value + inc + ) + return (new_sem_value,) + (None,) * (len(in_avals) - 1), () +state_discharge.register_discharge_rule(semaphore_signal_p)( + _semaphore_signal_discharge_rule +) + + semaphore_wait_p = jax_core.Primitive('semaphore_wait') semaphore_wait_p.multiple_results = True def semaphore_wait(sem_or_view, dec: int | jax.Array = 1): - ref, indexers = _get_ref_and_indexers(sem_or_view) + ref, transforms = _get_ref_and_transforms(sem_or_view) dec = jnp.asarray(dec, dtype=jnp.int32) - args = [ref, indexers, dec] + args = [ref, transforms, dec] flat_args, args_tree = tree_util.tree_flatten(args) semaphore_wait_p.bind(*flat_args, args_tree=args_tree) @semaphore_wait_p.def_abstract_eval def _semaphore_wait_abstract_eval(*avals, args_tree): - sem_aval, sem_indexers_avals, value_aval = tree_util.tree_unflatten(args_tree, avals) - check_sem_avals(sem_aval, sem_indexers_avals, "wait") + sem_aval, sem_transforms_avals, value_aval = tree_util.tree_unflatten( + args_tree, avals + ) + check_sem_avals(sem_aval, sem_transforms_avals, "wait") if value_aval.dtype != jnp.dtype("int32"): raise ValueError("Must wait an int32 value.") return [] @@ -307,29 +380,45 @@ def _semaphore_wait_pp_eqn(eqn: jax_core.JaxprEqn, tree = eqn.params["args_tree"] ( sem, - sem_indexers, + sem_transforms, value, ) = tree_util.tree_unflatten(tree, invars) return pp.concat([ - pp.text('semaphore_wait'), - pp.text(' '), - sp.pp_ref_indexers(context, sem, sem_indexers), - pp.text(' '), + pp.text("semaphore_wait"), + pp.text(" "), + sp.pp_ref_transforms(context, sem, sem_transforms), + pp.text(" "), pp.text(jax_core.pp_var(value, context)), ]) jax_core.pp_eqn_rules[semaphore_wait_p] = _semaphore_wait_pp_eqn +def _semaphore_wait_discharge_rule(in_avals, + out_avals, + *flat_args, + args_tree): + del out_avals + [ref, transforms, dec] = args_tree.unflatten(flat_args) + sem_value = _transform_semaphore(ref, transforms, in_avals[0]) + dec = dec.astype(pl_core.SEMAPHORE_INTERPRET_DTYPE) + _, new_sem_value = state_discharge.transform_swap_array( + ref, transforms, sem_value - dec + ) + return (new_sem_value,) + (None,) * (len(in_avals) - 1), () +state_discharge.register_discharge_rule(semaphore_wait_p)( + _semaphore_wait_discharge_rule +) + @dataclasses.dataclass class AsyncCopyDescriptor: src_ref: Any - src_indexers: tuple[indexing.NDIndexer, ...] + src_transforms: tuple[Transform, ...] dst_ref: Any - dst_indexers: tuple[indexing.NDIndexer, ...] + dst_transforms: tuple[Transform, ...] dst_sem: int | jax.Array - dst_sem_indexers: tuple[indexing.NDIndexer, ...] + dst_sem_transforms: tuple[Transform, ...] src_sem: int | jax.Array | None - src_sem_indexers: tuple[indexing.NDIndexer, ...] | None + src_sem_transforms: tuple[Transform, ...] | None device_id: int | jax.Array | None device_id_type: DeviceIdType = DeviceIdType.MESH @@ -345,13 +434,13 @@ def is_remote(self): def start(self): flat_args, tree = tree_util.tree_flatten(( self.src_ref, - self.src_indexers, + self.src_transforms, self.dst_ref, - self.dst_indexers, + self.dst_transforms, self.dst_sem, - self.dst_sem_indexers, + self.dst_sem_transforms, self.src_sem, - self.src_sem_indexers, + self.src_sem_transforms, self.device_id, )) dma_start_p.bind(*flat_args, tree=tree, device_id_type=self.device_id_type) @@ -362,9 +451,12 @@ def wait(self): self.wait_recv() def wait_recv(self): - wait_args, tree = tree_util.tree_flatten( - (self.dst_sem, self.dst_sem_indexers, self.dst_ref, self.dst_indexers) - ) + wait_args, tree = tree_util.tree_flatten(( + self.dst_sem, + self.dst_sem_transforms, + self.dst_ref, + self.dst_transforms, + )) dma_wait_p.bind( *wait_args, tree=tree, device_id_type=self.device_id_type ) @@ -372,9 +464,12 @@ def wait_recv(self): def wait_send(self): if not self.is_remote: raise ValueError("Cannot `wait_send` on a local copy.") - wait_args, tree = tree_util.tree_flatten( - (self.src_sem, self.src_sem_indexers, self.src_ref, self.src_indexers) - ) + wait_args, tree = tree_util.tree_flatten(( + self.src_sem, + self.src_sem_transforms, + self.src_ref, + self.src_transforms, + )) dma_wait_p.bind( *wait_args, tree=tree, device_id_type=self.device_id_type ) @@ -387,32 +482,32 @@ def wait_send(self): def _dma_start_abstract_eval(*args, tree, device_id_type): ( src_ref_aval, - src_indexers_avals, + src_transforms_avals, dst_ref_aval, - dst_indexers_avals, + dst_transforms_avals, dst_sem_aval, - dst_sem_indexers_avals, + dst_sem_transforms_avals, src_sem_aval, - src_sem_indexers_avals, + src_sem_transforms_avals, device_id_aval, ) = tree_util.tree_unflatten(tree, args) dst_sem_shape = dst_sem_aval.shape - if dst_sem_indexers_avals: - dst_sem_shape = dst_sem_indexers_avals[-1].get_indexer_shape() + if dst_sem_transforms_avals: + dst_sem_shape = dst_sem_transforms_avals[-1].get_indexer_shape() if dst_sem_shape: raise ValueError( f"Cannot signal on a non-()-shaped semaphore: {dst_sem_shape}" ) if src_sem_aval is not None: src_sem_shape = src_sem_aval.shape - if src_sem_indexers_avals: - src_sem_shape = src_sem_indexers_avals[-1].get_indexer_shape() + if src_sem_transforms_avals: + src_sem_shape = src_sem_transforms_avals[-1].get_indexer_shape() if src_sem_shape: raise ValueError( f"Cannot signal on a non-()-shaped semaphore: {src_sem_shape}" ) - n_src_indexers = len(tree_util.tree_leaves(src_indexers_avals)) - return [], {state.ReadEffect(0), state.WriteEffect(n_src_indexers + 1)} + n_src_transforms = len(tree_util.tree_leaves(src_transforms_avals)) + return [], {state.ReadEffect(0), state.WriteEffect(n_src_transforms + 1)} def _dma_start_pp_eqn(eqn: jax_core.JaxprEqn, context: jax_core.JaxprPpContext, @@ -421,27 +516,27 @@ def _dma_start_pp_eqn(eqn: jax_core.JaxprEqn, tree = eqn.params["tree"] ( src_ref, - src_indexers, + src_transforms, dst_ref, - dst_indexers, + dst_transforms, dst_sem, - dst_sem_indexers, + dst_sem_transforms, src_sem, - src_sem_indexers, + src_sem_transforms, device_id, ) = tree_util.tree_unflatten(tree, invars) - del src_sem_indexers + del src_sem_transforms # TODO(sharadmv): pretty print source semaphores and device id if src_sem or device_id: return jax_core._pp_eqn(eqn, context, settings) return pp.concat([ - pp.text('dma_start'), - pp.text(' '), - sp.pp_ref_indexers(context, src_ref, src_indexers), - pp.text(' -> '), - sp.pp_ref_indexers(context, dst_ref, dst_indexers), - pp.text(' '), - sp.pp_ref_indexers(context, dst_sem, dst_sem_indexers), + pp.text("dma_start"), + pp.text(" "), + sp.pp_ref_transforms(context, src_ref, src_transforms), + pp.text(" -> "), + sp.pp_ref_transforms(context, dst_ref, dst_transforms), + pp.text(" "), + sp.pp_ref_transforms(context, dst_sem, dst_sem_transforms), ]) jax_core.pp_eqn_rules[dma_start_p] = _dma_start_pp_eqn @@ -450,36 +545,40 @@ def dma_start_discharge_rule(in_avals, out_avals, *args, tree, device_id_type): ( src_ref, - src_indexers, + src_transforms, dst_ref, - dst_indexers, + dst_transforms, dst_sem, - dst_sem_indexers, + dst_sem_transforms, src_sem, - src_sem_indexers, + src_sem_transforms, device_id, ) = tree_util.tree_unflatten(tree, args) ( _, - src_indexers_avals, + src_transforms_avals, + _, + dst_transforms_avals, + dst_sem_aval, + dst_sem_transforms_avals, + src_sem_aval, + src_sem_transforms_avals, _, - dst_indexers_avals, - *_ ) = tree_util.tree_unflatten(tree, in_avals) - del out_avals, dst_sem, dst_sem_indexers + del out_avals is_remote = device_id is not None if not is_remote: # Local async copies only use one semaphore. assert src_sem is None - assert src_sem_indexers is None + assert src_sem_transforms is None - num_src_index_vals = len(tree_util.tree_leaves(src_indexers_avals)) - num_dst_index_vals = len(tree_util.tree_leaves(dst_indexers_avals)) + num_src_sem_transforms = len(tree_util.tree_leaves(src_sem_transforms_avals)) + num_dst_sem_transforms = len(tree_util.tree_leaves(dst_sem_transforms_avals)) + num_src_transform_vals = len(tree_util.tree_leaves(src_transforms_avals)) + num_dst_transform_vals = len(tree_util.tree_leaves(dst_transforms_avals)) - if src_indexers: - updates = state_discharge.index_array(src_ref, src_indexers) - else: - updates = src_ref + updates = state_discharge.transform_array(src_ref, src_transforms) + local_src = updates if is_remote: # Note that this code only works in SPMD mode. If not all devices execute @@ -522,31 +621,56 @@ def dma_start_discharge_rule(in_avals, out_avals, global_updates, index, axis=0, keepdims=False) # Handle asymmetrical indexing when devices do not share the same - # dst_indexer. - global_dst_indexers = tree_util.tree_map( - lambda x: jax.lax.all_gather(x, shard_axis), dst_indexers) - dst_indexers = tree_util.tree_map( + # dst_transform. + global_dst_transforms = tree_util.tree_map( + lambda x: jax.lax.all_gather(x, shard_axis), dst_transforms + ) + dst_transforms = tree_util.tree_map( lambda x: jax.lax.dynamic_index_in_dim( - x, index, axis=0, keepdims=False), global_dst_indexers) + x, index, axis=0, keepdims=False + ), + global_dst_transforms, + ) + + _, new_dst = state_discharge.transform_swap_array( + dst_ref, dst_transforms, updates + ) - if dst_indexers: - _, new_dst = state_discharge.index_swap_array( - dst_ref, dst_indexers, updates + # Update semaphore values. + # TODO(justinfu): Potentially handle asymmetric copy sizes. + recv_size = jnp.minimum(updates.size, pl_core.SEMAPHORE_MAX_VALUE) + recv_size = jnp.array(recv_size, dtype=pl_core.SEMAPHORE_INTERPRET_DTYPE) + dst_sem_value = _transform_semaphore( + dst_sem, dst_sem_transforms, dst_sem_aval + ) + _, new_dst_sem = state_discharge.transform_swap_array( + dst_sem, dst_sem_transforms, dst_sem_value + recv_size + ) + if is_remote: + send_size = jnp.minimum(local_src.size, pl_core.SEMAPHORE_MAX_VALUE) + send_size = jnp.array(send_size, dtype=pl_core.SEMAPHORE_INTERPRET_DTYPE) + src_sem_value = _transform_semaphore( + src_sem, src_sem_transforms, src_sem_aval + ) + _, new_src_sem = state_discharge.transform_swap_array( + src_sem, src_sem_transforms, src_sem_value + send_size ) else: - new_dst = updates - - # TODO(b/345505876): Implement semaphore counting. - new_avals = (None,) # src_aval - new_avals += (None,) * num_src_index_vals - new_avals += (new_dst,) # dst_aval - new_avals += (None,) * num_dst_index_vals - new_avals += (None,) # dst_sem_aval + new_src_sem = None + + new_vals = (None,) # src_val + new_vals += (None,) * num_src_transform_vals + new_vals += (new_dst,) # dst_val + new_vals += (None,) * num_dst_transform_vals + new_vals += (new_dst_sem,) # dst_sem + new_vals += (None,) * num_dst_sem_transforms if is_remote: - new_avals += (None, None) # src_sem_aval, device_id - assert (len(new_avals) == - len(in_avals)), f"{len(new_avals), new_avals} != {len(in_avals)}" - return new_avals, [] + new_vals += (new_src_sem,) # src_sem + new_vals += (None,) * num_src_sem_transforms + new_vals += (None,) # device_id + assert (len(new_vals) == + len(in_avals)), f"{len(new_vals), new_vals} != {len(in_avals)}" + return new_vals, [] state_discharge.register_discharge_rule(dma_start_p)(dma_start_discharge_rule) @@ -565,37 +689,67 @@ def _dma_wait_pp_eqn(eqn: jax_core.JaxprEqn, del settings invars = eqn.invars tree = eqn.params["tree"] - sem, sem_indexers, ref, indexers = tree_util.tree_unflatten(tree, invars) + sem, sem_transforms, ref, transforms = tree_util.tree_unflatten(tree, invars) return pp.concat([ - pp.text('dma_wait'), - pp.text(' '), - sp.pp_ref_indexers(context, ref, indexers), - pp.text(' '), - sp.pp_ref_indexers(context, sem, sem_indexers), + pp.text("dma_wait"), + pp.text(" "), + sp.pp_ref_transforms(context, ref, transforms), + pp.text(" "), + sp.pp_ref_transforms(context, sem, sem_transforms), ]) jax_core.pp_eqn_rules[dma_wait_p] = _dma_wait_pp_eqn def dma_wait_discharge_rule(in_avals, out_avals, *args, tree, device_id_type): - del out_avals, args, tree, device_id_type - # TODO(justinfu): Implement semaphore counting. - return (None,) * len(in_avals), [] + del out_avals, device_id_type + (sem, sem_transforms, ref, ref_transforms) = tree_util.tree_unflatten( + tree, args + ) + ( + sem_aval, + sem_transforms_avals, + _, + ref_transforms_avals, + ) = tree_util.tree_unflatten(tree, in_avals) + num_sem_transforms = len(tree_util.tree_leaves(sem_transforms_avals)) + num_transforms = len(tree_util.tree_leaves(ref_transforms_avals)) + updates = state_discharge.transform_array(ref, ref_transforms) + copy_size = jnp.minimum(updates.size, pl_core.SEMAPHORE_MAX_VALUE) + copy_size = jnp.array(copy_size, dtype=pl_core.SEMAPHORE_INTERPRET_DTYPE) + sem_value = _transform_semaphore(sem, sem_transforms, sem_aval) + _, new_sem = state_discharge.transform_swap_array( + sem, sem_transforms, sem_value - copy_size + ) + new_vals = (new_sem,) # sem + new_vals += (None,) * num_sem_transforms + new_vals += (None,) # ref + new_vals += (None,) * num_transforms + return new_vals, [] state_discharge.register_discharge_rule(dma_wait_p)(dma_wait_discharge_rule) -def _get_ref_and_indexers(ref): - if isinstance(ref, state.RefView): - return ref.ref, ref.indexers +def _get_ref_and_transforms(ref): + if isinstance(ref, state.TransformedRef): + return ref.ref, ref.transforms return ref, () def make_async_copy(src_ref, dst_ref, sem): """Issues a DMA copying from src_ref to dst_ref.""" - src_ref, src_indexers = _get_ref_and_indexers(src_ref) - dst_ref, dst_indexers = _get_ref_and_indexers(dst_ref) - sem, sem_indexers = _get_ref_and_indexers(sem) - return AsyncCopyDescriptor(src_ref, src_indexers, dst_ref, dst_indexers, - sem, sem_indexers, None, None, None, - DeviceIdType.MESH) + src_ref, src_transforms = _get_ref_and_transforms(src_ref) + dst_ref, dst_transforms = _get_ref_and_transforms(dst_ref) + sem, sem_transforms = _get_ref_and_transforms(sem) + return AsyncCopyDescriptor( + src_ref, + src_transforms, + dst_ref, + dst_transforms, + sem, + sem_transforms, + None, + None, + None, + DeviceIdType.MESH, + ) def async_copy(src_ref, dst_ref, sem): """Issues a DMA copying from src_ref to dst_ref.""" @@ -623,13 +777,22 @@ def make_async_remote_copy(src_ref, dst_ref, send_sem, recv_sem, device_id, Returns: An AsyncCopyDescriptor. """ - src_ref, src_indexers = _get_ref_and_indexers(src_ref) - send_sem, send_sem_indexers = _get_ref_and_indexers(send_sem) - dst_ref, dst_indexers = _get_ref_and_indexers(dst_ref) - recv_sem, recv_sem_indexers = _get_ref_and_indexers(recv_sem) + src_ref, src_transforms = _get_ref_and_transforms(src_ref) + send_sem, send_sem_transforms = _get_ref_and_transforms(send_sem) + dst_ref, dst_transforms = _get_ref_and_transforms(dst_ref) + recv_sem, recv_sem_transforms = _get_ref_and_transforms(recv_sem) return AsyncCopyDescriptor( - src_ref, src_indexers, dst_ref, dst_indexers, recv_sem, recv_sem_indexers, - send_sem, send_sem_indexers, device_id, device_id_type=device_id_type) + src_ref, + src_transforms, + dst_ref, + dst_transforms, + recv_sem, + recv_sem_transforms, + send_sem, + send_sem_transforms, + device_id, + device_id_type=device_id_type, + ) def async_remote_copy(src_ref, dst_ref, send_sem, recv_sem, device_id, device_id_type: DeviceIdType = DeviceIdType.MESH): diff --git a/jax/_src/pallas/mosaic/random.py b/jax/_src/pallas/mosaic/random.py index c642d99578cd..68a4fe508917 100644 --- a/jax/_src/pallas/mosaic/random.py +++ b/jax/_src/pallas/mosaic/random.py @@ -14,8 +14,8 @@ from collections.abc import Callable +import functools import jax -import numpy as np from jax import numpy as jnp from jax import random as jax_api_random from jax._src import blocked_sampler @@ -37,15 +37,13 @@ def to_pallas_key(key: jax_prng.PRNGKeyArray) -> jax_prng.PRNGKeyArray: """Helper function for converting non-Pallas PRNG keys into Pallas keys.""" - batch_dims = key.shape - key_data = jax_api_random.key_data(key) - pallas_key_size = np.prod(tpu_key_impl.key_shape) - if np.prod(key_data.shape[len(batch_dims):]) < pallas_key_size: - raise ValueError(f"Key data must be at least {pallas_key_size} bytes.") - pallas_key_data = jnp.reshape( - key_data, batch_dims + (-1,))[..., :pallas_key_size] - pallas_key_data = jnp.reshape(pallas_key_data, - batch_dims + tpu_key_impl.key_shape) + generate_key = functools.partial( + jax.random.bits, shape=tpu_key_impl.key_shape, dtype=jnp.uint32 + ) + if len(key.shape) == 0: + pallas_key_data = generate_key(key) + else: + pallas_key_data = (jax.vmap(generate_key))(key) return jax_api_random.wrap_key_data(pallas_key_data, impl="pallas_tpu") def _seed_func(seed: jnp.int32): diff --git a/jax/_src/pallas/mosaic/verification.py b/jax/_src/pallas/mosaic/verification.py index df186d46373a..bae87226c664 100644 --- a/jax/_src/pallas/mosaic/verification.py +++ b/jax/_src/pallas/mosaic/verification.py @@ -550,13 +550,17 @@ def _pretend_abstract_eval(*_, **params): def _pretend_lowering(ctx: lowering.LoweringRuleContext, *flat_args, tree): if ctx.lowering_context.for_verification: - (base_read_refs, indexers) = tree_util.tree_unflatten(tree, flat_args) + (base_read_refs, transforms) = tree_util.tree_unflatten(tree, flat_args) read_ref_avals, _ = tree_util.tree_unflatten(tree, ctx.avals_in) block_shapes, _ = tree_util.tree_unflatten(tree, ctx.block_shapes) read_refs = [ lowering._index_ref(ref, aval, block_shape, indexer)[0] for ref, aval, block_shape, indexer in zip( - base_read_refs, read_ref_avals, block_shapes, indexers, strict=True, + base_read_refs, + read_ref_avals, + block_shapes, + transforms, + strict=True, ) ] ir.Operation.create("verification.pretend", operands=read_refs) @@ -565,8 +569,10 @@ def _pretend_lowering(ctx: lowering.LoweringRuleContext, *flat_args, tree): lowering.lowering_rules[pretend_p] = _pretend_lowering # type: ignore def pretend(read_refs): - refs, indexers = unzip2(primitives._get_ref_and_indexers(r) for r in read_refs) - flat_args, tree = tree_util.tree_flatten((refs, indexers)) + refs, transforms = unzip2( + primitives._get_ref_and_transforms(r) for r in read_refs + ) + flat_args, tree = tree_util.tree_flatten((refs, transforms)) return pretend_p.bind(*flat_args, tree=tree) diff --git a/jax/_src/pallas/mosaic_gpu/BUILD b/jax/_src/pallas/mosaic_gpu/BUILD index 038826a663e8..fd291b201fa1 100644 --- a/jax/_src/pallas/mosaic_gpu/BUILD +++ b/jax/_src/pallas/mosaic_gpu/BUILD @@ -24,7 +24,7 @@ load( package( default_applicable_licenses = [], default_visibility = [ - "//:__subpackages__", + "//jax:internal", ], ) @@ -34,6 +34,7 @@ py_library( deps = [ ":core", ":pallas_call_registration", + ":primitives", ], ) @@ -59,6 +60,7 @@ pytype_strict_library( "//jax:mlir", "//jax:mosaic_gpu", "//jax:pallas", + "//jax:partial_eval", "//jax:util", "//jax/_src/lib", "//jax/_src/pallas", @@ -70,6 +72,26 @@ pytype_strict_library( srcs = ["core.py"], deps = [ "//jax", + "//jax:core", + "//jax:dtypes", + "//jax:mosaic_gpu", + "//jax:tree_util", + "//jax/_src/pallas", + ] + py_deps("numpy"), +) + +pytype_strict_library( + name = "primitives", + srcs = ["primitives.py"], + deps = [ + ":core", + ":lowering", + "//jax", + "//jax:core", + "//jax:effects", + "//jax:mlir", + "//jax:mosaic_gpu", + "//jax/_src/lib", "//jax/_src/pallas", ], ) diff --git a/jax/_src/pallas/mosaic_gpu/__init__.py b/jax/_src/pallas/mosaic_gpu/__init__.py index 862a661e24b9..187a84478c65 100644 --- a/jax/_src/pallas/mosaic_gpu/__init__.py +++ b/jax/_src/pallas/mosaic_gpu/__init__.py @@ -11,3 +11,22 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +# TODO(slebedev): Move these imports to ``jax.experimental.pallas``. + +from jax._src.pallas.mosaic_gpu.core import Barrier +from jax._src.pallas.mosaic_gpu.core import GPUBlockSpec +from jax._src.pallas.mosaic_gpu.core import GPUCompilerParams +from jax._src.pallas.mosaic_gpu.core import GPUMemorySpace +from jax._src.pallas.mosaic_gpu.core import TilingTransform +from jax._src.pallas.mosaic_gpu.core import TransposeTransform +from jax._src.pallas.mosaic_gpu.core import WGMMAAccumulatorRef as ACC +from jax._src.pallas.mosaic_gpu.primitives import async_copy_gmem_to_smem +from jax._src.pallas.mosaic_gpu.primitives import async_copy_smem_to_gmem +from jax._src.pallas.mosaic_gpu.primitives import wait_barrier +from jax._src.pallas.mosaic_gpu.primitives import wait_smem_to_gmem +from jax._src.pallas.mosaic_gpu.primitives import wgmma +from jax._src.pallas.mosaic_gpu.primitives import wgmma_wait + +GMEM = GPUMemorySpace.GMEM +SMEM = GPUMemorySpace.SMEM diff --git a/jax/_src/pallas/mosaic_gpu/core.py b/jax/_src/pallas/mosaic_gpu/core.py index fd06a9829644..fe8daf43e995 100644 --- a/jax/_src/pallas/mosaic_gpu/core.py +++ b/jax/_src/pallas/mosaic_gpu/core.py @@ -14,15 +14,42 @@ """Contains GPU-specific Pallas abstractions.""" +from collections.abc import Sequence import dataclasses import enum -from jax import core as jax_core +from typing import Any, ClassVar, Literal, Protocol + +from jax._src import core as jax_core +from jax._src import dtypes +from jax._src import tree_util from jax._src.pallas import core as pallas_core +import jax.experimental.mosaic.gpu as mgpu import jax.numpy as jnp + AbstractMemoryRef = pallas_core.AbstractMemoryRef +@dataclasses.dataclass(frozen=True, kw_only=True) +class GPUCompilerParams(pallas_core.CompilerParams): + """Mosaic GPU compiler parameters. + + Attributes: + approx_math: If True, the compiler is allowed to use approximate + implementations of some math operations, e.g. ``exp``. Defaults to False. + dimension_semantics: A list of dimension semantics for each grid + dimension of the kernel. Either "parallel" for dimensions that can + execute in any order, or "sequential" for dimensions that must be + executed sequentially. + max_concurrent_steps: The maximum number of sequential stages that are + active concurrently. Defaults to 1. + """ + PLATFORM: ClassVar[str] = "mosaic_gpu" + approx_math: bool = False + dimension_semantics: Sequence[Literal["parallel", "sequential"]] | None = None + max_concurrent_steps: int = 1 + + class GPUMemorySpace(enum.Enum): GMEM = "gmem" SMEM = "smem" @@ -33,23 +60,192 @@ def __str__(self) -> str: def __call__(self, shape: tuple[int, ...], dtype: jnp.dtype): # A convenience function for constructing MemoryRef types. - return MemoryRef(shape, dtype, self) + return pallas_core.MemoryRef(shape, dtype, memory_space=self) + + +class MemoryRefTransform(pallas_core.MemoryRefTransform, Protocol): + def to_gpu_transform(self) -> mgpu.MemRefTransform: + ... + + +@dataclasses.dataclass(frozen=True) +class TilingTransform(MemoryRefTransform): + """Represents a tiling transformation for memory refs. + + A tiling of (X, Y) on an array of shape (M, N) will result in a transformed + shape of (M // X, N // Y, X, Y). Ex. A (256, 256) block that is tiled with a + tiling of (64, 32) will be tiled as (4, 8, 64, 32). + """ + + tiling: tuple[int, ...] + + def __call__( + self, block_aval: pallas_core.AbstractMemoryRef + ) -> pallas_core.AbstractMemoryRef: + block_shape = block_aval.shape + old_tiled_dims = block_shape[-len(self.tiling) :] + num_tiles = tuple( + block_dim // tiling_dim + for block_dim, tiling_dim in zip(old_tiled_dims, self.tiling) + ) + rem = ( + block_dim % tiling_dim + for block_dim, tiling_dim in zip(old_tiled_dims, self.tiling) + ) + if any(rem): + raise ValueError( + f"Block shape {block_shape} is not divisible by tiling {self.tiling}" + ) + new_block_shape = block_shape[: -len(self.tiling)] + num_tiles + self.tiling + return block_aval.update( + inner_aval=block_aval.inner_aval.update(shape=new_block_shape) + ) + + def to_gpu_transform(self) -> mgpu.MemRefTransform: + return mgpu.TileTransform(self.tiling) + + +@dataclasses.dataclass(frozen=True) +class TransposeTransform(MemoryRefTransform): + """Transpose a tiled memref.""" + + permutation: tuple[int, ...] + + def __call__( + self, block_aval: pallas_core.AbstractMemoryRef + ) -> pallas_core.AbstractMemoryRef: + shape = block_aval.shape # pytype: disable=attribute-error + return block_aval.update( + inner_aval=block_aval.inner_aval.update( + shape=self.to_gpu_transform().transform_shape(shape) + ) + ) + + def to_gpu_transform(self) -> mgpu.MemRefTransform: + return mgpu.TransposeTransform(self.permutation) -# TODO(b/354568887): Cosolidate this with TPU's MemoryRef. @dataclasses.dataclass(frozen=True) -class MemoryRef: - """Like jax.ShapeDtypeStruct but with memory spaces.""" +class GPUBlockMapping(pallas_core.BlockMapping): + swizzle: int | None = None + - shape: tuple[int, ...] - dtype: jnp.dtype - memory_space: GPUMemorySpace +@dataclasses.dataclass +class GPUBlockSpec(pallas_core.BlockSpec): + transforms: MemoryRefTransform | tuple[MemoryRefTransform, ...] = () + swizzle: int | None = None # TODO: apaszke - Swizzle is also a transform. - def get_aval(self) -> AbstractMemoryRef: - return AbstractMemoryRef( - jax_core.ShapedArray(self.shape, self.dtype), self.memory_space + def to_block_mapping( + self, + origin: pallas_core.OriginStr, + array_aval: jax_core.ShapedArray, + *, + index_map_avals: Sequence[jax_core.AbstractValue], + index_map_tree: tree_util.PyTreeDef, + grid: pallas_core.GridMappingGrid, + mapped_dims: tuple[int, ...], + ) -> GPUBlockMapping: + bm = super().to_block_mapping( + origin, + array_aval, + index_map_avals=index_map_avals, + index_map_tree=index_map_tree, + grid=grid, + mapped_dims=mapped_dims, ) + transforms = self.transforms + if not isinstance(transforms, tuple): + transforms = (transforms,) + return GPUBlockMapping( + block_shape=bm.block_shape, + block_aval=bm.block_aval, + origin=bm.origin, + index_map_jaxpr=bm.index_map_jaxpr, + index_map_src_info=bm.index_map_src_info, + indexing_mode=bm.indexing_mode, + array_shape_dtype=bm.array_shape_dtype, + transforms=transforms, + swizzle=self.swizzle, + ) + GMEM = GPUMemorySpace.GMEM SMEM = GPUMemorySpace.SMEM REGS = GPUMemorySpace.REGS + + +class barrier_dtype(dtypes.extended): + pass + + +@dataclasses.dataclass(frozen=True) +class BarrierType(dtypes.ExtendedDType): + type: ClassVar[Any] = barrier_dtype + name: ClassVar[str] = "barrier" + + num_arrivals: int + + def __str__(self): + return self.name + + +@dataclasses.dataclass(frozen=True) +class Barrier: + num_arrivals: int + num_barriers: int = 1 + + def get_ref_aval(self) -> AbstractMemoryRef: + aval = jax_core.ShapedArray( + [self.num_barriers], BarrierType(self.num_arrivals) + ) + return AbstractMemoryRef(aval, SMEM) + + +@dataclasses.dataclass(frozen=True) +class WGMMAAccumulatorRef: + shape: tuple[int, int] + dtype: jnp.dtype = jnp.float32 + + def get_ref_aval(self) -> AbstractMemoryRef: + return WGMMAAbstractAccumulatorRef( + jax_core.ShapedArray(shape=self.shape, dtype=self.dtype), GPUMemorySpace.REGS + ) + + +def _is_trivial_index(idx): + _is_deref1 = lambda i: i is Ellipsis or i == slice(None) + if isinstance(idx, tuple): + return all(_is_deref1(i) for i in idx) + + return _is_deref1(idx) + +class WGMMAAbstractAccumulatorRef(AbstractMemoryRef): + __slots__ = ["inner_aval", "memory_space"] + + def __repr__(self) -> str: + return f'Accumulator{{{self.inner_aval.str_short()}}}' + + def join(self, other): + return _as_accum(super().join(other)) + + def update(self, inner_aval=None, memory_space=None): + return _as_accum(super().update(inner_aval=None, memory_space=None)) + + def at_least_vspace(self): + return _as_accum(super().at_least_vspace()) + + def _getitem(self, tracer, idx): + if not _is_trivial_index(idx): + raise NotImplementedError(f"Can only dereference accumulators, not slice ({idx=}).") + from jax._src.pallas.mosaic_gpu.primitives import wgmma_accumulator_deref # pytype: disable=import-error + return wgmma_accumulator_deref(tracer) + +def _as_accum(ref) -> WGMMAAbstractAccumulatorRef: + return WGMMAAbstractAccumulatorRef( + inner_aval=ref.inner_aval, + memory_space=ref.memory_space, # pytype: disable=attribute-error + ) + +def _ref_raise_to_shaped(ref_aval, weak_type): + return _as_accum(jax_core.raise_to_shaped_mappings[AbstractMemoryRef](ref_aval, weak_type)) +jax_core.raise_to_shaped_mappings[WGMMAAbstractAccumulatorRef] = _ref_raise_to_shaped diff --git a/jax/_src/pallas/mosaic_gpu/lowering.py b/jax/_src/pallas/mosaic_gpu/lowering.py index be9168e07567..0d0ac41d11e3 100644 --- a/jax/_src/pallas/mosaic_gpu/lowering.py +++ b/jax/_src/pallas/mosaic_gpu/lowering.py @@ -19,23 +19,31 @@ from collections.abc import Sequence import dataclasses import functools +import itertools as it import math from typing import Any, cast import jax +from jax import lax from jax._src import core as jax_core from jax._src import pjit from jax._src import util from jax._src.interpreters import mlir -from jax._src.lax import lax +from jax._src.interpreters import partial_eval as pe from jax._src.lib.mlir import ir from jax._src.lib.mlir.dialects import arith as arith_dialect +from jax._src.lib.mlir.dialects import gpu as gpu_dialect from jax._src.lib.mlir.dialects import memref as memref_dialect +from jax._src.lib.mlir.dialects import scf as scf_dialect from jax._src.pallas import core as pallas_core from jax._src.pallas import primitives +from jax._src.pallas import utils as pallas_utils +from jax._src.pallas.mosaic_gpu import core as gpu_core +from jax._src.state import discharge from jax._src.state import primitives as sp -from jax.experimental.mosaic import gpu as mosaic_gpu -from jax.experimental.mosaic.gpu import dsl as mgpu +import jax.experimental.mosaic.gpu as mgpu +from jax.experimental.mosaic.gpu import core as mgpu_core +from jax.experimental.mosaic.gpu import utils as mgpu_utils import jax.numpy as jnp import numpy as np @@ -48,14 +56,55 @@ zip, unsafe_zip = util.safe_zip, zip partial = functools.partial +SMEM = gpu_core.SMEM + +_smem_estimators = {} + + +def _regiter_smem_estimator(primitive: jax_core.Primitive): + def deco(fn): + _smem_estimators[primitive] = fn + return fn + + return deco + + +def _estimate_smem_scratch_bytes(jaxpr: jax_core.Jaxpr) -> int: + """Estimates the amount of SMEM scratch bytes required by the kernel.""" + max_used = 0 + for eqn in jaxpr.eqns: + # TODO(slebedev): Add support for other primitives, notably control flow. + rule = _smem_estimators.get(eqn.primitive) + if rule is None: + # Assume that unsupported primitives are neutral wrt SMEM usage. + continue + max_used = max( + max_used, rule(*(invar.aval for invar in eqn.invars), **eqn.params) + ) + return max_used + + +@_regiter_smem_estimator(primitives.run_scoped_p) +def _run_scoped_smem_estimator(*consts, jaxpr: jax_core.Jaxpr) -> int: + del consts # Unused. + in_avals = (v.aval.inner_aval for v in jaxpr.invars) + return sum(math.prod(aval.shape) * aval.dtype.itemsize for aval in in_avals) + + +@_regiter_smem_estimator(lax.reduce_sum_p) +def _reduce_sum_smem_estimator(x_aval: jax_core.ShapedArray, *, axes) -> int: + if axes != (0,): + raise NotImplementedError("No support for axes other than 0 yet") + return 4 * x_aval.dtype.itemsize @dataclasses.dataclass class ModuleContext: name: str grid_mapping: pallas_core.GridMapping + approx_math: bool runtime_smem: ir.Value # ir.MemRefType - smem_used_bytes: int + smem_used_bytes: int = 0 # TODO(cperivol): Only return the shapes and figure out the sizes when freeing. def scratch_view( @@ -93,11 +142,11 @@ def scratch_view( for s in structs: scratch_ty = ir.MemRefType.get( s.shape, - mlir.dtype_to_ir_type(s.dtype), + mgpu_utils.dtype_to_ir_type(s.dtype), memory_space=smem, ) views.append( - memref_dialect.view(scratch_ty, self.runtime_smem, _index(off), []) + memref_dialect.view(scratch_ty, self.runtime_smem, _as_index(off), []) ) off += math.prod(s.shape) * jnp.dtype(s.dtype).itemsize @@ -112,35 +161,51 @@ def stack_free_smem(self, bytes: int): self.smem_used_bytes -= bytes -@dataclasses.dataclass +@dataclasses.dataclass(frozen=True) class LoweringRuleContext: - module_context: ModuleContext + module_ctx: ModuleContext + launch_ctx: mgpu.LaunchContext avals_in: Sequence[jax_core.ShapedArray] avals_out: Sequence[jax_core.ShapedArray] - block_shapes: list[tuple[int | pallas_core.Mapped, ...]] | None replace = dataclasses.replace -@dataclasses.dataclass +@dataclasses.dataclass(frozen=True) class LoweringResult: module: ir.Module grid: tuple[int, ...] - gmem_scratch_bytes: int out_structs: tuple[jax.ShapeDtypeStruct, ...] -@dataclasses.dataclass -class BlockInfo: - full_shape_dtype: jax.ShapeDtypeStruct - start_indices: Sequence[Any] - block_shape: tuple[int, ...] - - class LoweringError(Exception): # pylint: disable=g-bad-exception-name pass +def _eval_index_map( + module_ctx: ModuleContext, + launch_ctx: mgpu.LaunchContext, + idx: Sequence[ir.Value], + block_mapping: pallas_core.BlockMapping, +) -> Sequence[ir.Value]: + block_indices = lower_jaxpr_to_mosaic_gpu( + module_ctx, launch_ctx, block_mapping.index_map_jaxpr.jaxpr, idx + ) + result = [] + for i, b in zip(block_indices, block_mapping.block_shape): + if b is pallas_core.mapped: + result.append(i) + else: + # TODO(slebedev): Use a type-agnostic multiplication wrapper. + result.append(arith_dialect.muli(_as_index(i), _as_index(b))) + return tuple(result) + + +def _uses_arguments(cjaxpr: jax_core.ClosedJaxpr) -> list[bool]: + jaxpr = cjaxpr.jaxpr + return pe.dce_jaxpr(jaxpr, used_outputs=[True] * len(jaxpr.outvars))[1] + + def lower_jaxpr_to_module( grid_mapping: pallas_core.GridMapping, jaxpr: jax_core.Jaxpr, @@ -149,74 +214,396 @@ def lower_jaxpr_to_module( cost_estimate: pallas_core.CostEstimate | None, ) -> LoweringResult: del cost_estimate # Unused. - in_structs = tuple(grid_mapping.in_shapes) - out_structs = grid_mapping.out_shapes + + block_mappings = grid_mapping.block_mappings + assert len(jaxpr.outvars) == 0 assert not grid_mapping.vmapped_dims - grid = grid_mapping.grid + if len(grid_mapping.grid) > 3: + raise NotImplementedError( + "Only <=3D grids are supported in Mosaic GPU lowering." + ) + if grid_mapping.num_dynamic_grid_bounds: + raise NotImplementedError( + "Dynamic grid bounds not supported in the Mosaic GPU lowering." + ) + if grid_mapping.num_index_operands: + raise NotImplementedError( + "Scalar prefetch not supported in Mosaic GPU lowering." + ) + if not all( + isinstance(bm.indexing_mode, pallas_core.Blocked) for bm in block_mappings + ): + raise NotImplementedError( + "Only Blocked indexing mode is supported in Mosaic GPU lowering." + ) + + with grid_mapping.trace_env(): + jaxpr, _ = pe.dce_jaxpr( + jaxpr, [True] * len(jaxpr.outvars), instantiate=True + ) + + block = (128, 1, 1) + params = compiler_params.get("mosaic_gpu", {}) + approx_math = params.get("approx_math", False) + max_concurrent_steps = params.get("max_concurrent_steps", 1) + dimension_semantics = params.get("dimension_semantics") + if dimension_semantics is None: + dimension_semantics = ["parallel"] * len(grid_mapping.grid) + elif len(dimension_semantics) != len(grid_mapping.grid): + raise ValueError( + "dimension_semantics must have an entry for each grid dimension:" + f" {len(dimension_semantics)=}, but len(grid) is {grid_mapping.grid})." + ) + sequential_axes = tuple( + i for i, s in enumerate(dimension_semantics) if s == "sequential" + ) + + grid = [d for i, d in enumerate(grid_mapping.grid) if i not in sequential_axes] if len(grid) < 3: grid += (1,) * (3 - len(grid)) - block = (128,) + (1,) * (len(grid) - 1) + else: + raise NotImplementedError( + "Only <=3D grids are supported in Mosaic GPU lowering." + ) + # Compute the number of steps along each sequential axis. + if sequential_axes: + # TODO(slebedev): Support multiple sequential axes. + if len(sequential_axes) > 1: + raise NotImplementedError( + "Multiple sequential axes are not supported in Mosaic GPU lowering." + ) + [sequential_axis] = sequential_axes + num_steps = grid_mapping.grid[sequential_axis] + out_sequential_invariant = [ + not _uses_arguments(bm.index_map_jaxpr)[sequential_axis] + for bm in grid_mapping.block_mappings_output + ] + else: + num_steps = 1 + out_sequential_invariant = [True] * len(grid_mapping.out_shapes) + + in_in_smem, out_in_smem = util.split_list( + [ + bm.block_aval.memory_space in (None, gpu_core.SMEM) + for bm in block_mappings + ], + [grid_mapping.num_inputs], + ) - def body(launch_ctx: mosaic_gpu.LaunchContext, *buffers): - *buffers_gmem, (*buffers_smem, runtime_smem, barriers) = buffers + in_structs_gmem = [*grid_mapping.in_shapes] + in_block_mappings, out_block_mappings = util.split_list( + block_mappings, [grid_mapping.num_inputs] + ) + in_structs_smem = [ + jax.ShapeDtypeStruct( + [max_concurrent_steps, *bm.ref_aval.shape], bm.ref_aval.dtype + ) + if in_smem + else None + for bm, in_smem in zip( + block_mappings[: grid_mapping.num_inputs], in_in_smem + ) + ] + in_gmem_transforms = [ + cast(gpu_core.MemoryRefTransform, bm.transforms) + for bm in grid_mapping.block_mappings[: grid_mapping.num_inputs] + ] + in_swizzles = map( + lambda bm: bm.swizzle + if isinstance(bm, gpu_core.GPUBlockMapping) + else None, + grid_mapping.block_mappings[: grid_mapping.num_inputs], + ) + out_structs_gmem = [*grid_mapping.out_shapes] + # TODO(justinfu): Implement output Memref transforms + out_structs_smem = [ + jax.ShapeDtypeStruct([max_concurrent_steps, *bm.block_shape], s.dtype) + if in_smem + else None + for bm, in_smem, s in zip( + block_mappings[grid_mapping.num_inputs :], + out_in_smem, + grid_mapping.out_shapes, + ) + ] + + def body(launch_ctx: mgpu.LaunchContext, *buffers: ir.Value): + *buffers_gmem, ( + buffers_smem, + *scratch_buffers_smem, + runtime_smem, + barriers, + ) = buffers assert len(buffers_gmem) == len(buffers_smem) - in_buffers_gmem = buffers_gmem[: len(in_structs)] - in_buffers_smem = buffers_smem[: len(in_structs)] - out_buffers_gmem = buffers_gmem[len(in_structs) :] - out_buffers_smem = buffers_smem[len(in_structs) :] - - [barrier] = cast(mgpu.BarrierRef, barriers) - - with mgpu.single_thread(): - for b_gmem, b_smem in zip(in_buffers_gmem, in_buffers_smem): - # TODO(slebedev): Support 128-byte swizzling, once we can lower matmuls. - launch_ctx.async_copy( - src_ref=b_gmem, - dst_ref=b_smem, - barrier=barrier, - swizzle=None, - arrive=True, - uniform=False, + in_buffers_gmem, out_buffers_gmem = util.split_list( + buffers_gmem, [grid_mapping.num_inputs] + ) + in_buffers_smem, out_buffers_smem = util.split_list( + buffers_smem, [grid_mapping.num_inputs] + ) + barriers, *extra_barriers = barriers + + parallel_count = it.count() + program_ids_template = [ + _program_id(next(parallel_count)) if i not in sequential_axes else None + for i in range(len(grid_mapping.grid)) + ] + module_ctx = ModuleContext( + name_and_src_info.name, grid_mapping, approx_math, runtime_smem + ) + + smem_scratch_it = iter(scratch_buffers_smem) + scratch_buffers_template = [] + should_discharge = [] + accs = [] + for aval in scratch_avals: + match aval: + case gpu_core.WGMMAAbstractAccumulatorRef(): + scratch_buffers_template.append(None) + should_discharge.append(True) + accs.append( + mgpu.WGMMAAccumulator.zero( + *aval.shape, dtype=mgpu_utils.dtype_to_ir_type(aval.dtype) + ) + ) + case gpu_core.AbstractMemoryRef() if isinstance( + aval.dtype, gpu_core.BarrierType + ): + pass + case gpu_core.AbstractMemoryRef() if aval.memory_space == SMEM: + scratch_buffers_template.append(next(smem_scratch_it)) + should_discharge.append(False) + case _: + raise NotImplementedError( + f"Unsupported scratch operand type: {aval}" + ) + assert not jaxpr.outvars + if any(should_discharge): + # User-visible WGMMA APIs use the effectful accumulator references, but we + # can't lower that directly to Mosaic GPU that uses pure dataflow for + # accumulators. So we have to discharge the effects first. + assert not jaxpr.constvars + should_discharge = ( + [False] * len(grid_mapping.block_mappings) + + should_discharge + + [False] * len(extra_barriers) + ) + with grid_mapping.trace_env(): + lowered_jaxpr, _ = discharge.discharge_state( + jaxpr, (), should_discharge=should_discharge ) + else: + lowered_jaxpr = jaxpr + + # Precompute the total number of bytes transferred from GMEM to SMEM, + # so that we can do a single arrive instruction for all of the inputs. + in_transfer_bytes = 0 + for in_smem, b_smem in zip(in_in_smem, in_buffers_smem): + if not in_smem: + continue + b_smem_type = ir.MemRefType(b_smem.type) + in_transfer_bytes += math.prod(b_smem_type.shape[1:]) * mgpu.bytewidth( + b_smem_type.element_type + ) - barrier.wait() + def gmem_slice( + step: ir.Value, + block_mapping: pallas_core.BlockMapping, + ) -> Sequence[mgpu.DynamicSlice]: + assert len(sequential_axes) <= 1 + program_ids = [step if i is None else i for i in program_ids_template] + idxs = _eval_index_map(module_ctx, launch_ctx, program_ids, block_mapping) + return tuple( + mgpu.ds(idx, dim) for idx, dim in zip(idxs, block_mapping.block_shape) + ) + + is_memory_thread = mgpu.single_thread_predicate(per_block=True) - module_ctx = ModuleContext(name_and_src_info.name, - grid_mapping, runtime_smem, smem_used_bytes=0) - _ = lower_jaxpr_to_mosaic_gpu(module_ctx, jaxpr, None, buffers_smem) + def fetch(idx: int, step: ir.Value, slot: ir.Value) -> None: + if not in_in_smem[idx]: + return - for b_gmem, b_smem in zip(out_buffers_gmem, out_buffers_smem): # TODO(slebedev): Support 128-byte swizzling, once we can lower matmuls. - launch_ctx.async_copy(src_ref=b_smem, dst_ref=b_gmem, swizzle=None) + gmem_transforms = (x.to_gpu_transform() for x in in_gmem_transforms[idx]) + launch_ctx.async_copy( + src_ref=in_buffers_gmem[idx], + dst_ref=mgpu.memref_slice(in_buffers_smem[idx], slot), + gmem_slice=gmem_slice(step, in_block_mappings[idx]), + barrier=barriers[slot], + gmem_transform=tuple(gmem_transforms), + swizzle=in_swizzles[idx], + arrive=False, # The caller must do ``arrive_expect_tx`` manually! + uniform=False, + predicate=is_memory_thread, + ) + + def store( + idx: int, step: ir.Value, slot: ir.Value, prev_base_offset: ir.Value | None + ) -> ir.Value | None: + if not out_in_smem[idx]: + return _as_index(-1) + + store_slice = gmem_slice(step, out_block_mappings[idx]) + if out_sequential_invariant[idx]: + assert prev_base_offset is None + do_store = None # Lack of predicate defaults to True. + base_offset = None + else: + assert prev_base_offset is not None + # We have to do some work to make sure that consecutive stores are not + # going to be writing to the same location, or else we'll end up with + # multiple concurrent writes and a racy program. + # TODO(apaszke,slebedev): In most cases output index maps depend only on + # parallel grid axes and in that case we can simply move the store to + # happen after the loop. + # TODO(apaszke,slebedev): This still diverges significantly from the TPU + # semantics in that it will move on to the next SMEM output slice even if + # it's not storing the previous one. + strides, _ = ir.MemRefType(out_buffers_gmem[idx].type).get_strides_and_offset() + base_offset = _as_index(0) + for stride, slc in zip(strides, store_slice): + base_offset = arith_dialect.addi( + base_offset, arith_dialect.muli(slc.base, _as_index(stride)) + ) + base_offset_changed = arith_dialect.cmpi( + arith_dialect.CmpIPredicate.ne, base_offset, prev_base_offset + ) + is_last_step = arith_dialect.cmpi( + arith_dialect.CmpIPredicate.eq, step, _as_index(num_steps - 1) + ) + do_store = arith_dialect.andi( + is_memory_thread, arith_dialect.ori(base_offset_changed, is_last_step) + ) + # TODO(slebedev): Support 128-byte swizzling, once we can lower matmuls. + launch_ctx.async_copy( + src_ref=mgpu.memref_slice(out_buffers_smem[idx], slot), + dst_ref=out_buffers_gmem[idx], + gmem_slice=store_slice, + swizzle=None, + uniform=False, + predicate=do_store, + ) + return base_offset + + for slot in range(min(max_concurrent_steps, num_steps)): + barriers[slot].arrive_expect_tx(in_transfer_bytes, predicate=is_memory_thread) + for idx in range(grid_mapping.num_inputs): + fetch(idx, _as_index(slot), _as_index(slot)) + + last_store_offsets = [None if inv else _as_index(-1) for inv in out_sequential_invariant] + @mgpu.fori(_as_index(num_steps), (accs, last_store_offsets)) + def _(step, carry): + accs, last_store_offsets = carry + slot = arith_dialect.remui(step, _as_index(max_concurrent_steps)) + if grid_mapping.num_inputs: + # Only wait if async copies were issued. + barriers[slot].wait() + # We need to make sure the output copy is complete before the kernel starts + # writing to the output window. + launch_ctx.await_async_copy(max_concurrent_steps - 1, await_read_only=True) + + args = [ + mgpu.memref_slice(buffers_smem[idx], slot) + if in_smem + else buffers_gmem[idx] + for idx, in_smem in enumerate(it.chain(in_in_smem, out_in_smem)) + ] + accs_it = iter(accs) + scratch_buffers = [ + b if b is not None else next(accs_it) + for b in scratch_buffers_template + ] + args.extend(scratch_buffers) + # TODO(apaszke): This assumes barriers come after buffers in scratch args, + # but that's not necessarily true. + args.extend(extra_barriers) + new_accs = lower_jaxpr_to_mosaic_gpu( + module_ctx, launch_ctx, lowered_jaxpr, args + ) + + # TODO(apaszke): Elide this if we're not going to perform any stores + mgpu.commit_shared() + new_store_offsets = [] + for idx in range(grid_mapping.num_outputs): + last_offset = last_store_offsets[idx] + new_store_offsets.append( + store(idx, step, slot, last_offset) + if not out_sequential_invariant[idx] + else last_offset # Only store if the output can depend on the step. + ) + + next_step = arith_dialect.addi(step, _as_index(max_concurrent_steps)) + next_step_in_bounds = arith_dialect.cmpi( + arith_dialect.CmpIPredicate.ult, next_step, _as_index(num_steps) + ) + next_slot = slot # (x + y) % y == x % y + with mgpu.when(next_step_in_bounds): + barriers[slot].arrive_expect_tx(in_transfer_bytes, predicate=is_memory_thread) + for idx in range(grid_mapping.num_inputs): + fetch(idx, next_step, next_slot) + + return list(new_accs), new_store_offsets + + # Outputs invariant to the sequential axis are never written from inside the + # loop. This is the only place where we store them. + last_slot = _as_index((num_steps - 1) % max_concurrent_steps) + for idx in range(grid_mapping.num_outputs): + if out_sequential_invariant[idx]: + store(idx, _as_index(0), last_slot, None) launch_ctx.await_async_copy(0) - # TODO(b/354568888): Add a jaxpr traversal to calculate the precise - # amount of memory required. + scratch_avals = [ + var.aval for var in jaxpr.invars[grid_mapping.slice_scratch_ops] + ] + local_spaces = (gpu_core.SMEM, gpu_core.REGS) + if not all( + isinstance(aval, pallas_core.AbstractMemoryRef) + and aval.memory_space in local_spaces + for aval in scratch_avals + ): + raise TypeError( + "All scratch operands must be SMEM references or accumulators (ACC)," + f" but got: {scratch_avals}" + ) + extra_barriers = [ + mgpu.Barrier(aval.dtype.num_arrivals, *aval.shape) + for aval in scratch_avals + if isinstance(aval.dtype, gpu_core.BarrierType) + ] extra_smem_scratch = [ - jax.ShapeDtypeStruct( - shape=[compiler_params.get("smem_scratch_bytes", 100000)], - dtype=np.int8, - ) + jax.ShapeDtypeStruct(aval.shape, aval.dtype) + for aval in scratch_avals + if not isinstance(aval.dtype, gpu_core.BarrierType) + and aval.memory_space == gpu_core.SMEM ] - module, out_structs, gmem_scratch_bytes, _ = mosaic_gpu._lower_as_gpu_kernel( + smem_scratch_bytes = compiler_params.get("smem_scratch_bytes") + if smem_scratch_bytes is None: + smem_scratch_bytes = _estimate_smem_scratch_bytes(jaxpr) + extra_smem_scratch.append( + jax.ShapeDtypeStruct(shape=[smem_scratch_bytes], dtype=np.int8) + ) + + module, out_structs_smem, _ = mgpu_core._lower_as_gpu_kernel( body, grid=grid, cluster=(), block=block, - in_shapes=in_structs, - out_shape=out_structs, + in_shapes=in_structs_gmem, + out_shape=out_structs_gmem, smem_scratch_shape=( - *in_structs, - *out_structs, + (*in_structs_smem, *out_structs_smem), *extra_smem_scratch, - mgpu.TMABarrier(), + ( + mgpu.Barrier(arrival_count=1, num_barriers=max_concurrent_steps), + *extra_barriers, + ), ), module_name=name_and_src_info.name, ) - return LoweringResult(module, grid, gmem_scratch_bytes, out_structs) + return LoweringResult(module, grid, out_structs_smem) mosaic_lowering_rules = {} @@ -231,30 +618,20 @@ def deco(fn): def lower_jaxpr_to_mosaic_gpu( - ctx: ModuleContext, + module_ctx: ModuleContext, + launch_ctx: mgpu.LaunchContext, jaxpr: jax_core.Jaxpr, - block_infos: Sequence[BlockInfo | None] | None, - args, + args: Sequence[ir.Value], consts=(), ) -> Sequence[ir.Value]: env = {} - block_info_env = {} def read_env(atom: jax_core.Atom): return atom.val if isinstance(atom, jax_core.Literal) else env[atom] - def read_block_info_env(atom: jax_core.Atom): - if isinstance(atom, jax_core.Literal): - return None - return block_info_env.get(atom, None) - def write_env(var: jax_core.Var, val): env[var] = val - if block_infos is None: - block_infos = [None] * len(jaxpr.invars) - for invar, block_info in zip(jaxpr.invars, block_infos): - block_info_env[invar] = block_info map(write_env, jaxpr.constvars, consts) map(write_env, jaxpr.invars, args) for eqn in jaxpr.eqns: @@ -263,14 +640,14 @@ def write_env(var: jax_core.Var, val): raise NotImplementedError( "Unimplemented primitive in Pallas Mosaic GPU lowering: " f"{eqn.primitive.name}. " - "Please file an issue on https://github.com/google/jax/issues." + "Please file an issue on https://github.com/jax-ml/jax/issues." ) rule = mosaic_lowering_rules[eqn.primitive] rule_ctx = LoweringRuleContext( - ctx, + module_ctx, + launch_ctx, avals_in=[cast(jax_core.ShapedArray, v.aval) for v in eqn.invars], avals_out=[cast(jax_core.ShapedArray, v.aval) for v in eqn.outvars], - block_shapes=map(read_block_info_env, eqn.invars), ) try: outvals = rule(rule_ctx, *invals, **eqn.params) @@ -289,23 +666,55 @@ def write_env(var: jax_core.Var, val): return map(read_env, jaxpr.outvars) +@register_lowering_rule(primitives.program_id_p) +def _program_id_lowering_rule(ctx: LoweringRuleContext, axis): + # TODO(apaszke): Sequential axis should be handled specially!! + del ctx # Unused. + return _program_id(axis) + + +def _program_id(axis: int) -> ir.Value: + return arith_dialect.index_cast( + ir.IntegerType.get_signless(32), + gpu_dialect.block_id(gpu_dialect.Dimension(axis)), + ) + + +@register_lowering_rule(primitives.num_programs_p) +def _num_programs_lowering_rule(ctx: LoweringRuleContext, axis): + del ctx # Unused. + return arith_dialect.index_cast( + ir.IntegerType.get_signless(32), + gpu_dialect.block_dim(gpu_dialect.Dimension(axis)), + ) + + @register_lowering_rule(sp.get_p) def _get_lowering_rule(ctx: LoweringRuleContext, x_smem, *indexers, tree): - del tree, ctx # Unused. + del tree # Unused. if indexers: raise NotImplementedError("No support for indexers yet") - - return mgpu.FragmentedArray.load_strided(x_smem) + [x_aval] = ctx.avals_in + return mgpu.FragmentedArray.load_strided( + x_smem, is_signed=mgpu_utils.is_signed(x_aval.dtype) + ) @register_lowering_rule(sp.swap_p) def _swap_lowering_rule( ctx: LoweringRuleContext, x_smem, value, *indexers, tree ): - del tree, ctx # Unused. + del tree # Unused. if indexers: raise NotImplementedError("No support for indexers yet") - old_value = mgpu.FragmentedArray.load_strided(x_smem) + if not isinstance(value, mgpu.FragmentedArray): + raise TypeError(f"Can only store arrays (got {value}).") + if not isinstance(x_smem, ir.Value) and ir.MemRefType.isinstance(x_smem): + raise TypeError(f"Can only store to references (got {value}).") + x_aval, _ = ctx.avals_in + old_value = mgpu.FragmentedArray.load_strided( + x_smem, is_signed=mgpu_utils.is_signed(x_aval.dtype) + ) value.store_untiled(x_smem) return old_value @@ -314,7 +723,25 @@ def _swap_lowering_rule( def _pjit_lowering_rule(ctx: LoweringRuleContext, *args, jaxpr, **_): if jaxpr.consts: raise NotImplementedError - return lower_jaxpr_to_mosaic_gpu(ctx.module_context, jaxpr.jaxpr, None, args) + return lower_jaxpr_to_mosaic_gpu( + ctx.module_ctx, ctx.launch_ctx, jaxpr.jaxpr, args + ) + + +@register_lowering_rule(lax.select_n_p) +def _select_n_lowering_rule(ctx: LoweringRuleContext, pred, *cases): + if len(cases) != 2: + raise NotImplementedError( + "Mosaic GPU lowering only supports select_n with 2 cases, got" + f" {len(cases)}" + ) + pred_aval, *cases_avals = ctx.avals_in + [out_aval] = ctx.avals_out + pred = _ensure_fa(pred, pred_aval.dtype) + cases = _bcast(*cases, *cases_avals, out_aval) + # ``select`` expects the first case to be the true branch, but ``select_n`` + # orders the cases in reverse. + return pred.select(*reversed(cases)) @register_lowering_rule(lax.broadcast_in_dim_p) @@ -327,7 +754,8 @@ def _broadcast_in_dim_lowering_rule( ): if broadcast_dimensions: raise NotImplementedError - return _ensure_fa(x, ctx.avals_in[0]).broadcast(shape) + [x_aval] = ctx.avals_in + return _ensure_fa(x, x_aval.dtype).broadcast(shape) @register_lowering_rule(lax.convert_element_type_p) @@ -335,7 +763,10 @@ def _convert_element_type_lowering_rule( ctx: LoweringRuleContext, x, *, new_dtype, weak_type, sharding ): del weak_type, sharding - return _ensure_fa(x, *ctx.avals_in).astype(mlir.dtype_to_ir_type(new_dtype)) + [x_aval] = ctx.avals_in + return _ensure_fa(x, x_aval.dtype).astype( + mgpu_utils.dtype_to_ir_type(new_dtype), is_signed=mgpu_utils.is_signed(new_dtype) + ) def _binary_op_lowering_rule(ctx: LoweringRuleContext, x, y, *, impl): @@ -348,12 +779,23 @@ def _binary_op_lowering_rule(ctx: LoweringRuleContext, x, y, *, impl): lax.sub_p: partial(_binary_op_lowering_rule, impl=lambda x, y: x - y), lax.mul_p: partial(_binary_op_lowering_rule, impl=lambda x, y: x * y), lax.div_p: partial(_binary_op_lowering_rule, impl=lambda x, y: x / y), + lax.rem_p: partial(_binary_op_lowering_rule, impl=lambda x, y: x % y), + lax.and_p: partial(_binary_op_lowering_rule, impl=lambda x, y: x & y), + lax.or_p: partial(_binary_op_lowering_rule, impl=lambda x, y: x | y), + lax.xor_p: partial(_binary_op_lowering_rule, impl=lambda x, y: x ^ y), + lax.gt_p: partial(_binary_op_lowering_rule, impl=lambda x, y: x > y), + lax.lt_p: partial(_binary_op_lowering_rule, impl=lambda x, y: x < y), + lax.ge_p: partial(_binary_op_lowering_rule, impl=lambda x, y: x >= y), + lax.le_p: partial(_binary_op_lowering_rule, impl=lambda x, y: x <= y), + lax.eq_p: partial(_binary_op_lowering_rule, impl=lambda x, y: x == y), + lax.ne_p: partial(_binary_op_lowering_rule, impl=lambda x, y: x != y), }) @register_lowering_rule(lax.integer_pow_p) def _integer_pow_lowering_rule(ctx: LoweringRuleContext, x, y): - x = _ensure_fa(x, *ctx.avals_in) + [x_aval] = ctx.avals_in + x = _ensure_fa(x, x_aval.dtype) if y == 2: return x * x return NotImplementedError @@ -361,7 +803,8 @@ def _integer_pow_lowering_rule(ctx: LoweringRuleContext, x, y): @register_lowering_rule(lax.rsqrt_p) def _rsqrt_lowering_rule(ctx: LoweringRuleContext, x): - return _ensure_fa(x, *ctx.avals_in).rsqrt() + [x_aval] = ctx.avals_in + return _ensure_fa(x, x_aval.dtype).rsqrt(approx=ctx.module_ctx.approx_math) @register_lowering_rule(lax.reduce_sum_p) @@ -369,10 +812,12 @@ def _reduce_sum_lowering_rule(ctx: LoweringRuleContext, x, *, axes): if axes != (0,): raise NotImplementedError("No support for axes other than 0 yet") [x_aval] = ctx.avals_in - _, [scratch] = ctx.module_context.scratch_view( + _, [scratch] = ctx.module_ctx.scratch_view( [jax.ShapeDtypeStruct(shape=(4,), dtype=x_aval.dtype)] ) - return mgpu.FragmentedArray.splat(x.reduce_sum(scratch), ()) + return mgpu.FragmentedArray.splat( + x.reduce_sum(scratch), (), is_signed=mgpu_utils.is_signed(x_aval.dtype) + ) @register_lowering_rule(primitives.debug_print_p) @@ -382,10 +827,29 @@ def _debug_print_lowering_rule( fmt, has_placeholders: bool, ): - del ctx - del has_placeholders + del has_placeholders # Unused. primitives.check_debug_print_format(fmt, *args) - mgpu.debug_print(fmt, *args) + if not any(aval.shape for aval in ctx.avals_in): + mgpu.debug_print( + fmt, + *( + _ensure_ir_value(arg, aval.dtype) + for arg, aval in zip(args, ctx.avals_in) + ), + ) + elif len(ctx.avals_in) == 1: + [arg] = args + @arg.foreach + def _(val, idx): + idx_fmt = ", ".join(["{}"] * len(idx)) + fmt_str = fmt.format(f"[{idx_fmt}]/{list(arg.shape)}: {{}}") + mgpu.debug_print(fmt_str, *idx, val, uniform=False) + else: + raise NotImplementedError( + "debug_print only supports printing of scalar values, or a single array" + " value when using the Mosaic GPU backend." + ) + return () @@ -393,15 +857,161 @@ def _debug_print_lowering_rule( def _run_scoped_lowering_rule( ctx: LoweringRuleContext, *consts, jaxpr: jax_core.Jaxpr ): - in_avals = [v.aval.inner_aval for v in jaxpr.invars] - bytes_allocated, input_refs = ctx.module_context.scratch_view( - [jax.ShapeDtypeStruct(shape=aval.shape, dtype=aval.dtype) for aval in in_avals] + input_refs = [] + bytes_allocated = 0 + should_discharge = [] + for a in jaxpr.invars: + a = a.aval + if isinstance(a, gpu_core.WGMMAAbstractAccumulatorRef): + mlir_dtype = mlir.dtype_to_ir_type(a.dtype) + input_refs.append(mgpu.WGMMAAccumulator.zero(*a.shape, mlir_dtype)) + should_discharge.append(True) + elif a.memory_space == gpu_core.SMEM: + ref_bytes, [input_ref] = ctx.module_ctx.scratch_view( + [jax.ShapeDtypeStruct(shape=a.shape, dtype=a.dtype)] + ) + bytes_allocated += ref_bytes + input_refs.append(input_ref) + should_discharge.append(False) + else: + raise ValueError(f"Can't convert to ref: {a}") + + if any(should_discharge): + # We convert consts to args, because we only have ir.Values and + # not JAX values during lowering. discharge_state() produces JAX + # valiues for the aguments but expects them to be provided for the + # consts. We also don't want to wrap the values in refs. + no_const_jaxpr = pe.convert_constvars_jaxpr(jaxpr) + should_discharge = [False] * len(consts) + should_discharge + discharged_jaxpr, _ = discharge.discharge_state(no_const_jaxpr, (), should_discharge=should_discharge) + new_input_vals = consts + tuple(input_refs) + outs = lower_jaxpr_to_mosaic_gpu( + ctx.module_ctx, ctx.launch_ctx, discharged_jaxpr, new_input_vals, () + ) + # Discharge appends to the output the refs that got discharged. + outs = outs[:-sum(should_discharge)] + else: + outs = lower_jaxpr_to_mosaic_gpu( + ctx.module_ctx, ctx.launch_ctx, jaxpr, input_refs, consts + ) + + for o in outs: + # This is definitely one of the accumulators we produced. Each + # run_scoped call is responsible for dereferencing its own + # accumulators. + if isinstance(o, mgpu.WGMMAAccumulator) or ( + isinstance(o, ir.Value) and ir.MemRefType.isinstance(o.type) + ): + raise ValueError(f"No references are allowed to escape a scope. (got {o})") + + assert len(outs) == len(jaxpr.outvars), (jaxpr, outs) + if bytes_allocated: + ctx.module_ctx.stack_free_smem(bytes_allocated) + + return outs + + +def _lower_jaxpr_to_for_loop( + ctx: LoweringRuleContext, + jaxpr: jax_core.Jaxpr, + start: ir.Value, + length: ir.Value, + consts, + *args, + has_loop_index: bool, +): + + @mgpu.fori(length, [*args]) + def loop(loop_index, body_args): + if has_loop_index: + loop_index = arith_dialect.addi(loop_index, start) + jaxpr_args = [*consts, loop_index, *body_args] + else: + jaxpr_args = [*consts, *body_args] + return lower_jaxpr_to_mosaic_gpu( + ctx.module_ctx, ctx.launch_ctx, jaxpr, jaxpr_args + ) + + return loop.results + + +@register_lowering_rule(lax.scan_p) +def _scan_lowering_rule( + ctx: LoweringRuleContext, + *args, + jaxpr: jax_core.ClosedJaxpr, + linear: tuple[bool, ...], + length: int, + reverse: bool, + unroll: bool | int, + num_consts: int, + num_carry: int, + _split_transpose: bool, +): + # Can only handle fori_loop-like scans. + if ( + (num_extensive := len(args) - num_consts - num_carry) + or reverse + or unroll != 1 + ): + raise NotImplementedError + del linear, num_extensive, reverse, unroll + + jaxpr, jaxpr_consts = jaxpr.jaxpr, jaxpr.consts + if jaxpr_consts: + raise NotImplementedError + del jaxpr_consts + + jaxpr, has_loop_index = pallas_utils.pattern_match_scan_to_fori_loop( + jaxpr, num_consts, num_carry ) - outs = lower_jaxpr_to_mosaic_gpu( - ctx.module_context, jaxpr, None, input_refs, consts + consts, args = util.split_list(args, [num_consts]) + _consts_avals, arg_avals = util.split_list(ctx.avals_in, [num_consts]) + if has_loop_index: + start, *args = args + index_aval, *arg_avals = arg_avals + start: ir.Value = _ensure_ir_value(start, index_aval.dtype) + length = _ir_constant(length, start.type) + else: + start = _i32_constant(0) + length = _i32_constant(length) + args = map(lambda arg, aval: _ensure_fa(arg, aval.dtype), args, arg_avals) + for_out = _lower_jaxpr_to_for_loop( + ctx, jaxpr, start, length, consts, *args, has_loop_index=has_loop_index + ) + if has_loop_index: + # Need to return the final loop index value if the outer scan expects + # it as an output. + return [length, *for_out] + return for_out + + +@register_lowering_rule(lax.cond_p) +def _cond_lowering_rule(ctx: LoweringRuleContext, index, *args, branches): + index_aval, *_arg_avals = ctx.avals_in + switch_op = scf_dialect.IndexSwitchOp( + map(mgpu_utils.dtype_to_ir_type, ctx.avals_out), + _as_index(_ensure_ir_value(index, index_aval.dtype)), + ir.DenseI64ArrayAttr.get(range(len(branches) - 1)), + num_caseRegions=len(branches) - 1, ) - ctx.module_context.stack_free_smem(bytes_allocated) - return outs + + # ``RegionSequence`` in MLIR does not support slicing, so the + # auto-generated Python bindings for ``caseRegions`` fail at runtime! + # We convert it to a list to work around that. + regions = list(switch_op.regions) + # Move the default region to the back. + regions = regions[1:] + regions[:1] + for branch, region in zip(branches, regions): + with ir.InsertionPoint(region.blocks.append()): + outs = lower_jaxpr_to_mosaic_gpu( + ctx.module_ctx, ctx.launch_ctx, branch.jaxpr, args + ) + scf_dialect.yield_([ + _ensure_ir_value(out, aval.dtype) + for out, aval in zip(outs, ctx.avals_out) + ]) + return list(switch_op.results) def _bcast( @@ -410,23 +1020,17 @@ def _bcast( x_aval: jax_core.ShapedArray, y_aval: jax_core.ShapedArray, out_aval: jax_core.ShapedArray, -) -> ir.Value: - if isinstance(x, (np.ndarray, np.number, int, float)): +) -> tuple[mgpu.FragmentedArray, mgpu.FragmentedArray]: + if not isinstance(x, mgpu.FragmentedArray): x_dtype = x_aval.dtype if x_aval.weak_type: x_dtype = y_aval.dtype - x = mgpu.FragmentedArray.splat( - _ir_constant(x, mlir.dtype_to_ir_type(x_dtype)), () - ) - if isinstance(y, (np.ndarray, np.number, int, float)): + x = _ensure_fa(x, x_dtype) + if not isinstance(y, mgpu.FragmentedArray): y_dtype = y_aval.dtype if y_aval.weak_type: y_dtype = x_aval.dtype - y = mgpu.FragmentedArray.splat( - _ir_constant(y, mlir.dtype_to_ir_type(y_dtype)), () - ) - assert isinstance(x, mgpu.FragmentedArray) - assert isinstance(y, mgpu.FragmentedArray) + y = _ensure_fa(y, y_dtype) if x_aval.shape != out_aval.shape: x = x.broadcast(out_aval.shape) if y_aval.shape != out_aval.shape: @@ -434,19 +1038,38 @@ def _bcast( return x, y -def _ensure_fa(x: object, aval: jax_core.ShapedArray) -> mgpu.FragmentedArray: +def _ensure_fa(x: object, dtype: jnp.dtype) -> mgpu.FragmentedArray: if isinstance(x, mgpu.FragmentedArray): + assert x.mlir_dtype == mgpu_utils.dtype_to_ir_type(dtype) return x elif isinstance(x, (np.number, np.ndarray, int, float)): return mgpu.FragmentedArray.splat( - _ir_constant(x, mlir.dtype_to_ir_type(aval.dtype)), () + _ir_constant(x, mgpu_utils.dtype_to_ir_type(dtype)), + (), + is_signed=mgpu_utils.is_signed(dtype), ) - raise NotImplementedError + elif isinstance(x, ir.Value): + if isinstance(x.type, (ir.IntegerType, ir.FloatType, ir.IndexType)): + assert x.type == mgpu_utils.dtype_to_ir_type(dtype) + return mgpu.FragmentedArray.splat(x, (), is_signed=mgpu_utils.is_signed(dtype)) + raise NotImplementedError(f"Unsupported type: {type(x)}") + + +def _ensure_ir_value(x: object, dtype: jnp.dtype) -> ir.Value: + if isinstance(x, ir.Value): + assert x.type == mgpu_utils.dtype_to_ir_type(dtype) + return x + elif isinstance(x, (np.number, np.ndarray, int, float)): + return _ir_constant(x, mgpu_utils.dtype_to_ir_type(dtype)) + elif isinstance(x, mgpu.FragmentedArray): + if isinstance(x.layout, mgpu.WGSplatFragLayout): + return x.registers.item() + raise NotImplementedError(f"Unsupported type: {type(x)}") def _ir_constant(v: object, t: ir.Type) -> ir.Value: if isinstance(v, (np.number, np.ndarray, int, float)): - if isinstance(t, ir.IntegerType): + if isinstance(t, (ir.IntegerType, ir.IndexType)): v = int(v) else: assert isinstance(t, ir.FloatType) @@ -455,5 +1078,17 @@ def _ir_constant(v: object, t: ir.Type) -> ir.Value: raise NotImplementedError(f"Unsupported constant: {v!r}") -def _index(i: int) -> ir.Value: - return arith_dialect.constant(ir.IndexType.get(), int(i)) +def _i32_constant(v: int) -> ir.Value: + return arith_dialect.constant(ir.IntegerType.get_signless(32), v) + + +def _i64_constant(v: int) -> ir.Value: + return arith_dialect.constant(ir.IntegerType.get_signless(64), v) + + +def _as_index(v: int | ir.Value) -> ir.Value: + if isinstance(v, int): + return arith_dialect.constant(ir.IndexType.get(), v) + if ir.IndexType.isinstance(v.type): + return v + return arith_dialect.index_cast(ir.IndexType.get(), v) diff --git a/jax/_src/pallas/mosaic_gpu/pallas_call_registration.py b/jax/_src/pallas/mosaic_gpu/pallas_call_registration.py index 9f28fa7c2944..960fe7d71856 100644 --- a/jax/_src/pallas/mosaic_gpu/pallas_call_registration.py +++ b/jax/_src/pallas/mosaic_gpu/pallas_call_registration.py @@ -23,7 +23,7 @@ from jax._src.interpreters import mlir from jax._src.pallas import core as pallas_core from jax._src.pallas.mosaic_gpu import lowering -from jax.experimental.mosaic import gpu as mosaic_gpu +import jax.experimental.mosaic.gpu.core as mosaic_core def pallas_call_lowering( @@ -37,16 +37,13 @@ def pallas_call_lowering( grid_mapping: pallas_core.GridMapping, compiler_params: dict[str, Any], cost_estimate: pallas_core.CostEstimate | None, + out_avals: tuple[jax_core.AbstractValue, ...], ): - del interpret + del interpret, out_avals if grid_mapping.num_dynamic_grid_bounds: raise NotImplementedError( "dynamic grid bounds not supported in the Mosaic GPU backend" ) - if input_output_aliases: - raise NotImplementedError( - "input_output_aliases not supported in the Mosaic GPU backend" - ) if debug: print(f"\nThe kernel jaxpr for pallas_call {name_and_src_info}:") @@ -66,10 +63,10 @@ def pallas_call_lowering( print(lowering_result.module.operation) module = lowering_result.module - return mosaic_gpu._mosaic_gpu_lowering_rule( + return mosaic_core._mosaic_gpu_lowering_rule( ctx, *args, module=module.operation.get_asm(binary=True, enable_debug_info=True), - gmem_scratch_bytes=lowering_result.gmem_scratch_bytes, out_types=lowering_result.out_structs, + input_output_aliases=input_output_aliases, ) diff --git a/jax/_src/pallas/mosaic_gpu/primitives.py b/jax/_src/pallas/mosaic_gpu/primitives.py new file mode 100644 index 000000000000..dcec631e389b --- /dev/null +++ b/jax/_src/pallas/mosaic_gpu/primitives.py @@ -0,0 +1,260 @@ +# Copyright 2024 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""GPU-specific Pallas primitives.""" + +from __future__ import annotations + +from jax._src import core as jax_core +from jax._src import effects +from jax._src import state +from jax._src.state import discharge +from jax._src.lib.mlir.dialects import nvvm as nvvm_dialect +from jax._src.pallas import core as pallas_core +from jax._src.pallas.mosaic_gpu import core as gpu_core +from jax._src.pallas.mosaic_gpu import lowering +import jax.experimental.mosaic.gpu as mgpu + +async_copy_p = jax_core.Primitive("async_copy") +async_copy_p.multiple_results = True + + +@async_copy_p.def_effectful_abstract_eval +def _async_copy_abstract_eval(*avals): + del avals # Unused. + return (), {state.ReadEffect(0), state.WriteEffect(1)} + + +@lowering.register_lowering_rule(async_copy_p) +def _async_copy_lowering_rule( + ctx: lowering.LoweringRuleContext, src, dst, barrier=None +): + ctx.launch_ctx.async_copy(src_ref=src, dst_ref=dst, barrier=barrier) + return () + + +def async_copy_smem_to_gmem( + src: pallas_core.AbstractMemoryRef, dst: pallas_core.AbstractMemoryRef +) -> None: + if src.memory_space is not gpu_core.SMEM: + raise TypeError(f"src must be a SMEM reference, got {src.memory_space}") + if dst.memory_space is not gpu_core.GMEM: + raise ValueError(f"dst must be a GMEM reference, got {dst.memory_space}") + async_copy_p.bind(src, dst) + return None + + +def async_copy_gmem_to_smem( + src: pallas_core.AbstractMemoryRef, + dst: pallas_core.AbstractMemoryRef, + *, + barrier: pallas_core.AbstractMemoryRef, +) -> None: + if src.memory_space is not gpu_core.GMEM: + raise TypeError(f"src must be a GMEM reference, got {src.memory_space}") + if dst.memory_space is not gpu_core.SMEM: + raise ValueError(f"dst must be a SMEM reference, got {dst.memory_space}") + async_copy_p.bind(src, dst, barrier) + return None + + +class WaitEffect(jax_core.Effect): + ... + + +wait_effect = WaitEffect() + + +wait_p = jax_core.Primitive("wait") +wait_p.multiple_results = True + + +@wait_p.def_effectful_abstract_eval +def _wait_abstract_eval(*avals, **params): + del avals, params # Unused. + return (), {wait_effect} + + +@lowering.register_lowering_rule(wait_p) +def _wait_lowering_rule( + ctx: lowering.LoweringRuleContext, barrier=None, allow_groups=None, +): + if barrier is not None: + barrier.wait() + else: + assert allow_groups is not None + ctx.launch_ctx.await_async_copy(allow_groups=allow_groups) + return () + + +def wait_smem_to_gmem(allow_groups: int) -> None: + """Waits until there are no more than the given number of SMEM->GMEM copies in flight.""" + wait_p.bind(allow_groups=allow_groups) + + +def wait_barrier(barrier: pallas_core.AbstractMemoryRef) -> None: + """Waits on the given barrier.""" + wait_p.bind(barrier) + + +class _WGMMAPipelineEffect(effects.Effect): + pass + + +_wgmma_pipeline_effect = _WGMMAPipelineEffect() +effects.control_flow_allowed_effects.add_type(_WGMMAPipelineEffect) + +# WGMMA on an accumulator reference +wgmma_ref_p = jax_core.Primitive("wgmma_ref") +wgmma_ref_p.multiple_results = True + +def wgmma(acc, a, b, *, rhs_transpose: bool = False, swizzle: int = 128): + """Asynchronous warp group matmul. + + The sm90 wgmma instruction, essentially acc[...] += a @ b. Requires + that accumulator is an accumualtion register reference. + + Args: + acc: The accumulator register. + a: The left hand side operand. + b: The right hand side operand. + transpose: Whether to transpose b. + n_tile: The number of tiles to use. + swizzle: The swizzle pattern. + """ + if not isinstance(acc.aval, gpu_core.WGMMAAbstractAccumulatorRef): + raise TypeError(f"Expected WGMMAAbstractAccumulatorRef got {acc}") + + ma, ka, tma, tka = a.shape + kb, nb, tkb, tnb = b.shape + mc, nc = acc.shape + + if rhs_transpose: + kb, nb, tkb, tnb = nb, kb, tnb, tkb + + if tma * ma != mc or nb * tnb != nc or ka != kb or tka != tkb: + raise ValueError(f"Incompatible shapes: {a.shape=}, {b.shape=}, {acc.shape=}, {rhs_transpose=}") + + return wgmma_ref_p.bind(acc, a, b, swizzle=swizzle, rhs_transpose=rhs_transpose) + + +@wgmma_ref_p.def_effectful_abstract_eval +def _wgmma_ref_effectful_abstract_eval(acc, *args, **kwargs): + del acc, args, kwargs + return [], { + _wgmma_pipeline_effect, + state.WriteEffect(0), + state.ReadEffect(0), + state.ReadEffect(1), + state.ReadEffect(2), + } + + +@discharge.register_discharge_rule(wgmma_ref_p) +def _wgmma_ref_discharge_rule( + in_avals, out_avals, + acc, + a, + b, + swizzle, + rhs_transpose, +): + del in_avals, out_avals + return ( + wgmma_p.bind( + acc, a, b, swizzle=swizzle, rhs_transpose=rhs_transpose + ), + None, + None, + ), [] + + +# Functional WGMMA, returns a shaped array. Internal. +wgmma_p = jax_core.Primitive("wgmma") + +@lowering.register_lowering_rule(wgmma_p) +def _wgmma_lowering_rule( + ctx: lowering.LoweringRuleContext, + acc, + a, + b, + swizzle, + rhs_transpose, +): + del ctx + new_acc = mgpu.wgmma( + acc, + a, + b, + swizzle=swizzle, + b_order=mgpu.WGMMALayout.COL_MAJOR + if rhs_transpose + else mgpu.WGMMALayout.ROW_MAJOR, + ) + nvvm_dialect.wgmma_commit_group_sync_aligned() + return new_acc + +@wgmma_p.def_effectful_abstract_eval +def _wgmma_effectful_abstract_eval(acc, *args, **kwargs): + del args, kwargs + return acc, { + _wgmma_pipeline_effect, + state.ReadEffect(1), + state.ReadEffect(2), + } + +wgmma_wait_p = jax_core.Primitive("wgmma_wait") +wgmma_wait_p.multiple_results = True + +def wgmma_wait(i: int): + """Wait until all but the last `i` WGMMA operations are done.""" + return wgmma_wait_p.bind(i) + + +@wgmma_wait_p.def_effectful_abstract_eval +def wgmma_wait_effectful_abstract_eval(_): + return [], {_wgmma_pipeline_effect} + +@lowering.register_lowering_rule(wgmma_wait_p) +def _wgmma_wait_lowering_rule(ctx: lowering.LoweringRuleContext, allow_groups): + del ctx + nvvm_dialect.wgmma_wait_group_sync_aligned(allow_groups) + return () + +wgmma_accumulator_deref_p = jax_core.Primitive("wgmma_accumulator_deref_p") +def wgmma_accumulator_deref(acc): + """Dereferences an accumulator register.""" + + if not isinstance(acc.aval, gpu_core.WGMMAAbstractAccumulatorRef): + raise TypeError(f"acc must be a WGMMAAccumulatorAbstractRef, got {acc.aval=}") + + return wgmma_accumulator_deref_p.bind(acc) + +@wgmma_accumulator_deref_p.def_effectful_abstract_eval +def _wgmma_accumulator_deref_abstract_eval(acc): + # Dereferencing implies flushing so we have a wgmma pipeline effect. + ret = acc.inner_aval if isinstance(acc, gpu_core.WGMMAAbstractAccumulatorRef) else acc + assert isinstance(ret, jax_core.ShapedArray), acc + return ret, {_wgmma_pipeline_effect} + +@discharge.register_discharge_rule(wgmma_accumulator_deref_p) +def _wgmma_accumulator_deref_discharge_rule(in_avals, out_avals, acc): + del in_avals, out_avals + return (None,), wgmma_accumulator_deref_p.bind(acc) + +@lowering.register_lowering_rule(wgmma_accumulator_deref_p) +def _wgmma_accumulator_deref_lowering_rule(ctx: lowering.LoweringRuleContext, acc): + del ctx + nvvm_dialect.wgmma_wait_group_sync_aligned(0) + return acc.value diff --git a/jax/_src/pallas/pallas_call.py b/jax/_src/pallas/pallas_call.py index f6ee5381adc8..1c10d2bda9e9 100644 --- a/jax/_src/pallas/pallas_call.py +++ b/jax/_src/pallas/pallas_call.py @@ -16,6 +16,7 @@ from __future__ import annotations from collections.abc import Callable, Iterable, Sequence +import dataclasses from functools import partial, reduce import itertools from typing import Any @@ -35,7 +36,8 @@ from jax._src.interpreters import mlir from jax._src.interpreters import partial_eval as pe from jax._src.pallas import core as pallas_core -from jax._src.pallas.primitives import uninitialized_value +from jax._src.pallas import primitives +from jax._src.pallas import utils as pallas_utils from jax._src.state import discharge as state_discharge from jax._src.util import ( safe_map, @@ -60,6 +62,7 @@ BlockSpecTree = pallas_core.BlockSpecTree NoBlockSpec = pallas_core.NoBlockSpec no_block_spec = pallas_core.no_block_spec +ScratchShapeTree = pallas_core.ScratchShapeTree CostEstimate = pallas_core.CostEstimate # See the docstring for GridMapping for the calling convention @@ -110,13 +113,15 @@ def _pad_values_to_block_dimension(value, ) if padded_shape != value.shape: pad_width = tuple((0, a-b) for a, b in zip(padded_shape, value.shape)) - pad_value = uninitialized_value(shape=(), dtype=value.dtype) + pad_value = primitives.uninitialized_value(shape=(), dtype=value.dtype) value = jnp.pad(value, pad_width, constant_values=pad_value) return value def _initialize_scratch_vals(scratch_avals) -> tuple[jax.Array, ...]: scratch_avals = (jax_core.raise_to_shaped(x) for x in scratch_avals) - return tuple(uninitialized_value(a.shape, a.dtype) for a in scratch_avals) + return tuple( + primitives.uninitialized_value(a.shape, a.dtype) for a in scratch_avals + ) def _initialize_output_vals( block_mappings_output: Iterable[BlockMapping], @@ -127,8 +132,9 @@ def _initialize_output_vals( if i in oi_map: output_vals.append(input_args[oi_map[i]]) else: - output_vals.append(uninitialized_value(bm.array_shape_dtype.shape, - bm.array_shape_dtype.dtype)) + output_vals.append(primitives.uninitialized_value( + bm.array_shape_dtype.shape, + bm.array_shape_dtype.dtype)) return output_vals def _logical_to_interpret_mode_dtype(dtype): @@ -162,8 +168,12 @@ def _get_next_indices(grid, indices): next_indices.append(jnp.where(carry, 0, i)) return tuple(reversed(next_indices)) -def _pallas_call_impl(*args, **kwargs): - assert False # We always jit a pallas call, we only need the lowering rule +def _pallas_call_impl(*args, **params): + # Call the lowering path + @partial(jax.jit, inline=True) + def _jit_run(*args): + return pallas_call_p.bind(*args, **params) + return _jit_run(*args) def _pallas_call_impl_interpret( @@ -175,9 +185,10 @@ def _pallas_call_impl_interpret( grid_mapping: GridMapping, compiler_params: Any, cost_estimate: CostEstimate, + out_avals: tuple[jax_core.AbstractValue, ...], ): - del compiler_params, cost_estimate - # If we're in interpreter mode, we *scan* over the grid and eval the + del compiler_params, cost_estimate, out_avals + # If we're in interpret mode, we *scan* over the grid and eval the # discharged jaxpr. dynamic_grid_args, args = split_list( # type: ignore args, [grid_mapping.num_dynamic_grid_bounds] @@ -192,7 +203,7 @@ def _pallas_call_impl_interpret( with grid_mapping.trace_env(): discharged_jaxpr, discharged_consts = state_discharge.discharge_state(jaxpr, ()) if debug: - print(f"\nJaxpr the the kernel in pallas_call {name_and_src_info}:") + print(f"\nJaxpr of the the kernel in pallas_call {name_and_src_info}:") print(discharged_jaxpr) out = _initialize_output_vals(grid_mapping.block_mappings_output, args, input_output_aliases) @@ -211,7 +222,7 @@ def _pallas_call_impl_interpret( if padding is not None and any(p != (0, 0) for p in padding): if input_output_aliases: raise NotImplementedError("Padding with aliasing not supported.") - pad_value = uninitialized_value(shape=(), dtype=x.dtype) + pad_value = primitives.uninitialized_value(shape=(), dtype=x.dtype) x = lax.pad(x, pad_value, [(*p, 0) for p in padding]) carry.append(x) @@ -228,6 +239,12 @@ def _pallas_call_impl_interpret( # Pad values to evenly divide into block dimensions. This matches the # behavior of the non-interpret mode. We pad with NaN, to make it easier # to catch OOB accesses. + for carry_element in carry: + aval = carry_element.aval + if isinstance(aval, jax_core.DShapedArray): + aval = jax_core.ShapedArray(aval.shape, aval.dtype) + carry_element.aval = aval + carry = map(_pad_values_to_block_dimension, carry, block_shapes) carry.extend(scratch_values) @@ -247,11 +264,16 @@ def cond(carry): return i < num_iterations def body(carry): i, loop_idx, *carry_blocks = carry - local_grid_env = tuple( - pallas_core.GridAxis(idx, b) - for dim, (idx, b) in enumerate(zip(loop_idx, grid)) - if dim not in grid_mapping.vmapped_dims - ) + + if grid_mapping.local_grid_env is not None: + local_grid_env = grid_mapping.local_grid_env(loop_idx, grid) + else: + local_grid_env = tuple( + pallas_core.GridAxis(idx, b) + for dim, (idx, b) in enumerate(zip(loop_idx, grid)) + if dim not in grid_mapping.vmapped_dims + ) + carry_consts_ins, scratch = split_list(carry_blocks, [num_inout_blocks]) with pallas_core.grid_env(local_grid_env): start_indices = [ @@ -268,8 +290,14 @@ def body(carry): len(blocks), len(scratch_values), ) - blocks = jax_core.eval_jaxpr(discharged_jaxpr, discharged_consts, *scalars, - *blocks, *scratch) + for s in scalars: + aval = jax_core.get_aval(s) + if isinstance(aval, jax_core.DShapedArray): + s.aval = aval.update(dtype=jnp.int32) + + blocks = jax_core.eval_jaxpr( + discharged_jaxpr, discharged_consts, *scalars, *blocks, *scratch + ) _, out_inout, out_scratch = split_list( blocks, [grid_mapping.num_index_operands, num_inout_blocks]) @@ -301,10 +329,20 @@ def body(carry): pallas_call_p.def_impl(_pallas_call_impl) -def _pallas_call_abstract_eval(*avals, grid_mapping: GridMapping, **_): - return tuple(jax_core.ShapedArray(bm.array_shape_dtype.shape, - bm.array_shape_dtype.dtype) - for bm in grid_mapping.block_mappings_output) + +def _pallas_call_abstract_eval( + *avals, out_avals: tuple[jax_core.AbstractValue, ...], **_ +): + del avals + # Make sure we don't return ShapedArrayWithMemorySpace to the outside world. + return [ + jax_core.ShapedArray(a.shape, a.dtype, a.weak_type) + if isinstance(a, pallas_core.ShapedArrayWithMemorySpace) + else a + for a in out_avals + ] + + pallas_call_p.def_abstract_eval(_pallas_call_abstract_eval) @@ -320,6 +358,7 @@ def _pallas_call_jvp_rule( interpret, compiler_params: Any, cost_estimate: CostEstimate | None, + out_avals: tuple[jax_core.AbstractValue, ...], ): if grid_mapping.num_dynamic_grid_bounds: raise NotImplementedError("interpret with dynamic grid bounds unsupported") @@ -383,6 +422,7 @@ def _pallas_call_jvp_rule( input_output_aliases=(), compiler_params=compiler_params, cost_estimate=jvp_cost_estimate, + out_avals=(*out_avals, *out_avals) ) out_primals, out_tangents = split_list(out_flat, [len(out_flat) // 2]) return out_primals, out_tangents @@ -390,19 +430,55 @@ def _pallas_call_jvp_rule( ad.primitive_jvps[pallas_call_p] = _pallas_call_jvp_rule -def _batch_block_mapping(grid_mapping: GridMapping, - axis_size: int, - aval: jax_core.ShapedArray, - dim: int | batching.NotMapped, - block_mapping: BlockMapping) -> BlockMapping: + +def _batch_block_mapping( + grid_mapping: GridMapping, + axis_size: int, + aval: jax_core.ShapedArray, + dim: int | batching.NotMapped, + block_mapping: BlockMapping, + for_ragged: bool, +) -> BlockMapping: def _block_map_function(new_idx, *args): - indices = jax_core.eval_jaxpr(block_mapping.index_map_jaxpr.jaxpr, - block_mapping.index_map_jaxpr.consts, - *args) + if for_ragged: + drop_last_args = args[:-1] + else: + drop_last_args = args + + indices = jax_core.eval_jaxpr( + block_mapping.index_map_jaxpr.jaxpr, + block_mapping.index_map_jaxpr.consts, + *drop_last_args, + ) if dim is not batching.not_mapped: - indices.insert(dim, new_idx) + if isinstance(dim, batching.RaggedAxis): + assert for_ragged, "Ragged axis not supported for non-ragged batching." + stacked_axis = dim.stacked_axis + indices.insert(stacked_axis, new_idx) + else: + indices.insert(dim, new_idx) return tuple(indices) idx_avals = [pallas_core.index_map_grid_aval, *block_mapping.index_map_jaxpr.in_avals] + + if for_ragged: + if isinstance(dim, batching.RaggedAxis): + assert for_ragged, "Ragged axis not supported for non-ragged batching." + _, _, ragged_axis_length = _ragged_axis_parts(dim) + aval = jax_core.get_aval(ragged_axis_length).update(dtype=jnp.int32) + if isinstance(aval, jax_core.DShapedArray): + aval = jax_core.ShapedArray(aval.shape, aval.dtype, aval.weak_type) + lengths_aval = pallas_core.AbstractMemoryRef( + aval, + pallas_core.MemorySpace.INDEX, + ) + idx_avals = [*idx_avals, lengths_aval] + else: + i32_aval_memref = pallas_core.AbstractMemoryRef( + jax_core.ShapedArray(([axis_size]), jnp.int32), + pallas_core.MemorySpace.INDEX, + ) + idx_avals = [*idx_avals, i32_aval_memref] + with grid_mapping.trace_env(): block_mapping_jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic( lu.wrap_init(_block_map_function), idx_avals) @@ -411,12 +487,27 @@ def _block_map_function(new_idx, *args): new_block_shape = shape new_array_shape_dtype = block_mapping.array_shape_dtype else: - new_block_shape = tuple_insert(shape, dim, pallas_core.mapped) + if isinstance(dim, batching.RaggedAxis): + assert for_ragged, "Ragged axis not supported for non-ragged batching." + new_block_shape = shape + stacked_axis = dim.stacked_axis + new_block_shape = tuple_insert( + new_block_shape, stacked_axis, pallas_core.mapped + ) + else: + new_block_shape = tuple_insert(shape, dim, pallas_core.mapped) + + array_shape = block_mapping.array_shape_dtype.shape + if isinstance(dim, batching.RaggedAxis): + assert for_ragged, "Ragged axis not supported for non-ragged batching." + stacked_axis = dim.stacked_axis + array_shape = tuple_insert(array_shape, stacked_axis, axis_size) + else: + array_shape = tuple_insert(array_shape, dim, axis_size) + new_array_shape_dtype = jax.ShapeDtypeStruct( - tuple_insert(block_mapping.array_shape_dtype.shape, - dim, - axis_size), - block_mapping.array_shape_dtype.dtype) + array_shape, block_mapping.array_shape_dtype.dtype + ) jaxpr = jax_core.ClosedJaxpr(block_mapping_jaxpr, consts) return block_mapping.replace(block_shape=new_block_shape, @@ -465,6 +556,7 @@ def _batch_with_explicit_loop( interpret: bool, compiler_params: Any, cost_estimate: CostEstimate | None, + out_avals: tuple[jax_core.AbstractValue, ...], ): """Batch the pallas_call by calling it in loop over the batch size. @@ -531,6 +623,7 @@ def body(batch_index: jax.Array, state: list[jax.Array]) -> list[jax.Array]: interpret=interpret, compiler_params=compiler_params, cost_estimate=cost_estimate, + out_avals=out_avals, ) for i, batch_out_array in enumerate(batch_out): state[i] = jax.lax.dynamic_update_index_in_dim( @@ -547,6 +640,16 @@ def body(batch_index: jax.Array, state: list[jax.Array]) -> list[jax.Array]: return result, (0,) * len(result) +def _ragged_axis_parts(dim: batching.RaggedAxis) -> tuple[int, int, int]: + stacked_axis = dim.stacked_axis + ragged_axes = dim.ragged_axes + if len(ragged_axes) != 1: + raise ValueError("Multiple ragged axes not yet implemented.") + ragged_axis_dim = ragged_axes[0][0] + ragged_axis_length = ragged_axes[0][1] + return stacked_axis, ragged_axis_dim, ragged_axis_length + + def _pallas_call_batching_rule( args, dims, @@ -559,6 +662,7 @@ def _pallas_call_batching_rule( interpret: bool, compiler_params: Any, cost_estimate: CostEstimate | None, + out_avals: tuple[jax_core.AbstractValue, ...], ): def _maybe_squeeze_out_bdim( x: jax.Array, bdim: int | batching.NotMapped @@ -567,8 +671,26 @@ def _maybe_squeeze_out_bdim( return x return jnp.squeeze(x, axis=bdim) + all_ragged_axes = [d for d in dims if isinstance(d, batching.RaggedAxis)] + if len(all_ragged_axes) > 1: + raise ValueError("Multiple ragged dimensions not yet implemented.") + + if all_ragged_axes: + stacked_axis, ragged_axis_dim, ragged_axis_length = _ragged_axis_parts( + all_ragged_axes[0] + ) + else: + stacked_axis, ragged_axis_dim, ragged_axis_length = None, None, None + + def get_size(i, x, d): + if not isinstance(d, batching.RaggedAxis): + return x.shape[d] + return x.aval.shape[i] + (axis_size,) = { - x.shape[d] for x, d in zip(args, dims) if d is not batching.not_mapped + get_size(i=i, x=x, d=d) + for i, (x, d) in enumerate(zip(args, dims)) + if d is not batching.not_mapped } if axis_size == 1: # Why are we even vmapping? @@ -583,6 +705,7 @@ def _maybe_squeeze_out_bdim( interpret=interpret, compiler_params=compiler_params, cost_estimate=cost_estimate, + out_avals=out_avals, ) return [jnp.expand_dims(x, 0) for x in out], (0,) * len(out) @@ -615,6 +738,7 @@ def _maybe_squeeze_out_bdim( interpret=interpret, compiler_params=compiler_params, cost_estimate=cost_estimate, + out_avals=out_avals, ) else: pass # No dynamic grid dimensions @@ -648,6 +772,7 @@ def _maybe_squeeze_out_bdim( interpret=interpret, compiler_params=compiler_params, cost_estimate=cost_estimate, + out_avals=out_avals, ) if not dims: @@ -670,12 +795,27 @@ def _maybe_squeeze_out_bdim( num_index_operands = grid_mapping.num_index_operands num_scratch_operands = grid_mapping.num_scratch_operands + lengths_aval = None + if ragged_axis_length is not None: + aval = jax_core.get_aval(ragged_axis_length).update(dtype=jnp.int32) + if isinstance(aval, jax_core.DShapedArray): + aval = jax_core.ShapedArray(aval.shape, aval.dtype, aval.weak_type) + lengths_aval = pallas_core.AbstractMemoryRef( + aval, + pallas_core.MemorySpace.INDEX, + ) + # Only add a batch dimension for the avals that actually have a grid mapping. # This excludes scalar prefetch inputs (the first in the list) and scratch # operands (the last in the list). avals_to_batch = avals[num_index_operands:(len(avals) - num_scratch_operands)] batched_block_mappings = map( - partial(_batch_block_mapping, grid_mapping, axis_size), + partial( + _batch_block_mapping, + grid_mapping, + axis_size, + for_ragged=lengths_aval is not None, + ), avals_to_batch, all_dims[num_index_operands:], block_mappings, @@ -685,15 +825,23 @@ def _maybe_squeeze_out_bdim( grid_mapping.index_map_avals) assert not index_map_tree_kwargs batched_index_map_args = (pallas_core.index_map_grid_aval,) + index_map_tree_args + + if lengths_aval: + batched_index_map_args = batched_index_map_args + (lengths_aval,) + num_index_operands += 1 + batched_index_map_avals, batched_index_map_tree = tree_util.tree_flatten( (batched_index_map_args, {})) + batched_grid_mapping = grid_mapping.replace( grid=(axis_size, *grid_mapping.grid), block_mappings=tuple(batched_block_mappings), - index_map_avals=batched_index_map_avals, + index_map_avals=tuple(batched_index_map_avals), index_map_tree=batched_index_map_tree, + num_index_operands=num_index_operands, vmapped_dims=(0,) + tuple(a + 1 for a in grid_mapping.vmapped_dims), ) + if cost_estimate is not None: batched_cost_estimate = CostEstimate( flops=cost_estimate.flops * axis_size, @@ -702,6 +850,107 @@ def _maybe_squeeze_out_bdim( ) else: batched_cost_estimate = None + + if lengths_aval: + batched_grid_mapping = batched_grid_mapping.replace( + get_grid_indices=lambda indices, maybe_include_mapped_dims: indices, + local_grid_env=lambda loop_idx, grid: tuple( + pallas_core.GridAxis(idx, b) for (idx, b) in zip(loop_idx, grid) + ), + ) + + # Note - on zero filling counterfactuals + # A debug util to produce a counterfactual version of the when + # gating, where for all values that don't pass the @when check, + # we write 0s. This is useful for debugging, as certain lowering paths + # like mosaic will write the last data as passthrough, leading to + # potentially confusing results. + debug_zero_fill_counterfactual = debug + + first_block_mapping = batched_grid_mapping.block_mappings[0] + for block_mapping in batched_grid_mapping.block_mappings: + # This invariant may already be checked elsewhere, but lets reaffirm it + assert block_mapping.block_shape == first_block_mapping.block_shape, ( + f"block_mapping.block_shape: {block_mapping.block_shape}, " + f"first_block_mapping.block_shape: {first_block_mapping.block_shape}" + ) + assert ( + block_mapping.array_shape_dtype + == first_block_mapping.array_shape_dtype + ), ( + f"block_mapping.array_shape_dtype: {block_mapping.array_shape_dtype}," + " first_block_mapping.array_shape_dtype:" + f" {first_block_mapping.array_shape_dtype}" + ) + + mapped_dim_idxs = [ + i + for i, d in enumerate(first_block_mapping.block_shape) + if d is pallas_core.mapped + ] + assert len(mapped_dim_idxs) == 1 + mapped_dim_idx = mapped_dim_idxs[0] + if stacked_axis != mapped_dim_idx: + raise ValueError( + f"Expected mapped dim to be {stacked_axis}, but got {mapped_dim_idx}" + ) + + assert ragged_axis_dim is not None, "Invariant violation" + # This is the blockspec size of the dimension + val_at_ragged_dim = first_block_mapping.block_shape[ragged_axis_dim] + + def when_wrapped_kernel(lengths_ref, *args, **kwargs): + b_idx = primitives.program_id(stacked_axis) + i_idx = ( + primitives.program_id(ragged_axis_dim) + * val_at_ragged_dim + ) + b_len = lengths_ref[b_idx] + + # TODO(mvoz): Unimplemented primitive in pallas + # b_len_mod = jnp.equal(jnp.mod(b_len, val_at_ragged_dim), 0) + # checkify.check(b_len_mod, "b_len % val_at_ragged_dim != 0") + + @pallas_utils.when(i_idx < b_len) + def f(): + # Important! This allows us to trace the inner kernel with the correct + # grid to preserve user program_id semantics. Ex: program_id(0) will + # always be analogous to program_id(1) in the outer kernel. + with pallas_core.tracing_grid_env(grid_mapping.grid, ()): + jax_core.eval_jaxpr(jaxpr, (), *args, **kwargs) + + if debug_zero_fill_counterfactual: + + @pallas_utils.when(i_idx >= b_len) + def g(): + for arg_ref in args: + arg_ref[...] = jnp.zeros_like(arg_ref) + + kernel_avals = [lengths_aval] + [v.aval for v in jaxpr.invars] + flat_kernel_avals, kernel_in_tree = tree_util.tree_flatten( + list(kernel_avals) + ) + # Important! This allows us to trace the outer kernel with the correct grid + # to enable accessing the batch program_id. + with pallas_core.tracing_grid_env(batched_grid_mapping.grid, ()): + kernel_src_info: pallas_core.SrcInfoStr = "" + + jaxpr = _trace_kernel_to_jaxpr( + when_wrapped_kernel, + kernel_src_info, + batched_grid_mapping, + tuple(flat_kernel_avals), + kernel_in_tree, + interpret=interpret, + ) + + assert ragged_axis_length is not None + args = (ragged_axis_length, *args) + assert all(isinstance(aval, jax_core.ShapedArray) for aval in out_avals) + batched_out_avals = tuple( + aval.update(shape=tuple_insert(aval.shape, 0, axis_size)) + for aval in out_avals + ) out = pallas_call_p.bind( *dynamic_grid_args, *args, @@ -715,6 +964,7 @@ def _maybe_squeeze_out_bdim( interpret=interpret, compiler_params=compiler_params, cost_estimate=batched_cost_estimate, + out_avals=batched_out_avals, ) return out, (0,) * len(out) @@ -744,6 +994,7 @@ def pallas_call_checkify_rule(error: checkify.Error, interpret: bool, input_output_aliases: tuple[tuple[int, int], ...], grid_mapping: GridMapping, + out_avals: tuple[jax_core.AbstractValue, ...], **kwargs): # We implement the checkify rule in 4 steps: # 1) First, trace the kernel body to get the expected error shapes. @@ -870,11 +1121,13 @@ def _ensure_2d_error_shape(arg): (i+num_scalars, i) for i in range(num_err_vals)) + input_output_aliases new_vals_in = [*scalars, *err_vals, *args] + new_out_avals = (*shaped_err_avals, *out_avals) result = pallas_call_p.bind(*dynamic_grid_bounds, *new_vals_in, jaxpr=final_jaxpr, interpret=interpret, grid_mapping=grid_mapping_with_error, input_output_aliases=input_output_aliases_with_error, + out_avals=new_out_avals, **kwargs) errors, results = split_list(result, [num_err_vals]) # TODO(b/350593266): Remove line below once we support ()-shaped scalars. @@ -889,7 +1142,7 @@ def _trace_kernel_to_jaxpr(fun: Callable, grid_mapping: GridMapping, kernel_avals: tuple[pallas_core.AbstractMemRef, ...], kernel_in_tree: tree_util.PyTreeDef, - interpret: bool + interpret: bool, ) -> jax_core.ClosedJaxpr: if interpret: kernel_avals = tuple(map(_logical_aval_to_interpret_mode_aval, @@ -1003,6 +1256,17 @@ def _pallas_call_typecheck_rule(*in_avals, grid_mapping, **params): ) jax_core.custom_typechecks[pallas_call_p] = _pallas_call_typecheck_rule +def _convert_out_shape_to_aval(out_shape: Any) -> jax_core.AbstractValue: + match out_shape: + case jax.ShapeDtypeStruct(): + return jax_core.ShapedArray(shape=out_shape.shape, dtype=out_shape.dtype) + case pallas_core.MemoryRef(): + return out_shape.get_array_aval() + case _: + if not (hasattr(out_shape, "shape") and hasattr(out_shape, "dtype")): + raise ValueError(f"Invalid out_shape type: {type(out_shape)}") + return jax_core.ShapedArray(shape=out_shape.shape, dtype=out_shape.dtype) + def pallas_call( kernel: Callable[..., None], @@ -1012,11 +1276,12 @@ def pallas_call( grid: TupleGrid = (), in_specs: BlockSpecTree = no_block_spec, out_specs: BlockSpecTree = no_block_spec, + scratch_shapes: ScratchShapeTree = (), input_output_aliases: dict[int, int] = {}, debug: bool = False, interpret: bool = False, name: str | None = None, - compiler_params: dict[str, Any] | None = None, + compiler_params: dict[str, Any] | pallas_core.CompilerParams | None = None, cost_estimate: CostEstimate | None = None, ) -> Callable[..., Any]: """Invokes a Pallas kernel on some inputs. @@ -1029,8 +1294,9 @@ def pallas_call( corresponding ``in_specs`` and ``out_specs``. out_shape: a PyTree of :class:`jax.ShapeDtypeStruct` describing the shape and dtypes of the outputs. - grid_spec: An alternative way to specify ``grid``, ``in_specs``, and - ``out_specs``. If given, those other parameters must not be also given. + grid_spec: An alternative way to specify ``grid``, ``in_specs``, + ``out_specs`` and ``scratch_shapes``. If given, those other parameters + must not be also given. grid: the iteration space, as a tuple of integers. The kernel is executed as many times as ``prod(grid)``. See details at :ref:`pallas_grid`. @@ -1044,6 +1310,9 @@ def pallas_call( The default value for ``out_specs`` specifies the whole array, e.g., as ``pl.BlockSpec(x.shape, lambda *indices: (0,) * x.ndim)``. See details at :ref:`pallas_blockspec`. + scratch_shapes: a PyTree of backend-specific temporary objects required + by the kernel, such as temporary buffers, synchronization primitives, + etc. input_output_aliases: a dictionary mapping the index of some inputs to the index of the output that aliases them. These indices are in the flattened inputs and outputs. @@ -1058,7 +1327,12 @@ def pallas_call( where the kernel function is defined, .e.g: `{name} for kernel function {kernel_name} at {file}:{line}`. If missing, then we use `{kernel_name} at {file}:{line}`. - compiler_params: TO BE DOCUMENTED. + compiler_params: Optional compiler parameters. If a dict is provided, it + should be of the form {platform: {param_name: param_value}}, where + platform is either 'mosaic' or 'triton'. It is also possible + to pass in `jax.experimental.pallas.tpu.TPUCompilerParams` for TPUs and + `jax.experimental.pallas.gpu.TritonCompilerParams` for Triton/GPUs. + Returns: A function that can be called on a number of positional array arguments to @@ -1070,9 +1344,16 @@ def pallas_call( name, kernel_src_info) if compiler_params is None: compiler_params = {} + if isinstance(compiler_params, pallas_core.CompilerParams): + if compiler_params.PLATFORM not in ["mosaic", "mosaic_gpu", "triton"]: + raise ValueError( + f"Unknown platform in compiler params: {compiler_params.PLATFORM}") + compiler_params = { + compiler_params.PLATFORM: dataclasses.asdict(compiler_params) + } if grid_spec is None: - grid_spec = GridSpec(grid, in_specs, out_specs) + grid_spec = GridSpec(grid, in_specs, out_specs, scratch_shapes) else: if grid: raise ValueError( @@ -1086,6 +1367,10 @@ def pallas_call( raise ValueError( "If `grid_spec` is specified, then `out_specs` must " f"be `no_block_spec`. It is {out_specs}") + if scratch_shapes: + raise ValueError( + "If `grid_spec` is specified, then `scratch_shapes` must " + f"be `()`. It is {scratch_shapes}") del grid, in_specs, out_specs grid_spec, dynamic_grid_bounds = pallas_core.unzip_dynamic_grid_bounds(grid_spec) # TODO(necula): this canonicalization may be convenient for some usage @@ -1095,15 +1380,15 @@ def pallas_call( out_shape = tuple(out_shape) flat_out_shapes_with_paths, out_tree = tree_util.tree_flatten_with_path(out_shape) out_paths, flat_out_shapes = unzip2(flat_out_shapes_with_paths) - flat_out_shapes = [jax.ShapeDtypeStruct(x.shape, x.dtype) # type: ignore - for x in flat_out_shapes] - @jax.jit + + @partial(jax.jit, inline=True) def wrapped(*args): flat_args_with_paths, in_tree = tree_util.tree_flatten_with_path(args) in_paths, flat_args = unzip2(flat_args_with_paths) flat_in_avals = tuple(jax_core.raise_to_shaped(jax_core.get_aval(a)) for a in flat_args) - flat_out_avals = tuple(jax_core.ShapedArray(v.shape, v.dtype) + + flat_out_avals = tuple(_convert_out_shape_to_aval(v) for v in flat_out_shapes) kernel_fun_sig = api_util.fun_signature(kernel) @@ -1126,10 +1411,11 @@ def wrapped(*args): flat_in_avals, in_tree, in_origins, flat_out_avals, out_tree, out_origins) flat_kernel_avals, kernel_in_tree = tree_util.tree_flatten(kernel_avals) - jaxpr = _trace_kernel_to_jaxpr( - kernel, kernel_src_info, - grid_mapping, tuple(flat_kernel_avals), kernel_in_tree, - interpret=interpret) + with pallas_core.interpret_mode_env(interpret): + jaxpr = _trace_kernel_to_jaxpr( + kernel, kernel_src_info, + grid_mapping, tuple(flat_kernel_avals), kernel_in_tree, + interpret=interpret) for i_idx, o_idx in input_output_aliases.items(): if i_idx not in range(len(flat_in_avals)): raise ValueError( @@ -1152,33 +1438,38 @@ def wrapped(*args): f"a different abstract value {out_aval}.") index_args, rest_args = split_list(flat_args, [grid_mapping.num_index_operands]) - out_flat = pallas_call_p.bind( - *dynamic_grid_bounds, - *index_args, - *rest_args, - jaxpr=jaxpr, - name_and_src_info=name_and_src_info, - debug=debug, - interpret=interpret, - grid_mapping=grid_mapping, - input_output_aliases=tuple(input_output_aliases.items()), - compiler_params=compiler_params, - cost_estimate=cost_estimate, - ) + with pallas_core.interpret_mode_env(interpret): + out_flat = pallas_call_p.bind( + *dynamic_grid_bounds, + *index_args, + *rest_args, + out_avals=flat_out_avals, + jaxpr=jaxpr, + name_and_src_info=name_and_src_info, + debug=debug, + interpret=interpret, + grid_mapping=grid_mapping, + input_output_aliases=tuple(input_output_aliases.items()), + compiler_params=compiler_params, + cost_estimate=cost_estimate, + ) out = tree_util.tree_unflatten(out_tree, out_flat) return out return wrapped -def in_path_to_input_origin(in_path: tree_util.KeyPath, - arg_names: tuple[str, ...] | None) -> pallas_core.OriginStr: +def in_path_to_input_origin( + in_path: tree_util.KeyPath, arg_names: tuple[str, ...] | None +) -> pallas_core.OriginStr: """Converts `args[k]` into `arg_k_name`.""" if arg_names is None: return f"args{tree_util.keystr(in_path)}" if len(in_path) == 0: return "args" arg_idx, *rest_path = in_path - if isinstance(arg_idx, tree_util.SequenceKey) and arg_idx.idx < len(arg_names): + if isinstance(arg_idx, tree_util.SequenceKey) and arg_idx.idx < len( + arg_names + ): return arg_names[arg_idx.idx] + tree_util.keystr(tuple(rest_path)) else: return f"args{tree_util.keystr(tuple(in_path))}" diff --git a/jax/_src/pallas/primitives.py b/jax/_src/pallas/primitives.py index 7ba5fa27791f..40caae76bd8f 100644 --- a/jax/_src/pallas/primitives.py +++ b/jax/_src/pallas/primitives.py @@ -59,6 +59,8 @@ def program_id(axis: int) -> jax.Array: grid coordinates `(1, 2)`, `program_id(axis=0)` returns `1` and `program_id(axis=1)` returns `2`. + The returned value is an array of shape `()` and dtype `int32`. + Args: axis: the axis of the grid along which to count the program. """ @@ -177,8 +179,10 @@ def _atomic_abstract_eval(*avals_flat, args_tree, atomic_type: AtomicOpType): def _atomic_rmw(x_ref_or_view, idx, val, *, mask: Any | None = None, atomic_type: AtomicOpType): - x_ref, indexers = sp.get_ref_and_indexers(x_ref_or_view, idx, "atomic_rmw") - args_flat, args_tree = tree_util.tree_flatten((x_ref, indexers, val, mask)) + x_ref, transforms = sp.get_ref_and_transforms( + x_ref_or_view, idx, "atomic_rmw" + ) + args_flat, args_tree = tree_util.tree_flatten((x_ref, transforms, val, mask)) return atomic_rmw_p.bind( *args_flat, args_tree=args_tree, atomic_type=atomic_type ) @@ -379,7 +383,7 @@ def _load_pp_rule(eqn, context, settings): result = [ lhs, pp.text(' <- '), - sp.pp_ref_indexers(context, x, indexers) + sp.pp_ref_transforms(context, x, indexers) ] if mask is not None: result += [ @@ -421,10 +425,17 @@ def _load_jvp(primals, tangents, args_tree, **params): def uninitialized_value(shape, dtype): if jnp.issubdtype(dtype, jnp.floating): return jnp.full(shape, jnp.nan, dtype) + # Note: Currently semaphore is i16[], meaning this case needs to be + # handled before the general case for integers. + # TODO(justinfu): Handle semaphores with a custom extended dtype. + elif jnp.issubdtype(dtype, pallas_core.SEMAPHORE_INTERPRET_DTYPE): + return jnp.full(shape, 0, dtype) elif jnp.issubdtype(dtype, jnp.integer): return jnp.full(shape, jnp.iinfo(dtype).min, dtype) elif jnp.issubdtype(dtype, jnp.bool): return jnp.full(shape, False, dtype) + elif jnp.issubdtype(dtype, pallas_core.semaphore_dtype): + return jnp.full(shape, 0, dtype) raise NotImplementedError(dtype) def _pad_values_to_avoid_dynamic_slice_oob_shift(value, @@ -464,7 +475,7 @@ def _load_discharge_rule(in_avals, out_avals, *args_flat, args_tree, **_): raise NotImplementedError("Only one indexer supported in discharge rule.") idx = indexers[0] if all((isinstance(s, Slice) or not s.shape) for s in idx.indices): - # TODO(b/329733289): support strided load/store in interpret mode. + # TODO(ayx): support strided load/store in interpret mode. for s in idx.indices: if isinstance(s, Slice) and s.stride > 1: raise NotImplementedError("Unimplemented stride support.") @@ -522,7 +533,7 @@ def _swap_pp_rule(eqn, context, settings): # Pretty prints `_ = swap x v i` as `x[i] <- v` y, = eqn.outvars x, indexers, val, mask = eqn.params["args_tree"].unflatten(eqn.invars) - x_i = sp.pp_ref_indexers(context, x, indexers) + x_i = sp.pp_ref_transforms(context, x, indexers) if isinstance(y, jax_core.DropVar): return pp.concat([ x_i, @@ -576,7 +587,7 @@ def _swap_discharge_rule(in_avals, out_avals, *args_flat, args_tree, **_): raise NotImplementedError("Only one indexer supported in discharge rule.") idx = indexers[0] if all((isinstance(s, Slice) or not s.shape) for s in idx.indices): - # TODO(b/329733289): support strided load/store in interpret mode. + # TODO(ayx): support strided load/store in interpret mode. for s in idx.indices: if isinstance(s, Slice) and s.stride > 1: raise NotImplementedError("Unimplemented stride support.") @@ -631,8 +642,10 @@ def load(x_ref_or_view, idx, *, mask=None, other=None, cache_modifier=None, eviction_policy: TO BE DOCUMENTED. volatile: TO BE DOCUMENTED. """ - x_ref, indexers = sp.get_ref_and_indexers(x_ref_or_view, idx, "load") - args_flat, args_tree = tree_util.tree_flatten((x_ref, indexers, mask, other)) + x_ref, transforms = sp.get_ref_and_transforms(x_ref_or_view, idx, "load") + args_flat, args_tree = tree_util.tree_flatten( + (x_ref, transforms, mask, other) + ) return load_p.bind( *args_flat, args_tree=args_tree, @@ -650,8 +663,10 @@ def swap(x_ref_or_view, idx, val, *, mask=None, eviction_policy=None, Returns: The value stored in the ref prior to the swap. """ - x_ref, indexers = sp.get_ref_and_indexers(x_ref_or_view, idx, _function_name) - args_flat, args_tree = tree_util.tree_flatten((x_ref, indexers, val, mask)) + x_ref, transforms = sp.get_ref_and_transforms( + x_ref_or_view, idx, _function_name + ) + args_flat, args_tree = tree_util.tree_flatten((x_ref, transforms, val, mask)) return swap_p.bind( *args_flat, args_tree=args_tree, eviction_policy=eviction_policy ) @@ -700,8 +715,8 @@ class PrintEffect(effects.Effect): debug_print_p.multiple_results = True -def debug_print(fmt: str, *args: jax.ArrayLike): - """Prints scalar values from inside a Pallas kernel. +def debug_print(fmt: str, *args: jax.typing.ArrayLike): + """Prints values from inside a Pallas kernel. Args: fmt: A format string to be included in the output. The restrictions on the @@ -711,11 +726,11 @@ def debug_print(fmt: str, *args: jax.ArrayLike): (``{...}``), since it is always printed before any of the values. * On GPU, when using the experimental Mosaic GPU backend, ``fmt`` must contain a placeholder for each value to be printed. Format specs and - conversions are not supported. + conversions are not supported. All values must be scalars. * In TPU, if ``fmt`` contains placeholders, all values must be 32-bit integers. If there are no placeholders, the values are printed after - the format string. - *args: The scalar values to print. + the format string. All values must be scalars. + *args: The values to print. """ # fmt: skip has_placeholders = False if fmt: @@ -725,7 +740,7 @@ def debug_print(fmt: str, *args: jax.ArrayLike): def check_debug_print_format( - fmt: str, *args: jax.ArrayLike + fmt: str, *args: jax.typing.ArrayLike ): n_placeholders = 0 for _, field, spec, conversion in string.Formatter().parse(fmt): @@ -758,9 +773,7 @@ def debug_print_impl(*args: Any, fmt: str, has_placeholders: bool): @debug_print_p.def_effectful_abstract_eval def debug_print_abstract_eval(*avals: Any, fmt: str, has_placeholders: bool): - del fmt, has_placeholders - if any(aval.shape for aval in avals): - raise ValueError("Only scalar values are supported") + del avals, fmt, has_placeholders # Unused. return [], {debug_print_effect} @@ -817,7 +830,7 @@ def run_scoped(f: Callable[..., Any], *types, **kw_types) -> Any: flat_types, in_tree = tree_util.tree_flatten((types, kw_types)) flat_fun, out_tree_thunk = api_util.flatten_fun(lu.wrap_init(f), in_tree) - avals = [t.get_aval() for t in flat_types] + avals = [t.get_ref_aval() for t in flat_types] # Turn the function into a jaxpr. The body of run_scoped may have # effects (IO) on constvars (i.e. variables inherited from the # parent scope). Jax can't reason about effects to references that @@ -843,3 +856,54 @@ def _run_scoped_abstract_eval(*args, jaxpr): ) } return [v.aval for v in jaxpr.outvars], nonlocal_effects + + +def _run_scoped_discharge_rule(in_avals, + out_avals, + *args_flat, + jaxpr, + **_): + del out_avals + num_consts = len(args_flat) + jaxpr_noconst = pe.convert_constvars_jaxpr(jaxpr) + num_return_values = len(jaxpr_noconst.outvars) + discharged_body, new_consts = state_discharge.discharge_state( + jaxpr_noconst, []) + if new_consts: + raise NotImplementedError( + "Cannot handle new consts created by state discharge.") + # Create inputs filled with uninitialized values to the body. + body_avals = [v.aval for v in discharged_body.invars[num_consts:]] + init_vals = [uninitialized_value( + aval.shape, aval.dtype) for aval in body_avals] + init_vals_with_consts = args_flat + tuple(init_vals) + out = jax_core.eval_jaxpr(discharged_body, [], *init_vals_with_consts) + # Order of outputs: + # (1) return values, (2) closed refs, (3) scoped refs. + return_values = out[:num_return_values] + ref_outputs = out[num_return_values:] + # We update all ref values with their updated values from the discharged + # body. For other values we leave them in place. + updates = [ + ref_outputs.pop(0) if isinstance(aval, pallas_core.AbstractMemoryRef) + else None for aval in in_avals] + assert len(ref_outputs) == len( + body_avals), f'{len(body_avals)}, != {len(ref_outputs)}' + assert len(updates) == len(in_avals), f'{len(updates)} != {len(in_avals)}' + return updates, return_values + + +state_discharge.register_discharge_rule(run_scoped_p)( + _run_scoped_discharge_rule) + + +@functools.partial(mlir.register_lowering, run_scoped_p) +def _run_scoped_lowering_rule(ctx, *args, jaxpr): + # This lowering rule gets triggered when run_scoped is not discharged. + # In this case there are no stateful effects to handle. + def _lower_fun(*lower_fun_args): + updates, out = _run_scoped_discharge_rule([], [], *lower_fun_args, + jaxpr=jaxpr) + assert len(updates) == 0, 'Cannot lower run_scoped with effects.' + return out + return mlir.lower_fun(_lower_fun, multiple_results=True)(ctx, *args) diff --git a/jax/_src/pallas/triton/BUILD b/jax/_src/pallas/triton/BUILD index 01d2480983d5..a9babcba0577 100644 --- a/jax/_src/pallas/triton/BUILD +++ b/jax/_src/pallas/triton/BUILD @@ -23,10 +23,16 @@ load( package( default_applicable_licenses = [], default_visibility = [ - "//:__subpackages__", + "//jax:internal", ], ) +pytype_strict_library( + name = "core", + srcs = ["core.py"], + deps = ["//jax/_src/pallas"], +) + pytype_strict_library( name = "primitives", srcs = ["primitives.py"], diff --git a/jax/_src/pallas/triton/core.py b/jax/_src/pallas/triton/core.py new file mode 100644 index 000000000000..a61dfd61b9b1 --- /dev/null +++ b/jax/_src/pallas/triton/core.py @@ -0,0 +1,38 @@ +# Copyright 2024 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Contains Triton-specific Pallas abstractions.""" +from __future__ import annotations + +import dataclasses +from typing import ClassVar + +from jax._src.pallas import core as pallas_core + +@dataclasses.dataclass(frozen=True) +class TritonCompilerParams(pallas_core.CompilerParams): + """Compiler parameters for Triton. + + Attributes: + num_warps: The number of warps to use for the kernel. Each warp consists of + 32 threads. + num_stages: The number of stages the compiler should use for software + pipelining loops. + serialized_metadata: Additional compiler metadata. This field is unstable + and may be removed in the future. + """ + PLATFORM: ClassVar[str] = "triton" + num_warps: int | None = None + num_stages: int | None = None + serialized_metadata: str | None = None diff --git a/jax/_src/pallas/triton/lowering.py b/jax/_src/pallas/triton/lowering.py index 852ac714d3c9..9db5e4081239 100644 --- a/jax/_src/pallas/triton/lowering.py +++ b/jax/_src/pallas/triton/lowering.py @@ -121,6 +121,12 @@ def _eval_index_map( _ensure_ir_value(i, jax_core.ShapedArray((), jnp.int32)) for i in block_indices ) + if isinstance(block_mapping.indexing_mode, pallas_core.Unblocked): + if block_mapping.indexing_mode.padding is not None: + raise NotImplementedError( + "Unblocked indexing with padding is not supported in Triton lowering." + ) + return tuple(block_indices) return tuple( i if b is pallas_core.mapped else _mul(i, _ir_constant(b, i.type)) for i, b in zip(block_indices, block_mapping.block_shape) @@ -214,19 +220,18 @@ def _process_grid_to_3d_grid(grid_mapping: GridMapping): if len(collapse_dims) == 0: prog_ids = [None] * len(prog_id_dims) for i in range(len(prog_id_dims)): - out_idx = launch_grid_to_pallas_grid[i] - prog_ids[out_idx] = _program_id(i) + prog_ids[launch_grid_to_pallas_grid[i]] = _program_id(i, prog_id_dims) return prog_id_dims, prog_ids - else: - new_grid = [math.prod(collapse_dims), *prog_id_dims] + + new_grid = [math.prod(collapse_dims), *prog_id_dims] assert new_grid[0] < 2**31 - 1, \ "Cannot fix pallas kernel launch grid within CUDA limits" out_indices = [None] * len(grid_mapping.grid) - grid0 = _program_id(0) + grid0 = _program_id(0, new_grid) for i, s in enumerate(collapse_dims): out_idx = launch_grid_to_pallas_grid[i] s = _i32_constant(s) @@ -235,7 +240,7 @@ def _process_grid_to_3d_grid(grid_mapping: GridMapping): for i in range(len(prog_id_dims)): out_idx = launch_grid_to_pallas_grid[num_collapse + i] - out_indices[out_idx] = _program_id(i + 1) + out_indices[out_idx] = _program_id(i + 1, new_grid) assert len(out_indices) == len(grid_mapping.grid) return new_grid, out_indices @@ -277,6 +282,10 @@ def lower_jaxpr_to_triton_module( raise NotImplementedError( "scalar prefetch not implemented in the Triton backend" ) + if jaxpr.invars[grid_mapping.slice_scratch_ops]: + raise NotImplementedError( + "scratch memory not implemented in the Triton backend" + ) with grid_mapping.trace_env(): jaxpr, _ = pe.dce_jaxpr( jaxpr, [True] * len(jaxpr.outvars), instantiate=True @@ -320,11 +329,6 @@ def lower_jaxpr_to_triton_module( raise NotImplementedError( "Scalar prefetch not supported in Triton lowering." ) - if not all(isinstance(bm.indexing_mode, Blocked) - for bm in grid_mapping.block_mappings): - raise NotImplementedError( - "Only Blocked indexing mode is supported in Triton lowering." - ) start_indices = map( functools.partial(_eval_index_map, ctx, program_ids), grid_mapping.block_mappings, @@ -377,14 +381,12 @@ def write_env(var: jax_core.Var, val): raise NotImplementedError( "Unimplemented primitive in Pallas GPU lowering: " f"{eqn.primitive.name}. " - "Please file an issue on https://github.com/google/jax/issues.") + "Please file an issue on https://github.com/jax-ml/jax/issues.") rule = triton_lowering_rules[eqn.primitive] avals_in = [v.aval for v in eqn.invars] avals_out = [v.aval for v in eqn.outvars] eqn_block_infos = map(read_block_info_env, eqn.invars) - loc = mlir._source_info_to_location( - ctx, eqn.primitive, eqn.params, eqn.source_info - ) + loc = mlir._source_info_to_location(ctx, eqn.primitive, eqn.source_info) rule_ctx = LoweringRuleContext(ctx, avals_in, avals_out, eqn_block_infos) try: with source_info_util.user_context(eqn.source_info.traceback), loc: @@ -406,13 +408,30 @@ def write_env(var: jax_core.Var, val): return map(read_env, jaxpr.outvars) +def lower_fun( + fun: Callable[..., Any], *, multiple_results: bool +) -> Callable[..., Any]: + fn = fun if multiple_results else lambda *args, **kw: (fun(*args, **kw),) + + def f_lowered(ctx: LoweringRuleContext, *args, **params): + wrapped_fun = lu.wrap_init(fn, params) + jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(wrapped_fun, ctx.avals_in) + jaxpr = jax_core.ClosedJaxpr(jaxpr, consts) + out = _closed_call_lowering_rule(ctx, *args, call_jaxpr=jaxpr) + return out if multiple_results else out[0] + + return f_lowered + + # # Primitive lowering rules # ## Programming model primitives -def _program_id(axis: int) -> ir.Value: +def _program_id(axis: int, launch_grid: Sequence[int]) -> ir.Value: if axis not in range(3): raise ValueError(f"axis must be in [0, 3), but got: {axis}") + if launch_grid[axis] == 1: + return _i32_constant(0) return tt_dialect.get_program_id(axis) @@ -461,7 +480,7 @@ def _atomic_lowering_rule( raise NotImplementedError("Only single indexer is supported.") idx = indexers[0] ptr = _compute_pointers_from_indices( - ptr, ctx.block_infos[0], idx, ctx.avals_in[0].shape + ptr, ctx.block_infos[0], idx, ctx.avals_in[0] ) val = _ensure_ir_value(val, value_aval) if mask is not None: @@ -558,7 +577,7 @@ def _not_lowering_rule(ctx: LoweringRuleContext, x): @dataclasses.dataclass(frozen=True) class _Extern: - arg_types: Sequence[str] + arg_types: Sequence[jax.typing.DTypeLike] symbol: str result_type: str @@ -566,7 +585,8 @@ def matches(self, avals: Sequence[jax_core.ShapedArray]) -> bool: if len(avals) != len(self.arg_types): return False return all( - aval.weak_type or aval.dtype.name == arg_type + aval.dtype == jnp.dtype(arg_type) + or (aval.weak_type and aval.dtype.kind == jnp.dtype(arg_type).kind) for aval, arg_type in zip(avals, self.arg_types) ) @@ -587,7 +607,7 @@ def lower(self, ctx: LoweringRuleContext, *args: Sequence[ir.Value]): @dataclasses.dataclass(frozen=True) class _Fallback: - arg_types: Sequence[str] + arg_types: Sequence[jax.typing.DTypeLike] lower: Callable[..., ir.Value] matches = _Extern.matches @@ -601,7 +621,7 @@ def inner(ctx: LoweringRuleContext, *args: ir.Value) -> ir.Value: table = tables[ctx.context.platform] h = next((e for e in table if e.matches(ctx.avals_in)), None) if h is None: - arg_aval_dtypes = tuple(aval.dtype.name for aval in ctx.avals_in) + arg_aval_dtypes = tuple(aval.dtype for aval in ctx.avals_in) raise NotImplementedError( f"unsupported types for {name}: {arg_aval_dtypes}" ) @@ -610,7 +630,7 @@ def inner(ctx: LoweringRuleContext, *args: ir.Value) -> ir.Value: bcast_args = [] for aval, arg, arg_type in zip(ctx.avals_in, args, h.arg_types): bcast_arg = _bcast_to(_ensure_ir_value(arg, aval), out_aval.shape) - if aval.weak_type and aval.dtype.name != arg_type: + if aval.weak_type and aval.dtype != jnp.dtype(arg_type): bcast_arg = _cast(bcast_arg, aval.dtype, jnp.dtype(arg_type)) bcast_args.append(bcast_arg) return h.lower(ctx, *bcast_args) @@ -621,16 +641,16 @@ def inner(ctx: LoweringRuleContext, *args: ir.Value) -> ir.Value: _abs_dispatch_table = _make_dispatch_table( "abs", cuda=[ - _Extern(["int32"], "__nv_abs", "int32"), - _Extern(["int64"], "__nv_llabs", "int64"), - _Extern(["float32"], "__nv_fabsf", "float32"), - _Extern(["float64"], "__nv_fabs", "float64"), + _Extern([jnp.int32], "__nv_abs", jnp.int32), + _Extern([jnp.int64], "__nv_llabs", jnp.int64), + _Extern([jnp.float32], "__nv_fabsf", jnp.float32), + _Extern([jnp.float64], "__nv_fabs", jnp.float64), ], rocm=[ - _Fallback(["int32"], lambda ctx, x: math_dialect.absi(x)), - _Fallback(["int64"], lambda ctx, x: math_dialect.absi(x)), - _Fallback(["float32"], lambda ctx, x: math_dialect.absf(x)), - _Fallback(["float64"], lambda ctx, x: math_dialect.absf(x)), + _Fallback([jnp.int32], lambda ctx, x: math_dialect.absi(x)), + _Fallback([jnp.int64], lambda ctx, x: math_dialect.absi(x)), + _Fallback([jnp.float32], lambda ctx, x: math_dialect.absf(x)), + _Fallback([jnp.float64], lambda ctx, x: math_dialect.absf(x)), ], ) @@ -654,330 +674,338 @@ def _abs_lowering_rule(ctx: LoweringRuleContext, x): lax.ceil_p: _make_dispatch_table( "ceil", cuda=[ - _Extern(["float32"], "__nv_ceilf", "float32"), - _Extern(["float64"], "__nv_ceil", "float64"), + _Extern([jnp.float32], "__nv_ceilf", jnp.float32), + _Extern([jnp.float64], "__nv_ceil", jnp.float64), ], rocm=[ - _Extern(["float32"], "__ocml_ceil_f32", "float32"), - _Extern(["float64"], "__ocml_ceil_f64", "float64"), + _Extern([jnp.float32], "__ocml_ceil_f32", jnp.float32), + _Extern([jnp.float64], "__ocml_ceil_f64", jnp.float64), ], ), lax.floor_p: _make_dispatch_table( "floor", cuda=[ - _Extern(["float32"], "__nv_floorf", "float32"), - _Extern(["float64"], "__nv_floor", "float64"), - _Fallback(["float16"], lambda ctx, x: math_dialect.floor(x)), - _Fallback(["bfloat16"], lambda ctx, x: math_dialect.floor(x)), + _Extern([jnp.float32], "__nv_floorf", jnp.float32), + _Extern([jnp.float64], "__nv_floor", jnp.float64), + _Fallback([jnp.float16], lambda ctx, x: math_dialect.floor(x)), + _Fallback([jnp.bfloat16], lambda ctx, x: math_dialect.floor(x)), ], rocm=[ - _Extern(["float32"], "__ocml_floor_f32", "float32"), - _Extern(["float64"], "__ocml_floor_f64", "float64"), - _Fallback(["float16"], lambda ctx, x: math_dialect.floor(x)), - _Fallback(["bfloat16"], lambda ctx, x: math_dialect.floor(x)), + _Extern([jnp.float32], "__ocml_floor_f32", jnp.float32), + _Extern([jnp.float64], "__ocml_floor_f64", jnp.float64), + _Fallback([jnp.float16], lambda ctx, x: math_dialect.floor(x)), + _Fallback([jnp.bfloat16], lambda ctx, x: math_dialect.floor(x)), ], ), lax.exp_p: _make_dispatch_table( "exp", cuda=[ - _Extern(["float32"], "__nv_expf", "float32"), - _Extern(["float64"], "__nv_exp", "float64"), - _Fallback(["float16"], lambda ctx, x: math_dialect.exp(x)), - _Fallback(["bfloat16"], lambda ctx, x: math_dialect.exp(x)), + _Extern([jnp.float32], "__nv_expf", jnp.float32), + _Extern([jnp.float64], "__nv_exp", jnp.float64), + _Fallback([jnp.float16], lambda ctx, x: math_dialect.exp(x)), + _Fallback([jnp.bfloat16], lambda ctx, x: math_dialect.exp(x)), ], rocm=[ - _Fallback(["float32"], lambda ctx, x: math_dialect.exp(x)), - _Fallback(["float64"], lambda ctx, x: math_dialect.exp(x)), - _Fallback(["float16"], lambda ctx, x: math_dialect.exp(x)), - _Fallback(["bfloat16"], lambda ctx, x: math_dialect.exp(x)), + _Fallback([jnp.float32], lambda ctx, x: math_dialect.exp(x)), + _Fallback([jnp.float64], lambda ctx, x: math_dialect.exp(x)), + _Fallback([jnp.float16], lambda ctx, x: math_dialect.exp(x)), + _Fallback([jnp.bfloat16], lambda ctx, x: math_dialect.exp(x)), ], ), lax.exp2_p: _make_dispatch_table( "exp2", cuda=[ - _Extern(["float32"], "__nv_exp2f", "float32"), - _Extern(["float64"], "__nv_exp2", "float64"), - _Fallback(["float16"], lambda ctx, x: math_dialect.exp2(x)), - _Fallback(["bfloat16"], lambda ctx, x: math_dialect.exp2(x)), + _Extern([jnp.float32], "__nv_exp2f", jnp.float32), + _Extern([jnp.float64], "__nv_exp2", jnp.float64), + _Fallback([jnp.float16], lambda ctx, x: math_dialect.exp2(x)), + _Fallback([jnp.bfloat16], lambda ctx, x: math_dialect.exp2(x)), ], rocm=[ - _Extern(["float32"], "__ocml_exp2_f32", "float32"), - _Extern(["float64"], "__ocml_exp2_f64", "float64"), - _Fallback(["float16"], lambda ctx, x: math_dialect.exp2(x)), - _Fallback(["bfloat16"], lambda ctx, x: math_dialect.exp2(x)), + _Extern([jnp.float32], "__ocml_exp2_f32", jnp.float32), + _Extern([jnp.float64], "__ocml_exp2_f64", jnp.float64), + _Fallback([jnp.float16], lambda ctx, x: math_dialect.exp2(x)), + _Fallback([jnp.bfloat16], lambda ctx, x: math_dialect.exp2(x)), ], ), lax.expm1_p: _make_dispatch_table( "expm1", cuda=[ - _Extern(["float32"], "__nv_expm1f", "float32"), - _Extern(["float64"], "__nv_expm1", "float64"), + _Extern([jnp.float32], "__nv_expm1f", jnp.float32), + _Extern([jnp.float64], "__nv_expm1", jnp.float64), ], rocm=[ - _Extern(["float32"], "__ocml_expm1_f32", "float32"), - _Extern(["float64"], "__ocml_expm1_f64", "float64"), + _Extern([jnp.float32], "__ocml_expm1_f32", jnp.float32), + _Extern([jnp.float64], "__ocml_expm1_f64", jnp.float64), ], ), lax.log_p: _make_dispatch_table( "log", cuda=[ - _Extern(["float32"], "__nv_logf", "float32"), - _Extern(["float64"], "__nv_log", "float64"), - _Fallback(["float16"], lambda ctx, x: math_dialect.log(x)), - _Fallback(["bfloat16"], lambda ctx, x: math_dialect.log(x)), + _Extern([jnp.float32], "__nv_logf", jnp.float32), + _Extern([jnp.float64], "__nv_log", jnp.float64), + _Fallback([jnp.float16], lambda ctx, x: math_dialect.log(x)), + _Fallback([jnp.bfloat16], lambda ctx, x: math_dialect.log(x)), ], rocm=[ - _Extern(["float32"], "__ocml_log_f32", "float32"), - _Extern(["float64"], "__ocml_log_f64", "float64"), - _Fallback(["float16"], lambda ctx, x: math_dialect.log(x)), - _Fallback(["bfloat16"], lambda ctx, x: math_dialect.log(x)), + _Extern([jnp.float32], "__ocml_log_f32", jnp.float32), + _Extern([jnp.float64], "__ocml_log_f64", jnp.float64), + _Fallback([jnp.float16], lambda ctx, x: math_dialect.log(x)), + _Fallback([jnp.bfloat16], lambda ctx, x: math_dialect.log(x)), ], ), lax.log1p_p: _make_dispatch_table( "log1p", cuda=[ - _Extern(["float32"], "__nv_log1pf", "float32"), - _Extern(["float64"], "__nv_log1p", "float64"), + _Extern([jnp.float32], "__nv_log1pf", jnp.float32), + _Extern([jnp.float64], "__nv_log1p", jnp.float64), ], rocm=[ - _Extern(["float32"], "__ocml_log1p_f32", "float32"), - _Extern(["float64"], "__ocml_log1p_f64", "float64"), + _Extern([jnp.float32], "__ocml_log1p_f32", jnp.float32), + _Extern([jnp.float64], "__ocml_log1p_f64", jnp.float64), ], ), lax.sqrt_p: _make_dispatch_table( "sqrt", cuda=[ - _Extern(["float32"], "__nv_sqrtf", "float32"), - _Extern(["float64"], "__nv_sqrt", "float64"), - _Fallback(["float16"], lambda ctx, x: math_dialect.sqrt(x)), - _Fallback(["bfloat16"], lambda ctx, x: math_dialect.sqrt(x)), + _Extern([jnp.float32], "__nv_sqrtf", jnp.float32), + _Extern([jnp.float64], "__nv_sqrt", jnp.float64), + _Fallback([jnp.float16], lambda ctx, x: math_dialect.sqrt(x)), + _Fallback([jnp.bfloat16], lambda ctx, x: math_dialect.sqrt(x)), ], rocm=[ - _Extern(["float32"], "__ocml_sqrt_f32", "float32"), - _Extern(["float64"], "__ocml_sqrt_f64", "float64"), - _Fallback(["float16"], lambda ctx, x: math_dialect.sqrt(x)), - _Fallback(["bfloat16"], lambda ctx, x: math_dialect.sqrt(x)), + _Extern([jnp.float32], "__ocml_sqrt_f32", jnp.float32), + _Extern([jnp.float64], "__ocml_sqrt_f64", jnp.float64), + _Fallback([jnp.float16], lambda ctx, x: math_dialect.sqrt(x)), + _Fallback([jnp.bfloat16], lambda ctx, x: math_dialect.sqrt(x)), ], ), lax.pow_p: _make_dispatch_table( "pow", cuda=[ - _Extern(["float32", "int32"], "__nv_powif", "float32"), - _Extern(["float64", "int32"], "__nv_powi", "float64"), - _Extern(["float32", "float32"], "__nv_powf", "float32"), - _Extern(["float64", "float64"], "__nv_pow", "float64"), + _Extern([jnp.float32, jnp.int32], "__nv_powif", jnp.float32), + _Extern([jnp.float64, jnp.int32], "__nv_powi", jnp.float64), + _Extern([jnp.float32, jnp.float32], "__nv_powf", jnp.float32), + _Extern([jnp.float64, jnp.float64], "__nv_pow", jnp.float64), ], rocm=[ - _Extern(["float32", "int32"], "__ocml_pown_f32", "float32"), - _Extern(["float64", "int32"], "__ocml_pown_f64", "float64"), - _Extern(["float32", "float32"], "__ocml_pow_f32", "float32"), - _Extern(["float64", "float64"], "__ocml_pow_f64", "float64"), + _Extern([jnp.float32, jnp.int32], "__ocml_pown_f32", jnp.float32), + _Extern([jnp.float64, jnp.int32], "__ocml_pown_f64", jnp.float64), + _Extern([jnp.float32, jnp.float32], "__ocml_pow_f32", jnp.float32), + _Extern([jnp.float64, jnp.float64], "__ocml_pow_f64", jnp.float64), ], ), lax.cbrt_p: _make_dispatch_table( "cbrt", cuda=[ - _Extern(["float32"], "__nv_cbrtf", "float32"), - _Extern(["float64"], "__nv_cbrt", "float64"), + _Extern([jnp.float32], "__nv_cbrtf", jnp.float32), + _Extern([jnp.float64], "__nv_cbrt", jnp.float64), ], rocm=[ - _Extern(["float32"], "__ocml_cbrt_f32", "float32"), - _Extern(["float64"], "__ocml_cbrt_f64", "float64"), + _Extern([jnp.float32], "__ocml_cbrt_f32", jnp.float32), + _Extern([jnp.float64], "__ocml_cbrt_f64", jnp.float64), ], ), lax.rsqrt_p: _make_dispatch_table( "rsqrt", cuda=[ - _Extern(["float32"], "__nv_rsqrtf", "float32"), - _Extern(["float64"], "__nv_rsqrt", "float64"), + _Extern([jnp.float32], "__nv_rsqrtf", jnp.float32), + _Extern([jnp.float64], "__nv_rsqrt", jnp.float64), ], rocm=[ - _Extern(["float32"], "__ocml_rsqrt_f32", "float32"), - _Extern(["float64"], "__ocml_rsqrt_f64", "float64"), + _Extern([jnp.float32], "__ocml_rsqrt_f32", jnp.float32), + _Extern([jnp.float64], "__ocml_rsqrt_f64", jnp.float64), ], ), lax.sin_p: _make_dispatch_table( "sin", cuda=[ - _Extern(["float32"], "__nv_sinf", "float32"), - _Extern(["float64"], "__nv_sin", "float64"), - _Fallback(["float16"], lambda ctx, x: math_dialect.sin(x)), - _Fallback(["bfloat16"], lambda ctx, x: math_dialect.sin(x)), + _Extern([jnp.float32], "__nv_sinf", jnp.float32), + _Extern([jnp.float64], "__nv_sin", jnp.float64), + _Fallback([jnp.float16], lambda ctx, x: math_dialect.sin(x)), + _Fallback([jnp.bfloat16], lambda ctx, x: math_dialect.sin(x)), ], rocm=[ - _Extern(["float32"], "__ocml_sin_f32", "float32"), - _Extern(["float64"], "__ocml_sin_f64", "float64"), - _Fallback(["float16"], lambda ctx, x: math_dialect.sin(x)), - _Fallback(["bfloat16"], lambda ctx, x: math_dialect.sin(x)), + _Extern([jnp.float32], "__ocml_sin_f32", jnp.float32), + _Extern([jnp.float64], "__ocml_sin_f64", jnp.float64), + _Fallback([jnp.float16], lambda ctx, x: math_dialect.sin(x)), + _Fallback([jnp.bfloat16], lambda ctx, x: math_dialect.sin(x)), ], ), lax.cos_p: _make_dispatch_table( "cos", cuda=[ - _Extern(["float32"], "__nv_cosf", "float32"), - _Extern(["float64"], "__nv_cos", "float64"), - _Fallback(["float16"], lambda ctx, x: math_dialect.cos(x)), - _Fallback(["bfloat16"], lambda ctx, x: math_dialect.cos(x)), + _Extern([jnp.float32], "__nv_cosf", jnp.float32), + _Extern([jnp.float64], "__nv_cos", jnp.float64), + _Fallback([jnp.float16], lambda ctx, x: math_dialect.cos(x)), + _Fallback([jnp.bfloat16], lambda ctx, x: math_dialect.cos(x)), ], rocm=[ - _Extern(["float32"], "__ocml_cos_f32", "float32"), - _Extern(["float64"], "__ocml_cos_f64", "float64"), - _Fallback(["float16"], lambda ctx, x: math_dialect.cos(x)), - _Fallback(["bfloat16"], lambda ctx, x: math_dialect.cos(x)), + _Extern([jnp.float32], "__ocml_cos_f32", jnp.float32), + _Extern([jnp.float64], "__ocml_cos_f64", jnp.float64), + _Fallback([jnp.float16], lambda ctx, x: math_dialect.cos(x)), + _Fallback([jnp.bfloat16], lambda ctx, x: math_dialect.cos(x)), ], ), lax.tan_p: _make_dispatch_table( "tan", cuda=[ - _Extern(["float32"], "__nv_tanf", "float32"), - _Extern(["float64"], "__nv_tan", "float64"), + _Extern([jnp.float32], "__nv_tanf", jnp.float32), + _Extern([jnp.float64], "__nv_tan", jnp.float64), ], rocm=[ - _Extern(["float32"], "__ocml_tan_f32", "float32"), - _Extern(["float64"], "__ocml_tan_f64", "float64"), + _Extern([jnp.float32], "__ocml_tan_f32", jnp.float32), + _Extern([jnp.float64], "__ocml_tan_f64", jnp.float64), ], ), lax.asin_p: _make_dispatch_table( "asin", cuda=[ - _Extern(["float32"], "__nv_asinf", "float32"), - _Extern(["float64"], "__nv_asin", "float64"), + _Extern([jnp.float32], "__nv_asinf", jnp.float32), + _Extern([jnp.float64], "__nv_asin", jnp.float64), ], rocm=[ - _Extern(["float32"], "__ocml_asin_f32", "float32"), - _Extern(["float64"], "__ocml_asin_f64", "float64"), + _Extern([jnp.float32], "__ocml_asin_f32", jnp.float32), + _Extern([jnp.float64], "__ocml_asin_f64", jnp.float64), ], ), lax.acos_p: _make_dispatch_table( "acos", cuda=[ - _Extern(["float32"], "__nv_acosf", "float32"), - _Extern(["float64"], "__nv_acos", "float64"), + _Extern([jnp.float32], "__nv_acosf", jnp.float32), + _Extern([jnp.float64], "__nv_acos", jnp.float64), ], rocm=[ - _Extern(["float32"], "__ocml_acos_f32", "float32"), - _Extern(["float64"], "__ocml_acos_f64", "float64"), + _Extern([jnp.float32], "__ocml_acos_f32", jnp.float32), + _Extern([jnp.float64], "__ocml_acos_f64", jnp.float64), ], ), lax.atan_p: _make_dispatch_table( "atan", cuda=[ - _Extern(["float32"], "__nv_atanf", "float32"), - _Extern(["float64"], "__nv_atan", "float64"), + _Extern([jnp.float32], "__nv_atanf", jnp.float32), + _Extern([jnp.float64], "__nv_atan", jnp.float64), ], rocm=[ - _Extern(["float32"], "__ocml_atan_f32", "float32"), - _Extern(["float64"], "__ocml_atan_f64", "float64"), + _Extern([jnp.float32], "__ocml_atan_f32", jnp.float32), + _Extern([jnp.float64], "__ocml_atan_f64", jnp.float64), ], ), lax.atan2_p: _make_dispatch_table( "atan2", cuda=[ - _Extern(["float32", "float32"], "__nv_atan2f", "float32"), - _Extern(["float64", "float64"], "__nv_atan2", "float64"), + _Extern([jnp.float32, jnp.float32], "__nv_atan2f", jnp.float32), + _Extern([jnp.float64, jnp.float64], "__nv_atan2", jnp.float64), ], rocm=[ - _Extern(["float32", "float32"], "__ocml_atan2_f32", "float32"), - _Extern(["float64", "float64"], "__ocml_atan2_f64", "float64"), + _Extern( + [jnp.float32, jnp.float32], "__ocml_atan2_f32", jnp.float32 + ), + _Extern( + [jnp.float64, jnp.float64], "__ocml_atan2_f64", jnp.float64 + ), ], ), lax.sinh_p: _make_dispatch_table( "sinh", cuda=[ - _Extern(["float32"], "__nv_sinhf", "float32"), - _Extern(["float64"], "__nv_sinh", "float64"), + _Extern([jnp.float32], "__nv_sinhf", jnp.float32), + _Extern([jnp.float64], "__nv_sinh", jnp.float64), ], rocm=[ - _Extern(["float32"], "__ocml_sinh_f32", "float32"), - _Extern(["float64"], "__ocml_sinh_f64", "float64"), + _Extern([jnp.float32], "__ocml_sinh_f32", jnp.float32), + _Extern([jnp.float64], "__ocml_sinh_f64", jnp.float64), ], ), lax.cosh_p: _make_dispatch_table( "cosh", cuda=[ - _Extern(["float32"], "__nv_coshf", "float32"), - _Extern(["float64"], "__nv_cosh", "float64"), + _Extern([jnp.float32], "__nv_coshf", jnp.float32), + _Extern([jnp.float64], "__nv_cosh", jnp.float64), ], rocm=[ - _Extern(["float32"], "__ocml_cosh_f32", "float32"), - _Extern(["float64"], "__ocml_cosh_f64", "float64"), + _Extern([jnp.float32], "__ocml_cosh_f32", jnp.float32), + _Extern([jnp.float64], "__ocml_cosh_f64", jnp.float64), ], ), lax.tanh_p: _make_dispatch_table( "tanh", cuda=[ - _Extern(["float32"], "__nv_tanhf", "float32"), - _Extern(["float64"], "__nv_tanh", "float64"), + _Extern([jnp.float32], "__nv_tanhf", jnp.float32), + _Extern([jnp.float64], "__nv_tanh", jnp.float64), ], rocm=[ - _Extern(["float32"], "__ocml_tanh_f32", "float32"), - _Extern(["float64"], "__ocml_tanh_f64", "float64"), + _Extern([jnp.float32], "__ocml_tanh_f32", jnp.float32), + _Extern([jnp.float64], "__ocml_tanh_f64", jnp.float64), ], ), lax.asinh_p: _make_dispatch_table( "asinh", cuda=[ - _Extern(["float32"], "__nv_asinhf", "float32"), - _Extern(["float64"], "__nv_asinh", "float64"), + _Extern([jnp.float32], "__nv_asinhf", jnp.float32), + _Extern([jnp.float64], "__nv_asinh", jnp.float64), ], rocm=[ - _Extern(["float32"], "__ocml_asinh_f32", "float32"), - _Extern(["float64"], "__ocml_asinh_f64", "float64"), + _Extern([jnp.float32], "__ocml_asinh_f32", jnp.float32), + _Extern([jnp.float64], "__ocml_asinh_f64", jnp.float64), ], ), lax.acosh_p: _make_dispatch_table( "acosh", cuda=[ - _Extern(["float32"], "__nv_acoshf", "float32"), - _Extern(["float64"], "__nv_acosh", "float64"), + _Extern([jnp.float32], "__nv_acoshf", jnp.float32), + _Extern([jnp.float64], "__nv_acosh", jnp.float64), ], rocm=[ - _Extern(["float32"], "__ocml_acosh_f32", "float32"), - _Extern(["float64"], "__ocml_acosh_f64", "float64"), + _Extern([jnp.float32], "__ocml_acosh_f32", jnp.float32), + _Extern([jnp.float64], "__ocml_acosh_f64", jnp.float64), ], ), lax.atanh_p: _make_dispatch_table( "atanh", cuda=[ - _Extern(["float32"], "__nv_atanhf", "float32"), - _Extern(["float64"], "__nv_atanh", "float64"), + _Extern([jnp.float32], "__nv_atanhf", jnp.float32), + _Extern([jnp.float64], "__nv_atanh", jnp.float64), ], rocm=[ - _Extern(["float32"], "__ocml_atanh_f32", "float32"), - _Extern(["float64"], "__ocml_atanh_f64", "float64"), + _Extern([jnp.float32], "__ocml_atanh_f32", jnp.float32), + _Extern([jnp.float64], "__ocml_atanh_f64", jnp.float64), ], ), lax.population_count_p: _make_dispatch_table( "population_count", cuda=[ - _Extern(["int32"], "__nv_popc", "int32"), - _Extern(["int64"], "__nv_popcll", "int32"), + _Extern([jnp.int32], "__nv_popc", jnp.int32), + _Extern([jnp.int64], "__nv_popcll", jnp.int32), ], rocm=[ - _Fallback(["int32"], lambda ctx, x: math_dialect.ctpop(x)), - _Fallback(["int64"], lambda ctx, x: math_dialect.ctpop(x)), + _Fallback([jnp.int32], lambda ctx, x: math_dialect.ctpop(x)), + _Fallback([jnp.int64], lambda ctx, x: math_dialect.ctpop(x)), ], ), lax.clz_p: _make_dispatch_table( "clz", cuda=[ - _Extern(["int32"], "__nv_clz", "int32"), - _Extern(["int64"], "__nv_clzll", "int32"), + _Extern([jnp.int32], "__nv_clz", jnp.int32), + _Extern([jnp.int64], "__nv_clzll", jnp.int32), ], rocm=[ - _Fallback(["int32"], lambda ctx, x: math_dialect.ctlz(x)), - _Fallback(["int64"], lambda ctx, x: math_dialect.ctlz(x)), + _Fallback([jnp.int32], lambda ctx, x: math_dialect.ctlz(x)), + _Fallback([jnp.int64], lambda ctx, x: math_dialect.ctlz(x)), ], ), lax.nextafter_p: _make_dispatch_table( "nextafter", cuda=[ - _Extern(["float32", "float32"], "__nv_nextafterf", "float32"), - _Extern(["float64", "float64"], "__nv_nextafter", "float64"), + _Extern([jnp.float32, jnp.float32], "__nv_nextafterf", jnp.float32 ), + _Extern([jnp.float64, jnp.float64], "__nv_nextafter", jnp.float64), ], rocm=[ - _Extern(["float32", "float32"], "__ocml_nextafter_f32", "float32"), - _Extern(["float64", "float64"], "__ocml_nextafter_f64", "float64"), + _Extern( + [jnp.float32, jnp.float32], "__ocml_nextafter_f32", jnp.float32 + ), + _Extern( + [jnp.float64, jnp.float64], "__ocml_nextafter_f64", jnp.float64 + ), ], ), }) @@ -992,21 +1020,19 @@ def _minus(x: ir.Value) -> ir.Value: def _add(x: ir.Value, y: ir.Value): x_element_type = _element_type(x.type) y_element_type = _element_type(y.type) - if tt_dialect.PointerType.isinstance(y_element_type): - assert not tt_dialect.PointerType.isinstance(x_element_type) - x, y = y, x - x_element_type, y_element_type = y_element_type, x_element_type if tt_dialect.PointerType.isinstance(x_element_type): + assert not tt_dialect.PointerType.isinstance(y_element_type) return tt_dialect.addptr(x.type, x, y) + if tt_dialect.PointerType.isinstance(y_element_type): + return tt_dialect.addptr(y.type, y, x) assert x.type == y.type, (str(x.type), str(y.type)) if isinstance(x_element_type, ir.IntegerType): return arith_dialect.addi(x, y) - elif isinstance(x_element_type, ir.FloatType): + if isinstance(x_element_type, ir.FloatType): return arith_dialect.addf(x, y) - else: - raise NotImplementedError(f"unsupported dtypes: {x.type} and {y.type}") + raise NotImplementedError(f"unsupported dtypes: {x.type} and {y.type}") def _sub(x: ir.Value, y: ir.Value) -> ir.Value: @@ -1182,7 +1208,14 @@ def debug_print_lowering_rule( "pl.debug_print() does not support placeholders when lowering to Triton" ) - tt_dialect.print_(f" {fmt} ", hex=False, args=args) + tt_dialect.print_( + f" {fmt} ", + hex=False, + args=args, + is_signed=ir.DenseI32ArrayAttr.get([ + jnp.issubdtype(aval.dtype, jnp.signedinteger) for aval in ctx.avals_in + ]), + ) return () @@ -1259,21 +1292,6 @@ def _integer_pow_rule(ctx: LoweringRuleContext, x, *, y: int): return acc -def lower_fun( - fun: Callable[..., Any], *, multiple_results: bool -) -> Callable[..., Any]: - fn = fun if multiple_results else lambda *args, **kw: (fun(*args, **kw),) - - def f_lowered(ctx: LoweringRuleContext, *args, **params): - wrapped_fun = lu.wrap_init(fn, params) - jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(wrapped_fun, ctx.avals_in) - jaxpr = jax_core.ClosedJaxpr(jaxpr, consts) - out = _closed_call_lowering_rule(ctx, *args, call_jaxpr=jaxpr) - return out if multiple_results else out[0] - - return f_lowered - - _JAX_FN_MAPPING = { lax.clamp_p: lambda min, a, max: jnp.minimum(jnp.maximum(min, a), max), lax.logistic_p: lambda a: 1 / (1 + jnp.exp(-a)), @@ -1333,15 +1351,14 @@ def _div_lowering_rule(ctx: LoweringRuleContext, x, y): return _floordiv(x, y, signed=signed) -@register_lowering(lax.sign_p) -def _sign_lowering_rule(ctx: LoweringRuleContext, x): - [x_aval] = ctx.avals_in - signed = jnp.issubdtype(x_aval.dtype, jnp.signedinteger) - zero = _full(x.type, 0) - return _sub( - _cast(_greater_than(x, zero, signed=signed), jnp.bool_, x_aval.dtype), - _cast(_less_than(x, zero, signed=signed), jnp.bool_, x_aval.dtype), - ) +register_lowering(lax.sign_p)( + lower_fun(pallas_utils.sign_lowering_helper, multiple_results=False) +) + + +register_lowering(lax.erf_inv_p)( + lower_fun(pallas_utils.erf_inv_lowering_helper, multiple_results=False) +) @register_lowering(lax.iota_p) @@ -1648,12 +1665,12 @@ def _compute_pointers_from_indices( root_ptr: ir.Value, block_info: BlockInfo | None, nd_indexer: NDIndexer, - array_shape: tuple[int, ...], + array_shape_dtype: Any, ) -> ir.Value: if block_info is None: # TODO(necula): is this branch dead? - full_shape = array_shape + full_shape = array_shape_dtype.shape num_mapped_dims = 0 - block_shape = array_shape + block_shape = array_shape_dtype.shape else: full_shape = block_info.full_shape_dtype.shape num_mapped_dims = sum( @@ -1666,7 +1683,6 @@ def _compute_pointers_from_indices( _check_tensor_size(indexer_shape) indices = nd_indexer.indices other_shape = indexer_shape[len(int_indexer_shape) :] - bcast_indices = [] other_shape_idx = 0 if block_info is None: start_index_offsets = [None] * len(indices) @@ -1674,82 +1690,73 @@ def _compute_pointers_from_indices( start_index_offsets = block_info.start_indices assert len(indices) + num_mapped_dims == len(full_shape) assert len(start_index_offsets) == len(full_shape) + + array_dtype = jnp.dtype(array_shape_dtype.dtype) + full_size = math.prod(full_shape) * array_dtype.itemsize + # Use 64-bit indexing when offset might be >= 2**32 bytes. + offset_eltype = ir.IntegerType.get_signless(64 if full_size > 2**32 else 32) + if indexer_shape: + offsets = _full(ir.RankedTensorType.get(indexer_shape, offset_eltype), 0) + else: + offsets = _ir_constant(0, offset_eltype) + indexer_iter = iter(indices) for dim_stride, dim_block_size, start_offset in zip( strides, block_shape, start_index_offsets ): if dim_block_size is pallas_core.mapped: - index = _i32_constant(0) + index = _ir_constant(0, offset_eltype) else: index = next(indexer_iter) + + if isinstance(index, slice): + index = primitives.Slice.from_slice(index, dim_block_size) + if isinstance(index, primitives.Slice): - if index.is_dynamic_start: - # Compute the offset as start + range(0, size). - ptr_dim_offset = _add( - _bcast_to(index.start, [index.size]), - _ir_cast(_make_range(0, index.size), index.start.type, signed=False), - ) - elif index.stride > 1: - # Compute the offset as start + range(0, size) * stride. - iota = _make_range(0, index.size) - ptr_dim_offset = _add( - _bcast_to(_i32_constant(index.start), [index.size]), - _mul(iota, _full(iota.type, index.stride)), - ) + if index.is_dynamic_start or (index.stride != 1): + start = index.start + if not index.is_dynamic_start: + start = _ir_constant(start, offset_eltype) + start = _ir_cast(start, offset_eltype, signed=False) + + iota = _ir_cast(_make_range(0, index.size), offset_eltype, signed=False) + if index.stride != 1: + iota = _mul(iota, _full(iota.type, index.stride)) + dim_offsets = _add(_bcast_to(start, [index.size]), iota) else: - ptr_dim_offset = _make_range(index.start, index.start + index.size) + iota = _make_range(index.start, index.start + index.size) + dim_offsets = _ir_cast(iota, offset_eltype, signed=False) - # We need to add broadcastable dimensions for the advanced int indexing - # and for previous slices - num_left_expand_dims = len(int_indexer_shape) + other_shape_idx - num_right_expand_dims = len(other_shape) - other_shape_idx - 1 - other_shape_idx += 1 - elif isinstance(index, slice): - if index != slice(None): - raise NotImplementedError("Only `slice(None)` allowed.") - ptr_dim_offset = _make_range(0, dim_block_size) - num_left_expand_dims = len(int_indexer_shape) + other_shape_idx - num_right_expand_dims = len(other_shape) - other_shape_idx - 1 other_shape_idx += 1 + for _ in other_shape[other_shape_idx:]: + rank = ir.RankedTensorType(dim_offsets.type).rank + dim_offsets = _expand_dims(dim_offsets, rank) else: # indexer is either a *scalar* or an array of size `int_indexer_shape` - ptr_dim_offset = _ensure_ir_value( - index, jax_core.ShapedArray((), jnp.int32) - ) - num_left_expand_dims = 0 - num_right_expand_dims = len(other_shape) - if not ir.RankedTensorType.isinstance(ptr_dim_offset.type): - num_left_expand_dims = max(len(indexer_shape) - 1, 0) - else: - num_right_expand_dims = len(other_shape) + dim_offsets = index + if not isinstance(dim_offsets, ir.Value): + dim_offsets = _ir_constant(dim_offsets, offset_eltype) + dim_offsets = _ir_cast(dim_offsets, offset_eltype, signed=False) + + if ir.RankedTensorType.isinstance(dim_offsets.type): + for _ in other_shape: + rank = ir.RankedTensorType(dim_offsets.type).rank + dim_offsets = _expand_dims(dim_offsets, rank) + + if ir.RankedTensorType.isinstance(dim_offsets.type): + rank = ir.RankedTensorType(dim_offsets.type).rank + for _ in range(len(indexer_shape) - rank): + dim_offsets = _expand_dims(dim_offsets, 0) + dim_offsets = _bcast_to(dim_offsets, indexer_shape) - if indexer_shape and not ir.RankedTensorType.isinstance(ptr_dim_offset.type): - ptr_dim_offset = _splat(ptr_dim_offset, [1] * len(indexer_shape)) - else: - for _ in range(num_left_expand_dims): - ptr_dim_offset = _expand_dims(ptr_dim_offset, 0) - for _ in range(num_right_expand_dims): - ndim = len(getattr(ptr_dim_offset.type, "shape", [])) - ptr_dim_offset = _expand_dims(ptr_dim_offset, ndim) - - ptr_dim_offset = _bcast_to(ptr_dim_offset, indexer_shape) - index_type = ir.IntegerType(_element_type(ptr_dim_offset.type)) if start_offset is not None: - start_offset = _ir_cast(start_offset, index_type, signed=False) - ptr_dim_offset = _add( - ptr_dim_offset, _bcast_to(start_offset, indexer_shape) - ) + start_offset = _ir_cast(start_offset, offset_eltype, signed=False) + dim_offsets = _add(dim_offsets, _bcast_to(start_offset, indexer_shape)) - if index_type.width == 32: - stride_size = _i32_constant(dim_stride) - else: - stride_size = _i64_constant(dim_stride) - stride_size = _splat(stride_size, indexer_shape) - bcast_indices.append(_mul(ptr_dim_offset, stride_size)) + dim_offsets = _mul(dim_offsets, _full(dim_offsets.type, dim_stride)) + offsets = _add(offsets, dim_offsets) - return functools.reduce( - _add, bcast_indices, _bcast_to(root_ptr, indexer_shape) - ) + return _add(_bcast_to(root_ptr, indexer_shape), offsets) @register_lowering(sp.get_p) @@ -1864,7 +1871,7 @@ def _masked_load_lowering_rule( assert len(ctx.avals_in) == 1 return ptr ptr = _compute_pointers_from_indices( - ptr, ctx.block_infos[0], idx, ctx.avals_in[0].shape + ptr, ctx.block_infos[0], idx, ctx.avals_in[0] ) if mask is not None: mask = _bcast_to(_ensure_ir_value(mask, mask_aval), idx.get_indexer_shape()) @@ -1961,7 +1968,7 @@ def _masked_swap_lowering_rule( raise NotImplementedError("No support for multiple indexers yet.") idx = indexers[0] ptr = _compute_pointers_from_indices( - ptr, ctx.block_infos[0], idx, ctx.avals_in[0].shape + ptr, ctx.block_infos[0], idx, ctx.avals_in[0] ) other = None if value is not None: @@ -1986,10 +1993,7 @@ def _addupdate_lowering_rule(ctx: LoweringRuleContext, ptr, value, *idx, tree): raise NotImplementedError("No support for multiple indexers yet.") indexer = indexers[0] ptr = _compute_pointers_from_indices( - ptr, - ctx.block_infos[0], - indexer, - ctx.avals_in[0].shape, + ptr, ctx.block_infos[0], indexer, ctx.avals_in[0] ) op = tt_dialect.RMWOp.FADD if isinstance(_element_type(value.type), ir.IntegerType): @@ -2090,8 +2094,10 @@ def _dot_general_lowering( dimension_numbers, precision, preferred_element_type, + algorithm, + transpose_algorithm, ): - del preferred_element_type # Unused. + del preferred_element_type, algorithm, transpose_algorithm # Unused. ((a_contract_dim,), (b_contract_dim,)), batch_dims = dimension_numbers assert batch_dims == ((), ()) @@ -2198,7 +2204,7 @@ def _argreduce_lowering( if i != axis: index = _expand_dims(index, i) index = _bcast_to(index, a_aval.shape) - ctx = ctx.replace(avals_in=[a_aval, a_aval.update(dtype=jnp.dtype("int32"))]) + ctx = ctx.replace(avals_in=[a_aval, a_aval.update(dtype=jnp.dtype(jnp.int32))]) _, indices = _reduction_lowering(body, ctx, (a, index), axes=axes) return indices @@ -2363,10 +2369,8 @@ def _lower_jaxpr_to_for_loop( else: jaxpr_args = [*consts, *for_body_args] all_out = lower_jaxpr_to_triton_ir( - ctx.context, - jaxpr, - ctx.block_infos, - *jaxpr_args) + ctx.context, jaxpr, ctx.block_infos, *jaxpr_args + ) scf_dialect.yield_(all_out) return list(for_op.results_) @@ -2403,11 +2407,9 @@ def _scan_lowering_rule( args = map(_ensure_ir_value, args, ctx.avals_in) consts, args = util.split_list(args, [num_consts]) if has_loop_index: - lb, *args = args - lower_bound = lb - ub = _add(lb, _ir_constant(length, lb.type)) - upper_bound = ub - bound_type = ub.type + lower_bound, *args = args + upper_bound = _add(lower_bound, _ir_constant(length, lower_bound.type)) + bound_type = lower_bound.type else: lower_bound = _i32_constant(0) upper_bound = _i32_constant(length) diff --git a/jax/_src/pallas/triton/pallas_call_registration.py b/jax/_src/pallas/triton/pallas_call_registration.py index b94adfb8fb3f..67b0bd326616 100644 --- a/jax/_src/pallas/triton/pallas_call_registration.py +++ b/jax/_src/pallas/triton/pallas_call_registration.py @@ -49,8 +49,9 @@ def pallas_call_lowering( grid_mapping: pallas_core.GridMapping, compiler_params: dict[str, Any], cost_estimate: pallas_core.CostEstimate | None, + out_avals: tuple[jax_core.AbstractValue, ...], ): - del interpret + del interpret, out_avals if grid_mapping.num_dynamic_grid_bounds: raise NotImplementedError( "dynamic grid bounds not supported in the Triton backend" @@ -61,11 +62,14 @@ def pallas_call_lowering( ) triton_params = compiler_params.get("triton", compiler_params) num_warps = triton_params.pop("num_warps", 4) + num_warps = 4 if num_warps is None else num_warps [lowering_platform] = ctx.platforms or ctx.module_context.platforms if lowering_platform == "rocm": num_stages = triton_params.pop("num_stages", 1) + num_stages = 1 if num_stages is None else num_stages else: num_stages = triton_params.pop("num_stages", 3) + num_stages = 3 if num_stages is None else num_stages if debug: print(f"\nThe kernel jaxpr for pallas_call {name_and_src_info}:") @@ -101,9 +105,10 @@ def pallas_call_lowering( ) if "serialized_metadata" in (triton_params or {}): # This field is unstable and may be removed in the future. - backend_config["serialized_metadata"] = ir.StringAttr.get( - triton_params["serialized_metadata"] - ) + if triton_params["serialized_metadata"] is not None: + backend_config["serialized_metadata"] = ir.StringAttr.get( + triton_params["serialized_metadata"] + ) return mlir.custom_call( call_target_name="__gpu$xla.gpu.triton", result_types=out_types, diff --git a/jax/_src/pallas/triton/primitives.py b/jax/_src/pallas/triton/primitives.py index 8518a94ed9cf..23fce50dc4f9 100644 --- a/jax/_src/pallas/triton/primitives.py +++ b/jax/_src/pallas/triton/primitives.py @@ -20,6 +20,7 @@ import jax from jax import core as jax_core +from jax._src.lib.mlir.dialects import gpu as gpu_dialect from jax._src.lib.triton import dialect as tt_dialect from jax._src.pallas.triton import lowering from jax.interpreters import mlir @@ -120,3 +121,22 @@ def _elementwise_inline_asm_lowering( packed_element=pack, args=args, ).result + + +def debug_barrier() -> None: + """Synchronizes all kernel executions in the grid.""" + return debug_barrier_p.bind() + + +debug_barrier_p = jax_core.Primitive("debug_barrier_p") +debug_barrier_p.multiple_results = True + +@debug_barrier_p.def_abstract_eval +def _debug_barrier_abstract_eval() -> Sequence[jax_core.ShapedArray]: + return () + +@lowering.register_lowering(debug_barrier_p) +def _debug_barrier_lowering(ctx: lowering.LoweringRuleContext): + del ctx # Unused. + gpu_dialect.barrier() + return [] diff --git a/jax/_src/pallas/utils.py b/jax/_src/pallas/utils.py index 41466be0822d..e485537216ca 100644 --- a/jax/_src/pallas/utils.py +++ b/jax/_src/pallas/utils.py @@ -71,6 +71,10 @@ def next_power_of_2(x: int) -> int: raise ValueError("`next_power_of_2` requires a non-negative integer.") return 1 if x == 0 else 2 ** (x - 1).bit_length() +def dtype_bitwidth(dtype: np.dtype | jnp.dtype) -> int: + if jnp.issubdtype(dtype, jnp.integer): + return jnp.iinfo(dtype).bits + return np.dtype(dtype).itemsize * 8 def pattern_match_scan_to_fori_loop( jaxpr: jax_core.Jaxpr, num_consts: int, num_carry: int @@ -179,3 +183,121 @@ def pattern_match_while_to_fori_loop( outvars=new_outvars, ) return jaxpr, None + + +# based on https://github.com/openxla/xla/blob/a7a09d56c3599123f8148bbf3e44c9ebc04624b9/xla/mlir_hlo/mhlo/transforms/chlo_legalize_to_hlo/chlo_legalize_to_hlo.cc#L644-L802 +def _erf_inv_32_lowering_helper(x): + k_degree = 9 + w_lt_5_constants = [ + 2.81022636e-08, 3.43273939e-07, -3.5233877e-06, + -4.39150654e-06, 0.00021858087, -0.00125372503, + -0.00417768164, 0.246640727, 1.50140941, + ] + w_gt_5_constants = [ + -0.000200214257, 0.000100950558, 0.00134934322, + -0.00367342844, 0.00573950773, -0.0076224613, + 0.00943887047, 1.00167406, 2.83297682, + ] + + w = -jnp.log1p(x * -x) + w_lt_5 = w < 5.0 + + w = jnp.where(w_lt_5, w - 2.5, jnp.sqrt(w) - 3.0) + + p = jnp.where(w_lt_5, w_lt_5_constants[0], w_gt_5_constants[0]) + for i in range(1, k_degree): + c = jnp.where(w_lt_5, w_lt_5_constants[i], w_gt_5_constants[i]) + p = c + p * w + + return jnp.where(jnp.abs(x) == 1.0, jnp.inf * x, p * x) + + +# based on https://github.com/openxla/xla/blob/a7a09d56c3599123f8148bbf3e44c9ebc04624b9/xla/mlir_hlo/mhlo/transforms/chlo_legalize_to_hlo/chlo_legalize_to_hlo.cc#L696-L802 +def _erf_inv_64_lowering_helper(x): + w_lt_625_constants = [ + -3.6444120640178196996e-21, -1.685059138182016589e-19, + 1.2858480715256400167e-18, 1.115787767802518096e-17, + -1.333171662854620906e-16, 2.0972767875968561637e-17, + 6.6376381343583238325e-15, -4.0545662729752068639e-14, + -8.1519341976054721522e-14, 2.6335093153082322977e-12, + -1.2975133253453532498e-11, -5.4154120542946279317e-11, + 1.051212273321532285e-09, -4.1126339803469836976e-09, + -2.9070369957882005086e-08, 4.2347877827932403518e-07, + -1.3654692000834678645e-06, -1.3882523362786468719e-05, + 0.0001867342080340571352, -0.00074070253416626697512, + -0.0060336708714301490533, 0.24015818242558961693, + 1.6536545626831027356 + ] + + w_lt_16_constants = [ + 2.2137376921775787049e-09, 9.0756561938885390979e-08, + -2.7517406297064545428e-07, 1.8239629214389227755e-08, + 1.5027403968909827627e-06, -4.013867526981545969e-06, + 2.9234449089955446044e-06, 1.2475304481671778723e-05, + -4.7318229009055733981e-05, 6.8284851459573175448e-05, + 2.4031110387097893999e-05, -0.0003550375203628474796, + 0.00095328937973738049703, -0.0016882755560235047313, + 0.0024914420961078508066, -0.0037512085075692412107, + 0.005370914553590063617, 1.0052589676941592334, + 3.0838856104922207635, + ] + + w_gt_16_constants = [ + -2.7109920616438573243e-11, -2.5556418169965252055e-10, + 1.5076572693500548083e-09, -3.7894654401267369937e-09, + 7.6157012080783393804e-09, -1.4960026627149240478e-08, + 2.9147953450901080826e-08, -6.7711997758452339498e-08, + 2.2900482228026654717e-07, -9.9298272942317002539e-07, + 4.5260625972231537039e-06, -1.9681778105531670567e-05, + 7.5995277030017761139e-05, -0.00021503011930044477347, + -0.00013871931833623122026, 1.0103004648645343977, + 4.8499064014085844221, + ] # should add "as jnp.float64 array"? + + w = -jnp.log1p(x * -x) + w_lt_625 = w < 6.25 + w_lt_16 = w < 16.0 + + def get_coefficient(i): + c = w_lt_625_constants[i] + if i < 19: + c = jnp.where(w_lt_625, c, w_lt_16_constants[i]) + if i < 17: + c = jnp.where(w_lt_16, c, w_gt_16_constants[i]) + return c + + select2 = jnp.where(w_lt_16, 3.25, 5.0) + select2_result = jnp.sqrt(w) - select2 + w = jnp.where(w_lt_625, w - 3.125, select2_result) + + p = get_coefficient(0) + for i in range(1, 17): + p = get_coefficient(i) + p * w + for i in range(17, 19): + p = jnp.where(w_lt_16, get_coefficient(i) + p * w, p) + for i in range(19, 23): + p = jnp.where(w_lt_625, get_coefficient(i) + p * w, p) + + return jnp.where(jnp.abs(x) == 1.0, np.inf * x, p * x) + + +def erf_inv_lowering_helper(x): + if x.dtype == jnp.float32: + return _erf_inv_32_lowering_helper(x) + if x.dtype == jnp.float64: + return _erf_inv_64_lowering_helper(x) + raise NotImplementedError(f"erf_inv_lowering_helper not implemented for {x.dtype}") + + +def sign_lowering_helper(x): + if jnp.issubdtype(x.dtype, jnp.unsignedinteger): + return (x != 0).astype(x.dtype) + + if jnp.issubdtype(x.dtype, jnp.integer): + return (x > 0).astype(x.dtype) - (x < 0).astype(x.dtype) + + if jnp.issubdtype(x.dtype, jnp.floating): + out = (x > 0.).astype(x.dtype) - (x < 0.).astype(x.dtype) + return jnp.where(jnp.isnan(x), jnp.nan, out) + + raise NotImplementedError(f"sign_lowering_helper not implemented for {x.dtype}") diff --git a/jax/_src/path.py b/jax/_src/path.py index 1dd523249692..8c46c5560b3c 100644 --- a/jax/_src/path.py +++ b/jax/_src/path.py @@ -14,22 +14,33 @@ import logging import pathlib +import importlib.util -logger = logging.getLogger(__name__) +__all__ = ["Path"] -try: - import etils.epath as epath - epath_installed = True -except: - epath = None - epath_installed = False +logger = logging.getLogger(__name__) # If etils.epath (aka etils[epath] to pip) is present, we prefer it because it # can read and write to, e.g., GCS buckets. Otherwise we use the builtin # pathlib and can only read/write to the local filesystem. -if epath: +epath_installed = bool( + importlib.util.find_spec("etils") and + importlib.util.find_spec("etils.epath") +) +if epath_installed: logger.debug("etils.epath found. Using etils.epath for file I/O.") - Path = epath.Path + + def __dir__(): + return ["Path"] + + def __getattr__(name): + if name != "Path": + raise AttributeError(f"module '{__name__}' has no attribute '{name}") + + global Path + from etils import epath + Path = epath.Path + return Path else: logger.debug("etils.epath was not found. Using pathlib for file I/O.") Path = pathlib.Path diff --git a/jax/_src/pjit.py b/jax/_src/pjit.py index ee7af5183ad9..0a75128477ce 100644 --- a/jax/_src/pjit.py +++ b/jax/_src/pjit.py @@ -19,7 +19,6 @@ import dataclasses from functools import partial import inspect -import itertools as it import logging import operator as op import weakref @@ -63,7 +62,9 @@ from jax._src.lib.mlir.dialects import func as func_dialect from jax._src.lib import jax_jit from jax._src.lib import xla_client as xc +from jax._src.lib import xla_extension_version from jax._src import sharding +from jax._src.mesh import AbstractMesh from jax._src.sharding_impls import ( NamedSharding, GSPMDSharding, SingleDeviceSharding, PmapSharding, AUTO, UNSPECIFIED, UnspecifiedValue, @@ -164,7 +165,6 @@ class PjitInfo(NamedTuple): keep_unused: bool inline: bool abstracted_axes: Any | None - has_explicit_sharding: bool use_resource_env: bool # False for jit, True for pjit # Hash and compare PjitInfo by identity when used as a cache key. @@ -281,8 +281,7 @@ def _get_fastpath_data( fastpath_data = pxla.MeshExecutableFastpathData( executable.xla_executable, out_tree, in_shardings, executable._out_shardings, out_avals, out_committed, kept_var_bitvec, - executable.unsafe_call.in_handler.local_devices, - executable.unsafe_call.in_handler.input_indices) + executable._dispatch_in_layouts) else: fastpath_data = None return fastpath_data @@ -301,9 +300,7 @@ def _read_most_recent_pjit_call_executable(jaxpr): def _read_pgle_profiler(jaxpr): - return _most_recent_pjit_call_executable.weak_pgle_profiler_dict.get( - jaxpr, None - ) + return _most_recent_pjit_call_executable.weak_pgle_profiler_dict.get(jaxpr, None) def _cpp_pjit_evict_fn(self): self._clear_cache() @@ -314,20 +311,48 @@ def _cpp_pjit_evict_fn(self): # The entries are doubled here from the default 4096 because _pjit_call_impl # also has a cpp dispatch path and that would double the number of entries in # the global shared cache. -_cpp_pjit_cache = xc._xla.PjitFunctionCache(capacity=8192) +# This cache is only used for jit's with only fun. For example: jax.jit(f) +_cpp_pjit_cache_fun_only = xc._xla.PjitFunctionCache(capacity=8192) +# This cache is used for jit where extra arguments are defined other than the +# fun. For example: jax.jit(f, donate_argnums=...) OR +# jax.jit(f, out_shardings=...), etc. We don't use the same cache because the +# capacity might get full very fast because of all the jitted function in JAX +# which might evict train_step for example. +_cpp_pjit_cache_explicit_attributes = xc._xla.PjitFunctionCache(capacity=8192) -def _get_cpp_global_cache(pjit_has_explicit_sharding): - if pjit_has_explicit_sharding: - return xc._xla.PjitFunctionCache() - else: - return _cpp_pjit_cache + +if xla_extension_version < 286: + def _get_cpp_global_cache(pjit_has_explicit_sharding): + if pjit_has_explicit_sharding: + return xc._xla.PjitFunctionCache() + else: + return _cpp_pjit_cache_fun_only + + def _pjit_explicit_sharding_and_layout( + in_shardings_flat, out_shardings_flat, in_layouts_flat, out_layouts_flat, + device, backend) -> bool: + return (device is not None or + backend is not None or + any(not is_unspecified(i) for i in in_shardings_flat) or + any(not is_unspecified(o) for o in out_shardings_flat) or + any(i is not None for i in in_layouts_flat) or + any(o is not None for o in out_layouts_flat)) +else: + def _get_cpp_global_cache(contains_explicit_attributes: bool): # type: ignore + if contains_explicit_attributes: + return _cpp_pjit_cache_explicit_attributes + else: + return _cpp_pjit_cache_fun_only def _cpp_pjit(fun: Callable, jit_info: PjitInfo): @api_boundary def cache_miss(*args, **kwargs): + if config.no_tracing.value: + raise RuntimeError(f"re-tracing function {jit_info.fun_sourceinfo} for " + "`jit`, but 'no_tracing' is set") outs, out_flat, out_tree, args_flat, jaxpr, attrs_tracked = _python_pjit_helper( fun, jit_info, *args, **kwargs) executable = _read_most_recent_pjit_call_executable(jaxpr) @@ -339,12 +364,35 @@ def cache_miss(*args, **kwargs): return outs, maybe_fastpath_data, _need_to_rebuild_with_fdo(pgle_profiler) - cpp_pjit_f = xc._xla.pjit( - fun_name(fun), - fun, cache_miss, jit_info.static_argnums, jit_info.static_argnames, - jit_info.donate_argnums, tree_util.dispatch_registry, - lambda x, sharding: pxla.shard_args([sharding], [x])[0], - _get_cpp_global_cache(jit_info.has_explicit_sharding)) + if xla_extension_version >= 286: + cache_key = pxla.JitGlobalCppCacheKeys( + donate_argnums=jit_info.donate_argnums, + donate_argnames=jit_info.donate_argnames, + device=jit_info.device, backend=jit_info.backend, + in_shardings_treedef=jit_info.in_shardings_treedef, + in_shardings_leaves=jit_info.in_shardings_leaves, + out_shardings_treedef=jit_info.out_shardings_treedef, + out_shardings_leaves=jit_info.out_shardings_leaves, + in_layouts_treedef=jit_info.in_layouts_treedef, + in_layouts_leaves=jit_info.in_layouts_leaves, + out_layouts_treedef=jit_info.out_layouts_treedef, + out_layouts_leaves=jit_info.out_layouts_leaves, + use_resource_env=jit_info.use_resource_env) + cpp_pjit_f = xc._xla.pjit( + fun_name(fun), fun, cache_miss, jit_info.static_argnums, + jit_info.static_argnames, cache_key, tree_util.dispatch_registry, # type: ignore + pxla.cc_shard_arg, + _get_cpp_global_cache(cache_key.contains_explicit_attributes)) + else: + has_explicit_sharding = _pjit_explicit_sharding_and_layout( + jit_info.in_shardings_leaves, jit_info.out_shardings_leaves, + jit_info.in_layouts_leaves, jit_info.out_layouts_leaves, + jit_info.device, jit_info.backend) + cpp_pjit_f = xc._xla.pjit( + fun_name(fun), fun, cache_miss, jit_info.static_argnums, + jit_info.static_argnames, jit_info.donate_argnums, + tree_util.dispatch_registry, pxla.cc_shard_arg, + _get_cpp_global_cache(has_explicit_sharding)) cpp_pjitted_f = wraps(fun)(cpp_pjit_f) cpp_pjitted_f._fun = fun @@ -352,16 +400,6 @@ def cache_miss(*args, **kwargs): return cpp_pjitted_f -def _pjit_explicit_sharding(in_shardings, out_shardings, device, - backend) -> bool: - in_shardings_flat, _ = tree_flatten(in_shardings) - out_shardings_flat, _ = tree_flatten(out_shardings) - return (device is not None or - backend is not None or - any(not is_unspecified(i) for i in in_shardings_flat) or - any(not is_unspecified(i) for i in out_shardings_flat)) - - def _split_layout_and_sharding(entries): entries_flat, treedef = tree_flatten(entries, is_leaf=lambda x: x is None) layouts, shardings = [], [] @@ -421,7 +459,7 @@ def _parse_jit_arguments(fun: Callable, in_shardings: Any, out_shardings: Any, # list: if in_axes is not a leaf, it must be a tuple of trees. However, # in cases like these users expect tuples and lists to be treated # essentially interchangeably, so we canonicalize lists to tuples here - # rather than raising an error. https://github.com/google/jax/issues/2367 + # rather than raising an error. https://github.com/jax-ml/jax/issues/2367 in_shardings = tuple(in_shardings) in_layouts, in_shardings = _split_layout_and_sharding(in_shardings) @@ -445,9 +483,6 @@ def _parse_jit_arguments(fun: Callable, in_shardings: Any, out_shardings: Any, fun, fun_signature, donate_argnums, donate_argnames, static_argnums, static_argnames) - has_explicit_sharding = _pjit_explicit_sharding( - in_shardings, out_shardings, device, backend) - return PjitInfo( fun_sourceinfo=fun_sourceinfo, fun_signature=fun_signature, @@ -465,7 +500,6 @@ def _parse_jit_arguments(fun: Callable, in_shardings: Any, out_shardings: Any, donate_argnames=donate_argnames, device=device, backend=backend, keep_unused=keep_unused, inline=inline, abstracted_axes=abstracted_axes, - has_explicit_sharding=has_explicit_sharding, use_resource_env=use_resource_env) @@ -473,23 +507,15 @@ def _make_jit_wrapper(fun: Callable, jit_info: PjitInfo): @api_boundary def lower(*args, **kwargs): - traced = trace(*args, **kwargs) - try: - return traced.lower() - except pxla.DeviceAssignmentMismatchError as e: - fails, = e.args - fun_name = getattr(fun, '__qualname__', - getattr(fun, '__name__', str(fun))) - msg = _device_assignment_mismatch_error( - fun_name, fails, traced._args_flat, 'jit', traced._arg_names) - raise ValueError(msg) from None + return trace(*args, **kwargs).lower() @api_boundary def eval_shape(*args, **kwargs): p, _ = _infer_params(fun, jit_info, args, kwargs) out_s = [None if is_unspecified(s) else s for s in p.params['out_shardings']] # TODO(yashkatariya): Add `Layout` to SDS. - out = [api.ShapeDtypeStruct(x.shape, x.dtype, sharding=s) + out = [api.ShapeDtypeStruct(x.shape, x.dtype, sharding=s, + weak_type=x.weak_type) for x, s in zip(p.params['jaxpr'].out_avals, out_s)] return tree_unflatten(p.out_tree, out) @@ -501,7 +527,7 @@ def trace(*args, **kwargs) -> stages.Traced: lower_callable = partial(_resolve_and_lower, args_flat, **p.params, pgle_profiler=None) return stages.Traced( - p.params['jaxpr'], args_info, p.params["name"],p.out_tree, + p.params['jaxpr'], args_info, p.params["name"], p.out_tree, lower_callable, args_flat, p.arg_names, p.num_consts) wrapped = _cpp_pjit(fun, jit_info) @@ -619,6 +645,18 @@ def _infer_params_impl( "An overflow was encountered while parsing an argument to a jitted " f"computation, whose {arg_path}." ) from e + except TypeError as e: + arg_description = (f"path {dbg.arg_names[i]}" if dbg + else f"flattened argument number {i}") + raise TypeError( + f"Error interpreting argument to {fun} as an abstract array." + f" The problematic value is of type {type(a)} and was passed to" + f" the function at {arg_description}.\n" + "This typically means that a jit-wrapped function was called with a non-array" + " argument, and this argument was not marked as static using the" + " static_argnums or static_argnames parameters of jax.jit." + ) from e + in_type = in_avals = tuple(avals) else: in_type = in_avals @@ -1015,8 +1053,8 @@ def _create_sharding_for_array(mesh, x, name, api_name): ' then the mesh context manager is not required.') # A nice user error is raised in prepare_axis_resources. assert x is None or isinstance(x, ParsedPartitionSpec), x - return (pxla.create_mesh_pspec_sharding(mesh, x) - if x is None else pxla.create_mesh_pspec_sharding(mesh, x.user_spec, x)) + return (pxla.create_mesh_pspec_sharding(mesh, x) if x is None else + pxla.create_mesh_pspec_sharding(mesh, x.get_partition_spec(), x)) def _create_sharding_with_device_backend(device, backend): @@ -1250,7 +1288,7 @@ def unpack(key): return done() # we think this is unreachable... - p("explanation unavailable! please open an issue at https://github.com/google/jax") + p("explanation unavailable! please open an issue at https://github.com/jax-ml/jax") return done() @partial(lu.cache, explain=explain_tracing_cache_miss) @@ -1442,10 +1480,17 @@ def _resolve_in_layouts(args, jit_in_layouts, resolved_in_shardings, in_avals): resolved_in_layouts = [] for arg, jit_in_l, rs, aval in safe_zip( args, jit_in_layouts, resolved_in_shardings, in_avals): - arg_layout, committed = ( - pxla._maybe_get_default_layout(getattr(arg, 'layout', None), jit_in_l, - rs, aval), - getattr(arg, '_committed', True)) + committed = getattr(arg, '_committed', True) + # `arg_layout` is only used for checking purposes in the `else` branch + # below. We cannot replace default layout with None to raise nicer errors. + # `dispatch_arg_layout` replaces default layouts with `None` to simplify + # dispatch and lowering logic downstream. + if hasattr(arg, 'layout'): + arg_layout = arg.layout.device_local_layout + dispatch_arg_layout = (None if pxla.is_default_layout(arg_layout, rs, aval) + else arg_layout) + else: + arg_layout, dispatch_arg_layout = None, None # Sharding can be unspecified when array is committed if it's a PmapSharding. is_pmap_sharding = (is_unspecified(rs) or isinstance(getattr(arg, 'sharding', None), PmapSharding)) @@ -1454,7 +1499,7 @@ def _resolve_in_layouts(args, jit_in_layouts, resolved_in_shardings, in_avals): if is_pmap_sharding: resolved_in_layouts.append(None) else: - resolved_in_layouts.append(arg_layout) + resolved_in_layouts.append(dispatch_arg_layout) else: resolved_in_layouts.append(None) else: @@ -1484,11 +1529,8 @@ def _resolve_in_layouts(args, jit_in_layouts, resolved_in_shardings, in_avals): return tuple(resolved_in_layouts) -def _resolve_in_shardings( - args, pjit_in_shardings: Sequence[PjitSharding], - out_shardings: Sequence[PjitSharding], - pjit_mesh: pxla.Mesh | None, - check_device_assignment: bool = True) -> Sequence[PjitSharding]: +def _resolve_in_shardings(args, pjit_in_shardings: Sequence[PjitSharding] + ) -> Sequence[PjitSharding]: # If True, means that device or backend is set by the user on pjit and it # has the same semantics as device_put i.e. doesn't matter which device the # arg is on, reshard it to the device mentioned. So don't do any of the @@ -1511,18 +1553,6 @@ def _resolve_in_shardings( if getattr(a, '_committed', True): committed_arg_shardings.append((arg_s, pxla.MismatchType.ARG_SHARDING, None)) - # Check if the device_assignment across inputs, outputs and arguments is the - # same. - if check_device_assignment: - pxla._get_and_check_device_assignment( - it.chain( - util.stable_unique(committed_arg_shardings), - ((i, pxla.MismatchType.IN_SHARDING, None) - for i in util.stable_unique(pjit_in_shardings)), - ((o, pxla.MismatchType.OUT_SHARDING, None) - for o in util.stable_unique(out_shardings))), - (None if pjit_mesh is None or pjit_mesh.empty else list(pjit_mesh.devices.flat))) - resolved_in_shardings = [] for arg, pjit_in_s in zip(args, pjit_in_shardings): # arg sharding can be None in case of ShapeDtypeStruct. jax.Array does @@ -1592,9 +1622,7 @@ def _resolve_and_lower( args, jaxpr, in_shardings, out_shardings, in_layouts, out_layouts, resource_env, donated_invars, name, keep_unused, inline, lowering_platforms, lowering_parameters, pgle_profiler): - in_shardings = _resolve_in_shardings( - args, in_shardings, out_shardings, - resource_env.physical_mesh if resource_env is not None else None) + in_shardings = _resolve_in_shardings(args, in_shardings) in_layouts = _resolve_in_layouts(args, in_layouts, in_shardings, jaxpr.in_avals) lowered = _pjit_lower( @@ -1685,7 +1713,7 @@ def _pjit_call_impl_python( "`jit` decorator, at the cost of losing optimizations. " "\n\n" "If you see this error, consider opening a bug report at " - "https://github.com/google/jax.") + "https://github.com/jax-ml/jax.") raise FloatingPointError(msg) @@ -1723,14 +1751,27 @@ def call_impl_cache_miss(*args_, **kwargs_): f = _get_jaxpr_as_fun( jaxpr, in_shardings, out_shardings, in_layouts, out_layouts, resource_env, donated_invars, name, keep_unused, inline) - donated_argnums = [i for i, d in enumerate(donated_invars) if d] - has_explicit_sharding = _pjit_explicit_sharding( - in_shardings, out_shardings, None, None) - return xc._xla.pjit( - name, f, call_impl_cache_miss, [], [], donated_argnums, - tree_util.dispatch_registry, - lambda x, sharding: pxla.shard_args([sharding], [x])[0], - _get_cpp_global_cache(has_explicit_sharding))(*args) + donated_argnums = tuple(i for i, d in enumerate(donated_invars) if d) + if xla_extension_version >= 286: + cache_key = pxla.JitGlobalCppCacheKeys( + donate_argnums=donated_argnums, donate_argnames=None, + device=None, backend=None, + in_shardings_treedef=None, in_shardings_leaves=in_shardings, + out_shardings_treedef=None, out_shardings_leaves=out_shardings, + in_layouts_treedef=None, in_layouts_leaves=in_layouts, + out_layouts_treedef=None, out_layouts_leaves=out_layouts, + use_resource_env=resource_env is not None) + return xc._xla.pjit( + name, f, call_impl_cache_miss, [], [], cache_key, + tree_util.dispatch_registry, pxla.cc_shard_arg, + _get_cpp_global_cache(cache_key.contains_explicit_attributes))(*args) + else: + has_explicit_sharding = _pjit_explicit_sharding_and_layout( + in_shardings, out_shardings, in_layouts, out_layouts, None, None) + return xc._xla.pjit( + name, f, call_impl_cache_miss, [], [], donated_argnums, + tree_util.dispatch_registry, pxla.cc_shard_arg, + _get_cpp_global_cache(has_explicit_sharding))(*args) pjit_p.def_impl(_pjit_call_impl) @@ -1755,14 +1796,8 @@ def _pjit_lower_cached( lowering_platforms: tuple[str, ...] | None, lowering_parameters: mlir.LoweringParameters, pgle_profiler: profiler.PGLEProfiler | None): - if resource_env is not None: - mesh = resource_env.physical_mesh - api_name = 'pjit' - else: - # resource_env is `None` in the jit wrapper around pjit. - mesh = None - api_name = 'jit' - + mesh, api_name = ((resource_env.physical_mesh, 'pjit') + if resource_env is not None else (None, 'jit')) return pxla.lower_sharding_computation( jaxpr, api_name, name, in_shardings, out_shardings, in_layouts, out_layouts, tuple(donated_invars), @@ -1777,13 +1812,11 @@ def pjit_staging_rule(trace, *args, **params): params['jaxpr'], params['out_shardings'], params['out_layouts']) params = dict(params, jaxpr=jaxpr, out_shardings=out_shardings, out_layouts=out_layouts) - if (params["inline"] and all(is_unspecified(i) for i in params["in_shardings"]) and all(is_unspecified(o) for o in params["out_shardings"]) and all(i is None for i in params["in_layouts"]) and all(o is None for o in params["out_layouts"])): - if config.dynamic_shapes.value: # Inline jaxpr doesn't handle dynamic shapes when inlining. If dynamic # shapes are enabled, use eval_jaxpr, which uses the tracing machinery, @@ -1996,6 +2029,9 @@ def _pjit_batcher_for_sharding( if spmd_axis_name is None: if sharding_impls.is_op_sharding_replicated(hlo_s): return s + if isinstance(s, NamedSharding) and isinstance(s.mesh, AbstractMesh): + parsed_pspec = s._parsed_pspec.insert_axis_partitions(dim, None) + return NamedSharding._from_parsed_pspec(s.mesh, parsed_pspec) new_op = hlo_s.to_proto().clone() tad = list(new_op.tile_assignment_dimensions) tad.insert(dim, 1) @@ -2005,6 +2041,9 @@ def _pjit_batcher_for_sharding( _device_list=getattr(s, '_internal_device_list', None)) return pxla._get_out_sharding_from_orig_sharding([new_gs], [None], s, None)[0] else: + if isinstance(s, NamedSharding) and isinstance(s.mesh, AbstractMesh): + parsed_pspec = s._parsed_pspec.insert_axis_partitions(dim, spmd_axis_name) + return NamedSharding._from_parsed_pspec(s.mesh, parsed_pspec) if isinstance(s, NamedSharding): mesh = s.mesh if mesh is None or mesh.empty: @@ -2470,6 +2509,27 @@ def _identity_fn(x): return x def _sharding_constraint_impl(x, sharding, layout, resource_env, unconstrained_dims): + if (isinstance(sharding, NamedSharding) and + isinstance(sharding.mesh, AbstractMesh)): + aval = shaped_abstractify(x) + if not hasattr(x, 'sharding'): + raise ValueError( + 'Target sharding contains a `jax.sharding.AbstractMesh` which' + ' requires the input passed should be a `jax.Array`. Got' + f' {type(x)} with shape {aval.str_short()}') + if not isinstance(x.sharding, NamedSharding): + raise TypeError( + 'The sharding on the input must be a `NamedSharding` since the target' + ' sharding has an `AbstractMesh` in it. Got sharding type' + f' {type(x.sharding)} for shape {aval.str_short()}') + if x.sharding.mesh.shape_tuple != sharding.mesh.shape_tuple: + raise ValueError( + f'Mesh shape of the input {x.sharding.mesh.shape_tuple} does not' + ' match the mesh shape of the target sharding' + f' {sharding.mesh.shape_tuple} for shape {aval.str_short()}') + sharding = NamedSharding._from_parsed_pspec( + x.sharding.mesh, sharding._parsed_pspec) + if layout is None: if hasattr(x, 'sharding') and x.sharding.is_equivalent_to(sharding, x.ndim): return x diff --git a/jax/_src/prng.py b/jax/_src/prng.py index 7091305824ce..a9dca0b4bffe 100644 --- a/jax/_src/prng.py +++ b/jax/_src/prng.py @@ -321,6 +321,7 @@ def base_arr_shape_to_keys_shape(impl, base_arr_shape): class KeyTyRules: + allow_conversion: bool = False @staticmethod def full(shape, fill_value, dtype): @@ -425,14 +426,6 @@ def tangent_dtype(_): def zero(_): return np.zeros((), dtypes.float0) - @staticmethod - def convert_from(key_dtype, other_dtype) -> bool: - return False - - @staticmethod - def convert_to(other_dtype, key_dtype) -> bool: - return False - class KeyTy(dtypes.ExtendedDType): _impl: PRNGImpl # TODO(mattjj,frostig): protocol really @@ -466,11 +459,12 @@ def __hash__(self) -> int: xla.canonicalize_dtype_handlers[PRNGKeyArray] = lambda x: x -def key_array_shard_arg_handler(xs: Sequence[PRNGKeyArray], shardings): +def key_array_shard_arg_handler(xs: Sequence[PRNGKeyArray], shardings, layouts): arrs = [x._base_array for x in xs] phys_shardings = [physical_sharding(x.aval, sharding) for x, sharding in zip(xs, shardings)] - return pxla.shard_args(phys_shardings, arrs) + # TODO(yashkatariya): `layouts` should be converted to physical layouts. + return pxla.shard_args(phys_shardings, layouts, arrs) pxla.shard_arg_handlers[PRNGKeyArray] = key_array_shard_arg_handler diff --git a/jax/_src/profiler.py b/jax/_src/profiler.py index 1330a21486b4..cdf739944f4b 100644 --- a/jax/_src/profiler.py +++ b/jax/_src/profiler.py @@ -17,12 +17,12 @@ from collections.abc import Callable from contextlib import contextmanager from functools import wraps -import glob import gzip import http.server import json import logging import os +import pathlib import socketserver import threading from typing import Any @@ -88,7 +88,7 @@ def reset(self): _profile_state = _ProfileState() -def start_trace(log_dir, create_perfetto_link: bool = False, +def start_trace(log_dir: os.PathLike | str, create_perfetto_link: bool = False, create_perfetto_trace: bool = False) -> None: """Starts a profiler trace. @@ -132,17 +132,15 @@ def start_trace(log_dir, create_perfetto_link: bool = False, _profile_state.log_dir = str(log_dir) -def _write_perfetto_trace_file(log_dir): +def _write_perfetto_trace_file(log_dir: os.PathLike | str): # Navigate to folder with the latest trace dump to find `trace.json.jz` - curr_path = os.path.abspath(log_dir) - root_trace_folder = os.path.join(curr_path, "plugins", "profile") - trace_folders = [os.path.join(root_trace_folder, trace_folder) for - trace_folder in os.listdir(root_trace_folder)] - latest_folder = max(trace_folders, key=os.path.getmtime) - trace_jsons = glob.glob(os.path.join(latest_folder, "*.trace.json.gz")) - if len(trace_jsons) != 1: - raise ValueError(f"Invalid trace folder: {latest_folder}") - trace_json, = trace_jsons + trace_folders = (pathlib.Path(log_dir).absolute() / "plugins" / "profile").iterdir() + latest_trace_folder = max(trace_folders, key=os.path.getmtime) + trace_jsons = latest_trace_folder.glob("*.trace.json.gz") + try: + trace_json, = trace_jsons + except ValueError as value_error: + raise ValueError(f"Invalid trace folder: {latest_trace_folder}") from value_error logger.info("Loading trace.json.gz and removing its metadata...") # Perfetto doesn't like the `metadata` field in `trace.json` so we remove @@ -152,8 +150,7 @@ def _write_perfetto_trace_file(log_dir): with gzip.open(trace_json, "rb") as fp: trace = json.load(fp) del trace["metadata"] - filename = "perfetto_trace.json.gz" - perfetto_trace = os.path.join(latest_folder, filename) + perfetto_trace = latest_trace_folder / "perfetto_trace.json.gz" logger.info("Writing perfetto_trace.json.gz...") with gzip.open(perfetto_trace, "w") as fp: fp.write(json.dumps(trace).encode("utf-8")) @@ -173,11 +170,11 @@ def do_GET(self): def do_POST(self): self.send_error(404, "File not found") -def _host_perfetto_trace_file(path): +def _host_perfetto_trace_file(path: os.PathLike | str): # ui.perfetto.dev looks for files hosted on `127.0.0.1:9001`. We set up a # TCP server that is hosting the `perfetto_trace.json.gz` file. port = 9001 - orig_directory = os.path.abspath(os.getcwd()) + orig_directory = pathlib.Path.cwd() directory, filename = os.path.split(path) try: os.chdir(directory) @@ -203,7 +200,7 @@ def stop_trace(): if _profile_state.profile_session is None: raise RuntimeError("No profile started") sess = _profile_state.profile_session - sess.export(sess.stop(), _profile_state.log_dir) + sess.export(sess.stop(), str(_profile_state.log_dir)) if _profile_state.create_perfetto_trace: abs_filename = _write_perfetto_trace_file(_profile_state.log_dir) if _profile_state.create_perfetto_link: @@ -227,7 +224,7 @@ def stop_and_get_fdo_profile() -> bytes | str: @contextmanager -def trace(log_dir, create_perfetto_link=False, create_perfetto_trace=False): +def trace(log_dir: os.PathLike | str, create_perfetto_link=False, create_perfetto_trace=False): """Context manager to take a profiler trace. The trace will capture CPU, GPU, and/or TPU activity, including Python diff --git a/jax/_src/random.py b/jax/_src/random.py index 113bcc450100..203f72d406e5 100644 --- a/jax/_src/random.py +++ b/jax/_src/random.py @@ -14,7 +14,7 @@ from __future__ import annotations -from collections.abc import Hashable, Sequence +from collections.abc import Sequence from functools import partial import math from operator import index @@ -146,11 +146,17 @@ class PRNGSpec: def __init__(self, impl): self._impl = impl - def __str__(self) -> str: return str(self._impl) - def __hash__(self) -> int: return hash(self._impl) + def __repr__(self) -> str: + return f"PRNGSpec({self._impl.name!r})" + + def __str__(self) -> str: + return str(self._impl) + + def __hash__(self) -> int: + return hash(self._impl) def __eq__(self, other) -> bool: - return self._impl == other._impl + return isinstance(other, PRNGSpec) and self._impl == other._impl # TODO(frostig,vanderplas): remove PRNGImpl from this union when it's @@ -197,9 +203,10 @@ def key(seed: int | ArrayLike, *, impl: PRNGSpecDesc | None = None) -> KeyArray: """Create a pseudo-random number generator (PRNG) key given an integer seed. - The result is a scalar array with a key that indicates the default PRNG - implementation, as determined by the optional ``impl`` argument or, - otherwise, by the ``jax_default_prng_impl`` config flag. + The result is a scalar array containing a key, whose dtype indicates + the default PRNG implementation, as determined by the optional + ``impl`` argument or, otherwise, by the ``jax_default_prng_impl`` + config flag at the time when this function is called. Args: seed: a 64- or 32-bit integer used as the value of the key. @@ -214,11 +221,20 @@ def key(seed: int | ArrayLike, *, def PRNGKey(seed: int | ArrayLike, *, impl: PRNGSpecDesc | None = None) -> KeyArray: - """Create a pseudo-random number generator (PRNG) key given an integer seed. + """Create a legacy PRNG key given an integer seed. - The resulting key carries the default PRNG implementation, as - determined by the optional ``impl`` argument or, otherwise, by the - ``jax_default_prng_impl`` config flag. + This function produces old-style legacy PRNG keys, which are arrays + of dtype ``uint32``. For more, see the note in the `PRNG keys + `_ + section. When possible, :func:`jax.random.key` is recommended for + use instead. + + The resulting key does not carry a PRNG implementation. The returned + key matches the implementation given by the optional ``impl`` + argument or, otherwise, determined by the ``jax_default_prng_impl`` + config flag. Callers must ensure that same implementation is set as + the default when passing this key as an argument to other functions + (such as ``jax.random.split`` and ``jax.random.normal``). Args: seed: a 64- or 32-bit integer used as the value of the key. @@ -282,7 +298,7 @@ def _key_impl(keys: KeyArray) -> PRNGImpl: keys_dtype = typing.cast(prng.KeyTy, keys.dtype) return keys_dtype._impl -def key_impl(keys: KeyArrayLike) -> Hashable: +def key_impl(keys: KeyArrayLike) -> PRNGSpec: typed_keys, _ = _check_prng_key("key_impl", keys, allow_batched=True) return PRNGSpec(_key_impl(typed_keys)) @@ -492,7 +508,7 @@ def _randint(key, shape, minval, maxval, dtype) -> Array: span = lax.convert_element_type(maxval - minval, unsigned_dtype) # Ensure that span=1 when maxval <= minval, so minval is always returned; - # https://github.com/google/jax/issues/222 + # https://github.com/jax-ml/jax/issues/222 span = lax.select(maxval <= minval, lax.full_like(span, 1), span) # When maxval is out of range, the span has to be one larger. @@ -2031,6 +2047,11 @@ def orthogonal( Returns: A random array of shape `(*shape, n, n)` and specified dtype. + + References: + .. [1] Mezzadri, Francesco. (2007). "How to generate random matrices from + the classical compact groups". Notices of the American Mathematical + Society, 54(5), 592-604. https://arxiv.org/abs/math-ph/0609050. """ shape = core.canonicalize_shape(shape) key, _ = _check_prng_key("orthogonal", key) @@ -2519,7 +2540,7 @@ def _binomial(key, count, prob, shape, dtype) -> Array: _btrs(key, count_btrs, q_btrs, shape, dtype, max_iters), ) # ensure nan q always leads to nan output and nan or neg count leads to nan - # as discussed in https://github.com/google/jax/pull/16134#pullrequestreview-1446642709 + # as discussed in https://github.com/jax-ml/jax/pull/16134#pullrequestreview-1446642709 invalid = (q_l_0 | q_is_nan | count_nan_or_neg) samples = lax.select( invalid, diff --git a/jax/_src/scipy/fft.py b/jax/_src/scipy/fft.py index a826d4746b1e..a0050cc81055 100644 --- a/jax/_src/scipy/fft.py +++ b/jax/_src/scipy/fft.py @@ -21,7 +21,7 @@ from jax import lax import jax.numpy as jnp from jax._src.util import canonicalize_axis -from jax._src.numpy.util import promote_dtypes_complex +from jax._src.numpy.util import promote_dtypes_complex, promote_dtypes_inexact from jax._src.typing import Array def _W4(N: int, k: Array) -> Array: @@ -298,12 +298,12 @@ def idct(x: Array, type: int = 2, n: int | None = None, [(0, n - x.shape[axis] if a == axis else 0, 0) for a in range(x.ndim)]) N = x.shape[axis] - x = x.astype(jnp.float32) + x, = promote_dtypes_inexact(x) if norm is None or norm == 'backward': x = _dct_ortho_norm(x, axis) x = _dct_ortho_norm(x, axis) - k = lax.expand_dims(jnp.arange(N, dtype=jnp.float32), [a for a in range(x.ndim) if a != axis]) + k = lax.expand_dims(jnp.arange(N, dtype=x.dtype), [a for a in range(x.ndim) if a != axis]) # everything is complex from here... w4 = _W4(N,k) x = x.astype(w4.dtype) diff --git a/jax/_src/scipy/linalg.py b/jax/_src/scipy/linalg.py index 5458d71dedf4..d014e5ceb24e 100644 --- a/jax/_src/scipy/linalg.py +++ b/jax/_src/scipy/linalg.py @@ -951,7 +951,7 @@ def _solve(a: ArrayLike, b: ArrayLike, assume_a: str, lower: bool) -> Array: factors = cho_factor(lax.stop_gradient(a), lower=lower) custom_solve = partial( lax.custom_linear_solve, - lambda x: lax_linalg._matvec_multiply(a, x), + lambda x: lax_linalg._broadcasted_matvec(a, x), solve=lambda _, x: cho_solve(factors, x), symmetric=True) if a.ndim == b.ndim + 1: diff --git a/jax/_src/scipy/ndimage.py b/jax/_src/scipy/ndimage.py index d81008308b94..ee144eaf990a 100644 --- a/jax/_src/scipy/ndimage.py +++ b/jax/_src/scipy/ndimage.py @@ -176,7 +176,7 @@ def map_coordinates( Note: Interpolation near boundaries differs from the scipy function, because JAX - fixed an outstanding bug; see https://github.com/google/jax/issues/11097. + fixed an outstanding bug; see https://github.com/jax-ml/jax/issues/11097. This function interprets the ``mode`` argument as documented by SciPy, but not as implemented by SciPy. """ diff --git a/jax/_src/scipy/signal.py b/jax/_src/scipy/signal.py index cb3719fafd8f..d950cd2ea395 100644 --- a/jax/_src/scipy/signal.py +++ b/jax/_src/scipy/signal.py @@ -27,8 +27,9 @@ import jax.numpy.fft import jax.numpy as jnp from jax import lax -from jax._src.api_util import _ensure_index_tuple +from jax._src import core from jax._src import dtypes +from jax._src.api_util import _ensure_index_tuple from jax._src.lax.lax import PrecisionLike from jax._src.numpy import linalg from jax._src.numpy.util import ( @@ -655,8 +656,7 @@ def pad(x, n, axis=-1): f"Unknown boundary option '{boundary}', " f"must be one of: {list(boundary_funcs.keys())}") - axis = jax.core.concrete_or_error(operator.index, axis, - "axis of windowed-FFT") + axis = core.concrete_or_error(operator.index, axis, "axis of windowed-FFT") axis = canonicalize_axis(axis, x.ndim) if y is None: @@ -686,8 +686,8 @@ def pad(x, n, axis=-1): noverlap_int: int = 0 if nperseg is not None: # if specified by user - nperseg_int = jax.core.concrete_or_error(int, nperseg, - "nperseg of windowed-FFT") + nperseg_int = core.concrete_or_error( + int, nperseg, "nperseg of windowed-FFT") if nperseg_int < 1: raise ValueError('nperseg must be a positive integer') # parse window; if array like, then set nperseg = win.shape @@ -698,14 +698,13 @@ def pad(x, n, axis=-1): if noverlap is None: noverlap_int = nperseg_int // 2 else: - noverlap_int = jax.core.concrete_or_error(int, noverlap, - "noverlap of windowed-FFT") + noverlap_int = core.concrete_or_error( + int, noverlap, "noverlap of windowed-FFT") if nfft is None: nfft_int = nperseg_int else: - nfft_int = jax.core.concrete_or_error(int, nfft, - "nfft of windowed-FFT") + nfft_int = core.concrete_or_error(int, nfft, "nfft of windowed-FFT") # Special cases for size == 0 if y is None: @@ -1015,8 +1014,8 @@ def _overlap_and_add(x: Array, step_size: int) -> Array: An array with `(..., output_size)`-shape containing overlapped signal. """ check_arraylike("_overlap_and_add", x) - step_size = jax.core.concrete_or_error(int, step_size, - "step_size for overlap_and_add") + step_size = core.concrete_or_error( + int, step_size, "step_size for overlap_and_add") if x.ndim < 2: raise ValueError('Input must have (..., frames, frame_length) shape.') @@ -1114,7 +1113,7 @@ def istft(Zxx: Array, fs: ArrayLike = 1.0, window: str = 'hann', n_default = (2 * (Zxx.shape[freq_axis] - 1) if input_onesided else Zxx.shape[freq_axis]) - nperseg_int = jax.core.concrete_or_error(int, nperseg or n_default, + nperseg_int = core.concrete_or_error(int, nperseg or n_default, "nperseg: segment length of STFT") if nperseg_int < 1: raise ValueError('nperseg must be a positive integer') @@ -1125,13 +1124,13 @@ def istft(Zxx: Array, fs: ArrayLike = 1.0, window: str = 'hann', if input_onesided and nperseg_int == n_default + 1: nfft_int += 1 # Odd nperseg, no FFT padding else: - nfft_int = jax.core.concrete_or_error(int, nfft, "nfft of STFT") + nfft_int = core.concrete_or_error(int, nfft, "nfft of STFT") if nfft_int < nperseg_int: raise ValueError( f'FFT length ({nfft_int}) must be longer than nperseg ({nperseg_int}).') - noverlap_int = jax.core.concrete_or_error(int, noverlap or nperseg_int // 2, - "noverlap of STFT") + noverlap_int = core.concrete_or_error( + int, noverlap or nperseg_int // 2, "noverlap of STFT") if noverlap_int >= nperseg_int: raise ValueError('noverlap must be less than nperseg.') nstep = nperseg_int - noverlap_int diff --git a/jax/_src/scipy/spatial/transform.py b/jax/_src/scipy/spatial/transform.py index 46bd873bd029..debd37dde64f 100644 --- a/jax/_src/scipy/spatial/transform.py +++ b/jax/_src/scipy/spatial/transform.py @@ -167,12 +167,12 @@ def as_rotvec(self, degrees: bool = False) -> jax.Array: """Represent as rotation vectors.""" return _as_rotvec(self.quat, degrees) - def as_quat(self, canonical: bool=False) -> jax.Array: + def as_quat(self, canonical: bool=False, scalar_first: bool=False) -> jax.Array: """Represent as quaternions.""" - if canonical: - return _make_canonical(self.quat) - else: - return self.quat + quat = _make_canonical(self.quat) if canonical else self.quat + if scalar_first: + return jnp.roll(quat, shift=1, axis=-1) + return quat def inv(self): """Invert this rotation.""" diff --git a/jax/_src/scipy/special.py b/jax/_src/scipy/special.py index 3401edd9e112..837aa011f165 100644 --- a/jax/_src/scipy/special.py +++ b/jax/_src/scipy/special.py @@ -558,7 +558,7 @@ def entr(x: ArrayLike) -> Array: \mathrm{entr}(x) = \begin{cases} -x\log(x) & x > 0 \\ 0 & x = 0\\ - -\infty & x > 0 + -\infty & \mathrm{otherwise} \end{cases} Args: diff --git a/jax/_src/shard_alike.py b/jax/_src/shard_alike.py index 2361eaf6426d..574d725c4999 100644 --- a/jax/_src/shard_alike.py +++ b/jax/_src/shard_alike.py @@ -44,7 +44,7 @@ def shard_alike(x, y): raise ValueError( 'The leaves shapes of `x` and `y` should match. Got `x` leaf shape:' f' {x_aval.shape} and `y` leaf shape: {y_aval.shape}. File an issue at' - ' https://github.com/google/jax/issues if you want this feature.') + ' https://github.com/jax-ml/jax/issues if you want this feature.') outs = [shard_alike_p.bind(x_, y_) for x_, y_ in safe_zip(x_flat, y_flat)] x_out_flat, y_out_flat = zip(*outs) diff --git a/jax/_src/sharding.py b/jax/_src/sharding.py index fef61566c6ae..20fe3131dcba 100644 --- a/jax/_src/sharding.py +++ b/jax/_src/sharding.py @@ -144,6 +144,11 @@ def is_fully_addressable(self) -> bool: """ raise NotImplementedError('Subclasses should implement this method.') + @property + def num_devices(self) -> int: + """Number of devices that the sharding contains.""" + raise NotImplementedError('Subclasses should implement this method.') + @property def memory_kind(self) -> str | None: """Returns the memory kind of the sharding.""" diff --git a/jax/_src/sharding_impls.py b/jax/_src/sharding_impls.py index 98fd8c7b02c2..b69e78fe9ddf 100644 --- a/jax/_src/sharding_impls.py +++ b/jax/_src/sharding_impls.py @@ -18,7 +18,6 @@ from collections import OrderedDict from collections.abc import Mapping, Sequence import dataclasses -import enum import functools import itertools import math @@ -31,6 +30,7 @@ from jax._src import tree_util from jax._src import util from jax._src import xla_bridge +from jax._src import mesh_utils from jax._src.lib import xla_client as xc from jax._src.op_shardings import ( are_op_shardings_equal, get_num_ways_dim_sharded, is_op_sharding_replicated) @@ -53,18 +53,17 @@ class TransferToMemoryKind: @util.cache(max_size=128, trace_context_in_key=False) def _check_mesh_resource_axis(mesh, parsed_pspec, _manual_axes): - try: - for p in parsed_pspec: - if p is not None: - for r in p: - mesh.shape[r] - if r in _manual_axes: - raise ValueError( - f"Axis: {r} of {parsed_pspec.get_partition_spec()} " - f"is also found in manual_axes: {_manual_axes}.") from None - except KeyError as e: - raise ValueError(f"Resource axis: {e.args[0]} of {parsed_pspec.user_spec} is " - "undefined.") from None + for p in parsed_pspec: + if p is not None: + for r in p: + if r not in mesh.shape: + raise ValueError( + f"Resource axis: {r} of {parsed_pspec.get_partition_spec()} " + f"is not found in mesh: {tuple(mesh.shape.keys())}.") + if r in _manual_axes: + raise ValueError( + f"Axis: {r} of {parsed_pspec.get_partition_spec()} " + f"is also found in manual_axes: {_manual_axes}.") from None def hashed_index(x) -> int: @@ -192,7 +191,7 @@ class NamedSharding(sharding.Sharding): >>> named_sharding = jax.sharding.NamedSharding(mesh, spec) """ - mesh: mesh_lib.Mesh + mesh: mesh_lib.Mesh | mesh_lib.AbstractMesh spec: PartitionSpec _memory_kind: str | None _parsed_pspec: ParsedPartitionSpec @@ -200,7 +199,7 @@ class NamedSharding(sharding.Sharding): @use_cpp_method() def __init__( - self, mesh: mesh_lib.Mesh, spec: PartitionSpec, *, + self, mesh: mesh_lib.Mesh | mesh_lib.AbstractMesh, spec: PartitionSpec, *, memory_kind: str | None = None, _parsed_pspec=None, _manual_axes=frozenset()): self.mesh = mesh @@ -258,22 +257,38 @@ def _from_parsed_pspec( memory_kind=memory_kind, _parsed_pspec=parsed_pspec, _manual_axes=_manual_axes) + @property + def num_devices(self) -> int: + return self.mesh.size + @property def device_set(self) -> set[Device]: + if isinstance(self.mesh, mesh_lib.AbstractMesh): + raise ValueError( + 'device_set is not implemented for `jax.sharding.AbstractMesh`.') return self.mesh._flat_devices_set @property def _device_assignment(self) -> XLADeviceAssignment: + if isinstance(self.mesh, mesh_lib.AbstractMesh): + raise ValueError('_device_assignment is not implemented for' + ' `jax.sharding.AbstractMesh`.') return self.mesh._flat_devices_tuple @property def is_fully_addressable(self) -> bool: + if isinstance(self.mesh, mesh_lib.AbstractMesh): + raise ValueError('is_fully_addressable is not implemented for ' + '`jax.sharding.AbstractMesh`.') # Speed up `is_fully_addressable` since there is a high chance that the # mesh across multiple NamedSharding objects will be the same. return not self.mesh.is_multi_process @property def addressable_devices(self) -> set[Device]: + if isinstance(self.mesh, mesh_lib.AbstractMesh): + raise ValueError('addressable_devices is not implemented for ' + '`jax.sharding.AbstractMesh`.') # Override addressable devices because there is a high chance that the mesh # across multiple NamedSharding objects will be the same. return self.mesh._local_devices_set @@ -355,6 +370,10 @@ def __eq__(self, other): return (self._device == other._device and self.memory_kind == other.memory_kind) + @property + def num_devices(self) -> int: + return len(self.device_set) + @property def device_set(self) -> set[Device]: return {self._device} @@ -490,6 +509,10 @@ def default(cls, shape: Shape, sharded_dim: int = 0, pmap_devices = np.array(devices) return cls(pmap_devices, sharding_spec) + @property + def num_devices(self) -> int: + return len(self.device_set) + @functools.cached_property def device_set(self) -> set[Device]: return set(self.devices.flat) @@ -667,13 +690,9 @@ def check_compatible_aval(self, aval_shape: Shape) -> None: def _remake( cls, devices: tuple[xc.Device, ...], ids: np.ndarray, *, memory_kind: str | None = None) -> PositionalSharding: - self = cls.__new__(cls) - self._devices = devices - self._ids = ids - self._internal_device_list = xc.DeviceList(self._devices) - self._memory_kind = xc.check_and_canonicalize_memory_kind( - memory_kind, self._internal_device_list) - return self + sharding = cls(devices, memory_kind=memory_kind) + sharding._ids = ids + return sharding # Hashable @@ -696,6 +715,10 @@ def __eq__(self, other) -> bool: # Sharding interface + @property + def num_devices(self) -> int: + return len(self.device_set) + @functools.cached_property def device_set(self) -> set[xc.Device]: return set(self._devices) @@ -815,6 +838,10 @@ def check_compatible_aval(self, aval_shape: Shape) -> None: f"{len(num_ways_dim_sharded)}, but was applied to a value of rank " f"{len(aval_shape)}") + @property + def num_devices(self) -> int: + return len(self.device_set) + @functools.cached_property def device_set(self) -> set[Device]: return set(self._devices) @@ -927,43 +954,20 @@ def get_array_mapping( cast(ArrayMapping, get_array_mapping(p))) -class SpecSync(enum.IntEnum): - """Encodes how much out of sync the real value of partitions is compared to the user specified one. - - We use this to make sure we don't show garbage modified values while claiming - that the users have specified them like that. - """ - OUT_OF_SYNC = 0 # Arbitrary changes, including new axes inserted - DIM_PERMUTE = 1 # Dimensions permuted, but no new sharding axes - IN_SYNC = 2 # Entirely in sync - class ParsedPartitionSpec: - __slots__ = ('unsafe_user_spec', 'partitions', 'sync') + __slots__ = ('_user_spec', 'partitions') - def __init__(self, user_spec, partitions, sync=SpecSync.IN_SYNC): - self.unsafe_user_spec = user_spec + def __init__(self, user_spec, partitions): + self._user_spec = user_spec # None in partitions represents unconstrained dim. # TODO(yashkatariya): May use a sentinel value. self.partitions = tuple(partitions) - self.sync = sync - - @property - def user_spec(self): - return self.unsynced_user_spec(SpecSync.IN_SYNC) def get_partition_spec(self) -> PartitionSpec: - if self.sync < SpecSync.IN_SYNC: - return get_single_pspec(self) + if isinstance(self._user_spec, PartitionSpec): + return self._user_spec else: - if isinstance(self.unsafe_user_spec, PartitionSpec): - return self.unsafe_user_spec - else: - return get_single_pspec(self) - - def unsynced_user_spec(self, min_sync): - if self.sync < min_sync: - raise AssertionError(f"Please open a bug report! ({self.sync} >= {min_sync})") - return self.unsafe_user_spec + return get_single_pspec(self) def insert_axis_partitions(self, dim, val): parts = self.partitions @@ -971,8 +975,7 @@ def insert_axis_partitions(self, dim, val): if too_short > 0: parts += ((),) * too_short new_partitions = util.tuple_insert(parts, dim, val) - new_sync = SpecSync.DIM_PERMUTE if (val == () or val is None) else SpecSync.OUT_OF_SYNC - return ParsedPartitionSpec(self.unsafe_user_spec, new_partitions, sync=new_sync) + return ParsedPartitionSpec(None, new_partitions) @classmethod def from_user_input(cls, entry, arg_name, allow_unconstrained_dims=False): @@ -999,11 +1002,12 @@ def from_user_input(cls, entry, arg_name, allow_unconstrained_dims=False): return cls(new_entry, axis_specs) def __hash__(self): - return hash((self.partitions, self.sync)) + return hash(self.partitions) def __eq__(self, other): - return (self.partitions == other.partitions and - self.sync == other.sync) + if not isinstance(other, ParsedPartitionSpec): + return False + return self.partitions == other.partitions def __len__(self): return len(self.partitions) @@ -1015,58 +1019,19 @@ def __iter__(self): return iter(self.partitions) def __repr__(self): - return (f"ParsedPartitionSpec(partitions={self.partitions}, " - f"unsafe_user_spec={self.unsafe_user_spec}, " - f"sync={self.sync})") - -class CanonicalizedParsedPartitionSpec(ParsedPartitionSpec): - """ParsedPartitionSpecs that are canonicalized. - - ParsedPartitionSpecs may contain trailing empty tuples, that make them - semantically different in general, and yet in some situations we prefer - to regard them as equivalent. For example, partitions of () and ((),) - cannot be always considered equivalent, since the first one is a valid - spec for a scalar value, while the second is not! However, when either of - those are applied to a 2D array, they both mean that the array is fully - replicated. - - So CanonicalizedParsedPartitionSpecs removes the trailing empty tuples from - partitions. - """ - - def __init__(self, parsed_pspec: ParsedPartitionSpec): - partitions = list(parsed_pspec.partitions) - while partitions and partitions[-1] == (): - partitions.pop() - - super().__init__(parsed_pspec.unsafe_user_spec, partitions, - parsed_pspec.sync) - - def __repr__(self): - return (f"CanonicalizedParsedPartitionSpec(partitions={self.partitions}, " - f"unsafe_user_spec={self.unsafe_user_spec}, " - f"sync={self.sync})") + return f"ParsedPartitionSpec(partitions={self.partitions})" def preprocess(mesh, spec, parsed_pspec, _manual_axes=frozenset()): - # This split exists because you can pass `_parsed_pspec` that has been - # modified from the original. For example: Adding extra dimension to - # axis_resources for vmap handlers. In such cases you need to preserve the - # `sync` attribute of parsed pspecs. - # PartitionSpec is inferred from the parsed pspec in this case. - # TODO(yaskatariya): Remove this and replace this with a normalized - # representation of Parsed Pspec if parsed_pspec is None: parsed_pspec = prepare_axis_resources( PartitionSpec() if spec is None else spec, "NamedSharding spec", allow_unconstrained_dims=True) - _check_mesh_resource_axis(mesh, parsed_pspec, _manual_axes) return parsed_pspec -def prepare_axis_resources(axis_resources, - arg_name, +def prepare_axis_resources(axis_resources, arg_name, allow_unconstrained_dims=False): # PyTrees don't treat None values as leaves, so we use an is_leaf function. entries, treedef = tree_util.tree_flatten( @@ -1103,9 +1068,11 @@ def _check_unique_resources(axis_resources, arg_name): if resource_counts.most_common(1)[0][1] > 1: multiple_uses = [r for r, c in resource_counts.items() if c > 1] if multiple_uses: - raise ValueError(f"A single {arg_name} specification can map every mesh axis " - f"to at most one positional dimension, but {arg_axis_resources.user_spec} " - f"has duplicate entries for {mesh_lib.show_axes(multiple_uses)}") + raise ValueError( + f'A single {arg_name} specification can map every mesh axis to at' + ' most one positional dimension, but' + f' {arg_axis_resources.get_partition_spec()} has duplicate entries' + f' for {mesh_lib.show_axes(multiple_uses)}') # Axis environments @@ -1284,8 +1251,7 @@ def parse_flatten_op_sharding(hlo_sharding: xc.OpSharding | xc.HloSharding, out.extend(parse_flatten_op_sharding(s, mesh)) return out elif hlo_sharding.is_replicated(): - return [CanonicalizedParsedPartitionSpec( - ParsedPartitionSpec(PartitionSpec(), ()))] + return [ParsedPartitionSpec(PartitionSpec(), ())] elif hlo_sharding.is_tiled(): mesh_shape = mesh.shape mesh_axis_order = unflatten_array( @@ -1309,8 +1275,9 @@ def parse_flatten_op_sharding(hlo_sharding: xc.OpSharding | xc.HloSharding, ) if hlo_sharding.replicate_on_last_tile_dim(): partitions = partitions[:-1] - return [CanonicalizedParsedPartitionSpec( - ParsedPartitionSpec('', partitions))] + while partitions and partitions[-1] == (): + partitions.pop() + return [ParsedPartitionSpec(None, partitions)] else: raise AssertionError("Unhandled OpSharding type. Please open a bug report!") @@ -1394,12 +1361,12 @@ def get_process_index_and_count( if (tensor_sharding.is_fully_addressable or tensor_sharding.is_fully_replicated): return (0, 1) - num_devices = len(tensor_sharding.device_set) # Get device to indices map, we don't care about the concrete # global shape here, only to get the distribution of shards across the tensor # using (num_devices, num_devices, ...) This is a universal shape that is # compatible with any mesh with num_devices. - device_map = tensor_sharding.devices_indices_map((num_devices,) * ndims) + device_map = tensor_sharding.devices_indices_map( + (tensor_sharding.num_devices,) * ndims) # Get the slices for 'dim' for all devices. global_slice = {k: v[dim] for k, v in device_map.items()} @@ -1542,7 +1509,7 @@ def num_addressable_indices( def physical_hlo_sharding(aval, hlo_sharding: xc.HloSharding) -> xc.HloSharding: - elt_aval = aval.dtype._rules.physical_element_aval(aval.dtype) + elt_aval = core.physical_element_aval(aval.dtype) new_op_sharding = hlo_sharding.to_proto().clone() partitions, num_replicas = get_num_ways_dim_sharded(hlo_sharding) suffix = [] if num_replicas == 1 else [num_replicas] @@ -1553,13 +1520,13 @@ def physical_hlo_sharding(aval, hlo_sharding: xc.HloSharding) -> xc.HloSharding: def is_single_device_sharding(sharding: sharding.Sharding) -> bool: # Special case PmapSharding here because PmapSharding maps away an axis # and needs to be handled separately.test_pjit_single_device_sharding_add - return len(sharding.device_set) == 1 and not isinstance(sharding, PmapSharding) + return sharding.num_devices == 1 and not isinstance(sharding, PmapSharding) def make_key_array_phys_sharding(aval, sharding): if is_single_device_sharding(sharding): return sharding elif isinstance(sharding, PmapSharding): - elt_aval = aval.dtype._rules.physical_element_aval(aval.dtype) + elt_aval = core.physical_element_aval(aval.dtype) trailing_sharding = [sharding_specs.NoSharding()] * elt_aval.ndim phys_sharding_spec = sharding_specs.ShardingSpec( sharding=(*sharding.sharding_spec.sharding, *trailing_sharding), @@ -1567,7 +1534,7 @@ def make_key_array_phys_sharding(aval, sharding): return PmapSharding(devices=sharding.devices, sharding_spec=phys_sharding_spec) elif isinstance(sharding, NamedSharding): - elt_aval = aval.dtype._rules.physical_element_aval(aval.dtype) + elt_aval = core.physical_element_aval(aval.dtype) trailing_spec = [None] * elt_aval.ndim return NamedSharding( sharding.mesh, @@ -1584,7 +1551,7 @@ def physical_sharding( def get_logical_gspmd_sharding(aval, phys_sharding): - elt_aval = aval.dtype._rules.physical_element_aval(aval.dtype) + elt_aval = core.physical_element_aval(aval.dtype) phys_hlo_sharding = phys_sharding._to_xla_hlo_sharding( aval.ndim + elt_aval.ndim) partitions, num_replicas = get_num_ways_dim_sharded(phys_hlo_sharding) @@ -1616,7 +1583,7 @@ def logical_sharding(aval, phys_sharding) -> sharding.Sharding: if is_single_device_sharding(phys_sharding): return phys_sharding elif isinstance(phys_sharding, PmapSharding): - elt_aval = aval.dtype._rules.physical_element_aval(aval.dtype) + elt_aval = core.physical_element_aval(aval.dtype) logical_sharding_spec = sharding_specs.ShardingSpec( sharding=phys_sharding.sharding_spec.sharding[:-elt_aval.ndim], mesh_mapping=phys_sharding.sharding_spec.mesh_mapping) @@ -1624,6 +1591,7 @@ def logical_sharding(aval, phys_sharding) -> sharding.Sharding: sharding_spec=logical_sharding_spec) elif isinstance(phys_sharding, NamedSharding): logical_gs = get_logical_gspmd_sharding(aval, phys_sharding) + assert isinstance(phys_sharding.mesh, mesh_lib.Mesh) return _gspmd_to_named_sharding_via_mesh( logical_gs, phys_sharding.mesh) else: @@ -1647,3 +1615,61 @@ def _gspmd_to_named_sharding_via_mesh( return create_mesh_pspec_sharding( mesh, parsed_pspec.get_partition_spec(), parsed_pspec, out_s.memory_kind) + + +def make_mesh(axis_shapes: Sequence[int], axis_names: Sequence[str], + *, devices: Sequence[xc.Device] | None = None) -> mesh_lib.Mesh: + """Creates an efficient mesh with the shape and axis names specified. + + This function attempts to automatically compute a good mapping from a set of + logical axes to a physical mesh. For example, on a TPU v3 with 8 devices: + + >>> mesh = jax.make_mesh((8,), ('x')) # doctest: +SKIP + >>> [d.id for d in mesh.devices.flat] # doctest: +SKIP + [0, 1, 2, 3, 6, 7, 4, 5] + + The above ordering takes into account the physical topology of TPU v3. + It orders the devices into a ring, which yields efficient all-reduces on a + TPU v3. + + Now, let's see another example with 16 devices of TPU v3: + + >>> mesh = jax.make_mesh((2, 8), ('x', 'y')) # doctest: +SKIP + >>> [d.id for d in mesh.devices.flat] # doctest: +SKIP + [0, 1, 2, 3, 6, 7, 4, 5, 8, 9, 10, 11, 14, 15, 12, 13] + >>> mesh = jax.make_mesh((4, 4), ('x', 'y')) # doctest: +SKIP + >>> [d.id for d in mesh.devices.flat] # doctest: +SKIP + [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15] + + As you can see, logical axes (`axis_shapes`) affect the ordering of the + devices. + + You can use `jax.experimental.mesh_utils.create_device_mesh` if you want to + use the extra arguments it provides like `contiguous_submeshes` and + `allow_split_physical_axes`. + + Args: + axis_shapes: Shape of the mesh. For example, axis_shape=(4, 2) + axis_names: Names of the mesh axes. For example, axis_names=('x', 'y') + devices: Optional keyword only argument, that allows you to specify the + devices you want to create a mesh with. + + Returns: + A `jax.sharding.Mesh` object. + """ + if devices is None: + devices = xla_bridge.devices() + axis_size = math.prod(axis_shapes) + if axis_size > len(devices): + raise ValueError( + f'Number of devices {len(devices)} must be >= the product ' + f'of mesh_shape {axis_shapes}') + elif axis_size < len(devices): + devices = devices[:axis_size] + if devices[0].device_kind == mesh_utils._TPU_V5_LITE: + allow_split_physical_axes = True + else: + allow_split_physical_axes = False + mesh_devices = mesh_utils.create_device_mesh( + axis_shapes, devices, allow_split_physical_axes=allow_split_physical_axes) + return mesh_lib.Mesh(mesh_devices, axis_names) diff --git a/jax/_src/stages.py b/jax/_src/stages.py index 549da2d39e77..3a2c375b64db 100644 --- a/jax/_src/stages.py +++ b/jax/_src/stages.py @@ -513,6 +513,7 @@ def output_shardings(self): # PyTree[sharding.Sharding] shardings_flat = self._executable.output_shardings() return tree_util.tree_unflatten(self.out_tree, shardings_flat) # pytype: disable=attribute-error + @property def input_layouts(self): layouts_flat = self._executable.input_layouts() assert all(isinstance(l, Layout) for l in layouts_flat) @@ -523,6 +524,7 @@ def input_layouts(self): else Layout() for i in range(self.in_tree.num_leaves)] return tree_util.tree_unflatten(self.in_tree, layouts_flat) # pytype: disable=attribute-error + @property def output_layouts(self): layouts_flat = self._executable.output_layouts() assert all(isinstance(l, Layout) for l in layouts_flat) @@ -732,12 +734,22 @@ def out_info(self): def lower(self, *, lowering_platforms: tuple[str, ...] | None = None, _private_parameters: mlir.LoweringParameters | None = None): + from jax._src.interpreters import pxla + from jax._src import pjit + if _private_parameters is None: _private_parameters = mlir.LoweringParameters() new_callable = functools.partial( self._lower_callable, lowering_platforms=lowering_platforms, lowering_parameters=_private_parameters) - return Lowered(new_callable(), self.args_info, self._out_tree) + try: + lowering = new_callable() + except pxla.DeviceAssignmentMismatchError as e: + fails, = e.args + msg = pjit._device_assignment_mismatch_error( + self.fun_name, fails, self._args_flat, 'jit', self._arg_names) + raise ValueError(msg) from None + return Lowered(lowering, self.args_info, self._out_tree) @runtime_checkable diff --git a/jax/_src/state/__init__.py b/jax/_src/state/__init__.py index 0041b2506061..2f1c88be495b 100644 --- a/jax/_src/state/__init__.py +++ b/jax/_src/state/__init__.py @@ -12,7 +12,15 @@ # See the License for the specific language governing permissions and # limitations under the License. """Module for state.""" -from jax._src.state.types import (AbstractRef, ReadEffect, WriteEffect, - AccumEffect, StateEffect, RefEffect, - get_ref_state_effects, shaped_array_ref, - RefView) +from jax._src.state.types import ( + AbstractRef, + AccumEffect, + ReadEffect, + RefEffect, + StateEffect, + Transform, + TransformedRef, + WriteEffect, + get_ref_state_effects, + shaped_array_ref, +) diff --git a/jax/_src/state/discharge.py b/jax/_src/state/discharge.py index e6ac0db98d08..7970440d29a6 100644 --- a/jax/_src/state/discharge.py +++ b/jax/_src/state/discharge.py @@ -20,10 +20,8 @@ import operator from typing import Any, Protocol, TypeVar -import numpy as np - -from jax._src import api_util from jax._src import ad_util +from jax._src import api_util from jax._src import config from jax._src import core from jax._src import linear_util as lu @@ -35,12 +33,20 @@ from jax._src.lax import lax from jax._src.lax import slicing as lax_slicing from jax._src.state import indexing -from jax._src.state.types import AbstractRef, RefEffect -from jax._src.state.primitives import get_p, swap_p, addupdate_p -from jax._src.state.utils import hoist_consts_to_refs +from jax._src.state.primitives import addupdate_p, get_p, swap_p +from jax._src.state.types import AbstractRef, RefBitcaster, RefEffect +from jax._src.state.utils import bitcast, hoist_consts_to_refs from jax._src.typing import Array -from jax._src.util import (safe_map, safe_zip, split_list, weakref_lru_cache, - partition_list, merge_lists, split_dict) +from jax._src.util import ( + merge_lists, + partition_list, + safe_map, + safe_zip, + split_dict, + split_list, + weakref_lru_cache, +) +import numpy as np ## JAX utilities @@ -158,33 +164,27 @@ def _is_trivial_indexer(indexer: indexing.NDIndexer): return False return True -def _convert_to_array_indexer(indexer: indexing.NDIndexer - ) -> tuple[int | Array, ...]: - # This is the general gather case. We need to create the gather arrays. - is_integer_indexer, _, integer_indexer = ( - indexing.unpack_ndindexer(indexer) - ) - total_shape = indexer.get_indexer_shape() - int_indexer_shape = indexer.int_indexer_shape - slice_shape = total_shape[len(int_indexer_shape):] - slice_dims = tuple( - i + len(int_indexer_shape) for i in range(len(slice_shape)) - ) - slice_dim_iter = iter(slice_dims) - slice_indexer: list[Array] = [] - for idx, is_int_index in zip(indexer.indices, is_integer_indexer): - if not is_int_index: - assert isinstance(idx, indexing.Slice) - slice_indices = lax.broadcasted_iota( - np.dtype("int32"), total_shape, next(slice_dim_iter) - ) + idx.start - slice_indexer.append(slice_indices) - integer_indexer = tuple( - lax.expand_dims(idx, (-1,)) for idx in integer_indexer - ) - continue - assert next(slice_dim_iter, None) is None - return tuple(merge_lists(is_integer_indexer, slice_indexer, integer_indexer)) + +def _maybe_convert_to_slice( + indexer: indexing.NDIndexer +) -> list[tuple[int, int, int]] | None: + args = [] + + for i in indexer.indices: + if not isinstance(i, indexing.Slice): + return None + + start = i.start + end = i.start + (i.size - 1) * i.stride + 1 + stride = i.stride + + # cannot convert to static `slice` if `start` or `end` is dynamic + if not isinstance(start, int) or not isinstance(end, int): + return None + + args.append((start, end, stride)) + + return args def _maybe_convert_to_dynamic_slice( @@ -198,10 +198,12 @@ def _maybe_convert_to_dynamic_slice( if not all(isinstance(i, indexing.Slice) or not np.shape(i) for i in indexer.indices): return None - # TODO(b/329733289): support strided load/store in interpret mode. + + # `lax.dynamic_slice` does not handle striding for i in indexer.indices: if isinstance(i, indexing.Slice) and i.stride > 1: - raise NotImplementedError("Unimplemented stride support.") + return None + _convert_i32 = lambda x: lax.convert_element_type(x, np.dtype("int32")) starts = tuple( _convert_i32(i.start) if isinstance(i, indexing.Slice) @@ -218,6 +220,35 @@ def _maybe_convert_to_dynamic_slice( return starts, sizes, squeeze_dims +def _convert_to_array_indexer(indexer: indexing.NDIndexer + ) -> tuple[int | Array, ...]: + # This is the general gather case. We need to create the gather arrays. + is_integer_indexer, _, integer_indexer = ( + indexing.unpack_ndindexer(indexer) + ) + total_shape = indexer.get_indexer_shape() + int_indexer_shape = indexer.int_indexer_shape + slice_shape = total_shape[len(int_indexer_shape):] + slice_dims = tuple( + i + len(int_indexer_shape) for i in range(len(slice_shape)) + ) + slice_dim_iter = iter(slice_dims) + slice_indexer: list[Array] = [] + for idx, is_int_index in zip(indexer.indices, is_integer_indexer): + if not is_int_index: + assert isinstance(idx, indexing.Slice) + slice_indices = lax.broadcasted_iota( + np.dtype("int32"), total_shape, next(slice_dim_iter) + ) * idx.stride + idx.start + slice_indexer.append(slice_indices) + integer_indexer = tuple( + lax.expand_dims(idx, (-1,)) for idx in integer_indexer + ) + continue + assert next(slice_dim_iter, None) is None + return tuple(merge_lists(is_integer_indexer, slice_indexer, integer_indexer)) + + @register_discharge_rule(get_p) def _get_discharge_rule( in_avals: Sequence[core.AbstractValue], @@ -239,52 +270,95 @@ def _prepend_scatter(x, indexer, val, *, add=False): return x[None].at[(0, *indexer)].add(val)[0] return x[None].at[(0, *indexer)].set(val)[0] +def _bitcast_array(x, bitcaster: RefBitcaster): + return bitcast(x, bitcaster.dtype) -def index_array(x, indexers): +def _index_array(x, indexer): + if _is_trivial_indexer(indexer): + return x + # Try the three APIs in the following order: `lax.slice`, + # `lax.dynamic_slice` and gather + if maybe_slice := _maybe_convert_to_slice(indexer): + x = lax_slicing.slice(x, *zip(*maybe_slice)) + # If everything in the indexer is a slice or ()-shaped, we can also + # use `lax.dynamic_slice` with 1-sized slices for ()-shaped indices. + # We need to squeeze out the 1-sized slices at the end. + elif maybe_slice := _maybe_convert_to_dynamic_slice(indexer): + starts, sizes, squeeze_dims = maybe_slice + y = lax_slicing.dynamic_slice(x, starts, sizes) + x = lax.squeeze(y, squeeze_dims) + else: + indexer = _convert_to_array_indexer(indexer) + x = x[None][(np.array(0, "int32"), *indexer)] + return x + + +def transform_array(x, transforms): + if transforms is None: + transforms = [] result = x - for indexer in indexers: - if _is_trivial_indexer(indexer): - continue - if indexer is None: + for transform in transforms: + if transform is None: continue - # If everything in the indexer is a slice or ()-shaped, we can also - # use `lax.dynamic_slice` with 1-sized slices for ()-shaped indices. - # We need to squeeze out the 1-sized slices at the end. - if maybe_slice := _maybe_convert_to_dynamic_slice(indexer): - starts, sizes, squeeze_dims = maybe_slice - y = lax_slicing.dynamic_slice(result, starts, sizes) - result = lax.squeeze(y, squeeze_dims) + if isinstance(transform, indexing.NDIndexer): + result = _index_array(result, transform) + elif isinstance(transform, RefBitcaster): + result = _bitcast_array(result, transform) else: - indexer = _convert_to_array_indexer(indexer) - result = result[None][(np.array(0, "int32"), *indexer)] + raise NotImplementedError(f"Unsupported transform: {transform}") return result -def index_swap_array(x, indexers, val): +def transform_swap_array(x, transforms, val): + if transforms is None: + transforms = [] result = x result_val = val - for indexer in indexers: - if _is_trivial_indexer(indexer): - continue - # If everything in the indexer is a slice or ()-shaped, we can also - # use `lax.dynamic_slice` with 1-sized slices for ()-shaped indices. - # We need to squeeze out the 1-sized slices at the end. - if maybe_slice := _maybe_convert_to_dynamic_slice(indexer): - starts, sizes, squeeze_dims = maybe_slice - result_old = lax_slicing.dynamic_slice(result, starts, sizes) - result_val = lax.expand_dims(result_val, squeeze_dims) - y = lax_slicing.dynamic_update_slice(result, result_val, starts) - result = lax.squeeze(result_old, squeeze_dims) - result_val = y + # Compute updated "val" (result). + _results = [x] + for transform in transforms: + if isinstance(transform, indexing.NDIndexer): + indexer = transform + if _is_trivial_indexer(indexer): + _results.append(None) + continue + # If everything in the indexer is a slice or ()-shaped, we can also + # use `lax.dynamic_slice` with 1-sized slices for ()-shaped indices. + # We need to squeeze out the 1-sized slices at the end. + if maybe_slice := _maybe_convert_to_dynamic_slice(indexer): + starts, sizes, squeeze_dims = maybe_slice + result_old = lax_slicing.dynamic_slice(result, starts, sizes) + result = lax.squeeze(result_old, squeeze_dims) + else: + indexer = _convert_to_array_indexer(indexer) + result = _prepend_gather(result, indexer) + _results.append(result) + elif isinstance(transform, RefBitcaster): + _results.append(_bitcast_array(result, transform)) + else: + raise NotImplementedError(f"Unsupported transform: {transform}") + + # Compute updated "x" (result_val) + for i, transform in reversed(list(enumerate(transforms))): + if isinstance(transform, indexing.NDIndexer): + indexer = transform + if _is_trivial_indexer(indexer): + continue + if maybe_slice := _maybe_convert_to_dynamic_slice(indexer): + starts, _, squeeze_dims = maybe_slice + result_val = lax.expand_dims(result_val, squeeze_dims) + result_val = lax_slicing.dynamic_update_slice( + _results[i], result_val, starts + ) + else: + indexer = _convert_to_array_indexer(indexer) + result_val = _prepend_scatter(_results[i], indexer, result_val) else: - indexer = _convert_to_array_indexer(indexer) - result_old = _prepend_gather(result, indexer) - result_val = _prepend_scatter(result, indexer, result_val) - result = result_old + raise NotImplementedError(f"Unsupported transform: {transform}") return result, result_val def _get_discharge(x, idx, tree): - indexers = tree_util.tree_unflatten(tree, idx) - return index_array(x, indexers) + transforms = tree_util.tree_unflatten(tree, idx) + return transform_array(x, transforms) @register_discharge_rule(swap_p) def _swap_discharge_rule( @@ -296,8 +370,8 @@ def _swap_discharge_rule( return (x_new, None) + (None,) * len(idx), z def _swap_discharge(x, val, idx, tree): - indexers = tree_util.tree_unflatten(tree, idx) - return index_swap_array(x, indexers, val) + transforms = tree_util.tree_unflatten(tree, idx) + return transform_swap_array(x, transforms, val) @register_discharge_rule(addupdate_p) def _addupdate_discharge_rule( @@ -309,10 +383,10 @@ def _addupdate_discharge_rule( return (ans, None) + (None,) * len(idx), [] def _addupdate_discharge(x, val, idx, tree): - indexers = tree_util.tree_unflatten(tree, idx) - if len(indexers) > 1: + transforms = tree_util.tree_unflatten(tree, idx) + if len(transforms) > 1: raise NotImplementedError("Only single indexer is supported.") - indexer = indexers[0] + indexer = transforms[0] if _is_trivial_indexer(indexer): return x + val # If everything in the indexer is a slice or ()-shaped, we can also @@ -416,7 +490,7 @@ def _run_state_jvp(primals: Sequence[Any], tangents: Sequence[Any], *, len(primals)]) del out_consts out_tangents_iter = iter(out_tangents) - out_tangents = [next(out_tangents_iter) if nz else ad_util.Zero.from_value(p) + out_tangents = [next(out_tangents_iter) if nz else ad_util.Zero.from_primal_value(p) for p, nz in zip(out_primals, nonzero_tangents)] return out_primals, out_tangents ad.primitive_jvps[run_state_p] = _run_state_jvp @@ -442,8 +516,7 @@ def eval_jaxpr(*refs): jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic( eval_jaxpr, [*in_avals, *res_ref_avals]) assert not consts - return jaxpr, [core.ShapedArray(a.inner_aval.shape, a.inner_aval.dtype) # pytype: disable=attribute-error - for a in res_ref_avals] + return jaxpr, [core.ShapedArray(a.shape, a.dtype) for a in res_ref_avals] def _convert_inputs_to_reads(num_res: int, jaxpr: core.Jaxpr) -> core.Jaxpr: assert not jaxpr.constvars, "Jaxpr should not have constvars" diff --git a/jax/_src/state/primitives.py b/jax/_src/state/primitives.py index 750d3239a019..773302c9f637 100644 --- a/jax/_src/state/primitives.py +++ b/jax/_src/state/primitives.py @@ -18,9 +18,6 @@ import types from typing import Any, Union -import numpy as np - - from jax._src import ad_util from jax._src import core from jax._src import dispatch @@ -28,14 +25,22 @@ from jax._src import tree_util from jax._src.interpreters import ad from jax._src.interpreters import batching -from jax._src.interpreters import partial_eval as pe from jax._src.interpreters import mlir +from jax._src.interpreters import partial_eval as pe from jax._src.lax import lax -from jax._src.typing import Array from jax._src.state import indexing -from jax._src.state.types import (AbstractRef, RefView, ReadEffect, WriteEffect, - AccumEffect) +from jax._src.state.types import ( + AbstractRef, + AccumEffect, + ReadEffect, + RefBitcaster, + Transform, + TransformedRef, + WriteEffect, +) +from jax._src.typing import Array from jax._src.util import safe_map, safe_zip +import numpy as np ## General utilities @@ -57,31 +62,43 @@ get_p = core.Primitive("get") get_p.def_impl(partial(dispatch.apply_primitive, get_p)) -Indexer = tuple[Union[int, slice, Array, types.EllipsisType], ...] +Indexer = Union[int, slice, Array, types.EllipsisType] + -def get_ref_and_indexers( - ref_or_view: Any, idx: Indexer | None, function_name: str -) -> tuple[Any, tuple[indexing.NDIndexer, ...]]: - if isinstance(ref_or_view, RefView): - ref, indexers = ref_or_view.ref, ref_or_view.indexers +def get_ref_and_transforms( + ref_or_view: Any, + idx: Indexer | tuple[Indexer, ...] | None, + function_name: str, +) -> tuple[Any, tuple[Transform, ...]]: + if isinstance(ref_or_view, TransformedRef): + ref, transforms = ref_or_view.ref, ref_or_view.transforms else: - ref, indexers = ref_or_view, () + ref, transforms = ref_or_view, () ref_aval = core.get_aval(ref) if not isinstance(ref_aval, AbstractRef): raise ValueError(f"Can only call `{function_name}` on a `Ref`: {ref}.") if not isinstance(ref_aval.inner_aval, core.ShapedArray): return ref, () - if idx is None: - return ref, indexers + + if idx is None or idx is Ellipsis: + idx = () + elif not isinstance(idx, tuple): + idx = (idx,) + + if not idx and transforms and isinstance(transforms[-1], indexing.NDIndexer): + return ref, transforms nd_indexer = indexing.NDIndexer.from_indices_shape(idx, ref_or_view.shape) - return ref, (*indexers, nd_indexer) + return ref, (*transforms, nd_indexer) -def ref_get(ref_or_view: Any, idx: Indexer | None = None) -> Array: +def ref_get( + ref_or_view: Any, idx: Indexer | tuple[Indexer, ...] | None = None +) -> Array: """Reads a value from a `Ref`, a.k.a. value <- ref[idx].""" - ref, indexers = get_ref_and_indexers(ref_or_view, idx, "ref_get") - flat_indexers, tree = tree_util.tree_flatten(indexers) - return get_p.bind(ref, *flat_indexers, tree=tree) + ref, transforms = get_ref_and_transforms(ref_or_view, idx, "ref_get") + flat_transforms, tree = tree_util.tree_flatten(transforms) + return get_p.bind(ref, *flat_transforms, tree=tree) + # `swap` mutates a `Ref`, setting its value and returns its previous value. # b = swap_p.bind(x, a) @@ -102,17 +119,28 @@ def ref_get(ref_or_view: Any, idx: Indexer | None = None) -> Array: swap_p = core.Primitive("swap") swap_p.def_impl(partial(dispatch.apply_primitive, swap_p)) -def ref_swap(ref_or_view: AbstractRef | RefView, idx: Indexer | None, value: Array, - _function_name: str = "ref_swap") -> Array: + +def ref_swap( + ref_or_view: AbstractRef | TransformedRef, + idx: Indexer | tuple[Indexer, ...] | None, + value: Array, + _function_name: str = "ref_swap", +) -> Array: """Sets a `Ref`'s value and returns the original value.""" - ref, indexers = get_ref_and_indexers(ref_or_view, idx, _function_name) - flat_indexers, tree = tree_util.tree_flatten(indexers) - return swap_p.bind(ref, value, *flat_indexers, tree=tree) + ref, transforms = get_ref_and_transforms(ref_or_view, idx, _function_name) + flat_transforms, tree = tree_util.tree_flatten(transforms) + return swap_p.bind(ref, value, *flat_transforms, tree=tree) + -def ref_set(ref_or_view: AbstractRef | RefView, idx: Indexer | None, value: Array) -> None: +def ref_set( + ref_or_view: AbstractRef | TransformedRef, + idx: Indexer | tuple[Indexer, ...] | None, + value: Array, +) -> None: """Sets a `Ref`'s value, a.k.a. ref[idx] <- value.""" ref_swap(ref_or_view, idx, value, _function_name="ref_set") + # `addupdate_p` mutates a `Ref`, adding a value to its existing value. # Semantically, # ``` @@ -128,36 +156,58 @@ def ref_set(ref_or_view: AbstractRef | RefView, idx: Indexer | None, value: Arra addupdate_p.multiple_results = True addupdate_p.def_impl(partial(dispatch.apply_primitive, addupdate_p)) -def ref_addupdate(ref_or_view: AbstractRef, idx: Indexer | None, x: Array) -> None: + +def ref_addupdate( + ref_or_view: AbstractRef, + idx: Indexer | tuple[Indexer, ...] | None, + x: Array, +) -> None: """Mutates a ref with an additive update i.e. `ref[idx] += x`.""" - ref, indexers = get_ref_and_indexers(ref_or_view, idx, "ref_addupdate") - flat_indexers, tree = tree_util.tree_flatten(indexers) - return addupdate_p.bind(ref, x, *flat_indexers, tree=tree) + ref, transforms = get_ref_and_transforms(ref_or_view, idx, "ref_addupdate") + flat_transforms, tree = tree_util.tree_flatten(transforms) + return addupdate_p.bind(ref, x, *flat_transforms, tree=tree) + ## get/set/addupdate abstract evaluation rules -def _shape_after_indexing( - shape: tuple[int | Array, ...], indexers: tuple[indexing.NDIndexer, ...] +def _shape_after_transforming( + shape: tuple[int | Array, ...], transforms: tuple[Transform, ...] ) -> tuple[int | Array, ...]: - for indexer in indexers: - # Run some simple checks that all the indexers have consistent shapes - if not indexer.is_dynamic_size: - assert indexer.shape == shape, (indexer.shape, shape) - shape = indexer.get_indexer_shape() + for transform in transforms: + match transform: + case indexing.NDIndexer(): + # Run some simple checks that all the indexers have consistent shapes + if not transform.is_dynamic_size: + assert transform.shape == shape, (transform.shape, shape) + shape = transform.get_indexer_shape() + case RefBitcaster(): + shape = transform.shape + case _: + raise ValueError(f"Unsupported transform: {transform}") return shape +def _dtype_after_transforming( + dtype: Any, transforms: tuple[Transform, ...] +) -> Any: + for transform in reversed(transforms): + if isinstance(transform, RefBitcaster): + return transform.dtype + return dtype + + def _get_abstract_eval(ref_aval: AbstractRef, *args, tree): - indexers = tree_util.tree_unflatten(tree, args) + transforms = tree_util.tree_unflatten(tree, args) if not isinstance(ref_aval, AbstractRef): raise ValueError(f"`get` must be called on `Ref` types: {ref_aval}.") if isinstance(ref_aval.inner_aval, core.ShapedArray): - out_shape = _shape_after_indexing(ref_aval.shape, indexers) - out_aval = ref_aval.inner_aval.update(shape=out_shape) + out_shape = _shape_after_transforming(ref_aval.shape, transforms) + out_dtype = _dtype_after_transforming(ref_aval.dtype, transforms) + out_aval = ref_aval.inner_aval.update(shape=out_shape, dtype=out_dtype) else: - if indexers: + if transforms: raise ValueError("Cannot index non-shaped array with nontrivial indices.") out_aval = ref_aval.inner_aval return (out_aval, {ReadEffect(0)}) @@ -166,27 +216,30 @@ def _get_abstract_eval(ref_aval: AbstractRef, *args, def _swap_abstract_eval(ref_aval: AbstractRef, val_aval: core.AbstractValue, *args: Any, tree): - indexers = tree_util.tree_unflatten(tree, args) + transforms = tree_util.tree_unflatten(tree, args) out_aval: core.AbstractValue if not isinstance(ref_aval, AbstractRef): raise ValueError(f"`swap` must be called on `Ref` types: {ref_aval}.") if isinstance(ref_aval.inner_aval, core.ShapedArray): val_aval = core.raise_to_shaped(val_aval) assert isinstance(val_aval, core.ShapedArray) - expected_out_shape = _shape_after_indexing(ref_aval.shape, indexers) + expected_out_shape = _shape_after_transforming(ref_aval.shape, transforms) + expected_out_dtype = _dtype_after_transforming(ref_aval.dtype, transforms) if expected_out_shape != val_aval.shape: raise ValueError("Invalid shape for `swap`. " f"Ref shape: {ref_aval.shape}. " f"Expected shape: {expected_out_shape}. " f"Value shape: {val_aval.shape}. " - f"Indices: {indexers}. ") - if ref_aval.dtype != val_aval.dtype and not val_aval.weak_type: - raise ValueError("Invalid dtype for `swap`. " - f"Ref dtype: {ref_aval.dtype}. " - f"Value dtype: {val_aval.dtype}. ") - out_aval = core.ShapedArray(expected_out_shape, ref_aval.dtype) + f"Transforms: {transforms}. ") + if expected_out_dtype != val_aval.dtype and not val_aval.weak_type: + raise ValueError( + "Invalid dtype for `swap`. " + f"Ref dtype: {expected_out_dtype}. " + f"Value dtype: {val_aval.dtype}. " + ) + out_aval = core.ShapedArray(expected_out_shape, expected_out_dtype) else: - if indexers: + if transforms: raise ValueError("Cannot index non-shaped array with nontrivial indices.") out_aval = ref_aval.inner_aval return (out_aval, {WriteEffect(0)}) @@ -196,26 +249,29 @@ def _swap_abstract_eval(ref_aval: AbstractRef, def _addupdate_abstract_eval(ref_aval: AbstractRef, val_aval: core.AbstractValue, *args: Any, tree): - indexers = tree_util.tree_unflatten(tree, args) + transforms = tree_util.tree_unflatten(tree, args) if not isinstance(ref_aval, AbstractRef): raise ValueError(f"`addupdate` must be called on `Ref` types: {ref_aval}.") if isinstance(ref_aval.inner_aval, core.ShapedArray): val_aval = core.raise_to_shaped(val_aval) - slice_shape = _shape_after_indexing(ref_aval.shape, indexers) + out_shape = _shape_after_transforming(ref_aval.shape, transforms) + out_dtype = _dtype_after_transforming(ref_aval.dtype, transforms) assert isinstance(val_aval, core.ShapedArray) - if slice_shape != val_aval.shape: - raise ValueError("Invalid shape for `addupdate`. " - f"Ref shape: {ref_aval.shape}. " - f"Slice shape: {slice_shape}. " - f"Value shape: {val_aval.shape}. " - f"Indices: {indexers}. ") - if ref_aval.dtype != val_aval.dtype: + if out_shape != val_aval.shape: + raise ValueError( + "Invalid shape for `addupdate`. " + f"Ref shape: {ref_aval.shape}. " + f"Expected shape: {out_shape}. " + f"Value shape: {val_aval.shape}. " + f"Transforms: {transforms}. " + ) + if out_dtype != val_aval.dtype: raise ValueError("Invalid dtype for `addupdate`. " f"Ref dtype: {ref_aval.dtype}. " f"Value shape: {val_aval.dtype}. ") else: - # Check that the indexers are valid - if indexers: + # Check that the transforms are valid + if transforms: raise ValueError("Cannot index non-shaped array with nontrivial indices.") return [], {AccumEffect(0)} addupdate_p.def_effectful_abstract_eval(_addupdate_abstract_eval) @@ -261,52 +317,73 @@ def pp_indexer(context: core.JaxprPpContext,indexer: indexing.NDIndexer indices.append(core.pp_var(idx, context)) # type: ignore return pp.concat([pp.text("["), pp.text(','.join(indices)), pp.text("]")]) -def _pp_indexers( - context: core.JaxprPpContext, indexers: tuple[indexing.NDIndexer, ...], + +def pp_bitcaster( + context: core.JaxprPpContext, bitcaster: RefBitcaster +) -> pp.Doc: + del context + return pp.text( + f"[bitcast({bitcaster.dtype}[{','.join(str(d) for d in bitcaster.shape)}])]" + ) + + +def pp_transform(context: core.JaxprPpContext, transform: Transform) -> pp.Doc: + match transform: + case indexing.NDIndexer(): + return pp_indexer(context, transform) + case RefBitcaster(): + return pp_bitcaster(context, transform) + case _: + raise ValueError(f"Unsupported transform: {transform}") + + +def _pp_transforms( + context: core.JaxprPpContext, + transforms: tuple[Transform, ...], ): - if not indexers: + if not transforms: return pp.text("[...]") return pp.concat( - [pp_indexer(context, indexer) for indexer in indexers] + [pp_transform(context, transform) for transform in transforms] ) -def pp_ref_indexers(context: core.JaxprPpContext, ref, indexers): + +def pp_ref_transforms(context: core.JaxprPpContext, ref, transforms): return pp_ref_var( pp.concat([ pp.text(core.pp_var(ref, context)), - _pp_indexers(context, indexers), + _pp_transforms(context, transforms), ]) ) + def _get_pp_rule(eqn, context, settings) -> pp.Doc: # Pretty prints `a = get x i` as `x[i] <- a` y, = eqn.outvars x, *flat_idx = eqn.invars - indexers = tree_util.tree_unflatten(eqn.params["tree"], flat_idx) + transforms = tree_util.tree_unflatten(eqn.params["tree"], flat_idx) lhs = core.pp_vars([y], context, print_shapes=settings.print_shapes) - return pp.concat([ - lhs, - pp.text(' <- '), - pp_ref_indexers(context, x, indexers) - ]) + return pp.concat( + [lhs, pp.text(" <- "), pp_ref_transforms(context, x, transforms)] + ) core.pp_eqn_rules[get_p] = _get_pp_rule def _swap_pp_rule(eqn, context, settings) -> pp.Doc: y, = eqn.outvars x, v, *flat_idx = eqn.invars - indexers = tree_util.tree_unflatten(eqn.params["tree"], flat_idx) + transforms = tree_util.tree_unflatten(eqn.params["tree"], flat_idx) if type(y) is core.DropVar: # In the case of a set (ignored return value), # pretty print `_ = swap x v i` as `x[i] <- v` del y return pp.concat([ - pp_ref_indexers(context, x, indexers), - pp.text(' <- '), - pp.text(core.pp_var(v, context)) - ]) + pp_ref_transforms(context, x, transforms), + pp.text(" <- "), + pp.text(core.pp_var(v, context)), + ]) else: # pretty-print `y:T = swap x v i` as `y:T, x[i] <- x[i], v` - x_i = pp_ref_indexers(context, x, indexers) + x_i = pp_ref_transforms(context, x, transforms) y = core.pp_vars([y], context, print_shapes=settings.print_shapes) return pp.concat([y, pp.text(', '), x_i, pp.text(' <- '), x_i, pp.text(', '), @@ -318,11 +395,12 @@ def _addupdate_pp_rule(eqn, context, settings) -> pp.Doc: # pretty-print ` = addupdate x i v` as `x[i] += v` () = eqn.outvars x, v, *flat_idx = eqn.invars - indexers = tree_util.tree_unflatten(eqn.params["tree"], flat_idx) + transforms = tree_util.tree_unflatten(eqn.params["tree"], flat_idx) return pp.concat([ - pp_ref_indexers(context, x, indexers), - pp.text(' += '), - pp.text(core.pp_var(v, context))]) + pp_ref_transforms(context, x, transforms), + pp.text(" += "), + pp.text(core.pp_var(v, context)), + ]) core.pp_eqn_rules[addupdate_p] = _addupdate_pp_rule ## get/swap/addupdate JVP rules diff --git a/jax/_src/state/types.py b/jax/_src/state/types.py index a71d671c5345..e64d6258a808 100644 --- a/jax/_src/state/types.py +++ b/jax/_src/state/types.py @@ -21,11 +21,13 @@ from typing import Any, Union from jax._src import core +from jax._src import dtypes from jax._src import effects from jax._src import pretty_printer as pp +from jax._src import tree_util from jax._src.state import indexing -from jax._src.util import safe_map, safe_zip from jax._src.typing import Array +from jax._src.util import safe_map, safe_zip ## JAX utilities @@ -72,7 +74,39 @@ class AccumEffect(RefEffect): StateEffect = Union[ReadEffect, WriteEffect, AccumEffect] + # ## `Ref`s +@tree_util.register_pytree_node_class +@dataclasses.dataclass(frozen=True) +class RefBitcaster: + dtype: dtypes.DType + shape: tuple[int, ...] + + @classmethod + def from_ref_new_dtype(cls, ref_or_view: Any, dtype) -> RefBitcaster: + if isinstance(ref_or_view, TransformedRef): + if ref_or_view.is_dynamic_size: + raise NotImplementedError( + "Bitcast ref with dynamic size is not supported." + ) + from jax._src.state.utils import eval_bitcast_shape # pytype: disable=import-error + dtype = dtypes.dtype(dtype) + return cls(dtype, eval_bitcast_shape(ref_or_view, dtype)) + + @property + def is_dynamic_size(self): + return False + + def tree_flatten(self): + return (), (self.dtype, self.shape) + + @classmethod + def tree_unflatten(cls, metadata, arrays): + assert not arrays + return cls(*metadata) + + +Transform = indexing.NDIndexer | RefBitcaster @dataclasses.dataclass class RefIndexer: @@ -82,37 +116,47 @@ def __getitem__(self, slc): if not isinstance(slc, tuple): slc = (slc,) indexer = indexing.NDIndexer.from_indices_shape(slc, self.ref_or_view.shape) - if isinstance(self.ref_or_view, RefView): + if isinstance(self.ref_or_view, TransformedRef): view = self.ref_or_view - return RefView(view.ref, (*view.indexers, indexer)) - return RefView(self.ref_or_view, (indexer,)) + return TransformedRef(view.ref, (*view.transforms, indexer)) + return TransformedRef(self.ref_or_view, (indexer,)) -Indexer = Any @dataclasses.dataclass -class RefView: +class TransformedRef: ref: Any - indexers: tuple[indexing.NDIndexer, ...] + transforms: tuple[Transform, ...] @property def is_dynamic_size(self): - return self.indexers[-1].is_dynamic_size + return self.transforms[-1].is_dynamic_size @property def shape(self) -> tuple[int | Array, ...]: assert ( - len(self.indexers) > 0 - ), "Should not be able to create a trivial RefView" - return self.indexers[-1].get_indexer_shape() + len(self.transforms) > 0 + ), "Should not be able to create a trivial TransformedRef" + if isinstance(self.transforms[-1], indexing.NDIndexer): + return self.transforms[-1].get_indexer_shape() + return self.transforms[-1].shape @property def dtype(self): + for transform in reversed(self.transforms): + if isinstance(transform, RefBitcaster): + return transform.dtype return self.ref.dtype @property def at(self) -> RefIndexer: return RefIndexer(self) + def bitcast(self, dtype): + return TransformedRef( + self.ref, + (*self.transforms, RefBitcaster.from_ref_new_dtype(self, dtype)), + ) + def __getattr__(self, name): return getattr(self.ref, name) @@ -152,20 +196,30 @@ def join(self, other): @property def shape(self): - if not isinstance(self.inner_aval, core.ShapedArray): - raise AttributeError(f"`Ref{{{self.inner_aval.str_short()}}} has no `shape`.") - return self.inner_aval.shape + try: + return self.inner_aval.shape # pytype: disable=attribute-error + except AttributeError: + raise AttributeError( + f"`Ref{{{self.inner_aval.str_short()}}} has no `shape`." + ) from None @property def dtype(self): - if not isinstance(self.inner_aval, core.UnshapedArray): - raise AttributeError(f"`Ref{{{self.inner_aval.str_short()}}} has no `dtype`.") - return self.inner_aval.dtype + try: + return self.inner_aval.dtype # pytype: disable=attribute-error + except AttributeError: + raise AttributeError( + f"`Ref{{{self.inner_aval.str_short()}}} has no `dtype`." + ) from None @core.aval_property def at(self): return RefIndexer(self) + @core.aval_method + def bitcast(self, dtype): + return TransformedRef(self, (RefBitcaster.from_ref_new_dtype(self, dtype),)) + @core.aval_method @staticmethod def get(tracer, idx=()): @@ -189,8 +243,8 @@ def _setitem(self, tracer, idx, value) -> None: def __repr__(self) -> str: return f'Ref{{{self.inner_aval.str_short()}}}' - def at_least_vspace(self): - return AbstractRef(self.inner_aval.at_least_vspace()) + def to_tangent_aval(self): + return AbstractRef(self.inner_aval.to_tangent_aval()) def __eq__(self, other): return (type(self) is type(other) and self.inner_aval == other.inner_aval) diff --git a/jax/_src/state/utils.py b/jax/_src/state/utils.py index 33fced775fad..909e84c3a6e3 100644 --- a/jax/_src/state/utils.py +++ b/jax/_src/state/utils.py @@ -13,14 +13,18 @@ # limitations under the License. """Utilities for tracing stateful functions.""" +from functools import partial from typing import Callable -from jax._src.interpreters import partial_eval as pe +import jax from jax._src import core +from jax._src import dtypes from jax._src import linear_util as lu +from jax._src.interpreters import partial_eval as pe from jax._src.state import AbstractRef -from jax._src.util import split_list, safe_map, safe_zip from jax._src.state.primitives import ref_get +from jax._src.typing import DTypeLike +from jax._src.util import safe_map, safe_zip, split_list map, unsafe_map = safe_map, map zip, unsafe_zip = safe_zip, zip @@ -79,3 +83,41 @@ def val_to_ref_aval(x) -> AbstractRef: if type(aval) is not core.ShapedArray: raise TypeError(f"can't make ref from {x}") return AbstractRef(aval) + + +def dtype_bitwidth(dtype: DTypeLike) -> int: + if dtypes.isdtype(dtype, "integral"): + return dtypes.iinfo(dtype).bits + return dtypes.dtype(dtype).itemsize * 8 + + +def bitcast(x, dtype: DTypeLike): + x_bitwidth = dtype_bitwidth(x.dtype) + y_bitwidth = dtype_bitwidth(dtype) + shape = list(x.shape) + if x_bitwidth != y_bitwidth: + if len(shape) < 2: + raise NotImplementedError( + "Bitcast 1D ref with bitwidth change is not supported." + ) + # Note: this is only valid on TPU. + if shape[-2] * x_bitwidth % y_bitwidth != 0: + raise ValueError( + "Expected input and output shapes are the same after multiplying" + " the second-minor dimension by the bitwidths." + ) + shape[-2] = shape[-2] * x_bitwidth // y_bitwidth + if x_bitwidth < y_bitwidth: + ratio = y_bitwidth // x_bitwidth + x = x.reshape(*x.shape[:-2], x.shape[-2] // ratio, ratio, -1).swapaxes( + -1, -2 + ) + y = jax.lax.bitcast_convert_type(x, dtype) + if x_bitwidth > y_bitwidth: + y = y.swapaxes(-1, -2).reshape(shape) + return y + + +def eval_bitcast_shape(x, dtype: DTypeLike): + f = partial(bitcast, dtype=dtype) + return jax.eval_shape(f, jax.ShapeDtypeStruct(x.shape, x.dtype)).shape diff --git a/jax/_src/test_util.py b/jax/_src/test_util.py index e4de7e7b787b..81737f27540b 100644 --- a/jax/_src/test_util.py +++ b/jax/_src/test_util.py @@ -39,6 +39,7 @@ import jax from jax import lax from jax._src import api +from jax._src import array from jax._src import config from jax._src import core from jax._src import dispatch @@ -172,7 +173,7 @@ def _normalize_tolerance(tol): if isinstance(tol, dict): return {np.dtype(k): v for k, v in tol.items()} else: - return {k: tol for k in _default_tolerance} + return dict.fromkeys(_default_tolerance, tol) def join_tolerance(tol1, tol2): tol1 = _normalize_tolerance(tol1) @@ -365,7 +366,7 @@ def compiled_call_count(*args, **kwargs): @contextmanager -def count_jit_and_pmap_compiles(): +def count_jit_and_pmap_lowerings(): # No need to clear any caches since we generally jit and pmap fresh callables # in tests. @@ -383,6 +384,44 @@ def mlir_lower_and_count(*args, **kwargs): mlir.lower_jaxpr_to_module = mlir_lower +@contextmanager +def count_jax_array_shard_arg_calls(): + # No need to clear any caches since we generally jit and pmap fresh callables + # in tests. + + array_shard_arg = array._array_shard_arg + count = [0] + + def array_shard_arg_and_count(*args, **kwargs): + count[0] += 1 + return array_shard_arg(*args, **kwargs) + + pxla.shard_arg_handlers[array.ArrayImpl] = array_shard_arg_and_count + try: + yield count + finally: + pxla.shard_arg_handlers[array.ArrayImpl] = array_shard_arg + + +@contextmanager +def count_jit_compilation_cache_miss(): + # No need to clear any caches since we generally jit and pmap fresh callables + # in tests. + + jit_compilation = pxla._cached_compilation + count = [0] + + def compile_and_count(*args, **kwargs): + count[0] += 1 + return jit_compilation(*args, **kwargs) + + pxla._cached_compilation = compile_and_count + try: + yield count + finally: + pxla._cached_compilation = jit_compilation + + @contextmanager def count_subjaxpr_to_hlo_conversion(fun_name: str): # No need to clear any caches since we generally jit and pmap fresh callables @@ -405,7 +444,7 @@ def mlir_lower_and_count(ctx, name, *args, **kwargs): @contextmanager def assert_num_jit_and_pmap_compilations(times): - with count_jit_and_pmap_compiles() as count: + with count_jit_and_pmap_lowerings() as count: yield if count[0] != times: raise AssertionError(f"Expected exactly {times} XLA compilations, " @@ -487,6 +526,8 @@ def is_device_tpu(version: int | None = None, variant: str = "") -> bool: # Special case v5e until the name is updated in device_kind if expected_version == "v5e": return "v5 lite" in device_kind + elif expected_version == "v6e": + return "v6 lite" in device_kind return expected_version in device_kind def is_cuda_compute_capability_at_least(capability: str) -> bool: @@ -1167,7 +1208,7 @@ def assertArraysEqual(self, x, y, *, check_dtypes=True, err_msg='', y = np.asarray(y) if (not allow_object_dtype) and (x.dtype == object or y.dtype == object): - # See https://github.com/google/jax/issues/17867 + # See https://github.com/jax-ml/jax/issues/17867 raise TypeError( "assertArraysEqual may be poorly behaved when np.asarray casts to dtype=object. " "If comparing PRNG keys, consider random_test.KeyArrayTest.assertKeysEqual. " @@ -1356,15 +1397,16 @@ def with_and_without_mesh(f): ('Mesh', (('x', 2),), (('i', 'x'),)) ))(with_mesh_from_kwargs(f)) -def create_global_mesh(mesh_shape, axis_names): +def create_mesh(mesh_shape, axis_names, iota_order=False): size = math.prod(mesh_shape) if len(jax.devices()) < size: raise unittest.SkipTest(f"Test requires {size} global devices.") - devices = sorted(jax.devices(), key=lambda d: d.id) - mesh_devices = np.array(devices[:size]).reshape(mesh_shape) - global_mesh = jax.sharding.Mesh(mesh_devices, axis_names) - return global_mesh - + if iota_order: + devices = sorted(jax.devices(), key=lambda d: d.id) + mesh_devices = np.array(devices[:size]).reshape(mesh_shape) + return jax.sharding.Mesh(mesh_devices, axis_names) + else: + return jax.make_mesh(mesh_shape, axis_names) class _cached_property: null = object() @@ -1943,7 +1985,7 @@ def arcsin(self, x): # On branch cut, mpmath.mp.asin returns different value compared # to mpmath.fp.asin and numpy.arcsin (see # mpmath/mpmath#786). The following if-block ensures - # compatibiliy with numpy.arcsin. + # compatibility with numpy.arcsin. if x.real > 1 and x.imag == 0: return ctx.asin(x).conjugate() @@ -1975,7 +2017,7 @@ def arccos(self, x): return ctx.make_mpc((real._mpf_, (-sign_imag * inf)._mpf_)) # On branch cut, mpmath.mp.acos returns different value # compared to mpmath.fp.acos and numpy.arccos. The - # following if-block ensures compatibiliy with + # following if-block ensures compatibility with # numpy.arccos. if x.imag == 0 and x.real > 1: return -ctx.acos(x) @@ -2004,7 +2046,7 @@ def arcsinh(self, x): # On branch cut, mpmath.mp.asinh returns different value # compared to mpmath.fp.asinh and numpy.arcsinh (see # mpmath/mpmath#786). The following if-block ensures - # compatibiliy with numpy.arcsinh. + # compatibility with numpy.arcsinh. if x.real == 0 and x.imag < -1: return (-ctx.asinh(x)).conjugate() return ctx.asinh(x) @@ -2032,6 +2074,51 @@ def arccosh(self, x): return ctx.make_mpc((inf._mpf_, imag._mpf_)) return ctx.acosh(x) + def arctan(self, x): + ctx = x.context + + if isinstance(x, ctx.mpc): + # Workaround mpmath 1.3 bug in atan(+-inf+-infj) evaluation + # (see mpmath/mpmath#775 with the fix). + # TODO(pearu): remove the if-block below when mpmath 1.4 or + # newer will be the required test dependency. + pi = ctx.pi + zero = ctx.zero + if ctx.isinf(x.real) or ctx.isinf(x.imag): + if x.real < 0: + return ctx.make_mpc(((-pi / 2)._mpf_, zero._mpf_)) + return ctx.make_mpc(((pi / 2)._mpf_, zero._mpf_)) + + # On branch cut, mpmath.mp.atan returns different value compared + # to mpmath.fp.atan and numpy.arctan (see mpmath/mpmath#865). + # The following if-block ensures compatibility with + # numpy.arctan. + if x.real == 0 and x.imag < -1: + return (-ctx.atan(x)).conjugate() + return ctx.atan(x) + + def arctanh(self, x): + ctx = x.context + + if isinstance(x, ctx.mpc): + # Workaround mpmath 1.3 bug in atanh(+-inf+-infj) evaluation + # (see mpmath/mpmath#775 with the fix). + # TODO(pearu): remove the if-block below when mpmath 1.4 or + # newer will be the required test dependency. + pi = ctx.pi + zero = ctx.zero + if ctx.isinf(x.real) or ctx.isinf(x.imag): + if x.imag < 0: + return ctx.make_mpc((zero._mpf_, (-pi / 2)._mpf_)) + return ctx.make_mpc((zero._mpf_, (pi / 2)._mpf_)) + + # On branch cut, mpmath.mp.atanh returns different value + # compared to mpmath.fp.atanh and numpy.arctanh. The following + # if-block ensures compatibility with numpy.arctanh. + if x.imag == 0 and x.real > 1: + return ctx.atanh(x).conjugate() + return ctx.atanh(x) + def normalize(self, exact, reference, value): """Normalize reference and value using precision defined by the difference of exact and reference. diff --git a/jax/_src/third_party/scipy/special.py b/jax/_src/third_party/scipy/special.py new file mode 100644 index 000000000000..67ef09f6de37 --- /dev/null +++ b/jax/_src/third_party/scipy/special.py @@ -0,0 +1,322 @@ +from __future__ import annotations + +import jax.numpy as jnp +from jax import jit + +from jax._src import custom_derivatives, dtypes +from jax._src.numpy.lax_numpy import complexfloating +from jax._src.numpy.util import promote_args_inexact +from jax._src.typing import Array, ArrayLike + + +@jit +def sincospisquaredhalf( + x: Array, +) -> tuple[Array, Array]: + """ + Accurate evaluation of sin(pi * x**2 / 2) and cos(pi * x**2 / 2). + + As based on the sinpi and cospi functions from SciPy, see: + - https://github.com/scipy/scipy/blob/v1.14.0/scipy/special/special/cephes/trig.h + """ + x = jnp.abs(x) + # define s = x % 2, y = x - s, then + # r = (x * x / 2) % 2 + # = [(y + s)*(y + s)/2] % 2 + # = [y*y/2 + s*y + s*s/2] % 2 + # = [(y*y/2)%2 + (s*y + s*s/2)%2]%2 + # = [0 + (s*(y+s/2))%2]%2 + # = [s*(x-s/2)]%2 + s = jnp.fmod(x, 2.0) + r = jnp.fmod(s * (x - s / 2), 2.0) + + sinpi = jnp.where( + r < 0.5, + jnp.sin(jnp.pi * r), + jnp.where( + r > 1.5, + jnp.sin(jnp.pi * (r - 2.0)), + -jnp.sin(jnp.pi * (r - 1.0)), + ), + ) + cospi = jnp.where( + r == 0.5, + 0.0, + jnp.where(r < 1.0, -jnp.sin(jnp.pi * (r - 0.5)), jnp.sin(jnp.pi * (r - 1.5))), + ) + + return sinpi, cospi + + +@custom_derivatives.custom_jvp +def fresnel(x: ArrayLike) -> tuple[Array, Array]: + r"""The Fresnel integrals + + JAX implementation of :obj:`scipy.special.fresnel`. + + The Fresnel integrals are defined as + .. math:: + S(x) &= \int_0^x \sin(\pi t^2 /2) dt \\ + C(x) &= \int_0^x \cos(\pi t^2 /2) dt. + + Args: + x: arraylike, real-valued. + + Returns: + Arrays containing the values of the Fresnel integrals. + + Notes: + The JAX version only supports real-valued inputs, and + is based on the SciPy C++ implementation, see + `here + `_. + For ``float32`` dtypes, the implementation is directly based + on the Cephes implementation ``fresnlf``. + + As for the original Cephes implementation, the accuracy + is only guaranteed in the domain [-10, 10]. Outside of + that domain, one could observe divergence between the + theoretical derivatives and the custom JVP implementation, + especially for large input values. + + Finally, for half-precision data types, ``float16`` + and ``bfloat16``, the array elements are upcasted to + ``float32`` as the Cephes coefficients used in + series expansions would otherwise lead to poor results. + Other data types, like ``float8``, are not supported. + """ + + xxa, = promote_args_inexact("fresnel", x) + original_dtype = xxa.dtype + + # This part is mostly a direct translation of SciPy's C++ code, + # and the original Cephes implementation for single precision. + + if dtypes.issubdtype(xxa.dtype, complexfloating): + raise NotImplementedError( + 'Support for complex-valued inputs is not implemented yet.') + elif xxa.dtype in (jnp.float32, jnp.float16, jnp.bfloat16): + # Single-precision Cephes coefficients + + # For half-precision, series expansions have either + # produce overflow or poor accuracy. + # Upcasting to single-precision is hence needed. + xxa = xxa.astype(jnp.float32) # No-op for float32 + + fresnl_sn = jnp.array([ + +1.647629463788700e-9, + -1.522754752581096e-7, + +8.424748808502400e-6, + -3.120693124703272e-4, + +7.244727626597022e-3, + -9.228055941124598e-2, + +5.235987735681432e-1, + ], dtype=jnp.float32) + + fresnl_cn = jnp.array([ + +1.416802502367354e-8, + -1.157231412229871e-6, + +5.387223446683264e-5, + -1.604381798862293e-3, + +2.818489036795073e-2, + -2.467398198317899e-1, + +9.999999760004487e-1, + ], dtype=jnp.float32) + + fresnl_fn = jnp.array([ + -1.903009855649792e12, + +1.355942388050252e11, + -4.158143148511033e9, + +7.343848463587323e7, + -8.732356681548485e5, + +8.560515466275470e3, + -1.032877601091159e2, + +2.999401847870011e0, + ], dtype=jnp.float32) + + fresnl_gn = jnp.array([ + -1.860843997624650e11, + +1.278350673393208e10, + -3.779387713202229e8, + +6.492611570598858e6, + -7.787789623358162e4, + +8.602931494734327e2, + -1.493439396592284e1, + +9.999841934744914e-1, + ], dtype=jnp.float32) + elif xxa.dtype == jnp.float64: + # Double-precision Cephes coefficients + + fresnl_sn = jnp.array([ + -2.99181919401019853726e3, + +7.08840045257738576863e5, + -6.29741486205862506537e7, + +2.54890880573376359104e9, + -4.42979518059697779103e10, + +3.18016297876567817986e11, + ], dtype=jnp.float64) + + fresnl_sd = jnp.array([ + +1.00000000000000000000e0, + +2.81376268889994315696e2, + +4.55847810806532581675e4, + +5.17343888770096400730e6, + +4.19320245898111231129e8, + +2.24411795645340920940e10, + +6.07366389490084639049e11, + ], dtype=jnp.float64) + + fresnl_cn = jnp.array([ + -4.98843114573573548651e-8, + +9.50428062829859605134e-6, + -6.45191435683965050962e-4, + +1.88843319396703850064e-2, + -2.05525900955013891793e-1, + +9.99999999999999998822e-1, + ], dtype=jnp.float64) + + fresnl_cd = jnp.array([ + +3.99982968972495980367e-12, + +9.15439215774657478799e-10, + +1.25001862479598821474e-7, + +1.22262789024179030997e-5, + +8.68029542941784300606e-4, + +4.12142090722199792936e-2, + +1.00000000000000000118e0, + ], dtype=jnp.float64) + + fresnl_fn = jnp.array([ + +4.21543555043677546506e-1, + +1.43407919780758885261e-1, + +1.15220955073585758835e-2, + +3.45017939782574027900e-4, + +4.63613749287867322088e-6, + +3.05568983790257605827e-8, + +1.02304514164907233465e-10, + +1.72010743268161828879e-13, + +1.34283276233062758925e-16, + +3.76329711269987889006e-20, + ], dtype=jnp.float64) + + fresnl_fd = jnp.array([ + +1.00000000000000000000e0, + +7.51586398353378947175e-1, + +1.16888925859191382142e-1, + +6.44051526508858611005e-3, + +1.55934409164153020873e-4, + +1.84627567348930545870e-6, + +1.12699224763999035261e-8, + +3.60140029589371370404e-11, + +5.88754533621578410010e-14, + +4.52001434074129701496e-17, + +1.25443237090011264384e-20, + ], dtype=jnp.float64) + + fresnl_gn = jnp.array([ + +5.04442073643383265887e-1, + +1.97102833525523411709e-1, + +1.87648584092575249293e-2, + +6.84079380915393090172e-4, + +1.15138826111884280931e-5, + +9.82852443688422223854e-8, + +4.45344415861750144738e-10, + +1.08268041139020870318e-12, + +1.37555460633261799868e-15, + +8.36354435630677421531e-19, + +1.86958710162783235106e-22, + ], dtype=jnp.float64) + + fresnl_gd = jnp.array([ + +1.00000000000000000000e0, + +1.47495759925128324529e0, + +3.37748989120019970451e-1, + +2.53603741420338795122e-2, + +8.14679107184306179049e-4, + +1.27545075667729118702e-5, + +1.04314589657571990585e-7, + +4.60680728146520428211e-10, + +1.10273215066240270757e-12, + +1.38796531259578871258e-15, + +8.39158816283118707363e-19, + +1.86958710162783236342e-22, + ], dtype=jnp.float64) + else: + raise NotImplementedError( + f'Support for {xxa.dtype} dtype is not implemented yet.') + + assert xxa.dtype in (jnp.float32, jnp.float64) + single_precision = (xxa.dtype == jnp.float32) + + x = jnp.abs(xxa) + + x2 = x * x + + # Infinite x values + s_inf = c_inf = 0.5 + + # Small x values + t = x2 * x2 + + if single_precision: + s_small = x * x2 * jnp.polyval(fresnl_sn, t) + c_small = x * jnp.polyval(fresnl_cn, t) + else: + s_small = x * x2 * jnp.polyval(fresnl_sn[:6], t) / jnp.polyval(fresnl_sd[:7], t) + c_small = x * jnp.polyval(fresnl_cn[:6], t) / jnp.polyval(fresnl_cd[:7], t) + + # Large x values + + sinpi, cospi = sincospisquaredhalf(x) + + if single_precision: + c_large = c_inf + s_large = s_inf + else: + c_large = 0.5 + 1 / (jnp.pi * x) * sinpi + s_large = 0.5 - 1 / (jnp.pi * x) * cospi + + # Other x values + t = jnp.pi * x2 + u = 1.0 / (t * t) + t = 1.0 / t + + if single_precision: + f = 1.0 - u * jnp.polyval(fresnl_fn, u) + g = t * jnp.polyval(fresnl_gn, u) + else: + f = 1.0 - u * jnp.polyval(fresnl_fn, u) / jnp.polyval(fresnl_fd, u) + g = t * jnp.polyval(fresnl_gn, u) / jnp.polyval(fresnl_gd, u) + + t = jnp.pi * x + c_other = 0.5 + (f * sinpi - g * cospi) / t + s_other = 0.5 - (f * cospi + g * sinpi) / t + + isinf = jnp.isinf(xxa) + small = x2 < 2.5625 + large = x > 36974.0 + s = jnp.where( + isinf, s_inf, jnp.where(small, s_small, jnp.where(large, s_large, s_other)) + ) + c = jnp.where( + isinf, c_inf, jnp.where(small, c_small, jnp.where(large, c_large, c_other)) + ) + + neg = xxa < 0.0 + s = jnp.where(neg, -s, s) + c = jnp.where(neg, -c, c) + + if original_dtype != xxa.dtype: + s = s.astype(original_dtype) + c = c.astype(original_dtype) + + return s, c + +def _fresnel_jvp(primals, tangents): + x, = primals + x_dot, = tangents + result = fresnel(x) + sinpi, cospi = sincospisquaredhalf(x) + dSdx = sinpi * x_dot + dCdx = cospi * x_dot + return result, (dSdx, dCdx) +fresnel.defjvp(_fresnel_jvp) diff --git a/jax/_src/tpu_custom_call.py b/jax/_src/tpu_custom_call.py index f77ed0666705..97b6a2cfd32a 100644 --- a/jax/_src/tpu_custom_call.py +++ b/jax/_src/tpu_custom_call.py @@ -21,6 +21,7 @@ import collections.abc from collections.abc import Callable, Sequence import dataclasses +import enum import functools import io import os @@ -67,6 +68,23 @@ tpu_custom_call_p.multiple_results = True +class MemorySpace(enum.Enum): + HBM = enum.auto() + VMEM = enum.auto() + SEMAPHORE_MEM = enum.auto() + + @property + def color(self) -> int: + if self == MemorySpace.HBM: + return 0 + elif self == MemorySpace.VMEM: + return 1 + elif self == MemorySpace.SEMAPHORE_MEM: + return 2 + else: + raise ValueError("invalid memory space: " + str(self)) + + @dataclasses.dataclass(frozen=True) class CostEstimate: flops: int @@ -95,6 +113,7 @@ class CustomCallBackendConfig: allow_input_fusion: list[bool] | None serialization_format: int | None internal_scratch_in_bytes: int | None + output_memory_spaces: tuple[MemorySpace | None, ...] | None # We omit the body while printing, because primitive params get embedded # in HLO metadata, and the body blows up its size. @@ -137,6 +156,14 @@ def to_json(self) -> bytes: if self.internal_scratch_in_bytes is not None: config.write(b', "internal_scratch_in_bytes": ') config.write(str(self.internal_scratch_in_bytes).encode("ascii")) + if self.output_memory_spaces is not None: + config.write(b', "output_memory_colors": [') + for i, memory_space in enumerate(self.output_memory_spaces): + if i: + config.write(b",") + color = memory_space.color if memory_space is not None else -1 + config.write(str(color).encode("ascii")) + config.write(b"]") config.write(b"}") # End of custom_call_config. if self.device_type is not None: config.write(b', "device_type": ') @@ -420,6 +447,7 @@ def _lower_to_custom_call_config( internal_scratch_in_bytes: int | None, collective_id: int | None, serialization_format: int | None, + output_memory_spaces: tuple[MemorySpace | None, ...] | None = None, ) -> CustomCallBackendConfig: lowered_module_asm, ( has_communication, @@ -445,6 +473,7 @@ def _lower_to_custom_call_config( has_communication=has_communication, needs_hlo_passes=needs_hlo_passes, needs_layout_passes=needs_layout_passes, + output_memory_spaces=output_memory_spaces, ) @@ -463,6 +492,7 @@ def _lowered_to_custom_call_config( needs_hlo_passes: bool, needs_layout_passes: bool, device_type: str | None, + output_memory_spaces: tuple[MemorySpace | None, ...] | None = None, ): if has_custom_barrier: if collective_id is None: @@ -492,6 +522,7 @@ def _lowered_to_custom_call_config( allow_input_fusion, serialization_format, internal_scratch_in_bytes, + output_memory_spaces, ) return config @@ -511,6 +542,7 @@ def lower_module_to_custom_call( internal_scratch_in_bytes: int | None, collective_id: int | None, serialization_format: int | None, + output_memory_spaces: tuple[MemorySpace | None, ...] | None, device_type: str | None, ) -> Sequence[ir.Value]: config = _lower_to_custom_call_config( @@ -524,6 +556,7 @@ def lower_module_to_custom_call( collective_id=collective_id, device_type=device_type, serialization_format=serialization_format, + output_memory_spaces=output_memory_spaces, ) return _tpu_custom_call_lowering( ctx, @@ -550,6 +583,7 @@ def as_tpu_kernel( internal_scratch_in_bytes: int | None = None, collective_id: int | None = None, serialization_format: int | None = 1, + output_memory_spaces: tuple[MemorySpace | None, ...] | None = None, ) -> Callable[..., Any]: """Turns an MLIR Mosaic kernel into a JAX-compatible function.""" config = _lower_to_custom_call_config( @@ -563,6 +597,7 @@ def as_tpu_kernel( internal_scratch_in_bytes=internal_scratch_in_bytes, collective_id=collective_id, serialization_format=serialization_format, + output_memory_spaces=output_memory_spaces, ) return _as_jax_callable( config, diff --git a/jax/_src/tree_util.py b/jax/_src/tree_util.py index 2b69c80edad6..b1c18a48263f 100644 --- a/jax/_src/tree_util.py +++ b/jax/_src/tree_util.py @@ -15,6 +15,7 @@ import collections from collections.abc import Callable, Hashable, Iterable, Sequence +import dataclasses from dataclasses import dataclass import difflib import functools @@ -925,7 +926,10 @@ class that defines how it could be flattened with keys. @export def register_dataclass( - nodetype: Typ, data_fields: Sequence[str], meta_fields: Sequence[str] + nodetype: Typ, + data_fields: Sequence[str], + meta_fields: Sequence[str], + drop_fields: Sequence[str] = (), ) -> Typ: """Extends the set of types that are considered internal nodes in pytrees. @@ -1001,6 +1005,23 @@ def register_dataclass( meta_fields = tuple(meta_fields) data_fields = tuple(data_fields) + if dataclasses.is_dataclass(nodetype): + init_fields = {f.name for f in dataclasses.fields(nodetype) if f.init} + init_fields.difference_update(*drop_fields) + if {*meta_fields, *data_fields} != init_fields: + msg = ( + "data_fields and meta_fields must include all dataclass fields with" + " ``init=True`` and only them." + ) + if missing := init_fields - {*meta_fields, *data_fields}: + msg += ( + f" Missing fields: {missing}. Add them to drop_fields to suppress" + " this error." + ) + if unexpected := {*meta_fields, *data_fields} - init_fields: + msg += f" Unexpected fields: {unexpected}." + raise ValueError(msg) + def flatten_with_keys(x): meta = tuple(getattr(x, name) for name in meta_fields) data = tuple((GetAttrKey(name), getattr(x, name)) for name in data_fields) diff --git a/jax/_src/typing.py b/jax/_src/typing.py index 0caa6e7c643b..010841b45dd2 100644 --- a/jax/_src/typing.py +++ b/jax/_src/typing.py @@ -21,7 +21,7 @@ and may change without notice. To see the proposal that led to the development of these tools, see -https://github.com/google/jax/pull/11859/. +https://github.com/jax-ml/jax/pull/11859/. """ from __future__ import annotations diff --git a/jax/_src/util.py b/jax/_src/util.py index 5174b21c2323..fce342c493ed 100644 --- a/jax/_src/util.py +++ b/jax/_src/util.py @@ -611,55 +611,40 @@ def wrapper(func: T) -> T: return wrapper -if TYPE_CHECKING: - def use_cpp_class(cpp_cls: Any) -> Callable[[T], T]: - def wrapper(cls: T) -> T: - return cls - return wrapper +def use_cpp_class(cpp_cls: type[Any]) -> Callable[[type[T]], type[T]]: + """A decorator replacing a Python class with its C++ version at runtime.""" - def use_cpp_method(is_enabled: bool = True) -> Callable[[T], T]: - def wrapper(cls: T) -> T: + def wrapper(cls): + if cpp_cls is None: return cls - return wrapper -else: - def use_cpp_class(cpp_cls): - """A helper decorator to replace a python class with its C++ version""" + exclude_methods = {'__module__', '__dict__', '__doc__'} - def wrapper(cls): - if cpp_cls is None: - return cls + originals = {} + for attr_name, attr in cls.__dict__.items(): + if attr_name not in exclude_methods: + if hasattr(_original_func(attr), "_use_cpp"): + originals[attr_name] = attr + else: + setattr(cpp_cls, attr_name, attr) - exclude_methods = {'__module__', '__dict__', '__doc__'} + cpp_cls.__doc__ = cls.__doc__ + # TODO(pschuh): Remove once fastpath is gone. + cpp_cls._original_py_fns = originals + return cpp_cls - originals = {} - for attr_name, attr in cls.__dict__.items(): - if attr_name not in exclude_methods: - if hasattr(_original_func(attr), "_use_cpp"): - originals[attr_name] = attr - else: - setattr(cpp_cls, attr_name, attr) - - cpp_cls.__doc__ = cls.__doc__ - # TODO(pschuh): Remove once fastpath is gone. - cpp_cls._original_py_fns = originals - return cpp_cls - - return wrapper + return wrapper - def use_cpp_method(is_enabled=True): - """A helper decorator to exclude methods from the set that are forwarded to C++ class""" - def decorator(f): - if is_enabled: - original_func = _original_func(f) - original_func._use_cpp = True - return f - - if not isinstance(is_enabled, bool): - raise TypeError( - "Decorator got wrong type: @use_cpp_method(is_enabled: bool=True)" - ) - return decorator +def use_cpp_method(is_enabled: bool = True) -> Callable[[T], T]: + """A decorator excluding methods from the set that are forwarded to C++ class.""" + if not isinstance(is_enabled, bool): + raise TypeError("``is_enabled`` must be a bool") + def decorator(f): + if is_enabled: + original_func = _original_func(f) + original_func._use_cpp = True + return f + return decorator try: diff --git a/jax/_src/xla_bridge.py b/jax/_src/xla_bridge.py index 91d761fec5d4..796093b6225f 100644 --- a/jax/_src/xla_bridge.py +++ b/jax/_src/xla_bridge.py @@ -1187,15 +1187,30 @@ def host_count(backend: str | xla_client.Client | None = None) -> int: return process_count(backend) +def process_indices( + backend: str | xla_client.Client | None = None +) -> list[int]: + """Returns the list of all JAX process indices associated with the backend. + + Args: + backend: This is an experimental feature and the API is likely to change. + Optional, a string representing the xla backend: ``'cpu'``, ``'gpu'``, or + ``'tpu'``. + + Returns: + List of integer process indices. + """ + return list(range(process_count(backend))) + + # TODO: remove this sometime after jax 0.2.13 is released def host_ids( backend: str | xla_client.Client | None = None ) -> list[int]: warnings.warn( - "jax.host_ids has been deprecated; please use range(jax.process_count()) " - "instead. jax.host_ids will eventually be removed; please update your " - "code.") - return list(range(process_count(backend))) + "jax.host_ids has been renamed to jax.process_indices. This alias " + "will eventually be removed; please update your code.") + return process_indices(backend) def using_pjrt_c_api(backend=None): @@ -1217,7 +1232,7 @@ def make_pjrt_tpu_topology(topology_name='', **kwargs): if library_path is None: raise RuntimeError( "JAX TPU support not installed; cannot generate TPU topology. See" - " https://github.com/google/jax#installation") + " https://github.com/jax-ml/jax#installation") c_api = xla_client.load_pjrt_plugin_dynamically("tpu", library_path) xla_client.profiler.register_plugin_profiler(c_api) assert xla_client.pjrt_plugin_loaded("tpu") diff --git a/jax/_src/xla_metadata.py b/jax/_src/xla_metadata.py new file mode 100644 index 000000000000..94b482e2dea4 --- /dev/null +++ b/jax/_src/xla_metadata.py @@ -0,0 +1,55 @@ +# Copyright 2024 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any +import threading +from contextlib import contextmanager + +from jax._src import config + + +class _XlaMetadata(threading.local): + val: dict[Any, Any] + + def __init__(self): + self.val = {} + +thread_local_metadata = _XlaMetadata() + +def current_xla_metadata(): + return thread_local_metadata.val + +@contextmanager +def set_xla_metadata(*args, **kwargs): + new_metadata = thread_local_metadata.val.copy() + if args: + new_metadata.update(args[0] if args[0] else {}) + else: + new_metadata.update(**kwargs) + prev_metadata, thread_local_metadata.val = ( + thread_local_metadata.val, + new_metadata, + ) + config.update_thread_local_jit_state( + xla_metadata_context_manager=tuple( + (v, k) for k, v in sorted(new_metadata.items()))) + try: + yield + finally: + thread_local_metadata.val = prev_metadata + config.update_thread_local_jit_state( + xla_metadata_context_manager=tuple( + (v, k) for k, v in sorted(prev_metadata.items()) + ) + ) diff --git a/jax/collect_profile.py b/jax/collect_profile.py index a7777085ce90..d1309e0c5bca 100644 --- a/jax/collect_profile.py +++ b/jax/collect_profile.py @@ -66,7 +66,7 @@ help="Profiler Python tracer level", type=int) def collect_profile(port: int, duration_in_ms: int, host: str, - log_dir: str | None, host_tracer_level: int, + log_dir: os.PathLike | str | None, host_tracer_level: int, device_tracer_level: int, python_tracer_level: int, no_perfetto_link: bool): options = profiler.ProfilerOptions( @@ -97,7 +97,7 @@ def collect_profile(port: int, duration_in_ms: int, host: str, fp.write(result.encode("utf-8")) if not no_perfetto_link: - path = jax_profiler._write_perfetto_trace_file(str(log_dir_)) + path = jax_profiler._write_perfetto_trace_file(log_dir_) jax_profiler._host_perfetto_trace_file(path) def main(args): diff --git a/jax/core.py b/jax/core.py index 80025e8619f3..90ef668b2493 100644 --- a/jax/core.py +++ b/jax/core.py @@ -13,7 +13,7 @@ # limitations under the License. # Note: import as is required for names to be exported. -# See PEP 484 & https://github.com/google/jax/issues/7570 +# See PEP 484 & https://github.com/jax-ml/jax/issues/7570 from jax._src.core import ( AbstractToken as AbstractToken, @@ -29,6 +29,7 @@ Effect as Effect, Effects as Effects, EvalTrace as EvalTrace, + get_opaque_trace_state as get_opaque_trace_state, InDBIdx as InDBIdx, InconclusiveDimensionOperation as InconclusiveDimensionOperation, InputType as InputType, @@ -41,6 +42,8 @@ Literal as Literal, MainTrace as MainTrace, MapPrimitive as MapPrimitive, + nonempty_axis_env as nonempty_axis_env_DO_NOT_USE, + OpaqueTraceState as OpaqueTraceState, NameGatheringSubst as NameGatheringSubst, OutDBIdx as OutDBIdx, OutputType as OutputType, @@ -55,6 +58,9 @@ TraceStack as TraceStack, TraceState as TraceState, Tracer as Tracer, + unsafe_am_i_under_a_jit as unsafe_am_i_under_a_jit_DO_NOT_USE, + unsafe_am_i_under_a_vmap as unsafe_am_i_under_a_vmap_DO_NOT_USE, + unsafe_get_axis_names as unsafe_get_axis_names_DO_NOT_USE, UnshapedArray as UnshapedArray, Value as Value, Var as Var, @@ -66,10 +72,7 @@ call_bind_with_continuation as call_bind_with_continuation, call_impl as call_impl, call_p as call_p, - check_eqn as check_eqn, check_jaxpr as check_jaxpr, - check_type as check_type, - check_valid_jaxtype as check_valid_jaxtype, closed_call_p as closed_call_p, concrete_aval as concrete_aval, concrete_or_error as concrete_or_error, @@ -88,6 +91,7 @@ full_lower as full_lower, gensym as gensym, get_aval as get_aval, + get_type as get_type, get_referent as get_referent, is_constant_dim as is_constant_dim, is_constant_shape as is_constant_shape, @@ -110,7 +114,6 @@ new_sublevel as new_sublevel, no_axis_name as no_axis_name, no_effects as no_effects, - non_negative_dim as _deprecated_non_negative_dim, outfeed_primitives as outfeed_primitives, primal_dtype_to_tangent_dtype as primal_dtype_to_tangent_dtype, primitive_uses_outfeed as primitive_uses_outfeed, @@ -144,19 +147,26 @@ from jax._src import core as _src_core _deprecations = { - # Added 2024-06-12 - "pp_aval": ("jax.core.pp_aval is deprecated.", _src_core.pp_aval), - "pp_eqn": ("jax.core.pp_eqn is deprecated.", _src_core.pp_eqn), - "pp_eqn_rules": ("jax.core.pp_eqn_rules is deprecated.", _src_core.pp_eqn_rules), - "pp_eqns": ("jax.core.pp_eqns is deprecated.", _src_core.pp_eqns), - "pp_jaxpr": ("jax.core.pp_jaxpr is deprecated.", _src_core.pp_jaxpr), - "pp_jaxpr_eqn_range": ("jax.core.pp_jaxpr_eqn_range is deprecated.", _src_core.pp_jaxpr_eqn_range), - "pp_jaxpr_skeleton": ("jax.core.pp_jaxpr_skeleton is deprecated.", _src_core.pp_jaxpr_skeleton), - "pp_jaxprs": ("jax.core.pp_jaxprs is deprecated.", _src_core.pp_jaxprs), - "pp_kv_pair": ("jax.core.pp_kv_pair is deprecated.", _src_core.pp_kv_pair), - "pp_kv_pairs": ("jax.core.pp_kv_pairs is deprecated.", _src_core.pp_kv_pairs), - "pp_var": ("jax.core.pp_var is deprecated.", _src_core.pp_var), - "pp_vars": ("jax.core.pp_vars is deprecated.", _src_core.pp_vars), + # Added 2024-08-14 + "check_eqn": ("jax.core.check_eqn is deprecated.", _src_core.check_eqn), + "check_type": ("jax.core.check_type is deprecated.", _src_core.check_type), + "check_valid_jaxtype": ( + ("jax.core.check_valid_jaxtype is deprecated. Instead, you can manually" + " raise an error if core.valid_jaxtype() returns False."), + _src_core.check_valid_jaxtype), + # Finalized 2024-09-25; remove after 2024-12-25 + "pp_aval": ("jax.core.pp_aval was removed in JAX v0.4.34.", None), + "pp_eqn": ("jax.core.pp_eqn was removed in JAX v0.4.34.", None), + "pp_eqn_rules": ("jax.core.pp_eqn_rules was removed in JAX v0.4.34.", None), + "pp_eqns": ("jax.core.pp_eqns was removed in JAX v0.4.34.", None), + "pp_jaxpr": ("jax.core.pp_jaxpr was removed in JAX v0.4.34.", None), + "pp_jaxpr_eqn_range": ("jax.core.pp_jaxpr_eqn_range was removed in JAX v0.4.34.", None), + "pp_jaxpr_skeleton": ("jax.core.pp_jaxpr_skeleton was removed in JAX v0.4.34.", None), + "pp_jaxprs": ("jax.core.pp_jaxprs was removed in JAX v0.4.34.", None), + "pp_kv_pair": ("jax.core.pp_kv_pair was removed in JAX v0.4.34.", None), + "pp_kv_pairs": ("jax.core.pp_kv_pairs was removed in JAX v0.4.34.", None), + "pp_var": ("jax.core.pp_var was removed in JAX v0.4.34.", None), + "pp_vars": ("jax.core.pp_vars was removed in JAX v0.4.34.", None), # Finalized 2024-05-13; remove after 2024-08-13 "DimSize": ( "jax.core.DimSize is deprecated. Use DimSize = int | Any.", @@ -181,25 +191,16 @@ ), # Added Jan 8, 2024 "non_negative_dim": ( - "jax.core.non_negative_dim is deprecated. Use max_dim(..., 0).", _deprecated_non_negative_dim, + "jax.core.non_negative_dim is deprecated. Use max_dim(..., 0).", _src_core.non_negative_dim, ), } import typing if typing.TYPE_CHECKING: - non_negative_dim = _deprecated_non_negative_dim - pp_aval = _src_core.pp_aval - pp_eqn = _src_core.pp_eqn - pp_eqn_rules = _src_core.pp_eqn_rules - pp_eqns = _src_core.pp_eqns - pp_jaxpr = _src_core.pp_jaxpr - pp_jaxpr_eqn_range = _src_core.pp_jaxpr_eqn_range - pp_jaxpr_skeleton = _src_core.pp_jaxpr_skeleton - pp_jaxprs = _src_core.pp_jaxprs - pp_kv_pair = _src_core.pp_kv_pair - pp_kv_pairs = _src_core.pp_kv_pairs - pp_var = _src_core.pp_var - pp_vars = _src_core.pp_vars + check_eqn = _src_core.check_eqn + check_type = _src_core.check_type + check_valid_jaxtype = _src_core.check_valid_jaxtype + non_negative_dim = _src_core.non_negative_dim else: from jax._src.deprecations import deprecation_getattr as _deprecation_getattr __getattr__ = _deprecation_getattr(__name__, _deprecations) diff --git a/jax/custom_batching.py b/jax/custom_batching.py index a4850f04c2ec..9b8dc8f8709a 100644 --- a/jax/custom_batching.py +++ b/jax/custom_batching.py @@ -13,6 +13,6 @@ # limitations under the License. from jax._src.custom_batching import ( - custom_vmap, - sequential_vmap, + custom_vmap as custom_vmap, + sequential_vmap as sequential_vmap, ) diff --git a/jax/custom_derivatives.py b/jax/custom_derivatives.py index 96dc8898fd8e..ea1ef4f0274e 100644 --- a/jax/custom_derivatives.py +++ b/jax/custom_derivatives.py @@ -13,7 +13,7 @@ # limitations under the License. # Note: import as is required for names to be exported. -# See PEP 484 & https://github.com/google/jax/issues/7570 +# See PEP 484 & https://github.com/jax-ml/jax/issues/7570 from jax._src.custom_derivatives import ( _initial_style_jaxpr, @@ -34,5 +34,6 @@ ) from jax._src.ad_util import ( - SymbolicZero as SymbolicZero + SymbolicZero as SymbolicZero, + zero_from_primal as zero_from_primal ) diff --git a/jax/custom_transpose.py b/jax/custom_transpose.py index 311139da2567..314163c4684a 100644 --- a/jax/custom_transpose.py +++ b/jax/custom_transpose.py @@ -13,5 +13,5 @@ # limitations under the License. from jax._src.custom_transpose import ( - custom_transpose, + custom_transpose as custom_transpose, ) diff --git a/jax/distributed.py b/jax/distributed.py index 284ae6f95f48..cf39b81f423a 100644 --- a/jax/distributed.py +++ b/jax/distributed.py @@ -12,4 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from jax._src.distributed import (initialize, shutdown) +from jax._src.distributed import ( + initialize as initialize, + shutdown as shutdown, +) diff --git a/jax/dlpack.py b/jax/dlpack.py index 707e966ee243..a65496ec0cbf 100644 --- a/jax/dlpack.py +++ b/jax/dlpack.py @@ -12,4 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -from jax._src.dlpack import (to_dlpack, from_dlpack, SUPPORTED_DTYPES) +from jax._src.dlpack import ( + to_dlpack as to_dlpack, + from_dlpack as from_dlpack, + SUPPORTED_DTYPES as SUPPORTED_DTYPES, +) diff --git a/jax/dtypes.py b/jax/dtypes.py index f2071fd4fe56..a6f1b764510b 100644 --- a/jax/dtypes.py +++ b/jax/dtypes.py @@ -13,7 +13,7 @@ # limitations under the License. # Note: import as is required for names to be exported. -# See PEP 484 & https://github.com/google/jax/issues/7570 +# See PEP 484 & https://github.com/jax-ml/jax/issues/7570 from jax._src.dtypes import ( bfloat16 as bfloat16, diff --git a/jax/errors.py b/jax/errors.py index 15a6654fa32d..6da7b717cb5f 100644 --- a/jax/errors.py +++ b/jax/errors.py @@ -13,7 +13,7 @@ # limitations under the License. # Note: import as is required for names to be exported. -# See PEP 484 & https://github.com/google/jax/issues/7570 +# See PEP 484 & https://github.com/jax-ml/jax/issues/7570 from jax._src.errors import ( JAXTypeError as JAXTypeError, @@ -26,4 +26,9 @@ UnexpectedTracerError as UnexpectedTracerError, KeyReuseError as KeyReuseError, ) + +from jax._src.lib import xla_client as _xc +JaxRuntimeError = _xc.XlaRuntimeError +del _xc + from jax._src.traceback_util import SimplifiedTraceback as SimplifiedTraceback diff --git a/jax/example_libraries/optimizers.py b/jax/example_libraries/optimizers.py index 71680ca61b96..3ad717ce358a 100644 --- a/jax/example_libraries/optimizers.py +++ b/jax/example_libraries/optimizers.py @@ -16,7 +16,7 @@ You likely do not mean to import this module! The optimizers in this library are intended as examples only. If you are looking for a fully featured optimizer -library, two good options are JAXopt_ and Optax_. +library, consider Optax_. This module contains some convenient optimizer definitions, specifically initialization and update functions, which can be used with ndarrays or @@ -85,8 +85,7 @@ def step(step, opt_state): value, opt_state = step(i, opt_state) -.. _JAXopt: https://github.com/google/jaxopt -.. _Optax: https://github.com/deepmind/optax +.. _Optax: https://github.com/google-deepmind/optax """ from __future__ import annotations @@ -98,11 +97,9 @@ def step(step, opt_state): import functools from functools import partial +import jax import jax.numpy as jnp from jax._src.util import safe_zip, safe_map, unzip2 -from jax import tree_util -from jax.tree_util import (tree_map, tree_flatten, tree_unflatten, - register_pytree_node) map = safe_map zip = safe_zip @@ -117,7 +114,7 @@ def step(step, opt_state): OptimizerState = namedtuple("OptimizerState", ["packed_state", "tree_def", "subtree_defs"]) -register_pytree_node( +jax.tree_util.register_pytree_node( OptimizerState, lambda xs: ((xs.packed_state,), (xs.tree_def, xs.subtree_defs)), lambda data, xs: OptimizerState(xs[0], data[0], data[1])) @@ -182,23 +179,23 @@ def tree_opt_maker(*args, **kwargs): @functools.wraps(init) def tree_init(x0_tree): - x0_flat, tree = tree_flatten(x0_tree) + x0_flat, tree = jax.tree.flatten(x0_tree) initial_states = [init(x0) for x0 in x0_flat] - states_flat, subtrees = unzip2(map(tree_flatten, initial_states)) + states_flat, subtrees = unzip2(map(jax.tree.flatten, initial_states)) return OptimizerState(states_flat, tree, subtrees) @functools.wraps(update) def tree_update(i, grad_tree, opt_state): states_flat, tree, subtrees = opt_state - grad_flat, tree2 = tree_flatten(grad_tree) + grad_flat, tree2 = jax.tree.flatten(grad_tree) if tree2 != tree: msg = ("optimizer update function was passed a gradient tree that did " "not match the parameter tree structure with which it was " "initialized: parameter tree {} and grad tree {}.") raise TypeError(msg.format(tree, tree2)) - states = map(tree_unflatten, subtrees, states_flat) + states = map(jax.tree.unflatten, subtrees, states_flat) new_states = map(partial(update, i), grad_flat, states) - new_states_flat, subtrees2 = unzip2(map(tree_flatten, new_states)) + new_states_flat, subtrees2 = unzip2(map(jax.tree.flatten, new_states)) for subtree, subtree2 in zip(subtrees, subtrees2): if subtree2 != subtree: msg = ("optimizer update function produced an output structure that " @@ -209,9 +206,9 @@ def tree_update(i, grad_tree, opt_state): @functools.wraps(get_params) def tree_get_params(opt_state): states_flat, tree, subtrees = opt_state - states = map(tree_unflatten, subtrees, states_flat) + states = map(jax.tree.unflatten, subtrees, states_flat) params = map(get_params, states) - return tree_unflatten(tree, params) + return jax.tree.unflatten(tree, params) return Optimizer(tree_init, tree_update, tree_get_params) return tree_opt_maker @@ -566,14 +563,14 @@ def make_schedule(scalar_or_schedule: float | Schedule) -> Schedule: def l2_norm(tree): """Compute the l2 norm of a pytree of arrays. Useful for weight decay.""" - leaves, _ = tree_flatten(tree) + leaves, _ = jax.tree.flatten(tree) return jnp.sqrt(sum(jnp.vdot(x, x) for x in leaves)) def clip_grads(grad_tree, max_norm): """Clip gradients stored as a pytree of arrays to maximum norm `max_norm`.""" norm = l2_norm(grad_tree) normalize = lambda g: jnp.where(norm < max_norm, g, g * (max_norm / norm)) - return tree_map(normalize, grad_tree) + return jax.tree.map(normalize, grad_tree) ### serialization utilities @@ -600,9 +597,9 @@ def unpack_optimizer_state(opt_state): A pytree with JoinPoint leaves that contain a second level of pytrees. """ states_flat, tree_def, subtree_defs = opt_state - subtrees = map(tree_unflatten, subtree_defs, states_flat) + subtrees = map(jax.tree.unflatten, subtree_defs, states_flat) sentinels = [JoinPoint(subtree) for subtree in subtrees] - return tree_util.tree_unflatten(tree_def, sentinels) + return jax.tree.unflatten(tree_def, sentinels) def pack_optimizer_state(marked_pytree): """Converts a marked pytree to an OptimizerState. @@ -617,8 +614,8 @@ def pack_optimizer_state(marked_pytree): Returns: An equivalent OptimizerState to the input argument. """ - sentinels, tree_def = tree_flatten(marked_pytree) + sentinels, tree_def = jax.tree.flatten(marked_pytree) assert all(isinstance(s, JoinPoint) for s in sentinels) subtrees = [s.subtree for s in sentinels] - states_flat, subtree_defs = unzip2(map(tree_flatten, subtrees)) + states_flat, subtree_defs = unzip2(map(jax.tree.flatten, subtrees)) return OptimizerState(states_flat, tree_def, subtree_defs) diff --git a/jax/experimental/__init__.py b/jax/experimental/__init__.py index caf27ec7a8ca..375d058d0edc 100644 --- a/jax/experimental/__init__.py +++ b/jax/experimental/__init__.py @@ -13,7 +13,7 @@ # limitations under the License. # Note: import as is required for names to be exported. -# See PEP 484 & https://github.com/google/jax/issues/7570 +# See PEP 484 & https://github.com/jax-ml/jax/issues/7570 from jax.experimental.x64_context import ( enable_x64 as enable_x64, @@ -22,6 +22,9 @@ from jax._src.callback import ( io_callback as io_callback ) +from jax._src.dtypes import ( + primal_tangent_dtype as primal_tangent_dtype, +) from jax._src.earray import ( EArray as EArray ) diff --git a/jax/experimental/array_api/__init__.py b/jax/experimental/array_api/__init__.py index 69b25d0b6ad9..3ac1d4246f6a 100644 --- a/jax/experimental/array_api/__init__.py +++ b/jax/experimental/array_api/__init__.py @@ -13,7 +13,7 @@ # limitations under the License. # Note: import as is required for names to be exported. -# See PEP 484 & https://github.com/google/jax/issues/7570 +# See PEP 484 & https://github.com/jax-ml/jax/issues/7570 import sys as _sys import warnings as _warnings diff --git a/jax/experimental/array_serialization/serialization.py b/jax/experimental/array_serialization/serialization.py index c7992dc629f1..2620f5cc760c 100644 --- a/jax/experimental/array_serialization/serialization.py +++ b/jax/experimental/array_serialization/serialization.py @@ -189,7 +189,7 @@ async def transfer_shard_to_host(shard: array.Shard) -> np.ndarray: data = shard.data has_pinned_host = any( m.kind == "pinned_host" for m in shard.device.addressable_memories()) - if config.enable_memories.value and has_pinned_host: + if has_pinned_host: # If available, transfer to pinned host memory sharding = jax.sharding.SingleDeviceSharding(shard.device, memory_kind="pinned_host") diff --git a/jax/experimental/array_serialization/serialization_test.py b/jax/experimental/array_serialization/serialization_test.py index 04a64fe55e25..61993637912f 100644 --- a/jax/experimental/array_serialization/serialization_test.py +++ b/jax/experimental/array_serialization/serialization_test.py @@ -52,7 +52,7 @@ def _on_commit_callback(self, temp_ckpt_dir, final_ckpt_dir): @jtu.skip_on_devices('cpu') def test_memory_consumption(self): - global_mesh = jtu.create_global_mesh((2, 4), ('x', 'y')) + global_mesh = jtu.create_mesh((2, 4), ('x', 'y')) inp_shape = (2_048, 4_096) pspec = P('x', 'y') num = math.prod(inp_shape) @@ -97,7 +97,7 @@ async def deserialize_with_byte_limit(): tm.stop() def test_memory_consumption_for_save(self): - global_mesh = jtu.create_global_mesh((1, 1), ('x', 'y')) + global_mesh = jtu.create_mesh((1, 1), ('x', 'y')) inp_shape = (16 * 1024, 16 * 1024) pspec = P('x', 'y') num = math.prod(inp_shape) @@ -132,7 +132,7 @@ def test_memory_consumption_for_save(self): tm.stop() def test_checkpointing_with_path_variant(self): - global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y')) + global_mesh = jtu.create_mesh((4, 2), ('x', 'y')) inp_shape = (8, 2) pspec = P('x', 'y') num = math.prod(inp_shape) @@ -164,7 +164,7 @@ def test_checkpointing_with_path_variant(self): self.assertEqual(m1.dtype, np.int32) def test_checkpointing_jax_array(self): - global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y')) + global_mesh = jtu.create_mesh((4, 2), ('x', 'y')) inp_shape = (8, 2) pspec = P('x', 'y') num = math.prod(inp_shape) @@ -188,7 +188,7 @@ def test_checkpointing_jax_array(self): # Third Array def cb3(_): return np.array([], dtype=np.float32) - global_mesh1d = jtu.create_global_mesh((8,), ('x',)) + global_mesh1d = jtu.create_mesh((8,), ('x',)) a3 = array.make_array_from_callback( (0,), NamedSharding(global_mesh1d, P(None)), cb3) ckpt_path3 = pathlib.Path(self.create_tempdir(f'{ckpt_dir}/third').full_path) @@ -232,7 +232,7 @@ def cb3(_): self.assertEqual(m3.dtype, np.float32) def test_checkpointing_ocdbt_transaction(self): - global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y')) + global_mesh = jtu.create_mesh((4, 2), ('x', 'y')) inp_shape = (8, 2) pspec = P('x', 'y') num = math.prod(inp_shape) @@ -262,7 +262,7 @@ def test_checkpointing_ocdbt_transaction(self): def cb3(_): return np.array([], dtype=np.float32) - global_mesh1d = jtu.create_global_mesh((8,), ('x',)) + global_mesh1d = jtu.create_mesh((8,), ('x',)) a3 = array.make_array_from_callback( (0,), NamedSharding(global_mesh1d, P(None)), cb3 ) @@ -327,7 +327,7 @@ def cb3(_): @parameterized.product(input_dtype=[np.int32, jnp.bfloat16]) def test_checkpointing_with_bigger_shape_jax_array(self, input_dtype): - global_mesh = jtu.create_global_mesh((2, 2), ('x', 'y')) + global_mesh = jtu.create_mesh((2, 2), ('x', 'y'), iota_order=True) global_input_shape = (8, 2) num = math.prod(global_input_shape) @@ -349,7 +349,8 @@ def cb1(index): on_commit_callback=partial(self._on_commit_callback, ckpt_dir, ckpt_dir)) manager.wait_until_finished() - ds = NamedSharding(jtu.create_global_mesh((4, 2), ('x', 'y')), P('x', 'y')) + ds = NamedSharding(jtu.create_mesh((4, 2), ('x', 'y'), iota_order=True), + P('x', 'y')) m1, = serialization.run_deserialization([ds], tspecs, [(12, 2)], [np.float32]) @@ -375,7 +376,7 @@ def cb1(index): @parameterized.product(input_dtype=[jnp.int4, jnp.int8]) def test_checkpointing_with_int4(self, input_dtype): - global_mesh = jtu.create_global_mesh((2, 2), ('x', 'y')) + global_mesh = jtu.create_mesh((2, 2), ('x', 'y'), iota_order=True) global_input_shape = (8, 2) num = math.prod(global_input_shape) @@ -397,7 +398,8 @@ def cb(index): on_commit_callback=partial(self._on_commit_callback, ckpt_dir, ckpt_dir)) manager.wait_until_finished() - ds = NamedSharding(jtu.create_global_mesh((4, 2), ('x', 'y')), P('x', 'y')) + ds = NamedSharding(jtu.create_mesh((4, 2), ('x', 'y'), iota_order=True), + P('x', 'y')) target_dtype = jnp.dtype('int4') m1, = serialization.run_deserialization([ds], tspecs, [(12, 2)], @@ -424,7 +426,7 @@ def cb(index): self.assertArraysEqual(l.data, global_input_data.astype(target_dtype)) def test_checkpointing_scalar_jax_array(self): - global_mesh = jtu.create_global_mesh((2,), ('x')) + global_mesh = jtu.create_mesh((2,), ('x')) global_input_shape = () data = np.array(4) s = NamedSharding(global_mesh, P(None)) @@ -441,7 +443,7 @@ def test_checkpointing_scalar_jax_array(self): on_commit_callback=partial(self._on_commit_callback, ckpt_dir, ckpt_dir)) manager.wait_until_finished() - ds = NamedSharding(jtu.create_global_mesh((2,), ('x')), P(None)) + ds = NamedSharding(jtu.create_mesh((2,), ('x')), P(None)) m1, = serialization.run_deserialization( [ds], @@ -454,7 +456,7 @@ def test_checkpointing_scalar_jax_array(self): self.assertArraysEqual(np.asarray(l.data), data.astype(np.float32)) def test_deserialize_tensorstore_array_jax_array(self): - global_mesh = jtu.create_global_mesh((2,), ('x')) + global_mesh = jtu.create_mesh((2,), ('x')) data = np.arange(1024) tspec = ts.array(data).spec() m1, = serialization.run_deserialization( @@ -550,13 +552,13 @@ def test_load_with_layout(self): if not jtu.test_device_matches(['tpu']): self.skipTest('Layouts are only supported on TPUs') - mesh = jtu.create_global_mesh((4, 2), ('x', 'y')) + mesh = jtu.create_mesh((4, 2), ('x', 'y')) np_inp = np.arange(32).reshape(8, 4) s = NamedSharding(mesh, P('x', 'y')) arr = jax.device_put(np_inp, s) out_layout = jax.jit(lambda x: x.T, out_shardings=Layout(DLL.AUTO)).lower( - arr).compile().output_layouts() + arr).compile().output_layouts self.assertEqual(arr.layout.device_local_layout.major_to_minor, out_layout.device_local_layout.major_to_minor[::-1]) @@ -579,6 +581,8 @@ def test_load_with_layout(self): self.assertArraysEqual(s.data, np_inp[s.index]) def test_deserialization_with_int4(self): + if jtu.test_device_matches(['gpu']): + self.skipTest("Fails on GPU. Enable after it's fixed") dtype = jnp.int4 shape = (8, 2) arr = jnp.arange(np.prod(shape)).reshape(shape).astype(dtype) @@ -611,7 +615,6 @@ def test_deserialization_with_int4(self): self.assertArraysEqual(out + out, out * 2) -@jtu.with_config(jax_enable_memories=True) class TransferShardTest(jtu.JaxTestCase): @jtu.skip_on_devices('cpu') diff --git a/jax/experimental/attrs.py b/jax/experimental/attrs.py index 8176465c1470..62da0f231d50 100644 --- a/jax/experimental/attrs.py +++ b/jax/experimental/attrs.py @@ -169,7 +169,7 @@ def linearize(f, *primals, attrs: list[tuple[Any, str]] = []): def _linearize(traceable: lu.WrappedFun, *primals): jvpfun, attrs = _split_attrs(_jvp(traceable)) in_pvals = (tuple(pe.PartialVal.known(p) for p in primals) - + tuple(pe.PartialVal.unknown(core.get_aval(p).at_least_vspace()) + + tuple(pe.PartialVal.unknown(core.get_aval(p).to_tangent_aval()) for p in primals)) _, in_tree = tree_flatten((primals, primals)) jvpfun_flat, out_tree = flatten_fun_nokwargs(jvpfun, in_tree) @@ -211,7 +211,7 @@ def vjp(f, *primals, attrs: list[tuple[Any, str]] = []): f_, out_tree = flatten_fun_nokwargs(_set_attrs(lu.wrap_init(f), attrs), tree) primal_out, out_pvals, jaxpr, consts, attrs_out = _linearize( f_, *attr_primals, *primals_flat) - attr_avals = [core.raise_to_shaped(core.get_aval(jax_getattr(o, a))).at_least_vspace() + attr_avals = [core.raise_to_shaped(core.get_aval(jax_getattr(o, a))).to_tangent_aval() for o, a in attrs_out] f_vjp = _vjp_wrap(jaxpr, consts, out_pvals, attr_avals, (in_tree, out_tree()), attrs, attrs_out) diff --git a/jax/experimental/checkify.py b/jax/experimental/checkify.py index 0b6b51f71a3c..8e11d4173afe 100644 --- a/jax/experimental/checkify.py +++ b/jax/experimental/checkify.py @@ -13,7 +13,7 @@ # limitations under the License. # Note: import as is required for names to be exported. -# See PEP 484 & https://github.com/google/jax/issues/7570 +# See PEP 484 & https://github.com/jax-ml/jax/issues/7570 from jax._src.checkify import ( Error as Error, diff --git a/jax/experimental/custom_partitioning.py b/jax/experimental/custom_partitioning.py index 3c7bfac40061..6da3ad7c5d4b 100644 --- a/jax/experimental/custom_partitioning.py +++ b/jax/experimental/custom_partitioning.py @@ -13,7 +13,7 @@ # limitations under the License. # Note: import as is required for names to be exported. -# See PEP 484 & https://github.com/google/jax/issues/7570 +# See PEP 484 & https://github.com/jax-ml/jax/issues/7570 from jax._src.custom_partitioning import ( custom_partitioning as custom_partitioning, diff --git a/jax/experimental/host_callback.py b/jax/experimental/host_callback.py index 43e9813d7fac..49162809a325 100644 --- a/jax/experimental/host_callback.py +++ b/jax/experimental/host_callback.py @@ -17,7 +17,7 @@ The host_callback APIs are deprecated as of March 20, 2024. The functionality is subsumed by the `new JAX external callbacks `_ - See https://github.com/google/jax/issues/20385. + See https://github.com/jax-ml/jax/issues/20385. This module introduces the host callback functions :func:`call`, :func:`id_tap`, and :func:`id_print`, that send their arguments from the device @@ -363,11 +363,11 @@ def power3_with_cotangents(x): This is relatively easy to do, once one understands both the JAX custom VJP and the TensorFlow autodiff mechanisms. The code for how this can be done is shown in the ``call_tf_full_ad`` -function in `host_callback_to_tf_test.py `_. +function in `host_callback_to_tf_test.py `_. This example supports arbitrary higher-order differentiation as well. Note that if you just want to call TensorFlow functions from JAX, you can also -use the `jax2tf.call_tf function `_. +use the `jax2tf.call_tf function `_. Using :func:`call` to call a JAX function on another device, with reverse-mode autodiff support ------------------------------------------------------------------------------------------------ @@ -378,7 +378,7 @@ def power3_with_cotangents(x): computation will run, and then the results are sent back to the original accelerator. The code for how this can be done is shown in the ``call_jax_other_device function`` -in `host_callback_test.py `_. +in `host_callback_test.py `_. Low-level details and debugging ------------------------------- @@ -536,6 +536,8 @@ def power3_with_cotangents(x): from jax._src import xla_bridge as xb from jax._src.lib import xla_client from jax._src.lib import xla_extension +from jax._src.lib import xla_extension_version +from jax._src.lib.mlir import ir from jax._src.lib.mlir.dialects import hlo import numpy as np @@ -570,7 +572,7 @@ def power3_with_cotangents(x): help=( 'Use old implementation of host_callback, documented in the module docstring.' 'If False, use the jax.experimental.io_callback implementation. ' - 'See https://github.com/google/jax/issues/20385.' + 'See https://github.com/jax-ml/jax/issues/20385.' ) ) @@ -590,7 +592,7 @@ def _raise_if_using_outfeed_with_pjrt_c_api(backend: xb.XlaBackend): "See https://jax.readthedocs.io/en/latest/debugging/index.html and " "https://jax.readthedocs.io/en/latest/notebooks/external_callbacks.html" " for alternatives. Please file a feature request at " - "https://github.com/google/jax/issues if none of the alternatives are " + "https://github.com/jax-ml/jax/issues if none of the alternatives are " "sufficient.") @@ -606,7 +608,7 @@ def _raise_if_using_outfeed_with_pjrt_c_api(backend: xb.XlaBackend): class CallbackFlavor(enum.Enum): """Specifies which flavor of callback to use under JAX_HOST_CALLBACK_LEGACY=False. - See https://github.com/google/jax/issues/20385. + See https://github.com/jax-ml/jax/issues/20385. """ IO_CALLBACK = 1 # uses jax.experimental.io_callback PURE = 2 # uses jax.pure_callback @@ -627,7 +629,7 @@ def _deprecated_id_tap(tap_func, The host_callback APIs are deprecated as of March 20, 2024. The functionality is subsumed by the `new JAX external callbacks `_ - See https://github.com/google/jax/issues/20385. + See https://github.com/jax-ml/jax/issues/20385. ``id_tap`` behaves semantically like the identity function but has the side-effect that a user-defined Python function is called with the runtime @@ -653,7 +655,7 @@ def _deprecated_id_tap(tap_func, i.e., does not work on CPU unless --jax_host_callback_outfeed=True. callback_flavor: if running with `JAX_HOST_CALLBACK_LEGACY=False` specifies the flavor of callback to use. - See https://github.com/google/jax/issues/20385. + See https://github.com/jax-ml/jax/issues/20385. Returns: ``arg``, or ``result`` if given. @@ -710,7 +712,7 @@ def _deprecated_id_print(arg, The host_callback APIs are deprecated as of March 20, 2024. The functionality is subsumed by the `new JAX external callbacks `_ - See https://github.com/google/jax/issues/20385. + See https://github.com/jax-ml/jax/issues/20385. On each invocation of the printing tap, the ``kwargs`` if present will be printed first (sorted by keys). Then arg will be printed, @@ -728,7 +730,7 @@ def _deprecated_id_print(arg, * ``threshold`` is passed to ``numpy.array2string``. * ``callback_flavor``: if running with `JAX_HOST_CALLBACK_LEGACY=False` specifies the flavor of callback to use. - See https://github.com/google/jax/issues/20385. + See https://github.com/jax-ml/jax/issues/20385. For more details see the :mod:`jax.experimental.host_callback` module documentation. """ @@ -755,7 +757,7 @@ def _deprecated_call(callback_func: Callable, arg, *, The host_callback APIs are deprecated as of March 20, 2024. The functionality is subsumed by the `new JAX external callbacks `_ - See https://github.com/google/jax/issues/20385. + See https://github.com/jax-ml/jax/issues/20385. Args: callback_func: The Python function to invoke on the host as @@ -785,7 +787,7 @@ def _deprecated_call(callback_func: Callable, arg, *, i.e., does not work on CPU unless --jax_host_callback_outfeed=True. callback_flavor: if running with `JAX_HOST_CALLBACK_LEGACY=False` specifies the flavor of callback to use. - See https://github.com/google/jax/issues/20385. + See https://github.com/jax-ml/jax/issues/20385. Returns: the result of the ``callback_func`` invocation. @@ -798,7 +800,7 @@ def _deprecated_call(callback_func: Callable, arg, *, raise NotImplementedError( "When using JAX_HOST_CALLBACK_LEGACY=False you can use the `DEBUG` " "flavor of callback only when the `result_shape` is None. " - "See https://github.com/google/jax/issues/20385." + "See https://github.com/jax-ml/jax/issues/20385." ) return _call(callback_func, arg, result_shape=result_shape, call_with_device=call_with_device, identity=False, @@ -817,7 +819,7 @@ def __init__(self, callback_func, identity, call_with_device): raise NotImplementedError( "When using JAX_HOST_CALLBACK_LEGACY=False, the host_callback APIs" " do not support `tap_with_device` and `call_with_device`. " - "See https://github.com/google/jax/issues/20385.") + "See https://github.com/jax-ml/jax/issues/20385.") def __hash__(self): return hash((self.callback_func, self.identity, self.call_with_device)) @@ -1085,7 +1087,6 @@ def _with_sharding_proto(builder, sharding_proto, op_fn, *args, **kwargs): finally: builder.clear_sharding() - def _outside_call_translation_rule(ctx, avals_in, avals_out, @@ -1185,8 +1186,123 @@ def _outside_call_translation_rule(ctx, f"identity = {identity}") return results + [next_token, next_itoken] +if xla_extension_version < 287: + xla.register_translation(outside_call_p, _outside_call_translation_rule) + + +def _outside_call_outfeed_lowering(ctx: mlir.LoweringRuleContext, + *args_op, + identity, + device_index, + flat_results_aval=(), + **params): + # We expect the current tokens at the end, inserted by _rewrite_jaxpr. + current_token = args_op[-2] + current_itoken = args_op[-1] + + args_to_outfeed = args_op[:-2] + # Some platforms refuse to infeed empty arrays. We generate constants + # instead. + non_empty_flat_results_aval = list(filter(lambda aval: not (_aval_is_empty(aval)), + flat_results_aval)) + need_callback_results_on_device = (not identity and + len(non_empty_flat_results_aval) > 0) + send_infeed = need_callback_results_on_device + generated_infeed = False # Keep track if we emitted an infeed op + for platform in ctx.module_context.platforms: + _raise_if_using_outfeed_with_pjrt_c_api( + xb.get_backend(platform) + ) + callback_id = _register_callback( + functools.partial( + _outside_call_run_callback, + send_infeed=send_infeed, + identity=identity, + flat_results_aval=flat_results_aval, + **params)) -xla.register_translation(outside_call_p, _outside_call_translation_rule) + outfeed_sharding = xla_client.OpSharding() + outfeed_sharding.type = xla_client.OpSharding.Type.MAXIMAL + outfeed_sharding.tile_assignment_dimensions = [1] + outfeed_sharding.tile_assignment_devices = [device_index] + + # next_token = _callback_handler_data.receiver.add_outfeed( + # comp, current_token, callback_id, args_to_outfeed, device_index) + + xla_shapes = util.flatten( + xla.aval_to_xla_shapes(aval) for aval in ctx.avals_in[:-2]) + _callback_handler_data.receiver.register_outfeed(callback_id, xla_shapes) + outfeed_header_start = 271828 # Must match kOutfeedHeaderStart in C++ + header = mlir.ir_constant(np.array([outfeed_header_start, callback_id], + dtype=np.uint32)) + header_outfeed = hlo.OutfeedOp([header], current_token, + outfeed_config=ir.StringAttr.get('')) + mlir.set_sharding(header_outfeed, outfeed_sharding) + next_token, = header_outfeed.results + data_outfeed = hlo.OutfeedOp(args_to_outfeed, next_token, + outfeed_config=ir.StringAttr.get('')) + mlir.set_sharding(data_outfeed, outfeed_sharding) + next_token, = data_outfeed.results + + + if identity: + results = list(args_to_outfeed) + next_itoken = current_itoken + else: + empty_results = [ + mlir.ir_constant(np.zeros(aval.shape, aval.dtype)) + for aval in flat_results_aval + if _aval_is_empty(aval) + ] + if non_empty_flat_results_aval: + assert need_callback_results_on_device + after_outfeed_itoken = hlo.AfterAllOp([current_itoken, next_token]) + # We shard the infeed as AssignedDevice(device_index). This must match the + # outfeed (from outfeed_receiver.cc). Since `lax.infeed` does not support + # this kind of sharding, we use a custom translation for infeed. + array_sharding_proto = xla_client.OpSharding() + array_sharding_proto.type = xla_client.OpSharding.Type.MAXIMAL + array_sharding_proto.tile_assignment_dimensions = [1] + array_sharding_proto.tile_assignment_devices = [device_index] + + token_sharding_proto = xla_client.OpSharding() + token_sharding_proto.type = xla_client.OpSharding.Type.REPLICATED + infeed_sharding_proto = xla.tuple_sharding_proto( + [array_sharding_proto] * len(non_empty_flat_results_aval) + + [token_sharding_proto]) + + output_types = map(mlir.aval_to_ir_types, non_empty_flat_results_aval) + flat_output_types = util.flatten(output_types) + + layouts = ir.ArrayAttr.get([ + ir.ArrayAttr.get( + [mlir.i64_attr(i) + for i in range(len(aval.shape) - 1, -1, -1)]) + for aval in non_empty_flat_results_aval + ]) + infeed = hlo.InfeedOp(flat_output_types + [hlo.TokenType.get()], + after_outfeed_itoken, + infeed_config=ir.StringAttr.get(''), + layout=layouts) + mlir.set_sharding(infeed, infeed_sharding_proto) + non_empty_results = list(infeed.results[:-1]) + next_itoken = infeed.results[-1] + generated_infeed = True + results = [ + empty_results.pop(0) + if _aval_is_empty(result_aval) else non_empty_results.pop(0) + for result_aval in flat_results_aval + ] + else: + results = empty_results + next_itoken = current_itoken + + assert generated_infeed == send_infeed, ( + f"generated_infeed ({generated_infeed}) != send_infeed ({send_infeed})") + assert identity or len(results) == len(flat_results_aval), ( + f"got {len(results)} but expected {len(flat_results_aval)}. " + f"identity = {identity}") + return results + [next_token, next_itoken] def _outside_call_lowering(ctx: mlir.LoweringRuleContext, @@ -1202,23 +1318,32 @@ def _outside_call_lowering(ctx: mlir.LoweringRuleContext, platform = ctx.module_context.platforms[0] use_outfeed = _use_outfeed(platform) if use_outfeed: - # Fall back to XLA path if we are using the outfeed - # TODO(sharadmv): update to use MLIR for this path as well and delete - # XLA lowering - return mlir.xla_fallback_lowering(outside_call_p)( - ctx, - *args, - has_token=has_token, - identity=identity, - flat_results_aval=flat_results_aval, - device_index=device_index, - **params) + if xla_extension_version < 287: + return mlir.xla_fallback_lowering(outside_call_p)( + ctx, + *args, + has_token=has_token, + identity=identity, + device_index=device_index, + flat_results_aval=flat_results_aval, + **params, + ) + else: + return _outside_call_outfeed_lowering( + ctx, *args, + has_token=has_token, + identity=identity, + flat_results_aval=flat_results_aval, + device_index=device_index, + **params, + ) else: # TODO(necula): It seems that on CPU, with custom call, the device_index # does not work, and the callback is always run on device_index=0 if (device_index != 0 and "cpu" in ctx.module_context.platforms): raise ValueError( "The device_index feature on CPU works only when using outfeed.") + # We expect the current tokens at the end, inserted by _rewrite_jaxpr. assert has_token current_token = args[-2] @@ -1280,7 +1405,10 @@ def wrapped_callback(*args): f"identity = {identity}") return list(results) + [next_token, next_itoken] -mlir.register_lowering(outside_call_p, _outside_call_lowering, platform="cpu") +if xla_extension_version < 287: + mlir.register_lowering(outside_call_p, _outside_call_lowering, platform="cpu") +else: + mlir.register_lowering(outside_call_p, _outside_call_lowering) def _outside_call_run_callback( arrays, device, *, @@ -1766,7 +1894,7 @@ def _rewrite_while_outfeed_cond(eqn: core.JaxprEqn, eqns: list[core.JaxprEqn], id_p.multiple_results = True id_p.def_impl(lambda *args: args) id_p.def_abstract_eval(lambda *args: args) -xla.register_translation(id_p, lambda ctx, avals_in, avals_out, *args: args) +mlir.register_lowering(id_p, lambda ctx, *args: args) dispatch.outfeed_rewriter = lambda j: _rewrite_jaxpr(j, False, False) @@ -1993,7 +2121,7 @@ def _deprecated_stop_outfeed_receiver(): _deprecation_msg = ( "The host_callback APIs are deprecated as of March 20, 2024. The functionality " "is subsumed by the new JAX external callbacks. " - "See https://github.com/google/jax/issues/20385.") + "See https://github.com/jax-ml/jax/issues/20385.") _deprecations = { # Added March 20, 2024 diff --git a/jax/experimental/jax2tf/JAX2TF_getting_started.ipynb b/jax/experimental/jax2tf/JAX2TF_getting_started.ipynb index 4f23d88e036e..3613dba0ef06 100644 --- a/jax/experimental/jax2tf/JAX2TF_getting_started.ipynb +++ b/jax/experimental/jax2tf/JAX2TF_getting_started.ipynb @@ -26,7 +26,7 @@ "Link: go/jax2tf-colab\n", "\n", "The JAX2TF colab has been deprecated, and the example code has\n", - "been moved to [jax2tf/examples](https://github.com/google/jax/tree/main/jax/experimental/jax2tf/examples). \n" + "been moved to [jax2tf/examples](https://github.com/jax-ml/jax/tree/main/jax/experimental/jax2tf/examples). \n" ] } ] diff --git a/jax/experimental/jax2tf/README.md b/jax/experimental/jax2tf/README.md index b190829fe7d0..b77474c03728 100644 --- a/jax/experimental/jax2tf/README.md +++ b/jax/experimental/jax2tf/README.md @@ -103,10 +103,10 @@ For more involved examples, please see examples involving: * SavedModel for archival ([examples below](#usage-saved-model)), including saving [batch-polymorphic functions](#shape-polymorphic-conversion), - * TensorFlow Lite ([examples](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/examples/tflite/mnist/README.md)), - * TensorFlow.js ([examples](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/examples/tf_js/quickdraw/README.md)), + * TensorFlow Lite ([examples](https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/examples/tflite/mnist/README.md)), + * TensorFlow.js ([examples](https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/examples/tf_js/quickdraw/README.md)), * TFX ([examples](https://github.com/tensorflow/tfx/blob/master/tfx/examples/penguin/README.md#instructions-for-using-flax)), - * TensorFlow Hub and Keras ([examples](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/examples/README.md)). + * TensorFlow Hub and Keras ([examples](https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/examples/README.md)). [TOC] @@ -249,7 +249,7 @@ graph (they will be saved in a `variables` area of the model, which is not subject to the 2GB limitation). For examples of how to save a Flax model as a SavedModel see the -[examples directory](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/examples/README.md). +[examples directory](https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/examples/README.md). ### Saved model and differentiation @@ -619,7 +619,7 @@ Cannot solve for values of dimension variables {'a', 'b'}. " We can only solve linear uni-variate constraints. " Using the following polymorphic shapes specifications: args[0].shape = (a + b,). Unprocessed specifications: 'a + b' for dimension size args[0].shape[0]. " -Please see https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#dimension-variables-must-be-solvable-from-the-input-shapes for more details. +Please see https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#dimension-variables-must-be-solvable-from-the-input-shapes for more details. ``` ### Shape assertion errors @@ -645,7 +645,7 @@ Input shapes do not match the polymorphic shapes specification. Division had remainder 1 when computing the value of 'd'. Using the following polymorphic shapes specifications: args[0].shape = (b, b, 2*d). Obtained dimension variables: 'b' = 3 from specification 'b' for dimension args[0].shape[0] (= 3). -Please see https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#shape-assertion-errors for more details. +Please see https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#shape-assertion-errors for more details. ``` When using native serialization these are checked by the `tf.XlaCallModule` @@ -869,7 +869,7 @@ leads to errors for the following expressions `b == a or b == b` or `b in [a, b] even though the error is avoided if we change the order of the comparisons. We attempted to retain soundness and hashability by creating both hashable and unhashable -kinds of symbolic dimensions [PR #14200](https://github.com/google/jax/pull/14200), +kinds of symbolic dimensions [PR #14200](https://github.com/jax-ml/jax/pull/14200), but it turned out to be very hard to diagnose hashing failures in user programs because often hashing is implicit when using sets or memo tables. @@ -880,7 +880,7 @@ is unsound. ### Division of symbolic dimensions is partially supported JAX will attempt to simplify division and modulo operations, -e.g., `(a * b + a) // (b + 1) == a` and `6*a + 4 % 3 == 1`. +e.g., `(a * b + a) // (b + 1) == a` and `(6 * a + 4) % 3 == 1`. In particular, JAX will handle the cases when either (a) there is no remainder, or (b) the divisor is a constant in which case there may be a constant remainder. @@ -989,7 +989,7 @@ We list here a history of the serialization version numbers: June 13th, 2023 (JAX 0.4.13). * Version 7 adds support for `stablehlo.shape_assertion` operations and for `shape_assertions` specified in `disabled_checks`. - See [Errors in presence of shape polymorphism](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#errors-in-presence-of-shape-polymorphism). Supported by XlaCallModule + See [Errors in presence of shape polymorphism](https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#errors-in-presence-of-shape-polymorphism). Supported by XlaCallModule since July 12th, 2023 (cl/547482522), available in JAX serialization since July 20th, 2023 (JAX 0.4.14), and the default since August 12th, 2023 (JAX 0.4.15). @@ -1164,7 +1164,7 @@ self.assertAllClose(grad_jax.b, grad_tf[1]) Applies to both native and non-native serialization. When JAX differentiates functions with integer or boolean arguments, the gradients will -be zero-vectors with a special `float0` type (see PR 4039](https://github.com/google/jax/pull/4039)). +be zero-vectors with a special `float0` type (see PR 4039](https://github.com/jax-ml/jax/pull/4039)). This type is translated to `int32` when lowering to TF. For example, @@ -1441,7 +1441,7 @@ Operations like ``jax.numpy.cumsum`` are lowered by JAX differently based on the platform. For TPU, the lowering uses the [HLO ReduceWindow](https://www.tensorflow.org/xla/operation_semantics#reducewindow) operation, which has an efficient implementation for the cases when the reduction function is associative. For CPU and GPU, JAX uses an alternative -lowering using [associative scans](https://github.com/google/jax/blob/f08bb50bfa9f6cf2de1f3f78f76e1aee4a78735d/jax/_src/lax/control_flow.py#L2801). +lowering using [associative scans](https://github.com/jax-ml/jax/blob/f08bb50bfa9f6cf2de1f3f78f76e1aee4a78735d/jax/_src/lax/control_flow.py#L2801). jax2tf uses the TPU lowering (because it does not support backend-specific lowering) and hence it can be slow in some cases on CPU and GPU. @@ -1502,7 +1502,7 @@ before conversion. (This is a hypothesis, we have not yet verified it extensivel There is one know case when the performance of the lowered code will be different. JAX programs use a [stateless -deterministic PRNG](https://github.com/google/jax/blob/main/docs/design_notes/prng.md) +deterministic PRNG](https://github.com/jax-ml/jax/blob/main/docs/design_notes/prng.md) and it has an internal JAX primitive for it. This primitive is at the moment lowered to a soup of tf.bitwise operations, which has a clear performance penalty. We plan to look into using the @@ -1589,7 +1589,7 @@ Applies to non-native serialization only. There are a number of cases when the TensorFlow ops that are used by the `jax2tf` are not supported by TensorFlow for the same data types as in JAX. There is an -[up-to-date list of unimplemented cases](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/g3doc/primitives_with_limited_support.md). +[up-to-date list of unimplemented cases](https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/g3doc/primitives_with_limited_support.md). If you try to lower and run in TensorFlow a program with partially supported primitives, you may see TensorFlow errors that @@ -1626,7 +1626,7 @@ the function to a SavedModel, knowing that upon restore the jax2tf-lowered code will be compiled. For a more elaborate example, see the test `test_tf_mix_jax_with_uncompilable` -in [savedmodel_test.py](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/tests/savedmodel_test.py). +in [savedmodel_test.py](https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/tests/savedmodel_test.py). # Calling TensorFlow functions from JAX @@ -1704,7 +1704,7 @@ For a more elaborate example, including round-tripping from JAX to TensorFlow and back through a SavedModel, with support for custom gradients, see the test `test_round_trip_custom_grad_saved_model` -in [call_tf_test.py](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/tests/call_tf_test.py). +in [call_tf_test.py](https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/tests/call_tf_test.py). All the metadata inserted by TF during tracing and compilation, e.g., source location information and op names, is carried through to the @@ -1901,7 +1901,7 @@ As of today, the tests are run using `tf_nightly==2.14.0.dev20230720`. To run jax2tf on GPU, both jaxlib and TensorFlow must be installed with support for CUDA. One must be mindful to install a version of CUDA that is compatible -with both [jaxlib](https://github.com/google/jax/blob/main/README.md#pip-installation) and +with both [jaxlib](https://github.com/jax-ml/jax/blob/main/README.md#pip-installation) and [TensorFlow](https://www.tensorflow.org/install/source#tested_build_configurations). ## Updating the limitations documentation @@ -1913,9 +1913,9 @@ JAX primitive, data type, device type, and TensorFlow execution mode (`eager`, `graph`, or `compiled`). These limitations are also used to generate tables of limitations, e.g., - * [List of primitives not supported in JAX](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/g3doc/jax_primitives_coverage.md), + * [List of primitives not supported in JAX](https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/g3doc/jax_primitives_coverage.md), e.g., due to unimplemented cases in the XLA compiler, and - * [List of primitives not supported in jax2tf](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/g3doc/primitives_with_limited_support.md), + * [List of primitives not supported in jax2tf](https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/g3doc/primitives_with_limited_support.md), e.g., due to unimplemented cases in TensorFlow. This list is incremental on top of the unsupported JAX primitives. diff --git a/jax/experimental/jax2tf/call_tf.py b/jax/experimental/jax2tf/call_tf.py index 6cb1ec7e4cb2..baae52403053 100644 --- a/jax/experimental/jax2tf/call_tf.py +++ b/jax/experimental/jax2tf/call_tf.py @@ -19,7 +19,7 @@ TensorFlow functions. For examples and details, see -https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#calling-tensorflow-functions-from-jax. +https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#calling-tensorflow-functions-from-jax. """ @@ -93,7 +93,7 @@ def call_tf( For an example and more details see the `README - `_. + `_. Args: callable_tf: a TensorFlow Callable that can take a pytree of TensorFlow @@ -224,9 +224,11 @@ def make_call_vjp_bwd(residual_jax, ct_res_jax): def tf_vjp_fun(args_tf, ct_res_tf): """Invoke TF gradient.""" - # TF does not like us to watch non-float vars - def replace_non_float(arg_tf): - if arg_tf.dtype.is_floating or arg_tf.dtype.is_complex: + # TF does not like us to watch non-float vars or Nones. + def replace_non_float_or_none(arg_tf): + if arg_tf is not None and ( + arg_tf.dtype.is_floating or arg_tf.dtype.is_complex + ): return arg_tf else: # When watched, this will be ignored. When used in results it will @@ -234,29 +236,38 @@ def replace_non_float(arg_tf): # replace it with a float0) return tf.zeros((), dtype=tf.float32) - watched_args_tf = tf.nest.map_structure(replace_non_float, args_tf) + watched_args_tf = tf.nest.map_structure( + replace_non_float_or_none, args_tf + ) with tf.GradientTape(persistent=True) as tape: tape.watch(watched_args_tf) res = callable_tf(*args_tf) tf.nest.assert_same_structure(res, ct_res_tf) dres_darg = tape.gradient( - tf.nest.map_structure(replace_non_float, res), + tf.nest.map_structure(replace_non_float_or_none, res), sources=watched_args_tf, output_gradients=ct_res_tf, - unconnected_gradients=tf.UnconnectedGradients.ZERO) + unconnected_gradients=tf.UnconnectedGradients.ZERO, + ) dres_darg = tree_util.tree_map( lambda x: x if x is None else tf.convert_to_tensor(x), dres_darg, ) - tf.nest.assert_same_structure(dres_darg, args_tf) + + # callable_tf may mutate (the structure of) args_tf, thus we check against + # watched_args_tf which should be structurally the same as the original + # args_tf. + tf.nest.assert_same_structure(dres_darg, watched_args_tf) return dres_darg # Use call_tf to call the VJP function ct_args_jax = call_tf(tf_vjp_fun)(args_jax, ct_res_jax) # We must make the float0s that JAX expects def fix_float0(arg_jax, ct_arg_jax): + if arg_jax is None: + return None arg_dtype = dtypes.result_type(arg_jax) # May be scalar ct_arg_dtype = core.primal_dtype_to_tangent_dtype(arg_dtype) if ct_arg_dtype != ct_arg_jax.dtype: @@ -264,7 +275,8 @@ def fix_float0(arg_jax, ct_arg_jax): ct_arg_dtype)) return ct_arg_jax - ct_args_jax_fixed = tree_util.tree_map(fix_float0, args_jax, ct_args_jax) + ct_args_jax_fixed = tree_util.tree_map(fix_float0, args_jax, ct_args_jax, + is_leaf=lambda x: x is None) return ct_args_jax_fixed make_call.defvjp(make_call_vjp_fwd, make_call_vjp_bwd) @@ -448,7 +460,7 @@ def is_fully_known_shape(s): msg = ("call_tf cannot call functions whose output has dynamic shape. " f"Found output shapes: {concrete_function_flat_tf.output_shapes}. " "Consider using the `output_shape_dtype` argument to call_tf. " - "\nSee https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#limitations-of-call_tf" + "\nSee https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#limitations-of-call_tf" " for a discussion.") raise ValueError(msg) @@ -487,7 +499,7 @@ def _call_tf_lowering( msg = ( "call_tf works best with a TensorFlow function that does not capture " "variables or tensors from the context. " - "See https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#limitations-of-call_tf for a discussion. " + "See https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#limitations-of-call_tf for a discussion. " f"The following captures were found {concrete_function_flat_tf.captured_inputs}") logging.warning(msg) for inp in concrete_function_flat_tf.captured_inputs: @@ -532,7 +544,7 @@ def convert_to_spec(x): "\ncall_tf can used " + "in a staged context (under jax.jit, lax.scan, etc.) only with " + "compilable functions with static output shapes.\n" + - "See https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#limitations-of-call_tf for a discussion." + + "See https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#limitations-of-call_tf for a discussion." + "\n\nCaught TensorFlow exception: " + str(e)) raise ValueError(msg) from e @@ -545,7 +557,7 @@ def canonical_res_aval(res_shape: xla_client.Shape) -> core.ShapedArray: f"{res_shape}. call_tf can used " + "in a staged context (under jax.jit, lax.scan, etc.) only with " + "compilable functions with static output shapes. " + - "See https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#limitations-of-call_tf for a discussion.") + "See https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#limitations-of-call_tf for a discussion.") raise ValueError(msg) res_dtype = res_shape.numpy_dtype() diff --git a/jax/experimental/jax2tf/examples/README.md b/jax/experimental/jax2tf/examples/README.md index b049798e7e15..8869a226b675 100644 --- a/jax/experimental/jax2tf/examples/README.md +++ b/jax/experimental/jax2tf/examples/README.md @@ -4,7 +4,7 @@ jax2tf Examples Link: go/jax2tf-examples. This directory contains a number of examples of using the -[jax2tf converter](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md) to: +[jax2tf converter](https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md) to: * save SavedModel from trained MNIST models, using both Flax and pure JAX. * reuse the feature-extractor part of the trained MNIST model @@ -19,12 +19,12 @@ You can also find usage examples in other projects: The functions generated by `jax2tf.convert` are standard TensorFlow functions and you can save them in a SavedModel using standard TensorFlow code, as shown -in the [jax2tf documentation](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#usage-saved-model). +in the [jax2tf documentation](https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#usage-saved-model). This decoupling of jax2tf from SavedModel is important, because it **allows the user to have full control over what metadata is saved in the SavedModel**. As an example, we provide the function `convert_and_save_model` -(see [saved_model_lib.py](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/examples/saved_model_lib.py).) +(see [saved_model_lib.py](https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/examples/saved_model_lib.py).) For serious uses, you will probably want to copy and expand this function as needed. @@ -65,7 +65,7 @@ If you are using Flax, then the recipe to obtain this pair is as follows: ``` You can see in -[mnist_lib.py](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/examples/mnist_lib.py) +[mnist_lib.py](https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/examples/mnist_lib.py) how this can be done for two implementations of MNIST, one using pure JAX (`PureJaxMNIST`) and a CNN one using Flax (`FlaxMNIST`). Other Flax models can be arranged similarly, @@ -91,7 +91,7 @@ embed all parameters in the graph: ``` (The MNIST Flax examples from -[mnist_lib.py](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/examples/mnist_lib.py) +[mnist_lib.py](https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/examples/mnist_lib.py) normally has a GraphDef of 150k and a variables section of 3Mb. If we embed the parameters as constants in the GraphDef as shown above, the variables section becomes empty and the GraphDef becomes 13Mb. This embedding may allow @@ -112,7 +112,7 @@ If you are using Haiku, then the recipe is along these lines: Once you have the model in this form, you can use the `saved_model_lib.save_model` function from -[saved_model_lib.py](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/examples/saved_model_lib.py) +[saved_model_lib.py](https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/examples/saved_model_lib.py) to generate the SavedModel. There is very little in that function that is specific to jax2tf. The goal of jax2tf is to convert JAX functions into functions that behave as if they had been written with TensorFlow. @@ -120,7 +120,7 @@ Therefore, if you are familiar with how to generate SavedModel, you can most likely just use your own code for this. The file -[saved_model_main.py](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/examples/saved_model_main.py) +[saved_model_main.py](https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/examples/saved_model_main.py) is an executable that shows how to perform the following sequence of steps: @@ -147,9 +147,9 @@ batch sizes: 1, 16, 128. You can see this in the dumped SavedModel. The SavedModel produced by the example in `saved_model_main.py` already implements the [reusable saved models interface](https://www.tensorflow.org/hub/reusable_saved_models). The executable -[keras_reuse_main.py](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/examples/keras_reuse_main.py) +[keras_reuse_main.py](https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/examples/keras_reuse_main.py) extends -[saved_model_main.py](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/examples/saved_model_main.py) +[saved_model_main.py](https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/examples/saved_model_main.py) with code to include a jax2tf SavedModel into a larger TensorFlow Keras model. @@ -174,7 +174,7 @@ In particular, you can select the Flax MNIST model: `--model=mnist_flax`. It is also possible to use jax2tf-generated SavedModel with TensorFlow serving. At the moment, the open-source TensorFlow model server is missing XLA support, but the Google version can be used, as shown in the -[serving examples](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/examples/serving/README.md). +[serving examples](https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/examples/serving/README.md). # Using jax2tf with TensorFlow Lite and TensorFlow JavaScript @@ -186,6 +186,6 @@ can pass the `enable_xla=False` parameter to `jax2tf.convert` to direct `jax2tf` to avoid problematic ops. This will increase the coverage, and in fact most, but not all, Flax examples can be converted this way. -Check out the [MNIST TensorFlow Lite](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/examples/tflite/mnist/README.md) +Check out the [MNIST TensorFlow Lite](https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/examples/tflite/mnist/README.md) and the -[Quickdraw TensorFlow.js example](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/examples/tf_js/quickdraw/README.md). +[Quickdraw TensorFlow.js example](https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/examples/tf_js/quickdraw/README.md). diff --git a/jax/experimental/jax2tf/examples/keras_reuse_main.py b/jax/experimental/jax2tf/examples/keras_reuse_main.py index 77f882af6850..1806e8c4545d 100644 --- a/jax/experimental/jax2tf/examples/keras_reuse_main.py +++ b/jax/experimental/jax2tf/examples/keras_reuse_main.py @@ -18,13 +18,16 @@ See README.md. """ import logging +import warnings from absl import app from absl import flags from jax.experimental.jax2tf.examples import mnist_lib from jax.experimental.jax2tf.examples import saved_model_main import tensorflow as tf import tensorflow_datasets as tfds # type: ignore -import tensorflow_hub as hub # type: ignore +with warnings.catch_warnings(): + warnings.simplefilter("ignore") + import tensorflow_hub as hub # type: ignore FLAGS = flags.FLAGS diff --git a/jax/experimental/jax2tf/examples/keras_reuse_main_test.py b/jax/experimental/jax2tf/examples/keras_reuse_main_test.py index 2934842912f0..e34282a76ff4 100644 --- a/jax/experimental/jax2tf/examples/keras_reuse_main_test.py +++ b/jax/experimental/jax2tf/examples/keras_reuse_main_test.py @@ -41,6 +41,7 @@ def setUp(self): @parameterized.named_parameters( dict(testcase_name=f"_{model}", model=model) for model in ["mnist_pure_jax", "mnist_flax"]) + @jtu.ignore_warning(message="the imp module is deprecated") def test_keras_reuse(self, model="mnist_pure_jax"): FLAGS.model = model keras_reuse_main.main(None) diff --git a/jax/experimental/jax2tf/examples/mnist_lib.py b/jax/experimental/jax2tf/examples/mnist_lib.py index 41173c79a5b9..77432f9ebd92 100644 --- a/jax/experimental/jax2tf/examples/mnist_lib.py +++ b/jax/experimental/jax2tf/examples/mnist_lib.py @@ -27,6 +27,7 @@ import re import time from typing import Any +import warnings from absl import flags import flax @@ -70,7 +71,9 @@ def load_mnist(split: tfds.Split, batch_size: int): if _MOCK_DATA.value: with tfds.testing.mock_data(num_examples=batch_size): try: - ds = tfds.load("mnist", split=split) + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + ds = tfds.load("mnist", split=split) except Exception as e: m = re.search(r'metadata files were not found in (.+/)mnist/', str(e)) if m: diff --git a/jax/experimental/jax2tf/examples/serving/README.md b/jax/experimental/jax2tf/examples/serving/README.md index 0d8f49e45d99..299923109226 100644 --- a/jax/experimental/jax2tf/examples/serving/README.md +++ b/jax/experimental/jax2tf/examples/serving/README.md @@ -2,7 +2,7 @@ Using jax2tf with TensorFlow serving ==================================== This is a supplement to the -[examples/README.md](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/examples/README.md) +[examples/README.md](https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/examples/README.md) with example code and instructions for using `jax2tf` with the open source TensorFlow model server. Specific instructions for Google-internal versions of model server are in the `internal` subdirectory. @@ -15,16 +15,16 @@ SavedModel**. The only difference in the SavedModel produced with jax2tf is that the function graphs may contain -[XLA TF ops](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#caveats) +[XLA TF ops](https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#caveats) that require enabling CPU/GPU XLA for execution in the model server. This is achieved using a command-line flag. There are no other differences compared to using SavedModel produced by TensorFlow. This serving example uses -[saved_model_main.py](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/examples/saved_model_main.py) +[saved_model_main.py](https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/examples/saved_model_main.py) for saving the SavedModel and adds code specific to interacting with the model server: -[model_server_request.py](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/examples/serving/model_server_request.py). +[model_server_request.py](https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/examples/serving/model_server_request.py). 0. *Set up JAX and TensorFlow serving*. @@ -36,7 +36,7 @@ We also need to install TensorFlow for the `jax2tf` feature and the rest of this We use the `tf_nightly` package to get an up-to-date version. ```shell - git clone https://github.com/google/jax + git clone https://github.com/jax-ml/jax JAX2TF_EXAMPLES=$(pwd)/jax/jax/experimental/jax2tf/examples pip install -e jax pip install flax jaxlib tensorflow_datasets tensorflow_serving_api tf_nightly diff --git a/jax/experimental/jax2tf/examples/tflite/mnist/README.md b/jax/experimental/jax2tf/examples/tflite/mnist/README.md index 9c889e647067..f39bd9c7ea9f 100644 --- a/jax/experimental/jax2tf/examples/tflite/mnist/README.md +++ b/jax/experimental/jax2tf/examples/tflite/mnist/README.md @@ -65,7 +65,7 @@ TensorFlow ops that are only available with the XLA compiler, and which are not understood (yet) by the TFLite converter to be used below. -Check out [more details about this limitation](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/g3doc/no_xla_limitations.md), +Check out [more details about this limitation](https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/g3doc/no_xla_limitations.md), including to which JAX primitives it applies. ### Convert the trained model to the TF Lite format diff --git a/jax/experimental/jax2tf/g3doc/BUILD b/jax/experimental/jax2tf/g3doc/BUILD index 424d3b8b9e5d..6222b82b3550 100644 --- a/jax/experimental/jax2tf/g3doc/BUILD +++ b/jax/experimental/jax2tf/g3doc/BUILD @@ -15,7 +15,7 @@ licenses(["notice"]) package( default_applicable_licenses = [], - default_visibility = ["//third_party/py/jax/experimental/jax2tf:__subpackages__"], + default_visibility = ["//jax/experimental/jax2tf:__subpackages__"], ) filegroup( diff --git a/jax/experimental/jax2tf/g3doc/convert_models_results.md b/jax/experimental/jax2tf/g3doc/convert_models_results.md index 545f1faee266..24e2539a3626 100644 --- a/jax/experimental/jax2tf/g3doc/convert_models_results.md +++ b/jax/experimental/jax2tf/g3doc/convert_models_results.md @@ -48,13 +48,13 @@ details on the different converters. ## `flax/actor_critic_[(_, 4*b, 4*b, _)]` ### Example: `flax/actor_critic_[(_, 4*b, 4*b, _)]` | Converter: `jax2tf_xla` ``` -InconclusiveDimensionOperation("Cannot divide '-1*b' by '2'.\nSee https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#division-of-shape-polynomials-is-partially-supported.\n\nThis error arises for arithmetic or comparison operations with shapes that\nare non-constant, and the result of the operation cannot be represented as\na polynomial of dimension variables, or a boolean constant (for comparisons).\n\nPlease see https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#computing-with-dimension-variables\nfor more details.\n") +InconclusiveDimensionOperation("Cannot divide '-1*b' by '2'.\nSee https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#division-of-shape-polynomials-is-partially-supported.\n\nThis error arises for arithmetic or comparison operations with shapes that\nare non-constant, and the result of the operation cannot be represented as\na polynomial of dimension variables, or a boolean constant (for comparisons).\n\nPlease see https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#computing-with-dimension-variables\nfor more details.\n") ``` [Back to top](#summary-table) ### Example: `flax/actor_critic_[(_, 4*b, 4*b, _)]` | Converter: `jax2tf_noxla` ``` -InconclusiveDimensionOperation("Cannot divide '-1*b' by '2'.\nSee https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#division-of-shape-polynomials-is-partially-supported.\n\nThis error arises for arithmetic or comparison operations with shapes that\nare non-constant, and the result of the operation cannot be represented as\na polynomial of dimension variables, or a boolean constant (for comparisons).\n\nPlease see https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#computing-with-dimension-variables\nfor more details.\n") +InconclusiveDimensionOperation("Cannot divide '-1*b' by '2'.\nSee https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#division-of-shape-polynomials-is-partially-supported.\n\nThis error arises for arithmetic or comparison operations with shapes that\nare non-constant, and the result of the operation cannot be represented as\na polynomial of dimension variables, or a boolean constant (for comparisons).\n\nPlease see https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#computing-with-dimension-variables\nfor more details.\n") ``` [Back to top](#summary-table) @@ -62,13 +62,13 @@ InconclusiveDimensionOperation("Cannot divide '-1*b' by '2'.\nSee https://github ``` Conversion error InconclusiveDimensionOperation("Cannot divide '-1*b' by '2'. -See https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#division-of-shape-polynomials-is-partially-supported. +See https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#division-of-shape-polynomials-is-partially-supported. This error arises for arithmetic or comparison operations with shapes that are non-constant, and the result of the operation cannot be represented as a polynomial of dimension variables, or a boolean constant (for comparisons). -Please see https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#computing-with-dimension-variables +Please see https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#computing-with-dimension-variables for more details. ") ``` @@ -78,13 +78,13 @@ for more details. ``` Conversion error InconclusiveDimensionOperation("Cannot divide '-1*b' by '2'. -See https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#division-of-shape-polynomials-is-partially-supported. +See https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#division-of-shape-polynomials-is-partially-supported. This error arises for arithmetic or comparison operations with shapes that are non-constant, and the result of the operation cannot be represented as a polynomial of dimension variables, or a boolean constant (for comparisons). -Please see https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#computing-with-dimension-variables +Please see https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#computing-with-dimension-variables for more details. ") ``` @@ -94,13 +94,13 @@ for more details. ``` Conversion error InconclusiveDimensionOperation("Cannot divide '-1*b' by '2'. -See https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#division-of-shape-polynomials-is-partially-supported. +See https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#division-of-shape-polynomials-is-partially-supported. This error arises for arithmetic or comparison operations with shapes that are non-constant, and the result of the operation cannot be represented as a polynomial of dimension variables, or a boolean constant (for comparisons). -Please see https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#computing-with-dimension-variables +Please see https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#computing-with-dimension-variables for more details. ") ``` @@ -122,13 +122,13 @@ RuntimeError('third_party/tensorflow/lite/kernels/concatenation.cc:159 t->dims-> ## `flax/bilstm_[(b, _), (_,)]` ### Example: `flax/bilstm_[(b, _), (_,)]` | Converter: `jax2tf_xla` ``` -InconclusiveDimensionOperation("Dimension polynomial comparison 'b' == '2' is inconclusive\n\nThis error arises for arithmetic or comparison operations with shapes that\nare non-constant, and the result of the operation cannot be represented as\na polynomial of dimension variables, or a boolean constant (for comparisons).\n\nPlease see https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#computing-with-dimension-variables\nfor more details.\n") +InconclusiveDimensionOperation("Dimension polynomial comparison 'b' == '2' is inconclusive\n\nThis error arises for arithmetic or comparison operations with shapes that\nare non-constant, and the result of the operation cannot be represented as\na polynomial of dimension variables, or a boolean constant (for comparisons).\n\nPlease see https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#computing-with-dimension-variables\nfor more details.\n") ``` [Back to top](#summary-table) ### Example: `flax/bilstm_[(b, _), (_,)]` | Converter: `jax2tf_noxla` ``` -InconclusiveDimensionOperation("Dimension polynomial comparison 'b' == '2' is inconclusive\n\nThis error arises for arithmetic or comparison operations with shapes that\nare non-constant, and the result of the operation cannot be represented as\na polynomial of dimension variables, or a boolean constant (for comparisons).\n\nPlease see https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#computing-with-dimension-variables\nfor more details.\n") +InconclusiveDimensionOperation("Dimension polynomial comparison 'b' == '2' is inconclusive\n\nThis error arises for arithmetic or comparison operations with shapes that\nare non-constant, and the result of the operation cannot be represented as\na polynomial of dimension variables, or a boolean constant (for comparisons).\n\nPlease see https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#computing-with-dimension-variables\nfor more details.\n") ``` [Back to top](#summary-table) @@ -141,7 +141,7 @@ This error arises for arithmetic or comparison operations with shapes that are non-constant, and the result of the operation cannot be represented as a polynomial of dimension variables, or a boolean constant (for comparisons). -Please see https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#computing-with-dimension-variables +Please see https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#computing-with-dimension-variables for more details. ") ``` @@ -156,7 +156,7 @@ This error arises for arithmetic or comparison operations with shapes that are non-constant, and the result of the operation cannot be represented as a polynomial of dimension variables, or a boolean constant (for comparisons). -Please see https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#computing-with-dimension-variables +Please see https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#computing-with-dimension-variables for more details. ") ``` @@ -171,7 +171,7 @@ This error arises for arithmetic or comparison operations with shapes that are non-constant, and the result of the operation cannot be represented as a polynomial of dimension variables, or a boolean constant (for comparisons). -Please see https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#computing-with-dimension-variables +Please see https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#computing-with-dimension-variables for more details. ") ``` @@ -180,13 +180,13 @@ for more details. ## `flax/bilstm_[(_, _), (b,)]` ### Example: `flax/bilstm_[(_, _), (b,)]` | Converter: `jax2tf_xla` ``` -InconclusiveDimensionOperation("Dimension polynomial comparison 'b' == '2' is inconclusive\n\nThis error arises for arithmetic or comparison operations with shapes that\nare non-constant, and the result of the operation cannot be represented as\na polynomial of dimension variables, or a boolean constant (for comparisons).\n\nPlease see https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#computing-with-dimension-variables\nfor more details.\n") +InconclusiveDimensionOperation("Dimension polynomial comparison 'b' == '2' is inconclusive\n\nThis error arises for arithmetic or comparison operations with shapes that\nare non-constant, and the result of the operation cannot be represented as\na polynomial of dimension variables, or a boolean constant (for comparisons).\n\nPlease see https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#computing-with-dimension-variables\nfor more details.\n") ``` [Back to top](#summary-table) ### Example: `flax/bilstm_[(_, _), (b,)]` | Converter: `jax2tf_noxla` ``` -InconclusiveDimensionOperation("Dimension polynomial comparison 'b' == '2' is inconclusive\n\nThis error arises for arithmetic or comparison operations with shapes that\nare non-constant, and the result of the operation cannot be represented as\na polynomial of dimension variables, or a boolean constant (for comparisons).\n\nPlease see https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#computing-with-dimension-variables\nfor more details.\n") +InconclusiveDimensionOperation("Dimension polynomial comparison 'b' == '2' is inconclusive\n\nThis error arises for arithmetic or comparison operations with shapes that\nare non-constant, and the result of the operation cannot be represented as\na polynomial of dimension variables, or a boolean constant (for comparisons).\n\nPlease see https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#computing-with-dimension-variables\nfor more details.\n") ``` [Back to top](#summary-table) @@ -199,7 +199,7 @@ This error arises for arithmetic or comparison operations with shapes that are non-constant, and the result of the operation cannot be represented as a polynomial of dimension variables, or a boolean constant (for comparisons). -Please see https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#computing-with-dimension-variables +Please see https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#computing-with-dimension-variables for more details. ") ``` @@ -214,7 +214,7 @@ This error arises for arithmetic or comparison operations with shapes that are non-constant, and the result of the operation cannot be represented as a polynomial of dimension variables, or a boolean constant (for comparisons). -Please see https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#computing-with-dimension-variables +Please see https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#computing-with-dimension-variables for more details. ") ``` @@ -229,7 +229,7 @@ This error arises for arithmetic or comparison operations with shapes that are non-constant, and the result of the operation cannot be represented as a polynomial of dimension variables, or a boolean constant (for comparisons). -Please see https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#computing-with-dimension-variables +Please see https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#computing-with-dimension-variables for more details. ") ``` @@ -238,13 +238,13 @@ for more details. ## `flax/cnn_[(_, b, b, _)]` ### Example: `flax/cnn_[(_, b, b, _)]` | Converter: `jax2tf_xla` ``` -InconclusiveDimensionOperation("Cannot compute stride for dimension 'b', window_size '2', stride '2'.\nDetails: Cannot divide 'b + -2' by '2'.\nSee https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#division-of-shape-polynomials-is-partially-supported.\n\nThis error arises for arithmetic or comparison operations with shapes that\nare non-constant, and the result of the operation cannot be represented as\na polynomial of dimension variables, or a boolean constant (for comparisons).\n\nPlease see https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#computing-with-dimension-variables\nfor more details.\n.\n\nThis error arises for arithmetic or comparison operations with shapes that\nare non-constant, and the result of the operation cannot be represented as\na polynomial of dimension variables, or a boolean constant (for comparisons).\n\nPlease see https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#computing-with-dimension-variables\nfor more details.\n") +InconclusiveDimensionOperation("Cannot compute stride for dimension 'b', window_size '2', stride '2'.\nDetails: Cannot divide 'b + -2' by '2'.\nSee https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#division-of-shape-polynomials-is-partially-supported.\n\nThis error arises for arithmetic or comparison operations with shapes that\nare non-constant, and the result of the operation cannot be represented as\na polynomial of dimension variables, or a boolean constant (for comparisons).\n\nPlease see https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#computing-with-dimension-variables\nfor more details.\n.\n\nThis error arises for arithmetic or comparison operations with shapes that\nare non-constant, and the result of the operation cannot be represented as\na polynomial of dimension variables, or a boolean constant (for comparisons).\n\nPlease see https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#computing-with-dimension-variables\nfor more details.\n") ``` [Back to top](#summary-table) ### Example: `flax/cnn_[(_, b, b, _)]` | Converter: `jax2tf_noxla` ``` -InconclusiveDimensionOperation("Cannot compute stride for dimension 'b', window_size '2', stride '2'.\nDetails: Cannot divide 'b + -2' by '2'.\nSee https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#division-of-shape-polynomials-is-partially-supported.\n\nThis error arises for arithmetic or comparison operations with shapes that\nare non-constant, and the result of the operation cannot be represented as\na polynomial of dimension variables, or a boolean constant (for comparisons).\n\nPlease see https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#computing-with-dimension-variables\nfor more details.\n.\n\nThis error arises for arithmetic or comparison operations with shapes that\nare non-constant, and the result of the operation cannot be represented as\na polynomial of dimension variables, or a boolean constant (for comparisons).\n\nPlease see https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#computing-with-dimension-variables\nfor more details.\n") +InconclusiveDimensionOperation("Cannot compute stride for dimension 'b', window_size '2', stride '2'.\nDetails: Cannot divide 'b + -2' by '2'.\nSee https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#division-of-shape-polynomials-is-partially-supported.\n\nThis error arises for arithmetic or comparison operations with shapes that\nare non-constant, and the result of the operation cannot be represented as\na polynomial of dimension variables, or a boolean constant (for comparisons).\n\nPlease see https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#computing-with-dimension-variables\nfor more details.\n.\n\nThis error arises for arithmetic or comparison operations with shapes that\nare non-constant, and the result of the operation cannot be represented as\na polynomial of dimension variables, or a boolean constant (for comparisons).\n\nPlease see https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#computing-with-dimension-variables\nfor more details.\n") ``` [Back to top](#summary-table) @@ -253,13 +253,13 @@ InconclusiveDimensionOperation("Cannot compute stride for dimension 'b', window_ Conversion error InconclusiveDimensionOperation("Cannot compute stride for dimension 'b', window_size '2', stride '2'. Details: Cannot divide 'b + -2' by '2'. -See https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#division-of-shape-polynomials-is-partially-supported. +See https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#division-of-shape-polynomials-is-partially-supported. This error arises for arithmetic or comparison operations with shapes that are non-constant, and the result of the operation cannot be represented as a polynomial of dimension variables, or a boolean constant (for comparisons). -Please see https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#computing-with-dimension-variables +Please see https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#computing-with-dimension-variables for more details. . @@ -267,7 +267,7 @@ This error arises for arithmetic or comparison operations with shapes that are non-constant, and the result of the operation cannot be represented as a polynomial of dimension variables, or a boolean constant (for comparisons). -Please see https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#computing-with-dimension-variables +Please see https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#computing-with-dimension-variables for more details. ") ``` @@ -278,13 +278,13 @@ for more details. Conversion error InconclusiveDimensionOperation("Cannot compute stride for dimension 'b', window_size '2', stride '2'. Details: Cannot divide 'b + -2' by '2'. -See https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#division-of-shape-polynomials-is-partially-supported. +See https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#division-of-shape-polynomials-is-partially-supported. This error arises for arithmetic or comparison operations with shapes that are non-constant, and the result of the operation cannot be represented as a polynomial of dimension variables, or a boolean constant (for comparisons). -Please see https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#computing-with-dimension-variables +Please see https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#computing-with-dimension-variables for more details. . @@ -292,7 +292,7 @@ This error arises for arithmetic or comparison operations with shapes that are non-constant, and the result of the operation cannot be represented as a polynomial of dimension variables, or a boolean constant (for comparisons). -Please see https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#computing-with-dimension-variables +Please see https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#computing-with-dimension-variables for more details. ") ``` @@ -303,13 +303,13 @@ for more details. Conversion error InconclusiveDimensionOperation("Cannot compute stride for dimension 'b', window_size '2', stride '2'. Details: Cannot divide 'b + -2' by '2'. -See https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#division-of-shape-polynomials-is-partially-supported. +See https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#division-of-shape-polynomials-is-partially-supported. This error arises for arithmetic or comparison operations with shapes that are non-constant, and the result of the operation cannot be represented as a polynomial of dimension variables, or a boolean constant (for comparisons). -Please see https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#computing-with-dimension-variables +Please see https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#computing-with-dimension-variables for more details. . @@ -317,7 +317,7 @@ This error arises for arithmetic or comparison operations with shapes that are non-constant, and the result of the operation cannot be represented as a polynomial of dimension variables, or a boolean constant (for comparisons). -Please see https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#computing-with-dimension-variables +Please see https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#computing-with-dimension-variables for more details. ") ``` @@ -395,13 +395,13 @@ ValueError('Cannot set tensor: Dimension mismatch. Got 8 but expected 1 for dime ## `flax/resnet50_[(_, 4*b, 4*b, _)]` ### Example: `flax/resnet50_[(_, 4*b, 4*b, _)]` | Converter: `jax2tf_xla` ``` -InconclusiveDimensionOperation("Cannot divide '-1*b' by '2'.\nSee https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#division-of-shape-polynomials-is-partially-supported.\n\nThis error arises for arithmetic or comparison operations with shapes that\nare non-constant, and the result of the operation cannot be represented as\na polynomial of dimension variables, or a boolean constant (for comparisons).\n\nPlease see https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#computing-with-dimension-variables\nfor more details.\n") +InconclusiveDimensionOperation("Cannot divide '-1*b' by '2'.\nSee https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#division-of-shape-polynomials-is-partially-supported.\n\nThis error arises for arithmetic or comparison operations with shapes that\nare non-constant, and the result of the operation cannot be represented as\na polynomial of dimension variables, or a boolean constant (for comparisons).\n\nPlease see https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#computing-with-dimension-variables\nfor more details.\n") ``` [Back to top](#summary-table) ### Example: `flax/resnet50_[(_, 4*b, 4*b, _)]` | Converter: `jax2tf_noxla` ``` -InconclusiveDimensionOperation("Cannot divide '-1*b' by '2'.\nSee https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#division-of-shape-polynomials-is-partially-supported.\n\nThis error arises for arithmetic or comparison operations with shapes that\nare non-constant, and the result of the operation cannot be represented as\na polynomial of dimension variables, or a boolean constant (for comparisons).\n\nPlease see https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#computing-with-dimension-variables\nfor more details.\n") +InconclusiveDimensionOperation("Cannot divide '-1*b' by '2'.\nSee https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#division-of-shape-polynomials-is-partially-supported.\n\nThis error arises for arithmetic or comparison operations with shapes that\nare non-constant, and the result of the operation cannot be represented as\na polynomial of dimension variables, or a boolean constant (for comparisons).\n\nPlease see https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#computing-with-dimension-variables\nfor more details.\n") ``` [Back to top](#summary-table) @@ -409,13 +409,13 @@ InconclusiveDimensionOperation("Cannot divide '-1*b' by '2'.\nSee https://github ``` Conversion error InconclusiveDimensionOperation("Cannot divide '-1*b' by '2'. -See https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#division-of-shape-polynomials-is-partially-supported. +See https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#division-of-shape-polynomials-is-partially-supported. This error arises for arithmetic or comparison operations with shapes that are non-constant, and the result of the operation cannot be represented as a polynomial of dimension variables, or a boolean constant (for comparisons). -Please see https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#computing-with-dimension-variables +Please see https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#computing-with-dimension-variables for more details. ") ``` @@ -425,13 +425,13 @@ for more details. ``` Conversion error InconclusiveDimensionOperation("Cannot divide '-1*b' by '2'. -See https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#division-of-shape-polynomials-is-partially-supported. +See https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#division-of-shape-polynomials-is-partially-supported. This error arises for arithmetic or comparison operations with shapes that are non-constant, and the result of the operation cannot be represented as a polynomial of dimension variables, or a boolean constant (for comparisons). -Please see https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#computing-with-dimension-variables +Please see https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#computing-with-dimension-variables for more details. ") ``` @@ -441,13 +441,13 @@ for more details. ``` Conversion error InconclusiveDimensionOperation("Cannot divide '-1*b' by '2'. -See https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#division-of-shape-polynomials-is-partially-supported. +See https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#division-of-shape-polynomials-is-partially-supported. This error arises for arithmetic or comparison operations with shapes that are non-constant, and the result of the operation cannot be represented as a polynomial of dimension variables, or a boolean constant (for comparisons). -Please see https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#computing-with-dimension-variables +Please see https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#computing-with-dimension-variables for more details. ") ``` @@ -613,13 +613,13 @@ IndexError('Cannot use NumPy slice indexing on an array dimension whose size is ## `flax/lm1b_[(b, _)]` ### Example: `flax/lm1b_[(b, _)]` | Converter: `jax2tf_xla` ``` -InconclusiveDimensionOperation("Dimension polynomial comparison 'b' == '2' is inconclusive\n\nThis error arises for arithmetic or comparison operations with shapes that\nare non-constant, and the result of the operation cannot be represented as\na polynomial of dimension variables, or a boolean constant (for comparisons).\n\nPlease see https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#computing-with-dimension-variables\nfor more details.\n") +InconclusiveDimensionOperation("Dimension polynomial comparison 'b' == '2' is inconclusive\n\nThis error arises for arithmetic or comparison operations with shapes that\nare non-constant, and the result of the operation cannot be represented as\na polynomial of dimension variables, or a boolean constant (for comparisons).\n\nPlease see https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#computing-with-dimension-variables\nfor more details.\n") ``` [Back to top](#summary-table) ### Example: `flax/lm1b_[(b, _)]` | Converter: `jax2tf_noxla` ``` -InconclusiveDimensionOperation("Dimension polynomial comparison 'b' == '2' is inconclusive\n\nThis error arises for arithmetic or comparison operations with shapes that\nare non-constant, and the result of the operation cannot be represented as\na polynomial of dimension variables, or a boolean constant (for comparisons).\n\nPlease see https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#computing-with-dimension-variables\nfor more details.\n") +InconclusiveDimensionOperation("Dimension polynomial comparison 'b' == '2' is inconclusive\n\nThis error arises for arithmetic or comparison operations with shapes that\nare non-constant, and the result of the operation cannot be represented as\na polynomial of dimension variables, or a boolean constant (for comparisons).\n\nPlease see https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#computing-with-dimension-variables\nfor more details.\n") ``` [Back to top](#summary-table) @@ -632,7 +632,7 @@ This error arises for arithmetic or comparison operations with shapes that are non-constant, and the result of the operation cannot be represented as a polynomial of dimension variables, or a boolean constant (for comparisons). -Please see https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#computing-with-dimension-variables +Please see https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#computing-with-dimension-variables for more details. ") ``` @@ -647,7 +647,7 @@ This error arises for arithmetic or comparison operations with shapes that are non-constant, and the result of the operation cannot be represented as a polynomial of dimension variables, or a boolean constant (for comparisons). -Please see https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#computing-with-dimension-variables +Please see https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#computing-with-dimension-variables for more details. ") ``` @@ -662,7 +662,7 @@ This error arises for arithmetic or comparison operations with shapes that are non-constant, and the result of the operation cannot be represented as a polynomial of dimension variables, or a boolean constant (for comparisons). -Please see https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#computing-with-dimension-variables +Please see https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#computing-with-dimension-variables for more details. ") ``` @@ -684,13 +684,13 @@ ValueError('Cannot set tensor: Dimension mismatch. Got 2 but expected 1 for dime ## `flax/wmt_[(b, _), (b, _)]` ### Example: `flax/wmt_[(b, _), (b, _)]` | Converter: `jax2tf_xla` ``` -InconclusiveDimensionOperation("Dimension polynomial comparison 'b' == '2' is inconclusive\n\nThis error arises for arithmetic or comparison operations with shapes that\nare non-constant, and the result of the operation cannot be represented as\na polynomial of dimension variables, or a boolean constant (for comparisons).\n\nPlease see https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#computing-with-dimension-variables\nfor more details.\n") +InconclusiveDimensionOperation("Dimension polynomial comparison 'b' == '2' is inconclusive\n\nThis error arises for arithmetic or comparison operations with shapes that\nare non-constant, and the result of the operation cannot be represented as\na polynomial of dimension variables, or a boolean constant (for comparisons).\n\nPlease see https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#computing-with-dimension-variables\nfor more details.\n") ``` [Back to top](#summary-table) ### Example: `flax/wmt_[(b, _), (b, _)]` | Converter: `jax2tf_noxla` ``` -InconclusiveDimensionOperation("Dimension polynomial comparison 'b' == '2' is inconclusive\n\nThis error arises for arithmetic or comparison operations with shapes that\nare non-constant, and the result of the operation cannot be represented as\na polynomial of dimension variables, or a boolean constant (for comparisons).\n\nPlease see https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#computing-with-dimension-variables\nfor more details.\n") +InconclusiveDimensionOperation("Dimension polynomial comparison 'b' == '2' is inconclusive\n\nThis error arises for arithmetic or comparison operations with shapes that\nare non-constant, and the result of the operation cannot be represented as\na polynomial of dimension variables, or a boolean constant (for comparisons).\n\nPlease see https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#computing-with-dimension-variables\nfor more details.\n") ``` [Back to top](#summary-table) @@ -703,7 +703,7 @@ This error arises for arithmetic or comparison operations with shapes that are non-constant, and the result of the operation cannot be represented as a polynomial of dimension variables, or a boolean constant (for comparisons). -Please see https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#computing-with-dimension-variables +Please see https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#computing-with-dimension-variables for more details. ") ``` @@ -718,7 +718,7 @@ This error arises for arithmetic or comparison operations with shapes that are non-constant, and the result of the operation cannot be represented as a polynomial of dimension variables, or a boolean constant (for comparisons). -Please see https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#computing-with-dimension-variables +Please see https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#computing-with-dimension-variables for more details. ") ``` @@ -733,7 +733,7 @@ This error arises for arithmetic or comparison operations with shapes that are non-constant, and the result of the operation cannot be represented as a polynomial of dimension variables, or a boolean constant (for comparisons). -Please see https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#computing-with-dimension-variables +Please see https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#computing-with-dimension-variables for more details. ") ``` @@ -798,14 +798,14 @@ This converter simply converts a the forward function of a JAX model to a Tensorflow function with XLA support linked in. This is considered the baseline converter and has the largest coverage, because we expect nearly all ops to be convertible. However, please see -[jax2tf Known Issue](https://github.com/google/jax/tree/main/jax/experimental/jax2tf#known-issues) +[jax2tf Known Issue](https://github.com/jax-ml/jax/tree/main/jax/experimental/jax2tf#known-issues) for a list of known problems. ### `jax2tf_noxla` This converter converts a JAX model to a Tensorflow function without XLA support. This means the Tensorflow XLA ops aren't used. See -[here](https://github.com/google/jax/tree/main/jax/experimental/jax2tf#tensorflow-xla-ops) +[here](https://github.com/jax-ml/jax/tree/main/jax/experimental/jax2tf#tensorflow-xla-ops) for more details. ### `jax2tfjs` diff --git a/jax/experimental/jax2tf/g3doc/convert_models_results.md.template b/jax/experimental/jax2tf/g3doc/convert_models_results.md.template index b54c5750334a..54e1d21356a7 100644 --- a/jax/experimental/jax2tf/g3doc/convert_models_results.md.template +++ b/jax/experimental/jax2tf/g3doc/convert_models_results.md.template @@ -29,14 +29,14 @@ This converter simply converts a the forward function of a JAX model to a Tensorflow function with XLA support linked in. This is considered the baseline converter and has the largest coverage, because we expect nearly all ops to be convertible. However, please see -[jax2tf Known Issue](https://github.com/google/jax/tree/main/jax/experimental/jax2tf#known-issues) +[jax2tf Known Issue](https://github.com/jax-ml/jax/tree/main/jax/experimental/jax2tf#known-issues) for a list of known problems. ### `jax2tf_noxla` This converter converts a JAX model to a Tensorflow function without XLA support. This means the Tensorflow XLA ops aren't used. See -[here](https://github.com/google/jax/tree/main/jax/experimental/jax2tf#tensorflow-xla-ops) +[here](https://github.com/jax-ml/jax/tree/main/jax/experimental/jax2tf#tensorflow-xla-ops) for more details. ### `jax2tfjs` diff --git a/jax/experimental/jax2tf/g3doc/no_xla_limitations.md b/jax/experimental/jax2tf/g3doc/no_xla_limitations.md index 457dc998abca..24a1d62ee67e 100644 --- a/jax/experimental/jax2tf/g3doc/no_xla_limitations.md +++ b/jax/experimental/jax2tf/g3doc/no_xla_limitations.md @@ -1,6 +1,6 @@ # jax2tf Limitations for `enable_xla=False` -*Note: the list below is only for running jax2tf with `enable_xla=False`. For general jax2tf known issues please see [here](https://github.com/google/jax/tree/main/jax/experimental/jax2tf#known-issues)* +*Note: the list below is only for running jax2tf with `enable_xla=False`. For general jax2tf known issues please see [here](https://github.com/jax-ml/jax/tree/main/jax/experimental/jax2tf#known-issues)* For most JAX primitives there is a natural TF op that fits the needed semantics (e.g., `jax.lax.abs` is equivalent to `tf.abs`). However, there are a number of diff --git a/jax/experimental/jax2tf/g3doc/primitives_with_limited_support.md b/jax/experimental/jax2tf/g3doc/primitives_with_limited_support.md index dabbcca4d430..b36b004a9d31 100644 --- a/jax/experimental/jax2tf/g3doc/primitives_with_limited_support.md +++ b/jax/experimental/jax2tf/g3doc/primitives_with_limited_support.md @@ -40,7 +40,7 @@ The converter has a mode in which it attempts to avoid special XLA TF ops (`enable_xla=False`). In this mode, some primitives have additional limitations. This table only shows errors for cases that are working in JAX (see [separate -list of unsupported or partially-supported primitives](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/g3doc/jax_primitives_coverage.md) ) +list of unsupported or partially-supported primitives](https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/g3doc/jax_primitives_coverage.md) ) We do not yet have support for `pmap` (with its collective primitives), nor for `sharded_jit` (SPMD partitioning). @@ -56,7 +56,7 @@ We use the following abbreviations for sets of dtypes: * `all` = `integer`, `inexact`, `bool` More detailed information can be found in the -[source code for the limitation specification](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/tests/primitives_test.py). +[source code for the limitation specification](https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/tests/primitives_test.py). | Affected primitive | Description of limitation | Affected dtypes | Affected devices | Affected compilation modes | diff --git a/jax/experimental/jax2tf/g3doc/primitives_with_limited_support.md.template b/jax/experimental/jax2tf/g3doc/primitives_with_limited_support.md.template index bf5dc41d8b8b..219802f5363a 100644 --- a/jax/experimental/jax2tf/g3doc/primitives_with_limited_support.md.template +++ b/jax/experimental/jax2tf/g3doc/primitives_with_limited_support.md.template @@ -40,7 +40,7 @@ The converter has a mode in which it attempts to avoid special XLA TF ops (`enable_xla=False`). In this mode, some primitives have additional limitations. This table only shows errors for cases that are working in JAX (see [separate -list of unsupported or partially-supported primitives](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/g3doc/jax_primitives_coverage.md) ) +list of unsupported or partially-supported primitives](https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/g3doc/jax_primitives_coverage.md) ) We do not yet have support for `pmap` (with its collective primitives), nor for `sharded_jit` (SPMD partitioning). @@ -56,7 +56,7 @@ We use the following abbreviations for sets of dtypes: * `all` = `integer`, `inexact`, `bool` More detailed information can be found in the -[source code for the limitation specification](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/tests/primitives_test.py). +[source code for the limitation specification](https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/tests/primitives_test.py). {{tf_error_table}} diff --git a/jax/experimental/jax2tf/impl_no_xla.py b/jax/experimental/jax2tf/impl_no_xla.py index 5ecde602cdaa..3c51e5d63f25 100644 --- a/jax/experimental/jax2tf/impl_no_xla.py +++ b/jax/experimental/jax2tf/impl_no_xla.py @@ -364,12 +364,15 @@ def _conv_general_dilated( def _dot_general(lhs, rhs, *, dimension_numbers, precision: tuple[PrecisionType, PrecisionType] | None, preferred_element_type: DType | None, + algorithm: Any, transpose_algorithm: Any, _in_avals: Sequence[core.ShapedArray], _out_aval: core.ShapedArray): """Implementation of lax.dot_general_p in terms of tf.linalg.einsum.""" # Unused arguments. del precision del preferred_element_type + del algorithm + del transpose_algorithm lhs, rhs, convert_result = jax2tf._dot_general_convert_to_common_dtype( lhs, _in_avals[0], rhs, _in_avals[1], _out_aval) @@ -591,7 +594,7 @@ def _padding_reduce_window(operand, operand_shape, computation_name, padding_type = pads_to_padtype(operand_shape, window_dimensions, window_strides, padding) - # https://github.com/google/jax/issues/11874. + # https://github.com/jax-ml/jax/issues/11874. needs_manual_padding = ( padding_type == "SAME" and computation_name == "add" and window_dimensions != [1] * len(operand_shape)) diff --git a/jax/experimental/jax2tf/jax2tf.py b/jax/experimental/jax2tf/jax2tf.py index 545945c91ffd..f01a3ab7a036 100644 --- a/jax/experimental/jax2tf/jax2tf.py +++ b/jax/experimental/jax2tf/jax2tf.py @@ -198,7 +198,7 @@ def __init__(self): # A cache for the tf.convert_to_tensor for constants. We try to preserve # sharing for constants, to enable tf.Graph to take advantage of it. - # See https://github.com/google/jax/issues/7992. + # See https://github.com/jax-ml/jax/issues/7992. self.constant_cache = None # None means that we don't use a cache. We # may be outside a conversion scope. @@ -249,7 +249,7 @@ def convert(fun_jax: Callable, """Allows calling a JAX function from a TensorFlow program. See - [README](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md) + [README](https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md) for more details about usage and common problems. Args: @@ -291,12 +291,12 @@ def convert(fun_jax: Callable, polymorphic_shapes are only supported for positional arguments; shape polymorphism is not supported for keyword arguments. - See [the README](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#shape-polymorphic-conversion) + See [the README](https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#shape-polymorphic-conversion) for more details. polymorphic_constraints: a sequence of contraints on symbolic dimension expressions, of the form `e1 >= e2` or `e1 <= e2`. - See more details at https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#user-specified-symbolic-constraints. + See more details at https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#user-specified-symbolic-constraints. with_gradient: if set (default), add a tf.custom_gradient to the lowered function, by converting the ``jax.vjp(fun)``. This means that reverse-mode TensorFlow AD is supported for the output TensorFlow function, and the @@ -1253,7 +1253,7 @@ def __init__(self, trace: TensorFlowTrace, val: TfVal, # We have a TF value with known shape, and the abstract shape is a shape variable. try: aval_int = int(_eval_shape([aval_dim])) # type: ignore - except (TypeError, KeyError): + except (TypeError, KeyError, shape_poly.UnexpectedDimVar): continue assert aval_int == val_dim, f"expected {phys_aval.shape} == {val_shape}. Found {aval_int} != {val_dim}." @@ -1548,6 +1548,9 @@ def _unexpected_primitive(p: core.Primitive, *args, **kwargs): "consume", "ragged_dot", "cholesky_update", + "symmetric_update", + "from_edtype", + "to_edtype", # Pallas TPU primitives "bitcast", "repeat", @@ -2173,9 +2176,12 @@ def gen_conv(lhs, rhs, preferred_element_type: DType | None): def _dot_general(lhs, rhs, *, dimension_numbers, precision: tuple[PrecisionType, PrecisionType] | None, preferred_element_type: DType | None, + algorithm: Any, transpose_algorithm: Any, _in_avals: Sequence[core.ShapedArray], _out_aval: core.ShapedArray): """Implementation of lax.dot_general_p in terms of tf.linalg.einsum.""" + del algorithm, transpose_algorithm # unused + # TODO(b/293247337): we ought to turn on this safety check, but this leads to # failures. Since we are going to turn on native serializaton soon, wait # until then to turn on this check. @@ -3535,7 +3541,7 @@ def _shard_value(val: TfVal, if tf_context.executing_eagerly(): raise ValueError( "A jit function with sharded arguments or results must be used under a `tf.function` context. " - "See https://github.com/google/jax/blob/main/jax/experimental/jax2tf/README.md#support-for-partitioning for a discussion") + "See https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/README.md#support-for-partitioning for a discussion") return xla_sharding.Sharding(proto=xla_sharding_proto).apply_to_tensor( val, use_sharding_op=True) diff --git a/jax/experimental/jax2tf/tests/back_compat_testdata/BUILD b/jax/experimental/jax2tf/tests/back_compat_testdata/BUILD index f584ab5d3191..3417c1abf6ac 100644 --- a/jax/experimental/jax2tf/tests/back_compat_testdata/BUILD +++ b/jax/experimental/jax2tf/tests/back_compat_testdata/BUILD @@ -18,7 +18,7 @@ licenses(["notice"]) package( default_applicable_licenses = [], - default_visibility = ["//third_party/py/jax/experimental/jax2tf:__subpackages__"], + default_visibility = ["//jax/experimental/jax2tf:__subpackages__"], ) py_library( diff --git a/jax/experimental/jax2tf/tests/call_tf_test.py b/jax/experimental/jax2tf/tests/call_tf_test.py index 5740b76038d8..492dfad4c855 100644 --- a/jax/experimental/jax2tf/tests/call_tf_test.py +++ b/jax/experimental/jax2tf/tests/call_tf_test.py @@ -88,6 +88,17 @@ def setUp(self): # bug in TensorFlow. _ = tf.add(1, 1) super().setUp() + self.warning_ctx = jtu.ignore_warning( + message=( + "(jax2tf.convert with native_serialization=False is deprecated" + "|Calling from_dlpack with a DLPack tensor is deprecated)" + ) + ) + self.warning_ctx.__enter__() + + def tearDown(self): + self.warning_ctx.__exit__(None, None, None) + super().tearDown() @_parameterized_jit def test_eval_scalar_arg(self, with_jit=True): @@ -304,7 +315,7 @@ def fun_tf(x): self.assertAllClose(x * outer_var_array + 1., res, check_dtypes=False) def test_with_var_different_shape(self): - # See https://github.com/google/jax/issues/6050 + # See https://github.com/jax-ml/jax/issues/6050 v = tf.Variable((4., 2.), dtype=tf.float32) def tf_func(x): @@ -428,7 +439,7 @@ def loss(functional, x_dict): self.assertAllClose(g_jax, g_tf) def test_grad_int_argument(self): - # Similar to https://github.com/google/jax/issues/6975 + # Similar to https://github.com/jax-ml/jax/issues/6975 # state is a pytree that contains an integer and a boolean. # The function returns an integer and a boolean. def f(param, state, x): @@ -862,6 +873,7 @@ def _transfer_guard(guard_level): class RoundTripToJaxTest(tf_test_util.JaxToTfTestCase): "Reloading output of jax2tf into JAX with call_tf" + def setUp(self): if tf is None: raise unittest.SkipTest("Test requires tensorflow") @@ -869,6 +881,17 @@ def setUp(self): # bug in TensorFlow. _ = tf.add(1, 1) super().setUp() + self.warning_ctx = jtu.ignore_warning( + message=( + "(jax2tf.convert with native_serialization=False is deprecated" + "|Calling from_dlpack with a DLPack tensor is deprecated)" + ) + ) + self.warning_ctx.__enter__() + + def tearDown(self): + self.warning_ctx.__exit__(None, None, None) + super().tearDown() def test_simple(self): f_jax = jnp.sin @@ -1136,6 +1159,16 @@ def tf_f(x): # Jit mode self.assertAllClose(jax.jit(grad_fun_jax)(x), jax.jit(grad_fun_jax_rt)(x)) + def test_grad_pytree_arg_with_none_leaf(self): + def tf_f(x, params): + return x * params["y"] + + x = jnp.array(1.0) + y = jnp.array(2.0) + actual = jax.grad( + jax2tf.call_tf(tf_f), argnums=(1,))(x, {"y": y, "other": None}) + self.assertDictEqual(actual[0], {"y": x, "other": None}) + class RoundTripToTfTest(tf_test_util.JaxToTfTestCase): "Reloading output of call_tf into TF with jax2tf." @@ -1147,6 +1180,17 @@ def setUp(self): # bug in TensorFlow. _ = tf.add(1, 1) super().setUp() + self.warning_ctx = jtu.ignore_warning( + message=( + "(jax2tf.convert with native_serialization=False is deprecated" + "|Calling from_dlpack with a DLPack tensor is deprecated)" + ) + ) + self.warning_ctx.__enter__() + + def tearDown(self): + self.warning_ctx.__exit__(None, None, None) + super().tearDown() def test_alternate(self): # Alternate sin/cos with sin in TF and cos in JAX diff --git a/jax/experimental/jax2tf/tests/flax_models/BUILD b/jax/experimental/jax2tf/tests/flax_models/BUILD index 19afb4a6877c..d3af9581ae02 100644 --- a/jax/experimental/jax2tf/tests/flax_models/BUILD +++ b/jax/experimental/jax2tf/tests/flax_models/BUILD @@ -19,7 +19,7 @@ licenses(["notice"]) package( default_applicable_licenses = [], - default_visibility = ["//third_party/py/jax/experimental/jax2tf:__subpackages__"], + default_visibility = ["//jax/experimental/jax2tf:__subpackages__"], ) py_library( @@ -27,8 +27,8 @@ py_library( srcs = glob(["*.py"]), srcs_version = "PY3", deps = [ + "//jax", "//third_party/py/flax:core", - "//third_party/py/jax", "//third_party/py/jraph", "//third_party/py/numpy", "//third_party/py/typing_extensions", diff --git a/jax/experimental/jax2tf/tests/jax2tf_limitations.py b/jax/experimental/jax2tf/tests/jax2tf_limitations.py index 896d0436e3c2..c3b9e96dc320 100644 --- a/jax/experimental/jax2tf/tests/jax2tf_limitations.py +++ b/jax/experimental/jax2tf/tests/jax2tf_limitations.py @@ -1066,9 +1066,9 @@ def nextafter(cls, harness: test_harnesses.Harness): @classmethod def qr(cls, harness: test_harnesses.Harness): - # See https://github.com/google/jax/pull/3775#issuecomment-659407824; + # See https://github.com/jax-ml/jax/pull/3775#issuecomment-659407824; # # jit_compile=True breaks for complex types. - # TODO: see https://github.com/google/jax/pull/3775#issuecomment-659407824. + # TODO: see https://github.com/jax-ml/jax/pull/3775#issuecomment-659407824. # - for now, the performance of the HLO QR implementation called when # compiling with TF is expected to have worse performance than the # custom calls made in JAX. diff --git a/jax/experimental/jax2tf/tests/jax2tf_test.py b/jax/experimental/jax2tf/tests/jax2tf_test.py index ffa3a103e7e4..6411dc581424 100644 --- a/jax/experimental/jax2tf/tests/jax2tf_test.py +++ b/jax/experimental/jax2tf/tests/jax2tf_test.py @@ -76,6 +76,17 @@ def setUpClass(cls): super().setUpClass() + def setUp(self): + super().setUp() + self.warning_ctx = jtu.ignore_warning( + message="jax2tf.convert with native_serialization=False is deprecated" + ) + self.warning_ctx.__enter__() + + def tearDown(self): + self.warning_ctx.__exit__(None, None, None) + super().tearDown() + def test_empty(self): f_jax = lambda x, y: x self.ConvertAndCompare(f_jax, 0.7, 1) @@ -595,7 +606,7 @@ def fn(x0, x1, x2, x3): @jtu.sample_product(with_function=[False, True]) def test_gradients_int_argument(self, with_function=False): - # https://github.com/google/jax/issues/6975 + # https://github.com/jax-ml/jax/issues/6975 # Also issue #6975. # An expanded version of test_gradients_unused_argument state = dict( @@ -965,11 +976,11 @@ def caller_jax(x): self.assertIn("my_test_function_jax/mul", self.TfToHlo(run_tf)) else: graph_def = str(tf.function(run_tf, autograph=False).get_concrete_function().graph.as_graph_def()) - if "my_test_function_jax/pjit_multiply_/Mul" not in graph_def: - self.assertIn("my_test_function_jax/jit_multiply_/Mul", graph_def) + if "my_test_function_jax/pjit__multiply_/Mul" not in graph_def: + self.assertIn("my_test_function_jax/jit__multiply_/Mul", graph_def) def test_bfloat16_constant(self): - # Re: https://github.com/google/jax/issues/3942 + # Re: https://github.com/jax-ml/jax/issues/3942 def jax_fn_scalar(x): x = x.astype(jnp.bfloat16) x *= 2. @@ -990,7 +1001,7 @@ def jax_fn_array(x): def test_shared_constants(self): # Check that the constants are shared properly in converted functions - # See https://github.com/google/jax/issues/7992. + # See https://github.com/jax-ml/jax/issues/7992. if config.jax2tf_default_native_serialization.value: raise unittest.SkipTest("shared constants tests not interesting for native serialization") const = np.random.uniform(size=256).astype(np.float32) # A shared constant @@ -1002,7 +1013,7 @@ def f(x): def test_shared_constants_under_cond(self): # Check that the constants are shared properly in converted functions - # See https://github.com/google/jax/issues/7992. + # See https://github.com/jax-ml/jax/issues/7992. if config.jax2tf_default_native_serialization.value: raise unittest.SkipTest("shared constants tests not interesting for native serialization") const_size = 512 @@ -1018,7 +1029,7 @@ def f2(x): self.assertLen(f2_consts, len(f1_consts)) def test_shared_constants_under_scan(self): - # See https://github.com/google/jax/issues/7992. + # See https://github.com/jax-ml/jax/issues/7992. if config.jax2tf_default_native_serialization.value: raise unittest.SkipTest("shared constants tests not interesting for native serialization") const_size = 512 @@ -1092,7 +1103,7 @@ def test_weak_types(self): @jtu.sample_product(with_function=[False, True]) def test_kwargs(self, with_function=False): - # Re: https://github.com/google/jax/issues/6791 + # Re: https://github.com/jax-ml/jax/issues/6791 def f_jax(*, x): return jnp.sum(x) f_tf = jax2tf.convert(f_jax) @@ -1104,7 +1115,7 @@ def f_jax(*, x): @jtu.sample_product(with_function=[False, True]) def test_grad_kwargs(self, with_function=False): - # Re: https://github.com/google/jax/issues/6791 + # Re: https://github.com/jax-ml/jax/issues/6791 x = (np.zeros(3, dtype=np.float32), np.zeros(4, dtype=np.float32)) def f_jax(*, x=(1., 2.)): @@ -1621,6 +1632,8 @@ def f_jax(*many_args): res = jax2tf.convert(f_jax, native_serialization=True)(*many_args) self.assertAllClose(f_jax(*many_args), res) + @jtu.ignore_warning(message="Calling from_dlpack with a DLPack tensor", + category=DeprecationWarning) def test_nested_convert(self): # Test call sequence: convert -> call_tf -> convert. @@ -1677,6 +1690,17 @@ def f_jax(x): @jtu.with_config(jax_enable_custom_prng=True) class Jax2tfWithCustomPRNGTest(tf_test_util.JaxToTfTestCase): + def setUp(self): + super().setUp() + self.warning_ctx = jtu.ignore_warning( + message="jax2tf.convert with native_serialization=False is deprecated" + ) + self.warning_ctx.__enter__() + + def tearDown(self): + self.warning_ctx.__exit__(None, None, None) + super().tearDown() + def test_key_argument(self): func = lambda key: jax.random.uniform(key, ()) key = jax.random.PRNGKey(0) @@ -1709,6 +1733,9 @@ def setUp(self): self.use_max_serialization_version = False super().setUp() + @jtu.ignore_warning( + message="jax2tf.convert with native_serialization=False is deprecated" + ) def test_simple(self): self.ConvertAndCompare(jnp.sin, 0.7) diff --git a/jax/experimental/jax2tf/tests/primitives_test.py b/jax/experimental/jax2tf/tests/primitives_test.py index 40af7959cd4b..78c24b7ea411 100644 --- a/jax/experimental/jax2tf/tests/primitives_test.py +++ b/jax/experimental/jax2tf/tests/primitives_test.py @@ -30,7 +30,7 @@ are captured as jax2tf_limitations.Jax2TfLimitation objects. From the limitations objects, we generate a -[report](https://github.com/google/jax/blob/main/jax/experimental/jax2tf/g3doc/primitives_with_limited_support.md). +[report](https://github.com/jax-ml/jax/blob/main/jax/experimental/jax2tf/g3doc/primitives_with_limited_support.md). The report has instructions for how to re-generate it. If a harness run fails with error, and a limitation that matches the device @@ -119,6 +119,11 @@ def test_prim(self, harness: test_harnesses.Harness): device == "tpu"): raise unittest.SkipTest("b/264716764: error on tf.cast from c64 to f32") + if ("eigh" == harness.group_name and + device == "cpu"): + raise unittest.SkipTest( + "Equality comparisons on eigendecompositions are not stable.") + if (config.jax2tf_default_native_serialization.value and device == "gpu" and "lu" in harness.fullname): @@ -178,11 +183,13 @@ def test_primitive_coverage(self): if p.name == "debug_callback" or p.name == "debug_print": # TODO(sharadmv,necula): enable debug callbacks in TF continue - if p.name in ("max_contiguous", "multiple_of"): + if p.name in ("max_contiguous", "multiple_of", "run_scoped"): # Pallas-specific primitives are not supported. continue if p.name == "pallas_call": continue + if p.name == "ffi_call": + continue if p.name == "tpu_custom_call": continue if p.name == "custom_partitioning": diff --git a/jax/experimental/jax2tf/tests/savedmodel_test.py b/jax/experimental/jax2tf/tests/savedmodel_test.py index 8b71de7db30c..aee15883332a 100644 --- a/jax/experimental/jax2tf/tests/savedmodel_test.py +++ b/jax/experimental/jax2tf/tests/savedmodel_test.py @@ -30,6 +30,17 @@ class SavedModelTest(tf_test_util.JaxToTfTestCase): + def setUp(self): + super().setUp() + self.warning_ctx = jtu.ignore_warning( + message="jax2tf.convert with native_serialization=False is deprecated" + ) + self.warning_ctx.__enter__() + + def tearDown(self): + self.warning_ctx.__exit__(None, None, None) + super().tearDown() + def test_eval(self): f_jax = jax.jit(lambda x: jnp.sin(jnp.cos(x))) model = tf.Module() @@ -175,7 +186,7 @@ def model_jax(params, inputs): def test_save_grad_integers(self): - # https://github.com/google/jax/issues/7123 + # https://github.com/jax-ml/jax/issues/7123 # In the end this is a test that does not involve JAX at all batch_size = 5 state = np.array([1], dtype=np.int32) # Works if float32 diff --git a/jax/experimental/jax2tf/tests/shape_poly_test.py b/jax/experimental/jax2tf/tests/shape_poly_test.py index a34c431edab9..07bd9b5aed22 100644 --- a/jax/experimental/jax2tf/tests/shape_poly_test.py +++ b/jax/experimental/jax2tf/tests/shape_poly_test.py @@ -334,7 +334,6 @@ def f_jax(x): # x: i32[b] check_shape_poly(self, f_jax, arg_descriptors=[x], polymorphic_shapes=["b"]) - @jtu.parameterized_filterable( kwargs=[ dict(testcase_name=f"expr={name}", expr=expr) @@ -933,7 +932,7 @@ def f(x): kwargs=[dict(with_function=v) for v in [True, False]] ) def test_grad_int(self, with_function=False): - # https://github.com/google/jax/issues/7093 + # https://github.com/jax-ml/jax/issues/7093 # Also issue #6975. x_shape = (2, 3, 4) xi = np.arange(math.prod(x_shape), dtype=np.int16).reshape(x_shape) @@ -941,7 +940,7 @@ def test_grad_int(self, with_function=False): xi_yf = (xi, yf) zb = np.array([True, False], dtype=np.bool_) def f_jax(xi_yf, zb): # xi: s16[2, 3, 4], yf: f32[2, 3, 4], zb: bool[2] - # results: f32[2, 3, 4], s16[2, 3, 4], bool[2], f32[2, 3, 4] + # results: f32[2, 3, 4], s16[2, 3, 4], bool[2], f32[2, 3, 4] xi, yf = xi_yf # Return a tuple: # (1) float constant, with 0 tangent; @@ -1032,6 +1031,9 @@ def f_jax(x): # A function whose gradient is a constant f_tf, input_signature=[tf.TensorSpec([None], x.dtype)]) self.assertAllClose(f_jax(x), restored_f(x)) + @jtu.ignore_warning( + message="jax2tf.convert with native_serialization=False is deprecated" + ) def test_readme_examples(self): """Some of the examples from the README.""" @@ -2172,7 +2174,7 @@ def f2(z, w): # z: f32[a, 5] w: f32[a + b, 5] -> f32[2*a + b, 10] (2, x.shape[0]), (1, 1), "VALID"), arg_descriptors=[RandArg((3, 8), _f32)], polymorphic_shapes=["b, ..."]), - # https://github.com/google/jax/issues/11804 + # https://github.com/jax-ml/jax/issues/11804 # Use the reshape trick to simulate a polymorphic dimension of 16*b. # (See test "conv_general_dilated.1d_1" above for more details.) PolyHarness("reduce_window", "add_monoid_strides_window_size=static", @@ -2665,6 +2667,11 @@ def test_harness(self, harness: PolyHarness): if 0 < shape[-1] <= 32: harness.check_result = False + if harness.group_name == "vmap_eigh": + raise unittest.SkipTest( + "Should not compare eigendecompositions for equality directly" + "because eigenvalues are sorted.") + if harness.group_name == "vmap_tan": # Tan (b/274462307) require support for custom call stablehlo.tan. raise unittest.SkipTest( diff --git a/jax/experimental/jax2tf/tests/sharding_test.py b/jax/experimental/jax2tf/tests/sharding_test.py index d6349b4870d2..24713539512c 100644 --- a/jax/experimental/jax2tf/tests/sharding_test.py +++ b/jax/experimental/jax2tf/tests/sharding_test.py @@ -61,7 +61,8 @@ def setUpModule(): global topology if jtu.test_device_matches(["tpu"]): - resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu='') + with jtu.ignore_warning(message="the imp module is deprecated"): + resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu='') tf.config.experimental_connect_to_cluster(resolver) # Do TPU init at beginning since it will wipe out all HBMs. topology = tf.tpu.experimental.initialize_tpu_system(resolver) @@ -84,6 +85,15 @@ def setUp(self): raise unittest.SkipTest("Test requires at least 2 local devices") self.devices = np.array(jax.devices()[:2]) # use 2 devices + self.warning_ctx = jtu.ignore_warning( + message="jax2tf.convert with native_serialization=False is deprecated" + ) + self.warning_ctx.__enter__() + + def tearDown(self): + self.warning_ctx.__exit__(None, None, None) + super().tearDown() + def log_jax_hlo(self, f_jax, args: Sequence[Any], *, num_replicas=1, num_partitions=2): """Log the HLO generated from JAX before and after optimizations""" @@ -437,7 +447,7 @@ def f_grad_tf(x_v, res_ct): def test_grad_sharding_different_mesh(self): # Convert with two similar meshes, the only difference being # the order of the devices. grad should not fail. - # https://github.com/google/jax/issues/21314 + # https://github.com/jax-ml/jax/issues/21314 devices = jax.local_devices()[:2] if len(devices) < 2: raise unittest.SkipTest("Test requires 2 local devices") diff --git a/jax/experimental/jet.py b/jax/experimental/jet.py index 1ed6183b1229..ffe362974dcb 100644 --- a/jax/experimental/jet.py +++ b/jax/experimental/jet.py @@ -45,11 +45,11 @@ and can thus be used for high-order automatic differentiation of :math:`f`. Details are explained in - `these notes `__. + `these notes `__. Note: Help improve :func:`jet` by contributing - `outstanding primitive rules `__. + `outstanding primitive rules `__. """ from collections.abc import Callable diff --git a/jax/experimental/mesh_utils.py b/jax/experimental/mesh_utils.py index 473700024ad7..075e4e6eed48 100644 --- a/jax/experimental/mesh_utils.py +++ b/jax/experimental/mesh_utils.py @@ -14,794 +14,8 @@ # ============================================================================== """Utils for building a device mesh.""" -from __future__ import annotations - -import collections -from collections.abc import Callable, Generator, MutableMapping, Sequence -import itertools -import logging -import math -from typing import Any - -from jax._src import xla_bridge as xb -import numpy as np - -logger = logging.getLogger(__name__) - -_TPU_V2 = 'TPU v2' -_TPU_V3 = 'TPU v3' -_TPU_V4 = 'TPU v4' -_TPU_V5_LITE = "TPU v5 lite" - -# Maps physical topology -> mesh shape -> transpose to use for jekbradbury's -# famous contiguous mesh trick. -# -# The trick only works for certain topologies and mesh shapes. Trivial dims of -# size 1 can be added to the shapes listed, and they are also supported. -_TRANSPOSE_TRICKS: dict[ - tuple[int, ...], dict[tuple[int, ...], tuple[int, ...]] -] = { - (2, 2, 1): { - (2, 2): (0, 1, 2), - }, - (2, 2, 4): { - (4, 4): (0, 1, 2), - }, - (4, 4, 4): { - (16, 4): (0, 2, 1), - }, - (4, 8, 8): { - (64, 4): (0, 2, 1), - (4, 64): (0, 2, 1), - }, - (8, 8, 8): { - (64, 8): (0, 2, 1), - }, - (8, 16, 16): { - (256, 8): (0, 2, 1), - (8, 256): (0, 2, 1), - }, -} - -# Physical ordering of core IDs in a tray that creates a ring -_TRAY_RING_ORDER = (0, 1, 2, 3, 6, 7, 4, 5) -_TRAY_2x2_RING_ORDER = (0, 1, 3, 2) -_TRAY_4x4_RING_ORDER = (0, 1, 2, 3, 7, 6, 5, 9, 10, 11, 15, 14, 13, 12, 8, 4) - -def _tpu_v2_v3_create_device_mesh( - mesh_shape: Sequence[int], - devices: Sequence[Any], - **unused_kwargs, -) -> np.ndarray: - if len(devices) == 8: - logger.info( - 'Reordering mesh to physical ring order on single-tray TPU v2/v3.' - ) - device_mesh = np.asarray(devices) - device_mesh = device_mesh[np.array(_TRAY_RING_ORDER)] - device_mesh = device_mesh.reshape(mesh_shape) - return device_mesh - elif mesh_shape[-1] == 8: - device_mesh = np.asarray(devices).reshape(mesh_shape) - logger.info( - 'Reordering mesh to physical ring order on each TPU v2/v3 tray.' - ) - perm = np.array(_TRAY_RING_ORDER) - device_mesh = device_mesh[..., perm] - return device_mesh - else: - # TODO(skye): implement 2D mesh_shape logic here: - # https://github.com/tensorflow/lingvo/blob/0df40cf604dfcd14e28f7087d73687a0bd2fe5c6/lingvo/core/gshard_utils.py#L187 - # (possibly replaces above mesh_shape[-1] == 8 case) - return np.asarray(devices).reshape(mesh_shape) - - -def _vlc_create_device_mesh( - mesh_shape: Sequence[int], devices: Sequence[Any], **unused_kwargs -) -> np.ndarray | None: - """Creates rotated pincer device assignment for selected topologies. - - Args: - mesh_shape: Logical mesh shape used by the model. - devices: TPU devices. - **unused_kwargs: ... - - Returns: - None or reordered devices reshaped as `mesh_shape`. - """ - max_x, max_y, max_z = max(getattr(d, "coords", (0, 0, 0)) for d in devices) - bound_x, bound_y, bound_z = max_x + 1, max_y + 1, max_z + 1 - # Our ring re-ordering makes sense only if the passed-in devices are - # sequential, which may not always be the case. reversed() changes z-minor to - # x-minor. - sequential_devices = sorted( - devices, - key=lambda d: tuple(reversed(getattr(d, "coords", (0, 0, 0))))) - - if bound_x == bound_y == 2 and bound_z == 1 and len(devices) == 4: # VLC2x2 - device_mesh = np.asarray(sequential_devices) - device_mesh = device_mesh[np.array(_TRAY_2x2_RING_ORDER)] - device_mesh = device_mesh.reshape(mesh_shape) - return device_mesh - - if bound_x == bound_y == 4 and bound_z == 1 and len(devices) == 16: # VLP4x4 - # Only uses ring order if the whole mesh is a replica group. - if max(mesh_shape) == len(devices): - device_mesh = np.asarray(sequential_devices) - device_mesh = device_mesh[np.array(_TRAY_4x4_RING_ORDER)] - device_mesh = device_mesh.reshape(mesh_shape) - return device_mesh - - return None - - -# Registers functions to create device mesh for specific device kinds. Takes -# precedence over the more general logic in create_device_mesh(). Handler may -# return None; in that case, it will fall back to using the default logic. -device_kind_handler_dict: dict[ - str, - Callable[..., np.ndarray | None], -] = { - _TPU_V2: _tpu_v2_v3_create_device_mesh, - _TPU_V3: _tpu_v2_v3_create_device_mesh, - _TPU_V5_LITE: _vlc_create_device_mesh, -} - - -def _create_device_mesh_for_nd_torus( - physical_mesh: np.ndarray, - mesh_shape: Sequence[int], - *, - allow_split_physical_axes: bool = False, -) -> tuple[np.ndarray, np.ndarray]: - """Assigns logical parallelism axes to physical axes of an N-D torus network. - - Given logical parallelism axes with sizes in `mesh_shape` and devices in an - N-dimensional torus network represented by `physical_mesh`, maps each logical - axis to one or more physical axes. Prefer to map more-performance-sensitive - logical axes to larger numbers of physical axes to maximize the bandwidth - available to them. Also prefer to assign logical axes to multiple physical - axes of the same size (e.g., a 2D square) rather than multiple physical axes - of different sizes when possible. - - If allow_split_physical_axes = False (default), this routine will error out - instead of splitting a physical axis over more than one logical axis (which - would reduce total usable bandwidth). - - Let's use a concrete example to explain the concepts and considerations. - - As an example, suppose the logical mesh is [data, model], for data and model - parallelism respectively. Also suppose that data parallelism is less - performance sensitive than model parallelism. Consider a 3D TPU pod slice of - shape 4x4x16, represented by a physical mesh of shape (4, 4, 16). - - A TPU pod slice has equal bandwidth along all axes with wraparound links, but - a 2D plane of size 4x4 may have faster XLA collective implementations than a - non-square plane or a 1D subgroup. If the mesh_shape is [16, 16], we may want - the more performance sensitive `model` axis to be mapped to the 4x4 XY plane. - - Args: - physical_mesh: a np.ndarray of devices in the shape of the N-D torus - physical topology. - mesh_shape: shape of the logical mesh (size of the various logical - parallelism axes), with axes ordered by increasing network intensity. - allow_split_physical_axes: If True, we would split physical axes if - necessary to fit the desired mesh shape. - - Returns: - An np.ndarray of devices in the shape of the logical mesh (mesh_shape), with - each logical parallelism axis mapped to one or more physical mesh axes. - The axis assignment matrix, which is a 2-d array mapping from - (physical_axis, logical_axis) to the size assigned, with the invariant - np.prod(assignment, axis=1) = physical_mesh_shape, and - np.prod(assignment, axis=0) = mesh_shape. - """ - # Remaining physical axes to be assigned to logical axes. - assignable_physical_mesh = list(physical_mesh.shape) - # Map each logical axis to a subset of physical axes. - assignment: list[tuple[int, ...]] = [() for _ in mesh_shape] - - # Assign logical axes from highest network intensity to lowest. - # `mesh_shape` is assumed to ordered by lowest network intensity first, so - # reverse it first. - for logical_axis_index, logical_axis_size in reversed( - list(enumerate(mesh_shape)) - ): - # Preferentially map to more physical axes first for higher bandwidth. - for num_axes in range(3, 0, -1): - # Try assign to any subset of size num_axes. Generate all candidates. - indices_and_axes = itertools.combinations( - enumerate(assignable_physical_mesh), num_axes - ) - for elem in indices_and_axes: - c_indices, c_axes = zip(*elem) - # TODO(zhangqiaorjc): Due to limitations in XLA, 2D collectives only - # implemented for square 2D plane. Mapping a physical axis to two - # logical axes might be slower for non-square 2D plane, e.g., map 32 to - # 4x8 or a single axis. If XLA 2D collectives support non-square plane - # soon, we can continue to preferentially map to 2D plane in general, - # otherwise, we should treat non-square 2D plane and 1D submesh equally. - if np.prod(c_axes) == logical_axis_size: - assignment[logical_axis_index] = c_indices - # Zero the assigned physical axes. - assignable_physical_mesh = [ - 0 if i in c_indices else v - for i, v in enumerate(assignable_physical_mesh) - ] - break - if assignment[logical_axis_index]: - # We already found an assignment from one candidate above. - break - else: - # If the num_axes for loop did not break, i.e. none of the candidates work - # goto here with this while-else construct. - if logical_axis_size > 1: - if not allow_split_physical_axes: - # Although this is now implemented, there are downstream tasks - # counting on this being a NotImplementedError. - raise NotImplementedError( - 'Failed to find assignment for logical_axis_index' - f' {logical_axis_index} of size {logical_axis_size} with' - f' remaining assignable mesh {assignable_physical_mesh}. The size' - ' of each axis in your logical mesh must be equal to the product' - ' of some subset of the physical mesh axis sizes. E.g. logical' - ' mesh (4, 16) is compatible with physical mesh 4x4x4 since 4=4' - ' and 16=4x4. If you want to split physical axes, set ' - ' allow_split_physical_axes to True.' - ) - else: - # We will try finding an assignment, even if that means splitting the - # physical axes, which requires a more sophisticated implementation. - return _create_device_mesh_for_nd_torus_splitting_axes( - physical_mesh, mesh_shape - ) - - # Flatten the assignment, e.g., [(), (2,), (0, 1)] -> (2, 0, 1). - transpose: list[int] = [] - assignment_array = np.ones( - [len(physical_mesh.shape), len(mesh_shape)], dtype=np.int64 - ) - for i, x in enumerate(assignment): - for y in x: - physical_mesh_axis = int(y) - assignment_array[physical_mesh_axis, i] = physical_mesh.shape[ - physical_mesh_axis - ] - transpose.append(physical_mesh_axis) - return ( - physical_mesh.transpose(transpose).reshape(mesh_shape), - assignment_array, - ) - - -def _create_device_mesh_for_nd_torus_splitting_axes( - physical_mesh: np.ndarray, - mesh_shape: Sequence[int], -) -> tuple[np.ndarray, np.ndarray]: - """Assigns logical parallelism axes to physical axes of an N-D torus network. - - This implementation allows creating meshes that requires splitting physical - axes, and thus one could produce logical mesh of any shape, as long as the - number of devices matches, e.g., - - - Creating 2x2x4 from 4x4; - - - Creating 2x2x16 from 8x8; - - Args: - physical_mesh: a np.ndarray of devices in the shape of the N-D torus - physical topology. - mesh_shape: shape of the logical mesh (size of the various logical - parallelism axes), with axes ordered by increasing network intensity. - - Returns: - An np.ndarray of devices in the shape of the logical mesh (mesh_shape), with - each logical parallelism axis mapped to one or more physical mesh axes. - The axis assignment matrix, which is a 2-d array mapping from - (physical_axis, logical_axis) to the size assigned, with the invariant - np.prod(assignment, axis=1) = physical_mesh_shape, and - np.prod(assignment, axis=0) = mesh_shape. - """ - if np.prod(physical_mesh.shape) != np.prod(mesh_shape): - raise ValueError( - 'The number of devices in physical mesh' - f' {physical_mesh.shape} does not match the number of devices' - f' in logical mesh {mesh_shape}.' - ) - - physical_mesh_shape = physical_mesh.shape - logical_mesh_shape = tuple(mesh_shape) - - # (Partial) assignment map as an 2-d array [p_axis, l_axis] -> size. - assignment = np.ones( - [len(physical_mesh_shape), len(logical_mesh_shape)], dtype=np.int64 - ) - - # Process logical axes from highest network intensity to lowest. - # `mesh_shape` is assumed to ordered by lowest network intensity first, so - # reverse it. - for logical_axis, logical_axis_size in reversed( - list(enumerate(logical_mesh_shape)) - ): - # Go over all the possible assignment for the logical axis, including the - # one that splits multiple physical axes. - best_logical_axis_assignment = None - for logical_axis_assignment in _enumerate_feasible_logical_axis_assignments( - physical_mesh_shape, assignment, logical_axis_size - ): - # TODO(rosun): Instead of using heuristics, replace this with a proper - # scoring function reflecting the underlying hardware properties. - if ( - best_logical_axis_assignment is None - or _prefer_first_logical_axis_assignment( - logical_axis_assignment, - best_logical_axis_assignment, - physical_mesh_shape=physical_mesh_shape, - assignment=assignment, - ) - ): - best_logical_axis_assignment = logical_axis_assignment - assignment[:, logical_axis] = best_logical_axis_assignment - - # Read out the assignment. - logical_mesh = _generate_logical_mesh( - physical_mesh, logical_mesh_shape, assignment - ) - - return logical_mesh, assignment - - -def _get_prime_factors(x: int) -> list[int]: - """Returns a sorted list of prime factors for the given number.""" - assert x > 0 - factors = [] - for p in range(2, math.isqrt(x) + 2): - while x % p == 0: - factors.append(p) - x //= p - if x == 1: - return factors - else: - return [x] # x is a prime number. - - -def _enumerate_feasible_logical_axis_assignments( - physical_mesh_shape: Sequence[int], - assignment: np.ndarray, - logical_axis_size: int, -) -> Generator[np.ndarray, None, None]: - """Yields feasible assignments for a single logical axis. - - For a physical mesh of shape [x_1, ..., x_n], and the product of all previous - assignments on each physical axes [y_1, ..., y_n], this function yields all - possible assignments for the axis as 1-d arrays [z_1, ..., z_n], so that: - - - prod(z_1, ..., z_n) = logical_axis_size - - - x_i % (z_i * y_i) = 0 - - Args: - physical_mesh_shape: Physical mesh shape. - assignment: Existing assignment matrix. - logical_axis_size: Size of the logical axis to assign. - - Yields: - All valid assignments for the logical axis. Each assignment is represented - as an integer array of length len(physical_mesh_shape). - """ - logical_axis_factors: MutableMapping[int, int] = collections.defaultdict(int) - for factor in _get_prime_factors(logical_axis_size): - logical_axis_factors[factor] += 1 - - available_physical_mesh_shape = np.array(physical_mesh_shape) // np.prod( - assignment, axis=-1 - ) - - # To enable efficient enumerations, we first index physical axes by their - # prime factors. Since we know the prime factorization of the logical axis - # size, we could simply enumerate by picking the correct count for each - # prime factor. - physical_axes_by_factor: MutableMapping[int, list[int]] = ( - collections.defaultdict(list) - ) - for physical_axis, physical_axis_size in enumerate( - available_physical_mesh_shape - ): - for factor in _get_prime_factors(physical_axis_size): - if factor not in logical_axis_factors: - continue - physical_axes_by_factor[factor].append(physical_axis) - - factors = [] - assignments_by_factor = [] - for factor, multiplicity in logical_axis_factors.items(): - factors.append(factor) - assignments_by_factor.append( - set( - itertools.combinations( - physical_axes_by_factor[factor], multiplicity - ) - ) - ) - - for axis_assignment in itertools.product(*assignments_by_factor): - result = np.ones([len(physical_mesh_shape)], dtype=np.int64) - for factor_index, per_factor_assignment in enumerate(axis_assignment): - for physical_axis in per_factor_assignment: - result[physical_axis] *= factors[factor_index] - yield result - - -def _prefer_first_logical_axis_assignment( - x: np.ndarray, - y: np.ndarray, - *, - physical_mesh_shape: Sequence[int], - assignment: np.ndarray, -) -> bool: - """Returns True if the first axis assignment is preferred over the second. - - For now, this is implemented with some very simple heuristics. However, - it is possible to introduce e.g., a value function here based on a more - precise model of the underlying hardware. - - TODO(rosun): Use a proxy of network capacity to select the partitions. - - Args: - x: Logical axis assignment as [len(physical_mesh_shape)] array. - y: Logical axis assignment as [len(physical_mesh_shape)] array. - physical_mesh_shape: Physical mesh shape. - assignment: Assignment matrix. - - Returns: - True if x is preferred over y. - """ - # Prefer occupying complete physical axes. I don't have a good reason for - # this, except that it is compatible with the existing behavior. - # - # E.g., on 4 x 4 x 8, [4, 4, -] will be preferred over [4, -, 4], and then - # over [2, 2, 4]. - x_whole_axis_size = np.prod( - [s for i, s in enumerate(x) if s == physical_mesh_shape[i]] - ) - y_whole_axis_size = np.prod( - [s for i, s in enumerate(y) if s == physical_mesh_shape[i]] - ) - - if x_whole_axis_size != y_whole_axis_size: - return x_whole_axis_size > y_whole_axis_size - - # Prefer occupying more whole physical axes for better bandwidth. - # - # This is consistent with existing logic, i.e., 2 x 2 is preferred over 4. - x_num_whole_axes = len( - [1 for i, s in enumerate(x) if s == physical_mesh_shape[i] and s > 1] - ) - y_num_whole_axes = len( - [1 for i, s in enumerate(y) if s == physical_mesh_shape[i] and s > 1] - ) - - if x_num_whole_axes != y_num_whole_axes: - return x_num_whole_axes > y_num_whole_axes - - # Prefer taking physical axes that are not taken by logical axes of higher - # network intensity. E.g., for a 4 x 4 x 4, suppose that the previous - # assignments are 1 x 2 x 4, and we want to place a new logical axis of size - # 2, we will go for [2, 1, 1] instead of [1, 2, 1], as the latter choice will - # tap into bandwidth already taken by the higher intensity axis. - assigned_physical_mesh_shape = np.prod(assignment, axis=-1) - - x_non_overlapping_axis_size = np.prod( - [s for i, s in enumerate(x) if assigned_physical_mesh_shape[i] > 1] - ) - y_non_overlapping_axis_size = np.prod( - [s for i, s in enumerate(y) if assigned_physical_mesh_shape[i] > 1] - ) - - if x_non_overlapping_axis_size != y_non_overlapping_axis_size: - return x_non_overlapping_axis_size > y_non_overlapping_axis_size - - # Otherwise sort by reverse lexical graphical order, to be consistent with - # existing behavior. - return tuple(x) > tuple(y) - - -def _generate_logical_mesh( - physical_mesh: np.ndarray, - logical_mesh_shape: Sequence[int], - assignment: np.ndarray, -) -> np.ndarray: - """Compute the logical mesh from assignment map. - - Args: - physical_mesh: Physical device mesh. - logical_mesh_shape: Logical mesh shape. - assignment: 2-d assignment matrix shape [physical_dims, logical_dims]. - - Returns: - Logical mesh reshaped from physical mesh. - """ - physical_indices = np.broadcast_to( - np.expand_dims( - np.arange(len(physical_mesh.shape), dtype=np.int64), axis=-1 - ), - assignment.shape, - ).reshape([-1]) - - logical_indices = np.broadcast_to( - np.expand_dims( - np.arange(len(logical_mesh_shape), dtype=np.int64), axis=0 - ), - assignment.shape, - ).reshape([-1]) - - # Axes of logical mesh is ordered by (physical_axis, logical_axis). - # - # Note that we sort for each physical_axis the logical_axis, so that higher - # intensity logical axes are replicated at inner (minor) dimensions. - # - # E.g., if a dimension size is 12 = 3x4, where 3 is higher intensity and 4 - # is lower, we want to reshape so that it becomes 12 = 4x3. Imagine in the - # 1-d case, this will allow more connections between the higher intensity - # axes. - logical_mesh = np.reshape(physical_mesh, assignment.reshape([-1])) - - # We will then group by l_axis as this is what is expected from output. - _, _, transpose_axes = zip( - *sorted( - zip(logical_indices, physical_indices, range(len(logical_indices))) - ) - ) - logical_mesh = np.transpose(logical_mesh, transpose_axes) - - # Reshape to add the trivial dimensions back. - logical_mesh = np.reshape(logical_mesh, logical_mesh_shape) - - return logical_mesh - - -def _bounds_from_last_device(last_device) -> Sequence[int]: - """Gets the bound from the given last device.""" - # Must be passed the device at the highest-coordinate corner of the - # relevant mesh, which is a requirement we know is satisfied by the last - # device in jax.devices(). - assert hasattr(last_device, 'coords'), 'Only TPU supported' - x, y, z = last_device.coords - return x + 1, y + 1, z + 1, last_device.core_on_chip + 1 - - -def _get_physical_tpu_mesh(jax_devices: Sequence[Any]) -> np.ndarray: - r"""Rearrange TPU devices in a slice into a physical mesh. - - Args: - jax_devices: A list of JAX devices in a TPU slice in process-tiled z, y, x, - core order, e.g. from jax.devices(). The coordinates of these devices - should constitute a cuboid with no holes; e.g., the coordinates can be - {(1, 0, 0), (1, 0, 1), (1, 1, 0), (1, 1, 1)} (a 1x2x2 cuboid); passing - only 3 of these devices would result in a "hole" in that cuboid, which is - an error. As in our example, the cuboid is not required to include the - point (0, 0, 0). - - Returns: - A np.ndarray of JAX devices with shape [global_x, global_y, global_z]. On - v2 and v3, global_z is instead cores_per_chip (i.e., 2). - """ - device_kind = jax_devices[0].device_kind - device_coords = [d.coords for d in jax_devices] - coord_size = len(device_coords[0]) - # Position-wise max and min coordinates: - max_coords = tuple( - max(dc[i] for dc in device_coords) for i in range(coord_size) - ) - min_coords = tuple( - min(dc[i] for dc in device_coords) for i in range(coord_size) - ) - dims = tuple(h - l + 1 for (h, l) in zip(max_coords, min_coords)) - - max_cores_per_chip = max(d.core_on_chip for d in jax_devices) - min_cores_per_chip = min(d.core_on_chip for d in jax_devices) - cores_per_chip = max_cores_per_chip - min_cores_per_chip + 1 - - assert len(dims) == 3, dims - assert ( - len(jax_devices) == np.prod(dims) * cores_per_chip - ), f'{jax_devices=} {dims=} {cores_per_chip=}' - - if device_kind in (_TPU_V2, _TPU_V3): - out = np.empty(dims[:2] + (cores_per_chip,), dtype=object) - for d in jax_devices: - coords = d.coords - assert coords[2] == 0, d - out[ - coords[0] - min_coords[0], - coords[1] - min_coords[1], - d.core_on_chip - min_cores_per_chip, - ] = d - else: - out = np.empty(dims, dtype=object) - for d in jax_devices: - coords = d.coords - if d.core_on_chip != 0: - raise AssertionError( - 'Creating meshes for TPU >v3 requires one device per chip' - f' ("megacore" mode). Got device id {d.core_on_chip} for a device' - f' of kind {device_kind}: {d}.' - ) - out[ - coords[0] - min_coords[0], - coords[1] - min_coords[1], - coords[2] - min_coords[2], - ] = d - - # Check there is no "hole" in the mesh we constructed. - if (out == None).any(): # pylint: disable=singleton-comparison - raise AssertionError( - 'Constructed mesh contains a "hole"; probable cause: coordinates ' - f'of jax_devices are not a contiguous cuboid: {jax_devices}' - ) - return out - - -# jekbradbury's famous trick for creating contiguous submeshes (where available) -def _transpose_trick( - physical_mesh: np.ndarray, mesh_shape: Sequence[int] -) -> np.ndarray: - mesh_shape = tuple(mesh_shape) - topology = physical_mesh.shape - if topology not in _TRANSPOSE_TRICKS: - raise ValueError( - 'create_device_mesh cannot create contiguous submeshes for ' - f'physical mesh topology {topology}' - ) - - mesh_shape_no_trivial_dims: tuple[int, ...] = () - for dim_size in mesh_shape: - if dim_size != 1: - mesh_shape_no_trivial_dims += (dim_size,) - - if mesh_shape_no_trivial_dims not in _TRANSPOSE_TRICKS[topology]: - raise ValueError( - 'create_device_mesh cannot create contiguous submeshes for ' - f'mesh_shape {mesh_shape} and physical mesh topology {topology}. ' - f'Available mesh_shapes: {list(_TRANSPOSE_TRICKS[topology].keys())}' - ) - - return physical_mesh.transpose( - *_TRANSPOSE_TRICKS[topology][mesh_shape_no_trivial_dims] - ) - - -def create_device_mesh( - mesh_shape: Sequence[int], - devices: Sequence[Any] | None = None, - *, - contiguous_submeshes: bool = False, - allow_split_physical_axes: bool = False, -) -> np.ndarray: - """Creates a performant device mesh for jax.sharding.Mesh. - - Args: - mesh_shape: shape of logical mesh, ordered by increasing network-intensity - e.g. [replica, data, mdl] where mdl has the most network communication - requirements. - devices: optionally, the devices to construct a mesh for. Defaults to - jax.devices(). - contiguous_submeshes: if True, this function will attempt to create a mesh - where each process's local devices form a contiguous submesh. A ValueError - will be raised if this function can't produce a suitable mesh. This - setting was sometimes necessary before the introduction of jax.Array to - ensure non-ragged local arrays; if using jax.Arrays, it's better to keep - this set to False. - allow_split_physical_axes: If True, we will split physical axes if necessary - to produce the desired device mesh. - - Raises: - ValueError: if the number of devices doesn't equal the product of - `mesh_shape`. - - Returns: - A np.ndarray of JAX devices with mesh_shape as its shape that can be fed - into jax.sharding.Mesh with good collective performance. - """ - if devices is None: - devices = xb.devices() - if np.prod(mesh_shape) != len(devices): - raise ValueError( - f'Number of devices {len(devices)} must equal the product ' - f'of mesh_shape {mesh_shape}' - ) - last_device = devices[-1] - - handler = device_kind_handler_dict.get(last_device.device_kind, None) - if handler is not None: - result = handler( - mesh_shape, devices, contiguous_submeshes=contiguous_submeshes - ) - if result is not None: - return result - - if last_device.platform == 'tpu': - physical_mesh = _get_physical_tpu_mesh(devices) - if contiguous_submeshes: - physical_mesh = _transpose_trick(physical_mesh, mesh_shape) - device_mesh, _ = _create_device_mesh_for_nd_torus( - physical_mesh, - mesh_shape, - allow_split_physical_axes=allow_split_physical_axes, - ) - return device_mesh - else: - device_mesh = np.asarray(devices).reshape(mesh_shape) - return device_mesh - - -def create_hybrid_device_mesh( - mesh_shape: Sequence[int], - dcn_mesh_shape: Sequence[int], - devices: Sequence[Any] | None = None, - *, - process_is_granule: bool = False, - should_sort_granules_by_key: bool = True, - allow_split_physical_axes: bool = False, -) -> np.ndarray: - """Creates a device mesh for hybrid (e.g., ICI and DCN) parallelism. - - Args: - mesh_shape: shape of the logical mesh for the faster/inner network, ordered - by increasing network intensity, e.g. [replica, data, mdl] where mdl has - the most network communication requirements. - dcn_mesh_shape: shape of the logical mesh for the slower/outer network, in - the same order as mesh_shape. - devices: optionally, the devices to construct a mesh for. Defaults to - jax.devices(). - process_is_granule: if True, this function will treat processes as the units - of the slower/outer network. Otherwise it will look for slice_index - attributes on devices and use slices as the units. Enabling this is meant - as a fallback for platforms that don't set slice_index. - should_sort_granules_by_key: Whether device granules should be sorted by the - granule key, either slice or process index, depending on - process_is_granule. - allow_split_physical_axes: If True, we will split physical axes if necessary - to produce the desired device mesh. - - Raises: - ValueError: if the number of slices to which the `devices` belong doesn't - equal the product of `dcn_mesh_shape`, or if the number of devices - belonging to any single slice does not equal the product of `mesh_shape`. - - Returns: - A np.ndarray of JAX devices with mesh_shape * dcn_mesh_shape as its shape - that can be fed into jax.sharding.Mesh for hybrid parallelism. - """ - if devices is None: - devices = xb.devices() - attr = 'process_index' if process_is_granule else 'slice_index' - assert hasattr(devices[0], attr) - granule_dict = collections.defaultdict(list) - for dev in devices: - granule_dict[getattr(dev, attr)].append(dev) - granules = ( - [granule_dict[key] for key in sorted(granule_dict.keys())] - if should_sort_granules_by_key - else granule_dict.values() - ) - if np.prod(dcn_mesh_shape) != len(granules): - raise ValueError( - f'Number of slices {len(granules)} must equal the product of ' - f'dcn_mesh_shape {dcn_mesh_shape}' - ) - per_granule_meshes = [ - create_device_mesh( - mesh_shape, - granule, - allow_split_physical_axes=allow_split_physical_axes, - ) - for granule in granules - ] - # TODO(jekbradbury): handle non-uniform DCN topologies - granule_mesh = np.arange(len(granules)).reshape(dcn_mesh_shape) - blocks = np.vectorize(lambda i: per_granule_meshes[i], otypes=[object])( - granule_mesh - ) - device_mesh = np.block(blocks.tolist()) - return device_mesh +from jax._src.mesh_utils import ( + create_device_mesh as create_device_mesh, + create_hybrid_device_mesh as create_hybrid_device_mesh, + device_kind_handler_dict as device_kind_handler_dict, +) diff --git a/jax/experimental/mosaic/gpu/__init__.py b/jax/experimental/mosaic/gpu/__init__.py index 9d4068745d3a..f5944c862480 100644 --- a/jax/experimental/mosaic/gpu/__init__.py +++ b/jax/experimental/mosaic/gpu/__init__.py @@ -13,837 +13,54 @@ # limitations under the License. # ============================================================================== -from collections.abc import Callable, Sequence -import contextlib -import ctypes -import dataclasses -import functools -import itertools -import math -import os -import pathlib -import subprocess -import tempfile -import time -from typing import Any, Generic, TypeVar - -import jax -from jax._src import config -from jax._src import core as jax_core -from jax._src.interpreters import mlir -from jax._src.lib import xla_client -from jaxlib.mlir import ir -from jaxlib.mlir.dialects import arith -from jaxlib.mlir.dialects import builtin -from jaxlib.mlir.dialects import func -from jaxlib.mlir.dialects import gpu -from jaxlib.mlir.dialects import llvm -from jaxlib.mlir.dialects import memref -from jaxlib.mlir.dialects import nvvm -from jaxlib.mlir.passmanager import PassManager -import numpy as np - -from . import profiler -from . import utils - -# mypy: ignore-errors - -# MLIR can't find libdevice unless we point it to the CUDA path -# TODO(apaszke): Unify with jax._src.lib.cuda_path -CUDA_ROOT = "/usr/local/cuda" -if os.environ.get("CUDA_ROOT") is None: - os.environ["CUDA_ROOT"] = CUDA_ROOT -else: - CUDA_ROOT = os.environ["CUDA_ROOT"] - -PTXAS_PATH = os.path.join(CUDA_ROOT, "bin/ptxas") -NVDISASM_PATH = os.path.join(CUDA_ROOT, "bin/nvdisasm") - -TMA_DESCRIPTOR_BYTES = 128 -TMA_DESCRIPTOR_ALIGNMENT = 64 - - -c = utils.c # This is too common to fully qualify. - - -RUNTIME_PATH = None -try: - from jax._src.lib import mosaic_gpu as mosaic_gpu_lib - - RUNTIME_PATH = ( - pathlib.Path(mosaic_gpu_lib._mosaic_gpu_ext.__file__).parent - / "libmosaic_gpu_runtime.so" - ) -except ImportError: - pass - -if RUNTIME_PATH and RUNTIME_PATH.exists(): - # Set this so that the custom call can find it - os.environ["MOSAIC_GPU_RUNTIME_LIB_PATH"] = str(RUNTIME_PATH) - - -mosaic_gpu_p = jax.core.Primitive("mosaic_gpu_p") -mosaic_gpu_p.multiple_results = True - - -@mosaic_gpu_p.def_abstract_eval -def _mosaic_gpu_abstract_eval(*_, module, out_types, gmem_scratch_bytes): - del module, gmem_scratch_bytes # Unused. - return [jax._src.core.ShapedArray(t.shape, t.dtype) for t in out_types] - -# TODO(apaszke): Implement a proper system for managing kernel lifetimes -kernel_idx = itertools.count() - -def _mosaic_gpu_lowering_rule(ctx, *args, module, out_types, gmem_scratch_bytes): - del out_types # Unused. - idx_bytes = next(kernel_idx).to_bytes(8, byteorder="little") - op = mlir.custom_call( - "mosaic_gpu", - result_types=[ - *(mlir.aval_to_ir_type(aval) for aval in ctx.avals_out), - mlir.aval_to_ir_type( - jax_core.ShapedArray((gmem_scratch_bytes,), np.uint8) - ), - ], - operands=args, - operand_layouts=[list(reversed(range(a.ndim))) for a in ctx.avals_in], - result_layouts=[list(reversed(range(a.ndim))) for a in ctx.avals_out] - + [[0]], - backend_config=idx_bytes + module, - ) - return op.results[:-1] # Skip the scratch space. - -mlir.register_lowering(mosaic_gpu_p, _mosaic_gpu_lowering_rule, "cuda") - - -@dataclasses.dataclass(frozen=True) -class MemRefTransform: - def apply(self, ref: ir.Value) -> ir.Value: - raise NotImplementedError("Subclasses should override this method") - - def transform_index(self, idx: Sequence[ir.Value]) -> tuple[ir.Value, ...]: - raise NotImplementedError("Subclasses should override this method") - - def transform_shape(self, shape: Sequence[int]) -> tuple[int, ...]: - raise NotImplementedError("Subclasses should override this method") - - -@dataclasses.dataclass(frozen=True) -class TileTransform(MemRefTransform): - """Tiles a suffix of memref dimensions. - - For example, given a memref of shape (5, 128, 128) and a tiling of (64, 32), - the shape of the result will be (5, 2, 4, 64, 32). The shape always ends with - the tile shape, and the size of tiled dimensions is divided by the tile size. - This is especially useful for swizzled WGMMA, which expect tiled layouts in - shared memory. - """ - tiling: tuple[int, ...] - - def apply(self, ref: ir.Value) -> ir.Value: - untiled_rank = ir.MemRefType(ref.type).rank - tiling_rank = len(self.tiling) - tiled_rank = untiled_rank + tiling_rank - for t, d in zip(self.tiling[::-1], range(untiled_rank)[::-1]): - ref = utils.memref_unfold(ref, d, (None, t)) - permutation = ( - *range(untiled_rank - tiling_rank), - *range(untiled_rank - tiling_rank, tiled_rank, 2), - *range(untiled_rank - tiling_rank + 1, tiled_rank, 2), - ) - return utils.memref_transpose(ref, permutation) - - def transform_index(self, idx: Sequence[ir.Value]) -> tuple[ir.Value, ...]: - index = ir.IndexType.get() - tiling_rank = len(self.tiling) - return ( - *idx[:-tiling_rank], - *( - arith.divui(i, c(t, index)) - for i, t in zip(idx[-tiling_rank:], self.tiling) - ), - *( - arith.remui(i, c(t, index)) - for i, t in zip(idx[-tiling_rank:], self.tiling) - ), - ) - - def transform_shape(self, shape: Sequence[int]) -> tuple[int, ...]: - # Note that this also checks that tiled dims are not squeezed. Their slice - # size would be 1 if so. - tiling_rank = len(self.tiling) - for size, tile_size in zip(shape[-tiling_rank:], self.tiling): - if size % tile_size: - raise ValueError( - f"Expected GMEM slice shape {shape} suffix to be a multiple" - f" of tiling {self.tiling}" - ) - return ( - *shape[:-tiling_rank], - *(s // t for s, t in zip(shape[-tiling_rank:], self.tiling)), - *self.tiling, - ) - - -@dataclasses.dataclass(frozen=True) -class TransposeTransform(MemRefTransform): - """Transposes memref dimensions.""" - permutation: tuple[int, ...] - - def __post_init__(self): - if len(self.permutation) != len(set(self.permutation)): - raise ValueError("Permutation must be a permutation") - - def apply(self, ref: ir.Value) -> ir.Value: - return utils.memref_transpose(ref, self.permutation) - - def transform_index(self, idx: Sequence[ir.Value]) -> tuple[ir.Value, ...]: - return tuple(idx[p] for p in self.permutation) - - def transform_shape(self, shape: Sequence[int]) -> tuple[int, ...]: - return tuple(shape[p] for p in self.permutation) - - -OnDeviceProfiler = profiler.OnDeviceProfiler - - -@dataclasses.dataclass() -class LaunchContext: - launch_op: gpu.LaunchOp - gmem_scratch_ptr: ir.Value - cluster_size: tuple[int, int, int] - profiler: OnDeviceProfiler | None = None - next_scratch_offset: int = 0 - host_scratch_init: list[Callable[[ir.Value], None]] = dataclasses.field( - default_factory=list, init=False - ) - tma_descriptors: dict[ - tuple[ir.Value, tuple[int, ...], int | None, tuple[MemRefTransform, ...]], - ir.Value, - ] = dataclasses.field(default_factory=dict, init=False) - - @contextlib.contextmanager - def named_region(self, *args, **kwargs): - if self.profiler is not None: - with self.profiler.record(*args, **kwargs): - yield - else: - yield - - def _alloc_scratch( - self, - size: int, - alignment: int | None = None, - host_init: Callable[[ir.Value], None] = lambda _: None, - device_init: Callable[[ir.Value], Any] = lambda x: x, - ) -> ir.Value: - """Allocates a GMEM scratch buffer. - - The buffer is initialized on the host and then copied to GMEM before the - kernel launch. - """ - i8 = ir.IntegerType.get_signless(8) - ptr_ty = ir.Type.parse("!llvm.ptr") - if alignment is None: - alignment = size - if self.next_scratch_offset % alignment: - raise NotImplementedError # TODO(apaszke): Pad to match alignment - alloc_base = self.next_scratch_offset - self.next_scratch_offset += size - def host_init_wrapped(host_ptr): - host_init( - llvm.getelementptr(ptr_ty, host_ptr, [], [alloc_base], i8) - ) - self.host_scratch_init.append(host_init_wrapped) - # with ir.InsertionPoint(self.gmem_scratch_ptr.owner): - # There is no way to create an insertion point after an operation... - gep = llvm.GEPOp( - ptr_ty, self.gmem_scratch_ptr, [], [alloc_base], i8 - ) - gep.move_after(self.gmem_scratch_ptr.owner) - return device_init(gep.result) - - def _get_tma_desc( - self, - gmem_ref, - gmem_transform: tuple[MemRefTransform, ...], - transformed_slice_shape: tuple[int, ...], - swizzle: int | None, - ): - tma_desc_key = (gmem_ref, transformed_slice_shape, swizzle, gmem_transform) - if (tma_desc := self.tma_descriptors.get(tma_desc_key, None)) is None: - i64 = ir.IntegerType.get_signless(64) - ptr_ty = ir.Type.parse("!llvm.ptr") - def init_tma_desc(host_ptr): - ref = gmem_ref - for t in gmem_transform: - ref = t.apply(ref) - ref_ty = ir.MemRefType(ref.type) - # TODO(apaszke): Use utils.memref_ptr to compute base_ptr - _, offset, *sizes_and_strides = memref.extract_strided_metadata(ref) - aligned_ptr_idx = memref.extract_aligned_pointer_as_index(ref) - as_i64 = lambda i: arith.index_cast(i64, i) - alloc_ptr = llvm.inttoptr(ptr_ty, as_i64(aligned_ptr_idx)) - llvm_dyn = -2147483648 # TODO(apaszke): Improve the MLIR bindings... - base_ptr = llvm.getelementptr( - ptr_ty, alloc_ptr, [as_i64(offset)], [llvm_dyn], ref_ty.element_type, - ) - rank = ref_ty.rank - assert rank * 2 == len(sizes_and_strides) - args = [ - host_ptr, - base_ptr, - c(utils.bytewidth(ref_ty.element_type), i64), - c(rank, i64), - utils.pack_array([as_i64(i) for i in sizes_and_strides[:rank]]), - utils.pack_array([as_i64(i) for i in sizes_and_strides[rank:]]), - c(0 if swizzle is None else swizzle, i64), - utils.pack_array([c(v, i64) for v in transformed_slice_shape]), - ] - func.call([], "mosaic_gpu_init_tma_desc", args) - def cast_tma_desc(device_ptr): - # TODO(apaszke): Investigate why prefetching can cause launch failures - # nvvm.prefetch_tensormap(device_ptr) - return device_ptr - tma_desc = self._alloc_scratch( - TMA_DESCRIPTOR_BYTES, - alignment=TMA_DESCRIPTOR_ALIGNMENT, - host_init=init_tma_desc, - device_init=cast_tma_desc, - ) - self.tma_descriptors[tma_desc_key] = tma_desc - return tma_desc - - def async_copy( - self, - *, - src_ref, - dst_ref, - gmem_slice: Any = (), - gmem_transform: MemRefTransform | tuple[MemRefTransform, ...] = (), - barrier: utils.BarrierRef | None = None, - swizzle: int | None = None, - arrive: bool | None = None, - uniform: bool = True, - collective: gpu.Dimension | None = None, - ): - index = ir.IndexType.get() - i16 = ir.IntegerType.get_signless(16) - i32 = ir.IntegerType.get_signless(32) - smem = ir.Attribute.parse("#gpu.address_space") - src_ref_ty = ir.MemRefType(src_ref.type) - dst_ref_ty = ir.MemRefType(dst_ref.type) - element_type = src_ref_ty.element_type - element_bytewidth = utils.bytewidth(element_type) - if element_type != dst_ref_ty.element_type: - raise ValueError( - f"Expected same element type, got {element_type} and" - f" {dst_ref_ty.element_type}" - ) - if not isinstance(gmem_transform, tuple): - gmem_transform = (gmem_transform,) - - if src_ref_ty.memory_space is None and dst_ref_ty.memory_space == smem: - gmem_ref, smem_ref = src_ref, dst_ref - if barrier is None: - raise ValueError("Barriers are required for GMEM -> SMEM copies") - if arrive is None: - arrive = True # Arrive by default - elif src_ref_ty.memory_space == smem and dst_ref_ty.memory_space is None: - gmem_ref, smem_ref = dst_ref, src_ref - if barrier is not None: - raise ValueError("Barriers are unsupported for SMEM -> GMEM copies") - if arrive is not None: - raise ValueError("arrive is unsupported for SMEM -> GMEM copies") - else: - raise ValueError("Only SMEM <-> GMEM copies supported") - # TODO(apaszke): This is a very approximate check. Improve it! - expected_name = "builtin.unrealized_conversion_cast" - if ( - gmem_ref.owner is None - or gmem_ref.owner.opview.OPERATION_NAME != expected_name - ): - raise ValueError("GMEM reference in async_copy must be a kernel argument") - - base_indices, slice_shape, is_squeezed = utils.parse_indices( - gmem_slice, ir.MemRefType(gmem_ref.type).shape - ) - dyn_base_indices = tuple( - c(i, index) if not isinstance(i, ir.Value) else i for i in base_indices - ) - slice_shape = tuple(slice_shape) - for t in gmem_transform: - dyn_base_indices = t.transform_index(dyn_base_indices) - slice_shape = t.transform_shape(slice_shape) - for dim, squeezed in enumerate(is_squeezed): - if squeezed: - smem_ref = utils.memref_unsqueeze(smem_ref, dim) - smem_ref_ty = ir.MemRefType(smem_ref.type) - - if slice_shape != tuple(smem_ref_ty.shape): - raise ValueError( - "Expected the SMEM reference to have the same shape as the" - f" transformed slice: {tuple(smem_ref_ty.shape)} != {slice_shape}" - ) - - dyn_base_indices = list(dyn_base_indices) - slice_shape = list(slice_shape) - collective_size = 1 if collective is None else self.cluster_size[collective] - if collective_size > 1: - def partition_dim(dim: int, idx: ir.Value, num_chunks: int): - nonlocal smem_ref - slice_shape[dim] //= num_chunks - block_offset = arith.muli(idx, c(slice_shape[dim], index)) - dyn_base_indices[dim] = arith.addi(dyn_base_indices[dim], block_offset) - smem_ref = utils.memref_slice( - smem_ref, - (slice(None),) * dim + (utils.ds(block_offset, slice_shape[dim]),) - ) - idx = gpu.cluster_block_id(collective) - rem_collective_size = collective_size - for dim, slice_size in enumerate(slice_shape[:-1]): - if slice_size % rem_collective_size == 0: - partition_dim(dim, idx, rem_collective_size) - break - elif collective_size % slice_size == 0: - dim_idx = arith.remui(idx, c(slice_size, index)) - partition_dim(dim, dim_idx, slice_size) - idx = arith.divui(idx, c(slice_size, index)) - rem_collective_size //= slice_size - else: - raise ValueError( - "None of the leading dimensions in the transformed slice shape" - f" {slice_shape} is divisible by the collective size" - f" {collective_size}" - ) - # Make each block load a smaller slice, adjust the GMEM indices and slice - # the SMEM reference accordingly. - multicast_mask = arith.trunci( - i16, utils.cluster_collective_mask(self.cluster_size, collective) - ) - else: - multicast_mask = None - - tma_desc = self._get_tma_desc( - gmem_ref, gmem_transform, tuple(slice_shape), swizzle, - ) - - # We constuct TMA descriptors in column-major order. - rev_dyn_base_indices = [ - arith.index_cast(i32, idx) for idx in reversed(dyn_base_indices) - ] - - uniform_ctx = ( - functools.partial(utils.single_thread, per_block=False) - if uniform - else contextlib.nullcontext - ) - - rank = len(slice_shape) - if rank > 5: # TODO: apaszke - Implement stride compression - raise ValueError("Async copies only support striding up to 5 dimensions") - if swizzle is not None and slice_shape[-1] != swizzle // element_bytewidth: - raise ValueError( - f"Async copies with {swizzle=} require last dimension of the slice to" - f" be exactly {swizzle} bytes" - f" ({swizzle // element_bytewidth} elements), but got" - f" {slice_shape[-1]}" - ) - smem_ptr = utils.memref_ptr(smem_ref, memory_space=3) - if gmem_ref is src_ref: - assert barrier is not None # for pytype - transfer_bytes = c( - np.prod(slice_shape) * element_bytewidth * collective_size, i32 - ) - barrier_ptr = barrier.get_ptr() - with uniform_ctx(): - if arrive: - nvvm.mbarrier_arrive_expect_tx_shared(barrier_ptr, transfer_bytes) - nvvm.cp_async_bulk_tensor_shared_cluster_global( - smem_ptr, tma_desc, rev_dyn_base_indices, barrier_ptr, [], multicast_mask=multicast_mask, - ) - else: - with uniform_ctx(): - nvvm.cp_async_bulk_tensor_global_shared_cta( - tma_desc, smem_ptr, rev_dyn_base_indices - ) - nvvm.cp_async_bulk_commit_group() - - def await_async_copy( - self, allow_groups: int, await_read_only: bool = False - ): - nvvm.cp_async_bulk_wait_group(allow_groups, read=await_read_only) - # TODO(apaszke): Use a warpgroup barrier!!! - gpu.barrier() # Groups are supposedly tracked per-thread - - -# ShapeTrees currently can not contain unions. -ShapeTree = Any -RefTree = Any -T = TypeVar('T') - - -@dataclasses.dataclass(frozen=True) -class Union(Generic[T]): - members: Sequence[T] - - def __iter__(self): - return iter(self.members) - -@dataclasses.dataclass(frozen=True) -class TMABarrier: - num_barriers: int = 1 - -@dataclasses.dataclass(frozen=True) -class Barrier: - arrival_count: int - num_barriers: int = 1 - -@dataclasses.dataclass(frozen=True) -class ClusterBarrier: - collective_dims: Sequence[gpu.Dimension] - num_barriers: int = 1 - - -def _count_buffer_bytes(shape_dtype: jax.ShapeDtypeStruct) -> int: - return np.prod(shape_dtype.shape) * np.dtype(shape_dtype.dtype).itemsize - - -def _construct_smem_reftree( - cluster_shape: tuple[int, int, int], - dynamic_smem: ir.Value, - smem_buffers: ShapeTree, - dynamic_smem_offset: int = 0, -) -> RefTree: - index = ir.IndexType.get() - i8 = ir.IntegerType.get_signless(8) - ptr = ir.Type.parse("!llvm.ptr") - smem = ir.Attribute.parse("#gpu.address_space") - flat_ref_tys, smem_buffer_tree = jax.tree.flatten( - smem_buffers, is_leaf=lambda x: isinstance(x, Union) - ) - smem_refs = [] - for ref_ty in flat_ref_tys: - def get_barrier_ptr(num_barriers: int) -> ir.Value: - nonlocal dynamic_smem_offset - smem_base_ptr = utils.memref_ptr(dynamic_smem, memory_space=3) - barrier_base_ptr = llvm.getelementptr( - ptr, smem_base_ptr, [], [dynamic_smem_offset], i8 - ) - dynamic_smem_offset += num_barriers * MBARRIER_BYTES - return barrier_base_ptr - match ref_ty: - case Union(members): - member_trees = [ - _construct_smem_reftree(cluster_shape, dynamic_smem, m, dynamic_smem_offset) - for m in members - ] - # TODO(apaszke): This is quadratic, but it shouldn't matter for now... - dynamic_smem_offset += _smem_tree_size(ref_ty) - ref = Union(member_trees) - case TMABarrier(num_barriers): - ref = utils.BarrierRef.initialize( - get_barrier_ptr(num_barriers), num_barriers, arrival_count=1 - ) - case Barrier(arrival_count, num_barriers): - ref = utils.BarrierRef.initialize( - get_barrier_ptr(num_barriers), - num_barriers, - arrival_count=arrival_count, - ) - case ClusterBarrier(collective_dims, num_barriers): - ref = utils.CollectiveBarrierRef.initialize( - get_barrier_ptr(num_barriers), - num_barriers, - collective_dims, - cluster_shape, - ) - case _: - mlir_dtype = mlir.dtype_to_ir_type(ref_ty.dtype) - tile_smem = memref.view( - ir.MemRefType.get(ref_ty.shape, mlir_dtype, memory_space=smem), - dynamic_smem, c(dynamic_smem_offset, index), [], - ) - dynamic_smem_offset += _count_buffer_bytes(ref_ty) - ref = tile_smem - smem_refs.append(ref) - return jax.tree.unflatten(smem_buffer_tree, smem_refs) - - -MBARRIER_BYTES = 8 - - -def _smem_tree_size(smem_buffers: ShapeTree) -> int: - leaves = jax.tree.leaves( - smem_buffers, is_leaf=lambda x: isinstance(x, Union) - ) - size = 0 - for l in leaves: - match l: - case Union(members): - size += max(_smem_tree_size(s) for s in members) - case ( - TMABarrier(num_barriers) - | ClusterBarrier(_, num_barriers=num_barriers) - | Barrier(_, num_barriers=num_barriers) - ): - if size % MBARRIER_BYTES: - raise NotImplementedError("Misaligned barrier allocation") - size += num_barriers * MBARRIER_BYTES - case _: - size += _count_buffer_bytes(l) - return size - - -# TODO(apaszke): Inline this -@contextlib.contextmanager -def _launch( - token, - grid: tuple[int, int, int], - cluster: tuple[int, int, int], - block: tuple[int, int, int], - scratch_arr, - smem_buffers: ShapeTree | Union[ShapeTree], - profiler_spec: profiler.ProfilerSpec | None = None, - maybe_prof_buffer: ir.Value | None = None, -): - if (profiler_spec is None) != (maybe_prof_buffer is None): - raise ValueError - index = ir.IndexType.get() - i32 = ir.IntegerType.get_signless(32) - i8 = ir.IntegerType.get_signless(8) - grid_vals = [c(i, index) for i in grid] - block_vals = [c(i, index) for i in block] - - user_smem_bytes = _smem_tree_size(smem_buffers) - - smem_bytes = user_smem_bytes - if profiler_spec is not None: - smem_bytes += profiler_spec.smem_bytes(block=block) - - # TODO(cperivol): Query the shared memory size programmatically. - if smem_bytes > 228 * 1024: - raise ValueError(f"Mosaic GPU kernel exceeds available shared memory {smem_bytes=} > 228000") - if math.prod(cluster) != 1: - if len(cluster) != 3: - raise ValueError("Clusters must be 3D") - cluster_kwargs = { - "clusterSize" + d: c(s, index) for s, d in zip(cluster, "XYZ") - } - for d, grid_size, cluster_size in zip("xyz", grid, cluster): - if grid_size % cluster_size != 0: - raise ValueError( - f"Grid dimension {d} must be divisible by cluster dimension:" - f" {grid_size} % {cluster_size} != 0" - ) - else: - cluster_kwargs = {} - launch_op = gpu.LaunchOp( - token.type, [token], *grid_vals, *block_vals, - dynamicSharedMemorySize=c(smem_bytes, i32), **cluster_kwargs) - launch_op.body.blocks.append(*([index] * (12 + 2 * len(cluster_kwargs)))) # Append an empty block - smem = ir.Attribute.parse("#gpu.address_space") - with ir.InsertionPoint(launch_op.body.blocks[0]): - dynamic_smem = gpu.dynamic_shared_memory( - ir.MemRefType.get( - (ir.ShapedType.get_dynamic_size(),), i8, memory_space=smem - ) - ) - - smem_ref_tree = _construct_smem_reftree( - cluster, dynamic_smem, smem_buffers - ) - # TODO(apaszke): Skip the following if no barriers were initialized. - nvvm.fence_mbarrier_init() - if math.prod(cluster) != 1: - nvvm.cluster_arrive_relaxed(aligned=ir.UnitAttr.get()) - nvvm.cluster_wait(aligned=ir.UnitAttr.get()) - gpu.barrier() - - if profiler_spec: - prof_smem = memref.view( - ir.MemRefType.get( - (profiler_spec.smem_i32_elements(block=block),), - i32, memory_space=smem, - ), - dynamic_smem, c(user_smem_bytes, index), [], - ) - prof = profiler.OnDeviceProfiler( - profiler_spec, prof_smem, maybe_prof_buffer - ) - else: - prof = None - - ptr_ty = ir.Type.parse("!llvm.ptr") - scratch_ptr = builtin.unrealized_conversion_cast([ptr_ty], [scratch_arr]) - yield LaunchContext(launch_op, scratch_ptr, cluster, prof), smem_ref_tree - if prof is not None: - prof.finalize(grid=grid, block=block) - gpu.terminator() - - -def _lower_as_gpu_kernel( - body, - grid: tuple[int, int, int], - cluster: tuple[int, int, int], - block: tuple[int, int, int], - in_shapes: tuple[Any, ...], - out_shape, - smem_scratch_shape: ShapeTree | Union[ShapeTree], - module_name: str, - prof_spec: profiler.ProfilerSpec | None = None, -): - ptr_ty = ir.Type.parse("!llvm.ptr") - token_ty = ir.Type.parse("!gpu.async.token") - i32 = ir.IntegerType.get_signless(32) - i64 = ir.IntegerType.get_signless(64) - - def _shape_to_ref_ty(shape: jax.ShapeDtypeStruct) -> ir.MemRefType: - return ir.MemRefType.get(shape.shape, mlir.dtype_to_ir_type(shape.dtype)) - - in_ref_tys = [_shape_to_ref_ty(t) for t in in_shapes] - - unwrap_output_tuple = False - if isinstance(out_shape, list): - out_shape = tuple(out_shape) - elif not isinstance(out_shape, tuple): - out_shape = (out_shape,) - unwrap_output_tuple = True - out_ref_tys = [_shape_to_ref_ty(t) for t in out_shape] - if prof_spec is not None: - out_shape = (*out_shape, prof_spec.jax_buffer_type(grid, block)) - out_ref_tys.append(prof_spec.mlir_buffer_type(grid, block)) - - module = ir.Module.create() - attrs = module.operation.attributes - attrs["sym_name"] = ir.StringAttr.get(module_name) - with ir.InsertionPoint(module.body): - _declare_runtime_functions() - gmem_scratch_bytes = 0 - global_scratch = llvm.GlobalOp( - ir.Type.parse("!llvm.array<0 x i8>"), # We don't know the shape yet. - "global_scratch", - ir.Attribute.parse("#llvm.linkage"), - addr_space=ir.IntegerAttr.get(i32, 4), # GPU constant memory. - ) - @func.FuncOp.from_py_func(ptr_ty, ptr_ty, ptr_ty) - def main(token_ptr, buffers, gmem_scratch_ptr): - nonlocal gmem_scratch_bytes - token = builtin.unrealized_conversion_cast([token_ty], [token_ptr]) - arg_refs = [] - for i, ref_ty in enumerate([*in_ref_tys, *out_ref_tys]): - ptr = llvm.LoadOp(ptr_ty, llvm.GEPOp(ptr_ty, buffers, [], [i], ptr_ty)) - arg_refs.append(utils.ptr_as_memref(ptr, ir.MemRefType(ref_ty))) - in_refs = arg_refs[:len(in_ref_tys)] - out_refs = arg_refs[len(in_ref_tys):] - prof_buffer = out_refs.pop() if prof_spec is not None else None - empty_arr_ty = ir.Type.parse("!llvm.array<0 x i8>") - scratch_alloc = llvm.AllocaOp( - ptr_ty, c(1, i64), empty_arr_ty, alignment=TMA_DESCRIPTOR_ALIGNMENT - ) - scratch_arr = llvm.load(empty_arr_ty, scratch_alloc.result) - with _launch( - token, grid, cluster, block, scratch_arr, smem_scratch_shape, - prof_spec, prof_buffer - ) as (launch_ctx, smem_refs): - body(launch_ctx, *in_refs, *out_refs, smem_refs) - gmem_scratch_bytes = launch_ctx.next_scratch_offset - # Allocate and initialize the host buffer right before the launch. - # Note that we couldn't do that before, because we had to run the body - # to learn what the scratch contains. - with ir.InsertionPoint(scratch_arr.owner): - scratch_arr_ty = ir.Type.parse(f"!llvm.array<{gmem_scratch_bytes} x i8>") - scratch_alloc.elem_type = ir.TypeAttr.get(scratch_arr_ty) - scratch_arr.set_type(scratch_arr_ty) - for init_callback in launch_ctx.host_scratch_init: - init_callback(scratch_alloc.result) - main.func_op.attributes["llvm.emit_c_interface"] = ir.UnitAttr.get() - sym_tab = ir.SymbolTable(module.operation) - sym_tab.insert(main.func_op) - sym_tab.insert(global_scratch) - module.operation.verify() - - return module, out_shape, gmem_scratch_bytes, unwrap_output_tuple - - -def as_gpu_kernel( - body, - grid: tuple[int, int, int], - block: tuple[int, int, int], - in_shape, - out_shape, - smem_scratch_shape: ShapeTree | Union[ShapeTree], - prof_spec: profiler.ProfilerSpec | None = None, - cluster: tuple[int, int, int] = (1, 1, 1), - module_name: str = "unknown", -): - if isinstance(in_shape, list): - in_shape = tuple(in_shape) - elif not isinstance(in_shape, tuple): - in_shape = (in_shape,) - - module, out_shape, gmem_scratch_bytes, unwrap_output_tuple = ( - _lower_as_gpu_kernel( - body, grid, cluster, block, in_shape, out_shape, smem_scratch_shape, - module_name, prof_spec - ) - ) - - expected_arg_treedef = jax.tree.structure(in_shape) - def _check_args(*args): - arg_treedef = jax.tree.structure(args) - if arg_treedef != expected_arg_treedef: - raise ValueError( - f"Invalid argument structure: expected {expected_arg_treedef}, got" - f" {arg_treedef}, ({args=})" - ) - - module_asm = module.operation.get_asm(binary=True, enable_debug_info=True) - def bind(*args): - return mosaic_gpu_p.bind( - *args, - out_types=out_shape, - module=module_asm, - gmem_scratch_bytes=gmem_scratch_bytes, - ) - - if prof_spec is not None: - @jax.jit - def prof_kernel(*args): - _check_args(*args) - *results, prof_buffer = bind(*args) - def dump_profile(prof_buffer): - out_file = os.path.join( - os.getenv("TEST_UNDECLARED_OUTPUTS_DIR"), - f"{time.time_ns()}-trace.json", - ) - try: - with open(out_file, "x") as f: - prof_spec.dump(prof_buffer, f, grid=grid, block=block) - except FileExistsError: - pass # TODO: Retry - jax.debug.callback(dump_profile, prof_buffer) - return results[0] if unwrap_output_tuple else results - return prof_kernel - else: - @jax.jit - def kernel(*args): - _check_args(*args) - results = bind(*args) - return results[0] if unwrap_output_tuple else results - return kernel - - -def _declare_runtime_functions(): - """Declares the runtime functions that can be used by the generated code.""" - ptr_ty = ir.Type.parse("!llvm.ptr") - i64 = ir.IntegerType.get_signless(64) - arg_tys = [ptr_ty, ptr_ty, i64, i64, ptr_ty, ptr_ty, i64, ptr_ty] - init_tma_desc_type = ir.FunctionType.get(arg_tys, []) - func.FuncOp( - "mosaic_gpu_init_tma_desc", init_tma_desc_type, visibility="private" - ) - memcpy_async_type = ir.FunctionType.get([ptr_ty, ptr_ty, i64, ptr_ty], []) - func.FuncOp( - "mosaic_gpu_memcpy_async_h2d", memcpy_async_type, visibility="private" - ) +from jax import ShapeDtypeStruct +from .core import ( + Barrier, + ClusterBarrier, + LaunchContext, + MemRefTransform, + TMABarrier, + TileTransform, + TransposeTransform, + Union, + as_gpu_kernel, +) +from .fragmented_array import ( + FragmentedArray, + FragmentedLayout, + WGMMA_LAYOUT, + WGMMA_ROW_LAYOUT, + WGSplatFragLayout, + WGStridedFragLayout, +) +from .utils import ( + BarrierRef, + CollectiveBarrierRef, + DynamicSlice, + Partition, + Partition1D, + bytewidth, + c, + commit_shared, + debug_print, + ds, + fori, + memref_fold, + memref_slice, + memref_transpose, + memref_unfold, + memref_unsqueeze, + single_thread, + single_thread_predicate, + thread_idx, + tile_shape, + warp_idx, + warpgroup_barrier, + warpgroup_idx, + when, +) +from .wgmma import ( + WGMMAAccumulator, + WGMMALayout, + wgmma, +) diff --git a/jax/experimental/mosaic/gpu/core.py b/jax/experimental/mosaic/gpu/core.py new file mode 100644 index 000000000000..bf5ec0dfc8af --- /dev/null +++ b/jax/experimental/mosaic/gpu/core.py @@ -0,0 +1,985 @@ +# Copyright 2024 The JAX Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +from collections.abc import Callable, Sequence +import contextlib +import ctypes +import dataclasses +import functools +import hashlib +import math +import os +import pathlib +import time +from typing import Any, Generic, TypeVar +import weakref + +import jax +from jax._src.interpreters import mlir +from jaxlib.mlir import ir +from jaxlib.mlir.dialects import arith +from jaxlib.mlir.dialects import builtin +from jaxlib.mlir.dialects import func +from jaxlib.mlir.dialects import gpu +from jaxlib.mlir.dialects import llvm +from jaxlib.mlir.dialects import memref +from jaxlib.mlir.dialects import nvvm +import numpy as np + +from . import profiler +from . import utils + +# mypy: ignore-errors + +# MLIR can't find libdevice unless we point it to the CUDA path +# TODO(apaszke): Unify with jax._src.lib.cuda_path +CUDA_ROOT = "/usr/local/cuda" +if os.environ.get("CUDA_ROOT") is None: + os.environ["CUDA_ROOT"] = CUDA_ROOT +else: + CUDA_ROOT = os.environ["CUDA_ROOT"] + +PTXAS_PATH = os.path.join(CUDA_ROOT, "bin/ptxas") +NVDISASM_PATH = os.path.join(CUDA_ROOT, "bin/nvdisasm") + +TMA_DESCRIPTOR_BYTES = 128 +TMA_DESCRIPTOR_ALIGNMENT = 64 + + +c = utils.c # This is too common to fully qualify. + + +RUNTIME_PATH = None +try: + from jax._src.lib import mosaic_gpu as mosaic_gpu_lib + + RUNTIME_PATH = ( + pathlib.Path(mosaic_gpu_lib._mosaic_gpu_ext.__file__).parent + / "libmosaic_gpu_runtime.so" + ) +except ImportError: + pass + +if RUNTIME_PATH and RUNTIME_PATH.exists(): + # Set this so that the custom call can find it + os.environ["MOSAIC_GPU_RUNTIME_LIB_PATH"] = str(RUNTIME_PATH) + + +mosaic_gpu_p = jax.core.Primitive("mosaic_gpu_p") +mosaic_gpu_p.multiple_results = True + + +@mosaic_gpu_p.def_abstract_eval +def _mosaic_gpu_abstract_eval(*_, module, out_types): + del module # Unused. + return [jax._src.core.ShapedArray(t.shape, t.dtype) for t in out_types] + +# TODO(apaszke): Implement a proper system for managing kernel lifetimes +KNOWN_KERNELS = {} + + +def _mosaic_gpu_lowering_rule( + ctx, + *args, + module, + out_types, + input_output_aliases: tuple[tuple[int, int], ...] = (), +): + del out_types # Unused. + kernel_id = hashlib.sha256(module).digest() + # Note that this is technically only a half measure. Someone might load a + # compiled module with a hash collision from disk. But that's so unlikely with + # SHA256 that it shouldn't be a problem. + if (kernel_text := KNOWN_KERNELS.get(kernel_id, None)) is not None: + if kernel_text != module: + raise RuntimeError("Hash collision!") + else: + KNOWN_KERNELS[kernel_id] = module + op = mlir.custom_call( + "mosaic_gpu", + result_types=[mlir.aval_to_ir_type(aval) for aval in ctx.avals_out], + operands=args, + operand_layouts=[list(reversed(range(a.ndim))) for a in ctx.avals_in], + result_layouts=[list(reversed(range(a.ndim))) for a in ctx.avals_out], + backend_config=kernel_id + module, + operand_output_aliases=dict(input_output_aliases), + ) + return op.results + + +mlir.register_lowering(mosaic_gpu_p, _mosaic_gpu_lowering_rule, "cuda") + + +@dataclasses.dataclass(frozen=True) +class MemRefTransform: + def apply(self, ref: ir.Value) -> ir.Value: + raise NotImplementedError("Subclasses should override this method") + + def transform_index(self, idx: Sequence[ir.Value]) -> tuple[ir.Value, ...]: + raise NotImplementedError("Subclasses should override this method") + + def transform_shape(self, shape: Sequence[int]) -> tuple[int, ...]: + raise NotImplementedError("Subclasses should override this method") + + +@dataclasses.dataclass(frozen=True) +class TileTransform(MemRefTransform): + """Tiles a suffix of memref dimensions. + + For example, given a memref of shape (5, 128, 128) and a tiling of (64, 32), + the shape of the result will be (5, 2, 4, 64, 32). The shape always ends with + the tile shape, and the size of tiled dimensions is divided by the tile size. + This is especially useful for swizzled WGMMA, which expect tiled layouts in + shared memory. + """ + tiling: tuple[int, ...] + + def apply(self, ref: ir.Value) -> ir.Value: + untiled_rank = ir.MemRefType(ref.type).rank + tiling_rank = len(self.tiling) + tiled_rank = untiled_rank + tiling_rank + for t, d in zip(self.tiling[::-1], range(untiled_rank)[::-1]): + s = ir.MemRefType(ref.type).shape[d] + if s % t and s > t: + raise ValueError( + f"Dimension {d} must have size smaller or a multiple of its tiling" + f" {t}, but got {s}" + ) + ref = utils.memref_unfold(ref, d, (None, min(t, s))) + permutation = ( + *range(untiled_rank - tiling_rank), + *range(untiled_rank - tiling_rank, tiled_rank, 2), + *range(untiled_rank - tiling_rank + 1, tiled_rank, 2), + ) + return utils.memref_transpose(ref, permutation) + + def transform_index(self, idx: Sequence[ir.Value]) -> tuple[ir.Value, ...]: + index = ir.IndexType.get() + tiling_rank = len(self.tiling) + return ( + *idx[:-tiling_rank], + *( + arith.divui(i, c(t, index)) + for i, t in zip(idx[-tiling_rank:], self.tiling) + ), + *( + arith.remui(i, c(t, index)) + for i, t in zip(idx[-tiling_rank:], self.tiling) + ), + ) + + def transform_shape(self, shape: Sequence[int]) -> tuple[int, ...]: + # Note that this also checks that tiled dims are not squeezed. Their slice + # size would be 1 if so. + tiling_rank = len(self.tiling) + for size, tile_size in zip(shape[-tiling_rank:], self.tiling): + if size % tile_size: + raise ValueError( + f"Expected GMEM slice shape {shape} suffix to be a multiple of" + f" tiling {self.tiling}.\nIf you're using padded async copies, your" + " slice might need to extend out of bounds of the GMEM buffer (OOB" + " accesses will be skipped)." + ) + return ( + *shape[:-tiling_rank], + *(s // t for s, t in zip(shape[-tiling_rank:], self.tiling)), + *self.tiling, + ) + + +@dataclasses.dataclass(frozen=True) +class TransposeTransform(MemRefTransform): + """Transposes memref dimensions.""" + permutation: tuple[int, ...] + + def __post_init__(self): + if len(self.permutation) != len(set(self.permutation)): + raise ValueError("Permutation must be a permutation") + + def apply(self, ref: ir.Value) -> ir.Value: + return utils.memref_transpose(ref, self.permutation) + + def transform_index(self, idx: Sequence[ir.Value]) -> tuple[ir.Value, ...]: + return tuple(idx[p] for p in self.permutation) + + def transform_shape(self, shape: Sequence[int]) -> tuple[int, ...]: + return tuple(shape[p] for p in self.permutation) + + +OnDeviceProfiler = profiler.OnDeviceProfiler + + +@dataclasses.dataclass() +class LaunchContext: + launch_op: gpu.LaunchOp + gmem_scratch_ptr: ir.Value + cluster_size: tuple[int, int, int] + profiler: OnDeviceProfiler | None = None + next_scratch_offset: int = 0 + host_scratch_init: list[Callable[[ir.Value], None]] = dataclasses.field( + default_factory=list, init=False + ) + tma_descriptors: dict[ + tuple[ir.Value, tuple[int, ...], int | None, tuple[MemRefTransform, ...]], + ir.Value, + ] = dataclasses.field(default_factory=dict, init=False) + + @contextlib.contextmanager + def named_region(self, *args, **kwargs): + if self.profiler is not None: + with self.profiler.record(*args, **kwargs): + yield + else: + yield + + def _alloc_scratch( + self, + size: int, + alignment: int | None = None, + host_init: Callable[[ir.Value], None] = lambda _: None, + device_init: Callable[[ir.Value], Any] = lambda x: x, + ) -> ir.Value: + """Allocates a GMEM scratch buffer. + + The buffer is initialized on the host and then copied to GMEM before the + kernel launch. + """ + i8 = ir.IntegerType.get_signless(8) + ptr_ty = ir.Type.parse("!llvm.ptr") + if alignment is None: + alignment = size + if self.next_scratch_offset % alignment: + raise NotImplementedError # TODO(apaszke): Pad to match alignment + alloc_base = self.next_scratch_offset + self.next_scratch_offset += size + def host_init_wrapped(host_ptr): + host_init( + llvm.getelementptr(ptr_ty, host_ptr, [], [alloc_base], i8) + ) + self.host_scratch_init.append(host_init_wrapped) + # with ir.InsertionPoint(self.gmem_scratch_ptr.owner): + # There is no way to create an insertion point after an operation... + gep = llvm.GEPOp( + ptr_ty, self.gmem_scratch_ptr, [], [alloc_base], i8 + ) + gep.move_after(self.gmem_scratch_ptr.owner) + return device_init(gep.result) + + def _get_tma_desc( + self, + gmem_ref, + gmem_transform: tuple[MemRefTransform, ...], + transformed_slice_shape: tuple[int, ...], + swizzle: int | None, + ): + tma_desc_key = (gmem_ref, transformed_slice_shape, swizzle, gmem_transform) + if (tma_desc := self.tma_descriptors.get(tma_desc_key, None)) is None: + i64 = ir.IntegerType.get_signless(64) + ptr_ty = ir.Type.parse("!llvm.ptr") + def init_tma_desc(host_ptr): + ref = gmem_ref + for t in gmem_transform: + ref = t.apply(ref) + ref_ty = ir.MemRefType(ref.type) + # TODO(apaszke): Use utils.memref_ptr to compute base_ptr + _, offset, *sizes_and_strides = memref.extract_strided_metadata(ref) + aligned_ptr_idx = memref.extract_aligned_pointer_as_index(ref) + as_i64 = lambda i: arith.index_cast(i64, i) + alloc_ptr = llvm.inttoptr(ptr_ty, as_i64(aligned_ptr_idx)) + llvm_dyn = -2147483648 # TODO(apaszke): Improve the MLIR bindings... + base_ptr = llvm.getelementptr( + ptr_ty, alloc_ptr, [as_i64(offset)], [llvm_dyn], ref_ty.element_type, + ) + rank = ref_ty.rank + assert rank * 2 == len(sizes_and_strides) + args = [ + host_ptr, + base_ptr, + c(utils.bytewidth(ref_ty.element_type), i64), + c(rank, i64), + utils.pack_array([as_i64(i) for i in sizes_and_strides[:rank]]), + utils.pack_array([as_i64(i) for i in sizes_and_strides[rank:]]), + c(0 if swizzle is None else swizzle, i64), + utils.pack_array([c(v, i64) for v in transformed_slice_shape]), + ] + func.call([], "mosaic_gpu_init_tma_desc", args) + def cast_tma_desc(device_ptr): + # TODO(apaszke): Investigate why prefetching can cause launch failures + # nvvm.prefetch_tensormap(device_ptr) + return device_ptr + tma_desc = self._alloc_scratch( + TMA_DESCRIPTOR_BYTES, + alignment=TMA_DESCRIPTOR_ALIGNMENT, + host_init=init_tma_desc, + device_init=cast_tma_desc, + ) + self.tma_descriptors[tma_desc_key] = tma_desc + return tma_desc + + def async_copy( + self, + *, + src_ref, + dst_ref, + gmem_slice: Any = (), + gmem_transform: MemRefTransform | tuple[MemRefTransform, ...] = (), + barrier: utils.BarrierRef | None = None, + swizzle: int | None = None, + arrive: bool | None = None, + uniform: bool = True, + collective: Sequence[gpu.Dimension] | gpu.Dimension | None = None, + predicate: ir.Value | None = None, + ): + index = ir.IndexType.get() + i16 = ir.IntegerType.get_signless(16) + i32 = ir.IntegerType.get_signless(32) + smem = ir.Attribute.parse("#gpu.address_space") + src_ref_ty = ir.MemRefType(src_ref.type) + dst_ref_ty = ir.MemRefType(dst_ref.type) + element_type = src_ref_ty.element_type + element_bytewidth = utils.bytewidth(element_type) + if element_type != dst_ref_ty.element_type: + raise ValueError( + f"Expected same element type, got {element_type} and" + f" {dst_ref_ty.element_type}" + ) + if not isinstance(gmem_transform, tuple): + gmem_transform = (gmem_transform,) + + if src_ref_ty.memory_space is None and dst_ref_ty.memory_space == smem: + gmem_ref, smem_ref = src_ref, dst_ref + if barrier is None: + raise ValueError("Barriers are required for GMEM -> SMEM copies") + if arrive is None: + arrive = True # Arrive by default + elif src_ref_ty.memory_space == smem and dst_ref_ty.memory_space is None: + gmem_ref, smem_ref = dst_ref, src_ref + if barrier is not None: + raise ValueError("Barriers are unsupported for SMEM -> GMEM copies") + if arrive is not None: + raise ValueError("arrive is unsupported for SMEM -> GMEM copies") + else: + raise ValueError("Only SMEM <-> GMEM copies supported") + # TODO(apaszke): This is a very approximate check. Improve it! + expected_name = "builtin.unrealized_conversion_cast" + if ( + gmem_ref.owner is None + or gmem_ref.owner.opview.OPERATION_NAME != expected_name + ): + raise ValueError("GMEM reference in async_copy must be a kernel argument") + + base_indices, slice_shape, is_squeezed = utils.parse_indices( + gmem_slice, ir.MemRefType(gmem_ref.type).shape + ) + dyn_base_indices = tuple( + c(i, index) if not isinstance(i, ir.Value) else i for i in base_indices + ) + slice_shape = tuple(slice_shape) + for t in gmem_transform: + dyn_base_indices = t.transform_index(dyn_base_indices) + slice_shape = t.transform_shape(slice_shape) + for dim, squeezed in enumerate(is_squeezed): + if squeezed: + smem_ref = utils.memref_unsqueeze(smem_ref, dim) + smem_ref_ty = ir.MemRefType(smem_ref.type) + + if slice_shape != tuple(smem_ref_ty.shape): + raise ValueError( + "Expected the SMEM reference to have the same shape as the" + f" transformed slice: {tuple(smem_ref_ty.shape)} != {slice_shape}" + ) + + dyn_base_indices = list(dyn_base_indices) + slice_shape = list(slice_shape) + collective_size = 1 + if collective is not None: + if isinstance(collective, gpu.Dimension): + collective = (collective,) + collective_size = math.prod(self.cluster_size[d] for d in collective) + if collective_size > 1: + def partition_dim(dim: int, idx: ir.Value, num_chunks: int): + nonlocal smem_ref + slice_shape[dim] //= num_chunks + block_offset = arith.muli(idx, c(slice_shape[dim], index)) + dyn_base_indices[dim] = arith.addi(dyn_base_indices[dim], block_offset) + smem_ref = utils.memref_slice( + smem_ref, + (slice(None),) * dim + (utils.ds(block_offset, slice_shape[dim]),) + ) + stride = 1 + idx = c(0, index) + for d in sorted(collective): + if self.cluster_size[d] == 1: # Optimize a multiply by 0. + continue + idx = arith.addi(idx, arith.muli(gpu.cluster_block_id(d), c(stride, index))) + stride *= self.cluster_size[d] + rem_collective_size = collective_size + for dim, slice_size in enumerate(slice_shape[:-1]): + if slice_size % rem_collective_size == 0: + partition_dim(dim, idx, rem_collective_size) + rem_collective_size = 1 + break + elif rem_collective_size % slice_size == 0: + dim_idx = arith.remui(idx, c(slice_size, index)) + partition_dim(dim, dim_idx, slice_size) + idx = arith.divui(idx, c(slice_size, index)) + rem_collective_size //= slice_size + else: + break # We failed to partition the leading dimensions. + del idx # We overwrote the block index in the loop. + if rem_collective_size > 1: + raise ValueError( + "None of the leading dimensions in the transformed slice shape" + f" {slice_shape} is divisible by the collective size" + f" {collective_size}" + ) + # Make each block load a smaller slice, adjust the GMEM indices and slice + # the SMEM reference accordingly. + multicast_mask = arith.trunci( + i16, utils.cluster_collective_mask(self.cluster_size, collective) + ) + else: + multicast_mask = None + + tma_desc = self._get_tma_desc( + gmem_ref, gmem_transform, tuple(slice_shape), swizzle, + ) + + # We constuct TMA descriptors in column-major order. + rev_dyn_base_indices = [ + arith.index_cast(i32, idx) for idx in reversed(dyn_base_indices) + ] + + uniform_ctx = ( + functools.partial(utils.single_thread, per_block=False) + if uniform + else contextlib.nullcontext + ) + + rank = len(slice_shape) + if rank > 5: # TODO: apaszke - Implement stride compression + raise ValueError("Async copies only support striding up to 5 dimensions") + if max(slice_shape) > 256: + raise ValueError( + "Async copies only support copying <=256 elements along each" + " dimension" + ) + if (zeroth_bw := slice_shape[-1] * element_bytewidth) % 16 != 0: + raise ValueError( + "Async copies require the number of bytes copied along the last" + f" dimension to be divisible by 16, but got {zeroth_bw}" + ) + if swizzle is not None and slice_shape[-1] != swizzle // element_bytewidth: + raise ValueError( + f"Async copies with {swizzle=} require last dimension of the slice to" + f" be exactly {swizzle} bytes" + f" ({swizzle // element_bytewidth} elements), but got" + f" {slice_shape[-1]}" + ) + smem_ptr = utils.memref_ptr(smem_ref, memory_space=3) + if gmem_ref is src_ref: + assert barrier is not None # for pytype + transfer_bytes = c( + np.prod(slice_shape) * element_bytewidth * collective_size, i32 + ) + barrier_ptr = barrier.get_ptr() + with uniform_ctx(): + if arrive: + nvvm.mbarrier_arrive_expect_tx_shared( + barrier_ptr, transfer_bytes, predicate=predicate + ) + nvvm.cp_async_bulk_tensor_shared_cluster_global( + smem_ptr, tma_desc, rev_dyn_base_indices, barrier_ptr, [], + multicast_mask=multicast_mask, predicate=predicate + ) + else: + with uniform_ctx(): + nvvm.cp_async_bulk_tensor_global_shared_cta( + tma_desc, smem_ptr, rev_dyn_base_indices, predicate=predicate + ) + nvvm.cp_async_bulk_commit_group() + + def await_async_copy( + self, allow_groups: int, await_read_only: bool = False + ): + nvvm.cp_async_bulk_wait_group(allow_groups, read=await_read_only) + utils.warpgroup_barrier() + + +# ShapeTrees currently can not contain unions. +ShapeTree = Any +RefTree = Any +T = TypeVar('T') + + +@dataclasses.dataclass(frozen=True) +class Union(Generic[T]): + members: Sequence[T] + + def __iter__(self): + return iter(self.members) + +@dataclasses.dataclass(frozen=True) +class TMABarrier: + num_barriers: int = 1 + +@dataclasses.dataclass(frozen=True) +class Barrier: + arrival_count: int + num_barriers: int = 1 + +@dataclasses.dataclass(frozen=True) +class ClusterBarrier: + collective_dims: Sequence[gpu.Dimension] + num_barriers: int = 1 + + +def _count_buffer_bytes(shape_dtype: jax.ShapeDtypeStruct) -> int: + return np.prod(shape_dtype.shape) * np.dtype(shape_dtype.dtype).itemsize + + +def _construct_smem_reftree( + cluster_shape: tuple[int, int, int], + dynamic_smem: ir.Value, + smem_buffers: ShapeTree, + dynamic_smem_offset: int = 0, +) -> RefTree: + index = ir.IndexType.get() + i8 = ir.IntegerType.get_signless(8) + ptr = ir.Type.parse("!llvm.ptr") + smem = ir.Attribute.parse("#gpu.address_space") + flat_ref_tys, smem_buffer_tree = jax.tree.flatten( + smem_buffers, is_leaf=lambda x: isinstance(x, Union) + ) + smem_refs = [] + for ref_ty in flat_ref_tys: + def get_barrier_ptr(num_barriers: int) -> ir.Value: + nonlocal dynamic_smem_offset + smem_base_ptr = utils.memref_ptr(dynamic_smem, memory_space=3) + barrier_base_ptr = llvm.getelementptr( + ptr, smem_base_ptr, [], [dynamic_smem_offset], i8 + ) + dynamic_smem_offset += num_barriers * MBARRIER_BYTES + return barrier_base_ptr + match ref_ty: + case Union(members): + member_trees = [ + _construct_smem_reftree(cluster_shape, dynamic_smem, m, dynamic_smem_offset) + for m in members + ] + # TODO(apaszke): This is quadratic, but it shouldn't matter for now... + dynamic_smem_offset += _smem_tree_size(ref_ty) + ref = Union(member_trees) + case TMABarrier(num_barriers): + ref = utils.BarrierRef.initialize( + get_barrier_ptr(num_barriers), num_barriers, arrival_count=1 + ) + case Barrier(arrival_count, num_barriers): + ref = utils.BarrierRef.initialize( + get_barrier_ptr(num_barriers), + num_barriers, + arrival_count=arrival_count, + ) + case ClusterBarrier(collective_dims, num_barriers): + ref = utils.CollectiveBarrierRef.initialize( + get_barrier_ptr(num_barriers), + num_barriers, + collective_dims, + cluster_shape, + ) + case _: + mlir_dtype = utils.dtype_to_ir_type(ref_ty.dtype) + tile_smem = memref.view( + ir.MemRefType.get(ref_ty.shape, mlir_dtype, memory_space=smem), + dynamic_smem, c(dynamic_smem_offset, index), [], + ) + dynamic_smem_offset += _count_buffer_bytes(ref_ty) + ref = tile_smem + smem_refs.append(ref) + return jax.tree.unflatten(smem_buffer_tree, smem_refs) + + +MBARRIER_BYTES = 8 + + +def _smem_tree_size(smem_buffers: ShapeTree) -> int: + leaves = jax.tree.leaves( + smem_buffers, is_leaf=lambda x: isinstance(x, Union) + ) + size = 0 + for l in leaves: + match l: + case Union(members): + size += max(_smem_tree_size(s) for s in members) + case ( + TMABarrier(num_barriers) + | ClusterBarrier(_, num_barriers=num_barriers) + | Barrier(_, num_barriers=num_barriers) + ): + if size % MBARRIER_BYTES: + raise NotImplementedError("Misaligned barrier allocation") + size += num_barriers * MBARRIER_BYTES + case _: + size += _count_buffer_bytes(l) + return size + + +# TODO(apaszke): Inline this +@contextlib.contextmanager +def _launch( + token, + grid: tuple[int, int, int], + cluster: tuple[int, int, int], + block: tuple[int, int, int], + scratch_arr, + smem_buffers: ShapeTree | Union[ShapeTree], + profiler_spec: profiler.ProfilerSpec | None = None, + maybe_prof_buffer: ir.Value | None = None, +): + if (profiler_spec is None) != (maybe_prof_buffer is None): + raise ValueError + index = ir.IndexType.get() + i32 = ir.IntegerType.get_signless(32) + i8 = ir.IntegerType.get_signless(8) + grid_vals = [c(i, index) for i in grid] + block_vals = [c(i, index) for i in block] + + user_smem_bytes = _smem_tree_size(smem_buffers) + + smem_bytes = user_smem_bytes + if profiler_spec is not None: + smem_bytes += profiler_spec.smem_bytes(block=block) + + # TODO(cperivol): Query the shared memory size programmatically. + if smem_bytes > 228 * 1024: + raise ValueError(f"Mosaic GPU kernel exceeds available shared memory {smem_bytes=} > 228000") + if math.prod(cluster) != 1: + if len(cluster) != 3: + raise ValueError("Clusters must be 3D") + cluster_kwargs = { + "clusterSize" + d: c(s, index) for s, d in zip(cluster, "XYZ") + } + for d, grid_size, cluster_size in zip("xyz", grid, cluster): + if grid_size % cluster_size != 0: + raise ValueError( + f"Grid dimension {d} must be divisible by cluster dimension:" + f" {grid_size} % {cluster_size} != 0" + ) + else: + cluster_kwargs = {} + launch_op = gpu.LaunchOp( + token.type, [token], *grid_vals, *block_vals, + dynamicSharedMemorySize=c(smem_bytes, i32), **cluster_kwargs) + launch_op.body.blocks.append(*([index] * (12 + 2 * len(cluster_kwargs)))) # Append an empty block + smem = ir.Attribute.parse("#gpu.address_space") + with ir.InsertionPoint(launch_op.body.blocks[0]): + dynamic_smem = gpu.dynamic_shared_memory( + ir.MemRefType.get( + (ir.ShapedType.get_dynamic_size(),), i8, memory_space=smem + ) + ) + + smem_ref_tree = _construct_smem_reftree( + cluster, dynamic_smem, smem_buffers + ) + # TODO(apaszke): Skip the following if no barriers were initialized. + nvvm.fence_mbarrier_init() + if math.prod(cluster) != 1: + nvvm.cluster_arrive_relaxed(aligned=ir.UnitAttr.get()) + nvvm.cluster_wait(aligned=ir.UnitAttr.get()) + gpu.barrier() + + if profiler_spec: + prof_smem = memref.view( + ir.MemRefType.get( + (profiler_spec.smem_i32_elements(block=block),), + i32, memory_space=smem, + ), + dynamic_smem, c(user_smem_bytes, index), [], + ) + prof = profiler.OnDeviceProfiler( + profiler_spec, prof_smem, maybe_prof_buffer + ) + else: + prof = None + + ptr_ty = ir.Type.parse("!llvm.ptr") + scratch_ptr = builtin.unrealized_conversion_cast([ptr_ty], [scratch_arr]) + yield LaunchContext(launch_op, scratch_ptr, cluster, prof), smem_ref_tree + if prof is not None: + prof.finalize(grid=grid, block=block) + gpu.terminator() + + +def _lower_as_gpu_kernel( + body, + grid: tuple[int, int, int], + cluster: tuple[int, int, int], + block: tuple[int, int, int], + in_shapes: tuple[Any, ...], + out_shape, + smem_scratch_shape: ShapeTree | Union[ShapeTree], + module_name: str, + prof_spec: profiler.ProfilerSpec | None = None, +): + ptr_ty = ir.Type.parse("!llvm.ptr") + token_ty = ir.Type.parse("!gpu.async.token") + i32 = ir.IntegerType.get_signless(32) + i64 = ir.IntegerType.get_signless(64) + + def _shape_to_ref_ty(shape: jax.ShapeDtypeStruct) -> ir.MemRefType: + return ir.MemRefType.get(shape.shape, utils.dtype_to_ir_type(shape.dtype)) + + in_ref_tys = [_shape_to_ref_ty(t) for t in in_shapes] + + unwrap_output_tuple = False + if isinstance(out_shape, list): + out_shape = tuple(out_shape) + elif not isinstance(out_shape, tuple): + out_shape = (out_shape,) + unwrap_output_tuple = True + out_ref_tys = [_shape_to_ref_ty(t) for t in out_shape] + if prof_spec is not None: + out_shape = (*out_shape, prof_spec.jax_buffer_type(grid, block)) + out_ref_tys.append(prof_spec.mlir_buffer_type(grid, block)) + + module = ir.Module.create() + attrs = module.operation.attributes + attrs["sym_name"] = ir.StringAttr.get(module_name) + with ir.InsertionPoint(module.body): + _declare_runtime_functions() + gmem_scratch_bytes = 0 + global_scratch = llvm.GlobalOp( + ir.Type.parse("!llvm.array<0 x i8>"), # We don't know the shape yet. + "global_scratch", + ir.Attribute.parse("#llvm.linkage"), + addr_space=ir.IntegerAttr.get(i32, 4), # GPU constant memory. + ) + @func.FuncOp.from_py_func(ptr_ty, ptr_ty) + def main(token_ptr, buffers): + nonlocal gmem_scratch_bytes + token = builtin.unrealized_conversion_cast([token_ty], [token_ptr]) + arg_refs = [] + for i, ref_ty in enumerate([*in_ref_tys, *out_ref_tys]): + ptr = llvm.LoadOp(ptr_ty, llvm.GEPOp(ptr_ty, buffers, [], [i], ptr_ty)) + arg_refs.append(utils.ptr_as_memref(ptr, ir.MemRefType(ref_ty))) + in_refs = arg_refs[:len(in_ref_tys)] + out_refs = arg_refs[len(in_ref_tys):] + prof_buffer = out_refs.pop() if prof_spec is not None else None + empty_arr_ty = ir.Type.parse("!llvm.array<0 x i8>") + scratch_alloc = llvm.AllocaOp( + ptr_ty, c(1, i64), empty_arr_ty, alignment=TMA_DESCRIPTOR_ALIGNMENT + ) + scratch_arr = llvm.load(empty_arr_ty, scratch_alloc.result) + with _launch( + token, grid, cluster, block, scratch_arr, smem_scratch_shape, + prof_spec, prof_buffer + ) as (launch_ctx, smem_refs): + body(launch_ctx, *in_refs, *out_refs, smem_refs) + gmem_scratch_bytes = launch_ctx.next_scratch_offset + # Allocate and initialize the host buffer right before the launch. + # Note that we couldn't do that before, because we had to run the body + # to learn what the scratch contains. + with ir.InsertionPoint(scratch_arr.owner): + scratch_arr_ty = ir.Type.parse(f"!llvm.array<{gmem_scratch_bytes} x i8>") + scratch_alloc.elem_type = ir.TypeAttr.get(scratch_arr_ty) + scratch_arr.set_type(scratch_arr_ty) + for init_callback in launch_ctx.host_scratch_init: + init_callback(scratch_alloc.result) + main.func_op.attributes["llvm.emit_c_interface"] = ir.UnitAttr.get() + sym_tab = ir.SymbolTable(module.operation) + sym_tab.insert(main.func_op) + sym_tab.insert(global_scratch) + module.operation.verify() + + return module, out_shape, unwrap_output_tuple + + +def _declare_runtime_functions(): + """Declares the runtime functions that can be used by the generated code.""" + ptr_ty = ir.Type.parse("!llvm.ptr") + i64 = ir.IntegerType.get_signless(64) + arg_tys = [ptr_ty, ptr_ty, i64, i64, ptr_ty, ptr_ty, i64, ptr_ty] + init_tma_desc_type = ir.FunctionType.get(arg_tys, []) + func.FuncOp( + "mosaic_gpu_init_tma_desc", init_tma_desc_type, visibility="private" + ) + memcpy_async_type = ir.FunctionType.get([ptr_ty, ptr_ty, i64, ptr_ty], []) + func.FuncOp( + "mosaic_gpu_memcpy_async_h2d", memcpy_async_type, visibility="private" + ) + + +def as_gpu_kernel( + body, + grid: tuple[int, int, int], + block: tuple[int, int, int], + in_shape, + out_shape, + smem_scratch_shape: ShapeTree | Union[ShapeTree], + prof_spec: profiler.ProfilerSpec | None = None, + cluster: tuple[int, int, int] = (1, 1, 1), + module_name: str = "unknown", +): + if isinstance(in_shape, list): + in_shape = tuple(in_shape) + elif not isinstance(in_shape, tuple): + in_shape = (in_shape,) + + module, out_shape, unwrap_output_tuple = ( + _lower_as_gpu_kernel( + body, grid, cluster, block, in_shape, out_shape, smem_scratch_shape, + module_name, prof_spec + ) + ) + + expected_arg_treedef = jax.tree.structure(in_shape) + def _check_args(*args): + arg_treedef = jax.tree.structure(args) + if arg_treedef != expected_arg_treedef: + raise ValueError( + f"Invalid argument structure: expected {expected_arg_treedef}, got" + f" {arg_treedef}, ({args=})" + ) + + module_asm = module.operation.get_asm(binary=True, enable_debug_info=True) + def bind(*args): + return mosaic_gpu_p.bind( + *args, + out_types=out_shape, + module=module_asm, + ) + + if prof_spec is not None: + @jax.jit + def prof_kernel(*args): + _check_args(*args) + *results, prof_buffer = bind(*args) + def dump_profile(prof_buffer): + out_file = os.path.join( + os.getenv("TEST_UNDECLARED_OUTPUTS_DIR"), + f"{time.time_ns()}-trace.json", + ) + try: + with open(out_file, "x") as f: + prof_spec.dump(prof_buffer, f, grid=grid, block=block) + except FileExistsError: + pass # TODO: Retry + jax.debug.callback(dump_profile, prof_buffer) + return results[0] if unwrap_output_tuple else results + return prof_kernel + else: + @jax.jit + def kernel(*args): + _check_args(*args) + results = bind(*args) + return results[0] if unwrap_output_tuple else results + return kernel + + +def as_torch_gpu_kernel( + body, + grid: tuple[int, int, int], + block: tuple[int, int, int], + in_shape, + out_shape, + smem_scratch_shape: ShapeTree | Union[ShapeTree], + prof_spec: profiler.ProfilerSpec | None = None, + cluster: tuple[int, int, int] = (1, 1, 1), + module_name: str = "unknown", +): + try: + import torch + except ImportError: + raise RuntimeError("as_torch_gpu_kernel requires PyTorch") + torch.cuda.init() # Make sure CUDA context is set up. + + if isinstance(in_shape, list): + in_shape = tuple(in_shape) + elif not isinstance(in_shape, tuple): + in_shape = (in_shape,) + + flat_out_types, out_treedef = jax.tree.flatten(out_shape) + expected_arg_treedef = jax.tree.structure(in_shape) + + module, out_shape, unwrap_output_tuple = ( + _lower_as_gpu_kernel( + body, grid, cluster, block, in_shape, out_shape, smem_scratch_shape, + module_name, prof_spec + ) + ) + + # Get our hands on the compilation and unload functions + try: + import jax_plugins.xla_cuda12 as cuda_plugin + except ImportError: + raise RuntimeError("as_torch_gpu_kernel only works with recent jaxlib builds " + "that use backend plugins") + dll = ctypes.CDLL(cuda_plugin._get_library_path()) + compile_func = dll.MosaicGpuCompile + compile_func.argtypes = [ctypes.c_void_p] + compile_func.restype = ctypes.POINTER(ctypes.c_void_p) + unload_func = dll.MosaicGpuUnload + unload_func.argtypes = [compile_func.restype] + unload_func.restype = None + + module_asm = module.operation.get_asm(binary=True, enable_debug_info=True) + compiled = compile_func(ctypes.c_char_p(module_asm)) + if compiled is None: + raise RuntimeError("Failed to compile the module") + ctx, launch_ptr = compiled[0], compiled[1] + ctx_ptr_ptr = ctypes.pointer(ctypes.c_void_p(ctx)) + launch = ctypes.CFUNCTYPE(None, ctypes.c_void_p)(launch_ptr) + + def as_torch_dtype(dtype): + # torch contains NumPy-compatible dtypes in its top namespace + return getattr(torch, np.dtype(dtype).name) + + def apply(*args): + flat_args, arg_treedef = jax.tree.flatten(args) + if arg_treedef != expected_arg_treedef: + raise ValueError( + f"Invalid argument structure: expected {expected_arg_treedef}, got" + f" {arg_treedef}, ({args=})" + ) + + # Construct a device pointer list like in the XLA calling convention + buffers = (ctypes.c_void_p * (arg_treedef.num_leaves + out_treedef.num_leaves))() + i = -1 # Define i in case there are no args + device = 'cuda' + for i, arg in enumerate(flat_args): + buffers[i] = arg.data_ptr() + device = arg.device + flat_outs = [] + for i, t in enumerate(flat_out_types, i + 1): + out = torch.empty(t.shape, dtype=as_torch_dtype(t.dtype), device=device) + flat_outs.append(out) + buffers[i] = out.data_ptr() + # Allocate another buffer for args of the host-side program. This is sadly + # the default MLIR calling convention. + args_ptr = (ctypes.POINTER(ctypes.c_void_p) * 3)() + args_ptr[0] = ctx_ptr_ptr + args_ptr[1] = ctypes.pointer(torch.cuda.default_stream(device)._as_parameter_) + args_ptr[2] = ctypes.cast(ctypes.pointer(ctypes.pointer(buffers)), + ctypes.POINTER(ctypes.c_void_p)) + launch(args_ptr) + return jax.tree.unflatten(out_treedef, flat_outs) + + # Unload the compiled code when the Python function is destroyed. + def unload(_): + unload_func(compiled) + apply.destructor = weakref.ref(apply, unload) + + return apply diff --git a/jax/experimental/mosaic/gpu/dsl.py b/jax/experimental/mosaic/gpu/dsl.py deleted file mode 100644 index 82e0aa4abb12..000000000000 --- a/jax/experimental/mosaic/gpu/dsl.py +++ /dev/null @@ -1,56 +0,0 @@ -# Copyright 2024 The JAX Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== - -from . import ( - Barrier, - ClusterBarrier, - TMABarrier, - Union, -) -from .fragmented_array import ( - FragmentedArray, - FragmentedLayout, - WGMMA_LAYOUT, - WGMMA_ROW_LAYOUT, - WGStridedFragLayout, -) -from .utils import ( - BarrierRef, - CollectiveBarrierRef, - DynamicSlice, - Partition, - Partition1D, - bytewidth, - c, - commit_shared, - debug_print, - ds, - fori, - memref_fold, - memref_slice, - memref_transpose, - memref_unfold, - memref_unsqueeze, - single_thread, - thread_idx, - tile_shape, - warp_idx, - warpgroup_idx, -) -from .wgmma import ( - WGMMAAccumulator, - WGMMALayout, - wgmma, -) diff --git a/jax/experimental/mosaic/gpu/examples/BUILD b/jax/experimental/mosaic/gpu/examples/BUILD index 3f9496b38376..fe1a7e9180ac 100644 --- a/jax/experimental/mosaic/gpu/examples/BUILD +++ b/jax/experimental/mosaic/gpu/examples/BUILD @@ -12,14 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. -load("//jaxlib:jax.bzl", "py_deps") -load("@rules_python//python:defs.bzl", "py_library", "py_test") +load("@rules_python//python:defs.bzl", "py_library") +load("//jaxlib:jax.bzl", "jax_multiplatform_test", "py_deps") licenses(["notice"]) package( default_applicable_licenses = [], - default_visibility = ["//third_party/py/jax:mosaic_gpu_users"], + default_visibility = ["//jax:mosaic_gpu_users"], ) exports_files( @@ -27,15 +27,15 @@ exports_files( "flash_attention.py", "matmul.py", ], - visibility = ["//third_party/py/jax:internal"], + visibility = ["//jax:internal"], ) py_library( name = "matmul", srcs = ["matmul.py"], deps = [ - "//third_party/py/jax", - "//third_party/py/jax:mosaic_gpu", + "//jax", + "//jax:mosaic_gpu", ], ) @@ -43,23 +43,22 @@ py_library( name = "flash_attention", srcs = ["flash_attention.py"], deps = [ - "//third_party/py/jax", - "//third_party/py/jax:mosaic_gpu", + "//jax", + "//jax:mosaic_gpu", ], ) -py_test( +jax_multiplatform_test( name = "run_matmul", srcs = ["matmul.py"], + enable_backends = [], + enable_configs = ["gpu_h100"], main = "matmul.py", tags = [ "manual", "notap", - "requires-gpu-sm90-only", ], deps = [ - "//learning/brain/research/jax:gpu_support", - "//third_party/py/jax", - "//third_party/py/jax:mosaic_gpu", + "//jax:mosaic_gpu", ] + py_deps("numpy"), ) diff --git a/jax/experimental/mosaic/gpu/examples/flash_attention.py b/jax/experimental/mosaic/gpu/examples/flash_attention.py index 0675844227ba..99586875ae90 100644 --- a/jax/experimental/mosaic/gpu/examples/flash_attention.py +++ b/jax/experimental/mosaic/gpu/examples/flash_attention.py @@ -24,9 +24,8 @@ from jax import random from jax._src.interpreters import mlir from jax._src import test_util as jtu -from jax.experimental.mosaic import gpu as mosaic_gpu from jax.experimental.mosaic.gpu import profiler -from jax.experimental.mosaic.gpu.dsl import * # noqa: F403 +from jax.experimental.mosaic.gpu import * # noqa: F403 import jax.numpy as jnp from jaxlib.mlir import ir from jaxlib.mlir.dialects import arith @@ -144,7 +143,7 @@ def c(value, ty=index): return _utils_c(value, ty) def tma_wg_kernel( - ctx: mosaic_gpu.LaunchContext, + ctx: LaunchContext, q_gmem, k_gmem, v_gmem, @@ -190,7 +189,7 @@ def only_wg(idx): ctx.async_copy( src_ref=q_gmem, gmem_slice=(q_head_idx, ds(q_seq_base, blocks.q)), - gmem_transform=mosaic_gpu.TileTransform(tiling), + gmem_transform=TileTransform(tiling), dst_ref=qo_smem, barrier=q_barriers[wg_idx], swizzle=128, @@ -287,17 +286,14 @@ def kv_loop(kv_step, carry): with ctx.named_region("Acc store"): acc.astype(f16).store_tiled(qo_smem, swizzle=128) - gpu.barrier() - nvvm.fence_proxy( - nvvm.ProxyKind.async_shared, space=nvvm.SharedSpace.shared_cta - ) # Make sure the store is visible to the TMA. + commit_shared() # Make sure the store is visible to the TMA. with ctx.named_region("GMEM store"): ctx.async_copy( src_ref=qo_smem, dst_ref=out_gmem, gmem_slice=(q_head_idx, ds(q_seq_base, blocks.q)), - gmem_transform=mosaic_gpu.TileTransform(tiling), + gmem_transform=TileTransform(tiling), swizzle=128, ) ctx.await_async_copy(0) @@ -307,10 +303,9 @@ def kv_loop(kv_step, carry): nvvm.setmaxregister(40, nvvm.SetMaxRegisterAction.decrease) with single_thread(per_block=False): k_tr = ( - mosaic_gpu.TileTransform(tiling), - mosaic_gpu.TransposeTransform((0, 2, 1, 3, 4)), + TileTransform(tiling), TransposeTransform((0, 2, 1, 3, 4)), ) - v_tr = mosaic_gpu.TileTransform(tiling) + v_tr = TileTransform(tiling) kv_head_idx = arith.divui(q_head_idx, c(q_heads_per_kv_head)) def start_kv_copy(slot, kv_seq_base, smem, gmem, barrier, transform): ctx.async_copy( @@ -353,7 +348,7 @@ def _kv_loop_memory(i, _): scf.yield_([]) def compute_only_kernel( - ctx: mosaic_gpu.LaunchContext, + ctx: LaunchContext, q_gmem, k_gmem, v_gmem, @@ -391,7 +386,7 @@ def only_wg(idx): ctx.async_copy( src_ref=q_gmem, gmem_slice=(q_head_idx, ds(q_seq_base, blocks.q)), - gmem_transform=mosaic_gpu.TileTransform(tiling), + gmem_transform=TileTransform(tiling), dst_ref=qo_smem, barrier=barriers[q_barrier], swizzle=128, @@ -404,10 +399,10 @@ def kv_copy_init(slot, kv_seq_base): txcount = 2 * blocks.kv * head_dim * bytewidth(f16) barriers[slot].arrive_expect_tx(txcount) k_tr = ( - mosaic_gpu.TileTransform(tiling), - mosaic_gpu.TransposeTransform((0, 2, 1, 3, 4)), + TileTransform(tiling), + TransposeTransform((0, 2, 1, 3, 4)), ) - v_tr = mosaic_gpu.TileTransform(tiling) + v_tr = TileTransform(tiling) for smem, gmem, t in ((k_smem, k_gmem, k_tr), (v_smem, v_gmem, v_tr)): ctx.async_copy( dst_ref=memref_slice(smem, slot), @@ -529,7 +524,7 @@ def kv_loop(kv_step, carry): src_ref=qo_smem, dst_ref=out_gmem, gmem_slice=(q_head_idx, ds(q_seq_base, blocks.q)), - gmem_transform=mosaic_gpu.TileTransform(tiling), + gmem_transform=TileTransform(tiling), swizzle=128, ) ctx.await_async_copy(0) @@ -554,7 +549,7 @@ def kv_loop(kv_step, carry): Barrier(arrival_count=256, num_barriers=2), Barrier(arrival_count=256, num_barriers=1), ) - return mosaic_gpu.as_gpu_kernel( + return as_gpu_kernel( kernel, grid, block, in_shape, out_shape, smem_scratch_shape, prof_spec ) diff --git a/jax/experimental/mosaic/gpu/examples/matmul.py b/jax/experimental/mosaic/gpu/examples/matmul.py index a0d008bf1b43..c56c5cd6b982 100644 --- a/jax/experimental/mosaic/gpu/examples/matmul.py +++ b/jax/experimental/mosaic/gpu/examples/matmul.py @@ -15,24 +15,21 @@ """Matmul kernels for H100.""" import dataclasses -import functools -from typing import Any +import itertools import math +from typing import Any import jax from jax import random from jax._src.interpreters import mlir -from jax.experimental.mosaic import gpu as mosaic_gpu from jax.experimental.mosaic.gpu import profiler -from jax.experimental.mosaic.gpu.dsl import * # noqa: F403 +from jax.experimental.mosaic.gpu import * # noqa: F403 import jax.numpy as jnp from jaxlib.mlir import ir from jaxlib.mlir.dialects import arith from jaxlib.mlir.dialects import gpu -from jaxlib.mlir.dialects import memref from jaxlib.mlir.dialects import nvvm from jaxlib.mlir.dialects import scf -from jaxlib.mlir.dialects import vector import numpy as np # mypy: ignore-errors @@ -87,13 +84,14 @@ def wgmma( b_order: WGMMALayout, a_slice: SmemRef, b_slice: SmemRef, + swizzle: int, ) -> dict[str, WGMMAAccumulator]: """Perform a matrix multiplication. This function must guarantee that all WGMMA operations queued before it was called have completed before returning. """ - acc = wgmma(acc, a_slice, b_slice, b_order=b_order) + acc = wgmma(acc, a_slice, b_slice, b_order=b_order, swizzle=swizzle) nvvm.wgmma_commit_group_sync_aligned() nvvm.wgmma_wait_group_sync_aligned(1) return acc @@ -109,45 +107,64 @@ def wrap(*args, **kw): @mlir_context def build_kernel( m, n, k, - lhs_dtype, rhs_dtype, + lhs_dtype, rhs_dtype, out_dtype, stages: int = 2, tile_m: int = 128, tile_n: int = 128, - cluster: tuple[int, int] = (1, 1), + swizzle: int = 128, + cluster_m: int = 1, + cluster_n: int = 1, + grid_tile_n: int = 1, rhs_transpose: bool = False, wgmma_impl=WGMMADefaultImpl, profiler_spec: profiler.ProfilerSpec | None = None, ): f32 = ir.F32Type.get() - out_128b_elems = 128 // bytewidth(f32) - out_tiling = (64, out_128b_elems) - out_tile = jax.ShapeDtypeStruct(tile_shape((tile_m, tile_n), out_tiling), jnp.float32) if tile_m % 64 != 0: raise ValueError(f"{tile_m=} must be divisible by 64") if m % tile_m != 0: raise ValueError(f"{m=} must be divisible by {tile_m=}") - if n % 64 != 0: - raise ValueError(f"n must be divisible by 64, but got {n=}") + if n % tile_n != 0: + raise ValueError(f"{n=} must be divisible by {tile_n=}") if stages < 2: raise ValueError(f"Need at least 2 stages, but got {stages=}") + if not rhs_transpose and jnp.dtype(rhs_dtype).itemsize != 2: + raise ValueError(f"Transpose only supported for 16bit types (got: {rhs_transpose=}, {rhs_dtype=})") + if swizzle not in {32, 64, 128}: + raise ValueError(f"swizzle must be 32, 64, or 128, but got {swizzle=}") + + out_mlir_dtype = mlir.dtype_to_ir_type(out_dtype) + out_swizzle = swizzle + if bytewidth(out_mlir_dtype) == 4: + if tile_n % 32 == 0: + out_swizzle = 128 + elif tile_n % 16 == 0: + out_swizzle = 64 + else: + raise NotImplementedError( + f"{tile_n=} must by divisible by 16 for 32-bit output" + ) + out_swizzle_elems = out_swizzle // bytewidth(out_mlir_dtype) + out_tiling = (64, out_swizzle_elems) + out_tile = jax.ShapeDtypeStruct(tile_shape((tile_m, tile_n), out_tiling), out_dtype) lhs_elem_bytes = bytewidth(mlir.dtype_to_ir_type(lhs_dtype)) rhs_elem_bytes = bytewidth(mlir.dtype_to_ir_type(rhs_dtype)) - lhs_128b_elems = 128 // lhs_elem_bytes - rhs_128b_elems = 128 // rhs_elem_bytes - tile_k = max(lhs_128b_elems, rhs_128b_elems) + lhs_swizzle_elems = swizzle // lhs_elem_bytes + rhs_swizzle_elems = swizzle // rhs_elem_bytes + tile_k = max(lhs_swizzle_elems, rhs_swizzle_elems) - if tile_n % rhs_128b_elems != 0: + if tile_n % rhs_swizzle_elems != 0: raise ValueError( - f"{tile_n=} must be divisible by 128 bytes =" - f" {((lhs_128b_elems, lhs_dtype), (rhs_128b_elems, rhs_dtype))}" + f"{tile_n=} must be divisible by {swizzle} bytes =" + f" {((lhs_swizzle_elems, lhs_dtype), (rhs_swizzle_elems, rhs_dtype))}" ) if k % tile_k != 0: raise ValueError(f"k must be divisible by {tile_k=}, but got {k=}") block_tiling = Tiling(m=tile_m, n=tile_n, k=tile_k) - tma_tiling = Tiling(m=64, n=rhs_128b_elems, k=lhs_128b_elems) + tma_tiling = Tiling(m=64, n=rhs_swizzle_elems, k=lhs_swizzle_elems) k_steps = k // block_tiling.k stages = min(stages, k_steps) @@ -155,7 +172,11 @@ def safe_div(x, y): assert x % y == 0, (x, y) return x // y - grid = (safe_div(m, block_tiling.m), safe_div(n, block_tiling.n), 1) + grid = ( + grid_tile_n, + safe_div(m, block_tiling.m), + safe_div(n, block_tiling.n * grid_tile_n), + ) block = (128, 1, 1) c = arith.ConstantOp.create_index @@ -166,16 +187,18 @@ def safe_div(x, y): wgmma_impl.smem_shape_extra(block_tiling, tma_tiling, lhs_dtype, rhs_dtype, rhs_transpose), ) epilogue_scratch_shape = jax.ShapeDtypeStruct(out_tile.shape, out_tile.dtype) - smem_shape = mosaic_gpu.Union([compute_scratch_shape, epilogue_scratch_shape]) + smem_shape = Union([compute_scratch_shape, epilogue_scratch_shape]) def _main(ctx, a_device, b_device, c_device, smem): ((lhs_smem, rhs_smem, impl_smem), epilogue_smem), *barriers = smem tma_barriers, cluster_barrier = barriers - memref.assume_alignment(c_device, 16) - - m_start = arith.muli(c(block_tiling.m), gpu.block_id(gpu.Dimension.x)) - n_start = arith.muli(c(block_tiling.n), gpu.block_id(gpu.Dimension.y)) + m_start = arith.muli(c(block_tiling.m), gpu.block_id(gpu.Dimension.y)) + n_block_idx = arith.addi( + gpu.block_id(gpu.Dimension.x), + arith.muli(gpu.block_id(gpu.Dimension.z), c(grid_tile_n)), + ) + n_start = arith.muli(c(block_tiling.n), n_block_idx) def fetch(slot, ki): barrier = tma_barriers[slot] @@ -184,7 +207,7 @@ def fetch(slot, ki): rhs_tma_tile_bytes = int(np.prod(block_tiling.kn) * rhs_elem_bytes) txcount = lhs_tma_tile_bytes + rhs_tma_tile_bytes common_copy_args = dict( - swizzle=128, barrier=barrier, arrive=False, uniform=False, + swizzle=swizzle, barrier=barrier, arrive=False, uniform=False, ) with single_thread(): barrier.arrive_expect_tx(txcount) @@ -192,22 +215,22 @@ def fetch(slot, ki): src_ref=a_device, dst_ref=memref_slice(lhs_smem, slot), gmem_slice=(ds(m_start, block_tiling.m), ds(k_start, block_tiling.k)), - gmem_transform=mosaic_gpu.TileTransform(tma_tiling.mk), - collective=gpu.Dimension.y, + gmem_transform=TileTransform(tma_tiling.mk), + collective=(gpu.Dimension.x, gpu.Dimension.z), **common_copy_args, ) rhs_slice = (ds(k_start, block_tiling.k), ds(n_start, block_tiling.n)) - rhs_transform = (mosaic_gpu.TileTransform(tma_tiling.kn),) + rhs_transform = (TileTransform(tma_tiling.kn),) if rhs_transpose: rhs_slice = rhs_slice[::-1] - rhs_transform += (mosaic_gpu.TransposeTransform((1, 0, 2, 3)),) + rhs_transform += (TransposeTransform((1, 0, 2, 3)),) assert tma_tiling.n == tma_tiling.k, block_tiling # No need to flip the tiling. ctx.async_copy( src_ref=b_device, dst_ref=memref_slice(rhs_smem, slot), gmem_slice=rhs_slice, gmem_transform=rhs_transform, - collective=gpu.Dimension.x, + collective=gpu.Dimension.y, **common_copy_args, ) @@ -230,7 +253,9 @@ def stage_loop_body(ki, accs): rhs_smem_order = ( WGMMALayout.COL_MAJOR if rhs_transpose else WGMMALayout.ROW_MAJOR ) - accs = wgmma_impl.wgmma(impl_smem, accs, rhs_smem_order, a_slice, b_slice) + accs = wgmma_impl.wgmma( + impl_smem, accs, rhs_smem_order, a_slice, b_slice, swizzle=swizzle + ) with ctx.named_region("TMA start"): tma_ki = arith.addi(ki, c(stages - 1)) @@ -256,7 +281,7 @@ def stage_loop_body(ki, accs): with ctx.named_region("SMEM store"): acc_val = wgmma_impl.get_result(stage_loop_body.result) - acc_val.store_tiled(epilogue_smem, swizzle=128) + acc_val.astype(out_mlir_dtype).store_tiled(epilogue_smem, swizzle=out_swizzle) commit_shared() # Make sure the stores are visible to TMA. with ctx.named_region("GMEM store"): @@ -264,12 +289,19 @@ def stage_loop_body(ki, accs): src_ref=epilogue_smem, dst_ref=c_device, gmem_slice=(ds(m_start, tile_m), ds(n_start, tile_n)), - gmem_transform=mosaic_gpu.TileTransform(out_tiling), - swizzle=128, + gmem_transform=TileTransform(out_tiling), + swizzle=out_swizzle, ) ctx.await_async_copy(0) - return mosaic_gpu.as_gpu_kernel( + cluster_tile_n = min(cluster_n, grid_tile_n) + if cluster_n % cluster_tile_n: + raise ValueError( + f"{cluster_n=} must be divisible by {cluster_tile_n} (due to" + f" {grid_tile_n=})" + ) + cluster = (cluster_tile_n, cluster_m, cluster_n // cluster_tile_n) + return as_gpu_kernel( _main, grid, block, @@ -277,17 +309,17 @@ def stage_loop_body(ki, accs): jax.ShapeDtypeStruct((m, k), lhs_dtype), jax.ShapeDtypeStruct((n, k) if rhs_transpose else (k, n), rhs_dtype), ), - jax.ShapeDtypeStruct((m, n), jnp.float32), + jax.ShapeDtypeStruct((m, n), out_dtype), ( smem_shape, TMABarrier(num_barriers=stages), ClusterBarrier( - collective_dims=(gpu.Dimension.x, gpu.Dimension.y), + collective_dims=((gpu.Dimension.x, gpu.Dimension.z), gpu.Dimension.y), num_barriers=stages, - ) if math.prod(cluster) > 1 else None, + ) if cluster_m * cluster_n > 1 else None, ), profiler_spec, - cluster=(*cluster, 1), + cluster=cluster, ) @@ -300,33 +332,32 @@ def verify( tile_n=128, cluster_m=1, cluster_n=1, + grid_tile_n=1, + swizzle=128, profile=False, - lhs_dtype=jnp.float16, - rhs_dtype=jnp.float16, + in_dtype=jnp.float16, + out_dtype=jnp.float32, rhs_transpose=False, ): - if not rhs_transpose and jnp.dtype(lhs_dtype).itemsize != 2: - raise ValueError( - "Implicit transpose can only happen for 16bit types (or mixed precision" - " that is underpinned by 16bit operations)." - ) + lhs_dtype, rhs_dtype = in_dtype, in_dtype kx, ky = random.split(random.key(1234)) x = random.uniform(kx, (m, k), dtype=lhs_dtype) y = random.uniform(ky, (n, k) if rhs_transpose else (k, n), dtype=rhs_dtype) - impl = WGMMADefaultImpl - prof_spec = profiler.ProfilerSpec(4096) if profile else None f = build_kernel( m, n, k, - jnp.dtype(lhs_dtype), jnp.dtype(rhs_dtype), + jnp.dtype(lhs_dtype), jnp.dtype(rhs_dtype), jnp.dtype(out_dtype), stages=stages, tile_m=tile_m, tile_n=tile_n, - cluster=(cluster_m, cluster_n), + cluster_m=cluster_m, + cluster_n=cluster_n, rhs_transpose=rhs_transpose, - wgmma_impl=impl, + swizzle=swizzle, + grid_tile_n=grid_tile_n, + wgmma_impl=WGMMADefaultImpl, profiler_spec=prof_spec, ) z, runtime = profiler.measure(f, x, y) @@ -342,21 +373,78 @@ def verify( for v in (x, y) ) - ref_f = functools.partial( - jax.lax.dot_general, - dimension_numbers=dimension_numbers, - preferred_element_type=jnp.float32, - ) + @jax.jit + def ref_f(x, y): + return jax.lax.dot_general( + x, + y, + dimension_numbers=dimension_numbers, + preferred_element_type=jnp.float32, + ).astype(out_dtype) ref, ref_runtime = profiler.measure(ref_f, x, y) - np.testing.assert_allclose(z, ref, atol=1e-3, rtol=1e-3) + np.testing.assert_allclose( + z.astype(jnp.float32), ref.astype(jnp.float32), atol=1e-3, rtol=1e-3 + ) return runtime, ref_runtime if __name__ == "__main__": - m, k, n = 4 * 33 * 128, 2048, 4 * 128 - runtime, ref_runtime = verify(m=m, k=k, n=n, cluster_m=1, cluster_n=4) + dtype = jnp.dtype(jnp.float16) + m, k, n = 16384, 2048, 16384 + + kx, ky = random.split(random.key(1234)) + x = random.uniform(kx, (m, k), dtype=dtype) + y = random.uniform(ky, (k, n), dtype=dtype) + + tile_m = tile_n = (64, 128) + cluster_m = cluster_n = (1, 2) + swizzle = (128,) # 64 can be a good choice for some shapes too! + stages = (2, 4, 5, 6) + grid_tile_n = (1, 4, 16) + configs = itertools.product(tile_m, tile_n, cluster_m, cluster_n, stages, swizzle, grid_tile_n) + names = ("tile_m", "tile_n", "cluster_m", "cluster_n", "stages", "swizzle", "grid_tile_n") + best_runtime = float("inf") + best_kwargs = {} + for config in configs: + kwargs = dict(zip(names, config)) + if kwargs["cluster_m"] * kwargs["cluster_n"] > 8: + continue + if m < kwargs["tile_m"] or n < kwargs["tile_n"]: + continue + if (m // kwargs["tile_m"]) % kwargs["cluster_m"]: + continue + if (n // kwargs["tile_n"]) % kwargs["cluster_n"]: + continue + if n % kwargs["grid_tile_n"]: + continue + # This is a heuristic, not a strict correctness check. You can relax it + # for a more complete search space. + if kwargs["tile_m"] == kwargs["tile_n"] == 64: + continue + try: + f = build_kernel( + m, n, k, dtype, dtype, dtype, wgmma_impl=WGMMADefaultImpl, **kwargs + ) + _, runtime = profiler.measure(f, x, y) + except ValueError as e: + if "Mosaic GPU kernel exceeds available shared memory" not in str(e): + raise + runtime = float("inf") + # Enable this to get more detailed information. + # else: + # print(" ".join(f"{k}={v}" for k, v in kwargs.items()), int(runtime * 1000)) + if runtime < best_runtime: + best_runtime = runtime + best_kwargs = kwargs + if not best_kwargs: + raise ValueError("No valid configuration found") + + runtime, ref_runtime = verify( + m=m, k=k, n=n, in_dtype=dtype, out_dtype=dtype, **best_kwargs + ) tflops = float(2 * k * m * n) / (runtime / 1e3) / 1e12 ref_tflops = float(2 * k * m * n) / (ref_runtime / 1e3) / 1e12 + print("Best parameters: ", " ".join(f"{k}={v}" for k, v in best_kwargs.items())) print(f"Kernel: {runtime * 1000:.1f} us = {tflops:.1f} TFLOPS") print(f"Reference: {ref_runtime * 1000:.1f} us = {ref_tflops:.1f} TFLOPS") diff --git a/jax/experimental/mosaic/gpu/fragmented_array.py b/jax/experimental/mosaic/gpu/fragmented_array.py index 8b13a00bced9..ae6c40b9416d 100644 --- a/jax/experimental/mosaic/gpu/fragmented_array.py +++ b/jax/experimental/mosaic/gpu/fragmented_array.py @@ -15,6 +15,8 @@ """Utilities for code generator.""" import dataclasses +import functools +import math from typing import Callable import jax @@ -28,7 +30,7 @@ from jaxlib.mlir.dialects import vector import numpy as np -from . import dsl as mgpu +import jax.experimental.mosaic.gpu as mgpu from . import utils # mypy: ignore-errors @@ -98,10 +100,10 @@ def from_memref_type(cls, memref_ty: ir.Type): memref_type = ir.MemRefType(memref_ty) bw = mgpu.bytewidth(memref_type.element_type) assert 8 % bw == 0 and 8 // bw != 0, bw - if np.prod(memref_type.shape) % WARPGROUP_SIZE != 0: + if math.prod(memref_type.shape) % WARPGROUP_SIZE != 0: raise ValueError( "Ref must have a number of elements that is a multiple of" - f" {WARPGROUP_SIZE}" + f" {WARPGROUP_SIZE} (got {math.prod(memref_type.shape)})" ) max_vec_size = np.prod(memref_type.shape) // WARPGROUP_SIZE return cls( @@ -109,6 +111,17 @@ def from_memref_type(cls, memref_ty: ir.Type): ) def thread_vec_idxs(self): + index = ir.IndexType.get() + for v in self.linear_thread_vec_idxs(): + res = [] + for dim in reversed(self.shape): + dim = c(dim, index) + res.append(arith.remui(v, dim)) + v = arith.divui(v, dim) + res.reverse() + yield res + + def linear_thread_vec_idxs(self): """The indexes to be used for vector load/store WGStridedFragLayout. Yields: @@ -121,7 +134,7 @@ def thread_vec_idxs(self): tidx = arith.remui(gpu.thread_id(gpu.Dimension.x), c(WARPGROUP_SIZE, index)) off = arith.muli(tidx, c(self.vec_size, tidx.type)) for i in range(reg_num): - yield [arith.addi(off, c(i * WARPGROUP_SIZE * self.vec_size, tidx.type))] + yield arith.addi(off, c(i * WARPGROUP_SIZE * self.vec_size, tidx.type)) FragmentedLayout = WGSplatFragLayout | WGStridedFragLayout | WGMMAFragLayout | WGMMARowFragLayout @@ -132,48 +145,71 @@ def thread_vec_idxs(self): @jax.tree_util.register_pytree_node_class +@dataclasses.dataclass(init=False, eq=False, frozen=True, slots=True) class FragmentedArray: - registers: np.ndarray # of ir.Value, see checks in init for shapes. + # An array of ir.Value, see checks in init for shapes. + registers: np.ndarray = dataclasses.field(repr=False) layout: FragmentedLayout - - def __init__(self, *, _registers: np.ndarray, _layout: FragmentedLayout): - self.registers = _registers - self.layout = _layout + is_signed: bool | None + + def __init__( + self, + *, + _registers: np.ndarray, + _layout: FragmentedLayout, + _is_signed: bool | None, + ): + """Initializes a fragmented array. + + This is a low-level API. Prefer using classmethods to construct fragmented + arrays instead. + """ + # We need to use ``object.__setattr__`` here because of ``frozen=True``. + object.__setattr__(self, "registers", _registers) + object.__setattr__(self, "layout", _layout) + object.__setattr__(self, "is_signed", _is_signed) + + if (_is_signed is not None) != ir.IntegerType.isinstance(self.mlir_dtype): + raise TypeError( + "is_signed must only be non-None if the MLIR type is an integer" + f" type, got {_is_signed=} for {self.mlir_dtype}" + ) match self.layout: # Registers are [m_tiles, n_tiles, 2 rows, 1 cols] in WGMMA layout # Each element is a vector<2xdtype> case WGMMAFragLayout(): - if self.registers.ndim != 4 or self.registers.shape[2:] != (2, 1): - raise ValueError("Invalid register array shape") + if _registers.ndim != 4 or _registers.shape[2:] != (2, 1): + raise ValueError(f"Invalid register array shape: {_registers.shape}") # Registers are [m_tiles, 2 rows] in WGMMA_ROW layout # Each element is a dtype scalar case WGMMARowFragLayout(): - if self.registers.ndim != 2 or self.registers.shape[-1] != 2: - raise ValueError("Invalid register array shape") + if _registers.ndim != 2 or _registers.shape[-1] != 2: + raise ValueError(f"Invalid register array shape: {_registers.shape}") # Registers are flat case WGStridedFragLayout(shape): - (reg_size,) = ir.VectorType(_registers.flat[0].type).shape - if np.prod(shape) != np.prod(_registers.shape) * WARPGROUP_SIZE * reg_size: - raise ValueError((reg_size, shape, _registers.shape, WARPGROUP_SIZE), _registers.flat[0].type) + [reg_size] = ir.VectorType(_registers.flat[0].type).shape + if ( + math.prod(shape) + != math.prod(_registers.shape) * WARPGROUP_SIZE * reg_size + ): + raise ValueError( + "Invalid register array shape: math.prod({_registers.shape}) *" + " {WARPGROUP_SIZE} * {reg_size}, want: math.prod({shape})" + ) # Just a single register case WGSplatFragLayout(): if _registers.size != 1: - raise ValueError(f"WGStridedFragLayout requires a single value {_registers.shape} ({_registers.size})") + raise ValueError(f"Invalid register array shape: {_registers.shape}") case _: raise NotImplementedError - def __repr__(self): - return ( - f"FragmentedArray(layout={self.layout}, shape={self.shape})" - ) - @classmethod - def load_strided(cls, ref: ir.Value): + def load_strided(cls, ref: ir.Value, *, is_signed: bool | None = None): if not ir.MemRefType.isinstance(ref.type): raise TypeError(ref.type) @@ -181,11 +217,11 @@ def load_strided(cls, ref: ir.Value): ref_1d = mgpu.memref_fold(ref, 0, len(ref_ty.shape)) layout = WGStridedFragLayout.from_memref_type(ref_ty) vec_ty = ir.VectorType.get((layout.vec_size,), ref_ty.element_type) - vecs = [vector.load(vec_ty, ref_1d, vec_idx) for vec_idx in layout.thread_vec_idxs()] - return cls(_registers=np.array(vecs), _layout=layout) + vecs = [vector.load(vec_ty, ref_1d, [vec_idx]) for vec_idx in layout.linear_thread_vec_idxs()] + return cls(_registers=np.array(vecs), _layout=layout, _is_signed=is_signed) @classmethod - def splat(cls, value, shape, layout=None): + def splat(cls, value, shape, layout=None, *, is_signed: bool | None = None): layout = layout or WGSplatFragLayout(shape) match layout: case WGMMARowFragLayout(): @@ -215,6 +251,7 @@ def splat(cls, value, shape, layout=None): return cls( _registers=np.full(reg_shape, value, dtype=object), _layout=layout, + _is_signed=is_signed, ) @property @@ -240,19 +277,32 @@ def mlir_dtype(self): case WGMMARowFragLayout() | WGSplatFragLayout(): return reg_ty - def _pointwise(self, op, *other): + def _pointwise(self, op, *other, output_is_signed: bool | None = None): + is_signed = ( + output_is_signed if output_is_signed is not None else self.is_signed + ) + other_arrs = [] for o in other: if not isinstance(o, FragmentedArray): - if not isinstance(o, ir.Value): + if isinstance(o, (float, int)): + o = utils.c(o, self.mlir_dtype) + elif not isinstance(o, ir.Value): raise NotImplementedError(o) - o = FragmentedArray.splat(o, shape=self.shape, layout=self.layout) + o = FragmentedArray.splat( + o, shape=self.shape, layout=self.layout, is_signed=is_signed + ) if isinstance(o.layout, WGSplatFragLayout): if not o.layout.can_broadcast_to(self.shape): raise ValueError("Can't broadcast shape.") - o = FragmentedArray.splat(o.registers.flat[0], shape=self.shape, layout=self.layout) + o = FragmentedArray.splat( + o.registers.flat[0], + shape=self.shape, + layout=self.layout, + is_signed=is_signed, + ) else: if self.layout != o.layout: raise ValueError("Incompatible FragmentedArray layouts") @@ -264,7 +314,20 @@ def _pointwise(self, op, *other): for idx, reg in np.ndenumerate(self.registers): new_regs[idx] = op(reg, *(o.registers[idx] for o in other_arrs)) - return FragmentedArray(_registers=new_regs, _layout=self.layout) + return FragmentedArray( + _registers=new_regs, _layout=self.layout, _is_signed=is_signed + ) + + def __pos__(self): + return self + + def __neg__(self): + if ir.FloatType.isinstance(self.mlir_dtype): + return self._pointwise(arith.negf) + elif ir.IntegerType.isinstance(self.mlir_dtype): + return 0 - self + else: + return NotImplemented def __add__(self, other): if ir.FloatType.isinstance(self.mlir_dtype): @@ -272,7 +335,7 @@ def __add__(self, other): elif ir.IntegerType.isinstance(self.mlir_dtype): return self._pointwise(arith.addi, other) else: - raise NotImplementedError(self.mlir_dtype) + return NotImplemented def __radd__(self, other): return self + other @@ -283,37 +346,124 @@ def __mul__(self, other): elif ir.IntegerType.isinstance(self.mlir_dtype): return self._pointwise(arith.muli, other) else: - raise NotImplementedError(self.mlir_dtype) + return NotImplemented def __rmul__(self, other): return self * other def __sub__(self, other): - if not ir.FloatType.isinstance(self.mlir_dtype): - raise NotImplementedError - return self._pointwise(arith.subf, other) + if ir.FloatType.isinstance(self.mlir_dtype): + return self._pointwise(arith.subf, other) + elif ir.IntegerType.isinstance(self.mlir_dtype): + return self._pointwise(arith.subi, other) + else: + return NotImplemented def __rsub__(self, other): - if not ir.FloatType.isinstance(self.mlir_dtype): - raise NotImplementedError - return self._pointwise(lambda s, o: arith.subf(o, s), other) + if ir.FloatType.isinstance(self.mlir_dtype): + return self._pointwise(lambda s, o: arith.subf(o, s), other) + elif ir.IntegerType.isinstance(self.mlir_dtype): + return self._pointwise(lambda s, o: arith.subi(o, s), other) + else: + return NotImplemented def __truediv__(self, other): if not ir.FloatType.isinstance(self.mlir_dtype): - raise NotImplementedError + return NotImplemented return self._pointwise(arith.divf, other) def __rtruediv__(self, other): if not ir.FloatType.isinstance(self.mlir_dtype): - raise NotImplementedError + return NotImplemented return self._pointwise(lambda s, o: arith.divf(o, s), other) - def max(self, other): - if not ir.FloatType.isinstance(self.mlir_dtype): + def __mod__(self, other): + if not ir.IntegerType.isinstance(self.mlir_dtype): + return NotImplemented + if self.is_signed: + return self._pointwise(arith.remsi, other) + else: + return self._pointwise(arith.remui, other) + + def __rmod__(self, other): + if not ir.IntegerType.isinstance(self.mlir_dtype): + return NotImplemented + if self.is_signed: + return self._pointwise(lambda s, o: arith.remsi(o, s), other) + else: + return self._pointwise(lambda s, o: arith.remui(o, s), other) + + def __eq__(self, other): + return self._compare( + other, + f_pred=arith.CmpFPredicate.OEQ, + si_pred=arith.CmpIPredicate.eq, + ui_pred=arith.CmpIPredicate.eq, + ) + + def __ne__(self, other): + return self._compare( + other, + f_pred=arith.CmpFPredicate.UNE, + si_pred=arith.CmpIPredicate.ne, + ui_pred=arith.CmpIPredicate.ne, + ) + + def __lt__(self, other): + return self._compare( + other, + f_pred=arith.CmpFPredicate.OLT, + si_pred=arith.CmpIPredicate.slt, + ui_pred=arith.CmpIPredicate.ult, + ) + + def __le__(self, other): + return self._compare( + other, + f_pred=arith.CmpFPredicate.OLE, + si_pred=arith.CmpIPredicate.sle, + ui_pred=arith.CmpIPredicate.ule, + ) + + def __gt__(self, other): + return self._compare( + other, + f_pred=arith.CmpFPredicate.OGT, + si_pred=arith.CmpIPredicate.sgt, + ui_pred=arith.CmpIPredicate.ugt, + ) + + def __ge__(self, other): + return self._compare( + other, + f_pred=arith.CmpFPredicate.OGE, + si_pred=arith.CmpIPredicate.sge, + ui_pred=arith.CmpIPredicate.uge, + ) + + def _compare(self, other, *, f_pred, si_pred, ui_pred): + if ir.FloatType.isinstance(self.mlir_dtype): + pred = functools.partial(arith.cmpf, f_pred) + elif ir.IntegerType.isinstance(self.mlir_dtype): + if ir.IntegerType(self.mlir_dtype).is_signed: + pred = functools.partial(arith.cmpi, si_pred) + else: + pred = functools.partial(arith.cmpi, ui_pred) + else: raise NotImplementedError - return self._pointwise(arith.maximumf, other) + return self._pointwise(pred, other, output_is_signed=False) - def exp(self, approx: bool = False): + def max(self, other): + if ir.FloatType.isinstance(self.mlir_dtype): + return self._pointwise(arith.maximumf, other) + elif ir.IntegerType.isinstance(self.mlir_dtype): + return self._pointwise( + arith.maxsi if self.is_signed else arith.maxui, other + ) + else: + return NotImplemented + + def exp(self, *, approx: bool = False): if not ir.FloatType.isinstance(self.mlir_dtype): raise NotImplementedError if approx: @@ -327,7 +477,7 @@ def fast_exp(x): return self._pointwise(self._lift_fast_unary(fast_exp)) return self._pointwise(mlir_math.exp) - def sin(self, approx: bool = False): + def sin(self, *, approx: bool = False): if not ir.FloatType.isinstance(self.mlir_dtype): raise NotImplementedError if approx and self.mlir_dtype != ir.F32Type.get(): @@ -336,7 +486,7 @@ def sin(self, approx: bool = False): self._lift_fast_unary("sin.approx.f32") if approx else mlir_math.sin ) - def cos(self, approx: bool = False): + def cos(self, *, approx: bool = False): if not ir.FloatType.isinstance(self.mlir_dtype): raise NotImplementedError if approx and self.mlir_dtype != ir.F32Type.get(): @@ -345,7 +495,7 @@ def cos(self, approx: bool = False): self._lift_fast_unary("cos.approx.f32") if approx else mlir_math.cos ) - def rsqrt(self, approx: bool = False): + def rsqrt(self, *, approx: bool = False): if not ir.FloatType.isinstance(self.mlir_dtype): raise NotImplementedError if approx and self.mlir_dtype != ir.F32Type.get(): @@ -417,13 +567,62 @@ def __getitem__(self, idx): base_idx[0] : base_idx[0] + slice_shape[0], base_idx[1] : base_idx[1] + slice_shape[1], ] - return FragmentedArray(_registers=new_regs, _layout=self.layout) + return FragmentedArray( + _registers=new_regs, _layout=self.layout, _is_signed=self.is_signed + ) # TODO(apaszke): Support JAX dtypes here as well? - def astype(self, new_dtype: ir.Type): + def astype(self, new_dtype: ir.Type, *, is_signed: bool | None = None): + i8 = ir.IntegerType.get_signless(8) + i16 = ir.IntegerType.get_signless(16) + i32 = ir.IntegerType.get_signless(32) + bf16 = ir.BF16Type.get() + cur_dtype = self.mlir_dtype if cur_dtype == new_dtype: - return self + if self.is_signed == is_signed: + return self + return FragmentedArray( + _registers=self.registers, _layout=self.layout, _is_signed=is_signed + ) + reg_type = self.registers.flat[0].type + is_vector_reg = ir.VectorType.isinstance(reg_type) + reg_shape = tuple(ir.VectorType(reg_type).shape) if is_vector_reg else () + if cur_dtype == i8 and new_dtype == bf16 and reg_shape == (2,): + new_registers = np.empty_like(self.registers) + for idx, reg in np.ndenumerate(self.registers): + reg_16 = vector.bitcast(ir.VectorType.get((1,), i16), reg) + val_16 = llvm.extractelement(reg_16, c(0, i32)) + # We first embed the s8 into a bf16 with the exponent equal to + # bias + mantissa bits. Then, we zero the msb that didn't fit into the + # mantissa, zero out all bits other than msb, and subtract the last + # two values from each other. This takes advantage of the fact that the + # lsb of the exponent (msb of the second byte) is zero, which allows us + # to losslesly pack the msb there. When 1, it doubles the value of s2, + # making the result negative. + new_val_32 = llvm.inline_asm( + i32, + [val_16], + """ + { + .reg .b32 s<3>; + prmt.b32 s0, $1, 0x43, 0x4140; + and.b32 s1, s0, 0xff7fff7f; + and.b32 s2, s0, 0xff80ff80; + sub.bf16x2 $0, s1, s2; + } + """, + "=r,r", + ) + new_vec = llvm.mlir_undef(ir.VectorType.get((1,), i32)) + new_vec = llvm.insertelement(new_vec, new_val_32, c(0, i32)) + new_registers[idx] = vector.bitcast( + ir.VectorType.get((2,), new_dtype), new_vec + ) + return FragmentedArray( + _registers=new_registers, _layout=self.layout, _is_signed=is_signed + ) + # Generic path. from_float = ir.FloatType.isinstance(cur_dtype) to_float = ir.FloatType.isinstance(new_dtype) from_integer = ir.IntegerType.isinstance(cur_dtype) @@ -442,6 +641,8 @@ def astype(self, new_dtype: ir.Type): convert = arith.sitofp elif from_float and to_integer: convert = arith.fptosi + else: + raise NotImplementedError(f"Unsupported conversion {cur_dtype} -> {new_dtype}") new_registers = np.empty_like(self.registers) match self.layout: case WGMMAFragLayout(): @@ -454,15 +655,24 @@ def astype(self, new_dtype: ir.Type): raise NotImplementedError(f"Unsupported layout {self.layout}") for idx, reg in np.ndenumerate(self.registers): new_registers[idx] = convert(new_reg_ty, reg) - return FragmentedArray(_registers=new_registers, _layout=self.layout) + return FragmentedArray( + _registers=new_registers, _layout=self.layout, _is_signed=is_signed + ) def reduce_sum(self, scratch) -> ir.Value: + if ir.FloatType.isinstance(self.mlir_dtype): + op = arith.addf + elif ir.IntegerType.isinstance(self.mlir_dtype): + op = arith.addi + else: + raise NotImplementedError(self.mlir_dtype) + index = ir.IndexType.get() if not isinstance(self.layout, WGStridedFragLayout): raise NotImplementedError(f"Unsupported layout {self.layout}") result = c(0, self.mlir_dtype) for reg in self.registers: - result = arith.addf( + result = op( result, vector.reduction(self.mlir_dtype, vector.CombiningKind.ADD, reg), ) @@ -470,19 +680,12 @@ def reduce_sum(self, scratch) -> ir.Value: if scratch_ty.element_type != self.mlir_dtype or scratch_ty.shape != [4]: raise ValueError(f"Expected shape={(4,)}, {self.mlir_dtype} (got {scratch_ty})") - if ir.FloatType.isinstance(self.mlir_dtype): - op = arith.addf - elif ir.IntegerType.isinstance(self.mlir_dtype): - op = arith.addi - else: - raise NotImplementedError(self.mlir_dtype) - warp_result = utils.warp_tree_reduce(result, op, 32) warp_id = arith.divui(gpu.thread_id(gpu.Dimension.x), c(32, index)) memref.store(warp_result, scratch, [warp_id]) - utils.commit_shared() + utils.warpgroup_barrier() zero_index = c(0, index) - with mgpu.single_thread(): + with mgpu.single_thread(per_block=False): scratch_vec = vector.load( ir.VectorType.get((4,), self.mlir_dtype), scratch, @@ -492,7 +695,7 @@ def reduce_sum(self, scratch) -> ir.Value: self.mlir_dtype, vector.CombiningKind.ADD, scratch_vec ) memref.store(scratch_sum, scratch, [zero_index]) - utils.commit_shared() + utils.warpgroup_barrier() return memref.load(scratch, [zero_index]) def reduce(self, op, axis): @@ -528,7 +731,9 @@ def reduce(self, op, axis): ) result = op(result, other_result) new_regs[row_tile, row_subtile] = result - return FragmentedArray(_registers=new_regs, _layout=WGMMA_ROW_LAYOUT) + return FragmentedArray( + _registers=new_regs, _layout=WGMMA_ROW_LAYOUT, _is_signed=self.is_signed + ) def broadcast(self, shape): if not isinstance(self.layout, WGSplatFragLayout): @@ -540,7 +745,11 @@ def broadcast(self, shape): if not self.layout.can_broadcast_to(shape): raise ValueError(f"Can't broadcast {self.shape} to {shape}") - return FragmentedArray(_registers=self.registers, _layout=WGSplatFragLayout(shape)) + return FragmentedArray( + _registers=self.registers, + _layout=WGSplatFragLayout(shape), + _is_signed=self.is_signed, + ) def reshape(self, shape): if self.shape == shape: @@ -552,7 +761,11 @@ def reshape(self, shape): if np.prod(shape) != np.prod(self.shape): raise ValueError(f"Can't reshape {self.shape} to {shape}") - return FragmentedArray(_registers=self.registers, _layout=WGSplatFragLayout(shape)) + return FragmentedArray( + _registers=self.registers, + _layout=WGSplatFragLayout(shape), + _is_signed=self.is_signed, + ) def broadcast_minor(self, n): if self.layout != WGMMA_ROW_LAYOUT: @@ -567,7 +780,28 @@ def broadcast_minor(self, n): new_regs[row_tile, :, row_subtile, :] = vector.splat( ir.VectorType.get((2,), dtype), reg ) - return FragmentedArray(_registers=new_regs, _layout=WGMMA_LAYOUT) + return FragmentedArray( + _registers=new_regs, _layout=WGMMA_LAYOUT, _is_signed=self.is_signed + ) + + def select(self, on_true, on_false): + if ( + not ir.IntegerType.isinstance(self.mlir_dtype) + or ir.IntegerType(self.mlir_dtype).width != 1 + ): + raise NotImplementedError + return self._pointwise(arith.select, on_true, on_false) + + def foreach(self, fn: Callable[[ir.Value, tuple[ir.Value, ...]], None]): + """Call a function for each value and index.""" + if not isinstance(self.layout, WGStridedFragLayout): + raise NotImplementedError(self.layout) + index = ir.IndexType.get() + for idx, reg in zip(self.layout.thread_vec_idxs(), self.registers.flat): + assert len(idx) == len(self.shape), (idx, self.shape) + for i in range(self.layout.vec_size): + i = c(i, index) + fn(vector.extractelement(reg, position=i), (*idx[:-1], arith.addi(idx[-1], i))) def store_untiled(self, ref: ir.Value): if not ir.MemRefType.isinstance(ref.type): @@ -576,19 +810,37 @@ def store_untiled(self, ref: ir.Value): match self.layout: case WGMMAFragLayout(): self._store_untiled_wgmma(ref) + case WGSplatFragLayout(): + self._store_untiled_splat(ref) case WGStridedFragLayout(): self._store_untiled_wg_strided(ref) case _: raise NotImplementedError(self.layout) + def _store_untiled_splat(self, ref: ir.Value): + vec_size = 8 // mgpu.bytewidth(self.mlir_dtype) + if np.prod(self.shape) < vec_size * WARPGROUP_SIZE: + vec_size = 1 + + if np.prod(self.shape) % WARPGROUP_SIZE * vec_size: + raise ValueError(self.shape, WARPGROUP_SIZE, vec_size) + + fa = FragmentedArray.splat( + self.registers.flat[0], + self.shape, + layout=WGStridedFragLayout(shape=self.shape, vec_size=vec_size), + is_signed=self.is_signed, + ) + fa.store_untiled(ref) + def _store_untiled_wg_strided(self, ref: ir.Value): ref_ty = ir.MemRefType(ref.type) ref_shape = tuple(ref_ty.shape) if ref_shape != self.shape: raise ValueError((ref_shape, self.shape)) smem_1d = mgpu.memref_fold(ref, 0, len(ref_ty.shape)) - for idx, reg in zip(self.layout.thread_vec_idxs(), self.registers.flat): - vector.store(reg, smem_1d, idx) + for idx, reg in zip(self.layout.linear_thread_vec_idxs(), self.registers.flat): + vector.store(reg, smem_1d, [idx]) def _store_untiled_wgmma(self, ref: ir.Value): """Stores accumulator to a 2D memref. Not optimized at the moment.""" @@ -627,13 +879,17 @@ def store_tiled(self, ref, swizzle: int | None): assert m % 64 == 0 # This is implied by the layout. cols_per_tile = swizzle // bw expected_shape = [m // 64, n // cols_per_tile, 64, cols_per_tile] + if n < cols_per_tile: # We allow singular tiles shorter than swizzle. + expected_shape = [m // 64, 1, 64, cols_per_tile] if ir.MemRefType(ref.type).shape != expected_shape: raise ValueError(ref.type, (m, n)) for get, _, idxs in self.transfer_tiled(self.shape, dtype, swizzle): vector.store(get(self.registers), ref, idxs) @classmethod - def load_tiled(cls, ref, swizzle: int | None): + def load_tiled( + cls, ref, swizzle: int | None, *, is_signed: bool | None = None + ): ref_ty = ir.MemRefType(ref.type) dtype = ref_ty.element_type bw = mgpu.bytewidth(dtype) @@ -649,16 +905,19 @@ def load_tiled(cls, ref, swizzle: int | None): ) for _, update, idxs in cls.transfer_tiled((m, n), dtype, swizzle): update(registers, vector.load(ir.VectorType.get((2,), dtype), ref, idxs)) - return cls(_registers=registers, _layout=WGMMA_LAYOUT) + return cls(_registers=registers, _layout=WGMMA_LAYOUT, _is_signed=is_signed) @staticmethod def transfer_tiled(shape, dtype, swizzle: int | None): # TODO(apaszke): We could use ldmatrix/stmatrix for 16-bit types. bw = mgpu.bytewidth(dtype) m, n = shape - cols_per_tile = swizzle // bw - if n % cols_per_tile != 0: - raise NotImplementedError + assert m % 64 == 0 and n % 8 == 0 # Implied by the layout. + cols_per_tile = swizzle_elems = swizzle // bw + if n < swizzle_elems: + cols_per_tile = n + else: + assert n % swizzle_elems == 0, (n, swizzle_elems) if swizzle not in {32, 64, 128}: raise NotImplementedError("Only swizzled stores supported") @@ -693,6 +952,8 @@ def transfer_tiled(shape, dtype, swizzle: int | None): case _: raise AssertionError(swizzle) stagger_amount = swizzle // 64 + if (cols_per_tile // 8) % (stagger_amount * 2): + raise NotImplementedError else: # We rely on canonicalization to clean up the selects. i1 = ir.IntegerType.get_signless(1) @@ -731,10 +992,11 @@ def update_registers(regs, new, left_idx=left_idx, right_idx=right_idx): yield get_register, update_registers, idx def tree_flatten(self): - return list(self.registers.flat), (self.layout, self.registers.shape) + aux = self.layout, self.registers.shape, self.is_signed + return list(self.registers.flat), aux @classmethod def tree_unflatten(cls, aux, flat_registers): - layout, reg_shape = aux + layout, reg_shape, is_signed = aux registers = np.asarray(flat_registers, dtype=object).reshape(reg_shape) - return cls(_registers=registers, _layout=layout) + return cls(_registers=registers, _layout=layout, _is_signed=is_signed) diff --git a/jax/experimental/mosaic/gpu/utils.py b/jax/experimental/mosaic/gpu/utils.py index 7733146d0153..a59ddbea5565 100644 --- a/jax/experimental/mosaic/gpu/utils.py +++ b/jax/experimental/mosaic/gpu/utils.py @@ -23,13 +23,14 @@ from typing import Any, Literal import jax +from jax import numpy as jnp +from jax.interpreters import mlir from jaxlib.mlir import ir from jaxlib.mlir.dialects import arith from jaxlib.mlir.dialects import builtin from jaxlib.mlir.dialects import gpu from jaxlib.mlir.dialects import llvm from jaxlib.mlir.dialects import memref -from jaxlib.mlir.dialects import nvgpu from jaxlib.mlir.dialects import nvvm from jaxlib.mlir.dialects import scf from jaxlib.mlir.dialects import vector @@ -154,9 +155,8 @@ def fori(bound, carrys): flat_carrys, carry_treedef = jax.tree.flatten(carrys) def wrapper(f): - index = ir.IndexType.get() - c0 = arith.ConstantOp(index, ir.IntegerAttr.get(index, 0)) - c1 = arith.ConstantOp(index, ir.IntegerAttr.get(index, 1)) + c0 = arith.constant(bound.type, 0) + c1 = arith.constant(bound.type, 1) for_op = scf.ForOp(c0, bound, c1, flat_carrys) with ir.InsertionPoint(for_op.body): i = for_op.induction_variable @@ -229,6 +229,15 @@ class ThreadSubset(enum.IntEnum): _ONCE_PER: ThreadSubset | None = None +def single_thread_predicate(per_block=True): + warp = warp_idx() + if not per_block: + warp = arith.remui(warp, c(4, warp.type)) + first_warp = arith.cmpi(arith.CmpIPredicate.eq, warp, c(0, warp.type)) + elected = nvvm.elect_sync(ir.IntegerType.get_signless(1)) + return arith.andi(first_warp, elected) + + @contextlib.contextmanager def single_thread(per_block=True): """Runs the context only from a single thread. @@ -244,16 +253,10 @@ def single_thread(per_block=True): yield return - warp = warp_idx() - if not per_block: - warp = arith.remui(warp, c(4, warp.type)) - first_warp = arith.cmpi(arith.CmpIPredicate.eq, warp, c(0, warp.type)) - elected = nvvm.elect_sync(ir.IntegerType.get_signless(1)) - should_run = arith.andi(first_warp, elected) - if_op = scf.IfOp(should_run) prev_scope = _ONCE_PER _ONCE_PER = scope try: + if_op = scf.IfOp(single_thread_predicate(per_block)) with ir.InsertionPoint(if_op.then_block): yield scf.YieldOp([]) @@ -309,6 +312,8 @@ class DynamicSlice: def memref_slice(ref: ir.Value, index) -> ir.Value: ref_ty = ir.MemRefType(ref.type) base_indices, slice_shape, is_squeezed = parse_indices(index, ref_ty.shape) + # TODO(apaszke): Check that slice is within the memref (indices might be + # dynamic, but we can at least catch some OOB slices). memref_strides, offset = ref_ty.get_strides_and_offset() new_offset = offset @@ -503,12 +508,25 @@ def parse_indices( def commit_shared(): - gpu.barrier() + warpgroup_barrier() nvvm.fence_proxy( nvvm.ProxyKind.async_shared, space=nvvm.SharedSpace.shared_cta ) +def warpgroup_barrier(): + # gpu.barrier() uses barrier number 0, and it would be unsafe to reuse it, + # so we shift the warpgroup index by 1. + i32 = ir.IntegerType.get_signless(32) + llvm.inline_asm( + ir.Type.parse("!llvm.void"), + [arith.addi(warpgroup_idx(sync=False), c(1, i32))], + f"bar.sync $0, {WARPGROUP_SIZE};", + "r", + has_side_effects=True, + ) + + @dataclasses.dataclass(frozen=True) class BarrierRef: base_address: ir.Value @@ -595,14 +613,15 @@ def arrive(self): i64 = ir.IntegerType.get_signless(64) nvvm.mbarrier_arrive_shared(i64, self.get_ptr()) - def arrive_expect_tx(self, bytes: int | ir.Value): + def arrive_expect_tx( + self, bytes: int | ir.Value, predicate: ir.Value | None = None + ): if isinstance(bytes, int): bytes = c(bytes, ir.IntegerType.get_signless(32)) elif ir.IndexType.isinstance(bytes.type): i32 = ir.IntegerType.get_signless(32) bytes = arith.index_cast(i32, bytes) - - nvvm.mbarrier_arrive_expect_tx(self.get_ptr(), bytes) + nvvm.mbarrier_arrive_expect_tx_shared(self.get_ptr(), bytes, predicate=predicate) def get_ptr(self): ptr = ir.Type.parse("!llvm.ptr<3>") @@ -622,21 +641,27 @@ class CollectiveBarrierRef: def initialize( address: ir.Value, num_barriers: int, - dims: Sequence[gpu.Dimension], + dims: Sequence[gpu.Dimension | Sequence[gpu.Dimension]], cluster_shape: tuple[int, int, int], ) -> "CollectiveBarrierRef": i32 = ir.IntegerType.get_signless(32) # With the exception of the current device, each pair of slices along # collective dims is disjoint. Since the current device is overcounted, # we must decrease the arrival count a little. - arrival_count = sum(cluster_shape[d] for d in dims) - len(dims) + 1 - if math.prod(cluster_shape[d] for d in dims) == 1: + dims_shape = [ + cluster_shape[d] + if isinstance(d, gpu.Dimension) + else math.prod(cluster_shape[dd] for dd in d) + for d in dims + ] + arrival_count = sum(dims_shape) - len(dims) + 1 + if arrival_count == 1: + assert all(s == 1 for s in dims_shape) cluster_mask = None - assert arrival_count == 1 else: cluster_mask = c(0, i32) - for d in dims: - if cluster_shape[d] == 1: + for d, size in zip(dims, dims_shape): + if size == 1: # Only the current device is in this mask, but it will also be # present in one of the non-trivial cluster dims. continue @@ -693,8 +718,11 @@ def arrive(self): has_side_effects=True, ) - def wait(self): - self.barrier.wait() + def wait(self, *args, **kwargs): + self.barrier.wait(*args, **kwargs) + + def wait_parity(self, *args, **kwargs): + self.barrier.wait_parity(*args, **kwargs) class Partition: @@ -887,8 +915,11 @@ def memref_ptr(memref_arg, memory_space=None): def cluster_collective_mask( - cluster_shape: tuple[int, int, int], collective: gpu.Dimension + cluster_shape: tuple[int, int, int], + collective: Sequence[gpu.Dimension] | gpu.Dimension, ): + if isinstance(collective, gpu.Dimension): + collective = (collective,) # We first compute the linearized index of the slice along the collective # dim that contains the current block. Then, the mask is a sequence of 1s # strided by the position of the collective dim, shifted left by the linear @@ -896,20 +927,36 @@ def cluster_collective_mask( # TODO(apaszke): Make sure this gets hoisted outside of any loops. # If not, we might need to do it manually. i32 = ir.IntegerType.get_signless(32) - stride = 1 mask_shift = c(0, i32) - collective_stride = None - for cluster_dim in gpu.Dimension: - if cluster_dim != collective: - if cluster_shape[cluster_dim] != 1: # Constant-fold multiply by 0. - dim_idx = arith.index_castui(i32, gpu.cluster_block_id(cluster_dim)) - mask_shift = arith.addi( - mask_shift, arith.muli(dim_idx, c(stride, i32)), - ) - else: - collective_stride = stride - stride *= cluster_shape[cluster_dim] + # NOTE: GPU dimensions are minor-to-major. + cluster_strides = get_contiguous_strides(cluster_shape[::-1])[::-1] + for stride, cluster_dim in zip(cluster_strides, gpu.Dimension): + if cluster_dim in collective: + continue + if cluster_shape[cluster_dim] != 1: # Constant-fold multiply by 0. + dim_idx = arith.index_castui(i32, gpu.cluster_block_id(cluster_dim)) + mask_shift = arith.addi( + mask_shift, arith.muli(dim_idx, c(stride, i32)), + ) mask_unshifted = 0 - for i in range(cluster_shape[collective]): - mask_unshifted |= 1 << (i * collective_stride) + collective_strides = [cluster_strides[d] for d in collective] + collective_shape = tuple(cluster_shape[d] for d in collective) + for idx in np.ndindex(collective_shape): + mask_unshifted |= 1 << sum(i * s for i, s in zip(idx, collective_strides)) return arith.shli(c(mask_unshifted, i32), mask_shift) + + +def dtype_to_ir_type(dtype: jax.typing.DTypeLike) -> ir.Type: + dtype = jnp.dtype(dtype) + if jnp.issubdtype(dtype, jnp.integer): + # All integer types in Mosaic GPU are signless. + return ir.IntegerType.get_signless(dtype.itemsize * 8) + return mlir.dtype_to_ir_type(dtype) + + +def is_signed(dtype: jax.typing.DTypeLike) -> bool | None: + if jnp.issubdtype(dtype, jnp.bool_): + return False + elif jnp.issubdtype(dtype, jnp.integer): + return jnp.issubdtype(dtype, jnp.signedinteger) + return None diff --git a/jax/experimental/mosaic/gpu/wgmma.py b/jax/experimental/mosaic/gpu/wgmma.py index fc2fe892ac03..ba0f130364ff 100644 --- a/jax/experimental/mosaic/gpu/wgmma.py +++ b/jax/experimental/mosaic/gpu/wgmma.py @@ -21,13 +21,11 @@ import jax from jaxlib.mlir import ir from jaxlib.mlir.dialects import arith -from jaxlib.mlir.dialects import builtin from jaxlib.mlir.dialects import llvm -from jaxlib.mlir.dialects import nvvm from jaxlib.mlir.dialects import vector import numpy as np -from . import dsl as mgpu +import jax.experimental.mosaic.gpu as mgpu from . import utils # mypy: ignore-errors @@ -55,15 +53,19 @@ def __init__(self, *, _value: mgpu.FragmentedArray, _sync: bool = True): self.value = wgmma_fence(_value) @classmethod - def zero(cls, m, n, dtype=None): + def zero(cls, m, n, dtype=None, *, is_signed: bool | None = None): if m % 64 or n % 8: raise ValueError + if is_signed is False: + raise TypeError("PTX does not support unsigned WGMMA accumulators") f32 = ir.F32Type.get() if dtype is None: dtype = f32 zero = arith.constant(dtype, ir.FloatAttr.get(dtype, 0.0)) return cls( - _value=mgpu.FragmentedArray.splat(zero, (m, n), mgpu.WGMMA_LAYOUT) + _value=mgpu.FragmentedArray.splat( + zero, (m, n), mgpu.WGMMA_LAYOUT, is_signed=is_signed + ) ) @classmethod @@ -156,14 +158,14 @@ def wgmma_m64( out_ty = ir.VectorType(acc.flat[0].type).element_type if not _supported_wgmma_types(out_ty, element_type): raise ValueError(f"Usupported wgmma types {(out_ty, element_type)=}") + if n % 8: + raise ValueError i32 = ir.IntegerType.get_signless(32) i64 = ir.IntegerType.get_signless(64) index = ir.IndexType.get() if b_k_stride % 16: raise ValueError - if n % (swizzle // bytewidth(element_type)): - raise ValueError # Only 16-bit types support transposes supports_transpose = bytewidth(element_type) == 2 if not supports_transpose and (a_transpose or b_transpose): @@ -326,7 +328,15 @@ def wgmma( kn_tile = swizzle // element_bytewidth groups_k, groups_n = b_ty.shape[:2] - if b_ty.shape[2:] != [kn_tile, kn_tile]: + k_group_size, n_group_size = ( + b_ty.shape[2:] if b_order == WGMMALayout.ROW_MAJOR else b_ty.shape[:1:-1] + ) + # Note that while this technically allows n to be smaller than kn_tile, + # the stride checks below will still enforce that the memory region is padded. + # It might be possible to relax that requirement, but I haven't tested it. + if n_group_size > kn_tile and n_group_size % kn_tile: + raise ValueError(n_group_size, kn_tile) + if k_group_size != kn_tile: raise ValueError(b_ty.shape) if a_in_regs: @@ -353,6 +363,12 @@ def wgmma( if a_order == WGMMALayout.COL_MAJOR and swizzle != 128: # Not sure what the layout is like, since the tiles aren't square. raise NotImplementedError + expected_acc_shape = (groups_m * 64, groups_n * n_group_size) + if acc.value.shape != expected_acc_shape: + raise ValueError( + f"Accumulator shape mismatch: expected {expected_acc_shape}, got" + f" {acc.value.shape}" + ) row_major = WGMMALayout.ROW_MAJOR col_major = WGMMALayout.COL_MAJOR @@ -375,7 +391,7 @@ def wgmma( b_transpose=b_order == row_major, a_k_stride=(2 if a_order == row_major else 128) << 4, b_k_stride=(swizzle if b_order == row_major else 2) << 4, - n=(groups_n * kn_tile), + n=(groups_n * n_group_size), swizzle=swizzle, element_type=ir.FloatTF32Type.get() if ir.F32Type.isinstance(element_type) @@ -418,7 +434,9 @@ def wgmma( ) return WGMMAAccumulator( _value=mgpu.FragmentedArray( - _registers=new_acc_regs, _layout=mgpu.WGMMA_LAYOUT + _registers=new_acc_regs, + _layout=mgpu.WGMMA_LAYOUT, + _is_signed=acc.value.is_signed, ), _sync=False, ) @@ -478,7 +496,7 @@ def wgmma_fence(array: mgpu.FragmentedArray): registers = np.asarray(regs, dtype=object).reshape(array.registers.shape) else: raise NotImplementedError(dtype) - return mgpu.FragmentedArray(_registers=registers, _layout=array.layout) + return mgpu.FragmentedArray(_registers=registers, _layout=array.layout, _is_signed=array.is_signed) def _as_fragmented_reg_ndarray(flat_regs, dtype: ir.Type, shape: tuple[int, ...]): diff --git a/jax/experimental/multihost_utils.py b/jax/experimental/multihost_utils.py index 1ca601da3942..56003ea7af5d 100644 --- a/jax/experimental/multihost_utils.py +++ b/jax/experimental/multihost_utils.py @@ -32,7 +32,6 @@ from jax._src.interpreters import pxla from jax.interpreters import xla from jax._src import pjit as pjit_lib -from jax.experimental.pjit import pjit from jax.sharding import PartitionSpec as P from jax._src import distributed from jax._src.util import safe_zip @@ -101,7 +100,7 @@ def _handle_array_process_allgather(inp, tiled): if isinstance(inp, array.ArrayImpl) and not inp.is_fully_addressable: reps = sharding_impls.GSPMDSharding.get_replicated( inp.sharding._device_assignment) - out = pjit(_identity_fn, out_shardings=reps)(inp) + out = jax.jit(_identity_fn, out_shardings=reps)(inp) else: # All inputs here will be fully addressable. if jax.process_count() == 1: @@ -124,8 +123,8 @@ def _handle_array_process_allgather(inp, tiled): bufs = [jax.device_put(host_np_arr, d) for d in jax.local_devices()] global_arr = array.make_array_from_single_device_arrays( global_aval.shape, s, bufs) - with global_mesh: - out = pjit(_identity_fn, out_shardings=None)(global_arr) + out = jax.jit(_identity_fn, + out_shardings=jax.NamedSharding(global_mesh, P()))(global_arr) return np.asarray(out.addressable_data(0)) diff --git a/jax/experimental/pallas/__init__.py b/jax/experimental/pallas/__init__.py index 9a768ed53e75..bb733e794c5f 100644 --- a/jax/experimental/pallas/__init__.py +++ b/jax/experimental/pallas/__init__.py @@ -21,12 +21,14 @@ from jax._src.deprecations import register as _register_deprecation from jax._src.pallas.core import Blocked from jax._src.pallas.core import BlockSpec +from jax._src.pallas.core import CompilerParams from jax._src.pallas.core import CostEstimate +from jax._src.pallas.core import GridSpec from jax._src.pallas.core import IndexingMode from jax._src.pallas.core import no_block_spec from jax._src.pallas.core import Unblocked from jax._src.pallas.core import unblocked -from jax._src.pallas.core import GridSpec +from jax._src.pallas.core import MemorySpace from jax._src.pallas.pallas_call import pallas_call from jax._src.pallas.pallas_call import pallas_call_p from jax._src.pallas.primitives import atomic_add @@ -56,5 +58,8 @@ from jax._src.state.indexing import Slice from jax._src.state.primitives import broadcast_to +ANY = MemorySpace.ANY + + _register_deprecation("pallas-block-spec-order") del _register_deprecation diff --git a/jax/experimental/pallas/gpu.py b/jax/experimental/pallas/gpu.py index adade4e8a72c..4f38192e3a14 100644 --- a/jax/experimental/pallas/gpu.py +++ b/jax/experimental/pallas/gpu.py @@ -14,5 +14,7 @@ """Triton-specific Pallas APIs.""" +from jax._src.pallas.triton.core import TritonCompilerParams from jax._src.pallas.triton.primitives import approx_tanh +from jax._src.pallas.triton.primitives import debug_barrier from jax._src.pallas.triton.primitives import elementwise_inline_asm diff --git a/jax/experimental/pallas/ops/gpu/attention.py b/jax/experimental/pallas/ops/gpu/attention.py index a0221ebf6f74..8e28be840d37 100644 --- a/jax/experimental/pallas/ops/gpu/attention.py +++ b/jax/experimental/pallas/ops/gpu/attention.py @@ -21,6 +21,7 @@ import jax from jax import lax from jax.experimental import pallas as pl +from jax.experimental.pallas import gpu as plgpu import jax.numpy as jnp import numpy as np @@ -41,7 +42,7 @@ def mha_forward_kernel( block_d: int, block_k: int, ): - seq_len = q_ref.shape[0] + seq_len = k_ref.shape[0] start_q = pl.program_id(0) # o is the buffer where we accumulate the output on sram. @@ -55,7 +56,7 @@ def mha_forward_kernel( # read, compute, and write all in 2d chunks. 1 element ~= 1 CUDA thread index. # q tile has shape [block_q, block_d], block_d == head_dim. curr_q_slice = pl.dslice(start_q * block_q, block_q) - q = pl.load(q_ref, (curr_q_slice, pl.dslice(None))) + q = q_ref[...] q_segment_ids = ( None if segment_ids_ref is None @@ -70,11 +71,6 @@ def body(start_k, carry): curr_k_slice = pl.dslice(start_k * block_k, block_k) k = pl.load(k_ref, (curr_k_slice, slice(None))) - kv_segment_ids = ( - None - if segment_ids_ref is None - else pl.load(segment_ids_ref, (curr_k_slice,)) - ) qk = pl.dot(q, k.T) # [block_q, block_k] if sm_scale != 1.: qk *= sm_scale # [block_q, block_k] @@ -87,6 +83,7 @@ def body(start_k, carry): if causal or segment_ids_ref is not None: mask = None if segment_ids_ref is not None: + kv_segment_ids = pl.load(segment_ids_ref, (curr_k_slice,)) mask = segment_mask(q_segment_ids, kv_segment_ids) if causal: span_q = start_q * block_q + jnp.arange(block_q) @@ -107,9 +104,7 @@ def body(start_k, carry): ) # Use m_next instead of m_curr to avoid a correction on l_curr l_curr = s_curr.sum(axis=-1) l_next = l_prev_corr + l_curr - l_next_rcp = 1. / l_next - s_curr = s_curr * l_next_rcp[:, None] - o_prev_corr = (l_prev_corr * l_next_rcp)[:, None] * o_prev + o_prev_corr = correction[:, None] * o_prev v = pl.load(v_ref, (curr_k_slice, pl.dslice(block_d))) o_curr = pl.dot(s_curr.astype(v.dtype), v) @@ -122,14 +117,16 @@ def body(start_k, carry): upper_bound = pl.cdiv(seq_len, block_k) o, m_i, l_i = lax.fori_loop(0, upper_bound, body, (o, m_i, l_i)) + # We keep an unscaled version of o during the scan over seq_len. Scaling it + # by the last l_i gives us the correct final output. See section 3.1.1 in the + # FlashAttention-2 paper: https://arxiv.org/pdf/2307.08691. + o /= l_i[:, None] + if residual_refs: - l_ref, m_ref = residual_refs - pl.store(l_ref, (curr_q_slice,), l_i) - pl.store(m_ref, (curr_q_slice,), m_i) + lse_ref = residual_refs[0] + lse_ref[...] = m_i + jnp.log(l_i) # Write output to dram. - o = o.astype(o_ref.dtype) - pl.store(o_ref, (curr_q_slice, pl.dslice(None)), o) - + o_ref[...] = o.astype(o_ref.dtype) def segment_mask( q_segment_ids: jax.Array, @@ -198,7 +195,7 @@ def mha( in_specs = [ pl.BlockSpec( - (None, seq_len, None, head_dim), lambda _, j, k: (j, 0, k, 0) + (None, block_q, None, head_dim), lambda i, j, k: (j, i, k, 0) ), pl.BlockSpec( (None, seq_len, None, head_dim), lambda _, j, k: (j, 0, k, 0) @@ -218,11 +215,10 @@ def mha( grid=grid_, in_specs=in_specs, out_specs=pl.BlockSpec( - (None, seq_len, None, head_dim), lambda _, j, k: (j, 0, k, 0) - ), - compiler_params=dict( - triton=dict(num_warps=num_warps_, num_stages=num_stages) + (None, block_q, None, head_dim), lambda i, j, k: (j, i, k, 0) ), + compiler_params=plgpu.TritonCompilerParams( + num_warps=num_warps_, num_stages=num_stages), out_shape=out_shape, debug=debug, interpret=interpret, @@ -262,15 +258,14 @@ def _mha_forward( sm_scale=sm_scale, causal=causal, block_q=block_q, block_k=block_k, block_d=head_dim) out_shape = [ - jax.ShapeDtypeStruct(shape=q.shape, dtype=q.dtype), # out - jax.ShapeDtypeStruct(shape=(batch_size, num_heads, seq_len), # l - dtype=jnp.float32), - jax.ShapeDtypeStruct(shape=(batch_size, num_heads, seq_len), # m - dtype=jnp.float32) + jax.ShapeDtypeStruct(shape=q.shape, dtype=q.dtype), # out + jax.ShapeDtypeStruct( + shape=(batch_size, num_heads, seq_len), dtype=jnp.float32 # lse + ), ] in_specs = [ pl.BlockSpec( - (None, seq_len, None, head_dim), lambda _, j, k: (j, 0, k, 0) + (None, block_q, None, head_dim), lambda i, j, k: (j, i, k, 0) ), pl.BlockSpec( (None, seq_len, None, head_dim), lambda _, j, k: (j, 0, k, 0) @@ -284,16 +279,15 @@ def _mha_forward( if segment_ids is None else pl.BlockSpec((None, seq_len), lambda _, j, k: (j, 0)) ) - out, l, m = pl.pallas_call( + out, lse = pl.pallas_call( kernel, grid=grid_, in_specs=in_specs, out_specs=[ pl.BlockSpec( - (None, seq_len, None, head_dim), lambda _, j, k: (j, 0, k, 0) + (None, block_q, None, head_dim), lambda i, j, k: (j, i, k, 0) ), - pl.BlockSpec((None, None, seq_len), lambda _, j, k: (j, k, 0)), - pl.BlockSpec((None, None, seq_len), lambda _, j, k: (j, k, 0)), + pl.BlockSpec((None, None, block_q), lambda i, j, k: (j, k, i)), ], compiler_params=dict( triton=dict(num_warps=num_warps_, num_stages=num_stages) @@ -303,57 +297,47 @@ def _mha_forward( interpret=interpret, name="mha_forward", )(q, k, v, segment_ids) - return out, (q, k, v, segment_ids, out, l, m) - + return out, (q, k, v, segment_ids, out, lse) -def _preprocess_backward_kernel(out_ref, dout_ref, l_ref, - new_dout_ref, delta_ref, *, - block_q: int): - pid_m = pl.program_id(0) - off_m = pl.ds(pid_m * block_q, block_q) +def _preprocess_backward_kernel(out_ref, dout_ref, delta_ref): # load - o = pl.load(out_ref, (off_m, slice(None))).astype(jnp.float32) - do = pl.load(dout_ref, (off_m, slice(None))).astype(jnp.float32) - denom = pl.load(l_ref, (off_m,)).astype(jnp.float32) + o = out_ref[...].astype(jnp.float32) + do = dout_ref[...].astype(jnp.float32) # compute - do = do / denom[:, None] delta = jnp.sum(o * do, axis=1) # write-back - pl.store(new_dout_ref, (off_m, slice(None)), - do.astype(new_dout_ref.dtype)) - pl.store(delta_ref, (off_m,), delta.astype(delta_ref.dtype)) + delta_ref[...] = delta.astype(delta_ref.dtype) @jax.named_scope("preprocess_backward") -def _preprocess_backward(out, do, l, block_q: int, +def _preprocess_backward(out, do, lse, block_q: int, debug: bool, interpret: bool): batch_size, seq_len, num_heads, head_dim = out.shape - out_shape = [ - jax.ShapeDtypeStruct(do.shape, do.dtype), - jax.ShapeDtypeStruct(l.shape, l.dtype), - ] - do_scaled, delta = pl.pallas_call( - functools.partial(_preprocess_backward_kernel, block_q=block_q), + out_shape = jax.ShapeDtypeStruct(lse.shape, lse.dtype) + delta = pl.pallas_call( + _preprocess_backward_kernel, grid=(pl.cdiv(seq_len, block_q), batch_size, num_heads), in_specs=[ - pl.BlockSpec((None, seq_len, None, head_dim), lambda _, j, k: (j, 0, k, 0)), - pl.BlockSpec((None, seq_len, None, head_dim), lambda _, j, k: (j, 0, k, 0)), - pl.BlockSpec((None, None, seq_len), lambda _, j, k: (j, k, 0)), - ], - out_specs=[ - pl.BlockSpec((None, seq_len, None, head_dim), lambda _, j, k: (j, 0, k, 0)), - pl.BlockSpec((None, None, seq_len), lambda _, j, k: (j, k, 0)), + pl.BlockSpec( + (None, block_q, None, head_dim), lambda i, j, k: (j, i, k, 0) + ), + pl.BlockSpec( + (None, block_q, None, head_dim), lambda i, j, k: (j, i, k, 0) + ), ], - compiler_params=dict( - triton=dict(num_warps=4, num_stages=3) - ), + out_specs=pl.BlockSpec((None, None, block_q), lambda i, j, k: (j, k, i)), + compiler_params=dict(triton=dict(num_warps=4, num_stages=3)), out_shape=out_shape, debug=debug, interpret=interpret, - name="mha_preprocess_backward")(out, do, l) - return do_scaled, delta + name="mha_preprocess_backward", + )(out, do) + return delta +# This kernel computes dK_i, dV_i and dQ_i in parallel across the sequence +# length. +# Inspired by the triton tutorial: https://github.com/triton-lang/triton/blob/main/python/tutorials/06-fused-attention.py def mha_backward_kernel( # Inputs q_ref, @@ -362,10 +346,8 @@ def mha_backward_kernel( segment_ids_ref: jax.Array | None, out_ref, do_scaled_ref, - l_ref, - m_ref, + lse_ref, delta_ref, - _, # Outputs dq_ref, dk_ref, @@ -373,84 +355,141 @@ def mha_backward_kernel( *, sm_scale: float, causal: bool, - block_q: int, + block_q1: int, + block_k1: int, + block_q2: int, + block_k2: int, block_d: int, - block_k: int, ): - del out_ref, l_ref # Not needed + del out_ref # Not needed seq_len = q_ref.shape[0] - def outer_loop(start_k, _): - - dv = jnp.zeros([block_k, block_d], dtype=jnp.float32) - dk = jnp.zeros([block_k, block_d], dtype=jnp.float32) - k = pl.load(k_ref, (pl.ds(start_k * block_k, block_k), slice(None))) - v = pl.load(v_ref, (pl.ds(start_k * block_k, block_k), slice(None))) - span_k = start_k * block_k + jnp.arange(block_k) - kv_segment_ids = ( - None - if segment_ids_ref is None - else pl.load(segment_ids_ref, (pl.ds(start_k * block_k, block_k),)) - ) - - def inner_loop(start_q, carry): - dv, dk = carry - q = pl.load(q_ref, (pl.ds(start_q * block_q, block_q), slice(None))) - qk = pl.dot(q, k.T) - qk = qk.astype(q_ref.dtype) - qk = qk.astype(jnp.float32) - if sm_scale != 1.0: - qk *= sm_scale - - q_segment_ids = ( - None - if segment_ids_ref is None - else pl.load(segment_ids_ref, (pl.ds(start_q * block_q, block_q),)) - ) - - if causal or segment_ids_ref is not None: - mask = None - if segment_ids_ref is not None: - mask = segment_mask(q_segment_ids, kv_segment_ids) - - if causal: - span_q = start_q * block_q + jnp.arange(block_q) - causal_mask = span_q[:, None] >= span_k[None, :] - mask = ( - causal_mask - if mask is None - else jnp.logical_and(mask, causal_mask) - ) - qk = jnp.where(mask, qk, DEFAULT_MASK_VALUE) - - m = pl.load(m_ref, (pl.ds(start_q * block_q, block_q),)) - p = jnp.exp(qk - m[:, None]) - do = pl.load(do_scaled_ref, (pl.ds(start_q * block_q, block_q), slice(None))) - dv = dv + pl.dot(p.astype(do.dtype).T, do) - di = pl.load(delta_ref, (pl.ds(start_q * block_q, block_q),)) - dp = jnp.zeros((block_q, block_k), dtype=jnp.float32) - di[:, None] - dp = dp + pl.dot(do, v.T) - ds = p * dp - if sm_scale != 1.0: - ds = ds * sm_scale - dk = dk + pl.dot(ds.astype(q_ref.dtype).T, q) - dq = pl.load(dq_ref, (pl.ds(start_q * block_q, block_q), - slice(None)), eviction_policy="evict_last") - dq = dq + pl.dot(ds.astype(k.dtype), k).astype(dq.dtype) - pl.store(dq_ref, (pl.ds(start_q * block_q, block_q), - slice(None)), dq, eviction_policy="evict_last") - return dv, dk - if causal: - lower_bound = lax.div(start_k * block_k, block_q) - else: - lower_bound = 0 - dv, dk = lax.fori_loop(lower_bound, pl.cdiv(seq_len, block_q), inner_loop, - (dv, dk)) - pl.store(dv_ref, (pl.ds(start_k * block_k, block_k), - slice(None)), dv.astype(dv_ref.dtype)) - pl.store(dk_ref, (pl.ds(start_k * block_k, block_k), - slice(None)), dk.astype(dk_ref.dtype)) - lax.fori_loop(0, pl.cdiv(seq_len, block_k), outer_loop, None) + # Scan #1: dK and dV + # 1. Load a block of K and V of size (block_k1, head_dim) in SMEM. + # 2. Iterate through Q in chunks of (block_q1, head_dim) to accumulate + # dK and dV. + start_k = pl.program_id(2) + curr_k_slice = pl.dslice(start_k * block_k1, block_k1) + + dv = jnp.zeros([block_k1, block_d], dtype=jnp.float32) + dk = jnp.zeros([block_k1, block_d], dtype=jnp.float32) + + v = pl.load(v_ref, (curr_k_slice, slice(None))) + k = pl.load(k_ref, (curr_k_slice, slice(None))) + span_k = start_k * block_k1 + jnp.arange(block_k1) + kv_segment_ids = ( + None + if segment_ids_ref is None + else pl.load(segment_ids_ref, (curr_k_slice,)) + ) + + def inner_loop_dkdv(start_q, carry): + dv, dk = carry + curr_q_slice = pl.dslice(start_q * block_q1, block_q1) + + q = pl.load(q_ref, (curr_q_slice, slice(None))) + qk = pl.dot(q, k.T) + if sm_scale != 1.0: + qk *= sm_scale + + if causal or segment_ids_ref is not None: + mask = None + if segment_ids_ref is not None: + q_segment_ids = pl.load(segment_ids_ref, (curr_q_slice,)) + mask = segment_mask(q_segment_ids, kv_segment_ids) + + if causal: + span_q = start_q * block_q1 + jnp.arange(block_q1) + causal_mask = span_q[:, None] >= span_k[None, :] + mask = ( + causal_mask if mask is None else jnp.logical_and(mask, causal_mask) + ) + qk = jnp.where(mask, qk, DEFAULT_MASK_VALUE) + + lse = pl.load(lse_ref, (curr_q_slice,)) + di = pl.load(delta_ref, (curr_q_slice,)) + do = pl.load(do_scaled_ref, (curr_q_slice, slice(None))) + + p = jnp.exp(qk - lse[:, None]) + dv = dv + pl.dot(p.astype(do.dtype).T, do) + dp = jnp.zeros((block_q1, block_k1), dtype=jnp.float32) - di[:, None] + dp = dp + pl.dot(do, v.T) + ds = p * dp + if sm_scale != 1.0: + ds = ds * sm_scale + dk = dk + pl.dot(ds.astype(q_ref.dtype).T, q) + + return dv, dk + + lower_bound = lax.div(start_k * block_k1, block_q1) if causal else 0 + dv, dk = lax.fori_loop( + lower_bound, pl.cdiv(seq_len, block_q1), inner_loop_dkdv, (dv, dk) + ) + dv_ref[...] = dv.astype(dv_ref.dtype) + dk_ref[...] = dk.astype(dk_ref.dtype) + + del dv, dk + + # Scan #2: dQ + # 1. Load a block of Q of size (block_q2, head_dim) in SMEM. + # 2. Iterate through K and V in chunks of (block_k2, head_dim) to + # accumulate dQ. + start_q = pl.program_id(2) + curr_q_slice = pl.ds(start_q * block_q2, block_q2) + span_q = start_q * block_q2 + jnp.arange(block_q2) + dq = jnp.zeros([block_q2, block_d], dtype=jnp.float32) + + q = pl.load(q_ref, (curr_q_slice, slice(None))) + q_segment_ids = ( + None + if segment_ids_ref is None + else pl.load(segment_ids_ref, (curr_q_slice,)) + ) + lse = pl.load(lse_ref, (curr_q_slice,)) + do = pl.load(do_scaled_ref, (curr_q_slice, slice(None))) + di = pl.load(delta_ref, (curr_q_slice,)) + + def inner_loop_dq(start_k, dq): + curr_k_slice = pl.dslice(start_k * block_k2, block_k2) + k = pl.load(k_ref, (curr_k_slice, slice(None))) + v = pl.load(v_ref, (curr_k_slice, slice(None))) + + qk = pl.dot(q, k.T) + if sm_scale != 1.0: + qk *= sm_scale + + if causal or segment_ids_ref is not None: + mask = None + if segment_ids_ref is not None: + kv_segment_ids = pl.load(segment_ids_ref, (curr_k_slice,)) + mask = segment_mask(q_segment_ids, kv_segment_ids) + + if causal: + span_k = start_k * block_k2 + jnp.arange(block_k2) + causal_mask = span_q[:, None] >= span_k[None, :] + mask = ( + causal_mask if mask is None else jnp.logical_and(mask, causal_mask) + ) + qk = jnp.where(mask, qk, DEFAULT_MASK_VALUE) + + p = jnp.exp(qk - lse[:, None]) + dp = jnp.zeros((block_q2, block_k2), dtype=jnp.float32) - di[:, None] + dp = dp + pl.dot(do, v.T) + ds = p * dp + if sm_scale != 1.0: + ds = ds * sm_scale + + dq = dq + pl.dot(ds.astype(k.dtype), k).astype(dq.dtype) + + return dq + + if causal: + upper_bound = lax.div((start_q + 1) * block_q2, block_k2) + else: + upper_bound = pl.cdiv(seq_len, block_k2) + + dq = lax.fori_loop(0, upper_bound, inner_loop_dq, (dq)) + dq_ref[...] = dq.astype(dq_ref.dtype) def _mha_backward(sm_scale: float, causal: bool, block_q: int, block_k: int, @@ -458,7 +497,7 @@ def _mha_backward(sm_scale: float, causal: bool, block_q: int, block_k: int, num_stages: int, grid: Any, interpret: bool, debug: bool, res, do): del num_warps, num_stages, grid - q, k, v, segment_ids, out, l, m = res + q, k, v, segment_ids, out, lse = res if backward_pass_impl == "xla": return jax.vjp( @@ -472,76 +511,72 @@ def _mha_backward(sm_scale: float, causal: bool, block_q: int, block_k: int, batch_size, seq_len, num_heads, head_dim = q.shape block_q = min(block_q, seq_len) block_k = min(block_k, seq_len) - do_scaled, delta = _preprocess_backward(out, do, l, block_q, debug, interpret) - # We accumulate into dq so we need to initialize it to zeros. - dq = jnp.zeros(q.shape, jnp.float32) + delta = _preprocess_backward(out, do, lse, block_q, debug, interpret) out_shapes = [ - jax.ShapeDtypeStruct(dq.shape, dq.dtype), - jax.ShapeDtypeStruct(k.shape, k.dtype), - jax.ShapeDtypeStruct(v.shape, v.dtype), + jax.ShapeDtypeStruct(q.shape, q.dtype), + jax.ShapeDtypeStruct(k.shape, k.dtype), + jax.ShapeDtypeStruct(v.shape, v.dtype), ] in_specs = [ pl.BlockSpec( - (None, seq_len, None, head_dim), lambda j, k: (j, 0, k, 0) - ), - pl.BlockSpec( - (None, seq_len, None, head_dim), lambda j, k: (j, 0, k, 0) + (None, seq_len, None, head_dim), lambda i, j, _: (i, 0, j, 0) ), pl.BlockSpec( - (None, seq_len, None, head_dim), lambda j, k: (j, 0, k, 0) + (None, seq_len, None, head_dim), lambda i, j, _: (i, 0, j, 0) ), pl.BlockSpec( - (None, seq_len, None, head_dim), lambda j, k: (j, 0, k, 0) + (None, seq_len, None, head_dim), lambda i, j, _: (i, 0, j, 0) ), pl.BlockSpec( - (None, seq_len, None, head_dim), lambda j, k: (j, 0, k, 0) + (None, seq_len, None, head_dim), lambda i, j, _: (i, 0, j, 0) ), - pl.BlockSpec((None, None, seq_len), lambda j, k: (j, k, 0)), - pl.BlockSpec((None, None, seq_len), lambda j, k: (j, k, 0)), - pl.BlockSpec((None, None, seq_len), lambda j, k: (j, k, 0)), pl.BlockSpec( - (None, seq_len, None, head_dim), lambda j, k: (j, 0, k, 0) + (None, seq_len, None, head_dim), lambda i, j, _: (i, 0, j, 0) ), + pl.BlockSpec((None, None, seq_len), lambda i, j, _: (i, j, 0)), + pl.BlockSpec((None, None, seq_len), lambda i, j, _: (i, j, 0)), ] if segment_ids is None: in_specs.insert(3, None) # type: ignore[arg-type] - input_output_aliases = {8: 0} else: - in_specs.insert(3, pl.BlockSpec((None, seq_len), lambda j, k: (j, 0))) - input_output_aliases = {9: 0} - grid = (batch_size, num_heads) - # TODO(sharadmv): figure out why num_warps=8 doesn't work! + in_specs.insert(3, pl.BlockSpec((None, seq_len), lambda i, j, _: (i, 0))) + + grid = (batch_size, num_heads, pl.cdiv(seq_len, block_k)) num_warps = 8 dq, dk, dv = pl.pallas_call( functools.partial( mha_backward_kernel, - block_q=block_q, - block_d=head_dim, - block_k=block_k, sm_scale=sm_scale, causal=causal, + block_q1=block_q, + block_k1=block_k, + block_q2=block_q, + block_k2=block_k, + block_d=head_dim, ), - grid=grid, out_shape=out_shapes, in_specs=in_specs, + grid=grid, out_specs=[ pl.BlockSpec( - (None, seq_len, None, head_dim), lambda j, k: (j, 0, k, 0) + (None, block_q, None, head_dim), + lambda i, j, k: (i, k, j, 0), # dq ), pl.BlockSpec( - (None, seq_len, None, head_dim), lambda j, k: (j, 0, k, 0) + (None, block_k, None, head_dim), + lambda i, j, k: (i, k, j, 0), # dk ), pl.BlockSpec( - (None, seq_len, None, head_dim), lambda j, k: (j, 0, k, 0) + (None, block_k, None, head_dim), + lambda i, j, k: (i, k, j, 0), # dv ), ], name="mha_backward", debug=debug, interpret=interpret, - compiler_params=dict(triton=dict(num_warps=num_warps, num_stages=1)), - input_output_aliases=input_output_aliases, - )(q, k, v, segment_ids, out, do_scaled, l, m, delta, dq) + compiler_params=dict(triton=dict(num_warps=num_warps, num_stages=2)), + )(q, k, v, segment_ids, out, do, lse, delta) else: raise ValueError(f"Invalid backward pass implementation: {backward_pass_impl}") return dq.astype(q.dtype), dk, dv, None diff --git a/jax/experimental/pallas/ops/gpu/decode_attention.py b/jax/experimental/pallas/ops/gpu/decode_attention.py index 9be724a1f42c..dde80d4603cc 100644 --- a/jax/experimental/pallas/ops/gpu/decode_attention.py +++ b/jax/experimental/pallas/ops/gpu/decode_attention.py @@ -21,6 +21,7 @@ import jax from jax import lax from jax.experimental import pallas as pl +from jax.experimental.pallas import gpu as plgpu import jax.numpy as jnp @@ -153,8 +154,8 @@ def attn_unbatched( pl.BlockSpec((None, block_h), lambda i, j: (j, i)), # l pl.BlockSpec((None, block_h), lambda i, j: (j, i)), # m ], - compiler_params=dict( - triton=dict(num_warps=num_warps_, num_stages=num_stages) + compiler_params=plgpu.TritonCompilerParams( + num_warps=num_warps_, num_stages=num_stages ), out_shape=[ jax.ShapeDtypeStruct(shape=(k_splits, *q.shape), dtype=q.dtype), # o diff --git a/jax/experimental/pallas/ops/gpu/layer_norm.py b/jax/experimental/pallas/ops/gpu/layer_norm.py index 0c39a9bf6e0d..e531395079ba 100644 --- a/jax/experimental/pallas/ops/gpu/layer_norm.py +++ b/jax/experimental/pallas/ops/gpu/layer_norm.py @@ -24,6 +24,7 @@ from jax._src.lax.control_flow.for_loop import for_loop from jax.experimental import pallas as pl +from jax.experimental.pallas import gpu as plgpu def layer_norm_forward_kernel( x_ref, weight_ref, bias_ref, # Input arrays @@ -282,9 +283,8 @@ def layer_norm( out_shape = jax.ShapeDtypeStruct(shape=(n,), dtype=x.dtype) method = pl.pallas_call( kernel, - compiler_params=dict( - triton=dict(num_warps=num_warps, num_stages=num_stages) - ), + compiler_params=plgpu.TritonCompilerParams( + num_warps=num_warps, num_stages=num_stages), grid=(), out_shape=out_shape, debug=False, diff --git a/jax/experimental/pallas/ops/gpu/rms_norm.py b/jax/experimental/pallas/ops/gpu/rms_norm.py index e1dfa3c5b9b7..3e373b895b8d 100644 --- a/jax/experimental/pallas/ops/gpu/rms_norm.py +++ b/jax/experimental/pallas/ops/gpu/rms_norm.py @@ -26,6 +26,7 @@ from jax._src.lax.control_flow.for_loop import for_loop from jax.experimental import pallas as pl +from jax.experimental.pallas import gpu as plgpu def rms_norm_forward_kernel( x_ref, weight_ref, bias_ref, # Input arrays @@ -83,7 +84,7 @@ def rms_norm_forward( ] method = pl.pallas_call( kernel, - compiler_params=dict(triton=dict(num_warps=num_warps)), + compiler_params=plgpu.TritonCompilerParams(num_warps=num_warps), grid=(), out_shape=out_shape, debug=False, diff --git a/jax/experimental/pallas/ops/gpu/softmax.py b/jax/experimental/pallas/ops/gpu/softmax.py index 3671331b8df8..33b416d165d7 100644 --- a/jax/experimental/pallas/ops/gpu/softmax.py +++ b/jax/experimental/pallas/ops/gpu/softmax.py @@ -18,6 +18,7 @@ import jax import jax.numpy as jnp from jax.experimental import pallas as pl +from jax.experimental.pallas import gpu as plgpu def _vmappable_softmax_kernel( @@ -79,7 +80,8 @@ def softmax( kernel = functools.partial(_vmappable_softmax_kernel, block_row=block_row) f = pl.pallas_call( kernel, - compiler_params=dict(triton=dict(num_warps=num_warps, num_stages=1)), + compiler_params=plgpu.TritonCompilerParams( + num_warps=num_warps, num_stages=1), grid=(), out_shape=out_shape, debug=debug, diff --git a/jax/experimental/pallas/ops/tpu/all_gather.py b/jax/experimental/pallas/ops/tpu/all_gather.py index e121db894122..8fb975504e26 100644 --- a/jax/experimental/pallas/ops/tpu/all_gather.py +++ b/jax/experimental/pallas/ops/tpu/all_gather.py @@ -136,7 +136,7 @@ def ag_local(x_shard): out = pl.pallas_call( functools.partial(ag_kernel, axis_name=axis_name, mesh=mesh), out_shape=out_shape, - compiler_params=dict(mosaic=dict(collective_id=0)), + compiler_params=pltpu.TPUCompilerParams(collective_id=0), grid_spec=pltpu.PrefetchScalarGridSpec( num_scalar_prefetch=0, scratch_shapes=( diff --git a/jax/experimental/pallas/ops/tpu/flash_attention.py b/jax/experimental/pallas/ops/tpu/flash_attention.py index f0332a87b508..82bcde8153ef 100644 --- a/jax/experimental/pallas/ops/tpu/flash_attention.py +++ b/jax/experimental/pallas/ops/tpu/flash_attention.py @@ -17,6 +17,7 @@ import dataclasses import functools +import math from typing import Any, NamedTuple import jax @@ -565,6 +566,40 @@ def _flash_attention_kernel_single_batch_single_step( ).astype(o_tile_ref.dtype) +def _bytes(x: jax.Array | jax.ShapeDtypeStruct) -> int: + return math.prod(x.shape) * x.dtype.itemsize + + +def _fwd_cost_estimate( + q: jax.Array, + k: jax.Array, + v: jax.Array, + ab: jax.Array | None, + segment_ids: SegmentIds | None, + *, + causal: bool, + sm_scale: jax.Array | None, + kernel_inputs_specs, + kernel_outputs_specs, +) -> pl.CostEstimate | None: + full_cost = ( + mha_reference.lower( + q, k, v, ab, segment_ids, causal=causal, sm_scale=sm_scale + ) + .compile() + .cost_analysis() + ) + if not full_cost: + return None + input_bytes = sum(_bytes(x) for x in jax.tree.leaves(kernel_inputs_specs)) + output_bytes = sum(_bytes(x) for x in jax.tree.leaves(kernel_outputs_specs)) + return pl.CostEstimate( + flops=full_cost[0]["flops"], + transcendentals=full_cost[0]["transcendentals"], + bytes_accessed=input_bytes + output_bytes, + ) + + def _flash_attention_impl( q, k, @@ -745,16 +780,25 @@ def kv_segment_ids_index_map( ), out_shape=out_shape, debug=debug, - compiler_params=dict( - mosaic=dict( - dimension_semantics=( - "parallel", - "parallel", - "parallel", - "arbitrary", - ) + compiler_params=pltpu.TPUCompilerParams( + dimension_semantics=( + "parallel", + "parallel", + "parallel", + "arbitrary", ) ), + cost_estimate=_fwd_cost_estimate( + q, + k, + v, + ab, + segment_ids, + causal=causal, + sm_scale=sm_scale, + kernel_inputs_specs=(q, k, v, ab, q_segment_ids, kv_segment_ids), + kernel_outputs_specs=out_shape, + ), )(q, k, v, ab, q_segment_ids, kv_segment_ids) if save_residuals: l, m = (v[..., 0] for v in aux[-2:]) @@ -1105,15 +1149,13 @@ def dkv_index_map(batch_index, head_index, kv_seq_index, _): ), out_shape=out_shapes, debug=debug, - compiler_params=dict( - mosaic=dict( + compiler_params=pltpu.TPUCompilerParams( dimension_semantics=( "parallel", "parallel", "parallel", "arbitrary", ) - ) ), )(q, k, v, ab, q_segment_ids, kv_segment_ids, l, m, do, di) assert dk.shape == k.shape @@ -1450,15 +1492,13 @@ def kv_segment_ids_index_map( ), out_shape=out_shapes, debug=debug, - compiler_params=dict( - mosaic=dict( + compiler_params=pltpu.TPUCompilerParams( dimension_semantics=( "parallel", "parallel", "parallel", "arbitrary", ) - ) ), )(q, k, v, ab, q_segment_ids, kv_segment_ids, l, m, do, di) diff --git a/jax/experimental/pallas/ops/tpu/matmul.py b/jax/experimental/pallas/ops/tpu/matmul.py new file mode 100644 index 000000000000..4ff82acbb5dd --- /dev/null +++ b/jax/experimental/pallas/ops/tpu/matmul.py @@ -0,0 +1,84 @@ +# Copyright 2023 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Example matmul TPU kernel. + +See discussion in https://jax.readthedocs.io/en/latest/pallas/tpu/matmul.html. +""" + +import functools + +import jax +from jax.experimental import pallas as pl +from jax.experimental.pallas import tpu as pltpu +import jax.numpy as jnp + + +def matmul_kernel(x_tile_ref, y_tile_ref, o_tile_ref, acc_ref): + @pl.when(pl.program_id(2) == 0) + def init(): + acc_ref[...] = jnp.zeros_like(acc_ref) + + acc_ref[...] = acc_ref[...] + jnp.dot( + x_tile_ref[...], + y_tile_ref[...], + preferred_element_type=acc_ref.dtype, + ) + # It is possible to make this conditional but in general this bundle packs + # quite well for a simple matmul kernel + o_tile_ref[...] = acc_ref[...].astype(o_tile_ref.dtype) + + +@functools.partial( + jax.jit, static_argnames=["block_shape", "block_k", "debug", "out_dtype"] +) +def matmul( + x: jax.Array, + y: jax.Array, + *, + block_shape, + block_k: int = 256, + out_dtype: jnp.dtype | None = None, + debug: bool = False, +) -> jax.Array: + if out_dtype is None: + if x.dtype != y.dtype: + # TODO(tlongeri): Maybe we could use a deduction similar to jnp.dot + raise TypeError( + f"Cannot deduce output dtype for different input dtypes: {x.dtype}," + f" {y.dtype}" + ) + out_dtype = x.dtype + acc_dtype = jnp.float32 + if x.dtype in [jnp.int8, jnp.int4, jnp.uint8, jnp.uint4]: + acc_dtype = jnp.int32 + + l, r = block_shape + return pl.pallas_call( + matmul_kernel, + out_shape=jax.ShapeDtypeStruct((x.shape[0], y.shape[1]), out_dtype), + grid_spec=pltpu.PrefetchScalarGridSpec( + num_scalar_prefetch=0, + in_specs=[ + pl.BlockSpec((l, block_k), lambda i, _, k: (i, k)), + pl.BlockSpec((block_k, r), lambda _, j, k: (k, j)), + ], + out_specs=pl.BlockSpec((l, r), lambda i, j, k: (i, j)), + grid=(x.shape[0] // l, y.shape[1] // r, x.shape[1] // block_k), + scratch_shapes=[pltpu.VMEM((l, r), acc_dtype)], + ), + compiler_params=pltpu.TPUCompilerParams( + dimension_semantics=("parallel", "parallel", "arbitrary")), + debug=debug, + )(x, y) diff --git a/jax/experimental/pallas/ops/tpu/megablox/gmm.py b/jax/experimental/pallas/ops/tpu/megablox/gmm.py index 320851422abf..5c2f938597e7 100644 --- a/jax/experimental/pallas/ops/tpu/megablox/gmm.py +++ b/jax/experimental/pallas/ops/tpu/megablox/gmm.py @@ -538,11 +538,8 @@ def out_transform_indices(n_i, grid_id, k_i, group_metadata, group_offset): scratch_shapes=[pltpu.VMEM((tm, tn), jnp.float32)], ), input_output_aliases=input_output_aliases, - compiler_params=dict( - mosaic=dict( - dimension_semantics=("parallel", "arbitrary", "arbitrary"), - ) - ), + compiler_params=pltpu.TPUCompilerParams( + dimension_semantics=("parallel", "arbitrary", "arbitrary")), interpret=interpret, cost_estimate=cost_estimate, ) @@ -780,13 +777,10 @@ def out_transform_indices(n_i, k_i, grid_id, group_metadata, group_offset): scratch_shapes=[pltpu.VMEM((tk, tn), jnp.float32)], ), input_output_aliases=input_output_aliases, - compiler_params=dict( - mosaic=dict( - dimension_semantics=("parallel", "arbitrary", "arbitrary"), - cost_estimate=cost_estimate, - ) - ), + compiler_params=pltpu.TPUCompilerParams( + dimension_semantics=("parallel", "arbitrary", "arbitrary")), interpret=interpret, + cost_estimate=cost_estimate, ) out = call_gmm( diff --git a/jax/experimental/pallas/ops/tpu/paged_attention/paged_attention_kernel.py b/jax/experimental/pallas/ops/tpu/paged_attention/paged_attention_kernel.py index 82fa5f7427bd..eb1e11df17da 100644 --- a/jax/experimental/pallas/ops/tpu/paged_attention/paged_attention_kernel.py +++ b/jax/experimental/pallas/ops/tpu/paged_attention/paged_attention_kernel.py @@ -14,7 +14,9 @@ """PagedAttention TPU kernel.""" +from collections.abc import Sequence import functools +from typing import Literal import jax from jax import lax @@ -516,6 +518,7 @@ def paged_attention( ) q_dtype_for_kernel_launch = q.dtype + dimension_semantics: Sequence[Literal["parallel", "arbitrary"]] if inline_seq_dim: kernel = paged_flash_attention_kernel_inline_seq_dim grid = ( @@ -525,7 +528,7 @@ def paged_attention( if megacore_mode == "kv_head" else num_kv_heads, ) - dimension_sematics = ("parallel", "arbitrary", "arbitrary") + dimension_semantics = ("parallel", "arbitrary", "arbitrary") else: kernel = paged_flash_attention_kernel grid = ( @@ -536,7 +539,7 @@ def paged_attention( else num_kv_heads, pages_per_sequence // pages_per_compute_block, ) # type: ignore - dimension_sematics = ("parallel", "arbitrary", "arbitrary", "arbitrary") # type: ignore + dimension_semantics = ("parallel", "arbitrary", "arbitrary", "arbitrary") if k_scales_pages is not None and v_scales_pages is not None: in_specs = [ @@ -640,7 +643,8 @@ def paged_attention( grid=grid, scratch_shapes=scratch_shapes, ), - compiler_params=dict(mosaic=dict(dimension_semantics=dimension_sematics)), + compiler_params=pltpu.TPUCompilerParams( + dimension_semantics=dimension_semantics), out_shape=[ jax.ShapeDtypeStruct(q.shape, q_dtype_for_kernel_launch), jax.ShapeDtypeStruct((*q.shape[:-1], 1), jnp.float32), diff --git a/jax/experimental/pallas/ops/tpu/splash_attention/splash_attention_kernel.py b/jax/experimental/pallas/ops/tpu/splash_attention/splash_attention_kernel.py index 4ae761d78953..536c32e574b2 100644 --- a/jax/experimental/pallas/ops/tpu/splash_attention/splash_attention_kernel.py +++ b/jax/experimental/pallas/ops/tpu/splash_attention/splash_attention_kernel.py @@ -1071,11 +1071,6 @@ def logsumexp_index_map(h, i, *_): out_shapes += [None] out_specs += [None] - mosaic_params = dict( - dimension_semantics=("parallel", "arbitrary", "arbitrary"), - flags={"XLA_TPU_FORCE_LP_LLO_SCHEDULER": True}, - ) - kernel_name = get_kernel_name( dataclasses.asdict(block_sizes), is_mqa=is_mqa, @@ -1112,7 +1107,9 @@ def logsumexp_index_map(h, i, *_): out_specs=out_specs, grid=grid, ), - compiler_params=dict(mosaic=mosaic_params), + compiler_params=pltpu.TPUCompilerParams( + dimension_semantics=("parallel", "arbitrary", "arbitrary"), + ), out_shape=out_shapes, name=kernel_name, interpret=interpret, @@ -1545,11 +1542,6 @@ def logsumexp_index_map(h, i, *_): ) num_scalar_prefetch = 3 - mosaic_params = dict( - dimension_semantics=("arbitrary", "arbitrary", "arbitrary"), - flags={"XLA_TPU_FORCE_LP_LLO_SCHEDULER": True}, - ) - kernel_name = get_kernel_name( dict( block_q_dq=bq, @@ -1573,7 +1565,9 @@ def logsumexp_index_map(h, i, *_): grid=grid, ), out_shape=out_shapes, - compiler_params=dict(mosaic=mosaic_params), + compiler_params=pltpu.TPUCompilerParams( + dimension_semantics=("arbitrary", "arbitrary", "arbitrary"), + ), name=kernel_name, interpret=interpret, )( @@ -2088,16 +2082,6 @@ def logsumexp_index_map( ) num_scalar_prefetch = 3 - # We set all dimensions to arbitrary because: - # 1) for kv_seq_len, the splash attention prefetch schedule assumes no - # megacore - # 2) for heads, we are reducing over heads - # 3) for q_seq_len, we are reducing over it to compute dkv - mosaic_params = dict( - dimension_semantics=("arbitrary", "arbitrary", "arbitrary"), - flags={"XLA_TPU_FORCE_LP_LLO_SCHEDULER": True}, - ) - kernel_name = get_kernel_name( dict( block_q_dkv=bq, @@ -2122,7 +2106,14 @@ def logsumexp_index_map( grid=grid, ), out_shape=out_shapes, - compiler_params=dict(mosaic=mosaic_params), + # We set all dimensions to arbitrary because: + # 1) for kv_seq_len, the splash attention prefetch schedule assumes no + # megacore + # 2) for heads, we are reducing over heads + # 3) for q_seq_len, we are reducing over it to compute dkv + compiler_params=pltpu.TPUCompilerParams( + dimension_semantics=("arbitrary", "arbitrary", "arbitrary"), + ), name=kernel_name, interpret=interpret, )( diff --git a/jax/experimental/pallas/tpu.py b/jax/experimental/pallas/tpu.py index 79d773379f9b..8a1a223ae36e 100644 --- a/jax/experimental/pallas/tpu.py +++ b/jax/experimental/pallas/tpu.py @@ -21,6 +21,7 @@ from jax._src.pallas.mosaic.core import semaphore from jax._src.pallas.mosaic.core import SemaphoreType from jax._src.pallas.mosaic.core import TPUMemorySpace +from jax._src.pallas.mosaic.core import TPUCompilerParams from jax._src.pallas.mosaic.lowering import LoweringException from jax._src.pallas.mosaic.pipeline import ARBITRARY from jax._src.pallas.mosaic.pipeline import BufferedRef @@ -67,3 +68,4 @@ CMEM = TPUMemorySpace.CMEM SMEM = TPUMemorySpace.SMEM VMEM = TPUMemorySpace.VMEM +SEMAPHORE = TPUMemorySpace.SEMAPHORE diff --git a/jax/experimental/shard_map.py b/jax/experimental/shard_map.py index e16fa2814e9d..0dace1977dc0 100644 --- a/jax/experimental/shard_map.py +++ b/jax/experimental/shard_map.py @@ -26,7 +26,7 @@ import jax import jax.numpy as jnp -from jax.sharding import NamedSharding, PartitionSpec, Mesh +from jax.sharding import NamedSharding, PartitionSpec from jax._src import ad_checkpoint from jax._src import ad_util from jax._src import callback @@ -46,10 +46,13 @@ from jax._src import traceback_util from jax._src import util from jax._src.core import Tracer +from jax._src.mesh import AbstractMesh, Mesh from jax._src.api import _shared_code_pmap, _prepare_pmap from jax._src.lax import (lax, parallel as lax_parallel, slicing, windowed_reductions, convolution, fft, linalg, special, control_flow, ann) +from jax._src.lib.mlir import ir +from jax._src.lib.mlir.dialects import hlo, sdy from jax._src.util import (HashableFunction, HashablePartial, unzip2, unzip3, as_hashable_function, memoize, partition_list, merge_lists, split_list, subs_list2) @@ -79,8 +82,9 @@ @traceback_util.api_boundary -def shard_map(f: Callable, mesh: Mesh, in_specs: Specs, out_specs: Specs, - check_rep: bool = True, auto: frozenset[AxisName] = frozenset()): +def shard_map(f: Callable, mesh: Mesh | AbstractMesh, in_specs: Specs, + out_specs: Specs, check_rep: bool = True, + auto: frozenset[AxisName] = frozenset()): """Map a function over shards of data. Note: @@ -134,14 +138,15 @@ def shard_map(f: Callable, mesh: Mesh, in_specs: Specs, out_specs: Specs, """ return _shard_map(f, mesh, in_specs, out_specs, check_rep, auto) -def _shard_map(f: Callable, mesh: Mesh, in_specs: Specs, +def _shard_map(f: Callable, mesh: Mesh | AbstractMesh, in_specs: Specs, out_specs: Specs | Callable[[], Specs], check_rep: bool, auto: frozenset[AxisName]): if not callable(f): raise TypeError("shard_map requires a callable for its first argument, " f"but got {f} of type {type(f)}.") - if not isinstance(mesh, Mesh): - raise TypeError("shard_map requires a `jax.sharding.Mesh` instance for its " + if not isinstance(mesh, (Mesh, AbstractMesh)): + raise TypeError("shard_map requires a `jax.sharding.Mesh` or a " + "`jax.sharding.AbstractMesh` instance for its " f"second argument, but got {mesh} of type {type(mesh)}.") if not auto.issubset(mesh.axis_names): raise ValueError(f"shard_map requires auto={auto} to be a subset of " @@ -163,7 +168,7 @@ def wrapped(*args): raise e('shard_map in_specs') from None dyn_argnums, in_specs_flat = unzip2((i, s) for i, s in enumerate(in_specs_flat) if s is not None) - fun, args_flat = argnums_partial(fun, dyn_argnums, args_flat) + fun, args_flat = argnums_partial(fun, dyn_argnums, args_flat, False) _check_specs_vs_args(f, mesh, in_tree, in_specs, dyn_argnums, in_specs_flat, args_flat) in_names_flat = tuple(map(_canonicalize_spec, in_specs_flat)) @@ -636,13 +641,75 @@ def _rule_missing(prim: core.Primitive, *_, **__): raise NotImplementedError( f"No replication rule for {prim}. As a workaround, pass the " "`check_rep=False` argument to `shard_map`. To get this fixed, open an " - "issue at https://github.com/google/jax/issues") + "issue at https://github.com/jax-ml/jax/issues") # Lowering +def _shardy_shard_map_sharding( + ctx: mlir.LoweringRuleContext, mesh, names, aval_in + ) -> ir.Attribute: + axes = {name: i for i, ns in names.items() for name in ns} + ns = _make_scoped_manual_sharding(ctx, mesh, axes) + if dtypes.issubdtype(aval_in.dtype, dtypes.extended): + ns = sharding_impls.physical_sharding(aval_in, ns) + aval_in = core.physical_aval(aval_in) + return ns._to_sdy_sharding(aval_in.ndim).build() + + +def _shard_map_lowering_shardy( + ctx, in_nodes, jaxpr, mesh, in_names, out_names, auto): + in_avals_ = [v.aval for v in jaxpr.invars] + if isinstance(ctx.module_context.axis_context, sharding_impls.SPMDAxisContext): + # Nested `ManualComputationOp`s cannot refer to axes that are already + # manual. So figure out what axes are free thus far and get the new axis + # context. + free_axis = frozenset(mesh.axis_names) - ctx.module_context.axis_context.manual_axes + new_axis_context = sharding_impls.SPMDAxisContext(mesh, free_axis - auto) + else: + new_axis_context = sharding_impls.SPMDAxisContext( + mesh, frozenset(mesh.axis_names) - auto) + sub_ctx = ctx.module_context.replace(axis_context=new_axis_context) + args = (*ctx.dim_var_values, *in_nodes) + + manual_axes = sub_ctx.axis_context.manual_axes + mesh_shape = mesh.shape + manual_axes_size = np.prod([mesh_shape[a] for a in manual_axes]) + if manual_axes_size == 1: + # No need for a `ManualComputationOp` if all manual axes are size 1. + out_nodes, _ = mlir.jaxpr_subcomp( + sub_ctx, jaxpr, ctx.name_stack, mlir.TokenSet(), (), *args, + dim_var_values=ctx.dim_var_values) + return out_nodes + + in_shardings = sdy.TensorShardingPerValueAttr.get(map( + partial(_shardy_shard_map_sharding, ctx, mesh), + in_names, ctx.avals_in)) + out_shardings = sdy.TensorShardingPerValueAttr.get(map( + partial(_shardy_shard_map_sharding, ctx, mesh), + out_names, ctx.avals_out)) + output_types = map(mlir.aval_to_ir_type, ctx.avals_out) + manual_computation_op = sdy.ManualComputationOp( + output_types, args, in_shardings, out_shardings, + sdy.ManualAxesAttr.get( + ir.ArrayAttr.get([ir.StringAttr.get(i) for i in manual_axes]))) + block = ir.Block.create_at_start( + manual_computation_op.body, map(mlir.aval_to_ir_type, in_avals_)) + with ir.InsertionPoint(block), core.extend_axis_env_nd( + tuple(mesh.shape.items())): + out_nodes_, _ = mlir.jaxpr_subcomp( + sub_ctx, jaxpr, ctx.name_stack, mlir.TokenSet(), (), *block.arguments, + dim_var_values=ctx.dim_var_values) + sdy.ReturnOp([ir.Value(x) for x in out_nodes_]) + + return manual_computation_op.results + + def _shard_map_lowering(ctx, *in_nodes, jaxpr, mesh, in_names, out_names, check_rep, rewrite, auto): del check_rep, rewrite + if config.use_shardy_partitioner.value: + return _shard_map_lowering_shardy( + ctx, in_nodes, jaxpr, mesh, in_names, out_names, auto) in_avals_ = [v.aval for v in jaxpr.invars] out_avals_ = [x.aval for x in jaxpr.outvars] in_nodes_ = map(partial(_xla_shard, ctx, mesh, auto), in_names, ctx.avals_in, @@ -711,10 +778,29 @@ def _pspec_mhlo_attrs(names: AxisNames, aval: core.AbstractValue) -> str: # Eager evaluation +def get_mesh_from_args(args_flat, mesh): + for a in args_flat: + if hasattr(a, 'sharding') and isinstance(a.sharding, NamedSharding): + if a.sharding.mesh.shape_tuple != mesh.shape_tuple: + aval = shaped_abstractify(a) + raise ValueError( + f"Mesh shape of the input {a.sharding.mesh.shape_tuple} does not" + " match the mesh shape passed to shard_map " + f" {mesh.shape_tuple} for shape {aval.str_short()}") + mesh = a.sharding.mesh + if isinstance(mesh, AbstractMesh): + raise ValueError( + "Please pass `jax.Array`s with a `NamedSharding` as input to" + " `shard_map` when passing `AbstractMesh` to the mesh argument.") + assert isinstance(mesh, Mesh) + return mesh + def _shard_map_impl(trace, prim, fun, args, *, mesh, in_names, out_names_thunk, check_rep, rewrite, auto): if auto: raise NotImplementedError del prim, auto + if isinstance(mesh, AbstractMesh): + mesh = get_mesh_from_args(args, mesh) args = map(partial(_unmatch_spec, mesh), in_names, args) in_rep = map(partial(_in_names_to_rep, mesh), in_names) with core.new_base_main(ShardMapTrace, mesh=mesh, check=check_rep) as main: @@ -823,20 +909,20 @@ def process_call(self, call_primitive, fun, tracers, params): f"Eager evaluation of `{call_primitive}` inside a `shard_map` isn't " "yet supported. Put a `jax.jit` around the `shard_map`-decorated " "function, and open a feature request at " - "https://github.com/google/jax/issues !") + "https://github.com/jax-ml/jax/issues !") def process_map(self, map_primitive, fun, tracers, params): raise NotImplementedError( "Eager evaluation of `pmap` inside a `shard_map` isn't yet supported." "Put a `jax.jit` around the `shard_map`-decorated function, and open " - "a feature request at https://github.com/google/jax/issues !") + "a feature request at https://github.com/jax-ml/jax/issues !") def process_custom_jvp_call(self, prim, fun, jvp, tracers, *, symbolic_zeros): # Since ShardMapTrace is only used as a base main, we can drop the jvp. if symbolic_zeros: msg = ("custom_jvp symbolic_zeros support with shard_map is not " "implemented; please open an issue at " - "https://github.com/google/jax/issues") + "https://github.com/jax-ml/jax/issues") raise NotImplementedError(msg) del prim, jvp, symbolic_zeros in_vals, in_rep = unzip2((t.val, t.rep) for t in tracers) @@ -854,7 +940,7 @@ def process_custom_vjp_call(self, prim, fun, fwd, bwd, tracers, out_trees, if symbolic_zeros: msg = ("custom_vjp symbolic_zeros support with shard_map is not " "implemented; please open an issue at " - "https://github.com/google/jax/issues") + "https://github.com/jax-ml/jax/issues") raise NotImplementedError(msg) del prim, fwd, bwd, out_trees, symbolic_zeros in_vals, in_rep = unzip2((t.val, t.rep) for t in tracers) @@ -1020,7 +1106,7 @@ def _standard_check(prim, mesh, *in_rep, **__): if in_rep_ and not in_rep_[:-1] == in_rep_[1:]: raise Exception(f"Primitive {prim} requires argument replication types " f"to match, but got {in_rep}. Please open an issue at " - "https://github.com/google/jax/issues and as a temporary " + "https://github.com/jax-ml/jax/issues and as a temporary " "workaround pass the check_rep=False argument to shard_map") return in_rep_[0] if in_rep_ else None @@ -1028,6 +1114,10 @@ def register_standard_collective(prim): register_check(prim)(partial(_standard_collective_check, prim)) register_rewrite(prim)(partial(_standard_collective_rewrite, prim)) +def register_reduction_collective(prim): + register_check(prim)(partial(_reduction_collective_check, prim)) + register_rewrite(prim)(partial(_reduction_collective_rewrite, prim)) + def _standard_collective_check(prim, mesh, x_rep, *, axis_name, **params): # The standard collective check is varying -> varying over axis_name. del mesh, params @@ -1035,7 +1125,7 @@ def _standard_collective_check(prim, mesh, x_rep, *, axis_name, **params): raise Exception(f"Collective {prim} must be applied to a device-varying " f"replication type, but got {x_rep} for collective acting " f"over axis name {axis_name}. Please open an issue at " - "https://github.com/google/jax/issues and as a temporary " + "https://github.com/jax-ml/jax/issues and as a temporary " "workaround pass the check_rep=False argument to shard_map") return x_rep @@ -1049,6 +1139,28 @@ def _standard_collective_rewrite(prim, mesh, in_rep, x, axis_name, **params): out_val = prim.bind(x, axis_name=axis_name, **params) return [out_val], [x_rep - axis_name_set] +def _reduction_collective_check(prim, mesh, x_rep, *, axes, **params): + # The reduction collective check is varying -> replicated over axes. + del mesh, params + axes = (axes,) if not isinstance(axes, tuple) else axes + if x_rep is None or any(a in x_rep for a in axes): + raise Exception(f"Collective {prim} must be applied to a device-varying " + f"replication type, but got {x_rep} for collective acting " + f"over axis name {axes}. Please open an issue at " + "https://github.com/jax-ml/jax/issues and as a temporary " + "workaround pass the check_rep=False argument to shard_map") + return x_rep | set(axes) + +def _reduction_collective_rewrite(prim, mesh, in_rep, x, axes, **params): + # The standard collective rewrite may insert a pbroadcast on the input. + axes = (axes,) if not isinstance(axes, tuple) else axes + x_rep, = in_rep + axes_set = set(axes) + if pbroadcast_axes := axes_set & x_rep: + x = pbroadcast(x, tuple(pbroadcast_axes)) + out_val, = prim.bind(x, axes=axes, **params) + return [out_val], [x_rep | axes_set] + for o in it.chain(lax.__dict__.values(), slicing.__dict__.values(), windowed_reductions.__dict__.values(), @@ -1092,7 +1204,7 @@ def _psum2_check(mesh, *in_rep, axes, axis_index_groups): raise Exception("Collective psum must be applied to a device-varying " f"replication type, but got {in_rep} for collective acting " f"over axis name {axes}. Please open an issue at " - "https://github.com/google/jax/issues, and as a temporary " + "https://github.com/jax-ml/jax/issues, and as a temporary " "workaround pass the check_rep=False argument to shard_map") in_rep = tuple(set(mesh.axis_names) if r is None else r for r in in_rep) return [r | set(axes) for r in in_rep] @@ -1107,7 +1219,7 @@ def _pbroadcast_check(mesh, *in_rep, axes, axis_index_groups): "non-device-varying " f"replication type, but got {in_rep} for collective acting " f"over axis name {axes}. Please open an issue at " - "https://github.com/google/jax/issues, and as a temporary " + "https://github.com/jax-ml/jax/issues, and as a temporary " "workaround pass the check_rep=False argument to shard_map") in_rep = tuple(set(mesh.axis_names) if r is None else r for r in in_rep) return [r - set(axes) for r in in_rep] @@ -1118,6 +1230,8 @@ def _pbroadcast_check(mesh, *in_rep, axes, axis_index_groups): register_standard_collective(lax_parallel.all_to_all_p) register_standard_collective(lax_parallel.ppermute_p) register_standard_collective(lax_parallel.reduce_scatter_p) +register_reduction_collective(lax_parallel.pmin_p) +register_reduction_collective(lax_parallel.pmax_p) @register_check(lax_parallel.axis_index_p) @@ -1194,7 +1308,7 @@ def _scan_check(mesh, *in_rep, jaxpr, num_consts, num_carry, **_): if not carry_rep_in == carry_rep_out: raise Exception("Scan carry input and output got mismatched replication " f"types {carry_rep_in} and {carry_rep_out}. Please open an " - "issue at https://github.com/google/jax/issues, and as a " + "issue at https://github.com/jax-ml/jax/issues, and as a " "temporary workaround pass the check_rep=False argument to " "shard_map") return out_rep @@ -1245,7 +1359,7 @@ def _custom_vjp_call_jaxpr_rewrite( mesh, in_rep, *args, fun_jaxpr, fwd_jaxpr_thunk, bwd, num_consts, out_trees, symbolic_zeros): if symbolic_zeros: - msg = ("Please open an issue at https://github.com/google/jax/issues and as" + msg = ("Please open an issue at https://github.com/jax-ml/jax/issues and as" " a temporary workaround pass the check_rep=False argument to " "shard_map") raise NotImplementedError(msg) @@ -1281,7 +1395,7 @@ def _linear_solve_check(mesh, *in_rep, const_lengths, jaxprs): assert in_rep if not in_rep_[:-1] == in_rep_[1:]: msg = ("shard_map check_rep rewrite failed. Please open an issue at " - "https://github.com/google/jax/issues and as a workaround pass the " + "https://github.com/jax-ml/jax/issues and as a workaround pass the " "check_rep=False argument to shard_map") raise Exception(msg) return [in_rep_[0]] * len(jaxprs.solve.out_avals) @@ -1383,7 +1497,7 @@ def new_out_names_thunk(): f_jvp, out_tree = ad.traceable(f_jvp, in_tree) result = shard_map_p.bind(f_jvp, *args, **params) primal_out, tangent_out = tree_unflatten(out_tree(), result) - tangent_out = [ad.Zero(core.get_aval(p).at_least_vspace()) if t is None else t + tangent_out = [ad.Zero(core.get_aval(p).to_tangent_aval()) if t is None else t for p, t in zip(primal_out, tangent_out)] return [ad.JVPTracer(trace, p, t) for p, t in zip(primal_out, tangent_out)] ad.JVPTrace.process_shard_map = _shard_map_jvp @@ -1856,7 +1970,7 @@ def post_process_call(self, call_primitive, out_tracers, params): def process_custom_jvp_call(self, prim, fun, jvp, tracers, *, symbolic_zeros): if symbolic_zeros: - msg = ("Please open an issue at https://github.com/google/jax/issues and " + msg = ("Please open an issue at https://github.com/jax-ml/jax/issues and " "as a temporary workaround pass the check_rep=False argument to " "shard_map") raise NotImplementedError(msg) @@ -1877,7 +1991,7 @@ def post_process_custom_jvp_call(self, out_tracers, jvp_was_run): def process_custom_vjp_call(self, prim, fun, fwd, bwd, tracers, out_trees, symbolic_zeros): if symbolic_zeros: - msg = ("Please open an issue at https://github.com/google/jax/issues and " + msg = ("Please open an issue at https://github.com/jax-ml/jax/issues and " "as a temporary workaround pass the check_rep=False argument to " "shard_map") raise NotImplementedError(msg) @@ -1906,18 +2020,25 @@ def _efficient_transpose_rewrite(fun, mesh, in_names, out_names_thunk): fun, out_reps_src = _efficient_transpose_rewrite_nomatch(fun, mesh, in_reps) return _match_rep(fun, mesh, out_reps_src, out_reps_dst) +def _efficient_transpose_rewrite_nomatch(fun, mesh, in_reps): + return _efficient_transpose_outer(_efficient_transpose_inner(fun), mesh, in_reps) + @lu.transformation_with_aux -def _efficient_transpose_rewrite_nomatch(mesh, in_reps, *args): +def _efficient_transpose_outer(mesh, in_reps, *args): lvl = core.dynamic_level() with core.new_main(RewriteTrace, dynamic=True, mesh=mesh, dyna=lvl) as main: - t = main.with_cur_sublevel() - in_tracers = map(partial(RewriteTracer, t), in_reps, args) - ans = yield in_tracers, {} - out_tracers = map(t.full_raise, ans) - out_vals, out_reps = unzip2((t.val, t.rep) for t in out_tracers) - del main, t, in_tracers, out_tracers, ans + out_vals, out_reps = yield (main, mesh, in_reps, args), {} + del main yield out_vals, out_reps +@lu.transformation +def _efficient_transpose_inner(main, mesh, in_reps, args): + t = main.with_cur_sublevel() + in_tracers = map(partial(RewriteTracer, t), in_reps, args) + ans = yield in_tracers, {} + out_tracers = map(t.full_raise, ans) + yield unzip2((t.val, t.rep) for t in out_tracers) + @lu.transformation def _match_rep(mesh, out_reps_src_, out_reps_dst_, *args): outs = yield args, {} @@ -1982,3 +2103,13 @@ def _match_replication(src, dst, x): if src - dst: x = pbroadcast(x, tuple(n for n in src if n not in dst)) return x + +# TODO(parkers,mattjj): change implementation when we have sharding-in-types. +def get_replication(x: jax.Array) -> set[AxisName]: + """For a jax.Array, return what axes it is known to be replicated along.""" + + if isinstance(x, RewriteTracer): + return x.rep + if isinstance(x, batching.BatchTracer): + return get_replication(x.val) + raise ValueError("get_replication not defined on %s" % repr(type(x))) diff --git a/jax/experimental/sparse/__init__.py b/jax/experimental/sparse/__init__.py index 8ab8cd88721d..f388cd527cf9 100644 --- a/jax/experimental/sparse/__init__.py +++ b/jax/experimental/sparse/__init__.py @@ -189,7 +189,7 @@ """ # Note: import as is required for names to be exported. -# See PEP 484 & https://github.com/google/jax/issues/7570 +# See PEP 484 & https://github.com/jax-ml/jax/issues/7570 from jax.experimental.sparse.ad import ( jacfwd as jacfwd, diff --git a/jax/experimental/sparse/bcoo.py b/jax/experimental/sparse/bcoo.py index 4cbe52383751..b20ed8da0326 100644 --- a/jax/experimental/sparse/bcoo.py +++ b/jax/experimental/sparse/bcoo.py @@ -105,7 +105,7 @@ def _bcoo_set_nse(mat: BCOO, nse: int) -> BCOO: unique_indices=mat.unique_indices) # TODO(jakevdp) this can be problematic when used with autodiff; see -# https://github.com/google/jax/issues/10163. Should this be a primitive? +# https://github.com/jax-ml/jax/issues/10163. Should this be a primitive? # Alternatively, maybe roll this into bcoo_sum_duplicates as an optional argument. def bcoo_eliminate_zeros(mat: BCOO, nse: int | None = None) -> BCOO: data, indices, shape = mat.data, mat.indices, mat.shape @@ -332,11 +332,11 @@ def _bcoo_fromdense_jvp(primals, tangents, *, nse, n_batch, n_dense, index_dtype data, indices = primals_out if type(Mdot) is ad.Zero: - data_dot = ad.Zero.from_value(data) + data_dot = ad.Zero.from_primal_value(data) else: data_dot = _bcoo_extract(indices, Mdot) - tangents_out = (data_dot, ad.Zero.from_value(indices)) + tangents_out = (data_dot, ad.Zero.from_primal_value(indices)) return primals_out, tangents_out @@ -571,7 +571,7 @@ def _bcoo_transpose_jvp(primals, tangents, *, permutation: Sequence[int], spinfo data_dot, _ = tangents primals_out = _bcoo_transpose(data, indices, permutation=permutation, spinfo=spinfo) data_dot_out, _ = _bcoo_transpose(data_dot, indices, permutation=permutation, spinfo=spinfo) - return primals_out, (data_dot_out, ad.Zero.from_value(indices)) + return primals_out, (data_dot_out, ad.Zero.from_primal_value(indices)) def _bcoo_transpose_transpose(ct, data, indices, *, permutation: Sequence[int], spinfo: SparseInfo): data_ct, indices_ct = ct @@ -609,7 +609,8 @@ def _bcoo_transpose_batch_rule(batched_args, batch_dims, *, permutation: Sequenc bcoo_dot_general_p = core.Primitive('bcoo_dot_general') def bcoo_dot_general(lhs: BCOO | Array, rhs: BCOO | Array, *, dimension_numbers: DotDimensionNumbers, - precision: None = None, preferred_element_type: None = None) -> BCOO | Array: + precision: None = None, preferred_element_type: None = None, + algorithm: None = None, transpose_algorithm: None = None) -> BCOO | Array: """A general contraction operation. Args: @@ -620,6 +621,8 @@ def bcoo_dot_general(lhs: BCOO | Array, rhs: BCOO | Array, *, dimension_numbers: (lhs_batch_dims, rhs_batch_dims))`. precision: unused preferred_element_type: unused + algorithm: unused + transpose_algorithm: unused Returns: An ndarray or BCOO-format sparse array containing the result. If both inputs @@ -627,7 +630,7 @@ def bcoo_dot_general(lhs: BCOO | Array, rhs: BCOO | Array, *, dimension_numbers: the result will be dense, of type ndarray. """ # TODO(jakevdp) make use of these? - del precision # unused + del precision, algorithm, transpose_algorithm # unused if isinstance(lhs, BCOO) and isinstance(rhs, BCOO): shape = _dot_general_validated_shape(lhs.shape, rhs.shape, dimension_numbers) @@ -738,12 +741,11 @@ def result(out_array, lhs_data, lhs_indices, rhs): @bcoo_dot_general_p.def_abstract_eval def _bcoo_dot_general_abstract_eval(lhs_data, lhs_indices, rhs, *, dimension_numbers, preferred_element_type, lhs_spinfo: SparseInfo): - out_aval = jax.eval_shape( - partial(lax.dot_general, - dimension_numbers=dimension_numbers, - preferred_element_type=preferred_element_type), - jax.ShapeDtypeStruct(lhs_spinfo.shape, lhs_data.dtype), - jax.ShapeDtypeStruct(rhs.shape, rhs.dtype)) + out_aval = jax.jit(lax.dot_general, static_argnames=("dimension_numbers", "preferred_element_type")).eval_shape( + jax.ShapeDtypeStruct(lhs_spinfo.shape, lhs_data.dtype), + jax.ShapeDtypeStruct(rhs.shape, rhs.dtype), + dimension_numbers=dimension_numbers, + preferred_element_type=preferred_element_type) (lhs_contracting, _), (lhs_batch, _) = dimension_numbers n_batch, n_sparse, _, _ = _validate_bcoo(lhs_data, lhs_indices, lhs_spinfo.shape) @@ -1054,7 +1056,9 @@ def _bcoo_dot_general_sampled_transpose(ct, A, B, indices, *, dimension_numbers) indices, ct = _bcoo_extract_transpose(ct, indices, mat, assume_unique=True) kwds = {'dimension_numbers': dimension_numbers, 'precision': None, - 'preferred_element_type': None} + 'preferred_element_type': None, + 'algorithm': None, + 'transpose_algorithm': None} A, B = ad.get_primitive_transpose(lax.dot_general_p)(ct, A, B, **kwds) return A, B, indices @@ -1141,7 +1145,7 @@ def _bcoo_spdot_general_unbatched(lhs_data, lhs_indices, rhs_data, rhs_indices, out_indices = out_indices.at[:, :, lhs_j.shape[-1]:].set(rhs_j[None, :]) out_indices = out_indices.reshape(len(out_data), out_indices.shape[-1]) # Note: we do not eliminate zeros here, because it can cause issues with autodiff. - # See https://github.com/google/jax/issues/10163. + # See https://github.com/jax-ml/jax/issues/10163. return _bcoo_sum_duplicates(out_data, out_indices, spinfo=SparseInfo(shape=out_shape), nse=out_nse) @bcoo_spdot_general_p.def_impl @@ -1187,12 +1191,11 @@ def _bcoo_spdot_general_abstract_eval(lhs_data, lhs_indices, rhs_data, rhs_indic dimension_numbers, preferred_element_type): lhs_shape = lhs_spinfo.shape rhs_shape = rhs_spinfo.shape - out_aval = jax.eval_shape( - partial(lax.dot_general, - dimension_numbers=dimension_numbers, - preferred_element_type=preferred_element_type), - jax.ShapeDtypeStruct(lhs_shape, lhs_data.dtype), - jax.ShapeDtypeStruct(rhs_shape, rhs_data.dtype)) + out_aval = jax.jit(lax.dot_general, static_argnames=("dimension_numbers", "preferred_element_type")).eval_shape( + jax.ShapeDtypeStruct(lhs_shape, lhs_data.dtype), + jax.ShapeDtypeStruct(rhs_shape, rhs_data.dtype), + dimension_numbers=dimension_numbers, + preferred_element_type=preferred_element_type) lhs = _validate_bcoo(lhs_data, lhs_indices, lhs_shape) rhs = _validate_bcoo(rhs_data, rhs_indices, rhs_shape) @@ -1279,7 +1282,7 @@ def _bcoo_spdot_general_jvp(primals, tangents, **kwds): data_dot_out += _bcoo_spdot_general(lhs_data_dot, lhs_indices, rhs_data, rhs_indices, **kwds)[0] if type(rhs_data_dot) is not ad.Zero: data_dot_out += _bcoo_spdot_general(lhs_data, lhs_indices, rhs_data_dot, rhs_indices, **kwds)[0] - return primals_out, [data_dot_out, ad.Zero.from_value(primals_out[1])] + return primals_out, [data_dot_out, ad.Zero.from_primal_value(primals_out[1])] # TODO(JVP): transpose rule batching.primitive_batchers[bcoo_spdot_general_p] = _bcoo_spdot_general_batch_rule @@ -1360,8 +1363,8 @@ def _bcoo_sort_indices_jvp(primals, tangents, *, spinfo): permute = nfold_vmap(lambda d, p: d[p], props.n_batch) data_out = permute(data, perm) - indices_dot_out = ad.Zero.from_value(indices) - data_dot_out = ad.Zero.from_value(data_out) if type(data_dot) is ad.Zero else permute(data_dot, perm) + indices_dot_out = ad.Zero.from_primal_value(indices) + data_dot_out = ad.Zero.from_primal_value(data_out) if type(data_dot) is ad.Zero else permute(data_dot, perm) return (data_out, indices_out), (data_dot_out, indices_dot_out) _bcoo_sort_indices_hlo = mlir.lower_fun( @@ -1539,15 +1542,15 @@ def _bcoo_sum_duplicates_jvp(primals, tangents, *, spinfo, nse): nse, *data.shape[props.n_batch + 1:]), dtype=data.dtype) data_dot_out = data_out # This check is because scatter-add on zero-sized arrays has poorly defined - # semantics; see https://github.com/google/jax/issues/13656. + # semantics; see https://github.com/jax-ml/jax/issues/13656. if data_out.size: permute = lambda x, i, y: x.at[i].add(y, mode='drop') else: permute = lambda x, i, y: x permute = nfold_vmap(permute, props.n_batch) data_out = permute(data_out, mapping, data) - indices_dot_out = ad.Zero.from_value(indices_out) - data_dot_out = ad.Zero.from_value(data_out) if type(data_dot) is ad.Zero else permute(data_dot_out, mapping, data_dot) + indices_dot_out = ad.Zero.from_primal_value(indices_out) + data_dot_out = ad.Zero.from_primal_value(data_out) if type(data_dot) is ad.Zero else permute(data_dot_out, mapping, data_dot) return (data_out, indices_out), (data_dot_out, indices_dot_out) _bcoo_sum_duplicates_hlo = mlir.lower_fun( @@ -1773,9 +1776,9 @@ def bcoo_concatenate(operands: Sequence[BCOO], *, dimension: int) -> BCOO: raise ValueError("bcoo_concatenate: expected operands to be a sequence of BCOO arrays. " f"Got {operands}") # Validate inputs using lax.concatenate abstract evaluation. - out_aval = jax.eval_shape( - functools.partial(lax.concatenate, dimension=dimension), - [core.ShapedArray(op.shape, op.dtype) for op in operands]) + out_aval = jax.jit(lax.concatenate, static_argnames=("dimension",)).eval_shape( + [core.ShapedArray(op.shape, op.dtype) for op in operands], + dimension=dimension) if len({op.n_dense for op in operands}) > 1: raise ValueError("bcoo_concatenate requires inputs to have matching nse dimensions.") @@ -1891,8 +1894,9 @@ def bcoo_reshape(mat: BCOO, *, new_sizes: Sequence[int], dimensions: Sequence[in def bcoo_rev(operand, dimensions): """Sparse implementation of {func}`jax.lax.rev`""" # Check validity of dimensions via original implementation. - _ = jax.eval_shape(partial(lax.rev, dimensions=dimensions), - jax.ShapeDtypeStruct(operand.shape, operand.dtype)) + _ = jax.jit(lax.rev, static_argnames=("dimensions",)).eval_shape( + jax.ShapeDtypeStruct(operand.shape, operand.dtype), + dimensions=dimensions) batch_dims = [d for d in dimensions if d < operand.n_batch] sparse_dims = [d for d in dimensions if operand.n_batch <= d < operand.n_batch + operand.n_sparse] dense_dims = [d for d in dimensions if d >= operand.n_batch + operand.n_sparse] @@ -2035,15 +2039,16 @@ def bcoo_dynamic_slice(mat: BCOO, start_indices: Sequence[Any], slice_sizes: Seq Returns: out: BCOO array containing the slice. """ + slice_sizes = tuple(operator.index(i) for i in slice_sizes) # Use abstract eval to validate inputs. - jax.eval_shape(partial(lax.dynamic_slice, slice_sizes=slice_sizes), - jax.ShapeDtypeStruct(mat.shape, mat.dtype), start_indices) + jax.jit(lax.dynamic_slice, static_argnames=("slice_sizes",)).eval_shape( + jax.ShapeDtypeStruct(mat.shape, mat.dtype), start_indices, + slice_sizes=slice_sizes) if not isinstance(mat, BCOO): raise TypeError(f"bcoo_slice: input should be BCOO array, got type(mat)={type(mat)}") start_indices = tuple(jnp.asarray(i) for i in start_indices) assert all(jnp.issubdtype(i.dtype, np.integer) for i in start_indices) assert all(i.shape == () for i in start_indices) - slice_sizes = tuple(operator.index(i) for i in slice_sizes) if len(start_indices) != len(slice_sizes) != mat.ndim: raise ValueError(f"bcoo_dynamic_slice: indices must have size mat.ndim={mat.ndim}") if not all(0 <= slice_size <= axis_size for slice_size, axis_size in zip(slice_sizes, mat.shape)): @@ -2303,9 +2308,13 @@ def bcoo_gather(operand: BCOO, start_indices: Array, mode=mode, fill_value=fill_value) # Abstract eval lax.gather to validate arguments & determine output shape. - out_aval = jax.eval_shape(partial(lax.gather, **kwds), - jax.ShapeDtypeStruct(operand.shape, operand.dtype), - jax.ShapeDtypeStruct(start_indices.shape, start_indices.dtype)) + static_argnames = ("dimension_numbers", "slice_sizes", "unique_indices", + "indices_are_sorted", "mode", "fill_value",) + out_aval = jax.jit(lax.gather, static_argnames=static_argnames).eval_shape( + jax.ShapeDtypeStruct(operand.shape, operand.dtype), + jax.ShapeDtypeStruct(start_indices.shape, start_indices.dtype), + **kwds) + offset_dims = dimension_numbers.offset_dims collapsed_slice_dims = dimension_numbers.collapsed_slice_dims start_index_map = dimension_numbers.start_index_map diff --git a/jax/experimental/sparse/bcsr.py b/jax/experimental/sparse/bcsr.py index 7f3ebb43c0ec..1b877aec9c75 100644 --- a/jax/experimental/sparse/bcsr.py +++ b/jax/experimental/sparse/bcsr.py @@ -272,11 +272,11 @@ def _bcsr_fromdense_jvp(primals, tangents, *, nse, n_batch, n_dense, index_dtype data, indices, indptr = primals_out if type(Mdot) is ad.Zero: - data_dot = ad.Zero.from_value(data) + data_dot = ad.Zero.from_primal_value(data) else: data_dot = bcsr_extract(indices, indptr, Mdot) - tangents_out = (data_dot, ad.Zero.from_value(indices), ad.Zero.from_value(indptr)) + tangents_out = (data_dot, ad.Zero.from_primal_value(indices), ad.Zero.from_primal_value(indptr)) return primals_out, tangents_out @@ -463,7 +463,9 @@ def _bcsr_extract_batching_rule(batched_args, batch_dims): def bcsr_dot_general(lhs: BCSR | Array, rhs: Array, *, dimension_numbers: DotDimensionNumbers, precision: None = None, - preferred_element_type: None = None) -> Array: + preferred_element_type: None = None, + algorithm: None = None, + transpose_algorithm: None = None) -> Array: """A general contraction operation. Args: @@ -474,13 +476,15 @@ def bcsr_dot_general(lhs: BCSR | Array, rhs: Array, *, (lhs_batch_dims, rhs_batch_dims))`. precision: unused preferred_element_type: unused + algorithm: unused + transpose_algorithm: unused Returns: An ndarray or BCSR-format sparse array containing the result. If both inputs are sparse, the result will be sparse, of type BCSR. If either input is dense, the result will be dense, of type ndarray. """ - del precision # unused + del precision, algorithm, transpose_algorithm # unused if isinstance(rhs, (np.ndarray, jax.Array)): if isinstance(lhs, (np.ndarray, jax.Array)): return lax.dot_general(lhs, rhs, dimension_numbers=dimension_numbers, diff --git a/jax/experimental/sparse/coo.py b/jax/experimental/sparse/coo.py index 8863478df4d3..c65bc87235d6 100644 --- a/jax/experimental/sparse/coo.py +++ b/jax/experimental/sparse/coo.py @@ -348,11 +348,11 @@ def _coo_fromdense_jvp(primals, tangents, *, nse, index_dtype): data, row, col = primals_out if type(Mdot) is ad.Zero: - data_dot = ad.Zero.from_value(data) + data_dot = ad.Zero.from_primal_value(data) else: data_dot = _coo_extract(row, col, Mdot) - tangents_out = (data_dot, ad.Zero.from_value(row), ad.Zero.from_value(col)) + tangents_out = (data_dot, ad.Zero.from_primal_value(row), ad.Zero.from_primal_value(col)) return primals_out, tangents_out diff --git a/jax/experimental/sparse/csr.py b/jax/experimental/sparse/csr.py index c1178943c02a..89d08f109d68 100644 --- a/jax/experimental/sparse/csr.py +++ b/jax/experimental/sparse/csr.py @@ -380,11 +380,11 @@ def _csr_fromdense_jvp(primals, tangents, *, nse, index_dtype): data, indices, indptr = primals_out if type(Mdot) is ad.Zero: - data_dot = ad.Zero.from_value(data) + data_dot = ad.Zero.from_primal_value(data) else: data_dot = _csr_extract(indices, indptr, Mdot) - tangents_out = (data_dot, ad.Zero.from_value(indices), ad.Zero.from_value(indptr)) + tangents_out = (data_dot, ad.Zero.from_primal_value(indices), ad.Zero.from_primal_value(indptr)) return primals_out, tangents_out diff --git a/jax/experimental/sparse/linalg.py b/jax/experimental/sparse/linalg.py index b0ac1fa5d380..184eb9741dd2 100644 --- a/jax/experimental/sparse/linalg.py +++ b/jax/experimental/sparse/linalg.py @@ -602,7 +602,9 @@ def spsolve(data, indices, indptr, b, tol=1e-6, reorder=1): """A sparse direct solver using QR factorization. Accepts a sparse matrix in CSR format `data, indices, indptr` arrays. - Currently only the CUDA GPU backend is implemented. + Currently only the CUDA GPU backend is implemented, the CPU backend will fall + back to `scipy.sparse.linalg.spsolve`. Neither the CPU nor the GPU + implementation support batching with `vmap`. Args: data : An array containing the non-zero entries of the CSR matrix. diff --git a/jax/experimental/sparse/util.py b/jax/experimental/sparse/util.py index 9aa9e42f2a60..c79dee09cec2 100644 --- a/jax/experimental/sparse/util.py +++ b/jax/experimental/sparse/util.py @@ -113,4 +113,5 @@ def _dot_general_validated_shape( rhs = core.ShapedArray(rhs_shape, np.float32) return _dot_general_shape_rule( lhs, rhs, dimension_numbers=dimension_numbers, - precision=None, preferred_element_type=None) + precision=None, preferred_element_type=None, algorithm=None, + transpose_algorithm=None) diff --git a/jax/experimental/xla_metadata.py b/jax/experimental/xla_metadata.py new file mode 100644 index 000000000000..fb15e4743d2b --- /dev/null +++ b/jax/experimental/xla_metadata.py @@ -0,0 +1,17 @@ +# Copyright 2024 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the ific language governing permissions and +# limitations under the License. + +from jax._src.xla_metadata import ( + set_xla_metadata as set_xla_metadata, +) diff --git a/jax/extend/BUILD b/jax/extend/BUILD index babe0c8b10d2..59958c1da389 100644 --- a/jax/extend/BUILD +++ b/jax/extend/BUILD @@ -80,3 +80,9 @@ pytype_strict_library( srcs = ["ffi.py"], deps = ["//jax"], ) + +pytype_strict_library( + name = "ifrt_programs", + srcs = ["ifrt_programs.py"], + deps = ["//jax/_src/lib"], +) diff --git a/jax/extend/backend.py b/jax/extend/backend.py index 66fd149d7c8e..b1e471133482 100644 --- a/jax/extend/backend.py +++ b/jax/extend/backend.py @@ -13,7 +13,7 @@ # limitations under the License. # Note: import as is required for names to be exported. -# See PEP 484 & https://github.com/google/jax/issues/7570 +# See PEP 484 & https://github.com/jax-ml/jax/issues/7570 from jax._src.api import ( clear_backends as clear_backends, diff --git a/jax/extend/core/__init__.py b/jax/extend/core/__init__.py index 2732b1984c1d..9f1632fb37a9 100644 --- a/jax/extend/core/__init__.py +++ b/jax/extend/core/__init__.py @@ -13,7 +13,7 @@ # limitations under the License. # Note: import as is required for names to be exported. -# See PEP 484 & https://github.com/google/jax/issues/7570 +# See PEP 484 & https://github.com/jax-ml/jax/issues/7570 from jax._src.abstract_arrays import ( array_types as array_types diff --git a/jax/extend/core/primitives.py b/jax/extend/core/primitives.py index e37287180eee..feb70b5171be 100644 --- a/jax/extend/core/primitives.py +++ b/jax/extend/core/primitives.py @@ -13,7 +13,7 @@ # limitations under the License. # Note: import as is required for names to be exported. -# See PEP 484 & https://github.com/google/jax/issues/7570 +# See PEP 484 & https://github.com/jax-ml/jax/issues/7570 from jax._src.ad_util import stop_gradient_p as stop_gradient_p diff --git a/jax/extend/ffi.py b/jax/extend/ffi.py index 3a26030c1687..b2d480adc7eb 100644 --- a/jax/extend/ffi.py +++ b/jax/extend/ffi.py @@ -13,7 +13,7 @@ # limitations under the License. # Note: import as is required for names to be exported. -# See PEP 484 & https://github.com/google/jax/issues/7570 +# See PEP 484 & https://github.com/jax-ml/jax/issues/7570 from jax._src.extend.ffi import ( ffi_call as ffi_call, diff --git a/jax/extend/ifrt_programs.py b/jax/extend/ifrt_programs.py new file mode 100644 index 000000000000..715dfd43592c --- /dev/null +++ b/jax/extend/ifrt_programs.py @@ -0,0 +1,22 @@ +# Copyright 2024 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Note: import as is required for names to be exported. +# See PEP 484 & https://github.com/jax-ml/jax/issues/7570 + +from jax._src.lib import xla_extension as _xe + +ifrt_programs = _xe.ifrt_programs + +del _xe diff --git a/jax/extend/linear_util.py b/jax/extend/linear_util.py index 1706f8c8c30b..74c52dddbae8 100644 --- a/jax/extend/linear_util.py +++ b/jax/extend/linear_util.py @@ -13,7 +13,7 @@ # limitations under the License. # Note: import as is required for names to be exported. -# See PEP 484 & https://github.com/google/jax/issues/7570 +# See PEP 484 & https://github.com/jax-ml/jax/issues/7570 from jax._src.linear_util import ( StoreException as StoreException, diff --git a/jax/extend/random.py b/jax/extend/random.py index a055c75751bd..d6e0cfaab0e4 100644 --- a/jax/extend/random.py +++ b/jax/extend/random.py @@ -13,7 +13,7 @@ # limitations under the License. # Note: import as is required for names to be exported. -# See PEP 484 & https://github.com/google/jax/issues/7570 +# See PEP 484 & https://github.com/jax-ml/jax/issues/7570 from jax._src.extend.random import ( define_prng_impl as define_prng_impl, diff --git a/jax/extend/source_info_util.py b/jax/extend/source_info_util.py index f74df2cab5e1..f031dabef48d 100644 --- a/jax/extend/source_info_util.py +++ b/jax/extend/source_info_util.py @@ -13,7 +13,7 @@ # limitations under the License. # Note: import as is required for names to be exported. -# See PEP 484 & https://github.com/google/jax/issues/7570 +# See PEP 484 & https://github.com/jax-ml/jax/issues/7570 from jax._src.source_info_util import ( NameStack as NameStack, diff --git a/jax/image/__init__.py b/jax/image/__init__.py index c7ee8ffa9c64..993395f503fd 100644 --- a/jax/image/__init__.py +++ b/jax/image/__init__.py @@ -21,7 +21,7 @@ """ # Note: import as is required for names to be exported. -# See PEP 484 & https://github.com/google/jax/issues/7570 +# See PEP 484 & https://github.com/jax-ml/jax/issues/7570 from jax._src.image.scale import ( resize as resize, diff --git a/jax/interpreters/ad.py b/jax/interpreters/ad.py index 6663df3ac473..28816afb01e3 100644 --- a/jax/interpreters/ad.py +++ b/jax/interpreters/ad.py @@ -13,7 +13,7 @@ # limitations under the License. # Note: import as is required for names to be exported. -# See PEP 484 & https://github.com/google/jax/issues/7570 +# See PEP 484 & https://github.com/jax-ml/jax/issues/7570 from __future__ import annotations @@ -59,9 +59,7 @@ primitive_jvps as primitive_jvps, primitive_transposes as primitive_transposes, rearrange_binders as rearrange_binders, - recast_to_float0 as recast_to_float0, reducing_transposes as reducing_transposes, - replace_float0s as replace_float0s, standard_jvp as standard_jvp, standard_jvp2 as standard_jvp2, traceable as traceable, diff --git a/jax/interpreters/batching.py b/jax/interpreters/batching.py index 98fad903cc4f..607fc6fa596d 100644 --- a/jax/interpreters/batching.py +++ b/jax/interpreters/batching.py @@ -13,7 +13,7 @@ # limitations under the License. # Note: import as is required for names to be exported. -# See PEP 484 & https://github.com/google/jax/issues/7570 +# See PEP 484 & https://github.com/jax-ml/jax/issues/7570 from jax._src.interpreters.batching import ( Array as Array, diff --git a/jax/interpreters/partial_eval.py b/jax/interpreters/partial_eval.py index 706f5a2fe253..3c63948bee63 100644 --- a/jax/interpreters/partial_eval.py +++ b/jax/interpreters/partial_eval.py @@ -91,7 +91,6 @@ trace_to_subjaxpr_dynamic as trace_to_subjaxpr_dynamic, trace_to_subjaxpr_dynamic2 as trace_to_subjaxpr_dynamic2, trace_to_subjaxpr_nounits as trace_to_subjaxpr_nounits, - trace_to_subjaxpr_nounits_dyn as trace_to_subjaxpr_nounits_dyn, trace_to_subjaxpr_nounits_fwd as trace_to_subjaxpr_nounits_fwd, tracers_to_jaxpr as tracers_to_jaxpr, trivial_ctx as trivial_ctx, diff --git a/jax/lax/__init__.py b/jax/lax/__init__.py index bac005b81650..bb72abb2ec32 100644 --- a/jax/lax/__init__.py +++ b/jax/lax/__init__.py @@ -13,12 +13,15 @@ # limitations under the License. # Note: import as is required for names to be exported. -# See PEP 484 & https://github.com/google/jax/issues/7570 +# See PEP 484 & https://github.com/jax-ml/jax/issues/7570 from jax._src.lax.lax import ( DotDimensionNumbers as DotDimensionNumbers, Precision as Precision, PrecisionLike as PrecisionLike, + DotAlgorithm as DotAlgorithm, + DotAlgorithmLike as DotAlgorithmLike, + DotTransposeAlgorithmLike as DotTransposeAlgorithmLike, RandomAlgorithm as RandomAlgorithm, RoundingMethod as RoundingMethod, abs as abs, @@ -142,6 +145,8 @@ nextafter as nextafter, nextafter_p as nextafter_p, not_p as not_p, + optimization_barrier as optimization_barrier, + optimization_barrier_p as optimization_barrier_p, or_p as or_p, outfeed as outfeed, outfeed_p as outfeed_p, diff --git a/jax/lib/xla_client.py b/jax/lib/xla_client.py index 7422e9fcc56d..a51625eb072e 100644 --- a/jax/lib/xla_client.py +++ b/jax/lib/xla_client.py @@ -25,7 +25,6 @@ ArrayImpl = _xc.ArrayImpl Client = _xc.Client CompileOptions = _xc.CompileOptions -Device = _xc.Device DeviceAssignment = _xc.DeviceAssignment FftType = _xc.FftType Frame = _xc.Frame @@ -37,26 +36,41 @@ Traceback = _xc.Traceback XlaBuilder = _xc.XlaBuilder XlaComputation = _xc.XlaComputation -XlaRuntimeError = _xc.XlaRuntimeError _deprecations = { - # Added Aug 5 2024 - "_xla" : ( - "jax.lib.xla_client._xla is deprecated; use jax.lib.xla_extension.", - _xc._xla - ), - "bfloat16" : ( - "jax.lib.xla_client.bfloat16 is deprecated; use ml_dtypes.bfloat16.", - _xc.bfloat16 - ), + # Added Aug 5 2024 + "_xla": ( + "jax.lib.xla_client._xla is deprecated; use jax.lib.xla_extension.", + _xc._xla, + ), + "bfloat16": ( + "jax.lib.xla_client.bfloat16 is deprecated; use ml_dtypes.bfloat16.", + _xc.bfloat16, + ), + # Added Sep 26 2024 + "Device" : ( + "jax.lib.xla_client.Device is deprecated; use jax.Device instead.", + _xc.Device + ), + "XlaRuntimeError": ( + ( + "jax.lib.xla_client.XlaRuntimeError is deprecated; use" + " jax.errors.JaxRuntimeError." + ), + _xc.XlaRuntimeError, + ), } import typing as _typing + if _typing.TYPE_CHECKING: _xla = _xc._xla bfloat16 = _xc.bfloat16 + Device = _xc.Device + XlaRuntimeError = _xc.XlaRuntimeError else: from jax._src.deprecations import deprecation_getattr as _deprecation_getattr + __getattr__ = _deprecation_getattr(__name__, _deprecations) del _deprecation_getattr del _typing diff --git a/jax/nn/__init__.py b/jax/nn/__init__.py index 230aacb7654a..496d03261384 100644 --- a/jax/nn/__init__.py +++ b/jax/nn/__init__.py @@ -15,7 +15,7 @@ """Common functions for neural network libraries.""" # Note: import as is required for names to be exported. -# See PEP 484 & https://github.com/google/jax/issues/7570 +# See PEP 484 & https://github.com/jax-ml/jax/issues/7570 from jax.numpy import tanh as tanh from jax.nn import initializers as initializers diff --git a/jax/nn/initializers.py b/jax/nn/initializers.py index 6c73356ce1a1..019f3e179215 100644 --- a/jax/nn/initializers.py +++ b/jax/nn/initializers.py @@ -18,7 +18,7 @@ """ # Note: import as is required for names to be exported. -# See PEP 484 & https://github.com/google/jax/issues/7570 +# See PEP 484 & https://github.com/jax-ml/jax/issues/7570 from jax._src.nn.initializers import ( constant as constant, diff --git a/jax/numpy/__init__.py b/jax/numpy/__init__.py index 88e1840ef1c0..20c37c55902c 100644 --- a/jax/numpy/__init__.py +++ b/jax/numpy/__init__.py @@ -13,7 +13,7 @@ # limitations under the License. # Note: import as is required for names to be exported. -# See PEP 484 & https://github.com/google/jax/issues/7570 +# See PEP 484 & https://github.com/jax-ml/jax/issues/7570 from jax.numpy import fft as fft from jax.numpy import linalg as linalg @@ -212,7 +212,6 @@ rollaxis as rollaxis, rot90 as rot90, round as round, - round_ as round_, save as save, savez as savez, searchsorted as searchsorted, @@ -466,6 +465,11 @@ _deprecations = { + # Deprecated 03 Sept 2024 + "round_": ( + "jnp.round_ is deprecated; use jnp.round instead.", + round + ), # Deprecated 18 Sept 2023 and removed 06 Feb 2024 "trapz": ( "jnp.trapz is deprecated; use jnp.trapezoid instead.", @@ -473,6 +477,11 @@ ), } -from jax._src.deprecations import deprecation_getattr as _deprecation_getattr -__getattr__ = _deprecation_getattr(__name__, _deprecations) -del _deprecation_getattr +import typing +if typing.TYPE_CHECKING: + round_ = round +else: + from jax._src.deprecations import deprecation_getattr as _deprecation_getattr + __getattr__ = _deprecation_getattr(__name__, _deprecations) + del _deprecation_getattr +del typing diff --git a/jax/numpy/__init__.pyi b/jax/numpy/__init__.pyi index dfea8a8ddd74..c23f659bd3f9 100644 --- a/jax/numpy/__init__.pyi +++ b/jax/numpy/__init__.pyi @@ -3,7 +3,7 @@ from __future__ import annotations import builtins from collections.abc import Callable, Sequence -from typing import Any, Literal, NamedTuple, TypeVar, Union, overload +from typing import Any, Literal, NamedTuple, Protocol, TypeVar, Union, overload from jax._src import core as _core from jax._src import dtypes as _dtypes @@ -28,6 +28,34 @@ _Device = Device ComplexWarning: type +class BinaryUfunc(Protocol): + @property + def nin(self) -> int: ... + @property + def nout(self) -> int: ... + @property + def nargs(self) -> int: ... + @property + def identity(self) -> builtins.bool | int | float: ... + def __call__(self, x: ArrayLike, y: ArrayLike, /) -> Array: ... + def reduce(self, arr: ArrayLike, /, *, + axis: int | None = 0, + dtype: DTypeLike | None = None, + keepdims: builtins.bool = False, + initial: ArrayLike | None = None, + where: ArrayLike | None = None) -> Array: ... + def accumulate(self, a: ArrayLike, /, *, + axis: int = 0, + dtype: DTypeLike | None = None, + out: None = None) -> Array: ... + def at(self, a: ArrayLike, indices: Any, b: ArrayLike | None = None, /, *, + inplace: builtins.bool = True) -> Array: ... + def reduceat(self, a: ArrayLike, indices: Any, *, + axis: int = 0, + dtype: DTypeLike | None = None, + out: None = None) -> Array: ... + def outer(self, a: ArrayLike, b: ArrayLike, /) -> Array: ... + __array_api_version__: str def __array_namespace_info__() -> ArrayNamespaceInfo: ... @@ -36,7 +64,7 @@ def abs(x: ArrayLike, /) -> Array: ... def absolute(x: ArrayLike, /) -> Array: ... def acos(x: ArrayLike, /) -> Array: ... def acosh(x: ArrayLike, /) -> Array: ... -def add(x: ArrayLike, y: ArrayLike, /) -> Array: ... +add: BinaryUfunc def amax(a: ArrayLike, axis: _Axis = ..., out: None = ..., keepdims: builtins.bool = ..., initial: ArrayLike | None = ..., where: ArrayLike | None = ...) -> Array: ... @@ -162,14 +190,14 @@ def bartlett(M: int) -> Array: ... bfloat16: Any def bincount(x: ArrayLike, weights: ArrayLike | None = ..., minlength: int = ..., *, length: int | None = ...) -> Array: ... -def bitwise_and(x: ArrayLike, y: ArrayLike, /) -> Array: ... +bitwise_and: BinaryUfunc def bitwise_count(x: ArrayLike, /) -> Array: ... def bitwise_invert(x: ArrayLike, /) -> Array: ... def bitwise_left_shift(x: ArrayLike, y: ArrayLike, /) -> Array: ... def bitwise_not(x: ArrayLike, /) -> Array: ... -def bitwise_or(x: ArrayLike, y: ArrayLike, /) -> Array: ... +bitwise_or: BinaryUfunc def bitwise_right_shift(x: ArrayLike, y: ArrayLike, /) -> Array: ... -def bitwise_xor(x: ArrayLike, y: ArrayLike, /) -> Array: ... +bitwise_xor: BinaryUfunc def blackman(M: int) -> Array: ... def block(arrays: ArrayLike | Sequence[ArrayLike] | Sequence[Sequence[ArrayLike]]) -> Array: ... bool: Any @@ -244,14 +272,14 @@ def cross( axis: int | None = ..., ) -> Array: ... csingle: Any -def cumprod(a: ArrayLike, axis: _Axis = ..., dtype: DTypeLike = ..., +def cumprod(a: ArrayLike, axis: int | None = ..., dtype: DTypeLike = ..., out: None = ...) -> Array: ... cumproduct = cumprod -def cumsum(a: ArrayLike, axis: _Axis = ..., dtype: DTypeLike = ..., +def cumsum(a: ArrayLike, axis: int | None = ..., dtype: DTypeLike = ..., out: None = ...) -> Array: ... def cumulative_sum(x: ArrayLike, /, *, axis: int | None = ..., dtype: DTypeLike | None = ..., - include_initial: bool = ...) -> Array: ... + include_initial: builtins.bool = ...) -> Array: ... def deg2rad(x: ArrayLike, /) -> Array: ... degrees = rad2deg @@ -272,7 +300,8 @@ def diagonal( def diff(a: ArrayLike, n: int = ..., axis: int = ..., prepend: ArrayLike | None = ..., append: ArrayLike | None = ...) -> Array: ... -def digitize(x: ArrayLike, bins: ArrayLike, right: builtins.bool = ...) -> Array: ... +def digitize(x: ArrayLike, bins: ArrayLike, right: builtins.bool = ..., *, + method: str | None = ...) -> Array: ... divide = true_divide def divmod(x: ArrayLike, y: ArrayLike, /) -> tuple[Array, Array]: ... def dot( @@ -495,7 +524,8 @@ def interp(x: ArrayLike, xp: ArrayLike, fp: ArrayLike, right: ArrayLike | str | None = ..., period: ArrayLike | None = ...) -> Array: ... def intersect1d(ar1: ArrayLike, ar2: ArrayLike, assume_unique: builtins.bool = ..., - return_indices: builtins.bool = ...) -> Array | tuple[Array, Array, Array]: ... + return_indices: builtins.bool = ..., *, size: int | None = ..., + fill_value: ArrayLike | None = ...) -> Array | tuple[Array, Array, Array]: ... def invert(x: ArrayLike, /) -> Array: ... def isclose(a: ArrayLike, b: ArrayLike, rtol: ArrayLike = ..., atol: ArrayLike = ..., equal_nan: builtins.bool = ...) -> Array: ... @@ -556,10 +586,10 @@ def log1p(x: ArrayLike, /) -> Array: ... def log2(x: ArrayLike, /) -> Array: ... def logaddexp(x: ArrayLike, y: ArrayLike, /) -> Array: ... def logaddexp2(x: ArrayLike, y: ArrayLike, /) -> Array: ... -def logical_and(x: ArrayLike, y: ArrayLike, /) -> Array: ... +logical_and: BinaryUfunc def logical_not(x: ArrayLike, /) -> Array: ... -def logical_or(x: ArrayLike, y: ArrayLike, /) -> Array: ... -def logical_xor(x: ArrayLike, y: ArrayLike, /) -> Array: ... +logical_or: BinaryUfunc +logical_xor: BinaryUfunc def logspace(start: ArrayLike, stop: ArrayLike, num: int = ..., endpoint: builtins.bool = ..., base: ArrayLike = ..., dtype: DTypeLike | None = ..., axis: int = ...) -> Array: ... @@ -587,7 +617,7 @@ def mod(x: ArrayLike, y: ArrayLike, /) -> Array: ... def modf(x: ArrayLike, /, out=None) -> tuple[Array, Array]: ... def moveaxis(a: ArrayLike, source: int | Sequence[int], destination: int | Sequence[int]) -> Array: ... -def multiply(x: ArrayLike, y: ArrayLike, /) -> Array: ... +multiply: BinaryUfunc nan: float def nan_to_num(x: ArrayLike, copy: builtins.bool = ..., nan: ArrayLike = ..., posinf: ArrayLike | None = ..., @@ -604,9 +634,9 @@ def nanargmin( out: None = ..., keepdims: builtins.bool | None = ..., ) -> Array: ... -def nancumprod(a: ArrayLike, axis: _Axis = ..., dtype: DTypeLike = ..., +def nancumprod(a: ArrayLike, axis: int | None = ..., dtype: DTypeLike = ..., out: None = ...) -> Array: ... -def nancumsum(a: ArrayLike, axis: _Axis = ..., dtype: DTypeLike = ..., +def nancumsum(a: ArrayLike, axis: int | None = ..., dtype: DTypeLike = ..., out: None = ...) -> Array: ... def nanmax(a: ArrayLike, axis: _Axis = ..., out: None = ..., keepdims: builtins.bool = ..., initial: ArrayLike | None = ..., diff --git a/jax/numpy/fft.py b/jax/numpy/fft.py index 24a271487d5e..c268c2d65597 100644 --- a/jax/numpy/fft.py +++ b/jax/numpy/fft.py @@ -13,7 +13,7 @@ # limitations under the License. # Note: import as is required for names to be exported. -# See PEP 484 & https://github.com/google/jax/issues/7570 +# See PEP 484 & https://github.com/jax-ml/jax/issues/7570 from jax._src.numpy.fft import ( ifft as ifft, diff --git a/jax/numpy/linalg.py b/jax/numpy/linalg.py index c342fde0ae6e..05b5ff6db289 100644 --- a/jax/numpy/linalg.py +++ b/jax/numpy/linalg.py @@ -13,7 +13,7 @@ # limitations under the License. # Note: import as is required for names to be exported. -# See PEP 484 & https://github.com/google/jax/issues/7570 +# See PEP 484 & https://github.com/jax-ml/jax/issues/7570 from jax._src.numpy.linalg import ( cholesky as cholesky, diff --git a/jax/ops/__init__.py b/jax/ops/__init__.py index c61a44fd1357..5e1f3d682589 100644 --- a/jax/ops/__init__.py +++ b/jax/ops/__init__.py @@ -13,7 +13,7 @@ # limitations under the License. # Note: import as is required for names to be exported. -# See PEP 484 & https://github.com/google/jax/issues/7570 +# See PEP 484 & https://github.com/jax-ml/jax/issues/7570 from jax._src.ops.scatter import ( segment_sum as segment_sum, diff --git a/jax/profiler.py b/jax/profiler.py index 01ea6e2222cc..77157dc02a13 100644 --- a/jax/profiler.py +++ b/jax/profiler.py @@ -13,7 +13,7 @@ # limitations under the License. # Note: import as is required for names to be exported. -# See PEP 484 & https://github.com/google/jax/issues/7570 +# See PEP 484 & https://github.com/jax-ml/jax/issues/7570 from jax._src.profiler import ( StepTraceAnnotation as StepTraceAnnotation, diff --git a/jax/random.py b/jax/random.py index 5c2eaf81f2bc..29a625389811 100644 --- a/jax/random.py +++ b/jax/random.py @@ -103,7 +103,7 @@ **TLDR**: JAX PRNG = `Threefry counter PRNG `_ + a functional array-oriented `splitting model `_ -See `docs/jep/263-prng.md `_ +See `docs/jep/263-prng.md `_ for more details. To summarize, among other requirements, the JAX PRNG aims to: @@ -201,7 +201,7 @@ """ # Note: import as is required for names to be exported. -# See PEP 484 & https://github.com/google/jax/issues/7570 +# See PEP 484 & https://github.com/jax-ml/jax/issues/7570 from jax._src.random import ( PRNGKey as PRNGKey, diff --git a/jax/scipy/__init__.py b/jax/scipy/__init__.py index c0746910dd3f..cf44b6e179c0 100644 --- a/jax/scipy/__init__.py +++ b/jax/scipy/__init__.py @@ -13,7 +13,7 @@ # limitations under the License. # Note: import as is required for names to be exported. -# See PEP 484 & https://github.com/google/jax/issues/7570 +# See PEP 484 & https://github.com/jax-ml/jax/issues/7570 from typing import TYPE_CHECKING diff --git a/jax/scipy/cluster/__init__.py b/jax/scipy/cluster/__init__.py index 5a01ea0ee493..ea35467f6353 100644 --- a/jax/scipy/cluster/__init__.py +++ b/jax/scipy/cluster/__init__.py @@ -13,6 +13,6 @@ # limitations under the License. # Note: import as is required for names to be exported. -# See PEP 484 & https://github.com/google/jax/issues/7570 +# See PEP 484 & https://github.com/jax-ml/jax/issues/7570 from jax.scipy.cluster import vq as vq diff --git a/jax/scipy/cluster/vq.py b/jax/scipy/cluster/vq.py index 3a46ce52f468..eeeabb7224bc 100644 --- a/jax/scipy/cluster/vq.py +++ b/jax/scipy/cluster/vq.py @@ -13,6 +13,6 @@ # limitations under the License. # Note: import as is required for names to be exported. -# See PEP 484 & https://github.com/google/jax/issues/7570 +# See PEP 484 & https://github.com/jax-ml/jax/issues/7570 from jax._src.scipy.cluster.vq import vq as vq diff --git a/jax/scipy/fft.py b/jax/scipy/fft.py index b8005b72f349..d3c2de09935a 100644 --- a/jax/scipy/fft.py +++ b/jax/scipy/fft.py @@ -13,7 +13,7 @@ # limitations under the License. # Note: import as is required for names to be exported. -# See PEP 484 & https://github.com/google/jax/issues/7570 +# See PEP 484 & https://github.com/jax-ml/jax/issues/7570 from jax._src.scipy.fft import ( dct as dct, diff --git a/jax/scipy/integrate.py b/jax/scipy/integrate.py index b19aa054ca00..3335f12fd381 100644 --- a/jax/scipy/integrate.py +++ b/jax/scipy/integrate.py @@ -13,7 +13,7 @@ # limitations under the License. # Note: import as is required for names to be exported. -# See PEP 484 & https://github.com/google/jax/issues/7570 +# See PEP 484 & https://github.com/jax-ml/jax/issues/7570 from jax._src.scipy.integrate import ( trapezoid as trapezoid diff --git a/jax/scipy/linalg.py b/jax/scipy/linalg.py index 059f927ec46c..64bc0544000b 100644 --- a/jax/scipy/linalg.py +++ b/jax/scipy/linalg.py @@ -13,7 +13,7 @@ # limitations under the License. # Note: import as is required for names to be exported. -# See PEP 484 & https://github.com/google/jax/issues/7570 +# See PEP 484 & https://github.com/jax-ml/jax/issues/7570 from jax._src.scipy.linalg import ( block_diag as block_diag, diff --git a/jax/scipy/ndimage.py b/jax/scipy/ndimage.py index 2f63e236654c..81d7e3ef27d8 100644 --- a/jax/scipy/ndimage.py +++ b/jax/scipy/ndimage.py @@ -13,7 +13,7 @@ # limitations under the License. # Note: import as is required for names to be exported. -# See PEP 484 & https://github.com/google/jax/issues/7570 +# See PEP 484 & https://github.com/jax-ml/jax/issues/7570 from jax._src.scipy.ndimage import ( map_coordinates as map_coordinates, diff --git a/jax/scipy/optimize/__init__.py b/jax/scipy/optimize/__init__.py index 8a2248733145..f1c7167c33f4 100644 --- a/jax/scipy/optimize/__init__.py +++ b/jax/scipy/optimize/__init__.py @@ -13,7 +13,7 @@ # limitations under the License. # Note: import as is required for names to be exported. -# See PEP 484 & https://github.com/google/jax/issues/7570 +# See PEP 484 & https://github.com/jax-ml/jax/issues/7570 from jax._src.scipy.optimize.minimize import ( minimize as minimize, diff --git a/jax/scipy/signal.py b/jax/scipy/signal.py index 7e39da3f95b1..c46b2fce3572 100644 --- a/jax/scipy/signal.py +++ b/jax/scipy/signal.py @@ -13,7 +13,7 @@ # limitations under the License. # Note: import as is required for names to be exported. -# See PEP 484 & https://github.com/google/jax/issues/7570 +# See PEP 484 & https://github.com/jax-ml/jax/issues/7570 from jax._src.scipy.signal import ( fftconvolve as fftconvolve, diff --git a/jax/scipy/sparse/__init__.py b/jax/scipy/sparse/__init__.py index f2e305e829c8..2968a26b4415 100644 --- a/jax/scipy/sparse/__init__.py +++ b/jax/scipy/sparse/__init__.py @@ -13,6 +13,6 @@ # limitations under the License. # Note: import as is required for names to be exported. -# See PEP 484 & https://github.com/google/jax/issues/7570 +# See PEP 484 & https://github.com/jax-ml/jax/issues/7570 from jax.scipy.sparse import linalg as linalg diff --git a/jax/scipy/sparse/linalg.py b/jax/scipy/sparse/linalg.py index d475ddff81f7..d22e5ec43977 100644 --- a/jax/scipy/sparse/linalg.py +++ b/jax/scipy/sparse/linalg.py @@ -13,7 +13,7 @@ # limitations under the License. # Note: import as is required for names to be exported. -# See PEP 484 & https://github.com/google/jax/issues/7570 +# See PEP 484 & https://github.com/jax-ml/jax/issues/7570 from jax._src.scipy.sparse.linalg import ( cg as cg, diff --git a/jax/scipy/spatial/transform.py b/jax/scipy/spatial/transform.py index 4b532d5f3d50..63e8dd3736b2 100644 --- a/jax/scipy/spatial/transform.py +++ b/jax/scipy/spatial/transform.py @@ -13,7 +13,7 @@ # limitations under the License. # Note: import as is required for names to be exported. -# See PEP 484 & https://github.com/google/jax/issues/7570 +# See PEP 484 & https://github.com/jax-ml/jax/issues/7570 from jax._src.scipy.spatial.transform import ( Rotation as Rotation, diff --git a/jax/scipy/special.py b/jax/scipy/special.py index e244c3705af3..431617d362ea 100644 --- a/jax/scipy/special.py +++ b/jax/scipy/special.py @@ -13,7 +13,7 @@ # limitations under the License. # Note: import as is required for names to be exported. -# See PEP 484 & https://github.com/google/jax/issues/7570 +# See PEP 484 & https://github.com/jax-ml/jax/issues/7570 from jax._src.scipy.special import ( bernoulli as bernoulli, @@ -61,3 +61,7 @@ xlogy as xlogy, zeta as zeta, ) + +from jax._src.third_party.scipy.special import ( + fresnel as fresnel, +) diff --git a/jax/scipy/stats/__init__.py b/jax/scipy/stats/__init__.py index 7aa73f7b5218..7719945f23df 100644 --- a/jax/scipy/stats/__init__.py +++ b/jax/scipy/stats/__init__.py @@ -13,7 +13,7 @@ # limitations under the License. # Note: import as is required for names to be exported. -# See PEP 484 & https://github.com/google/jax/issues/7570 +# See PEP 484 & https://github.com/jax-ml/jax/issues/7570 from jax.scipy.stats import bernoulli as bernoulli from jax.scipy.stats import beta as beta diff --git a/jax/scipy/stats/bernoulli.py b/jax/scipy/stats/bernoulli.py index 46c1e4825d11..1623f71130c1 100644 --- a/jax/scipy/stats/bernoulli.py +++ b/jax/scipy/stats/bernoulli.py @@ -13,7 +13,7 @@ # limitations under the License. # Note: import as is required for names to be exported. -# See PEP 484 & https://github.com/google/jax/issues/7570 +# See PEP 484 & https://github.com/jax-ml/jax/issues/7570 from jax._src.scipy.stats.bernoulli import ( logpmf as logpmf, diff --git a/jax/scipy/stats/beta.py b/jax/scipy/stats/beta.py index 5c57dda6bb56..2a4e7f12f7a5 100644 --- a/jax/scipy/stats/beta.py +++ b/jax/scipy/stats/beta.py @@ -13,7 +13,7 @@ # limitations under the License. # Note: import as is required for names to be exported. -# See PEP 484 & https://github.com/google/jax/issues/7570 +# See PEP 484 & https://github.com/jax-ml/jax/issues/7570 from jax._src.scipy.stats.beta import ( cdf as cdf, diff --git a/jax/scipy/stats/betabinom.py b/jax/scipy/stats/betabinom.py index 48f955d9eaf3..f8adf68f4b2e 100644 --- a/jax/scipy/stats/betabinom.py +++ b/jax/scipy/stats/betabinom.py @@ -13,7 +13,7 @@ # limitations under the License. # Note: import as is required for names to be exported. -# See PEP 484 & https://github.com/google/jax/issues/7570 +# See PEP 484 & https://github.com/jax-ml/jax/issues/7570 from jax._src.scipy.stats.betabinom import ( logpmf as logpmf, diff --git a/jax/scipy/stats/cauchy.py b/jax/scipy/stats/cauchy.py index 4ff79f5f9888..34c9972d09bd 100644 --- a/jax/scipy/stats/cauchy.py +++ b/jax/scipy/stats/cauchy.py @@ -13,7 +13,7 @@ # limitations under the License. # Note: import as is required for names to be exported. -# See PEP 484 & https://github.com/google/jax/issues/7570 +# See PEP 484 & https://github.com/jax-ml/jax/issues/7570 from jax._src.scipy.stats.cauchy import ( cdf as cdf, diff --git a/jax/scipy/stats/chi2.py b/jax/scipy/stats/chi2.py index e17a2e331958..47fcb76db28d 100644 --- a/jax/scipy/stats/chi2.py +++ b/jax/scipy/stats/chi2.py @@ -13,7 +13,7 @@ # limitations under the License. # Note: import as is required for names to be exported. -# See PEP 484 & https://github.com/google/jax/issues/7570 +# See PEP 484 & https://github.com/jax-ml/jax/issues/7570 from jax._src.scipy.stats.chi2 import ( cdf as cdf, diff --git a/jax/scipy/stats/dirichlet.py b/jax/scipy/stats/dirichlet.py index 9368defc8f58..22e9b3cc11cc 100644 --- a/jax/scipy/stats/dirichlet.py +++ b/jax/scipy/stats/dirichlet.py @@ -13,7 +13,7 @@ # limitations under the License. # Note: import as is required for names to be exported. -# See PEP 484 & https://github.com/google/jax/issues/7570 +# See PEP 484 & https://github.com/jax-ml/jax/issues/7570 from jax._src.scipy.stats.dirichlet import ( logpdf as logpdf, diff --git a/jax/scipy/stats/expon.py b/jax/scipy/stats/expon.py index 1ec50ac3f604..8f5c0a0680ce 100644 --- a/jax/scipy/stats/expon.py +++ b/jax/scipy/stats/expon.py @@ -13,7 +13,7 @@ # limitations under the License. # Note: import as is required for names to be exported. -# See PEP 484 & https://github.com/google/jax/issues/7570 +# See PEP 484 & https://github.com/jax-ml/jax/issues/7570 from jax._src.scipy.stats.expon import ( logpdf as logpdf, diff --git a/jax/scipy/stats/gamma.py b/jax/scipy/stats/gamma.py index 8efecafed3bd..531a1e300ca9 100644 --- a/jax/scipy/stats/gamma.py +++ b/jax/scipy/stats/gamma.py @@ -13,7 +13,7 @@ # limitations under the License. # Note: import as is required for names to be exported. -# See PEP 484 & https://github.com/google/jax/issues/7570 +# See PEP 484 & https://github.com/jax-ml/jax/issues/7570 from jax._src.scipy.stats.gamma import ( cdf as cdf, diff --git a/jax/scipy/stats/gennorm.py b/jax/scipy/stats/gennorm.py index c903ff606c25..c760575fa7a6 100644 --- a/jax/scipy/stats/gennorm.py +++ b/jax/scipy/stats/gennorm.py @@ -13,7 +13,7 @@ # limitations under the License. # Note: import as is required for names to be exported. -# See PEP 484 & https://github.com/google/jax/issues/7570 +# See PEP 484 & https://github.com/jax-ml/jax/issues/7570 from jax._src.scipy.stats.gennorm import ( cdf as cdf, diff --git a/jax/scipy/stats/geom.py b/jax/scipy/stats/geom.py index 75f917fc27c7..eb12dbb5a183 100644 --- a/jax/scipy/stats/geom.py +++ b/jax/scipy/stats/geom.py @@ -13,7 +13,7 @@ # limitations under the License. # Note: import as is required for names to be exported. -# See PEP 484 & https://github.com/google/jax/issues/7570 +# See PEP 484 & https://github.com/jax-ml/jax/issues/7570 from jax._src.scipy.stats.geom import ( logpmf as logpmf, diff --git a/jax/scipy/stats/laplace.py b/jax/scipy/stats/laplace.py index 3abe62020398..8f182804daf0 100644 --- a/jax/scipy/stats/laplace.py +++ b/jax/scipy/stats/laplace.py @@ -13,7 +13,7 @@ # limitations under the License. # Note: import as is required for names to be exported. -# See PEP 484 & https://github.com/google/jax/issues/7570 +# See PEP 484 & https://github.com/jax-ml/jax/issues/7570 from jax._src.scipy.stats.laplace import ( cdf as cdf, diff --git a/jax/scipy/stats/logistic.py b/jax/scipy/stats/logistic.py index c25a06856ff7..7cdb26fb1d20 100644 --- a/jax/scipy/stats/logistic.py +++ b/jax/scipy/stats/logistic.py @@ -13,7 +13,7 @@ # limitations under the License. # Note: import as is required for names to be exported. -# See PEP 484 & https://github.com/google/jax/issues/7570 +# See PEP 484 & https://github.com/jax-ml/jax/issues/7570 from jax._src.scipy.stats.logistic import ( cdf as cdf, diff --git a/jax/scipy/stats/multinomial.py b/jax/scipy/stats/multinomial.py index 723d1a645726..392ca405581e 100644 --- a/jax/scipy/stats/multinomial.py +++ b/jax/scipy/stats/multinomial.py @@ -13,7 +13,7 @@ # limitations under the License. # Note: import as is required for names to be exported. -# See PEP 484 & https://github.com/google/jax/issues/7570 +# See PEP 484 & https://github.com/jax-ml/jax/issues/7570 from jax._src.scipy.stats.multinomial import ( logpmf as logpmf, diff --git a/jax/scipy/stats/multivariate_normal.py b/jax/scipy/stats/multivariate_normal.py index 95ad355c75f1..94c4cc50a18c 100644 --- a/jax/scipy/stats/multivariate_normal.py +++ b/jax/scipy/stats/multivariate_normal.py @@ -13,7 +13,7 @@ # limitations under the License. # Note: import as is required for names to be exported. -# See PEP 484 & https://github.com/google/jax/issues/7570 +# See PEP 484 & https://github.com/jax-ml/jax/issues/7570 from jax._src.scipy.stats.multivariate_normal import ( logpdf as logpdf, diff --git a/jax/scipy/stats/norm.py b/jax/scipy/stats/norm.py index f47765adfc68..563e40ce06cd 100644 --- a/jax/scipy/stats/norm.py +++ b/jax/scipy/stats/norm.py @@ -13,7 +13,7 @@ # limitations under the License. # Note: import as is required for names to be exported. -# See PEP 484 & https://github.com/google/jax/issues/7570 +# See PEP 484 & https://github.com/jax-ml/jax/issues/7570 from jax._src.scipy.stats.norm import ( cdf as cdf, diff --git a/jax/scipy/stats/pareto.py b/jax/scipy/stats/pareto.py index bf27ea205948..5e46fd5d0bc7 100644 --- a/jax/scipy/stats/pareto.py +++ b/jax/scipy/stats/pareto.py @@ -13,7 +13,7 @@ # limitations under the License. # Note: import as is required for names to be exported. -# See PEP 484 & https://github.com/google/jax/issues/7570 +# See PEP 484 & https://github.com/jax-ml/jax/issues/7570 from jax._src.scipy.stats.pareto import ( logpdf as logpdf, diff --git a/jax/scipy/stats/poisson.py b/jax/scipy/stats/poisson.py index 2e857bc15a3b..5fcde905f89b 100644 --- a/jax/scipy/stats/poisson.py +++ b/jax/scipy/stats/poisson.py @@ -13,7 +13,7 @@ # limitations under the License. # Note: import as is required for names to be exported. -# See PEP 484 & https://github.com/google/jax/issues/7570 +# See PEP 484 & https://github.com/jax-ml/jax/issues/7570 from jax._src.scipy.stats.poisson import ( logpmf as logpmf, diff --git a/jax/scipy/stats/t.py b/jax/scipy/stats/t.py index d92fcab97bf7..694bcb0b0dfc 100644 --- a/jax/scipy/stats/t.py +++ b/jax/scipy/stats/t.py @@ -13,7 +13,7 @@ # limitations under the License. # Note: import as is required for names to be exported. -# See PEP 484 & https://github.com/google/jax/issues/7570 +# See PEP 484 & https://github.com/jax-ml/jax/issues/7570 from jax._src.scipy.stats.t import ( logpdf as logpdf, diff --git a/jax/scipy/stats/truncnorm.py b/jax/scipy/stats/truncnorm.py index 28d5533b02da..cb8e8958d735 100644 --- a/jax/scipy/stats/truncnorm.py +++ b/jax/scipy/stats/truncnorm.py @@ -13,7 +13,7 @@ # limitations under the License. # Note: import as is required for names to be exported. -# See PEP 484 & https://github.com/google/jax/issues/7570 +# See PEP 484 & https://github.com/jax-ml/jax/issues/7570 from jax._src.scipy.stats.truncnorm import ( cdf as cdf, diff --git a/jax/scipy/stats/uniform.py b/jax/scipy/stats/uniform.py index d0a06c673b3c..fa754125f556 100644 --- a/jax/scipy/stats/uniform.py +++ b/jax/scipy/stats/uniform.py @@ -13,7 +13,7 @@ # limitations under the License. # Note: import as is required for names to be exported. -# See PEP 484 & https://github.com/google/jax/issues/7570 +# See PEP 484 & https://github.com/jax-ml/jax/issues/7570 from jax._src.scipy.stats.uniform import ( logpdf as logpdf, diff --git a/jax/scipy/stats/vonmises.py b/jax/scipy/stats/vonmises.py index 8de7fba47096..6572e43f63c6 100644 --- a/jax/scipy/stats/vonmises.py +++ b/jax/scipy/stats/vonmises.py @@ -13,7 +13,7 @@ # limitations under the License. # Note: import as is required for names to be exported. -# See PEP 484 & https://github.com/google/jax/issues/7570 +# See PEP 484 & https://github.com/jax-ml/jax/issues/7570 from jax._src.scipy.stats.vonmises import ( logpdf as logpdf, diff --git a/jax/scipy/stats/wrapcauchy.py b/jax/scipy/stats/wrapcauchy.py index 6e2420c5ae7b..eb1768f0c959 100644 --- a/jax/scipy/stats/wrapcauchy.py +++ b/jax/scipy/stats/wrapcauchy.py @@ -13,7 +13,7 @@ # limitations under the License. # Note: import as is required for names to be exported. -# See PEP 484 & https://github.com/google/jax/issues/7570 +# See PEP 484 & https://github.com/jax-ml/jax/issues/7570 from jax._src.scipy.stats.wrapcauchy import ( logpdf as logpdf, diff --git a/jax/sharding.py b/jax/sharding.py index fe221f90af67..26c542292e87 100644 --- a/jax/sharding.py +++ b/jax/sharding.py @@ -13,7 +13,7 @@ # limitations under the License. # Note: import as is required for names to be exported. -# See PEP 484 & https://github.com/google/jax/issues/7570 +# See PEP 484 & https://github.com/jax-ml/jax/issues/7570 from jax._src.sharding import Sharding as Sharding from jax._src.sharding_impls import ( @@ -28,6 +28,7 @@ PartitionSpec as PartitionSpec, ) from jax._src.interpreters.pxla import Mesh as Mesh +from jax._src.mesh import AbstractMesh _deprecations = { # Added Jun 4, 2024. diff --git a/jax/stages.py b/jax/stages.py index 6ffc3144c3bc..3e7e461c385b 100644 --- a/jax/stages.py +++ b/jax/stages.py @@ -22,7 +22,7 @@ """ # Note: import as is required for names to be exported. -# See PEP 484 & https://github.com/google/jax/issues/7570 +# See PEP 484 & https://github.com/jax-ml/jax/issues/7570 from jax._src.stages import ( Compiled as Compiled, diff --git a/jax/test_util.py b/jax/test_util.py index 5d4f5ed0aa77..176f4521b281 100644 --- a/jax/test_util.py +++ b/jax/test_util.py @@ -13,7 +13,7 @@ # limitations under the License. # Note: import as is required for names to be exported. -# See PEP 484 & https://github.com/google/jax/issues/7570 +# See PEP 484 & https://github.com/jax-ml/jax/issues/7570 from jax._src.public_test_util import ( check_grads as check_grads, diff --git a/jax/tools/build_defs.bzl b/jax/tools/build_defs.bzl index 1540afe42a6a..06f5e69833c5 100644 --- a/jax/tools/build_defs.bzl +++ b/jax/tools/build_defs.bzl @@ -146,9 +146,9 @@ EOF ) if format == "TF": - jax_to_ir_rule = "//third_party/py/jax/tools:jax_to_ir_with_tensorflow" + jax_to_ir_rule = "//jax/tools:jax_to_ir_with_tensorflow" else: - jax_to_ir_rule = "//third_party/py/jax/tools:jax_to_ir" + jax_to_ir_rule = "//jax/tools:jax_to_ir" py_binary( name = runner, diff --git a/jax/tools/pgo_nsys_converter.py b/jax/tools/pgo_nsys_converter.py index 5e87220be606..fc06cb360509 100644 --- a/jax/tools/pgo_nsys_converter.py +++ b/jax/tools/pgo_nsys_converter.py @@ -28,8 +28,7 @@ parser = argparse.ArgumentParser(description='Tool to convert NVIDIA Nsys Profiles to the .pbtxt format') parser.add_argument("--profile_path", type=str, help="path to nsys profile") - parser.add_argument("--post_process", help="post process pbtxt to get minimum cost value for each instruction", action="store_true") - parser.add_argument("--pgle_output_path", type=str, help="output directory", default="/opt/paxml/workspace/lhs_pbtxt/temp.pbtxt") + parser.add_argument("--pgle_output_path", type=str, help="output file", default="/opt/paxml/workspace/lhs_pbtxt/temp.pbtxt") args = parser.parse_args() @@ -38,7 +37,14 @@ profile_folder = os.path.join(os.path.split(args.profile_path)[0], '') assert isinstance(nsys_path, str) - stats_command = [nsys_path, "stats", "--force-overwrite", "true", "--force-export", "true", "--report", "nvtxkernsum", f"{args.profile_path}", "-o", f"{args.pgle_output_path}"] + + # Older versions of nsys use `nvtxsum` for the report name so determine which is available. + query_reports_command = [nsys_path, "stats", "--help-reports"] + reports_list = subprocess.run(query_reports_command, capture_output=True, text=True).stdout + report_name = "nvtx_sum" if "nvtx_sum" in reports_list else "nvtxsum" + + assert isinstance(nsys_path, str) + stats_command = [nsys_path, "stats", "--force-overwrite", "true", "--force-export", "true", "--report", report_name, f"{args.profile_path}", "-o", f"{args.pgle_output_path}"] print(f""" ******Starting stats command****** @@ -49,10 +55,10 @@ thunk_re = re.compile("hlo_op=(.*)#") with open(f"{args.pgle_output_path}", 'w', newline='') as protofile: - with open(f"{pgle_folder}{pgle_filename}.pbtxt_nvtxkernsum.csv", newline='') as csvfile: + with open(f"{pgle_folder}{pgle_filename}.pbtxt_{report_name}.csv", newline='') as csvfile: reader = csv.DictReader(csvfile) for row in reader: - name = row['NVTX Range'] + name = row['Range'] time_ns = float(row['Avg (ns)']) m = thunk_re.search(name) if m is not None: diff --git a/jax/tree_util.py b/jax/tree_util.py index b4854c7dfbf1..956d79b9b4ef 100644 --- a/jax/tree_util.py +++ b/jax/tree_util.py @@ -36,7 +36,7 @@ """ # Note: import as is required for names to be exported. -# See PEP 484 & https://github.com/google/jax/issues/7570 +# See PEP 484 & https://github.com/jax-ml/jax/issues/7570 from jax._src.tree_util import ( DictKey as DictKey, diff --git a/jax/typing.py b/jax/typing.py index c75e0567e002..89efa1f2ca66 100644 --- a/jax/typing.py +++ b/jax/typing.py @@ -24,7 +24,7 @@ - :obj:`jax.typing.ArrayLike`: annotation for any value that is safe to implicitly cast to a JAX array; this includes :class:`jax.Array`, :class:`numpy.ndarray`, as well as Python builtin numeric values (e.g. :class:`int`, :class:`float`, etc.) and numpy scalar values - (e.g. :class:`numpy.int32`, :class:`numpy.flota64`, etc.) + (e.g. :class:`numpy.int32`, :class:`numpy.float64`, etc.) - :obj:`jax.typing.DTypeLike`: annotation for any value that can be cast to a JAX-compatible dtype; this includes strings (e.g. `'float32'`, `'int32'`), scalar types (e.g. `float`, `np.float32`), dtypes (e.g. `np.dtype('float32')`), or objects with a dtype attribute diff --git a/jax/util.py b/jax/util.py index c1259e9c5f56..8071f77dffe2 100644 --- a/jax/util.py +++ b/jax/util.py @@ -13,7 +13,7 @@ # limitations under the License. # Note: import as is required for names to be exported. -# See PEP 484 & https://github.com/google/jax/issues/7570 +# See PEP 484 & https://github.com/jax-ml/jax/issues/7570 from jax._src.util import ( HashableFunction as HashableFunction, diff --git a/jax/version.py b/jax/version.py index cc690e02cb46..6c64d75b9733 100644 --- a/jax/version.py +++ b/jax/version.py @@ -21,7 +21,7 @@ import pathlib import subprocess -_version = "0.4.32" +_version = "0.4.34" # The following line is overwritten by build scripts in distributions & # releases. Do not modify this manually, or jax/jaxlib build will fail. _release_version: str | None = None @@ -115,7 +115,7 @@ def run(self): # missing or outdated. Because _write_version(...) modifies the copy of # this file in the build tree, re-building from the same JAX directory # would not automatically re-copy a clean version, and _write_version - # would fail without this deletion. See google/jax#18252. + # would fail without this deletion. See jax-ml/jax#18252. if os.path.isfile(this_file_in_build_dir): os.unlink(this_file_in_build_dir) super().run() @@ -133,7 +133,7 @@ def make_release_tree(self, base_dir, files): __version__ = _get_version_string() -_minimum_jaxlib_version = "0.4.31" +_minimum_jaxlib_version = "0.4.33" def _version_as_tuple(version_str): return tuple(int(i) for i in version_str.split(".") if i.isdigit()) diff --git a/jax_plugins/cuda/BUILD.bazel b/jax_plugins/cuda/BUILD.bazel index fea9723e189b..79aebcd86826 100644 --- a/jax_plugins/cuda/BUILD.bazel +++ b/jax_plugins/cuda/BUILD.bazel @@ -14,7 +14,6 @@ licenses(["notice"]) -load("//jaxlib:symlink_files.bzl", "symlink_files") load( "//jaxlib:jax.bzl", "if_windows", @@ -35,22 +34,15 @@ exports_files([ "setup.py", ]) -symlink_files( - name = "pjrt_c_api_gpu_plugin", - srcs = if_windows( - ["@xla//xla/pjrt/c/pjrt_c_api_gpu_plugin.pyd"], - ["@xla//xla/pjrt/c:pjrt_c_api_gpu_plugin.so"], - ), - dst = ".", - flatten = True, -) - py_library_providing_imports_info( name = "cuda_plugin", srcs = [ "__init__.py", ], - data = [":pjrt_c_api_gpu_plugin"], + data = if_windows( + ["@xla//xla/pjrt/c/pjrt_c_api_gpu_plugin.pyd"], + ["@xla//xla/pjrt/c:pjrt_c_api_gpu_plugin.so"], + ), lib_rule = pytype_library, ) diff --git a/jax_plugins/cuda/__init__.py b/jax_plugins/cuda/__init__.py index ff5a1561dbbc..9867c07b1176 100644 --- a/jax_plugins/cuda/__init__.py +++ b/jax_plugins/cuda/__init__.py @@ -48,6 +48,13 @@ def _get_library_path(): local_path = os.path.join( os.path.dirname(__file__), 'pjrt_c_api_gpu_plugin.so' ) + if not os.path.exists(local_path): + runfiles_dir = os.getenv('RUNFILES_DIR', None) + if runfiles_dir: + local_path = os.path.join( + runfiles_dir, 'xla/xla/pjrt/c/pjrt_c_api_gpu_plugin.so' + ) + if os.path.exists(local_path): logger.debug( 'Native library %s does not exist. This most likely indicates an issue' diff --git a/jax_plugins/cuda/plugin_setup.py b/jax_plugins/cuda/plugin_setup.py index 468c0c48709f..8e99907d7078 100644 --- a/jax_plugins/cuda/plugin_setup.py +++ b/jax_plugins/cuda/plugin_setup.py @@ -66,13 +66,13 @@ def has_ext_modules(self): # dependency via, for example, cuSOLVER. NVIDIA's cuSOLVER packages # do not have a version constraint on their dependencies, so the # package doesn't get upgraded even though not doing that can cause - # problems (https://github.com/google/jax/issues/18027#issuecomment-1756305196) + # problems (https://github.com/jax-ml/jax/issues/18027#issuecomment-1756305196) # Until NVIDIA add version constraints, add a version constraint # here. "nvidia-nvjitlink-cu12>=12.1.105", ], }, - url="https://github.com/google/jax", + url="https://github.com/jax-ml/jax", license="Apache-2.0", classifiers=[ "Development Status :: 3 - Alpha", diff --git a/jax_plugins/cuda/setup.py b/jax_plugins/cuda/setup.py index 96ce577fc643..1ce555978dac 100644 --- a/jax_plugins/cuda/setup.py +++ b/jax_plugins/cuda/setup.py @@ -48,7 +48,7 @@ def load_version_module(pkg_path): author_email="jax-dev@google.com", packages=packages, install_requires=[], - url="https://github.com/google/jax", + url="https://github.com/jax-ml/jax", license="Apache-2.0", classifiers=[ "Development Status :: 3 - Alpha", diff --git a/jax_plugins/rocm/BUILD.bazel b/jax_plugins/rocm/BUILD.bazel index 08a61c786262..6e265bcd18cf 100644 --- a/jax_plugins/rocm/BUILD.bazel +++ b/jax_plugins/rocm/BUILD.bazel @@ -14,7 +14,6 @@ licenses(["notice"]) -load("//jaxlib:symlink_files.bzl", "symlink_files") load( "//jaxlib:jax.bzl", "if_windows", @@ -35,21 +34,14 @@ exports_files([ "setup.py", ]) -symlink_files( - name = "pjrt_c_api_gpu_plugin", - srcs = if_windows( - ["@xla//xla/pjrt/c/pjrt_c_api_gpu_plugin.pyd"], - ["@xla//xla/pjrt/c:pjrt_c_api_gpu_plugin.so"], - ), - dst = ".", - flatten = True, -) - py_library_providing_imports_info( name = "rocm_plugin", srcs = [ "__init__.py", ], - data = [":pjrt_c_api_gpu_plugin"], + data = if_windows( + ["@xla//xla/pjrt/c/pjrt_c_api_gpu_plugin.pyd"], + ["@xla//xla/pjrt/c:pjrt_c_api_gpu_plugin.so"], + ), lib_rule = pytype_library, ) diff --git a/jax_plugins/rocm/__init__.py b/jax_plugins/rocm/__init__.py index 4535f1b3bbc8..3dbcaf4491e0 100644 --- a/jax_plugins/rocm/__init__.py +++ b/jax_plugins/rocm/__init__.py @@ -15,6 +15,7 @@ import functools import importlib import logging +import os import pathlib import platform @@ -47,6 +48,13 @@ def _get_library_path(): local_path = ( base_path / 'pjrt_c_api_gpu_plugin.so' ) + if not local_path.exists(): + runfiles_dir = os.getenv('RUNFILES_DIR', None) + if runfiles_dir: + local_path = pathlib.Path( + os.path.join(runfiles_dir, 'xla/xla/pjrt/c/pjrt_c_api_gpu_plugin.so') + ) + if local_path.exists(): logger.debug( 'Native library %s does not exist. This most likely indicates an issue' diff --git a/jax_plugins/rocm/plugin_setup.py b/jax_plugins/rocm/plugin_setup.py index 9ccf3bf44339..a84a6b34ea48 100644 --- a/jax_plugins/rocm/plugin_setup.py +++ b/jax_plugins/rocm/plugin_setup.py @@ -51,7 +51,7 @@ def has_ext_modules(self): packages=[package_name], python_requires=">=3.9", install_requires=[f"jax-rocm{rocm_version}-pjrt=={__version__}"], - url="https://github.com/google/jax", + url="https://github.com/jax-ml/jax", license="Apache-2.0", classifiers=[ "Development Status :: 3 - Alpha", diff --git a/jax_plugins/rocm/setup.py b/jax_plugins/rocm/setup.py index 8782676ce9a2..d131e732c91a 100644 --- a/jax_plugins/rocm/setup.py +++ b/jax_plugins/rocm/setup.py @@ -48,7 +48,7 @@ def load_version_module(pkg_path): author_email="Ruturaj.Vaidya@amd.com", packages=packages, install_requires=[], - url="https://github.com/google/jax", + url="https://github.com/jax-ml/jax", license="Apache-2.0", classifiers=[ "Development Status :: 3 - Alpha", diff --git a/jaxlib/BUILD b/jaxlib/BUILD index 77b46d6d51aa..ab60b3fadd37 100644 --- a/jaxlib/BUILD +++ b/jaxlib/BUILD @@ -14,19 +14,19 @@ # JAX is Autograd and XLA -load("//jaxlib:symlink_files.bzl", "symlink_files") load( "//jaxlib:jax.bzl", "py_library_providing_imports_info", "pybind_extension", "pytype_library", ) +load("//jaxlib:symlink_files.bzl", "symlink_files") licenses(["notice"]) package( default_applicable_licenses = [], - default_visibility = ["//:__subpackages__"], + default_visibility = ["//jax:internal"], ) # This makes xla_extension module accessible from jax._src.lib. @@ -129,13 +129,13 @@ cc_library( hdrs = ["ffi_helpers.h"], features = ["-use_header_modules"], deps = [ - "@xla//xla/ffi/api:c_api", - "@xla//xla/ffi/api:ffi", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings:str_format", + "@xla//xla/ffi/api:c_api", + "@xla//xla/ffi/api:ffi", ], ) @@ -149,10 +149,10 @@ cc_library( features = ["-use_header_modules"], deps = [ ":kernel_helpers", - "@xla//xla/ffi/api:c_api", - "@xla//xla/tsl/python/lib/core:numpy", "@com_google_absl//absl/base", "@nanobind", + "@xla//xla/ffi/api:c_api", + "@xla//xla/tsl/python/lib/core:numpy", ], ) @@ -201,10 +201,10 @@ pybind_extension( srcs = ["utils.cc"], module_name = "utils", deps = [ - "@xla//third_party/python_runtime:headers", "@com_google_absl//absl/cleanup", "@com_google_absl//absl/container:inlined_vector", "@nanobind", + "@xla//third_party/python_runtime:headers", ], ) @@ -238,6 +238,9 @@ pybind_extension( module_name = "rocm_plugin_extension", deps = [ "//jaxlib:kernel_nanobind_helpers", + "@com_google_absl//absl/status", + "@local_config_rocm//rocm:rocm_headers", + "@nanobind", "@xla//third_party/python_runtime:headers", "@xla//xla:status", "@xla//xla:util", @@ -248,9 +251,6 @@ pybind_extension( "@xla//xla/pjrt/c:pjrt_c_api_helpers", "@xla//xla/python:py_client_gpu", "@xla//xla/tsl/python/lib/core:numpy", - "@com_google_absl//absl/status", - "@local_config_rocm//rocm:rocm_headers", - "@nanobind", ], ) diff --git a/jaxlib/README.md b/jaxlib/README.md index 74e1e5b36ae3..cee5f246d96b 100644 --- a/jaxlib/README.md +++ b/jaxlib/README.md @@ -4,4 +4,4 @@ jaxlib is the support library for JAX. While JAX itself is a pure Python package jaxlib contains the binary (C/C++) parts of the library, including Python bindings, the XLA compiler, the PJRT runtime, and a handful of handwritten kernels. For more information, including installation and build instructions, refer to main -JAX README: https://github.com/google/jax/. +JAX README: https://github.com/jax-ml/jax/. diff --git a/jaxlib/cpu/BUILD b/jaxlib/cpu/BUILD index d97a11e4f61c..d3d15c4fc939 100644 --- a/jaxlib/cpu/BUILD +++ b/jaxlib/cpu/BUILD @@ -23,7 +23,7 @@ licenses(["notice"]) package( default_applicable_licenses = [], - default_visibility = ["//:__subpackages__"], + default_visibility = ["//jax:internal"], ) # LAPACK @@ -36,12 +36,13 @@ cc_library( features = ["-use_header_modules"], deps = [ "//jaxlib:ffi_helpers", - "@xla//xla/ffi/api:c_api", - "@xla//xla/ffi/api:ffi", - "@xla//xla/service:custom_call_status", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/base:dynamic_annotations", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/types:span", + "@xla//xla/ffi/api:c_api", + "@xla//xla/ffi/api:ffi", + "@xla//xla/service:custom_call_status", ], ) @@ -70,8 +71,8 @@ pybind_extension( deps = [ ":lapack_kernels", "//jaxlib:kernel_nanobind_helpers", - "@xla//xla/ffi/api:ffi", "@nanobind", + "@xla//xla/ffi/api:ffi", ], ) diff --git a/jaxlib/cpu/_lapack/__init__.pyi b/jaxlib/cpu/_lapack/__init__.pyi index 35c46fceeb9f..4275d8e48813 100644 --- a/jaxlib/cpu/_lapack/__init__.pyi +++ b/jaxlib/cpu/_lapack/__init__.pyi @@ -13,7 +13,6 @@ # limitations under the License. from . import eig as eig -from . import svd as svd def initialize() -> None: ... @@ -50,21 +49,7 @@ def zgesdd_work_size(m: int, n: int, job_opt_compute_uv: bool, job_opt_full_matr # FFI Kernel LAPACK Workspace Size Queries -def cgesdd_work_size_ffi(m: int, n: int, mode: svd.ComputationMode) -> int: ... -def dgesdd_work_size_ffi(m: int, n: int, mode: svd.ComputationMode) -> int: ... -def gesdd_iwork_size_ffi(m: int, n: int) -> int: ... -def gesdd_rwork_size_ffi(m: int, n: int, mode: svd.ComputationMode) -> int: ... -def heevd_rwork_size_ffi(n: int) -> int: ... -def heevd_work_size_ffi(n: int) -> int: ... -def lapack_cgeqrf_workspace_ffi(m: int, n: int) -> int: ... def lapack_cungqr_workspace_ffi(m: int, n: int, k: int) -> int: ... -def lapack_dgeqrf_workspace_ffi(m: int, n: int) -> int: ... def lapack_dorgqr_workspace_ffi(m: int, n: int, k: int) -> int: ... -def lapack_sgeqrf_workspace_ffi(m: int, n: int) -> int: ... def lapack_sorgqr_workspace_ffi(m: int, n: int, k: int) -> int: ... -def lapack_zgeqrf_workspace_ffi(m: int, n: int) -> int: ... def lapack_zungqr_workspace_ffi(m: int, n: int, k: int) -> int: ... -def sgesdd_work_size_ffi(m: int, n: int, mode: svd.ComputationMode) -> int: ... -def syevd_iwork_size_ffi(n: int) -> int: ... -def syevd_work_size_ffi(n: int) -> int: ... -def zgesdd_work_size_ffi(m: int, n: int, mode: svd.ComputationMode) -> int: ... diff --git a/jaxlib/cpu/cpu_kernels.cc b/jaxlib/cpu/cpu_kernels.cc index 93717ea9b492..c2e122c048a4 100644 --- a/jaxlib/cpu/cpu_kernels.cc +++ b/jaxlib/cpu/cpu_kernels.cc @@ -149,6 +149,10 @@ JAX_CPU_REGISTER_HANDLER(lapack_sgeev_ffi); JAX_CPU_REGISTER_HANDLER(lapack_dgeev_ffi); JAX_CPU_REGISTER_HANDLER(lapack_cgeev_ffi); JAX_CPU_REGISTER_HANDLER(lapack_zgeev_ffi); +JAX_CPU_REGISTER_HANDLER(lapack_sgehrd_ffi); +JAX_CPU_REGISTER_HANDLER(lapack_dgehrd_ffi); +JAX_CPU_REGISTER_HANDLER(lapack_cgehrd_ffi); +JAX_CPU_REGISTER_HANDLER(lapack_zgehrd_ffi); #undef JAX_CPU_REGISTER_HANDLER diff --git a/jaxlib/cpu/lapack.cc b/jaxlib/cpu/lapack.cc index 3e59a4a024d6..8fc480951b1e 100644 --- a/jaxlib/cpu/lapack.cc +++ b/jaxlib/cpu/lapack.cc @@ -37,29 +37,6 @@ svd::ComputationMode GetSvdComputationMode(bool job_opt_compute_uv, return svd::ComputationMode::kComputeFullUVt; } -template -int64_t GesddGetWorkspaceSize(lapack_int m, lapack_int n, - bool job_opt_compute_uv, - bool job_opt_full_matrices) { - svd::ComputationMode mode = - GetSvdComputationMode(job_opt_compute_uv, job_opt_full_matrices); - return svd::SVDType::GetWorkspaceSize(m, n, mode); -}; - -lapack_int GesddGetRealWorkspaceSize(lapack_int m, lapack_int n, - bool job_opt_compute_uv) { - svd::ComputationMode mode = GetSvdComputationMode(job_opt_compute_uv, true); - return svd::GetRealWorkspaceSize(m, n, mode); -} - -// Due to enforced kComputeEigenvectors, this assumes a larger workspace size. -// Could be improved to more accurately estimate the expected size based on the -// eig::ComputationMode value. -template -inline constexpr auto BoundWithEigvecs = +[](lapack_int n) { - return f(n, eig::ComputationMode::kComputeEigenvectors); -}; - void GetLapackKernelsFromScipy() { static bool initialized = false; // Protected by GIL if (initialized) return; @@ -165,6 +142,10 @@ void GetLapackKernelsFromScipy() { AssignKernelFn>(lapack_ptr("dgehrd")); AssignKernelFn>>(lapack_ptr("cgehrd")); AssignKernelFn>>(lapack_ptr("zgehrd")); + AssignKernelFn>(lapack_ptr("sgehrd")); + AssignKernelFn>(lapack_ptr("dgehrd")); + AssignKernelFn>(lapack_ptr("cgehrd")); + AssignKernelFn>(lapack_ptr("zgehrd")); AssignKernelFn>(lapack_ptr("ssytrd")); AssignKernelFn>(lapack_ptr("dsytrd")); @@ -276,6 +257,10 @@ nb::dict Registrations() { dict["lapack_dgeev_ffi"] = EncapsulateFunction(lapack_dgeev_ffi); dict["lapack_cgeev_ffi"] = EncapsulateFunction(lapack_cgeev_ffi); dict["lapack_zgeev_ffi"] = EncapsulateFunction(lapack_zgeev_ffi); + dict["lapack_sgehrd_ffi"] = EncapsulateFunction(lapack_sgehrd_ffi); + dict["lapack_dgehrd_ffi"] = EncapsulateFunction(lapack_dgehrd_ffi); + dict["lapack_cgehrd_ffi"] = EncapsulateFunction(lapack_cgehrd_ffi); + dict["lapack_zgehrd_ffi"] = EncapsulateFunction(lapack_zgehrd_ffi); return dict; } @@ -351,18 +336,6 @@ NB_MODULE(_lapack, m) { m.def("lapack_zhetrd_workspace", &Sytrd>::Workspace, nb::arg("lda"), nb::arg("n")); // FFI Kernel LAPACK Workspace Size Queries - m.def("lapack_sgeqrf_workspace_ffi", - &QrFactorization::GetWorkspaceSize, nb::arg("m"), - nb::arg("n")); - m.def("lapack_dgeqrf_workspace_ffi", - &QrFactorization::GetWorkspaceSize, nb::arg("m"), - nb::arg("n")); - m.def("lapack_cgeqrf_workspace_ffi", - &QrFactorization::GetWorkspaceSize, nb::arg("m"), - nb::arg("n")); - m.def("lapack_zgeqrf_workspace_ffi", - &QrFactorization::GetWorkspaceSize, nb::arg("m"), - nb::arg("n")); m.def("lapack_sorgqr_workspace_ffi", &OrthogonalQr::GetWorkspaceSize, nb::arg("m"), nb::arg("n"), nb::arg("k")); @@ -375,26 +348,6 @@ NB_MODULE(_lapack, m) { m.def("lapack_zungqr_workspace_ffi", &OrthogonalQr::GetWorkspaceSize, nb::arg("m"), nb::arg("n"), nb::arg("k")); - m.def("gesdd_iwork_size_ffi", &svd::GetIntWorkspaceSize, nb::arg("m"), - nb::arg("n")); - m.def("sgesdd_work_size_ffi", &svd::SVDType::GetWorkspaceSize, - nb::arg("m"), nb::arg("n"), nb::arg("mode")); - m.def("dgesdd_work_size_ffi", &svd::SVDType::GetWorkspaceSize, - nb::arg("m"), nb::arg("n"), nb::arg("mode")); - m.def("gesdd_rwork_size_ffi", &svd::GetRealWorkspaceSize, nb::arg("m"), - nb::arg("n"), nb::arg("mode")); - m.def("cgesdd_work_size_ffi", &svd::SVDType::GetWorkspaceSize, - nb::arg("m"), nb::arg("n"), nb::arg("mode")); - m.def("zgesdd_work_size_ffi", &svd::SVDType::GetWorkspaceSize, - nb::arg("m"), nb::arg("n"), nb::arg("mode")); - m.def("syevd_work_size_ffi", BoundWithEigvecs, - nb::arg("n")); - m.def("syevd_iwork_size_ffi", BoundWithEigvecs, - nb::arg("n")); - m.def("heevd_work_size_ffi", BoundWithEigvecs, - nb::arg("n")); - m.def("heevd_rwork_size_ffi", BoundWithEigvecs, - nb::arg("n")); } } // namespace diff --git a/jaxlib/cpu/lapack_kernels.cc b/jaxlib/cpu/lapack_kernels.cc index c3a32c481a8b..7d58395228d1 100644 --- a/jaxlib/cpu/lapack_kernels.cc +++ b/jaxlib/cpu/lapack_kernels.cc @@ -29,6 +29,7 @@ limitations under the License. #include "absl/algorithm/container.h" #include "absl/base/dynamic_annotations.h" +#include "absl/status/statusor.h" #include "absl/types/span.h" #include "jaxlib/ffi_helpers.h" #include "xla/ffi/api/c_api.h" @@ -79,9 +80,8 @@ inline T CastNoOverflow(int64_t value, const std::string& source = __FILE__) { template void CopyIfDiffBuffer(ffi::Buffer x, ffi::ResultBuffer x_out) { - auto [batch_count, x_rows, x_cols] = SplitBatch2D(x.dimensions()); if (x.typed_data() != x_out->typed_data()) { - const auto x_size = batch_count * x_rows * x_cols; + const auto x_size = x.element_count(); std::copy_n(x.typed_data(), x_size, x_out->typed_data()); } } @@ -149,8 +149,8 @@ ffi::Error TriMatrixEquationSolver::Kernel( MatrixParams::UpLo uplo, MatrixParams::Transpose trans_x, MatrixParams::Diag diag) { CopyIfDiffBuffer(y, y_out); - - auto [batch_count, y_rows, y_cols] = SplitBatch2D(y.dimensions()); + FFI_ASSIGN_OR_RETURN((auto [batch_count, y_rows, y_cols]), + SplitBatch2D(y.dimensions())); auto* y_out_data = y_out->typed_data(); lapack_int x_leading_dim_v = side == MatrixParams::Side::kLeft ? y_rows : y_cols; @@ -225,8 +225,8 @@ ffi::Error LuDecomposition::Kernel( ffi::Buffer x, ffi::ResultBuffer x_out, ffi::ResultBuffer ipiv, ffi::ResultBuffer info) { - FFI_RETURN_IF_ERROR(CheckMatrixDimensions(x.dimensions())); - auto [batch_count, x_rows, x_cols] = SplitBatch2D(x.dimensions()); + FFI_ASSIGN_OR_RETURN((auto [batch_count, x_rows, x_cols]), + SplitBatch2D(x.dimensions())); auto* x_out_data = x_out->typed_data(); auto* ipiv_data = ipiv->typed_data(); auto* info_data = info->typed_data(); @@ -306,19 +306,20 @@ template struct Geqrf>; // FFI Kernel template -ffi::Error QrFactorization::Kernel( - ffi::Buffer x, ffi::ResultBuffer x_out, - ffi::ResultBuffer tau, ffi::ResultBuffer info, - ffi::ResultBuffer work) { - auto [batch_count, x_rows, x_cols] = SplitBatch2D(x.dimensions()); +ffi::Error QrFactorization::Kernel(ffi::Buffer x, + ffi::ResultBuffer x_out, + ffi::ResultBuffer tau) { + FFI_ASSIGN_OR_RETURN((auto [batch_count, x_rows, x_cols]), + SplitBatch2D(x.dimensions())); auto* x_out_data = x_out->typed_data(); auto* tau_data = tau->typed_data(); - auto* info_data = info->typed_data(); - auto* work_data = work->typed_data(); + lapack_int info; + const int64_t work_size = GetWorkspaceSize(x_rows, x_cols); + auto work_data = AllocateScratchMemory(work_size); CopyIfDiffBuffer(x, x_out); - FFI_ASSIGN_OR_RETURN(auto workspace_dim_v, MaybeCastNoOverflow( - work->dimensions().back())); + FFI_ASSIGN_OR_RETURN(auto workspace_dim_v, + MaybeCastNoOverflow(work_size)); FFI_ASSIGN_OR_RETURN(auto x_rows_v, MaybeCastNoOverflow(x_rows)); FFI_ASSIGN_OR_RETURN(auto x_cols_v, MaybeCastNoOverflow(x_cols)); auto x_leading_dim_v = x_rows_v; @@ -326,11 +327,10 @@ ffi::Error QrFactorization::Kernel( const int64_t x_out_step{x_rows * x_cols}; const int64_t tau_step{std::min(x_rows, x_cols)}; for (int64_t i = 0; i < batch_count; ++i) { - fn(&x_rows_v, &x_cols_v, x_out_data, &x_leading_dim_v, tau_data, work_data, - &workspace_dim_v, info_data); + fn(&x_rows_v, &x_cols_v, x_out_data, &x_leading_dim_v, tau_data, + work_data.get(), &workspace_dim_v, &info); x_out_data += x_out_step; tau_data += tau_step; - ++info_data; } return ffi::Error::Success(); } @@ -408,33 +408,34 @@ template struct Orgqr>; template ffi::Error OrthogonalQr::Kernel(ffi::Buffer x, ffi::Buffer tau, - ffi::ResultBuffer x_out, - ffi::ResultBuffer info, - ffi::ResultBuffer work) { - auto [batch_count, x_rows, x_cols] = SplitBatch2D(x.dimensions()); + ffi::ResultBuffer x_out) { + FFI_ASSIGN_OR_RETURN((auto [batch_count, x_rows, x_cols]), + SplitBatch2D(x.dimensions())); auto* tau_data = tau.typed_data(); auto* x_out_data = x_out->typed_data(); - auto* info_data = info->typed_data(); - auto* work_data = work->typed_data(); + lapack_int info; CopyIfDiffBuffer(x, x_out); - FFI_ASSIGN_OR_RETURN(auto tau_size_v, MaybeCastNoOverflow( - tau.dimensions().back())); + // Prepare LAPACK workspaces. + int64_t work_size = GetWorkspaceSize(x_rows, x_cols, tau.dimensions().back()); + FFI_ASSIGN_OR_RETURN(auto work_size_v, + MaybeCastNoOverflow(work_size)); + auto work_data = AllocateScratchMemory(work_size); + FFI_ASSIGN_OR_RETURN(auto x_rows_v, MaybeCastNoOverflow(x_rows)); FFI_ASSIGN_OR_RETURN(auto x_cols_v, MaybeCastNoOverflow(x_cols)); - FFI_ASSIGN_OR_RETURN(auto workspace_dim_v, MaybeCastNoOverflow( - work->dimensions().back())); + FFI_ASSIGN_OR_RETURN(auto tau_size_v, MaybeCastNoOverflow( + tau.dimensions().back())); auto x_leading_dim_v = x_rows_v; const int64_t x_out_step{x_rows * x_cols}; const int64_t tau_step{tau_size_v}; for (int64_t i = 0; i < batch_count; ++i) { fn(&x_rows_v, &x_cols_v, &tau_size_v, x_out_data, &x_leading_dim_v, - tau_data, work_data, &workspace_dim_v, info_data); + tau_data, work_data.get(), &work_size_v, &info); x_out_data += x_out_step; tau_data += tau_step; - ++info_data; } return ffi::Error::Success(); } @@ -499,8 +500,8 @@ template ffi::Error CholeskyFactorization::Kernel( ffi::Buffer x, MatrixParams::UpLo uplo, ffi::ResultBuffer x_out, ffi::ResultBuffer info) { - FFI_RETURN_IF_ERROR(CheckMatrixDimensions(x.dimensions())); - auto [batch_count, x_rows, x_cols] = SplitBatch2D(x.dimensions()); + FFI_ASSIGN_OR_RETURN((auto [batch_count, x_rows, x_cols]), + SplitBatch2D(x.dimensions())); auto* x_out_data = x_out->typed_data(); auto* info_data = info->typed_data(); @@ -686,40 +687,48 @@ template struct ComplexGesdd>; namespace internal { -template -using RealBufferForComplexOrNull = - std::conditional_t(), - ffi::ResultBuffer, std::nullptr_t>; - template static ffi::Error SvdKernel( ffi::Buffer x, ffi::ResultBuffer x_out, ffi::ResultBuffer singular_values, ffi::ResultBuffer u, ffi::ResultBuffer vt, - ffi::ResultBuffer info, - ffi::ResultBuffer iwork, ffi::ResultBuffer work, - svd::ComputationMode mode, RealBufferForComplexOrNull rwork) { + ffi::ResultBuffer info, svd::ComputationMode mode) { if (mode == svd::ComputationMode::kComputeVtOverwriteXPartialU) [[unlikely]] { return ffi::Error( XLA_FFI_Error_Code_UNIMPLEMENTED, "Current implementation does not support this computation mode"); } - auto [batch_count, x_rows, x_cols] = SplitBatch2D(x.dimensions()); + FFI_ASSIGN_OR_RETURN((auto [batch_count, x_rows, x_cols]), + SplitBatch2D(x.dimensions())); auto* x_out_data = x_out->typed_data(); auto* singular_values_data = singular_values->typed_data(); auto* u_data = u->typed_data(); auto* vt_data = vt->typed_data(); auto* info_data = info->typed_data(); - auto* iwork_data = iwork->typed_data(); - auto* work_data = work->typed_data(); + + // Prepare LAPACK workspaces. + FFI_ASSIGN_OR_RETURN( + const auto work_size, + svd::SVDType::GetWorkspaceSize(x_rows, x_cols, mode)); + FFI_ASSIGN_OR_RETURN(const auto iwork_size, + svd::GetIntWorkspaceSize(x_rows, x_cols)); + auto work_data = AllocateScratchMemory(work_size); + auto iwork_data = AllocateScratchMemory(iwork_size); + using RealType = typename svd::SVDType::RealType; + std::unique_ptr rwork; + if constexpr (ffi::IsComplexType()) { + FFI_ASSIGN_OR_RETURN(const auto rwork_size, + svd::GetRealWorkspaceSize(x_rows, x_cols, mode)); + rwork = AllocateScratchMemory(rwork_size); + } CopyIfDiffBuffer(x, x_out); FFI_ASSIGN_OR_RETURN(auto x_rows_v, MaybeCastNoOverflow(x_rows)); FFI_ASSIGN_OR_RETURN(auto x_cols_v, MaybeCastNoOverflow(x_cols)); auto mode_v = static_cast(mode); - FFI_ASSIGN_OR_RETURN(auto workspace_dim_v, MaybeCastNoOverflow( - work->dimensions().back())); + FFI_ASSIGN_OR_RETURN(auto workspace_dim_v, + MaybeCastNoOverflow(work_size)); auto x_leading_dim_v = x_rows_v; auto u_leading_dim_v = x_rows_v; @@ -738,14 +747,14 @@ static ffi::Error SvdKernel( svd::SVDType::fn(&mode_v, &x_rows_v, &x_cols_v, x_out_data, &x_leading_dim_v, singular_values_data, u_data, &u_leading_dim_v, vt_data, &vt_leading_dim_v, - work_data, &workspace_dim_v, rwork->typed_data(), - iwork_data, info_data); + work_data.get(), &workspace_dim_v, rwork.get(), + iwork_data.get(), info_data); } else { svd::SVDType::fn(&mode_v, &x_rows_v, &x_cols_v, x_out_data, &x_leading_dim_v, singular_values_data, u_data, &u_leading_dim_v, vt_data, &vt_leading_dim_v, - work_data, &workspace_dim_v, iwork_data, - info_data); + work_data.get(), &workspace_dim_v, + iwork_data.get(), info_data); } x_out_data += x_out_step; singular_values_data += singular_values_step; @@ -767,7 +776,6 @@ static int64_t SvdGetWorkspaceSize(lapack_int x_rows, lapack_int x_cols, auto x_leading_dim_v = x_rows; auto u_leading_dim_v = x_rows; auto vt_leading_dim_v = mode == svd::ComputationMode::kComputeFullUVt - ? x_cols : std::min(x_rows, x_cols); if constexpr (ffi::IsComplexType()) { @@ -791,10 +799,9 @@ ffi::Error SingularValueDecomposition::Kernel( ffi::Buffer x, ffi::ResultBuffer x_out, ffi::ResultBuffer singular_values, ffi::ResultBuffer u, ffi::ResultBuffer vt, ffi::ResultBuffer info, - ffi::ResultBuffer iwork, ffi::ResultBuffer work, svd::ComputationMode mode) { return internal::SvdKernel(x, x_out, singular_values, u, vt, info, - iwork, work, mode, nullptr); + mode); } template @@ -802,39 +809,38 @@ ffi::Error SingularValueDecompositionComplex::Kernel( ffi::Buffer x, ffi::ResultBuffer x_out, ffi::ResultBuffer singular_values, ffi::ResultBuffer u, ffi::ResultBuffer vt, - ffi::ResultBuffer info, - ffi::ResultBuffer rwork, - ffi::ResultBuffer iwork, ffi::ResultBuffer work, - svd::ComputationMode mode) { + ffi::ResultBuffer info, svd::ComputationMode mode) { return internal::SvdKernel(x, x_out, singular_values, u, vt, info, - iwork, work, mode, rwork); + mode); } template -int64_t SingularValueDecomposition::GetWorkspaceSize( +absl::StatusOr SingularValueDecomposition::GetWorkspaceSize( lapack_int x_rows, lapack_int x_cols, svd::ComputationMode mode) { return internal::SvdGetWorkspaceSize(x_rows, x_cols, mode); } template -int64_t SingularValueDecompositionComplex::GetWorkspaceSize( +absl::StatusOr +SingularValueDecompositionComplex::GetWorkspaceSize( lapack_int x_rows, lapack_int x_cols, svd::ComputationMode mode) { return internal::SvdGetWorkspaceSize(x_rows, x_cols, mode); } -lapack_int svd::GetRealWorkspaceSize(int64_t x_rows, int64_t x_cols, - svd::ComputationMode mode) { +absl::StatusOr svd::GetRealWorkspaceSize( + int64_t x_rows, int64_t x_cols, svd::ComputationMode mode) { const auto min_dim = std::min(x_rows, x_cols); if (!ComputesUV(mode)) { - return CastNoOverflow(7 * min_dim); + return MaybeCastNoOverflow(7 * min_dim); } const auto max_dim = std::max(x_rows, x_cols); - return CastNoOverflow( + return MaybeCastNoOverflow( std::max(5 * min_dim * min_dim + 5 * min_dim, 2 * max_dim * min_dim + 2 * min_dim * min_dim + min_dim)); } -lapack_int svd::GetIntWorkspaceSize(int64_t x_rows, int64_t x_cols) { +absl::StatusOr svd::GetIntWorkspaceSize(int64_t x_rows, + int64_t x_cols) { return CastNoOverflow(8 * std::min(x_rows, x_cols)); } @@ -948,21 +954,24 @@ template struct ComplexHeevd>; // FFI Kernel -lapack_int eig::GetWorkspaceSize(int64_t x_cols, ComputationMode mode) { +absl::StatusOr eig::GetWorkspaceSize(int64_t x_cols, + ComputationMode mode) { switch (mode) { case ComputationMode::kNoEigenvectors: - return CastNoOverflow(2 * x_cols + 1); + return MaybeCastNoOverflow(2 * x_cols + 1); case ComputationMode::kComputeEigenvectors: - return CastNoOverflow(1 + 6 * x_cols + 2 * x_cols * x_cols); + return MaybeCastNoOverflow(1 + 6 * x_cols + + 2 * x_cols * x_cols); } } -lapack_int eig::GetIntWorkspaceSize(int64_t x_cols, ComputationMode mode) { +absl::StatusOr eig::GetIntWorkspaceSize(int64_t x_cols, + ComputationMode mode) { switch (mode) { case ComputationMode::kNoEigenvectors: return 1; case ComputationMode::kComputeEigenvectors: - return CastNoOverflow(3 + 5 * x_cols); + return MaybeCastNoOverflow(3 + 5 * x_cols); } } @@ -970,33 +979,39 @@ template ffi::Error EigenvalueDecompositionSymmetric::Kernel( ffi::Buffer x, MatrixParams::UpLo uplo, ffi::ResultBuffer x_out, ffi::ResultBuffer eigenvalues, - ffi::ResultBuffer info, ffi::ResultBuffer work, - ffi::ResultBuffer iwork, eig::ComputationMode mode) { - auto [batch_count, x_rows, x_cols] = SplitBatch2D(x.dimensions()); + ffi::ResultBuffer info, eig::ComputationMode mode) { + FFI_ASSIGN_OR_RETURN((auto [batch_count, x_rows, x_cols]), + SplitBatch2D(x.dimensions())); auto* x_out_data = x_out->typed_data(); auto* eigenvalues_data = eigenvalues->typed_data(); auto* info_data = info->typed_data(); - auto* work_data = work->typed_data(); - auto* iwork_data = iwork->typed_data(); CopyIfDiffBuffer(x, x_out); auto mode_v = static_cast(mode); auto uplo_v = static_cast(uplo); FFI_ASSIGN_OR_RETURN(auto x_cols_v, MaybeCastNoOverflow(x_cols)); - FFI_ASSIGN_OR_RETURN(auto workspace_dim_v, MaybeCastNoOverflow( - work->dimensions().back())); - FFI_ASSIGN_OR_RETURN(auto iworkspace_dim_v, MaybeCastNoOverflow( - iwork->dimensions().back())); FFI_ASSIGN_OR_RETURN(auto x_leading_dim_v, MaybeCastNoOverflow(x_cols)); + // Prepare LAPACK workspaces. + FFI_ASSIGN_OR_RETURN(lapack_int work_size_v, + eig::GetWorkspaceSize(x_cols, mode)); + FFI_ASSIGN_OR_RETURN(lapack_int iwork_size_v, + eig::GetIntWorkspaceSize(x_cols, mode)); + auto work_data = AllocateScratchMemory(work_size_v); + auto iwork_data = AllocateScratchMemory(iwork_size_v); const int64_t x_out_step{x_cols * x_cols}; const int64_t eigenvalues_step{x_cols}; for (int64_t i = 0; i < batch_count; ++i) { fn(&mode_v, &uplo_v, &x_cols_v, x_out_data, &x_leading_dim_v, - eigenvalues_data, work_data, &workspace_dim_v, iwork_data, - &iworkspace_dim_v, info_data); + eigenvalues_data, work_data.get(), &work_size_v, iwork_data.get(), + &iwork_size_v, info_data); + ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(x_out_data, + sizeof(*x_out_data) * x_cols * x_cols); + ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(eigenvalues_data, + sizeof(*eigenvalues_data) * x_cols); + ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(info_data, sizeof(lapack_int)); x_out_data += x_out_step; eigenvalues_data += eigenvalues_step; ++info_data; @@ -1006,21 +1021,24 @@ ffi::Error EigenvalueDecompositionSymmetric::Kernel( namespace eig { -lapack_int GetComplexWorkspaceSize(int64_t x_cols, ComputationMode mode) { +absl::StatusOr GetComplexWorkspaceSize(int64_t x_cols, + ComputationMode mode) { switch (mode) { case ComputationMode::kNoEigenvectors: - return CastNoOverflow(x_cols + 1); + return MaybeCastNoOverflow(x_cols + 1); case ComputationMode::kComputeEigenvectors: - return CastNoOverflow(2 * x_cols + x_cols * x_cols); + return MaybeCastNoOverflow(2 * x_cols + x_cols * x_cols); } } -lapack_int GetRealWorkspaceSize(int64_t x_cols, ComputationMode mode) { +absl::StatusOr GetRealWorkspaceSize(int64_t x_cols, + ComputationMode mode) { switch (mode) { case ComputationMode::kNoEigenvectors: - return CastNoOverflow(std::max(x_cols, int64_t{1})); + return MaybeCastNoOverflow(std::max(x_cols, int64_t{1})); case ComputationMode::kComputeEigenvectors: - return CastNoOverflow(1 + 5 * x_cols + 2 * x_cols * x_cols); + return MaybeCastNoOverflow(1 + 5 * x_cols + + 2 * x_cols * x_cols); } } @@ -1031,36 +1049,37 @@ ffi::Error EigenvalueDecompositionHermitian::Kernel( ffi::Buffer x, MatrixParams::UpLo uplo, ffi::ResultBuffer x_out, ffi::ResultBuffer eigenvalues, - ffi::ResultBuffer info, ffi::ResultBuffer work, - ffi::ResultBuffer rwork, - ffi::ResultBuffer iwork, eig::ComputationMode mode) { - auto [batch_count, x_rows, x_cols] = SplitBatch2D(x.dimensions()); + ffi::ResultBuffer info, eig::ComputationMode mode) { + FFI_ASSIGN_OR_RETURN((auto [batch_count, x_rows, x_cols]), + SplitBatch2D(x.dimensions())); auto* x_out_data = x_out->typed_data(); auto* eigenvalues_data = eigenvalues->typed_data(); auto* info_data = info->typed_data(); - auto* work_data = work->typed_data(); - auto* iwork_data = iwork->typed_data(); CopyIfDiffBuffer(x, x_out); auto mode_v = static_cast(mode); auto uplo_v = static_cast(uplo); FFI_ASSIGN_OR_RETURN(auto x_cols_v, MaybeCastNoOverflow(x_cols)); - FFI_ASSIGN_OR_RETURN(auto workspace_dim_v, MaybeCastNoOverflow( - work->dimensions().back())); - FFI_ASSIGN_OR_RETURN(auto rworkspace_dim_v, MaybeCastNoOverflow( - rwork->dimensions().back())); - FFI_ASSIGN_OR_RETURN(auto iworkspace_dim_v, MaybeCastNoOverflow( - iwork->dimensions().back())); FFI_ASSIGN_OR_RETURN(auto x_leading_dim_v, MaybeCastNoOverflow(x_cols)); + // Prepare LAPACK workspaces. + FFI_ASSIGN_OR_RETURN(lapack_int work_size_v, + eig::GetComplexWorkspaceSize(x_cols, mode)); + FFI_ASSIGN_OR_RETURN(lapack_int rwork_size_v, + eig::GetRealWorkspaceSize(x_cols, mode)); + FFI_ASSIGN_OR_RETURN(lapack_int iwork_size_v, + eig::GetIntWorkspaceSize(x_cols, mode)); + auto work_data = AllocateScratchMemory(work_size_v); + auto iwork_data = AllocateScratchMemory(iwork_size_v); + auto rwork_data = AllocateScratchMemory(rwork_size_v); const int64_t x_out_step{x_cols * x_cols}; const int64_t eigenvalues_step{x_cols}; for (int64_t i = 0; i < batch_count; ++i) { fn(&mode_v, &uplo_v, &x_cols_v, x_out_data, &x_leading_dim_v, - eigenvalues_data, work_data, &workspace_dim_v, rwork->typed_data(), - &rworkspace_dim_v, iwork_data, &iworkspace_dim_v, info_data); + eigenvalues_data, work_data.get(), &work_size_v, rwork_data.get(), + &rwork_size_v, iwork_data.get(), &iwork_size_v, info_data); x_out_data += x_out_step; eigenvalues_data += eigenvalues_step; ++info_data; @@ -1257,15 +1276,11 @@ ffi::Error EigenvalueDecomposition::Kernel( ffi::ResultBuffer eigvals_imag, ffi::ResultBuffer eigvecs_left, ffi::ResultBuffer eigvecs_right, - ffi::ResultBuffer info, ffi::ResultBuffer x_work, - ffi::ResultBuffer work_eigvecs_left, - ffi::ResultBuffer work_eigvecs_right) { - auto [batch_count, x_rows, x_cols] = SplitBatch2D(x.dimensions()); + ffi::ResultBuffer info) { + FFI_ASSIGN_OR_RETURN((auto [batch_count, x_rows, x_cols]), + SplitBatch2D(x.dimensions())); const auto* x_data = x.typed_data(); - auto* x_work_data = x_work->typed_data(); - auto* work_eigvecs_left_data = work_eigvecs_left->typed_data(); - auto* work_eigvecs_right_data = work_eigvecs_right->typed_data(); auto* eigvecs_left_data = eigvecs_left->typed_data(); auto* eigvecs_right_data = eigvecs_right->typed_data(); auto* eigvals_real_data = eigvals_real->typed_data(); @@ -1275,43 +1290,45 @@ ffi::Error EigenvalueDecomposition::Kernel( auto compute_left_v = static_cast(compute_left); auto compute_right_v = static_cast(compute_right); FFI_ASSIGN_OR_RETURN(auto x_cols_v, MaybeCastNoOverflow(x_cols)); - + // Prepare LAPACK workspaces. int64_t work_size = GetWorkspaceSize(x_cols_v, compute_left, compute_right); FFI_ASSIGN_OR_RETURN(auto work_size_v, MaybeCastNoOverflow(work_size)); - // TODO(phawkins): preallocate workspace using XLA. - auto work = std::make_unique(work_size); - auto* work_data = work.get(); + auto work_data = AllocateScratchMemory(work_size); + const int64_t x_size{x_cols * x_cols}; + auto x_copy = AllocateScratchMemory(x_size); + auto work_eigvecs_left = AllocateScratchMemory(x_size); + auto work_eigvecs_right = AllocateScratchMemory(x_size); const auto is_finite = [](ValueType* data, int64_t size) { return absl::c_all_of(absl::MakeSpan(data, size), [](ValueType value) { return std::isfinite(value); }); }; - const int64_t x_size{x_cols * x_cols}; [[maybe_unused]] const auto x_size_bytes = static_cast(x_size) * sizeof(ValueType); [[maybe_unused]] const auto x_cols_bytes = static_cast(x_cols) * sizeof(ValueType); for (int64_t i = 0; i < batch_count; ++i) { - std::copy_n(x_data, x_size, x_work_data); - if (is_finite(x_work_data, x_size)) { - fn(&compute_left_v, &compute_right_v, &x_cols_v, x_work_data, &x_cols_v, - eigvals_real_data, eigvals_imag_data, work_eigvecs_left_data, - &x_cols_v, work_eigvecs_right_data, &x_cols_v, work_data, &work_size_v, - info_data); - ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(x_work_data, x_size_bytes); + std::copy_n(x_data, x_size, x_copy.get()); + if (is_finite(x_copy.get(), x_size)) { + fn(&compute_left_v, &compute_right_v, &x_cols_v, x_copy.get(), &x_cols_v, + eigvals_real_data, eigvals_imag_data, work_eigvecs_left.get(), + &x_cols_v, work_eigvecs_right.get(), &x_cols_v, work_data.get(), + &work_size_v, info_data); + ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(x_copy.get(), x_size_bytes); ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(eigvals_real_data, x_cols_bytes); ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(eigvals_imag_data, x_cols_bytes); - ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(work_eigvecs_left_data, x_size_bytes); - ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(work_eigvecs_right_data, + ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(work_eigvecs_left.get(), + x_size_bytes); + ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(work_eigvecs_right.get(), x_size_bytes); ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(info_data, sizeof(lapack_int)); if (info_data[0] == 0) { - UnpackEigenvectors(x_cols_v, eigvals_imag_data, work_eigvecs_left_data, + UnpackEigenvectors(x_cols_v, eigvals_imag_data, work_eigvecs_left.get(), eigvecs_left_data); - UnpackEigenvectors(x_cols_v, eigvals_imag_data, work_eigvecs_right_data, - eigvecs_right_data); + UnpackEigenvectors(x_cols_v, eigvals_imag_data, + work_eigvecs_right.get(), eigvecs_right_data); } } else { info_data[0] = -4; @@ -1332,11 +1349,10 @@ ffi::Error EigenvalueDecompositionComplex::Kernel( eig::ComputationMode compute_right, ffi::ResultBuffer eigvals, ffi::ResultBuffer eigvecs_left, ffi::ResultBuffer eigvecs_right, - ffi::ResultBuffer info, ffi::ResultBuffer x_work, - ffi::ResultBuffer rwork) { - auto [batch_count, x_rows, x_cols] = SplitBatch2D(x.dimensions()); + ffi::ResultBuffer info) { + FFI_ASSIGN_OR_RETURN((auto [batch_count, x_rows, x_cols]), + SplitBatch2D(x.dimensions())); const auto* x_data = x.typed_data(); - auto* x_work_data = x_work->typed_data(); auto* eigvecs_left_data = eigvecs_left->typed_data(); auto* eigvecs_right_data = eigvecs_right->typed_data(); auto* eigvals_data = eigvals->typed_data(); @@ -1345,13 +1361,14 @@ ffi::Error EigenvalueDecompositionComplex::Kernel( auto compute_left_v = static_cast(compute_left); auto compute_right_v = static_cast(compute_right); FFI_ASSIGN_OR_RETURN(auto x_cols_v, MaybeCastNoOverflow(x_cols)); - + // Prepare LAPACK workspaces. int64_t work_size = GetWorkspaceSize(x_cols_v, compute_left, compute_right); FFI_ASSIGN_OR_RETURN(auto work_size_v, MaybeCastNoOverflow(work_size)); - // TODO(phawkins): preallocate workspace using XLA. - auto work = std::make_unique(work_size); - auto* work_data = work.get(); + auto work_data = AllocateScratchMemory(work_size); + const int64_t x_size{x_cols * x_cols}; + auto x_copy = AllocateScratchMemory(x_size); + auto rwork_data = AllocateScratchMemory(2 * x_cols); const auto is_finite = [](ValueType* data, int64_t size) { return absl::c_all_of(absl::MakeSpan(data, size), [](const auto& z) { @@ -1359,18 +1376,17 @@ ffi::Error EigenvalueDecompositionComplex::Kernel( }); }; - const int64_t x_size{x_cols * x_cols}; [[maybe_unused]] const auto x_size_bytes = static_cast(x_size) * sizeof(ValueType); [[maybe_unused]] const auto x_cols_bytes = static_cast(x_cols) * sizeof(ValueType); for (int64_t i = 0; i < batch_count; ++i) { - std::copy_n(x_data, x_size, x_work_data); - if (is_finite(x_work_data, x_size)) { - fn(&compute_left_v, &compute_right_v, &x_cols_v, x_work_data, &x_cols_v, + std::copy_n(x_data, x_size, x_copy.get()); + if (is_finite(x_copy.get(), x_size)) { + fn(&compute_left_v, &compute_right_v, &x_cols_v, x_copy.get(), &x_cols_v, eigvals_data, eigvecs_left_data, &x_cols_v, eigvecs_right_data, - &x_cols_v, work_data, &work_size_v, rwork->typed_data(), info_data); - ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(x_work_data, x_size_bytes); + &x_cols_v, work_data.get(), &work_size_v, rwork_data.get(), info_data); + ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(x_copy.get(), x_size_bytes); ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(eigvals_data, x_cols_bytes); ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(eigvecs_left_data, x_size_bytes); ABSL_ANNOTATE_MEMORY_IS_INITIALIZED(eigvecs_right_data, x_size_bytes); @@ -1611,6 +1627,59 @@ template struct Gehrd; template struct Gehrd>; template struct Gehrd>; +// FFI Kernel + +template +ffi::Error HessenbergDecomposition::Kernel( + ffi::Buffer x, lapack_int low, lapack_int high, + ffi::ResultBuffer x_out, ffi::ResultBuffer tau, + ffi::ResultBuffer info) { + FFI_ASSIGN_OR_RETURN((auto [batch_count, x_rows, x_cols]), + SplitBatch2D(x.dimensions())); + + CopyIfDiffBuffer(x, x_out); + + ValueType* x_out_data = x_out->typed_data(); + ValueType* tau_data = tau->typed_data(); + lapack_int* info_data = info->typed_data(); + FFI_ASSIGN_OR_RETURN(auto x_cols_v, MaybeCastNoOverflow(x_cols)); + FFI_ASSIGN_OR_RETURN(auto x_leading_dim_v, + MaybeCastNoOverflow(x_rows)); + // Prepare LAPACK workspaces. + int64_t work_size = GetWorkspaceSize(x_rows, x_cols, low, high); + FFI_ASSIGN_OR_RETURN(auto work_size_v, + MaybeCastNoOverflow(work_size)); + auto work_data = AllocateScratchMemory(work_size); + + int64_t x_size{x_rows * x_cols}; + for (int64_t i = 0; i < batch_count; ++i) { + fn(&x_cols_v, &low, &high, x_out_data, &x_leading_dim_v, tau_data, + work_data.get(), &work_size_v, info_data); + x_out_data += x_size; + tau_data += x_cols - 1; + ++info_data; + } + return ffi::Error::Success(); +} + +template +int64_t HessenbergDecomposition::GetWorkspaceSize(lapack_int x_rows, + lapack_int x_cols, + lapack_int low, + lapack_int high) { + ValueType optimal_size = {}; + lapack_int workspace_query = -1; + lapack_int info = 0; + fn(&x_cols, &low, &high, nullptr, &x_rows, nullptr, &optimal_size, + &workspace_query, &info); + return info == 0 ? static_cast(std::real(optimal_size)) : -1; +} + +template struct HessenbergDecomposition; +template struct HessenbergDecomposition; +template struct HessenbergDecomposition; +template struct HessenbergDecomposition; + //== Tridiagonal Reduction ==// // lapack sytrd/hetrd @@ -1696,25 +1765,21 @@ template struct Sytrd>; .Ret<::xla::ffi::Buffer>(/*ipiv*/) \ .Ret<::xla::ffi::Buffer>(/*info*/)) -#define JAX_CPU_DEFINE_GEQRF(name, data_type) \ - XLA_FFI_DEFINE_HANDLER_SYMBOL( \ - name, QrFactorization::Kernel, \ - ::xla::ffi::Ffi::Bind() \ - .Arg<::xla::ffi::Buffer>(/*x*/) \ - .Ret<::xla::ffi::Buffer>(/*x_out*/) \ - .Ret<::xla::ffi::Buffer>(/*tau*/) \ - .Ret<::xla::ffi::Buffer>(/*info*/) \ - .Ret<::xla::ffi::Buffer>(/*work*/)) +#define JAX_CPU_DEFINE_GEQRF(name, data_type) \ + XLA_FFI_DEFINE_HANDLER_SYMBOL( \ + name, QrFactorization::Kernel, \ + ::xla::ffi::Ffi::Bind() \ + .Arg<::xla::ffi::Buffer>(/*x*/) \ + .Ret<::xla::ffi::Buffer>(/*x_out*/) \ + .Ret<::xla::ffi::Buffer>(/*tau*/)) -#define JAX_CPU_DEFINE_ORGQR(name, data_type) \ - XLA_FFI_DEFINE_HANDLER_SYMBOL( \ - name, OrthogonalQr::Kernel, \ - ::xla::ffi::Ffi::Bind() \ - .Arg<::xla::ffi::Buffer>(/*x*/) \ - .Arg<::xla::ffi::Buffer>(/*tau*/) \ - .Ret<::xla::ffi::Buffer>(/*x_out*/) \ - .Ret<::xla::ffi::Buffer>(/*info*/) \ - .Ret<::xla::ffi::Buffer>(/*work*/)) +#define JAX_CPU_DEFINE_ORGQR(name, data_type) \ + XLA_FFI_DEFINE_HANDLER_SYMBOL( \ + name, OrthogonalQr::Kernel, \ + ::xla::ffi::Ffi::Bind() \ + .Arg<::xla::ffi::Buffer>(/*x*/) \ + .Arg<::xla::ffi::Buffer>(/*tau*/) \ + .Ret<::xla::ffi::Buffer>(/*x_out*/)) #define JAX_CPU_DEFINE_POTRF(name, data_type) \ XLA_FFI_DEFINE_HANDLER_SYMBOL( \ @@ -1725,33 +1790,28 @@ template struct Sytrd>; .Ret<::xla::ffi::Buffer>(/*x_out*/) \ .Ret<::xla::ffi::Buffer>(/*info*/)) -#define JAX_CPU_DEFINE_GESDD(name, data_type) \ - XLA_FFI_DEFINE_HANDLER_SYMBOL( \ - name, SingularValueDecomposition::Kernel, \ - ::xla::ffi::Ffi::Bind() \ - .Arg<::xla::ffi::Buffer>(/*x*/) \ - .Ret<::xla::ffi::Buffer>(/*x_out*/) \ - .Ret<::xla::ffi::Buffer>(/*s*/) \ - .Ret<::xla::ffi::Buffer>(/*u*/) \ - .Ret<::xla::ffi::Buffer>(/*vt*/) \ - .Ret<::xla::ffi::Buffer>(/*info*/) \ - .Ret<::xla::ffi::Buffer>(/*iwork*/) \ - .Ret<::xla::ffi::Buffer>(/*work*/) \ +#define JAX_CPU_DEFINE_GESDD(name, data_type) \ + XLA_FFI_DEFINE_HANDLER_SYMBOL( \ + name, SingularValueDecomposition::Kernel, \ + ::xla::ffi::Ffi::Bind() \ + .Arg<::xla::ffi::Buffer>(/*x*/) \ + .Ret<::xla::ffi::Buffer>(/*x_out*/) \ + .Ret<::xla::ffi::Buffer>(/*s*/) \ + .Ret<::xla::ffi::Buffer>(/*u*/) \ + .Ret<::xla::ffi::Buffer>(/*vt*/) \ + .Ret<::xla::ffi::Buffer>(/*info*/) \ .Attr("mode")) -#define JAX_CPU_DEFINE_GESDD_COMPLEX(name, data_type) \ - XLA_FFI_DEFINE_HANDLER_SYMBOL( \ - name, SingularValueDecompositionComplex::Kernel, \ - ::xla::ffi::Ffi::Bind() \ - .Arg<::xla::ffi::Buffer>(/*x*/) \ - .Ret<::xla::ffi::Buffer>(/*x_out*/) \ - .Ret<::xla::ffi::Buffer<::xla::ffi::ToReal(data_type)>>(/*s*/) \ - .Ret<::xla::ffi::Buffer>(/*u*/) \ - .Ret<::xla::ffi::Buffer>(/*vt*/) \ - .Ret<::xla::ffi::Buffer>(/*info*/) \ - .Ret<::xla::ffi::Buffer<::xla::ffi::ToReal(data_type)>>(/*rwork*/) \ - .Ret<::xla::ffi::Buffer>(/*iwork*/) \ - .Ret<::xla::ffi::Buffer>(/*work*/) \ +#define JAX_CPU_DEFINE_GESDD_COMPLEX(name, data_type) \ + XLA_FFI_DEFINE_HANDLER_SYMBOL( \ + name, SingularValueDecompositionComplex::Kernel, \ + ::xla::ffi::Ffi::Bind() \ + .Arg<::xla::ffi::Buffer>(/*x*/) \ + .Ret<::xla::ffi::Buffer>(/*x_out*/) \ + .Ret<::xla::ffi::Buffer<::xla::ffi::ToReal(data_type)>>(/*s*/) \ + .Ret<::xla::ffi::Buffer>(/*u*/) \ + .Ret<::xla::ffi::Buffer>(/*vt*/) \ + .Ret<::xla::ffi::Buffer>(/*info*/) \ .Attr("mode")) #define JAX_CPU_DEFINE_SYEVD(name, data_type) \ @@ -1763,23 +1823,18 @@ template struct Sytrd>; .Ret<::xla::ffi::Buffer>(/*x_out*/) \ .Ret<::xla::ffi::Buffer>(/*eigenvalues*/) \ .Ret<::xla::ffi::Buffer>(/*info*/) \ - .Ret<::xla::ffi::Buffer>(/*work*/) \ - .Ret<::xla::ffi::Buffer>(/*iwork*/) \ .Attr("mode")) -#define JAX_CPU_DEFINE_HEEVD(name, data_type) \ - XLA_FFI_DEFINE_HANDLER_SYMBOL( \ - name, EigenvalueDecompositionHermitian::Kernel, \ - ::xla::ffi::Ffi::Bind() \ - .Arg<::xla::ffi::Buffer>(/*x*/) \ - .Attr("uplo") \ - .Ret<::xla::ffi::Buffer>(/*x_out*/) \ - .Ret<::xla::ffi::Buffer<::xla::ffi::ToReal(data_type)>>( \ - /*eigenvalues*/) \ - .Ret<::xla::ffi::Buffer>(/*info*/) \ - .Ret<::xla::ffi::Buffer>(/*work*/) \ - .Ret<::xla::ffi::Buffer<::xla::ffi::ToReal(data_type)>>(/*rwork*/) \ - .Ret<::xla::ffi::Buffer>(/*iwork*/) \ +#define JAX_CPU_DEFINE_HEEVD(name, data_type) \ + XLA_FFI_DEFINE_HANDLER_SYMBOL( \ + name, EigenvalueDecompositionHermitian::Kernel, \ + ::xla::ffi::Ffi::Bind() \ + .Arg<::xla::ffi::Buffer>(/*x*/) \ + .Attr("uplo") \ + .Ret<::xla::ffi::Buffer>(/*x_out*/) \ + .Ret<::xla::ffi::Buffer<::xla::ffi::ToReal(data_type)>>( \ + /*eigenvalues*/) \ + .Ret<::xla::ffi::Buffer>(/*info*/) \ .Attr("mode")) #define JAX_CPU_DEFINE_GEEV(name, data_type) \ @@ -1795,12 +1850,7 @@ template struct Sytrd>; /*eigvecs_left*/) \ .Ret<::xla::ffi::Buffer<::xla::ffi::ToComplex(data_type)>>( \ /*eigvecs_right*/) \ - .Ret<::xla::ffi::Buffer>(/*info*/) \ - .Ret<::xla::ffi::Buffer>(/*x_work*/) \ - .Ret<::xla::ffi::Buffer<::xla::ffi::ToReal(data_type)>>( \ - /*work_eigvecs_left*/) \ - .Ret<::xla::ffi::Buffer<::xla::ffi::ToReal(data_type)>>( \ - /*work_eigvecs_right*/)) + .Ret<::xla::ffi::Buffer>(/*info*/)) #define JAX_CPU_DEFINE_GEEV_COMPLEX(name, data_type) \ XLA_FFI_DEFINE_HANDLER_SYMBOL( \ @@ -1812,9 +1862,18 @@ template struct Sytrd>; .Ret<::xla::ffi::Buffer>(/*eigvals*/) \ .Ret<::xla::ffi::Buffer>(/*eigvecs_left*/) \ .Ret<::xla::ffi::Buffer>(/*eigvecs_right*/) \ - .Ret<::xla::ffi::Buffer>(/*info*/) \ - .Ret<::xla::ffi::Buffer>(/*x_work*/) \ - .Ret<::xla::ffi::Buffer<::xla::ffi::ToReal(data_type)>>(/*rwork*/)) + .Ret<::xla::ffi::Buffer>(/*info*/)) + +#define JAX_CPU_DEFINE_GEHRD(name, data_type) \ + XLA_FFI_DEFINE_HANDLER_SYMBOL( \ + name, HessenbergDecomposition::Kernel, \ + ::xla::ffi::Ffi::Bind() \ + .Arg<::xla::ffi::Buffer>(/*x*/) \ + .Attr("low") \ + .Attr("high") \ + .Ret<::xla::ffi::Buffer>(/*x_out*/) \ + .Ret<::xla::ffi::Buffer>(/*tau*/) \ + .Ret<::xla::ffi::Buffer>(/*info*/)) // FFI Handlers @@ -1858,6 +1917,11 @@ JAX_CPU_DEFINE_GEEV(lapack_dgeev_ffi, ::xla::ffi::DataType::F64); JAX_CPU_DEFINE_GEEV_COMPLEX(lapack_cgeev_ffi, ::xla::ffi::DataType::C64); JAX_CPU_DEFINE_GEEV_COMPLEX(lapack_zgeev_ffi, ::xla::ffi::DataType::C128); +JAX_CPU_DEFINE_GEHRD(lapack_sgehrd_ffi, ::xla::ffi::DataType::F32); +JAX_CPU_DEFINE_GEHRD(lapack_dgehrd_ffi, ::xla::ffi::DataType::F64); +JAX_CPU_DEFINE_GEHRD(lapack_cgehrd_ffi, ::xla::ffi::DataType::C64); +JAX_CPU_DEFINE_GEHRD(lapack_zgehrd_ffi, ::xla::ffi::DataType::C128); + #undef JAX_CPU_DEFINE_TRSM #undef JAX_CPU_DEFINE_GETRF #undef JAX_CPU_DEFINE_GEQRF @@ -1869,5 +1933,6 @@ JAX_CPU_DEFINE_GEEV_COMPLEX(lapack_zgeev_ffi, ::xla::ffi::DataType::C128); #undef JAX_CPU_DEFINE_HEEVD #undef JAX_CPU_DEFINE_GEEV #undef JAX_CPU_DEFINE_GEEV_COMPLEX +#undef JAX_CPU_DEFINE_GEHRD } // namespace jax diff --git a/jaxlib/cpu/lapack_kernels.h b/jaxlib/cpu/lapack_kernels.h index 5493ec8cbffc..b4f54b923478 100644 --- a/jaxlib/cpu/lapack_kernels.h +++ b/jaxlib/cpu/lapack_kernels.h @@ -20,8 +20,9 @@ limitations under the License. #include #include -#include "xla/ffi/api/ffi.h" +#include "absl/status/statusor.h" #include "xla/ffi/api/c_api.h" +#include "xla/ffi/api/ffi.h" #include "xla/service/custom_call_status.h" // Underlying function pointers (i.e., KERNEL_CLASS::Fn) are initialized either @@ -193,9 +194,7 @@ struct QrFactorization { static ::xla::ffi::Error Kernel(::xla::ffi::Buffer x, ::xla::ffi::ResultBuffer x_out, - ::xla::ffi::ResultBuffer tau, - ::xla::ffi::ResultBuffer info, - ::xla::ffi::ResultBuffer work); + ::xla::ffi::ResultBuffer tau); static int64_t GetWorkspaceSize(lapack_int x_rows, lapack_int x_cols); }; @@ -227,9 +226,7 @@ struct OrthogonalQr { static ::xla::ffi::Error Kernel(::xla::ffi::Buffer x, ::xla::ffi::Buffer tau, - ::xla::ffi::ResultBuffer x_out, - ::xla::ffi::ResultBuffer info, - ::xla::ffi::ResultBuffer work); + ::xla::ffi::ResultBuffer x_out); static int64_t GetWorkspaceSize(lapack_int x_rows, lapack_int x_cols, lapack_int tau_size); @@ -303,6 +300,7 @@ struct SingularValueDecomposition { static_assert(!::xla::ffi::IsComplexType(), "There exists a separate implementation for Complex types"); using ValueType = ::xla::ffi::NativeType; + using RealType = ValueType; using FnType = void(char* jobz, lapack_int* m, lapack_int* n, ValueType* a, lapack_int* lda, ValueType* s, ValueType* u, lapack_int* ldu, ValueType* vt, lapack_int* ldvt, @@ -315,12 +313,11 @@ struct SingularValueDecomposition { ::xla::ffi::Buffer x, ::xla::ffi::ResultBuffer x_out, ::xla::ffi::ResultBuffer singular_values, ::xla::ffi::ResultBuffer u, ::xla::ffi::ResultBuffer vt, - ::xla::ffi::ResultBuffer info, - ::xla::ffi::ResultBuffer iwork, - ::xla::ffi::ResultBuffer work, svd::ComputationMode mode); + ::xla::ffi::ResultBuffer info, svd::ComputationMode mode); - static int64_t GetWorkspaceSize(lapack_int x_rows, lapack_int x_cols, - svd::ComputationMode mode); + static absl::StatusOr GetWorkspaceSize(lapack_int x_rows, + lapack_int x_cols, + svd::ComputationMode mode); }; template <::xla::ffi::DataType dtype> @@ -341,13 +338,11 @@ struct SingularValueDecompositionComplex { ::xla::ffi::Buffer x, ::xla::ffi::ResultBuffer x_out, ::xla::ffi::ResultBuffer<::xla::ffi::ToReal(dtype)> singular_values, ::xla::ffi::ResultBuffer u, ::xla::ffi::ResultBuffer vt, - ::xla::ffi::ResultBuffer info, - ::xla::ffi::ResultBuffer<::xla::ffi::ToReal(dtype)> rwork, - ::xla::ffi::ResultBuffer iwork, - ::xla::ffi::ResultBuffer work, svd::ComputationMode mode); + ::xla::ffi::ResultBuffer info, svd::ComputationMode mode); - static int64_t GetWorkspaceSize(lapack_int x_rows, lapack_int x_cols, - svd::ComputationMode mode); + static absl::StatusOr GetWorkspaceSize(lapack_int x_rows, + lapack_int x_cols, + svd::ComputationMode mode); }; namespace svd { @@ -357,9 +352,9 @@ using SVDType = std::conditional_t<::xla::ffi::IsComplexType(), SingularValueDecompositionComplex, SingularValueDecomposition>; -lapack_int GetIntWorkspaceSize(int64_t x_rows, int64_t x_cols); -lapack_int GetRealWorkspaceSize(int64_t x_rows, int64_t x_cols, - ComputationMode mode); +absl::StatusOr GetIntWorkspaceSize(int64_t x_rows, int64_t x_cols); +absl::StatusOr GetRealWorkspaceSize(int64_t x_rows, int64_t x_cols, + ComputationMode mode); } // namespace svd @@ -398,12 +393,16 @@ struct ComplexHeevd { namespace eig { // Eigenvalue Decomposition -lapack_int GetWorkspaceSize(int64_t x_cols, ComputationMode mode); -lapack_int GetIntWorkspaceSize(int64_t x_cols, ComputationMode mode); +absl::StatusOr GetWorkspaceSize(int64_t x_cols, + ComputationMode mode); +absl::StatusOr GetIntWorkspaceSize(int64_t x_cols, + ComputationMode mode); // Hermitian Eigenvalue Decomposition -lapack_int GetComplexWorkspaceSize(int64_t x_cols, ComputationMode mode); -lapack_int GetRealWorkspaceSize(int64_t x_cols, ComputationMode mode); +absl::StatusOr GetComplexWorkspaceSize(int64_t x_cols, + ComputationMode mode); +absl::StatusOr GetRealWorkspaceSize(int64_t x_cols, + ComputationMode mode); } // namespace eig @@ -420,14 +419,12 @@ struct EigenvalueDecompositionSymmetric { inline static FnType* fn = nullptr; - static ::xla::ffi::Error Kernel( - ::xla::ffi::Buffer x, MatrixParams::UpLo uplo, - ::xla::ffi::ResultBuffer x_out, - ::xla::ffi::ResultBuffer eigenvalues, - ::xla::ffi::ResultBuffer info, - ::xla::ffi::ResultBuffer work, - ::xla::ffi::ResultBuffer iwork, - eig::ComputationMode mode); + static ::xla::ffi::Error Kernel(::xla::ffi::Buffer x, + MatrixParams::UpLo uplo, + ::xla::ffi::ResultBuffer x_out, + ::xla::ffi::ResultBuffer eigenvalues, + ::xla::ffi::ResultBuffer info, + eig::ComputationMode mode); }; template <::xla::ffi::DataType dtype> @@ -447,11 +444,7 @@ struct EigenvalueDecompositionHermitian { ::xla::ffi::Buffer x, MatrixParams::UpLo uplo, ::xla::ffi::ResultBuffer x_out, ::xla::ffi::ResultBuffer<::xla::ffi::ToReal(dtype)> eigenvalues, - ::xla::ffi::ResultBuffer info, - ::xla::ffi::ResultBuffer work, - ::xla::ffi::ResultBuffer<::xla::ffi::ToReal(dtype)> rwork, - ::xla::ffi::ResultBuffer iwork, - eig::ComputationMode mode); + ::xla::ffi::ResultBuffer info, eig::ComputationMode mode); }; // lapack geev @@ -499,10 +492,7 @@ struct EigenvalueDecomposition { ::xla::ffi::ResultBuffer eigvals_imag, ::xla::ffi::ResultBuffer<::xla::ffi::ToComplex(dtype)> eigvecs_left, ::xla::ffi::ResultBuffer<::xla::ffi::ToComplex(dtype)> eigvecs_right, - ::xla::ffi::ResultBuffer info, - ::xla::ffi::ResultBuffer x_work, - ::xla::ffi::ResultBuffer<::xla::ffi::ToReal(dtype)> work_eigvecs_left, - ::xla::ffi::ResultBuffer<::xla::ffi::ToReal(dtype)> work_eigvecs_right); + ::xla::ffi::ResultBuffer info); static int64_t GetWorkspaceSize(lapack_int x_cols, eig::ComputationMode compute_left, @@ -529,9 +519,7 @@ struct EigenvalueDecompositionComplex { ::xla::ffi::ResultBuffer eigvals, ::xla::ffi::ResultBuffer eigvecs_left, ::xla::ffi::ResultBuffer eigvecs_right, - ::xla::ffi::ResultBuffer info, - ::xla::ffi::ResultBuffer x_work, - ::xla::ffi::ResultBuffer<::xla::ffi::ToReal(dtype)> rwork); + ::xla::ffi::ResultBuffer info); static int64_t GetWorkspaceSize(lapack_int x_cols, eig::ComputationMode compute_left, @@ -590,6 +578,27 @@ struct real_type> { typedef T type; }; +// FFI Kernel + +template <::xla::ffi::DataType dtype> +struct HessenbergDecomposition { + using ValueType = ::xla::ffi::NativeType; + using FnType = void(lapack_int* n, lapack_int* ilo, lapack_int* ihi, + ValueType* a, lapack_int* lda, ValueType* tau, + ValueType* work, lapack_int* lwork, lapack_int* info); + + inline static FnType* fn = nullptr; + + static ::xla::ffi::Error Kernel( + ::xla::ffi::Buffer x, lapack_int low, lapack_int high, + ::xla::ffi::ResultBuffer x_out, + ::xla::ffi::ResultBuffer tau, + ::xla::ffi::ResultBuffer info); + + static int64_t GetWorkspaceSize(lapack_int x_rows, lapack_int x_cols, + lapack_int low, lapack_int high); +}; + //== Tridiagonal Reduction ==// //== Reduces a Symmetric/Hermitian square matrix to tridiagonal form ==// @@ -641,6 +650,10 @@ XLA_FFI_DECLARE_HANDLER_SYMBOL(lapack_sgeev_ffi); XLA_FFI_DECLARE_HANDLER_SYMBOL(lapack_dgeev_ffi); XLA_FFI_DECLARE_HANDLER_SYMBOL(lapack_cgeev_ffi); XLA_FFI_DECLARE_HANDLER_SYMBOL(lapack_zgeev_ffi); +XLA_FFI_DECLARE_HANDLER_SYMBOL(lapack_sgehrd_ffi); +XLA_FFI_DECLARE_HANDLER_SYMBOL(lapack_dgehrd_ffi); +XLA_FFI_DECLARE_HANDLER_SYMBOL(lapack_cgehrd_ffi); +XLA_FFI_DECLARE_HANDLER_SYMBOL(lapack_zgehrd_ffi); } // namespace jax diff --git a/jaxlib/cpu/lapack_kernels_using_lapack.cc b/jaxlib/cpu/lapack_kernels_using_lapack.cc index 2a2597629b93..9f13bb99d582 100644 --- a/jaxlib/cpu/lapack_kernels_using_lapack.cc +++ b/jaxlib/cpu/lapack_kernels_using_lapack.cc @@ -71,10 +71,10 @@ jax::RealGees::FnType dgees_; jax::ComplexGees>::FnType cgees_; jax::ComplexGees>::FnType zgees_; -jax::Gehrd::FnType sgehrd_; -jax::Gehrd::FnType dgehrd_; -jax::Gehrd>::FnType cgehrd_; -jax::Gehrd>::FnType zgehrd_; +jax::HessenbergDecomposition::FnType sgehrd_; +jax::HessenbergDecomposition::FnType dgehrd_; +jax::HessenbergDecomposition::FnType cgehrd_; +jax::HessenbergDecomposition::FnType zgehrd_; jax::Sytrd::FnType ssytrd_; jax::Sytrd::FnType dsytrd_; @@ -211,6 +211,22 @@ static_assert( jax::EigenvalueDecompositionComplex::FnType, jax::ComplexGeev>::FnType>, JAX_KERNEL_FNTYPE_MISMATCH_MSG); +static_assert( + std::is_same_v::FnType, + jax::Gehrd::FnType>, + JAX_KERNEL_FNTYPE_MISMATCH_MSG); +static_assert( + std::is_same_v::FnType, + jax::Gehrd::FnType>, + JAX_KERNEL_FNTYPE_MISMATCH_MSG); +static_assert( + std::is_same_v::FnType, + jax::Gehrd>::FnType>, + JAX_KERNEL_FNTYPE_MISMATCH_MSG); +static_assert( + std::is_same_v::FnType, + jax::Gehrd>::FnType>, + JAX_KERNEL_FNTYPE_MISMATCH_MSG); #undef JAX_KERNEL_FNTYPE_MISMATCH_MSG @@ -315,6 +331,11 @@ static auto init = []() -> int { AssignKernelFn>(cgeev_); AssignKernelFn>(zgeev_); + AssignKernelFn>(sgehrd_); + AssignKernelFn>(dgehrd_); + AssignKernelFn>(cgehrd_); + AssignKernelFn>(zgehrd_); + return 0; }(); diff --git a/jaxlib/cpu_feature_guard.c b/jaxlib/cpu_feature_guard.c index 7c8ff2951a79..d18478eb57d5 100644 --- a/jaxlib/cpu_feature_guard.c +++ b/jaxlib/cpu_feature_guard.c @@ -172,5 +172,12 @@ static struct PyModuleDef cpu_feature_guard_module = { #endif EXPORT_SYMBOL PyMODINIT_FUNC PyInit_cpu_feature_guard(void) { - return PyModule_Create(&cpu_feature_guard_module); + PyObject *module = PyModule_Create(&cpu_feature_guard_module); + if (module == NULL) { + return NULL; + } +#ifdef Py_GIL_DISABLED + PyUnstable_Module_SetGIL(module, Py_MOD_GIL_NOT_USED); +#endif + return module; } diff --git a/jaxlib/cuda/BUILD b/jaxlib/cuda/BUILD index e515de2d3a95..34e40d12d5be 100644 --- a/jaxlib/cuda/BUILD +++ b/jaxlib/cuda/BUILD @@ -26,7 +26,7 @@ licenses(["notice"]) package( default_applicable_licenses = [], - default_visibility = ["//:__subpackages__"], + default_visibility = ["//jax:internal"], ) cc_library( @@ -37,9 +37,9 @@ cc_library( defines = ["JAX_GPU_CUDA=1"], visibility = ["//visibility:public"], deps = [ - "@xla//xla/tsl/cuda:cupti", "@local_config_cuda//cuda:cuda_headers", "@local_config_cuda//cuda:cudnn_header", + "@xla//xla/tsl/cuda:cupti", ], ) @@ -57,9 +57,6 @@ cc_library( features = ["-use_header_modules"], deps = [ ":cuda_vendor", - "@xla//xla/tsl/cuda:cupti", - "@xla//xla/tsl/cuda:cusolver", - "@xla//xla/tsl/cuda:cusparse", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/log:check", "@com_google_absl//absl/memory", @@ -69,6 +66,19 @@ cc_library( "@com_google_absl//absl/strings:str_format", "@local_config_cuda//cuda:cublas_headers", "@local_config_cuda//cuda:cuda_headers", + "@xla//xla/tsl/cuda:cupti", + "@xla//xla/tsl/cuda:cusolver", + "@xla//xla/tsl/cuda:cusparse", + ], +) + +cuda_library( + name = "cuda_make_batch_pointers", + srcs = ["//jaxlib/gpu:make_batch_pointers.cu.cc"], + hdrs = ["//jaxlib/gpu:make_batch_pointers.h"], + deps = [ + ":cuda_vendor", + "@local_config_cuda//cuda:cuda_headers", ], ) @@ -80,11 +90,11 @@ cc_library( ":cuda_gpu_kernel_helpers", ":cuda_vendor", "//jaxlib:handle_pool", - "@xla//xla/tsl/cuda:cublas", - "@xla//xla/tsl/cuda:cudart", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/synchronization", "@local_config_cuda//cuda:cuda_headers", + "@xla//xla/tsl/cuda:cublas", + "@xla//xla/tsl/cuda:cudart", ], ) @@ -95,11 +105,9 @@ cc_library( deps = [ ":cuda_blas_handle_pool", ":cuda_gpu_kernel_helpers", + ":cuda_make_batch_pointers", ":cuda_vendor", "//jaxlib:kernel_helpers", - "@xla//xla/service:custom_call_status", - "@xla//xla/tsl/cuda:cublas", - "@xla//xla/tsl/cuda:cudart", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/base", "@com_google_absl//absl/base:core_headers", @@ -111,22 +119,9 @@ cc_library( "@com_google_absl//absl/strings:str_format", "@local_config_cuda//cuda:cublas_headers", "@local_config_cuda//cuda:cuda_headers", - ], -) - -cc_library( - name = "cublas_kernels_ffi", - srcs = ["//jaxlib/gpu:blas_kernels_ffi.cc"], - hdrs = ["//jaxlib/gpu:blas_kernels_ffi.h"], - deps = [ - ":cuda_blas_handle_pool", - ":cuda_gpu_kernel_helpers", - ":cuda_vendor", - "//jaxlib:ffi_helpers", - "@xla//xla/ffi/api:ffi", + "@xla//xla/service:custom_call_status", "@xla//xla/tsl/cuda:cublas", "@xla//xla/tsl/cuda:cudart", - "@com_google_absl//absl/status", ], ) @@ -148,15 +143,14 @@ pybind_extension( module_name = "_blas", deps = [ ":cublas_kernels", - ":cublas_kernels_ffi", ":cuda_vendor", "//jaxlib:kernel_nanobind_helpers", - "@xla//xla/tsl/cuda:cublas", - "@xla//xla/tsl/cuda:cudart", - "@xla//xla/tsl/python/lib/core:numpy", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/strings:str_format", "@nanobind", + "@xla//xla/tsl/cuda:cublas", + "@xla//xla/tsl/cuda:cudart", + "@xla//xla/tsl/python/lib/core:numpy", ], ) @@ -169,13 +163,13 @@ cc_library( ":cuda_vendor", "//jaxlib:handle_pool", "//jaxlib:kernel_helpers", - "@xla//xla/service:custom_call_status", - "@xla//xla/tsl/cuda:cudart", - "@xla//xla/tsl/cuda:cudnn", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings:str_format", "@local_config_cuda//cuda:cuda_headers", + "@xla//xla/service:custom_call_status", + "@xla//xla/tsl/cuda:cudart", + "@xla//xla/tsl/cuda:cudnn", ], ) @@ -207,11 +201,11 @@ cc_library( ":cuda_gpu_kernel_helpers", ":cuda_vendor", "//jaxlib:handle_pool", - "@xla//xla/tsl/cuda:cudart", - "@xla//xla/tsl/cuda:cusolver", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/synchronization", "@local_config_cuda//cuda:cuda_headers", + "@xla//xla/tsl/cuda:cudart", + "@xla//xla/tsl/cuda:cusolver", ], ) @@ -224,12 +218,28 @@ cc_library( ":cuda_solver_handle_pool", ":cuda_vendor", "//jaxlib:kernel_helpers", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@local_config_cuda//cuda:cuda_headers", "@xla//xla/service:custom_call_status", "@xla//xla/tsl/cuda:cudart", "@xla//xla/tsl/cuda:cusolver", + ], +) + +cc_library( + name = "cusolver_interface", + srcs = ["//jaxlib/gpu:solver_interface.cc"], + hdrs = ["//jaxlib/gpu:solver_interface.h"], + deps = [ + ":cuda_gpu_kernel_helpers", + ":cuda_vendor", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", - "@local_config_cuda//cuda:cuda_headers", + "@com_google_absl//absl/strings:str_format", + "@xla//xla/tsl/cuda:cublas", + "@xla//xla/tsl/cuda:cudart", + "@xla//xla/tsl/cuda:cusolver", ], ) @@ -238,15 +248,20 @@ cc_library( srcs = ["//jaxlib/gpu:solver_kernels_ffi.cc"], hdrs = ["//jaxlib/gpu:solver_kernels_ffi.h"], deps = [ + ":cuda_blas_handle_pool", ":cuda_gpu_kernel_helpers", + ":cuda_make_batch_pointers", ":cuda_solver_handle_pool", ":cuda_vendor", + ":cusolver_interface", "//jaxlib:ffi_helpers", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:str_format", "@xla//xla/ffi/api:ffi", + "@xla//xla/tsl/cuda:cublas", "@xla//xla/tsl/cuda:cudart", "@xla//xla/tsl/cuda:cusolver", - "@com_google_absl//absl/status", - "@com_google_absl//absl/status:statusor", ], ) @@ -274,14 +289,15 @@ pybind_extension( ":cusolver_kernels", ":cusolver_kernels_ffi", "//jaxlib:kernel_nanobind_helpers", - "@xla//xla/tsl/cuda:cudart", - "@xla//xla/tsl/cuda:cusolver", - "@xla//xla/tsl/python/lib/core:numpy", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings:str_format", "@local_config_cuda//cuda:cuda_headers", "@nanobind", + "@xla//xla/tsl/cuda:cublas", + "@xla//xla/tsl/cuda:cudart", + "@xla//xla/tsl/cuda:cusolver", + "@xla//xla/tsl/python/lib/core:numpy", ], ) @@ -294,13 +310,14 @@ cc_library( ":cuda_vendor", "//jaxlib:handle_pool", "//jaxlib:kernel_helpers", - "@xla//xla/service:custom_call_status", - "@xla//xla/tsl/cuda:cudart", - "@xla//xla/tsl/cuda:cusparse", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", "@com_google_absl//absl/synchronization", "@local_config_cuda//cuda:cuda_headers", + "@xla//xla/service:custom_call_status", + "@xla//xla/tsl/cuda:cudart", + "@xla//xla/tsl/cuda:cusparse", ], ) @@ -324,10 +341,8 @@ pybind_extension( ":cuda_gpu_kernel_helpers", ":cuda_vendor", ":cusparse_kernels", + "//jaxlib:absl_status_casters", "//jaxlib:kernel_nanobind_helpers", - "@xla//xla/tsl/cuda:cudart", - "@xla//xla/tsl/cuda:cusparse", - "@xla//xla/tsl/python/lib/core:numpy", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/base", "@com_google_absl//absl/base:core_headers", @@ -339,6 +354,9 @@ pybind_extension( "@com_google_absl//absl/synchronization", "@local_config_cuda//cuda:cuda_headers", "@nanobind", + "@xla//xla/tsl/cuda:cudart", + "@xla//xla/tsl/cuda:cusparse", + "@xla//xla/tsl/python/lib/core:numpy", ], ) @@ -350,36 +368,34 @@ cc_library( hdrs = ["//jaxlib/gpu:linalg_kernels.h"], features = ["-use_header_modules"], deps = [ + ":cuda_blas_handle_pool", ":cuda_gpu_kernel_helpers", ":cuda_linalg_kernels_impl", ":cuda_vendor", "//jaxlib:ffi_helpers", "//jaxlib:kernel_helpers", - "@xla//xla/ffi/api:c_api", - "@xla//xla/ffi/api:ffi", - "@xla//xla/service:custom_call_status", - "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings:str_format", "@local_config_cuda//cuda:cuda_headers", + "@xla//xla/ffi/api:c_api", + "@xla//xla/ffi/api:ffi", + "@xla//xla/service:custom_call_status", + "@xla//xla/tsl/cuda:cublas", ], ) cuda_library( name = "cuda_linalg_kernels_impl", - srcs = [ - "//jaxlib/gpu:linalg_kernels.cu.cc", - ], - hdrs = [ - "//jaxlib/gpu:linalg_kernels.h", - ], + srcs = ["//jaxlib/gpu:linalg_kernels.cu.cc"], + hdrs = ["//jaxlib/gpu:linalg_kernels.h"], deps = [ ":cuda_gpu_kernel_helpers", ":cuda_vendor", + "//jaxlib:ffi_helpers", + "@local_config_cuda//cuda:cuda_headers", "@xla//xla/ffi/api:ffi", "@xla//xla/service:custom_call_status", - "@local_config_cuda//cuda:cuda_headers", ], ) @@ -397,10 +413,10 @@ pybind_extension( ":cuda_linalg_kernels", ":cuda_vendor", "//jaxlib:kernel_nanobind_helpers", - "@xla//xla/tsl/cuda:cudart", - "@xla//xla/tsl/python/lib/core:numpy", "@local_config_cuda//cuda:cuda_headers", "@nanobind", + "@xla//xla/tsl/cuda:cudart", + "@xla//xla/tsl/python/lib/core:numpy", ], ) @@ -416,12 +432,12 @@ cc_library( ":cuda_vendor", "//jaxlib:ffi_helpers", "//jaxlib:kernel_helpers", - "@xla//xla/ffi/api:c_api", - "@xla//xla/ffi/api:ffi", - "@xla//xla/service:custom_call_status", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/status", "@local_config_cuda//cuda:cuda_headers", + "@xla//xla/ffi/api:c_api", + "@xla//xla/ffi/api:ffi", + "@xla//xla/service:custom_call_status", ], ) @@ -435,9 +451,9 @@ cuda_library( ":cuda_gpu_kernel_helpers", ":cuda_vendor", "//jaxlib:kernel_helpers", + "@local_config_cuda//cuda:cuda_headers", "@xla//xla/ffi/api:ffi", "@xla//xla/service:custom_call_status", - "@local_config_cuda//cuda:cuda_headers", ], ) @@ -454,9 +470,9 @@ pybind_extension( ":cuda_gpu_kernel_helpers", ":cuda_prng_kernels", "//jaxlib:kernel_nanobind_helpers", - "@xla//xla/tsl/cuda:cudart", "@local_config_cuda//cuda:cuda_headers", "@nanobind", + "@xla//xla/tsl/cuda:cudart", ], ) @@ -466,7 +482,6 @@ cc_library( visibility = ["//visibility:public"], deps = [ ":cublas_kernels", - ":cublas_kernels_ffi", ":cuda_linalg_kernels", ":cuda_prng_kernels", ":cuda_vendor", @@ -491,10 +506,6 @@ cc_library( ":cuda_vendor", ":triton_utils", "//jaxlib/gpu:triton_cc_proto", - "@xla//xla/service:custom_call_status", - "@xla//xla/stream_executor/cuda:cuda_asm_compiler", - "@xla//xla/tsl/cuda:cudart", - "@tsl//tsl/platform:env", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/cleanup", "@com_google_absl//absl/container:flat_hash_map", @@ -505,6 +516,10 @@ cc_library( "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/synchronization", + "@tsl//tsl/platform:env", + "@xla//xla/service:custom_call_status", + "@xla//xla/stream_executor/cuda:cuda_asm_compiler", + "@xla//xla/tsl/cuda:cudart", ], ) @@ -564,6 +579,7 @@ cc_library( deps = [ ":cuda_gpu_kernel_helpers", ":cuda_vendor", + "@com_google_absl//absl/base:dynamic_annotations", "@xla//xla/tsl/cuda:cublas", "@xla//xla/tsl/cuda:cudart", "@xla//xla/tsl/cuda:cudnn", @@ -571,7 +587,6 @@ cc_library( "@xla//xla/tsl/cuda:cupti", "@xla//xla/tsl/cuda:cusolver", "@xla//xla/tsl/cuda:cusparse", - "@com_google_absl//absl/base:dynamic_annotations", ], ) @@ -602,6 +617,8 @@ pybind_extension( ":versions_helpers", "//jaxlib:absl_status_casters", "//jaxlib:kernel_nanobind_helpers", + "@com_google_absl//absl/status:statusor", + "@nanobind", "@xla//xla/tsl/cuda:cublas", "@xla//xla/tsl/cuda:cudart", "@xla//xla/tsl/cuda:cudnn", @@ -609,8 +626,6 @@ pybind_extension( "@xla//xla/tsl/cuda:cupti", "@xla//xla/tsl/cuda:cusolver", "@xla//xla/tsl/cuda:cusparse", - "@com_google_absl//absl/status:statusor", - "@nanobind", ], ) diff --git a/jaxlib/cuda_plugin_extension.cc b/jaxlib/cuda_plugin_extension.cc index 0bb8cbbace65..ea81109b36c0 100644 --- a/jaxlib/cuda_plugin_extension.cc +++ b/jaxlib/cuda_plugin_extension.cc @@ -38,7 +38,7 @@ namespace xla { namespace { absl::Status RegisterCustomCallTarget(const PJRT_Api* c_api, const char* fn_name_c_str, - size_t fn_name_size, nb::capsule fn, + size_t fn_name_size, nb::object fn, int api_version, XLA_FFI_Handler_Traits traits) { if (c_api->extension_start == nullptr) { @@ -54,6 +54,8 @@ absl::Status RegisterCustomCallTarget(const PJRT_Api* c_api, if (next == nullptr) { return Unimplemented("The plugin does not have a custom call extension."); } + PJRT_Gpu_Register_Custom_Call* register_custom_call = + reinterpret_cast(next)->custom_call; if (traits != 0) { return Unimplemented("The plugin does not support custom call traits."); @@ -63,14 +65,73 @@ absl::Status RegisterCustomCallTarget(const PJRT_Api* c_api, args.struct_size = PJRT_Gpu_Register_Custom_Call_Args_STRUCT_SIZE; args.function_name = fn_name_c_str; args.function_name_size = fn_name_size; + #if PJRT_API_GPU_EXTENSION_VERSION >= 1 args.api_version = api_version; #endif - args.custom_call_function = static_cast(fn.data()); - RETURN_STATUS_IF_PJRT_ERROR( - reinterpret_cast(next)->custom_call(&args), - c_api); + + auto as_capsule = [](nb::object obj) -> absl::StatusOr { + nb::capsule capsule; + if (!nb::try_cast(obj, capsule)) { + return absl::InvalidArgumentError( + "Custom call target registration requires handlers as PyCapsules"); + } + return capsule; + }; + +#if PJRT_API_GPU_EXTENSION_VERSION <= 1 + TF_ASSIGN_OR_RETURN(nb::capsule fn_execute, as_capsule(fn)); + args.custom_call_function = fn_execute.data(); + RETURN_STATUS_IF_PJRT_ERROR(register_custom_call(&args), c_api); return absl::OkStatus(); +#else + args.handler_instantiate = nullptr; + args.handler_prepare = nullptr; + args.handler_initialize = nullptr; + args.handler_execute = nullptr; + + // Register legacy custom call target (untyped void* API). + if (api_version == 0) { + TF_ASSIGN_OR_RETURN(nb::capsule capsule_execute, as_capsule(fn)); + args.handler_execute = capsule_execute.data(); + RETURN_STATUS_IF_PJRT_ERROR(register_custom_call(&args), c_api); + return absl::OkStatus(); + } + + // Register XLA FFI handler (typed API with explicit function signatures). + if (api_version == 1) { + auto capsule_execute = as_capsule(fn); + if (capsule_execute.ok()) { + args.handler_execute = capsule_execute->data(); + RETURN_STATUS_IF_PJRT_ERROR(register_custom_call(&args), c_api); + return absl::OkStatus(); + } + + nb::dict bundle; + if (nb::try_cast(fn, bundle)) { + auto handler = [&](const char* name) -> absl::StatusOr { + if (!bundle.contains(name)) return nullptr; + TF_ASSIGN_OR_RETURN(nb::capsule capsule, as_capsule(bundle[name])); + return capsule.data(); + }; + + TF_ASSIGN_OR_RETURN(args.handler_instantiate, handler("instantiate")); + TF_ASSIGN_OR_RETURN(args.handler_prepare, handler("prepare")); + TF_ASSIGN_OR_RETURN(args.handler_initialize, handler("initialize")); + TF_ASSIGN_OR_RETURN(args.handler_execute, handler("execute")); + RETURN_STATUS_IF_PJRT_ERROR(register_custom_call(&args), c_api); + return absl::OkStatus(); + } + + return absl::InvalidArgumentError( + "Unsupported custom call target type for api_version=1"); + } + + return absl::UnimplementedError(absl::StrFormat( + "API version %d is not supported by RegisterCustomCallTarget. " + "Supported versions are 0 and 1.", + api_version)); +#endif } nb::dict Registrations() { @@ -97,7 +158,7 @@ NB_MODULE(cuda_plugin_extension, m) { tsl::ImportNumpy(); m.def( "register_custom_call_target", - [](nb::capsule c_api, nb::object fn_name_py, nb::capsule fn, + [](nb::capsule c_api, nb::object fn_name_py, nb::object fn, nb::str xla_platform_name, int api_version, XLA_FFI_Handler_Traits traits) { const char* fn_name_c_str; diff --git a/jaxlib/ffi_helpers.h b/jaxlib/ffi_helpers.h index 2374680c20eb..47505020f3b8 100644 --- a/jaxlib/ffi_helpers.h +++ b/jaxlib/ffi_helpers.h @@ -1,11 +1,16 @@ #ifndef JAXLIB_FFI_HELPERS_H_ #define JAXLIB_FFI_HELPERS_H_ +#include #include #include #include +#include #include +#include #include +#include +#include #include "absl/algorithm/container.h" #include "absl/base/optimization.h" @@ -17,12 +22,7 @@ namespace jax { -#define FFI_ASSIGN_OR_RETURN(lhs, rhs) \ - if (ABSL_PREDICT_FALSE(!rhs.ok())) { \ - return ::jax::AsFfiError(rhs.status()); \ - } \ - lhs = rhs.value() - +// Returns from the function if the argument is an ffi::Error. #define FFI_RETURN_IF_ERROR(...) \ do { \ ::xla::ffi::Error err = (__VA_ARGS__); \ @@ -31,6 +31,8 @@ namespace jax { } \ } while (0) +// Returns from the function with an ffi::Error if the argument is an +// absl::Status. #define FFI_RETURN_IF_ERROR_STATUS(...) \ do { \ ::absl::Status status = (__VA_ARGS__); \ @@ -39,6 +41,37 @@ namespace jax { } \ } while (0) +// Returns from the function with an ffi::Error if the RHS is an absl::Status, +// otherwise assigns to the LHS. Most of the complication here stems from the +// fact that we want to support having the LHS wrapped in parentheses (when +// unpacking a tuple, for example). +#define FFI_ASSIGN_OR_RETURN(lhs, rhs) \ + FFI_ASSIGN_OR_RETURN_IMPL_( \ + FFI_ASSIGN_OR_RETURN_CONCAT_(_status_or_value, __LINE__), lhs, rhs) + +#define FFI_ASSIGN_OR_RETURN_IMPL_(statusor, lhs, rhs) \ + auto statusor = (rhs); \ + if (ABSL_PREDICT_FALSE(!statusor.ok())) { \ + return ::jax::AsFfiError(statusor.status()); \ + } \ + FFI_ASSIGN_OR_RETURN_UNPARENTHESIZE_IF_PARENTHESIZED(lhs) = \ + (*std::move(statusor)) + +#define FFI_ASSIGN_OR_RETURN_CONCAT_INNER_(x, y) x##y +#define FFI_ASSIGN_OR_RETURN_CONCAT_(x, y) \ + FFI_ASSIGN_OR_RETURN_CONCAT_INNER_(x, y) + +// All the macros below here are to handle the case in FFI_ASSIGN_OR_RETURN +// where the LHS is wrapped in parentheses. See a more detailed discussion at +// https://stackoverflow.com/a/62984543 +#define FFI_ASSIGN_OR_RETURN_UNPARENTHESIZE_IF_PARENTHESIZED(X) \ + FFI_ASSIGN_OR_RETURN_ESCAPE(FFI_ASSIGN_OR_RETURN_EMPTY X) +#define FFI_ASSIGN_OR_RETURN_EMPTY(...) FFI_ASSIGN_OR_RETURN_EMPTY __VA_ARGS__ +#define FFI_ASSIGN_OR_RETURN_ESCAPE(...) \ + FFI_ASSIGN_OR_RETURN_ESCAPE_(__VA_ARGS__) +#define FFI_ASSIGN_OR_RETURN_ESCAPE_(...) FFI_ASSIGN_OR_RETURN_##__VA_ARGS__ +#define FFI_ASSIGN_OR_RETURN_FFI_ASSIGN_OR_RETURN_EMPTY + template inline absl::StatusOr MaybeCastNoOverflow( std::int64_t value, const std::string& source = __FILE__) { @@ -55,30 +88,100 @@ inline absl::StatusOr MaybeCastNoOverflow( } } -inline xla::ffi::Error AsFfiError(const absl::Status& status) { +inline ::xla::ffi::Error AsFfiError(const absl::Status& status) { if (ABSL_PREDICT_FALSE(!status.ok())) { - return xla::ffi::Error(static_cast(status.code()), - std::string(status.message())); + return ::xla::ffi::Error(static_cast(status.code()), + std::string(status.message())); } else { - return xla::ffi::Error::Success(); + return ::xla::ffi::Error::Success(); } } -template -xla::ffi::Error CheckMatrixDimensions(xla::ffi::Span dims) { +inline int64_t GetBatchSize(::xla::ffi::Span dims) { + return absl::c_accumulate(dims, 1, std::multiplies()); +} + +inline absl::StatusOr> SplitBatch1D( + ::xla::ffi::Span dims, + const std::string& source = __FILE__) { + if (dims.size() < 1) { + return absl::InvalidArgumentError( + absl::StrFormat("%s: Argument must have at least 1 dimension", source)); + } + return std::make_pair(GetBatchSize(dims.first(dims.size() - 1)), dims.back()); +} + +inline absl::StatusOr> SplitBatch2D( + ::xla::ffi::Span dims, + const std::string& source = __FILE__) { if (dims.size() < 2) { - return xla::ffi::Error(xla::ffi::ErrorCode::kInvalidArgument, - "Matrix must have at least 2 dimensions"); + return absl::InvalidArgumentError(absl::StrFormat( + "%s: Argument must have at least 2 dimensions", source)); } - return xla::ffi::Error::Success(); + auto trailingDims = dims.last(2); + return std::make_tuple(GetBatchSize(dims.first(dims.size() - 2)), + trailingDims.front(), trailingDims.back()); } -template -std::tuple SplitBatch2D(xla::ffi::Span dims) { - auto matrix_dims = dims.last(2); - return std::make_tuple(absl::c_accumulate(dims.first(dims.size() - 2), 1, - std::multiplies()), - matrix_dims.front(), matrix_dims.back()); +inline ::xla::ffi::Error CheckShape(::xla::ffi::Span dimensions, + int64_t expected_batch, + std::string_view name, + std::string_view op) { + auto batch = GetBatchSize(dimensions); + if (batch != expected_batch) { + return ::xla::ffi::Error::InvalidArgument(absl::StrFormat( + "Invalid total batch size for input %s to %s. Expected %d, got %d.", + name, op, expected_batch, batch)); + } + return ::xla::ffi::Error::Success(); +} + +inline ::xla::ffi::Error CheckShape(::xla::ffi::Span dimensions, + std::tuple shape, + std::string_view name, + std::string_view op) { + FFI_ASSIGN_OR_RETURN((auto [batch, size]), SplitBatch1D(dimensions)); + auto [expected_batch, expected_size] = shape; + if (batch != expected_batch) { + return ::xla::ffi::Error::InvalidArgument(absl::StrFormat( + "Invalid total batch size for input %s to %s. Expected %d, got %d.", + name, op, expected_batch, batch)); + } + if (batch != expected_batch || size != expected_size) { + return ::xla::ffi::Error::InvalidArgument( + absl::StrFormat("Invalid trailing dimension for input %s " + "to %s. Expected %d, got %d.", + name, op, expected_size, size)); + } + return ::xla::ffi::Error::Success(); +} + +inline ::xla::ffi::Error CheckShape(::xla::ffi::Span dimensions, + std::tuple shape, + std::string_view name, + std::string_view op) { + FFI_ASSIGN_OR_RETURN((auto [batch, rows, cols]), SplitBatch2D(dimensions)); + auto [expected_batch, expected_rows, expected_cols] = shape; + if (batch != expected_batch) { + return ::xla::ffi::Error::InvalidArgument(absl::StrFormat( + "Invalid total batch size for input %s to %s. Expected %d, got %d.", + name, op, expected_batch, batch)); + } + if (rows != expected_rows || cols != expected_cols) { + return ::xla::ffi::Error::InvalidArgument( + absl::StrFormat("Invalid matrix dimensions for input %s to %s. " + "Expected (%d, %d), got (%d, %d).", + name, op, expected_rows, expected_cols, rows, cols)); + } + return ::xla::ffi::Error::Success(); +} + +template <::xla::ffi::DataType dtype> +auto AllocateScratchMemory(std::size_t size) + -> std::unique_ptr>[]> { + // TODO(paruzelp): use std::make_unique_for_overwrite when C++20 is available. + using ValueType = std::remove_extent_t<::xla::ffi::NativeType>; + return std::unique_ptr(new ValueType[size]); } } // namespace jax diff --git a/jaxlib/gpu/BUILD b/jaxlib/gpu/BUILD index daa03aa5be24..048ea23a9cff 100644 --- a/jaxlib/gpu/BUILD +++ b/jaxlib/gpu/BUILD @@ -14,13 +14,18 @@ # Shared CUDA/ROCM GPU kernels. -load("//jaxlib:jax.bzl", "cc_proto_library") +load( + "//jaxlib:jax.bzl", + "cc_proto_library", + "jax_visibility", + "xla_py_proto_library", +) licenses(["notice"]) package( default_applicable_licenses = [], - default_visibility = ["//:__subpackages__"], + default_visibility = ["//jax:internal"], ) exports_files(srcs = [ @@ -29,8 +34,6 @@ exports_files(srcs = [ "blas_handle_pool.h", "blas_kernels.cc", "blas_kernels.h", - "blas_kernels_ffi.cc", - "blas_kernels_ffi.h", "gpu_kernel_helpers.cc", "gpu_kernel_helpers.h", "gpu_kernels.cc", @@ -38,6 +41,8 @@ exports_files(srcs = [ "linalg_kernels.cc", "linalg_kernels.cu.cc", "linalg_kernels.h", + "make_batch_pointers.cu.cc", + "make_batch_pointers.h", "prng.cc", "prng_kernels.cc", "prng_kernels.cu.cc", @@ -48,6 +53,8 @@ exports_files(srcs = [ "solver.cc", "solver_handle_pool.cc", "solver_handle_pool.h", + "solver_interface.cc", + "solver_interface.h", "solver_kernels.cc", "solver_kernels.h", "solver_kernels_ffi.cc", @@ -72,3 +79,10 @@ cc_proto_library( name = "triton_cc_proto", deps = [":triton_proto"], ) + +xla_py_proto_library( + name = "triton_py_pb2", + api_version = 2, + visibility = jax_visibility("triton_proto_py_users"), + deps = [":triton_proto"], +) diff --git a/jaxlib/gpu/blas.cc b/jaxlib/gpu/blas.cc index 62a1bbc94790..e8761bd32ac9 100644 --- a/jaxlib/gpu/blas.cc +++ b/jaxlib/gpu/blas.cc @@ -22,7 +22,6 @@ limitations under the License. #include "absl/container/flat_hash_map.h" #include "absl/strings/str_format.h" #include "jaxlib/gpu/blas_kernels.h" -#include "jaxlib/gpu/blas_kernels_ffi.h" #include "jaxlib/gpu/vendor.h" #include "jaxlib/kernel_nanobind_helpers.h" #include "xla/tsl/python/lib/core/numpy.h" @@ -70,9 +69,6 @@ nb::dict Registrations() { nb::dict dict; dict[JAX_GPU_PREFIX "blas_getrf_batched"] = EncapsulateFunction(GetrfBatched); dict[JAX_GPU_PREFIX "blas_geqrf_batched"] = EncapsulateFunction(GeqrfBatched); - - dict[JAX_GPU_PREFIX "blas_getrf_batched_ffi"] = - EncapsulateFfiHandler(GetrfBatchedFfi); return dict; } diff --git a/jaxlib/gpu/blas_kernels.cc b/jaxlib/gpu/blas_kernels.cc index a963aa3fd762..ac30aa9cc520 100644 --- a/jaxlib/gpu/blas_kernels.cc +++ b/jaxlib/gpu/blas_kernels.cc @@ -26,6 +26,7 @@ limitations under the License. #include "absl/strings/str_format.h" #include "jaxlib/gpu/blas_handle_pool.h" #include "jaxlib/gpu/gpu_kernel_helpers.h" +#include "jaxlib/gpu/make_batch_pointers.h" #include "jaxlib/gpu/vendor.h" #include "jaxlib/kernel_helpers.h" #include "xla/service/custom_call_status.h" @@ -69,13 +70,9 @@ static absl::Status GetrfBatched_(gpuStream_t stream, void** buffers, int* ipiv = static_cast(buffers[2]); int* info = static_cast(buffers[3]); - auto a_ptrs_host = MakeBatchPointers(stream, buffers[1], buffers[4], d.batch, - SizeOfBlasType(d.type) * d.n * d.n); - JAX_RETURN_IF_ERROR(a_ptrs_host.status()); - // TODO(phawkins): ideally we would not need to synchronize here, but to - // avoid it we need a way to keep the host-side buffer alive until the copy - // completes. - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpuStreamSynchronize(stream))); + MakeBatchPointersAsync(stream, buffers[1], buffers[4], d.batch, + SizeOfBlasType(d.type) * d.n * d.n); + JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpuGetLastError())); switch (d.type) { case BlasType::F32: { float** batch_ptrs = static_cast(buffers[4]); @@ -132,17 +129,12 @@ static absl::Status GeqrfBatched_(gpuStream_t stream, void** buffers, } std::vector info(d.batch); - auto a_ptrs_host = MakeBatchPointers(stream, buffers[1], buffers[3], d.batch, - SizeOfBlasType(d.type) * d.m * d.n); - JAX_RETURN_IF_ERROR(a_ptrs_host.status()); - auto tau_ptrs_host = - MakeBatchPointers(stream, buffers[2], buffers[4], d.batch, - SizeOfBlasType(d.type) * std::min(d.m, d.n)); - JAX_RETURN_IF_ERROR(tau_ptrs_host.status()); - // TODO(phawkins): ideally we would not need to synchronize here, but to - // avoid it we need a way to keep the host-side buffer alive until the copy - // completes. - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpuStreamSynchronize(stream))); + MakeBatchPointersAsync(stream, buffers[1], buffers[3], d.batch, + SizeOfBlasType(d.type) * d.m * d.n); + JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpuGetLastError())); + MakeBatchPointersAsync(stream, buffers[2], buffers[4], d.batch, + SizeOfBlasType(d.type) * std::min(d.m, d.n)); + JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpuGetLastError())); switch (d.type) { case BlasType::F32: { float** a_batch_ptrs = static_cast(buffers[3]); diff --git a/jaxlib/gpu/blas_kernels_ffi.cc b/jaxlib/gpu/blas_kernels_ffi.cc deleted file mode 100644 index 610ce105260e..000000000000 --- a/jaxlib/gpu/blas_kernels_ffi.cc +++ /dev/null @@ -1,133 +0,0 @@ -/* Copyright 2024 The JAX Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "jaxlib/gpu/blas_kernels_ffi.h" - -#include "absl/status/status.h" -#include "jaxlib/ffi_helpers.h" -#include "jaxlib/gpu/blas_handle_pool.h" -#include "jaxlib/gpu/gpu_kernel_helpers.h" -#include "jaxlib/gpu/vendor.h" -#include "xla/ffi/api/ffi.h" - -namespace jax { -namespace JAX_GPU_NAMESPACE { - -namespace ffi = ::xla::ffi; - -namespace { -#define GETRF_BATCHED_KERNEL_IMPL(type, name) \ - template <> \ - struct GetrfBatchedKernel { \ - static absl::Status Run(gpublasHandle_t handle, int n, type** a, int lda, \ - int* ipiv, int* info, int batch) { \ - return JAX_AS_STATUS(name(handle, n, a, lda, ipiv, info, batch)); \ - } \ - } - -template -struct GetrfBatchedKernel; -GETRF_BATCHED_KERNEL_IMPL(float, gpublasSgetrfBatched); -GETRF_BATCHED_KERNEL_IMPL(double, gpublasDgetrfBatched); -GETRF_BATCHED_KERNEL_IMPL(gpublasComplex, gpublasCgetrfBatched); -GETRF_BATCHED_KERNEL_IMPL(gpublasDoubleComplex, gpublasZgetrfBatched); -#undef GETRF_BATCHED_KERNEL_IMPL - -template -ffi::Error GetrfBatchedImpl(gpuStream_t stream, ffi::ScratchAllocator& scratch, - ffi::AnyBuffer a, ffi::Result out, - ffi::Result> ipiv, - ffi::Result> info) { - FFI_RETURN_IF_ERROR(CheckMatrixDimensions(a.dimensions())); - auto [batch, rows, cols] = SplitBatch2D(a.dimensions()); - if (rows != cols) { - return ffi::Error(ffi::ErrorCode::kInvalidArgument, - "getrf_batched only supports square matrices"); - } - FFI_ASSIGN_OR_RETURN(auto n, MaybeCastNoOverflow(cols)); - FFI_ASSIGN_OR_RETURN(auto handle, BlasHandlePool::Borrow(stream)); - - auto maybe_workspace = scratch.Allocate(sizeof(void*) * batch); - if (!maybe_workspace.has_value()) { - return ffi::Error(ffi::ErrorCode::kUnknown, - "Unable to allocate workspace for batched getrf"); - } - auto workspace = maybe_workspace.value(); - - auto a_data = a.untyped_data(); - auto out_data = out->untyped_data(); - auto ipiv_data = ipiv->typed_data(); - auto info_data = info->typed_data(); - if (a_data != out_data) { - FFI_RETURN_IF_ERROR_STATUS(JAX_AS_STATUS( - gpuMemcpyAsync(out_data, a_data, sizeof(T) * batch * cols * cols, - gpuMemcpyDeviceToDevice, stream))); - } - - FFI_ASSIGN_OR_RETURN( - auto a_ptrs_host, - MakeBatchPointers(stream, out_data, workspace, batch, sizeof(T) * n * n)); - // TODO(phawkins, danfm): ideally we would not need to synchronize here, but - // to avoid it we need a way to keep the host-side buffer alive until the copy - // completes. - FFI_RETURN_IF_ERROR_STATUS(JAX_AS_STATUS(gpuStreamSynchronize(stream))); - - auto batch_ptrs = static_cast(workspace); - FFI_RETURN_IF_ERROR_STATUS(GetrfBatchedKernel::Run( - handle.get(), n, batch_ptrs, n, ipiv_data, info_data, batch)); - - return ffi::Error::Success(); -} - -ffi::Error GetrfBatchedDispatch( - gpuStream_t stream, ffi::ScratchAllocator scratch, ffi::AnyBuffer a, - ffi::Result out, - ffi::Result> ipiv, - ffi::Result> info) { - auto dataType = a.element_type(); - if (dataType != out->element_type()) { - return ffi::Error( - ffi::ErrorCode::kInvalidArgument, - "Input and output to getrf_batched must have the same element type"); - } - if (dataType == ffi::DataType::F32) { - return GetrfBatchedImpl(stream, scratch, a, out, ipiv, info); - } else if (dataType == ffi::DataType::F64) { - return GetrfBatchedImpl(stream, scratch, a, out, ipiv, info); - } else if (dataType == ffi::DataType::C64) { - return GetrfBatchedImpl(stream, scratch, a, out, ipiv, - info); - } else if (dataType == ffi::DataType::C128) { - return GetrfBatchedImpl(stream, scratch, a, out, ipiv, - info); - } - return ffi::Error(ffi::ErrorCode::kInvalidArgument, - "Unsupported element type for getrf"); -} -} // namespace - -XLA_FFI_DEFINE_HANDLER_SYMBOL( - GetrfBatchedFfi, GetrfBatchedDispatch, - ffi::Ffi::Bind() - .Ctx>() - .Ctx() - .Arg() // a - .Ret() // out - .Ret>() // ipiv - .Ret>() // info -); - -} // namespace JAX_GPU_NAMESPACE -} // namespace jax diff --git a/jaxlib/gpu/gpu_kernel_helpers.cc b/jaxlib/gpu/gpu_kernel_helpers.cc index f43122f2efaa..5a434f4b6ad5 100644 --- a/jaxlib/gpu/gpu_kernel_helpers.cc +++ b/jaxlib/gpu/gpu_kernel_helpers.cc @@ -313,20 +313,5 @@ absl::Status AsStatus(cufftResult error, const char* file, std::int64_t line, } #endif -absl::StatusOr> MakeBatchPointers( - gpuStream_t stream, void* buffer, void* dev_ptrs, int batch, - int batch_elem_size) { - char* ptr = static_cast(buffer); - auto host_ptrs = absl::make_unique(batch); - for (int i = 0; i < batch; ++i) { - host_ptrs[i] = ptr; - ptr += batch_elem_size; - } - JAX_RETURN_IF_ERROR(JAX_AS_STATUS( - gpuMemcpyAsync(dev_ptrs, host_ptrs.get(), sizeof(void*) * batch, - gpuMemcpyHostToDevice, stream))); - return std::move(host_ptrs); -} - } // namespace JAX_GPU_NAMESPACE } // namespace jax diff --git a/jaxlib/gpu/gpu_kernel_helpers.h b/jaxlib/gpu/gpu_kernel_helpers.h index 46fca7bc4bd4..aecb8a4fdcf1 100644 --- a/jaxlib/gpu/gpu_kernel_helpers.h +++ b/jaxlib/gpu/gpu_kernel_helpers.h @@ -67,16 +67,6 @@ absl::Status AsStatus(cufftResult error, const char* file, std::int64_t line, const char* expr); #endif -// Builds an array of pointers to each array in a batch, in device memory. -// Caution: the return value must be kept alive (e.g., via a stream -// synchronization) until the copy enqueued by MakeBatchPointers on `stream` -// completes. -absl::StatusOr> MakeBatchPointers(gpuStream_t stream, - void* buffer, - void* dev_ptrs, - int batch, - int batch_elem_size); - } // namespace JAX_GPU_NAMESPACE } // namespace jax diff --git a/jaxlib/gpu/gpu_kernels.cc b/jaxlib/gpu/gpu_kernels.cc index ccca8e157b98..62977c5f57a1 100644 --- a/jaxlib/gpu/gpu_kernels.cc +++ b/jaxlib/gpu/gpu_kernels.cc @@ -17,7 +17,6 @@ limitations under the License. // JAX-generated HLO code from outside of JAX. #include "jaxlib/gpu/blas_kernels.h" -#include "jaxlib/gpu/blas_kernels_ffi.h" #include "jaxlib/gpu/linalg_kernels.h" #include "jaxlib/gpu/prng_kernels.h" #include "jaxlib/gpu/rnn_kernels.h" @@ -36,8 +35,6 @@ namespace { XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cublas_getrf_batched", GetrfBatched, "CUDA"); -XLA_FFI_REGISTER_HANDLER(XLA_FFI_GetApi(), "cublas_getrf_batched_ffi", "CUDA", - GetrfBatchedFfi); XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cublas_geqrf_batched", GeqrfBatched, "CUDA"); XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cudnn_rnn", RNNForward, "CUDA"); @@ -49,15 +46,25 @@ XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cu_threefry2x32", ThreeFry2x32, XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cusolver_getrf", Getrf, "CUDA"); XLA_FFI_REGISTER_HANDLER(XLA_FFI_GetApi(), "cusolver_getrf_ffi", "CUDA", GetrfFfi); +XLA_FFI_REGISTER_HANDLER(XLA_FFI_GetApi(), "cusolver_syrk_ffi", "CUDA", + SyrkFfi); XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cusolver_geqrf", Geqrf, "CUDA"); +XLA_FFI_REGISTER_HANDLER(XLA_FFI_GetApi(), "cusolver_geqrf_ffi", "CUDA", + GeqrfFfi); XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cusolver_csrlsvqr", Csrlsvqr, "CUDA"); XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cusolver_orgqr", Orgqr, "CUDA"); +XLA_FFI_REGISTER_HANDLER(XLA_FFI_GetApi(), "cusolver_orgqr_ffi", "CUDA", + OrgqrFfi); XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cusolver_syevd", Syevd, "CUDA"); XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cusolver_syevj", Syevj, "CUDA"); +XLA_FFI_REGISTER_HANDLER(XLA_FFI_GetApi(), "cusolver_syevd_ffi", "CUDA", + SyevdFfi); XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cusolver_sytrd", Sytrd, "CUDA"); XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cusolver_gesvd", Gesvd, "CUDA"); XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("cusolver_gesvdj", Gesvdj, "CUDA"); +XLA_FFI_REGISTER_HANDLER(XLA_FFI_GetApi(), "cu_cholesky_update_ffi", "CUDA", + CholeskyUpdateFfi); XLA_FFI_REGISTER_HANDLER(XLA_FFI_GetApi(), "cu_lu_pivots_to_permutation", "CUDA", LuPivotsToPermutation); XLA_FFI_REGISTER_HANDLER(XLA_FFI_GetApi(), "cu_threefry2x32_ffi", "CUDA", diff --git a/jaxlib/gpu/linalg.cc b/jaxlib/gpu/linalg.cc index 189d3a01e382..0ab2b87a290d 100644 --- a/jaxlib/gpu/linalg.cc +++ b/jaxlib/gpu/linalg.cc @@ -41,6 +41,8 @@ NB_MODULE(_linalg, m) { EncapsulateFfiHandler(LuPivotsToPermutation); dict[JAX_GPU_PREFIX "_cholesky_update"] = EncapsulateFunction(CholeskyUpdate); + dict[JAX_GPU_PREFIX "_cholesky_update_ffi"] = + EncapsulateFunction(CholeskyUpdateFfi); return dict; }); m.def("build_cholesky_update_descriptor", &BuildCholeskyUpdateDescriptor); diff --git a/jaxlib/gpu/linalg_kernels.cc b/jaxlib/gpu/linalg_kernels.cc index 6636f5654180..039a9b5c1019 100644 --- a/jaxlib/gpu/linalg_kernels.cc +++ b/jaxlib/gpu/linalg_kernels.cc @@ -16,13 +16,9 @@ limitations under the License. #include "jaxlib/gpu/linalg_kernels.h" #include -#include -#include -#include #include #include -#include "absl/algorithm/container.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_format.h" @@ -44,7 +40,8 @@ absl::Status CholeskyUpdateImpl(gpuStream_t stream, void** buffers, auto s = UnpackDescriptor(opaque, opaque_len); JAX_RETURN_IF_ERROR(s.status()); const CholeskyUpdateDescriptor& d = **s; - LaunchCholeskyUpdateKernel(stream, buffers, d); + JAX_RETURN_IF_ERROR( + JAX_AS_STATUS(LaunchCholeskyUpdateKernel(stream, buffers, d))); JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpuGetLastError())); return absl::OkStatus(); } @@ -59,23 +56,84 @@ void CholeskyUpdate(gpuStream_t stream, void** buffers, const char* opaque, } } +namespace { +ffi::Error CholeskyUpdateFfiImpl(gpuStream_t stream, ffi::AnyBuffer matrix_in, + ffi::AnyBuffer vector_in, + ffi::Result matrix_out, + ffi::Result vector_out) { + FFI_ASSIGN_OR_RETURN((auto [batch, rows, cols]), + SplitBatch2D(matrix_in.dimensions())); + if (rows != cols) { + return ffi::Error::InvalidArgument( + "The matrix input to Cholesky update must be square."); + } + FFI_RETURN_IF_ERROR(CheckShape(vector_in.dimensions(), {batch, cols}, + "vector", "cholesky_update")); + FFI_RETURN_IF_ERROR(CheckShape(matrix_out->dimensions(), {batch, rows, cols}, + "matrix_out", "cholesky_update")); + FFI_RETURN_IF_ERROR(CheckShape(vector_out->dimensions(), {batch, cols}, + "vector_out", "cholesky_update")); + FFI_ASSIGN_OR_RETURN(auto size, MaybeCastNoOverflow(cols)); + auto dtype = matrix_in.element_type(); + if (dtype != ffi::F32 && dtype != ffi::F64) { + return ffi::Error::InvalidArgument( + "Invalid input type for Cholesky update; must be float32 or float64."); + } + if (vector_in.element_type() != dtype || + matrix_out->element_type() != dtype || + vector_out->element_type() != dtype) { + return ffi::Error::InvalidArgument( + "All input and output types for Cholesky update must match."); + } + bool is_single_precision = dtype == ffi::F32; + auto matrix = matrix_out->untyped_data(); + if (matrix_in.untyped_data() != matrix) { + FFI_RETURN_IF_ERROR_STATUS(JAX_AS_STATUS( + gpuMemcpyAsync(matrix, matrix_in.untyped_data(), matrix_in.size_bytes(), + gpuMemcpyDeviceToDevice, stream))); + } + auto vector = vector_out->untyped_data(); + if (vector_in.untyped_data() != vector) { + FFI_RETURN_IF_ERROR_STATUS(JAX_AS_STATUS( + gpuMemcpyAsync(vector, vector_in.untyped_data(), vector_in.size_bytes(), + gpuMemcpyDeviceToDevice, stream))); + } + for (auto n = 0; n < batch; ++n) { + FFI_RETURN_IF_ERROR_STATUS(JAX_AS_STATUS(LaunchCholeskyUpdateFfiKernel( + stream, matrix, vector, size, is_single_precision))); + FFI_RETURN_IF_ERROR_STATUS(JAX_AS_STATUS(gpuGetLastError())); + } + return ffi::Error::Success(); +} +} // namespace + +XLA_FFI_DEFINE_HANDLER_SYMBOL(CholeskyUpdateFfi, CholeskyUpdateFfiImpl, + ffi::Ffi::Bind() + .Ctx>() + .Arg() + .Arg() + .Ret() + .Ret()); + namespace { ffi::Error LuPivotsToPermutationImpl( - gpuStream_t stream, std::int32_t permutation_size, + gpuStream_t stream, ffi::Dictionary /* unused */, ffi::Buffer pivots, ffi::Result> permutation) { - auto dims = pivots.dimensions(); - - if (dims.size() < 1) { + FFI_ASSIGN_OR_RETURN((auto [batch_size, pivot_size]), + SplitBatch1D(pivots.dimensions())); + FFI_ASSIGN_OR_RETURN((auto [permutation_batch, permutation_size]), + SplitBatch1D(permutation->dimensions())); + if (permutation_batch != batch_size) { return ffi::Error(ffi::ErrorCode::kInvalidArgument, - "pivots must have at least one dimension"); + "pivots and permutation must have the same batch size."); } - FFI_ASSIGN_OR_RETURN(std::int32_t pivot_size, - MaybeCastNoOverflow(dims.back())); - std::int64_t batch_size = 1; - if (dims.size() >= 2) { - batch_size = - absl::c_accumulate(dims.first(dims.size() - 1), 1, std::multiplies<>()); + if (permutation_size < pivot_size) { + return ffi::Error( + ffi::ErrorCode::kInvalidArgument, + absl::StrFormat("Output permutation size %d must match or exceed the " + "trailing dimension of the input pivots %d.", + permutation_size, pivot_size)); } LaunchLuPivotsToPermutationKernel(stream, batch_size, pivot_size, permutation_size, pivots.typed_data(), @@ -88,7 +146,10 @@ ffi::Error LuPivotsToPermutationImpl( XLA_FFI_DEFINE_HANDLER_SYMBOL(LuPivotsToPermutation, LuPivotsToPermutationImpl, ffi::Ffi::Bind() .Ctx>() - .Attr("permutation_size") + // TODO(b/358275922): remove Attrs (and the + // unused Dictionary above) 12 weeks after + // release of jaxlib v0.4.32. + .Attrs() .Arg>() .Ret>()); diff --git a/jaxlib/gpu/linalg_kernels.cu.cc b/jaxlib/gpu/linalg_kernels.cu.cc index 8aa769bb5735..7f87d66fb4ef 100644 --- a/jaxlib/gpu/linalg_kernels.cu.cc +++ b/jaxlib/gpu/linalg_kernels.cu.cc @@ -15,18 +15,11 @@ limitations under the License. #include "jaxlib/gpu/linalg_kernels.h" -#include +#include #include -#include #include "jaxlib/gpu/vendor.h" -#ifdef JAX_GPU_HIP -#include "rocm/include/hip/amd_detail/amd_hip_cooperative_groups.h" -#else // JAX_GPU_CUDA -#include "third_party/gpus/cuda/include/cooperative_groups.h" -#endif - namespace cg = cooperative_groups; namespace jax { @@ -47,7 +40,6 @@ __device__ void drotg(T* da, T* db, T* c, T* s) { T rh = rhypot(a, b); *c = a * rh; *s = -(b * rh); - return; } template @@ -75,8 +67,9 @@ __global__ void CholeskyUpdateKernel(T* rMatrix, T* uVector, int nSize) { } // namespace template -void LaunchCholeskyUpdateKernelBody(gpuStream_t stream, void** buffers, - int grid_dim, int block_dim, int nSize) { +gpuError_t LaunchCholeskyUpdateKernelBody(gpuStream_t stream, void** buffers, + int grid_dim, int block_dim, + int nSize) { T* rMatrix = reinterpret_cast(buffers[2]); T* uVector = reinterpret_cast(buffers[3]); @@ -85,43 +78,72 @@ void LaunchCholeskyUpdateKernelBody(gpuStream_t stream, void** buffers, reinterpret_cast(&uVector), reinterpret_cast(&nSize), }; -#ifdef JAX_GPU_HIP - hipLaunchCooperativeKernel((void*)CholeskyUpdateKernel, grid_dim, - block_dim, arg_ptrs, - /*dynamic_shared_mem_bytes=*/0, stream); -#else // JAX_GPU_CUDA - cudaLaunchCooperativeKernel((void*)CholeskyUpdateKernel, grid_dim, - block_dim, arg_ptrs, - /*dynamic_shared_mem_bytes=*/0, stream); -#endif + return gpuLaunchCooperativeKernel((void*)CholeskyUpdateKernel, grid_dim, + block_dim, arg_ptrs, + /*dynamic_shared_mem_bytes=*/0, stream); } -void LaunchCholeskyUpdateKernel(gpuStream_t stream, void** buffers, - CholeskyUpdateDescriptor descriptor) { +gpuError_t LaunchCholeskyUpdateKernel(gpuStream_t stream, void** buffers, + CholeskyUpdateDescriptor descriptor) { int nSize = descriptor.matrix_size; LinalgType type = descriptor.linalg_type; int dev = 0; -#ifdef JAX_GPU_HIP - hipDeviceProp_t deviceProp; - hipGetDeviceProperties(&deviceProp, dev); -#else // JAX_GPU_CUDA - cudaDeviceProp deviceProp; - cudaGetDeviceProperties(&deviceProp, dev); -#endif + gpuDeviceProp deviceProp; + gpuError_t err = gpuGetDeviceProperties(&deviceProp, dev); + if (err != gpuSuccess) { + return err; + } int block_dim = deviceProp.maxThreadsPerBlock; int grid_dim = deviceProp.multiProcessorCount; switch (type) { case LinalgType::F64: - LaunchCholeskyUpdateKernelBody(stream, buffers, grid_dim, - block_dim, nSize); - break; + return LaunchCholeskyUpdateKernelBody(stream, buffers, grid_dim, + block_dim, nSize); case LinalgType::F32: - LaunchCholeskyUpdateKernelBody(stream, buffers, grid_dim, - block_dim, nSize); - break; + return LaunchCholeskyUpdateKernelBody(stream, buffers, grid_dim, + block_dim, nSize); + } +} + +template +gpuError_t LaunchCholeskyUpdateFfiKernelBody(gpuStream_t stream, void* matrix, + void* vector, int grid_dim, + int block_dim, int nSize) { + T* rMatrix = reinterpret_cast(matrix); + T* uVector = reinterpret_cast(vector); + + void* arg_ptrs[3] = { + reinterpret_cast(&rMatrix), + reinterpret_cast(&uVector), + reinterpret_cast(&nSize), + }; + return gpuLaunchCooperativeKernel((void*)CholeskyUpdateKernel, grid_dim, + block_dim, arg_ptrs, + /*dynamic_shared_mem_bytes=*/0, stream); +} + +gpuError_t LaunchCholeskyUpdateFfiKernel(gpuStream_t stream, void* matrix, + void* vector, int size, + bool is_single_precision) { + int dev = 0; + gpuDeviceProp deviceProp; + gpuError_t err = gpuGetDeviceProperties(&deviceProp, dev); + if (err != gpuSuccess) { + return err; + } + + int block_dim = deviceProp.maxThreadsPerBlock; + int grid_dim = deviceProp.multiProcessorCount; + + if (is_single_precision) { + return LaunchCholeskyUpdateFfiKernelBody(stream, matrix, vector, + grid_dim, block_dim, size); + } else { + return LaunchCholeskyUpdateFfiKernelBody(stream, matrix, vector, + grid_dim, block_dim, size); } } diff --git a/jaxlib/gpu/linalg_kernels.h b/jaxlib/gpu/linalg_kernels.h index 73a0ac173d41..2c41b7f4350d 100644 --- a/jaxlib/gpu/linalg_kernels.h +++ b/jaxlib/gpu/linalg_kernels.h @@ -26,8 +26,6 @@ limitations under the License. namespace jax { namespace JAX_GPU_NAMESPACE { -namespace ffi = xla::ffi; - enum LinalgType { F32 = 0, F64 = 1, @@ -38,19 +36,23 @@ struct CholeskyUpdateDescriptor { std::int64_t matrix_size; // leading dim (N) for a square (NxN)matrix }; -void LaunchCholeskyUpdateKernel(gpuStream_t stream, void** buffers, - CholeskyUpdateDescriptor descriptor); +gpuError_t LaunchCholeskyUpdateKernel(gpuStream_t stream, void** buffers, + CholeskyUpdateDescriptor descriptor); void CholeskyUpdate(gpuStream_t stream, void** buffers, const char* opaque, size_t opaque_len, XlaCustomCallStatus* status); +gpuError_t LaunchCholeskyUpdateFfiKernel(gpuStream_t stream, void* matrix, + void* vector, int size, + bool is_single_precision); +XLA_FFI_DECLARE_HANDLER_SYMBOL(CholeskyUpdateFfi); + void LaunchLuPivotsToPermutationKernel(gpuStream_t stream, std::int64_t batch_size, std::int32_t pivot_size, std::int32_t permutation_size, const std::int32_t* pivots, std::int32_t* permutation); - XLA_FFI_DECLARE_HANDLER_SYMBOL(LuPivotsToPermutation); } // namespace JAX_GPU_NAMESPACE diff --git a/jaxlib/gpu/make_batch_pointers.cu.cc b/jaxlib/gpu/make_batch_pointers.cu.cc new file mode 100644 index 000000000000..b10655645924 --- /dev/null +++ b/jaxlib/gpu/make_batch_pointers.cu.cc @@ -0,0 +1,46 @@ +/* Copyright 2024 The JAX Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "jaxlib/gpu/make_batch_pointers.h" + +#include + +#include "jaxlib/gpu/vendor.h" + +namespace jax { +namespace JAX_GPU_NAMESPACE { + +namespace { +__global__ void MakeBatchPointersAsyncKernel(char* buffer_in, void** buffer_out, + int batch, int batch_elem_size) { + for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < batch; + idx += blockDim.x * gridDim.x) { + buffer_out[idx] = buffer_in + idx * batch_elem_size; + } +} +} // namespace + +void MakeBatchPointersAsync(gpuStream_t stream, void* buffer_in, + void* buffer_out, int batch, int batch_elem_size) { + const int block_dim = 128; + const std::size_t grid_dim = + std::min(1024, (batch + block_dim - 1) / block_dim); + MakeBatchPointersAsyncKernel<<>>( + static_cast(buffer_in), static_cast(buffer_out), batch, + batch_elem_size); +} + +} // namespace JAX_GPU_NAMESPACE +} // namespace jax diff --git a/jaxlib/gpu/blas_kernels_ffi.h b/jaxlib/gpu/make_batch_pointers.h similarity index 74% rename from jaxlib/gpu/blas_kernels_ffi.h rename to jaxlib/gpu/make_batch_pointers.h index ad3bf90120e9..f2fd064961e8 100644 --- a/jaxlib/gpu/blas_kernels_ffi.h +++ b/jaxlib/gpu/make_batch_pointers.h @@ -13,18 +13,18 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef JAXLIB_GPU_BLAS_KERNELS_FFI_H_ -#define JAXLIB_GPU_BLAS_KERNELS_FFI_H_ +#ifndef JAXLIB_GPU_MAKE_BATCH_POINTERS_H_ +#define JAXLIB_GPU_MAKE_BATCH_POINTERS_H_ #include "jaxlib/gpu/vendor.h" -#include "xla/ffi/api/ffi.h" namespace jax { namespace JAX_GPU_NAMESPACE { -XLA_FFI_DECLARE_HANDLER_SYMBOL(GetrfBatchedFfi); +void MakeBatchPointersAsync(gpuStream_t stream, void* buffer_in, + void* buffer_out, int batch, int batch_elem_size); } // namespace JAX_GPU_NAMESPACE } // namespace jax -#endif // JAXLIB_GPU_BLAS_KERNELS_FFI_H_ +#endif // JAXLIB_GPU_MAKE_BATCH_POINTERS_H_ diff --git a/jaxlib/gpu/solver.cc b/jaxlib/gpu/solver.cc index 223c8a9798be..38936ee497cf 100644 --- a/jaxlib/gpu/solver.cc +++ b/jaxlib/gpu/solver.cc @@ -20,8 +20,8 @@ limitations under the License. #include "nanobind/nanobind.h" #include "nanobind/stl/pair.h" #include "absl/container/flat_hash_map.h" -#include "absl/strings/str_format.h" #include "absl/status/statusor.h" +#include "absl/strings/str_format.h" #include "jaxlib/gpu/gpu_kernel_helpers.h" #include "jaxlib/gpu/solver_handle_pool.h" #include "jaxlib/gpu/solver_kernels.h" @@ -473,9 +473,19 @@ nb::dict Registrations() { #ifdef JAX_GPU_CUDA dict["cusolver_csrlsvqr"] = EncapsulateFunction(Csrlsvqr); dict["cusolver_gesvdj"] = EncapsulateFunction(Gesvdj); + #endif // JAX_GPU_CUDA dict[JAX_GPU_PREFIX "solver_getrf_ffi"] = EncapsulateFfiHandler(GetrfFfi); + dict[JAX_GPU_PREFIX "solver_geqrf_ffi"] = EncapsulateFfiHandler(GeqrfFfi); + dict[JAX_GPU_PREFIX "solver_orgqr_ffi"] = EncapsulateFfiHandler(OrgqrFfi); + dict[JAX_GPU_PREFIX "solver_syevd_ffi"] = EncapsulateFfiHandler(SyevdFfi); + dict[JAX_GPU_PREFIX "solver_syrk_ffi"] = EncapsulateFfiHandler(SyrkFfi); + dict[JAX_GPU_PREFIX "solver_gesvd_ffi"] = EncapsulateFfiHandler(GesvdFfi); + +#ifdef JAX_GPU_CUDA + dict[JAX_GPU_PREFIX "solver_gesvdj_ffi"] = EncapsulateFfiHandler(GesvdjFfi); +#endif // JAX_GPU_CUDA return dict; } diff --git a/jaxlib/gpu/solver_interface.cc b/jaxlib/gpu/solver_interface.cc new file mode 100644 index 000000000000..4d1af3c50d76 --- /dev/null +++ b/jaxlib/gpu/solver_interface.cc @@ -0,0 +1,322 @@ +/* Copyright 2024 The JAX Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "jaxlib/gpu/solver_interface.h" + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "jaxlib/gpu/gpu_kernel_helpers.h" +#include "jaxlib/gpu/vendor.h" + +namespace jax { +namespace JAX_GPU_NAMESPACE { +namespace solver { + +// LU decomposition: getrf + +#define JAX_GPU_DEFINE_GETRF(Type, Name) \ + template <> \ + absl::StatusOr GetrfBufferSize(gpusolverDnHandle_t handle, int m, \ + int n) { \ + int lwork; \ + JAX_RETURN_IF_ERROR(JAX_AS_STATUS( \ + Name##_bufferSize(handle, m, n, /*A=*/nullptr, m, &lwork))); \ + return lwork; \ + } \ + \ + template <> \ + absl::Status Getrf(gpusolverDnHandle_t handle, int m, int n, Type *a, \ + Type *workspace, int lwork, int *ipiv, int *info) { \ + return JAX_AS_STATUS( \ + Name(handle, m, n, a, m, workspace, lwork, ipiv, info)); \ + } + +JAX_GPU_DEFINE_GETRF(float, gpusolverDnSgetrf); +JAX_GPU_DEFINE_GETRF(double, gpusolverDnDgetrf); +JAX_GPU_DEFINE_GETRF(gpuComplex, gpusolverDnCgetrf); +JAX_GPU_DEFINE_GETRF(gpuDoubleComplex, gpusolverDnZgetrf); +#undef JAX_GPU_DEFINE_GETRF + +#define JAX_GPU_DEFINE_GETRF_BATCHED(Type, Name) \ + template <> \ + absl::Status GetrfBatched(gpublasHandle_t handle, int n, Type **a, \ + int lda, int *ipiv, int *info, int batch) { \ + return JAX_AS_STATUS(Name(handle, n, a, lda, ipiv, info, batch)); \ + } + +JAX_GPU_DEFINE_GETRF_BATCHED(float, gpublasSgetrfBatched); +JAX_GPU_DEFINE_GETRF_BATCHED(double, gpublasDgetrfBatched); +JAX_GPU_DEFINE_GETRF_BATCHED(gpublasComplex, gpublasCgetrfBatched); +JAX_GPU_DEFINE_GETRF_BATCHED(gpublasDoubleComplex, gpublasZgetrfBatched); +#undef JAX_GPU_DEFINE_GETRF_BATCHED + +// QR decomposition: geqrf + +#define JAX_GPU_DEFINE_GEQRF(Type, Name) \ + template <> \ + absl::StatusOr GeqrfBufferSize(gpusolverDnHandle_t handle, int m, \ + int n) { \ + int lwork; \ + JAX_RETURN_IF_ERROR(JAX_AS_STATUS( \ + Name##_bufferSize(handle, m, n, /*A=*/nullptr, m, &lwork))); \ + return lwork; \ + } \ + \ + template <> \ + absl::Status Geqrf(gpusolverDnHandle_t handle, int m, int n, Type *a, \ + Type *tau, Type *workspace, int lwork, int *info) { \ + return JAX_AS_STATUS( \ + Name(handle, m, n, a, m, tau, workspace, lwork, info)); \ + } + +JAX_GPU_DEFINE_GEQRF(float, gpusolverDnSgeqrf); +JAX_GPU_DEFINE_GEQRF(double, gpusolverDnDgeqrf); +JAX_GPU_DEFINE_GEQRF(gpuComplex, gpusolverDnCgeqrf); +JAX_GPU_DEFINE_GEQRF(gpuDoubleComplex, gpusolverDnZgeqrf); +#undef JAX_GPU_DEFINE_GEQRF + +#define JAX_GPU_DEFINE_GEQRF_BATCHED(Type, Name) \ + template <> \ + absl::Status GeqrfBatched(gpublasHandle_t handle, int m, int n, \ + Type **a, Type **tau, int *info, \ + int batch) { \ + return JAX_AS_STATUS(Name(handle, m, n, a, m, tau, info, batch)); \ + } + +JAX_GPU_DEFINE_GEQRF_BATCHED(float, gpublasSgeqrfBatched); +JAX_GPU_DEFINE_GEQRF_BATCHED(double, gpublasDgeqrfBatched); +JAX_GPU_DEFINE_GEQRF_BATCHED(gpublasComplex, gpublasCgeqrfBatched); +JAX_GPU_DEFINE_GEQRF_BATCHED(gpublasDoubleComplex, gpublasZgeqrfBatched); +#undef JAX_GPU_DEFINE_GEQRF_BATCHED + +// Householder transformations: orgqr + +#define JAX_GPU_DEFINE_ORGQR(Type, Name) \ + template <> \ + absl::StatusOr OrgqrBufferSize(gpusolverDnHandle_t handle, int m, \ + int n, int k) { \ + int lwork; \ + JAX_RETURN_IF_ERROR(JAX_AS_STATUS(Name##_bufferSize( \ + handle, m, n, k, /*A=*/nullptr, /*lda=*/m, /*tau=*/nullptr, &lwork))); \ + return lwork; \ + } \ + \ + template <> \ + absl::Status Orgqr(gpusolverDnHandle_t handle, int m, int n, int k, \ + Type *a, Type *tau, Type *workspace, int lwork, \ + int *info) { \ + return JAX_AS_STATUS( \ + Name(handle, m, n, k, a, m, tau, workspace, lwork, info)); \ + } + +JAX_GPU_DEFINE_ORGQR(float, gpusolverDnSorgqr); +JAX_GPU_DEFINE_ORGQR(double, gpusolverDnDorgqr); +JAX_GPU_DEFINE_ORGQR(gpuComplex, gpusolverDnCungqr); +JAX_GPU_DEFINE_ORGQR(gpuDoubleComplex, gpusolverDnZungqr); +#undef JAX_GPU_DEFINE_ORGQR + +// Symmetric (Hermitian) eigendecomposition: +// * Jacobi algorithm: syevj/heevj (batches of matrices up to 32) +// * QR algorithm: syevd/heevd + +#define JAX_GPU_DEFINE_SYEVJ(Type, Name) \ + template <> \ + absl::StatusOr SyevjBufferSize( \ + gpusolverDnHandle_t handle, gpusolverEigMode_t jobz, \ + gpusolverFillMode_t uplo, int n, gpuSyevjInfo_t params) { \ + int lwork; \ + JAX_RETURN_IF_ERROR(JAX_AS_STATUS( \ + Name##_bufferSize(handle, jobz, uplo, n, /*A=*/nullptr, /*lda=*/n, \ + /*w=*/nullptr, &lwork, params))); \ + return lwork; \ + } \ + \ + template <> \ + absl::Status Syevj( \ + gpusolverDnHandle_t handle, gpusolverEigMode_t jobz, \ + gpusolverFillMode_t uplo, int n, Type *a, RealType::value *w, \ + Type *workspace, int lwork, int *info, gpuSyevjInfo_t params) { \ + return JAX_AS_STATUS( \ + Name(handle, jobz, uplo, n, a, n, w, workspace, lwork, info, params)); \ + } + +JAX_GPU_DEFINE_SYEVJ(float, gpusolverDnSsyevj); +JAX_GPU_DEFINE_SYEVJ(double, gpusolverDnDsyevj); +JAX_GPU_DEFINE_SYEVJ(gpuComplex, gpusolverDnCheevj); +JAX_GPU_DEFINE_SYEVJ(gpuDoubleComplex, gpusolverDnZheevj); +#undef JAX_GPU_DEFINE_SYEVJ + +#define JAX_GPU_DEFINE_SYEVJ_BATCHED(Type, Name) \ + template <> \ + absl::StatusOr SyevjBatchedBufferSize( \ + gpusolverDnHandle_t handle, gpusolverEigMode_t jobz, \ + gpusolverFillMode_t uplo, int n, gpuSyevjInfo_t params, int batch) { \ + int lwork; \ + JAX_RETURN_IF_ERROR(JAX_AS_STATUS( \ + Name##_bufferSize(handle, jobz, uplo, n, /*A=*/nullptr, /*lda=*/n, \ + /*w=*/nullptr, &lwork, params, batch))); \ + return lwork; \ + } \ + \ + template <> \ + absl::Status SyevjBatched( \ + gpusolverDnHandle_t handle, gpusolverEigMode_t jobz, \ + gpusolverFillMode_t uplo, int n, Type *a, RealType::value *w, \ + Type *workspace, int lwork, int *info, gpuSyevjInfo_t params, \ + int batch) { \ + return JAX_AS_STATUS(Name(handle, jobz, uplo, n, a, n, w, workspace, \ + lwork, info, params, batch)); \ + } + +JAX_GPU_DEFINE_SYEVJ_BATCHED(float, gpusolverDnSsyevjBatched); +JAX_GPU_DEFINE_SYEVJ_BATCHED(double, gpusolverDnDsyevjBatched); +JAX_GPU_DEFINE_SYEVJ_BATCHED(gpuComplex, gpusolverDnCheevjBatched); +JAX_GPU_DEFINE_SYEVJ_BATCHED(gpuDoubleComplex, gpusolverDnZheevjBatched); +#undef JAX_GPU_DEFINE_SYEVJ_BATCHED + +#define JAX_GPU_DEFINE_SYEVD(Type, Name) \ + template <> \ + absl::StatusOr SyevdBufferSize(gpusolverDnHandle_t handle, \ + gpusolverEigMode_t jobz, \ + gpusolverFillMode_t uplo, int n) { \ + int lwork; \ + JAX_RETURN_IF_ERROR( \ + JAX_AS_STATUS(Name##_bufferSize(handle, jobz, uplo, n, /*A=*/nullptr, \ + /*lda=*/n, /*w=*/nullptr, &lwork))); \ + return lwork; \ + } \ + \ + template <> \ + absl::Status Syevd(gpusolverDnHandle_t handle, \ + gpusolverEigMode_t jobz, gpusolverFillMode_t uplo, \ + int n, Type *a, RealType::value *w, \ + Type *workspace, int lwork, int *info) { \ + return JAX_AS_STATUS( \ + Name(handle, jobz, uplo, n, a, n, w, workspace, lwork, info)); \ + } + +JAX_GPU_DEFINE_SYEVD(float, gpusolverDnSsyevd); +JAX_GPU_DEFINE_SYEVD(double, gpusolverDnDsyevd); +JAX_GPU_DEFINE_SYEVD(gpuComplex, gpusolverDnCheevd); +JAX_GPU_DEFINE_SYEVD(gpuDoubleComplex, gpusolverDnZheevd); +#undef JAX_GPU_DEFINE_SYEVD + +// Symmetric rank-k update: syrk + +#define JAX_GPU_DEFINE_SYRK(Type, Name) \ + template <> \ + absl::Status Syrk(gpublasHandle_t handle, gpublasFillMode_t uplo, \ + gpublasOperation_t trans, int n, int k, \ + const Type *alpha, const Type *a, const Type *beta, \ + Type *c) { \ + int lda = trans == GPUBLAS_OP_N ? n : k; \ + return JAX_AS_STATUS( \ + Name(handle, uplo, trans, n, k, alpha, a, lda, beta, c, n)); \ + } + +JAX_GPU_DEFINE_SYRK(float, gpublasSsyrk); +JAX_GPU_DEFINE_SYRK(double, gpublasDsyrk); +JAX_GPU_DEFINE_SYRK(gpublasComplex, gpublasCsyrk); +JAX_GPU_DEFINE_SYRK(gpublasDoubleComplex, gpublasZsyrk); +#undef JAX_GPU_DEFINE_SYRK + +// Singular Value Decomposition: gesvd + +#define JAX_GPU_DEFINE_GESVD(Type, Name) \ + template <> \ + absl::StatusOr GesvdBufferSize(gpusolverDnHandle_t handle, \ + signed char job, int m, int n) { \ + int lwork; \ + JAX_RETURN_IF_ERROR( \ + JAX_AS_STATUS(Name##_bufferSize(handle, job, job, m, n, &lwork))); \ + return lwork; \ + } \ + \ + template <> \ + absl::Status Gesvd(gpusolverDnHandle_t handle, signed char job, int m, \ + int n, Type *a, RealType::value *s, Type *u, \ + Type *vt, Type *workspace, int lwork, int *info) { \ + return JAX_AS_STATUS(Name(handle, job, job, m, n, a, m, s, u, m, vt, n, \ + workspace, lwork, /*rwork=*/nullptr, info)); \ + } + +JAX_GPU_DEFINE_GESVD(float, gpusolverDnSgesvd); +JAX_GPU_DEFINE_GESVD(double, gpusolverDnDgesvd); +JAX_GPU_DEFINE_GESVD(gpuComplex, gpusolverDnCgesvd); +JAX_GPU_DEFINE_GESVD(gpuDoubleComplex, gpusolverDnZgesvd); +#undef JAX_GPU_DEFINE_GESVD + +#ifdef JAX_GPU_CUDA + +#define JAX_GPU_DEFINE_GESVDJ(Type, Name) \ + template <> \ + absl::StatusOr GesvdjBufferSize( \ + gpusolverDnHandle_t handle, gpusolverEigMode_t job, int econ, int m, \ + int n, gpuGesvdjInfo_t params) { \ + int lwork; \ + JAX_RETURN_IF_ERROR(JAX_AS_STATUS(Name##_bufferSize( \ + handle, job, econ, m, n, /*a=*/nullptr, /*lda=*/m, /*s=*/nullptr, \ + /*u=*/nullptr, /*ldu=*/m, /*v=*/nullptr, /*ldv=*/n, &lwork, params))); \ + return lwork; \ + } \ + \ + template <> \ + absl::Status Gesvdj( \ + gpusolverDnHandle_t handle, gpusolverEigMode_t job, int econ, int m, \ + int n, Type *a, RealType::value *s, Type *u, Type *v, \ + Type *workspace, int lwork, int *info, gpuGesvdjInfo_t params) { \ + return JAX_AS_STATUS(Name(handle, job, econ, m, n, a, m, s, u, m, v, n, \ + workspace, lwork, info, params)); \ + } + +JAX_GPU_DEFINE_GESVDJ(float, gpusolverDnSgesvdj); +JAX_GPU_DEFINE_GESVDJ(double, gpusolverDnDgesvdj); +JAX_GPU_DEFINE_GESVDJ(gpuComplex, gpusolverDnCgesvdj); +JAX_GPU_DEFINE_GESVDJ(gpuDoubleComplex, gpusolverDnZgesvdj); +#undef JAX_GPU_DEFINE_GESVDJ + +#define JAX_GPU_DEFINE_GESVDJ_BATCHED(Type, Name) \ + template <> \ + absl::StatusOr GesvdjBatchedBufferSize( \ + gpusolverDnHandle_t handle, gpusolverEigMode_t job, int m, int n, \ + gpuGesvdjInfo_t params, int batch) { \ + int lwork; \ + JAX_RETURN_IF_ERROR(JAX_AS_STATUS( \ + Name##_bufferSize(handle, job, m, n, /*a=*/nullptr, /*lda=*/m, \ + /*s=*/nullptr, /*u=*/nullptr, /*ldu=*/m, \ + /*v=*/nullptr, /*ldv=*/n, &lwork, params, batch))); \ + return lwork; \ + } \ + \ + template <> \ + absl::Status GesvdjBatched( \ + gpusolverDnHandle_t handle, gpusolverEigMode_t job, int m, int n, \ + Type *a, RealType::value *s, Type *u, Type *v, Type *workspace, \ + int lwork, int *info, gpuGesvdjInfo_t params, int batch) { \ + return JAX_AS_STATUS(Name(handle, job, m, n, a, m, s, u, m, v, n, \ + workspace, lwork, info, params, batch)); \ + } + +JAX_GPU_DEFINE_GESVDJ_BATCHED(float, gpusolverDnSgesvdjBatched); +JAX_GPU_DEFINE_GESVDJ_BATCHED(double, gpusolverDnDgesvdjBatched); +JAX_GPU_DEFINE_GESVDJ_BATCHED(gpuComplex, gpusolverDnCgesvdjBatched); +JAX_GPU_DEFINE_GESVDJ_BATCHED(gpuDoubleComplex, gpusolverDnZgesvdjBatched); +#undef JAX_GPU_DEFINE_GESVDJ_BATCHED + +#endif // JAX_GPU_CUDA + +} // namespace solver +} // namespace JAX_GPU_NAMESPACE +} // namespace jax diff --git a/jaxlib/gpu/solver_interface.h b/jaxlib/gpu/solver_interface.h new file mode 100644 index 000000000000..336480e2e13b --- /dev/null +++ b/jaxlib/gpu/solver_interface.h @@ -0,0 +1,217 @@ +/* Copyright 2024 The JAX Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// This file defines a standard interface to the GPU linear algebra libraries. + +#ifndef JAXLIB_GPU_SOLVER_INTERFACE_H_ +#define JAXLIB_GPU_SOLVER_INTERFACE_H_ + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_format.h" +#include "jaxlib/gpu/vendor.h" + +namespace jax { +namespace JAX_GPU_NAMESPACE { +namespace solver { + +template +struct RealType { + using value = T; +}; + +template <> +struct RealType { + using value = float; +}; + +template <> +struct RealType { + using value = double; +}; + +#define JAX_GPU_SOLVER_EXPAND_DEFINITION(ReturnType, FunctionName) \ + template \ + ReturnType FunctionName( \ + JAX_GPU_SOLVER_##FunctionName##_ARGS(T, typename RealType::value)) { \ + return absl::UnimplementedError(absl::StrFormat( \ + #FunctionName " not implemented for type %s", typeid(T).name())); \ + } \ + template <> \ + ReturnType FunctionName( \ + JAX_GPU_SOLVER_##FunctionName##_ARGS(float, float)); \ + template <> \ + ReturnType FunctionName( \ + JAX_GPU_SOLVER_##FunctionName##_ARGS(double, double)); \ + template <> \ + ReturnType FunctionName( \ + JAX_GPU_SOLVER_##FunctionName##_ARGS(gpuComplex, float)); \ + template <> \ + ReturnType FunctionName( \ + JAX_GPU_SOLVER_##FunctionName##_ARGS(gpuDoubleComplex, double)) + +// LU decomposition: getrf + +#define JAX_GPU_SOLVER_GetrfBufferSize_ARGS(Type, ...) \ + gpusolverDnHandle_t handle, int m, int n +JAX_GPU_SOLVER_EXPAND_DEFINITION(absl::StatusOr, GetrfBufferSize); +#undef JAX_GPU_SOLVER_GetrfBufferSize_ARGS + +#define JAX_GPU_SOLVER_Getrf_ARGS(Type, ...) \ + gpusolverDnHandle_t handle, int m, int n, Type *a, Type *workspace, \ + int lwork, int *ipiv, int *info +JAX_GPU_SOLVER_EXPAND_DEFINITION(absl::Status, Getrf); +#undef JAX_GPU_SOLVER_Getrf_ARGS + +#define JAX_GPU_SOLVER_GetrfBatched_ARGS(Type, ...) \ + gpublasHandle_t handle, int n, Type **a, int lda, int *ipiv, int *info, \ + int batch +JAX_GPU_SOLVER_EXPAND_DEFINITION(absl::Status, GetrfBatched); +#undef JAX_GPU_SOLVER_GetrfBatched_ARGS + +// QR decomposition: geqrf + +#define JAX_GPU_SOLVER_GeqrfBufferSize_ARGS(Type, ...) \ + gpusolverDnHandle_t handle, int m, int n +JAX_GPU_SOLVER_EXPAND_DEFINITION(absl::StatusOr, GeqrfBufferSize); +#undef JAX_GPU_SOLVER_GeqrfBufferSize_ARGS + +#define JAX_GPU_SOLVER_Geqrf_ARGS(Type, ...) \ + gpusolverDnHandle_t handle, int m, int n, Type *a, Type *tau, \ + Type *workspace, int lwork, int *info +JAX_GPU_SOLVER_EXPAND_DEFINITION(absl::Status, Geqrf); +#undef JAX_GPU_SOLVER_Geqrf_ARGS + +#define JAX_GPU_SOLVER_GeqrfBatched_ARGS(Type, ...) \ + gpublasHandle_t handle, int m, int n, Type **a, Type **tau, int *info, \ + int batch +JAX_GPU_SOLVER_EXPAND_DEFINITION(absl::Status, GeqrfBatched); +#undef JAX_GPU_SOLVER_GeqrfBatched_ARGS + +// Householder transformations: orgqr + +#define JAX_GPU_SOLVER_OrgqrBufferSize_ARGS(Type, ...) \ + gpusolverDnHandle_t handle, int m, int n, int k +JAX_GPU_SOLVER_EXPAND_DEFINITION(absl::StatusOr, OrgqrBufferSize); +#undef JAX_GPU_SOLVER_OrgqrBufferSize_ARGS + +#define JAX_GPU_SOLVER_Orgqr_ARGS(Type, ...) \ + gpusolverDnHandle_t handle, int m, int n, int k, Type *a, Type *tau, \ + Type *workspace, int lwork, int *info +JAX_GPU_SOLVER_EXPAND_DEFINITION(absl::Status, Orgqr); +#undef JAX_GPU_SOLVER_Orgqr_ARGS + +// Symmetric (Hermitian) eigendecomposition: +// * Jacobi algorithm: syevj/heevj (batches of matrices up to 32) +// * QR algorithm: syevd/heevd + +#define JAX_GPU_SOLVER_SyevjBufferSize_ARGS(Type, ...) \ + gpusolverDnHandle_t handle, gpusolverEigMode_t jobz, \ + gpusolverFillMode_t uplo, int n, gpuSyevjInfo_t params +JAX_GPU_SOLVER_EXPAND_DEFINITION(absl::StatusOr, SyevjBufferSize); +#undef JAX_GPU_SOLVER_SyevjBufferSize_ARGS + +#define JAX_GPU_SOLVER_Syevj_ARGS(Type, Real) \ + gpusolverDnHandle_t handle, gpusolverEigMode_t jobz, \ + gpusolverFillMode_t uplo, int n, Type *a, Real *w, Type *workspace, \ + int lwork, int *info, gpuSyevjInfo_t params +JAX_GPU_SOLVER_EXPAND_DEFINITION(absl::Status, Syevj); +#undef JAX_GPU_SOLVER_Syevj_ARGS + +#define JAX_GPU_SOLVER_SyevjBatchedBufferSize_ARGS(Type, ...) \ + gpusolverDnHandle_t handle, gpusolverEigMode_t jobz, \ + gpusolverFillMode_t uplo, int n, gpuSyevjInfo_t params, int batch +JAX_GPU_SOLVER_EXPAND_DEFINITION(absl::StatusOr, SyevjBatchedBufferSize); +#undef JAX_GPU_SOLVER_SyevjBatchedBufferSize_ARGS + +#define JAX_GPU_SOLVER_SyevjBatched_ARGS(Type, Real) \ + gpusolverDnHandle_t handle, gpusolverEigMode_t jobz, \ + gpusolverFillMode_t uplo, int n, Type *a, Real *w, Type *workspace, \ + int lwork, int *info, gpuSyevjInfo_t params, int batch +JAX_GPU_SOLVER_EXPAND_DEFINITION(absl::Status, SyevjBatched); +#undef JAX_GPU_SOLVER_SyevjBatched_ARGS + +#define JAX_GPU_SOLVER_SyevdBufferSize_ARGS(Type, ...) \ + gpusolverDnHandle_t handle, gpusolverEigMode_t jobz, \ + gpusolverFillMode_t uplo, int n +JAX_GPU_SOLVER_EXPAND_DEFINITION(absl::StatusOr, SyevdBufferSize); +#undef JAX_GPU_SOLVER_SyevdBufferSize_ARGS + +#define JAX_GPU_SOLVER_Syevd_ARGS(Type, Real) \ + gpusolverDnHandle_t handle, gpusolverEigMode_t jobz, \ + gpusolverFillMode_t uplo, int n, Type *a, Real *w, Type *workspace, \ + int lwork, int *info +JAX_GPU_SOLVER_EXPAND_DEFINITION(absl::Status, Syevd); +#undef JAX_GPU_SOLVER_Syevd_ARGS + +// Symmetric rank-k update: syrk + +#define JAX_GPU_SOLVER_Syrk_ARGS(Type, ...) \ + gpublasHandle_t handle, gpublasFillMode_t uplo, gpublasOperation_t trans, \ + int n, int k, const Type *alpha, const Type *a, const Type *beta, \ + Type *c +JAX_GPU_SOLVER_EXPAND_DEFINITION(absl::Status, Syrk); +#undef JAX_GPU_SOLVER_Syrk_ARGS + +// Singular Value Decomposition: gesvd + +#define JAX_GPU_SOLVER_GesvdBufferSize_ARGS(Type, ...) \ + gpusolverDnHandle_t handle, signed char job, int m, int n +JAX_GPU_SOLVER_EXPAND_DEFINITION(absl::StatusOr, GesvdBufferSize); +#undef JAX_GPU_SOLVER_GesvdBufferSize_ARGS + +#define JAX_GPU_SOLVER_Gesvd_ARGS(Type, Real) \ + gpusolverDnHandle_t handle, signed char job, int m, int n, Type *a, Real *s, \ + Type *u, Type *vt, Type *workspace, int lwork, int *info +JAX_GPU_SOLVER_EXPAND_DEFINITION(absl::Status, Gesvd); +#undef JAX_GPU_SOLVER_Gesvd_ARGS + +#ifdef JAX_GPU_CUDA + +#define JAX_GPU_SOLVER_GesvdjBufferSize_ARGS(Type, ...) \ + gpusolverDnHandle_t handle, gpusolverEigMode_t job, int econ, int m, int n, \ + gesvdjInfo_t params +JAX_GPU_SOLVER_EXPAND_DEFINITION(absl::StatusOr, GesvdjBufferSize); +#undef JAX_GPU_SOLVER_GesvdjBufferSize_ARGS + +#define JAX_GPU_SOLVER_Gesvdj_ARGS(Type, Real) \ + gpusolverDnHandle_t handle, gpusolverEigMode_t job, int econ, int m, int n, \ + Type *a, Real *s, Type *u, Type *v, Type *workspace, \ + int lwork, int *info, gesvdjInfo_t params +JAX_GPU_SOLVER_EXPAND_DEFINITION(absl::Status, Gesvdj); +#undef JAX_GPU_SOLVER_Gesvdj_ARGS + +#define JAX_GPU_SOLVER_GesvdjBatchedBufferSize_ARGS(Type, ...) \ + gpusolverDnHandle_t handle, gpusolverEigMode_t job, int m, int n, \ + gpuGesvdjInfo_t params, int batch +JAX_GPU_SOLVER_EXPAND_DEFINITION(absl::StatusOr, GesvdjBatchedBufferSize); +#undef JAX_GPU_SOLVER_GesvdjBatchedBufferSize_ARGS + +#define JAX_GPU_SOLVER_GesvdjBatched_ARGS(Type, Real) \ + gpusolverDnHandle_t handle, gpusolverEigMode_t job, int m, int n, Type *a, \ + Real *s, Type *u, Type *v, Type *workspace, int lwork, \ + int *info, gpuGesvdjInfo_t params, int batch +JAX_GPU_SOLVER_EXPAND_DEFINITION(absl::Status, GesvdjBatched); +#undef JAX_GPU_SOLVER_GesvdjBatched_ARGS + +#endif // JAX_GPU_CUDA + +#undef JAX_GPU_SOLVER_EXPAND_DEFINITION + +} // namespace solver +} // namespace JAX_GPU_NAMESPACE +} // namespace jax + +#endif // JAXLIB_GPU_SOLVER_INTERFACE_H_ diff --git a/jaxlib/gpu/solver_kernels.cc b/jaxlib/gpu/solver_kernels.cc index 8d90c70537c7..8c22dfcdbca7 100644 --- a/jaxlib/gpu/solver_kernels.cc +++ b/jaxlib/gpu/solver_kernels.cc @@ -23,8 +23,8 @@ limitations under the License. #include "absl/status/status.h" #include "absl/status/statusor.h" -#include "jaxlib/gpu/solver_handle_pool.h" #include "jaxlib/gpu/gpu_kernel_helpers.h" +#include "jaxlib/gpu/solver_handle_pool.h" #include "jaxlib/gpu/vendor.h" #include "jaxlib/kernel_helpers.h" #include "xla/service/custom_call_status.h" @@ -421,9 +421,9 @@ static absl::Status Syevd_(gpuStream_t stream, void** buffers, int output_idx = 1; // with static shapes buffers[1] is the first output if (d.batch == -1) { // the batch is passed as a second operand - gpuMemcpyAsync((void*)&batch, - reinterpret_cast(buffers[1]), - sizeof(batch), gpuMemcpyDeviceToHost, stream); + JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpuMemcpyAsync( + (void*)&batch, reinterpret_cast(buffers[1]), + sizeof(batch), gpuMemcpyDeviceToHost, stream))); JAX_RETURN_IF_ERROR(JAX_AS_STATUS(gpuStreamSynchronize(stream))); output_idx = 2; } diff --git a/jaxlib/gpu/solver_kernels_ffi.cc b/jaxlib/gpu/solver_kernels_ffi.cc index 414e159b2aac..32cd97565f5e 100644 --- a/jaxlib/gpu/solver_kernels_ffi.cc +++ b/jaxlib/gpu/solver_kernels_ffi.cc @@ -16,122 +16,906 @@ limitations under the License. #include "jaxlib/gpu/solver_kernels_ffi.h" #include +#include +#include +#include +#include #include "absl/status/status.h" #include "absl/status/statusor.h" +#include "absl/strings/str_format.h" #include "jaxlib/ffi_helpers.h" +#include "jaxlib/gpu/blas_handle_pool.h" #include "jaxlib/gpu/gpu_kernel_helpers.h" +#include "jaxlib/gpu/make_batch_pointers.h" #include "jaxlib/gpu/solver_handle_pool.h" +#include "jaxlib/gpu/solver_interface.h" #include "jaxlib/gpu/vendor.h" #include "xla/ffi/api/ffi.h" +#if JAX_GPU_64_BIT +#include +#endif + +#ifdef JAX_GPU_CUDA +#include +#endif + +#define JAX_FFI_RETURN_IF_GPU_ERROR(...) \ + FFI_RETURN_IF_ERROR_STATUS(JAX_AS_STATUS(__VA_ARGS__)) + +XLA_FFI_REGISTER_ENUM_ATTR_DECODING(jax::JAX_GPU_NAMESPACE::SyevdAlgorithm); + namespace jax { namespace JAX_GPU_NAMESPACE { namespace ffi = ::xla::ffi; -namespace { -#define GETRF_KERNEL_IMPL(type, name) \ - template <> \ - struct GetrfKernel { \ - static absl::StatusOr BufferSize(gpusolverDnHandle_t handle, int m, \ - int n) { \ - int lwork; \ - JAX_RETURN_IF_ERROR(JAX_AS_STATUS( \ - name##_bufferSize(handle, m, n, /*A=*/nullptr, /*lda=*/m, &lwork))); \ - return lwork; \ - } \ - static absl::Status Run(gpusolverDnHandle_t handle, int m, int n, type* a, \ - type* workspace, int lwork, int* ipiv, \ - int* info) { \ - return JAX_AS_STATUS( \ - name(handle, m, n, a, m, workspace, lwork, ipiv, info)); \ - } \ +template +inline absl::StatusOr AllocateWorkspace(ffi::ScratchAllocator& scratch, + int64_t size, + std::string_view name) { + auto maybe_workspace = scratch.Allocate(sizeof(T) * size); + if (!maybe_workspace.has_value()) { + return absl::Status( + absl::StatusCode::kResourceExhausted, + absl::StrFormat("Unable to allocate workspace for %s", name)); + } + return static_cast(maybe_workspace.value()); +} + +#define SOLVER_DISPATCH_IMPL(impl, ...) \ + switch (dataType) { \ + case ffi::F32: \ + return impl(__VA_ARGS__); \ + case ffi::F64: \ + return impl(__VA_ARGS__); \ + case ffi::C64: \ + return impl(__VA_ARGS__); \ + case ffi::C128: \ + return impl(__VA_ARGS__); \ + default: \ + break; \ } -template -struct GetrfKernel; -GETRF_KERNEL_IMPL(float, gpusolverDnSgetrf); -GETRF_KERNEL_IMPL(double, gpusolverDnDgetrf); -GETRF_KERNEL_IMPL(gpuComplex, gpusolverDnCgetrf); -GETRF_KERNEL_IMPL(gpuDoubleComplex, gpusolverDnZgetrf); -#undef GETRF_KERNEL_IMPL +#define SOLVER_BLAS_DISPATCH_IMPL(impl, ...) \ + switch (dataType) { \ + case ffi::F32: \ + return impl(__VA_ARGS__); \ + case ffi::F64: \ + return impl(__VA_ARGS__); \ + case ffi::C64: \ + return impl(__VA_ARGS__); \ + case ffi::C128: \ + return impl(__VA_ARGS__); \ + default: \ + break; \ + } + +// LU decomposition: getrf template -ffi::Error GetrfImpl(gpuStream_t stream, ffi::ScratchAllocator& scratch, +ffi::Error GetrfImpl(int64_t batch, int64_t rows, int64_t cols, + gpuStream_t stream, ffi::ScratchAllocator& scratch, ffi::AnyBuffer a, ffi::Result out, - ffi::Result> ipiv, - ffi::Result> info) { - FFI_RETURN_IF_ERROR(CheckMatrixDimensions(a.dimensions())); - auto [batch, rows, cols] = SplitBatch2D(a.dimensions()); + ffi::Result> ipiv, + ffi::Result> info) { FFI_ASSIGN_OR_RETURN(auto m, MaybeCastNoOverflow(rows)); FFI_ASSIGN_OR_RETURN(auto n, MaybeCastNoOverflow(cols)); FFI_ASSIGN_OR_RETURN(auto handle, SolverHandlePool::Borrow(stream)); FFI_ASSIGN_OR_RETURN(int lwork, - GetrfKernel::BufferSize(handle.get(), m, n)); - - auto maybe_workspace = scratch.Allocate(sizeof(T) * lwork); - if (!maybe_workspace.has_value()) { - return ffi::Error(ffi::ErrorCode::kUnknown, - "Unable to allocate workspace for getrf"); - } - auto workspace = static_cast(maybe_workspace.value()); + solver::GetrfBufferSize(handle.get(), m, n)); + FFI_ASSIGN_OR_RETURN(auto workspace, + AllocateWorkspace(scratch, lwork, "getrf")); auto a_data = static_cast(a.untyped_data()); auto out_data = static_cast(out->untyped_data()); auto ipiv_data = ipiv->typed_data(); auto info_data = info->typed_data(); if (a_data != out_data) { - FFI_RETURN_IF_ERROR_STATUS(JAX_AS_STATUS( - gpuMemcpyAsync(out_data, a_data, sizeof(T) * batch * rows * cols, - gpuMemcpyDeviceToDevice, stream))); + JAX_FFI_RETURN_IF_GPU_ERROR(gpuMemcpyAsync( + out_data, a_data, a.size_bytes(), gpuMemcpyDeviceToDevice, stream)); } - for (int i = 0; i < batch; ++i) { - FFI_RETURN_IF_ERROR_STATUS(GetrfKernel::Run( + int ipiv_step = std::min(m, n); + for (auto i = 0; i < batch; ++i) { + FFI_RETURN_IF_ERROR_STATUS(solver::Getrf( handle.get(), m, n, out_data, workspace, lwork, ipiv_data, info_data)); out_data += m * n; - ipiv_data += std::min(m, n); + ipiv_data += ipiv_step; ++info_data; } return ffi::Error::Success(); } +template +ffi::Error GetrfBatchedImpl(int64_t batch, int64_t cols, gpuStream_t stream, + ffi::ScratchAllocator& scratch, ffi::AnyBuffer a, + ffi::Result out, + ffi::Result> ipiv, + ffi::Result> info) { + FFI_ASSIGN_OR_RETURN(auto n, MaybeCastNoOverflow(cols)); + FFI_ASSIGN_OR_RETURN(auto handle, BlasHandlePool::Borrow(stream)); + FFI_ASSIGN_OR_RETURN(auto batch_ptrs, + AllocateWorkspace(scratch, batch, "batched getrf")); + + auto a_data = a.untyped_data(); + auto out_data = out->untyped_data(); + auto ipiv_data = ipiv->typed_data(); + auto info_data = info->typed_data(); + if (a_data != out_data) { + JAX_FFI_RETURN_IF_GPU_ERROR(gpuMemcpyAsync( + out_data, a_data, a.size_bytes(), gpuMemcpyDeviceToDevice, stream)); + } + + MakeBatchPointersAsync(stream, out_data, batch_ptrs, batch, + sizeof(T) * n * n); + JAX_FFI_RETURN_IF_GPU_ERROR(gpuGetLastError()); + + FFI_RETURN_IF_ERROR_STATUS(solver::GetrfBatched( + handle.get(), n, batch_ptrs, n, ipiv_data, info_data, batch)); + + return ffi::Error::Success(); +} + ffi::Error GetrfDispatch(gpuStream_t stream, ffi::ScratchAllocator scratch, ffi::AnyBuffer a, ffi::Result out, - ffi::Result> ipiv, - ffi::Result> info) { + ffi::Result> ipiv, + ffi::Result> info) { auto dataType = a.element_type(); if (dataType != out->element_type()) { - return ffi::Error( - ffi::ErrorCode::kInvalidArgument, + return ffi::Error::InvalidArgument( "The input and output to getrf must have the same element type"); } - if (dataType == ffi::DataType::F32) { - return GetrfImpl(stream, scratch, a, out, ipiv, info); - } else if (dataType == ffi::DataType::F64) { - return GetrfImpl(stream, scratch, a, out, ipiv, info); - } else if (dataType == ffi::DataType::C64) { - return GetrfImpl(stream, scratch, a, out, ipiv, info); - } else if (dataType == ffi::DataType::C128) { - return GetrfImpl(stream, scratch, a, out, ipiv, info); - } - return ffi::Error(ffi::ErrorCode::kInvalidArgument, - "Unsupported element type for getrf"); -} -} // namespace - -XLA_FFI_DEFINE_HANDLER_SYMBOL( - GetrfFfi, GetrfDispatch, - ffi::Ffi::Bind() - .Ctx>() - .Ctx() - .Arg() // a - .Ret() // out - .Ret>() // ipiv - .Ret>() // info + FFI_ASSIGN_OR_RETURN((auto [batch, rows, cols]), + SplitBatch2D(a.dimensions())); + FFI_RETURN_IF_ERROR( + CheckShape(out->dimensions(), {batch, rows, cols}, "out", "getrf")); + FFI_RETURN_IF_ERROR(CheckShape( + ipiv->dimensions(), {batch, std::min(rows, cols)}, "ipiv", "getrf")); + FFI_RETURN_IF_ERROR(CheckShape(info->dimensions(), batch, "info", "getrf")); + if (batch > 1 && rows == cols && rows / batch <= 128) { + SOLVER_BLAS_DISPATCH_IMPL(GetrfBatchedImpl, batch, cols, stream, scratch, a, + out, ipiv, info); + } else { + SOLVER_DISPATCH_IMPL(GetrfImpl, batch, rows, cols, stream, scratch, a, out, + ipiv, info); + } + return ffi::Error::InvalidArgument(absl::StrFormat( + "Unsupported dtype %s in getrf", absl::FormatStreamed(dataType))); +} + +XLA_FFI_DEFINE_HANDLER_SYMBOL(GetrfFfi, GetrfDispatch, + ffi::Ffi::Bind() + .Ctx>() + .Ctx() + .Arg() // a + .Ret() // out + .Ret>() // ipiv + .Ret>() // info +); + +// QR decomposition: geqrf + +template +ffi::Error GeqrfImpl(int64_t batch, int64_t rows, int64_t cols, + gpuStream_t stream, ffi::ScratchAllocator& scratch, + ffi::AnyBuffer a, ffi::Result out, + ffi::Result tau) { + FFI_ASSIGN_OR_RETURN(auto m, MaybeCastNoOverflow(rows)); + FFI_ASSIGN_OR_RETURN(auto n, MaybeCastNoOverflow(cols)); + + FFI_ASSIGN_OR_RETURN(auto handle, SolverHandlePool::Borrow(stream)); + FFI_ASSIGN_OR_RETURN(int lwork, + solver::GeqrfBufferSize(handle.get(), m, n)); + + FFI_ASSIGN_OR_RETURN(auto workspace, + AllocateWorkspace(scratch, lwork, "geqrf")); + // Note: We ignore the returned value of info because it is only used for + // shape checking (which we already do ourselves), but it is expected to be + // in device memory, so we need to allocate it. + FFI_ASSIGN_OR_RETURN(auto info, AllocateWorkspace(scratch, 1, "geqrf")); + + auto a_data = static_cast(a.untyped_data()); + auto out_data = static_cast(out->untyped_data()); + auto tau_data = static_cast(tau->untyped_data()); + if (a_data != out_data) { + JAX_FFI_RETURN_IF_GPU_ERROR(gpuMemcpyAsync( + out_data, a_data, a.size_bytes(), gpuMemcpyDeviceToDevice, stream)); + } + + int out_step = m * n; + int tau_step = std::min(m, n); + for (auto i = 0; i < batch; ++i) { + FFI_RETURN_IF_ERROR_STATUS(solver::Geqrf( + handle.get(), m, n, out_data, tau_data, workspace, lwork, info)); + out_data += out_step; + tau_data += tau_step; + } + return ffi::Error::Success(); +} + +template +ffi::Error GeqrfBatchedImpl(int64_t batch, int64_t rows, int64_t cols, + gpuStream_t stream, ffi::ScratchAllocator& scratch, + ffi::AnyBuffer a, ffi::Result out, + ffi::Result tau) { + FFI_ASSIGN_OR_RETURN(auto m, MaybeCastNoOverflow(rows)); + FFI_ASSIGN_OR_RETURN(auto n, MaybeCastNoOverflow(cols)); + FFI_ASSIGN_OR_RETURN(auto handle, BlasHandlePool::Borrow(stream)); + FFI_ASSIGN_OR_RETURN(auto out_batch_ptrs, + AllocateWorkspace(scratch, batch, "batched geqrf")); + FFI_ASSIGN_OR_RETURN(auto tau_batch_ptrs, + AllocateWorkspace(scratch, batch, "batched geqrf")); + + auto a_data = a.untyped_data(); + auto out_data = out->untyped_data(); + auto tau_data = tau->untyped_data(); + if (a_data != out_data) { + JAX_FFI_RETURN_IF_GPU_ERROR(gpuMemcpyAsync( + out_data, a_data, a.size_bytes(), gpuMemcpyDeviceToDevice, stream)); + } + + MakeBatchPointersAsync(stream, out_data, out_batch_ptrs, batch, + sizeof(T) * m * n); + JAX_FFI_RETURN_IF_GPU_ERROR(gpuGetLastError()); + MakeBatchPointersAsync(stream, tau_data, tau_batch_ptrs, batch, + sizeof(T) * std::min(m, n)); + JAX_FFI_RETURN_IF_GPU_ERROR(gpuGetLastError()); + + // We ignore the output value of `info` because it is only used for shape + // checking. + int info; + FFI_RETURN_IF_ERROR_STATUS(solver::GeqrfBatched( + handle.get(), m, n, out_batch_ptrs, tau_batch_ptrs, &info, batch)); + + return ffi::Error::Success(); +} + +ffi::Error GeqrfDispatch(gpuStream_t stream, ffi::ScratchAllocator scratch, + ffi::AnyBuffer a, ffi::Result out, + ffi::Result tau) { + auto dataType = a.element_type(); + if (dataType != out->element_type() || dataType != tau->element_type()) { + return ffi::Error::InvalidArgument( + "The inputs and outputs to geqrf must have the same element type"); + } + FFI_ASSIGN_OR_RETURN((auto [batch, rows, cols]), + SplitBatch2D(a.dimensions())); + FFI_RETURN_IF_ERROR( + CheckShape(out->dimensions(), {batch, rows, cols}, "out", "geqrf")); + FFI_RETURN_IF_ERROR(CheckShape( + tau->dimensions(), {batch, std::min(rows, cols)}, "tau", "geqrf")); + if (batch > 1 && rows / batch <= 128 && cols / batch <= 128) { + SOLVER_BLAS_DISPATCH_IMPL(GeqrfBatchedImpl, batch, rows, cols, stream, + scratch, a, out, tau); + } else { + SOLVER_DISPATCH_IMPL(GeqrfImpl, batch, rows, cols, stream, scratch, a, out, + tau); + } + return ffi::Error::InvalidArgument(absl::StrFormat( + "Unsupported dtype %s in geqrf", absl::FormatStreamed(dataType))); +} + +XLA_FFI_DEFINE_HANDLER_SYMBOL(GeqrfFfi, GeqrfDispatch, + ffi::Ffi::Bind() + .Ctx>() + .Ctx() + .Arg() // a + .Ret() // out + .Ret() // tau +); + +// Householder transformations: orgqr + +template +ffi::Error OrgqrImpl(int64_t batch, int64_t rows, int64_t cols, int64_t size, + gpuStream_t stream, ffi::ScratchAllocator& scratch, + ffi::AnyBuffer a, ffi::AnyBuffer tau, + ffi::Result out) { + FFI_ASSIGN_OR_RETURN(auto m, MaybeCastNoOverflow(rows)); + FFI_ASSIGN_OR_RETURN(auto n, MaybeCastNoOverflow(cols)); + FFI_ASSIGN_OR_RETURN(auto k, MaybeCastNoOverflow(size)); + + FFI_ASSIGN_OR_RETURN(auto handle, SolverHandlePool::Borrow(stream)); + FFI_ASSIGN_OR_RETURN(int lwork, + solver::OrgqrBufferSize(handle.get(), m, n, k)); + + FFI_ASSIGN_OR_RETURN(auto workspace, + AllocateWorkspace(scratch, lwork, "orgqr")); + // Note: We ignore the returned value of info because it is only used for + // shape checking (which we already do ourselves), but it is expected to be + // in device memory, so we need to allocate it. + FFI_ASSIGN_OR_RETURN(auto info, AllocateWorkspace(scratch, 1, "orgqr")); + + auto a_data = static_cast(a.untyped_data()); + auto tau_data = static_cast(tau.untyped_data()); + auto out_data = static_cast(out->untyped_data()); + if (a_data != out_data) { + JAX_FFI_RETURN_IF_GPU_ERROR(gpuMemcpyAsync( + out_data, a_data, a.size_bytes(), gpuMemcpyDeviceToDevice, stream)); + } + + int out_step = m * n; + for (auto i = 0; i < batch; ++i) { + FFI_RETURN_IF_ERROR_STATUS(solver::Orgqr( + handle.get(), m, n, k, out_data, tau_data, workspace, lwork, info)); + out_data += out_step; + tau_data += k; + } + return ffi::Error::Success(); +} + +ffi::Error OrgqrDispatch(gpuStream_t stream, ffi::ScratchAllocator scratch, + ffi::AnyBuffer a, ffi::AnyBuffer tau, + ffi::Result out) { + auto dataType = a.element_type(); + if (dataType != tau.element_type() || dataType != out->element_type()) { + return ffi::Error::InvalidArgument( + "The inputs and outputs to orgqr must have the same element type"); + } + FFI_ASSIGN_OR_RETURN((auto [batch, rows, cols]), + SplitBatch2D(a.dimensions())); + FFI_ASSIGN_OR_RETURN((auto [tau_batch, size]), + SplitBatch1D(tau.dimensions())); + if (tau_batch != batch) { + return ffi::Error::InvalidArgument( + "The batch dimensions of the inputs to orgqr must match"); + } + if (size > cols) { + return ffi::Error::InvalidArgument( + "The trailing dimension of the tau input to orgqr must be less than or " + "equal to the number of columns of the input matrix"); + } + FFI_RETURN_IF_ERROR( + CheckShape(out->dimensions(), {batch, rows, cols}, "out", "orgqr")); + SOLVER_DISPATCH_IMPL(OrgqrImpl, batch, rows, cols, size, stream, scratch, a, + tau, out); + return ffi::Error::InvalidArgument(absl::StrFormat( + "Unsupported dtype %s in orgqr", absl::FormatStreamed(dataType))); +} + +XLA_FFI_DEFINE_HANDLER_SYMBOL(OrgqrFfi, OrgqrDispatch, + ffi::Ffi::Bind() + .Ctx>() + .Ctx() + .Arg() // a + .Arg() // tau + .Ret() // out +); + +// Symmetric (Hermitian) eigendecomposition: +// * Jacobi algorithm: syevj/heevj (batches of matrices up to 32) +// * QR algorithm: syevd/heevd +// For historical reasons, the target is called "syevd" even though it +// dispatches dynamically to both syevd and syevj depending on the problem +// size and the algorithm selected by the user via the `algorithm` attribute. + +template +ffi::Error SyevdImpl(int64_t batch, int64_t size, gpuStream_t stream, + ffi::ScratchAllocator& scratch, SyevdAlgorithm algorithm, + bool lower, ffi::AnyBuffer a, + ffi::Result out, + ffi::Result w, + ffi::Result> info) { + FFI_ASSIGN_OR_RETURN(auto n, MaybeCastNoOverflow(size)); + FFI_ASSIGN_OR_RETURN(auto handle, SolverHandlePool::Borrow(stream)); + + gpusolverEigMode_t jobz = GPUSOLVER_EIG_MODE_VECTOR; + gpusolverFillMode_t uplo = + lower ? GPUSOLVER_FILL_MODE_LOWER : GPUSOLVER_FILL_MODE_UPPER; + + auto a_data = static_cast(a.untyped_data()); + auto out_data = static_cast(out->untyped_data()); + auto w_data = static_cast::value*>(w->untyped_data()); + auto info_data = info->typed_data(); + if (a_data != out_data) { + JAX_FFI_RETURN_IF_GPU_ERROR(gpuMemcpyAsync( + out_data, a_data, a.size_bytes(), gpuMemcpyDeviceToDevice, stream)); + } + if (algorithm == SyevdAlgorithm::kJacobi || + (algorithm == SyevdAlgorithm::kDefault && size <= 32)) { + gpuSyevjInfo_t params; + JAX_FFI_RETURN_IF_GPU_ERROR(gpusolverDnCreateSyevjInfo(¶ms)); + std::unique_ptr params_cleanup( + params, [](gpuSyevjInfo_t p) { gpusolverDnDestroySyevjInfo(p); }); + + if (batch == 1) { + FFI_ASSIGN_OR_RETURN(int lwork, solver::SyevjBufferSize( + handle.get(), jobz, uplo, n, params)); + FFI_ASSIGN_OR_RETURN(auto workspace, + AllocateWorkspace(scratch, lwork, "syevj")); + FFI_RETURN_IF_ERROR_STATUS(solver::Syevj(handle.get(), jobz, uplo, n, + out_data, w_data, workspace, + lwork, info_data, params)); + } else { + FFI_ASSIGN_OR_RETURN( + int lwork, solver::SyevjBatchedBufferSize(handle.get(), jobz, uplo, + n, params, batch)); + FFI_ASSIGN_OR_RETURN( + auto workspace, + AllocateWorkspace(scratch, lwork, "syevj_batched")); + FFI_RETURN_IF_ERROR_STATUS( + solver::SyevjBatched(handle.get(), jobz, uplo, n, out_data, w_data, + workspace, lwork, info_data, params, batch)); + } + } else { + FFI_ASSIGN_OR_RETURN( + int lwork, solver::SyevdBufferSize(handle.get(), jobz, uplo, n)); + FFI_ASSIGN_OR_RETURN(auto workspace, + AllocateWorkspace(scratch, lwork, "syevd")); + int out_step = n * n; + for (auto i = 0; i < batch; ++i) { + FFI_RETURN_IF_ERROR_STATUS(solver::Syevd(handle.get(), jobz, uplo, n, + out_data, w_data, workspace, + lwork, info_data)); + out_data += out_step; + w_data += n; + ++info_data; + } + } + return ffi::Error::Success(); +} + +ffi::Error SyevdDispatch(gpuStream_t stream, ffi::ScratchAllocator scratch, + SyevdAlgorithm algorithm, bool lower, + ffi::AnyBuffer a, ffi::Result out, + ffi::Result w, + ffi::Result> info) { + auto dataType = a.element_type(); + if (dataType != out->element_type() || + ffi::ToReal(dataType) != w->element_type()) { + return ffi::Error::InvalidArgument( + "The inputs and outputs to syevd must have the same element type"); + } + FFI_ASSIGN_OR_RETURN((auto [batch, rows, cols]), + SplitBatch2D(a.dimensions())); + if (rows != cols) { + return ffi::Error::InvalidArgument( + "The input matrix to syevd must be square"); + } + FFI_RETURN_IF_ERROR( + CheckShape(out->dimensions(), {batch, rows, cols}, "out", "syevd")); + FFI_RETURN_IF_ERROR(CheckShape(w->dimensions(), {batch, cols}, "w", "syevd")); + FFI_RETURN_IF_ERROR(CheckShape(info->dimensions(), batch, "info", "syevd")); + SOLVER_DISPATCH_IMPL(SyevdImpl, batch, cols, stream, scratch, algorithm, + lower, a, out, w, info); + return ffi::Error::InvalidArgument(absl::StrFormat( + "Unsupported dtype %s in syevd", absl::FormatStreamed(dataType))); +} + +XLA_FFI_DEFINE_HANDLER_SYMBOL(SyevdFfi, SyevdDispatch, + ffi::Ffi::Bind() + .Ctx>() + .Ctx() + .Attr("algorithm") + .Attr("lower") + .Arg() // a + .Ret() // out + .Ret() // w + .Ret>() // info ); +// Symmetric rank-k update: syrk + +template +ffi::Error SyrkImpl(gpuStream_t stream, bool transpose, ffi::AnyBuffer a, + ffi::AnyBuffer c_in, ffi::AnyBuffer alpha, + ffi::AnyBuffer beta, ffi::Result c_out) { + FFI_ASSIGN_OR_RETURN((auto [batch, rows, cols]), + SplitBatch2D(a.dimensions())); + if (alpha.element_count() != 1 || beta.element_count() != 1) { + return ffi::Error::InvalidArgument( + "The alpha and beta inputs to syrk must be scalars"); + } + auto size = transpose ? cols : rows; + FFI_RETURN_IF_ERROR( + CheckShape(c_in.dimensions(), {batch, size, size}, "c_in", "syrk")); + FFI_RETURN_IF_ERROR( + CheckShape(c_out->dimensions(), {batch, size, size}, "c_out", "syrk")); + + FFI_ASSIGN_OR_RETURN(auto n, + MaybeCastNoOverflow(transpose ? cols : rows)); + FFI_ASSIGN_OR_RETURN(auto k, + MaybeCastNoOverflow(transpose ? rows : cols)); + gpublasFillMode_t uplo = GPUSOLVER_FILL_MODE_UPPER; + gpublasOperation_t trans = transpose ? GPUBLAS_OP_N : GPUBLAS_OP_T; + + const T* a_data = static_cast(a.untyped_data()); + T* c_data = static_cast(c_in.untyped_data()); + T* c_out_data = static_cast(c_out->untyped_data()); + + // with alpha or beta provided as device_pointers, cublassyrk will SIGSEGV + T host_alpha; + JAX_FFI_RETURN_IF_GPU_ERROR(gpuMemcpyAsync(&host_alpha, alpha.untyped_data(), + sizeof(T), gpuMemcpyDeviceToHost, + stream)); + + T host_beta; + JAX_FFI_RETURN_IF_GPU_ERROR(gpuMemcpyAsync(&host_beta, beta.untyped_data(), + sizeof(T), gpuMemcpyDeviceToHost, + stream)); + + if (c_data != c_out_data) { + JAX_FFI_RETURN_IF_GPU_ERROR( + gpuMemcpyAsync(c_out_data, c_data, c_in.size_bytes(), + gpuMemcpyDeviceToDevice, stream)); + } + FFI_ASSIGN_OR_RETURN(auto handle, BlasHandlePool::Borrow(stream)); + for (int i = 0; i < batch; ++i) { + FFI_RETURN_IF_ERROR_STATUS(solver::Syrk(handle.get(), uplo, trans, n, k, + &host_alpha, a_data, &host_beta, + c_out_data)); + a_data += k * n; + c_out_data += n * n; + } + return ffi::Error::Success(); +} + +ffi::Error SyrkDispatch(gpuStream_t stream, bool transpose, ffi::AnyBuffer a, + ffi::AnyBuffer c_in, ffi::AnyBuffer alpha, + ffi::AnyBuffer beta, + ffi::Result c_out) { + auto dataType = a.element_type(); + SOLVER_BLAS_DISPATCH_IMPL(SyrkImpl, stream, transpose, a, c_in, alpha, beta, + c_out); + return ffi::Error::InvalidArgument(absl::StrFormat( + "Unsupported dtype %s in syrk", absl::FormatStreamed(dataType))); +} + +XLA_FFI_DEFINE_HANDLER_SYMBOL(SyrkFfi, SyrkDispatch, + ffi::Ffi::Bind() + .Ctx>() + .Attr("transpose") // transpose + .Arg() // a + .Arg() // c_in + .Arg() // alpha + .Arg() // beta + .Ret() // c_out +); + +// Singular Value Decomposition: gesvd + +#if JAX_GPU_64_BIT + +ffi::Error Gesvd64Impl(int64_t batch, int64_t m, int64_t n, gpuStream_t stream, + ffi::ScratchAllocator& scratch, bool full_matrices, + bool compute_uv, ffi::AnyBuffer a, + ffi::Result out, + ffi::Result s, + ffi::Result u, + ffi::Result vt, + ffi::Result> info) { + FFI_ASSIGN_OR_RETURN(auto handle, SolverHandlePool::Borrow(stream)); + signed char job = compute_uv ? (full_matrices ? 'A' : 'S') : 'N'; + + auto dataType = a.element_type(); + gpuDataType aType, sType; + switch (dataType) { + case ffi::F32: + aType = GPU_R_32F; + sType = GPU_R_32F; + break; + case ffi::F64: + aType = GPU_R_64F; + sType = GPU_R_64F; + break; + case ffi::C64: + aType = GPU_C_32F; + sType = GPU_R_32F; + break; + case ffi::C128: + aType = GPU_C_64F; + sType = GPU_R_64F; + break; + default: + return ffi::Error::InvalidArgument(absl::StrFormat( + "Unsupported dtype %s in gesvd", absl::FormatStreamed(dataType))); + } + + gpusolverDnParams_t params; + JAX_FFI_RETURN_IF_GPU_ERROR(gpusolverDnCreateParams(¶ms)); + std::unique_ptr + params_cleanup( + params, [](gpusolverDnParams_t p) { gpusolverDnDestroyParams(p); }); + + size_t workspaceInBytesOnDevice, workspaceInBytesOnHost; + JAX_FFI_RETURN_IF_GPU_ERROR(gpusolverDnXgesvd_bufferSize( + handle.get(), params, job, job, m, n, aType, /*a=*/nullptr, m, sType, + /*s=*/nullptr, aType, /*u=*/nullptr, m, aType, /*vt=*/nullptr, n, aType, + &workspaceInBytesOnDevice, &workspaceInBytesOnHost)); + + auto maybe_workspace = scratch.Allocate(workspaceInBytesOnDevice); + if (!maybe_workspace.has_value()) { + return ffi::Error(ffi::ErrorCode::kResourceExhausted, + "Unable to allocate device workspace for gesvd"); + } + auto workspaceOnDevice = maybe_workspace.value(); + auto workspaceOnHost = + std::unique_ptr(new char[workspaceInBytesOnHost]); + + const char* a_data = static_cast(a.untyped_data()); + char* out_data = static_cast(out->untyped_data()); + char* s_data = static_cast(s->untyped_data()); + char* u_data = static_cast(u->untyped_data()); + char* vt_data = static_cast(vt->untyped_data()); + int* info_data = info->typed_data(); + if (a_data != out_data) { + JAX_FFI_RETURN_IF_GPU_ERROR(gpuMemcpyAsync( + out_data, a_data, a.size_bytes(), gpuMemcpyDeviceToDevice, stream)); + } + + size_t out_step = m * n * ffi::ByteWidth(dataType); + size_t s_step = n * ffi::ByteWidth(ffi::ToReal(dataType)); + size_t u_step = 0; + size_t vt_step = 0; + if (compute_uv) { + u_step = m * (full_matrices ? m : n) * ffi::ByteWidth(dataType); + vt_step = n * n * ffi::ByteWidth(dataType); + } + for (auto i = 0; i < batch; ++i) { + JAX_FFI_RETURN_IF_GPU_ERROR(gpusolverDnXgesvd( + handle.get(), params, job, job, m, n, aType, out_data, m, sType, s_data, + aType, u_data, m, aType, vt_data, n, aType, workspaceOnDevice, + workspaceInBytesOnDevice, workspaceOnHost.get(), workspaceInBytesOnHost, + info_data)); + out_data += out_step; + s_data += s_step; + u_data += u_step; + vt_data += vt_step; + ++info_data; + } + + return ffi::Error::Success(); +} + +#else + +template +ffi::Error GesvdImpl(int64_t batch, int64_t rows, int64_t cols, + gpuStream_t stream, ffi::ScratchAllocator& scratch, + bool full_matrices, bool compute_uv, ffi::AnyBuffer a, + ffi::Result out, + ffi::Result s, + ffi::Result u, + ffi::Result vt, + ffi::Result> info) { + FFI_ASSIGN_OR_RETURN(auto m, MaybeCastNoOverflow(rows)); + FFI_ASSIGN_OR_RETURN(auto n, MaybeCastNoOverflow(cols)); + FFI_ASSIGN_OR_RETURN(auto handle, SolverHandlePool::Borrow(stream)); + signed char job = compute_uv ? (full_matrices ? 'A' : 'S') : 'N'; + + FFI_ASSIGN_OR_RETURN(int lwork, + solver::GesvdBufferSize(handle.get(), job, m, n)); + FFI_ASSIGN_OR_RETURN(auto workspace, + AllocateWorkspace(scratch, lwork, "gesvd")); + auto a_data = static_cast(a.untyped_data()); + auto out_data = static_cast(out->untyped_data()); + auto s_data = static_cast::value*>(s->untyped_data()); + auto u_data = compute_uv ? static_cast(u->untyped_data()) : nullptr; + auto vt_data = compute_uv ? static_cast(vt->untyped_data()) : nullptr; + auto info_data = info->typed_data(); + if (a_data != out_data) { + FFI_RETURN_IF_ERROR_STATUS(JAX_AS_STATUS(gpuMemcpyAsync( + out_data, a_data, a.size_bytes(), gpuMemcpyDeviceToDevice, stream))); + } + + int out_step = m * n; + int u_step = compute_uv ? m * (full_matrices ? m : n) : 0; + int vt_step = compute_uv ? n * n : 0; + for (auto i = 0; i < batch; ++i) { + FFI_RETURN_IF_ERROR_STATUS( + solver::Gesvd(handle.get(), job, m, n, out_data, s_data, u_data, + vt_data, workspace, lwork, info_data)); + out_data += out_step; + s_data += n; // n is always less than m because of the logic in dispatch. + u_data += u_step; + vt_data += vt_step; + ++info_data; + } + return ffi::Error::Success(); +} + +#endif // JAX_GPU_64_BIT + +ffi::Error GesvdDispatch(gpuStream_t stream, ffi::ScratchAllocator scratch, + bool full_matrices, bool compute_uv, bool transposed, + ffi::AnyBuffer a, ffi::Result out, + ffi::Result s, + ffi::Result u, + ffi::Result vt, + ffi::Result> info) { + auto dataType = a.element_type(); + if (out->element_type() != dataType || + s->element_type() != ffi::ToReal(dataType) || + u->element_type() != dataType || vt->element_type() != dataType) { + return ffi::Error::InvalidArgument( + "The inputs and outputs to gesvd must have the same element type"); + } + FFI_ASSIGN_OR_RETURN((auto [batch, rows, cols]), + SplitBatch2D(a.dimensions())); + int64_t m = transposed ? cols : rows; + int64_t n = transposed ? rows : cols; + if (n > m) { + return ffi::Error::InvalidArgument( + "The GPU implementation of gesvd requires that the input matrix be m x " + "n with m >= n"); + } + FFI_RETURN_IF_ERROR( + CheckShape(out->dimensions(), {batch, rows, cols}, "out", "gesvd")); + FFI_RETURN_IF_ERROR(CheckShape(s->dimensions(), {batch, n}, "s", "gesvd")); + if (compute_uv) { + if (full_matrices) { + FFI_RETURN_IF_ERROR( + CheckShape(u->dimensions(), {batch, m, m}, "u", "gesvd")); + } else { + if (transposed) { + FFI_RETURN_IF_ERROR( + CheckShape(u->dimensions(), {batch, n, m}, "u", "gesvd")); + } else { + FFI_RETURN_IF_ERROR( + CheckShape(u->dimensions(), {batch, m, n}, "u", "gesvd")); + } + } + FFI_RETURN_IF_ERROR( + CheckShape(vt->dimensions(), {batch, n, n}, "vt", "gesvd")); + } + FFI_RETURN_IF_ERROR(CheckShape(info->dimensions(), batch, "info", "gesvd")); + +#if JAX_GPU_64_BIT + return Gesvd64Impl(batch, m, n, stream, scratch, full_matrices, compute_uv, a, + out, s, u, vt, info); +#else + SOLVER_DISPATCH_IMPL(GesvdImpl, batch, m, n, stream, scratch, full_matrices, + compute_uv, a, out, s, u, vt, info); + return ffi::Error::InvalidArgument(absl::StrFormat( + "Unsupported dtype %s in gesvd", absl::FormatStreamed(dataType))); +#endif +} + +XLA_FFI_DEFINE_HANDLER_SYMBOL(GesvdFfi, GesvdDispatch, + ffi::Ffi::Bind() + .Ctx>() + .Ctx() + .Attr("full_matrices") + .Attr("compute_uv") + .Attr("transposed") + .Arg() // a + .Ret() // out + .Ret() // s + .Ret() // u + .Ret() // vt + .Ret>() // info +); + +#ifdef JAX_GPU_CUDA + +template +ffi::Error GesvdjImpl(int64_t batch, int64_t rows, int64_t cols, + gpuStream_t stream, ffi::ScratchAllocator& scratch, + bool full_matrices, bool compute_uv, ffi::AnyBuffer a, + ffi::Result out, + ffi::Result s, + ffi::Result u, + ffi::Result v, + ffi::Result> info) { + FFI_ASSIGN_OR_RETURN(auto m, MaybeCastNoOverflow(rows)); + FFI_ASSIGN_OR_RETURN(auto n, MaybeCastNoOverflow(cols)); + FFI_ASSIGN_OR_RETURN(auto handle, SolverHandlePool::Borrow(stream)); + + gpusolverEigMode_t job = + compute_uv ? GPUSOLVER_EIG_MODE_VECTOR : GPUSOLVER_EIG_MODE_NOVECTOR; + int econ = full_matrices ? 0 : 1; + + gpuGesvdjInfo_t params; + JAX_FFI_RETURN_IF_GPU_ERROR(gpusolverDnCreateGesvdjInfo(¶ms)); + std::unique_ptr params_cleanup( + params, [](gpuGesvdjInfo_t p) { gpusolverDnDestroyGesvdjInfo(p); }); + + auto a_data = static_cast(a.untyped_data()); + auto out_data = static_cast(out->untyped_data()); + auto s_data = static_cast::value*>(s->untyped_data()); + auto u_data = static_cast(u->untyped_data()); + auto v_data = static_cast(v->untyped_data()); + auto info_data = info->typed_data(); + if (a_data != out_data) { + JAX_FFI_RETURN_IF_GPU_ERROR(gpuMemcpyAsync( + out_data, a_data, a.size_bytes(), gpuMemcpyDeviceToDevice, stream)); + } + + if (batch <= 1 || batch > std::numeric_limits::max() || m > 32 || + n > 32 || econ) { + FFI_ASSIGN_OR_RETURN(int lwork, solver::GesvdjBufferSize( + handle.get(), job, econ, m, n, params)); + FFI_ASSIGN_OR_RETURN(auto workspace, + AllocateWorkspace(scratch, lwork, "gesvdj")); + int k = std::min(m, n); + int out_step = m * n; + int u_step = m * (full_matrices ? m : k); + int v_step = n * (full_matrices ? n : k); + for (auto i = 0; i < batch; ++i) { + FFI_RETURN_IF_ERROR_STATUS(solver::Gesvdj( + handle.get(), job, econ, m, n, out_data, s_data, u_data, v_data, + workspace, lwork, info_data, params)); + out_data += out_step; + s_data += k; + u_data += u_step; + v_data += v_step; + ++info_data; + } + } else { + FFI_ASSIGN_OR_RETURN(int lwork, solver::GesvdjBatchedBufferSize( + handle.get(), job, m, n, params, + static_cast(batch))); + FFI_ASSIGN_OR_RETURN( + auto workspace, AllocateWorkspace(scratch, lwork, "gesvdj_batched")); + FFI_RETURN_IF_ERROR_STATUS(solver::GesvdjBatched( + handle.get(), job, m, n, out_data, s_data, u_data, v_data, workspace, + lwork, info_data, params, static_cast(batch))); + } + return ffi::Error::Success(); +} + +ffi::Error GesvdjDispatch(gpuStream_t stream, ffi::ScratchAllocator scratch, + bool full_matrices, bool compute_uv, ffi::AnyBuffer a, + ffi::Result out, + ffi::Result s, + ffi::Result u, + ffi::Result v, + ffi::Result> info) { + auto dataType = a.element_type(); + if (out->element_type() != dataType || + s->element_type() != ffi::ToReal(dataType) || + u->element_type() != dataType || v->element_type() != dataType) { + return ffi::Error::InvalidArgument( + "The inputs and outputs to gesvdj must have the same element type"); + } + FFI_ASSIGN_OR_RETURN((auto [batch, rows, cols]), + SplitBatch2D(a.dimensions())); + int64_t size = std::min(rows, cols); + FFI_RETURN_IF_ERROR( + CheckShape(out->dimensions(), {batch, rows, cols}, "out", "gesvdj")); + FFI_RETURN_IF_ERROR( + CheckShape(s->dimensions(), {batch, size}, "s", "gesvdj")); + // U and V must always be allocated even if compute_uv is false. + if (full_matrices) { + FFI_RETURN_IF_ERROR( + CheckShape(u->dimensions(), {batch, rows, rows}, "u", "gesvdj")); + FFI_RETURN_IF_ERROR( + CheckShape(v->dimensions(), {batch, cols, cols}, "v", "gesvdj")); + } else { + FFI_RETURN_IF_ERROR( + CheckShape(u->dimensions(), {batch, rows, size}, "u", "gesvdj")); + FFI_RETURN_IF_ERROR( + CheckShape(v->dimensions(), {batch, cols, size}, "v", "gesvdj")); + } + FFI_RETURN_IF_ERROR(CheckShape(info->dimensions(), batch, "info", "gesvdj")); + + SOLVER_DISPATCH_IMPL(GesvdjImpl, batch, rows, cols, stream, scratch, + full_matrices, compute_uv, a, out, s, u, v, info); + return ffi::Error::InvalidArgument(absl::StrFormat( + "Unsupported dtype %s in gesvdj", absl::FormatStreamed(dataType))); +} + +XLA_FFI_DEFINE_HANDLER_SYMBOL(GesvdjFfi, GesvdjDispatch, + ffi::Ffi::Bind() + .Ctx>() + .Ctx() + .Attr("full_matrices") + .Attr("compute_uv") + .Arg() // a + .Ret() // out + .Ret() // s + .Ret() // u + .Ret() // v + .Ret>() // info +); + +#endif // JAX_GPU_CUDA + +#undef SOLVER_DISPATCH_IMPL +#undef SOLVER_BLAS_DISPATCH_IMPL + } // namespace JAX_GPU_NAMESPACE } // namespace jax diff --git a/jaxlib/gpu/solver_kernels_ffi.h b/jaxlib/gpu/solver_kernels_ffi.h index 64fb1baba56a..022564eb108c 100644 --- a/jaxlib/gpu/solver_kernels_ffi.h +++ b/jaxlib/gpu/solver_kernels_ffi.h @@ -16,13 +16,30 @@ limitations under the License. #ifndef JAXLIB_GPU_SOLVER_KERNELS_FFI_H_ #define JAXLIB_GPU_SOLVER_KERNELS_FFI_H_ +#include + #include "jaxlib/gpu/vendor.h" #include "xla/ffi/api/ffi.h" namespace jax { namespace JAX_GPU_NAMESPACE { +enum class SyevdAlgorithm : uint8_t { + kDefault = 0, + kDivideAndConquer, + kJacobi, +}; + XLA_FFI_DECLARE_HANDLER_SYMBOL(GetrfFfi); +XLA_FFI_DECLARE_HANDLER_SYMBOL(GeqrfFfi); +XLA_FFI_DECLARE_HANDLER_SYMBOL(OrgqrFfi); +XLA_FFI_DECLARE_HANDLER_SYMBOL(SyevdFfi); +XLA_FFI_DECLARE_HANDLER_SYMBOL(SyrkFfi); +XLA_FFI_DECLARE_HANDLER_SYMBOL(GesvdFfi); + +#ifdef JAX_GPU_CUDA +XLA_FFI_DECLARE_HANDLER_SYMBOL(GesvdjFfi); +#endif // JAX_GPU_CUDA } // namespace JAX_GPU_NAMESPACE } // namespace jax diff --git a/jaxlib/gpu/sparse.cc b/jaxlib/gpu/sparse.cc index b9eb51388fa9..2eeb94e309ce 100644 --- a/jaxlib/gpu/sparse.cc +++ b/jaxlib/gpu/sparse.cc @@ -26,6 +26,7 @@ limitations under the License. #include "absl/container/flat_hash_map.h" #include "absl/strings/str_format.h" #include "absl/synchronization/mutex.h" +#include "jaxlib/absl_status_casters.h" #include "jaxlib/gpu/gpu_kernel_helpers.h" #include "jaxlib/gpu/sparse_kernels.h" #include "jaxlib/gpu/vendor.h" @@ -57,14 +58,19 @@ gpusparseIndexType_t DtypeToCuSparseIndexType(const dtype& np_type) { gpuDataType DtypeToCudaDataType(const dtype& np_type) { static auto* types = new absl::flat_hash_map, gpuDataType>({ - {{'f', 2}, GPU_R_16F}, {{'c', 4}, GPU_C_16F}, {{'f', 4}, GPU_R_32F}, - {{'c', 8}, GPU_C_32F}, {{'f', 8}, GPU_R_64F}, - {{'c', 16}, GPU_C_64F}, + {{'f', 2}, GPU_R_16F}, + {{'c', 4}, GPU_C_16F}, + {{'f', 4}, GPU_R_32F}, + {{'c', 8}, GPU_C_32F}, + {{'f', 8}, GPU_R_64F}, + {{'c', 16}, GPU_C_64F}, #ifdef JAX_GPU_CUDA - {{'i', 1}, CUDA_R_8I}, {{'u', 1}, CUDA_R_8U}, - {{'i', 4}, CUDA_R_32I}, {{'u', 4}, CUDA_R_32U}, + {{'i', 1}, CUDA_R_8I}, + {{'u', 1}, CUDA_R_8U}, + {{'i', 4}, CUDA_R_32I}, + {{'u', 4}, CUDA_R_32U}, #if JAX_GPU_HAVE_SPARSE - {{'V', 2}, CUDA_R_16BF}, + {{'V', 2}, CUDA_R_16BF}, #endif // JAX_GPU_HAVE_SPARSE #endif // JAX_GPU_CUDA }); @@ -78,9 +84,8 @@ gpuDataType DtypeToCudaDataType(const dtype& np_type) { } // Returns the descriptor for a Sparse matrix. SparseMatDescriptor BuildSparseMatDescriptor(const dtype& data_dtype, - const dtype& index_dtype, - int rows, int cols, int nnz, - int batch_count, + const dtype& index_dtype, int rows, + int cols, int nnz, int batch_count, int batch_stride) { gpuDataType value_type = DtypeToCudaDataType(data_dtype); gpusparseIndexType_t index_type = DtypeToCuSparseIndexType(index_dtype); @@ -89,16 +94,15 @@ SparseMatDescriptor BuildSparseMatDescriptor(const dtype& data_dtype, } // Returns the descriptor for a Dense matrix. -DenseMatDescriptor BuildDenseMatDescriptor(const dtype& data_dtype, - int rows, int cols, int batch_count, +DenseMatDescriptor BuildDenseMatDescriptor(const dtype& data_dtype, int rows, + int cols, int batch_count, int batch_stride) { gpuDataType value_type = DtypeToCudaDataType(data_dtype); return DenseMatDescriptor{value_type, rows, cols, batch_count, batch_stride}; } // Returns the descriptor for a Dense vector. -DenseVecDescriptor BuildDenseVecDescriptor(const dtype& data_dtype, - int size) { +DenseVecDescriptor BuildDenseVecDescriptor(const dtype& data_dtype, int size) { gpuDataType value_type = DtypeToCudaDataType(data_dtype); return DenseVecDescriptor{value_type, size}; } @@ -107,9 +111,10 @@ DenseVecDescriptor BuildDenseVecDescriptor(const dtype& data_dtype, // CsrToDense: Convert CSR matrix to dense matrix // Returns the descriptor for a Sparse matrix. -std::pair BuildCsrToDenseDescriptor( - const dtype& data_dtype, const dtype& index_dtype, int rows, - int cols, int nnz) { +std::pair BuildCsrToDenseDescriptor(const dtype& data_dtype, + const dtype& index_dtype, + int rows, int cols, + int nnz) { auto h = SparseHandlePool::Borrow(/*stream=*/nullptr); JAX_THROW_IF_ERROR(h.status()); auto& handle = *h; @@ -185,8 +190,8 @@ void CsrToDense(gpuStream_t stream, void** buffers, const char* opaque, // Returns the descriptor for a CsrFromDense operation. std::pair BuildCsrFromDenseDescriptor( - const dtype& data_dtype, const dtype& index_dtype, int rows, - int cols, int nnz) { + const dtype& data_dtype, const dtype& index_dtype, int rows, int cols, + int nnz) { auto h = SparseHandlePool::Borrow(/*stream=*/nullptr); JAX_THROW_IF_ERROR(h.status()); auto& handle = *h; @@ -261,9 +266,8 @@ void CsrFromDense(gpuStream_t stream, void** buffers, const char* opaque, // Returns the descriptor for a CsrMatvec operation. std::pair BuildCsrMatvecDescriptor( - const dtype& data_dtype, const dtype& x_dtype, - const dtype& compute_dtype, const dtype& index_dtype, int rows, - int cols, int nnz, bool transpose) { + const dtype& data_dtype, const dtype& x_dtype, const dtype& compute_dtype, + const dtype& index_dtype, int rows, int cols, int nnz, bool transpose) { auto h = SparseHandlePool::Borrow(/*stream=*/nullptr); JAX_THROW_IF_ERROR(h.status()); auto& handle = *h; @@ -292,7 +296,7 @@ std::pair BuildCsrMatvecDescriptor( JAX_THROW_IF_ERROR( JAX_AS_STATUS(gpusparseCreateDnVec(&vec_y, y.size, empty, y.type))); size_t buffer_size; - SparseConst alpha = ConstOne(y.type); + SparseConst alpha = ValueOrThrow(ConstOne(y.type)); SparseConst beta = ConstZero(y.type); JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusparseSpMV_bufferSize( handle.get(), op, &alpha, mat_a, vec_x, &beta, vec_y, y.type, @@ -309,9 +313,9 @@ std::pair BuildCsrMatvecDescriptor( // Returns the descriptor for a CsrMatmat operation. std::pair BuildCsrMatmatDescriptor( - const dtype& data_dtype, const dtype& b_dtype, - const dtype& compute_dtype, const dtype& index_dtype, int rows, - int cols, int BCcols, int nnz, bool transpose) { + const dtype& data_dtype, const dtype& b_dtype, const dtype& compute_dtype, + const dtype& index_dtype, int rows, int cols, int BCcols, int nnz, + bool transpose) { auto h = SparseHandlePool::Borrow(/*stream=*/nullptr); JAX_THROW_IF_ERROR(h.status()); auto& handle = *h; @@ -344,7 +348,7 @@ std::pair BuildCsrMatmatDescriptor( JAX_AS_STATUS(gpusparseCreateDnMat(&mat_c, C.rows, C.cols, /*ld=*/C.cols, empty, C.type, GPUSPARSE_ORDER_ROW))); size_t buffer_size; - SparseConst alpha = ConstOne(C.type); + SparseConst alpha = ValueOrThrow(ConstOne(C.type)); SparseConst beta = ConstZero(C.type); JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusparseSpMM_bufferSize( handle.get(), op_A, GPUSPARSE_OPERATION_NON_TRANSPOSE, &alpha, mat_a, @@ -360,9 +364,10 @@ std::pair BuildCsrMatmatDescriptor( // CooToDense: Convert COO matrix to dense matrix // Returns the descriptor for a CooToDense operation. -std::pair BuildCooToDenseDescriptor( - const dtype& data_dtype, const dtype& index_dtype, int rows, - int cols, int nnz) { +std::pair BuildCooToDenseDescriptor(const dtype& data_dtype, + const dtype& index_dtype, + int rows, int cols, + int nnz) { auto h = SparseHandlePool::Borrow(/*stream=*/nullptr); JAX_THROW_IF_ERROR(h.status()); auto& handle = *h; @@ -398,8 +403,8 @@ std::pair BuildCooToDenseDescriptor( // Returns the descriptor for a CooFromDense operation. std::pair BuildCooFromDenseDescriptor( - const dtype& data_dtype, const dtype& index_dtype, int rows, - int cols, int nnz) { + const dtype& data_dtype, const dtype& index_dtype, int rows, int cols, + int nnz) { auto h = SparseHandlePool::Borrow(/*stream=*/nullptr); JAX_THROW_IF_ERROR(h.status()); auto& handle = *h; @@ -434,9 +439,8 @@ std::pair BuildCooFromDenseDescriptor( // Returns the descriptor for a CooMatvec operation. std::pair BuildCooMatvecDescriptor( - const dtype& data_dtype, const dtype& x_dtype, - const dtype& compute_dtype, const dtype& index_dtype, int rows, - int cols, int nnz, bool transpose) { + const dtype& data_dtype, const dtype& x_dtype, const dtype& compute_dtype, + const dtype& index_dtype, int rows, int cols, int nnz, bool transpose) { auto h = SparseHandlePool::Borrow(/*stream=*/nullptr); JAX_THROW_IF_ERROR(h.status()); auto& handle = *h; @@ -465,7 +469,7 @@ std::pair BuildCooMatvecDescriptor( JAX_THROW_IF_ERROR( JAX_AS_STATUS(gpusparseCreateDnVec(&vec_y, y.size, empty, y.type))); size_t buffer_size; - SparseConst alpha = ConstOne(y.type); + SparseConst alpha = ValueOrThrow(ConstOne(y.type)); SparseConst beta = ConstZero(y.type); JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusparseSpMV_bufferSize( handle.get(), op, &alpha, mat_a, vec_x, &beta, vec_y, y.type, @@ -482,10 +486,10 @@ std::pair BuildCooMatvecDescriptor( // Returns the descriptor for a CooMatmat operation. std::pair BuildCooMatmatDescriptor( - const dtype& data_dtype, const dtype& b_dtype, - const dtype& compute_dtype, const dtype& index_dtype, int rows, - int cols, int BCcols, int nnz, bool transpose, int batch_count, - int lhs_batch_stride, int rhs_batch_stride) { + const dtype& data_dtype, const dtype& b_dtype, const dtype& compute_dtype, + const dtype& index_dtype, int rows, int cols, int BCcols, int nnz, + bool transpose, int batch_count, int lhs_batch_stride, + int rhs_batch_stride) { // Three batch modes are supported, C_i = A_i B, C_i = A B_i, and // Ci = A_i B_i, where `i` denotes the batch dimension. // All three matrices A, B, and C must have the same batch count. @@ -535,7 +539,7 @@ std::pair BuildCooMatmatDescriptor( JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusparseDnMatSetStridedBatch( mat_c, /*batchCount=*/batch_count, /*batchStride=*/C.batch_stride))); size_t buffer_size; - SparseConst alpha = ConstOne(C.type); + SparseConst alpha = ValueOrThrow(ConstOne(C.type)); SparseConst beta = ConstZero(C.type); JAX_THROW_IF_ERROR(JAX_AS_STATUS(gpusparseSpMM_bufferSize( handle.get(), op_A, GPUSPARSE_OPERATION_NON_TRANSPOSE, &alpha, mat_a, diff --git a/jaxlib/gpu/sparse_kernels.cc b/jaxlib/gpu/sparse_kernels.cc index 93c6aef17008..a44d4b33149d 100644 --- a/jaxlib/gpu/sparse_kernels.cc +++ b/jaxlib/gpu/sparse_kernels.cc @@ -23,6 +23,7 @@ limitations under the License. #include "absl/status/status.h" #include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" #include "absl/synchronization/mutex.h" #include "jaxlib/gpu/gpu_kernel_helpers.h" #include "jaxlib/gpu/vendor.h" @@ -58,7 +59,7 @@ SparseConst ConstZero(gpuDataType type) { return c; } -SparseConst ConstOne(gpuDataType type) { +absl::StatusOr ConstOne(gpuDataType type) { SparseConst c; std::memset(&c, 0, sizeof(c)); switch (type) { @@ -138,6 +139,9 @@ SparseConst ConstOne(gpuDataType type) { case GPU_C_64F: c.f64[0] = 1.0; break; + default: + return absl::InvalidArgumentError( + absl::StrCat("Unsupported data type: ", type)); } return c; } @@ -248,7 +252,7 @@ static absl::Status CsrMatvec_(gpuStream_t stream, void** buffers, // are sufficient for basic matvec operations. // Note that, contrary to cusparse docs, alpha and beta must be host pointers // or else the operation will segfault. - SparseConst alpha = ConstOne(d.y.type); + JAX_ASSIGN_OR_RETURN(SparseConst alpha, ConstOne(d.y.type)); SparseConst beta = ConstZero(d.y.type); gpusparseSpMatDescr_t mat_a = 0; @@ -305,7 +309,7 @@ static absl::Status CsrMatmat_(gpuStream_t stream, void** buffers, // are sufficient for basic matvec operations. // Note that, contrary to cusparse docs, alpha and beta must be host pointers // or else the operation will segfault. - SparseConst alpha = ConstOne(d.C.type); + JAX_ASSIGN_OR_RETURN(SparseConst alpha, ConstOne(d.C.type)); SparseConst beta = ConstZero(d.C.type); gpusparseSpMatDescr_t mat_a = 0; @@ -446,7 +450,7 @@ static absl::Status CooMatvec_(gpuStream_t stream, void** buffers, // are sufficient for basic matvec operations. // Note that, contrary to cusparse docs, alpha and beta must be host pointers // or else the operation will segfault. - SparseConst alpha = ConstOne(d.y.type); + JAX_ASSIGN_OR_RETURN(SparseConst alpha, ConstOne(d.y.type)); SparseConst beta = ConstZero(d.y.type); gpusparseSpMatDescr_t mat_a = 0; @@ -502,7 +506,7 @@ static absl::Status CooMatmat_(gpuStream_t stream, void** buffers, // are sufficient for basic matvec operations. // Note that, contrary to cusparse docs, alpha and beta must be host pointers // or else the operation will segfault. - SparseConst alpha = ConstOne(d.C.type); + JAX_ASSIGN_OR_RETURN(SparseConst alpha, ConstOne(d.C.type)); SparseConst beta = ConstZero(d.C.type); gpusparseSpMatDescr_t mat_a = 0; @@ -567,7 +571,7 @@ static absl::Status gtsv2(F computeGtsv2, gpuStream_t stream, void** buffers, T* du = static_cast(buffers[2]); T* B = static_cast(buffers[3]); T* X = static_cast(buffers[4]); - void* buffer = static_cast(buffers[5]); + void* buffer = static_cast(buffers[5]); // The solution X is written in place to B. We need to therefore copy the // contents of B into the output buffer X and pass that into the kernel as B. @@ -581,8 +585,8 @@ static absl::Status gtsv2(F computeGtsv2, gpuStream_t stream, void** buffers, gpuMemcpyAsync(X, B, B_bytes, gpuMemcpyDeviceToDevice, stream))); } for (int i = 0; i < batch; ++i) { - JAX_RETURN_IF_ERROR(JAX_AS_STATUS(computeGtsv2( - handle.get(), m, n, dl, d, du, X, ldb, buffer))); + JAX_RETURN_IF_ERROR(JAX_AS_STATUS( + computeGtsv2(handle.get(), m, n, dl, d, du, X, ldb, buffer))); dl += m; d += m; du += m; diff --git a/jaxlib/gpu/sparse_kernels.h b/jaxlib/gpu/sparse_kernels.h index 2180767b0cf7..48433b3d6eaa 100644 --- a/jaxlib/gpu/sparse_kernels.h +++ b/jaxlib/gpu/sparse_kernels.h @@ -51,7 +51,7 @@ union SparseConst { }; SparseConst ConstZero(gpuDataType type); -SparseConst ConstOne(gpuDataType type); +absl::StatusOr ConstOne(gpuDataType type); struct SparseMatDescriptor { gpuDataType value_type; diff --git a/jaxlib/gpu/triton.cc b/jaxlib/gpu/triton.cc index 1274eeba466b..500034af3ebb 100644 --- a/jaxlib/gpu/triton.cc +++ b/jaxlib/gpu/triton.cc @@ -132,6 +132,18 @@ NB_MODULE(_triton, m) { return major * 10 + minor; })); + m.def( + "get_arch_details", + ValueOrThrowWrapper([](int device) -> absl::StatusOr { +#ifdef JAX_GPU_HIP + hipDeviceProp_t prop; + hipGetDeviceProperties(&prop, 0); + return prop.gcnArchName; +#else + return absl::UnimplementedError("Not a HIP GPU"); +#endif + })); + m.def("get_serialized_metadata", ValueOrThrowWrapper( [](nb::bytes opaque) -> absl::StatusOr { diff --git a/jaxlib/gpu/triton_kernels.cc b/jaxlib/gpu/triton_kernels.cc index 89d804511313..c4a9af5ffe2e 100644 --- a/jaxlib/gpu/triton_kernels.cc +++ b/jaxlib/gpu/triton_kernels.cc @@ -31,11 +31,14 @@ #include "jaxlib/gpu/triton_utils.h" #include "jaxlib/gpu/vendor.h" #include "xla/service/custom_call_status.h" -#include "tsl/platform/env.h" #ifdef JAX_GPU_CUDA #include "xla/stream_executor/cuda/cuda_asm_compiler.h" -#endif +#endif // JAX_GPU_CUDA + +#ifdef JAX_GPU_HIP +#include "tsl/platform/env.h" +#endif // JAX_GPU_HIP #define GPU_RETURN_IF_ERROR(expr) JAX_RETURN_IF_ERROR(JAX_AS_STATUS(expr)) @@ -45,7 +48,12 @@ namespace { constexpr float kBenchmarkTimeMillis = 10.; struct gpuModuleDeleter { - void operator()(gpuModule_t module) { gpuModuleUnload(module); } + void operator()(gpuModule_t module) { + absl::Status status = JAX_AS_STATUS(gpuModuleUnload(module)); + if (!status.ok()) { + LOG(WARNING) << "Failed to unload GPU module: " << status; + } + } }; using OwnedGPUmodule = @@ -53,11 +61,11 @@ using OwnedGPUmodule = absl::StatusOr GetStreamDevice(gpuStream_t stream) { gpuDevice_t device; - gpuContext_t context; #ifdef JAX_GPU_HIP int device_id = gpuGetStreamDeviceId(stream); GPU_RETURN_IF_ERROR(gpuDeviceGet(&device, device_id)); #else // JAX_GPU_CUDA + gpuContext_t context; GPU_RETURN_IF_ERROR(gpuStreamGetCtx(stream, &context)); GPU_RETURN_IF_ERROR(gpuCtxPushCurrent(context)); absl::Cleanup ctx_restorer = [] { gpuCtxPopCurrent(nullptr); }; @@ -137,14 +145,18 @@ absl::StatusOr GetKernelCall(absl::string_view opaque, gpuStream_t stream, void** buffers) { static absl::Mutex mutex; static auto& kernel_calls = - *new absl::flat_hash_map> + *new absl::flat_hash_map>> ABSL_GUARDED_BY(mutex); { // Fast path uses reader lock (as hash map look-up is relatively slow). absl::ReaderMutexLock lock(&mutex); auto it = kernel_calls.find(opaque); - if (ABSL_PREDICT_TRUE(it != kernel_calls.end())) return it->second.get(); + if (ABSL_PREDICT_TRUE(it != kernel_calls.end())) { + JAX_RETURN_IF_ERROR(it->second.status()); + return it->second->get(); + } } if (opaque.empty()) { @@ -152,37 +164,41 @@ absl::StatusOr GetKernelCall(absl::string_view opaque, } absl::MutexLock lock(&mutex); - std::unique_ptr& kernel_call = kernel_calls[opaque]; - // We released the reader lock, so it may have been written by another thread. - if (kernel_call != nullptr) return kernel_call.get(); - // The opaque data is a zlib compressed protobuf. - JAX_ASSIGN_OR_RETURN(std::string serialized, ZlibUncompress(opaque)); + auto get_kernel_call = [&]() -> absl::StatusOr> { + // The opaque data is a zlib compressed protobuf. + JAX_ASSIGN_OR_RETURN(std::string serialized, ZlibUncompress(opaque)); - jax_triton::TritonAnyKernelCall proto; - if (!proto.ParseFromString(serialized)) { - return absl::InvalidArgumentError("Failed to parse serialized data."); - } + jax_triton::TritonAnyKernelCall proto; + if (!proto.ParseFromString(serialized)) { + return absl::InvalidArgumentError("Failed to parse serialized data."); + } - if (proto.has_kernel_call()) { - JAX_ASSIGN_OR_RETURN(KernelCall kernel_call_, - KernelCall::FromProto(proto.kernel_call())); - kernel_call = std::make_unique(std::move(kernel_call_)); - } else if (proto.has_autotuned_kernel_call()) { - JAX_ASSIGN_OR_RETURN( - AutotunedKernelCall autotuned_call, - AutotunedKernelCall::FromProto(proto.autotuned_kernel_call())); - { + if (proto.has_kernel_call()) { JAX_ASSIGN_OR_RETURN(KernelCall kernel_call_, - AutotunedKernelCall::Autotune( - std::move(autotuned_call), stream, buffers)); - kernel_call = std::make_unique(std::move(kernel_call_)); + KernelCall::FromProto(proto.kernel_call())); + return std::make_unique(std::move(kernel_call_)); + } else if (proto.has_autotuned_kernel_call()) { + JAX_ASSIGN_OR_RETURN( + AutotunedKernelCall autotuned_call, + AutotunedKernelCall::FromProto(proto.autotuned_kernel_call())); + { + JAX_ASSIGN_OR_RETURN(KernelCall kernel_call_, + AutotunedKernelCall::Autotune( + std::move(autotuned_call), stream, buffers)); + return std::make_unique(std::move(kernel_call_)); + } + } else { + return absl::InvalidArgumentError("Unknown kernel call type."); } - } else { - return absl::InvalidArgumentError("Unknown kernel call type."); - } + }; - return kernel_call.get(); + // We released the reader lock, so it may have been written by another thread. + // Create a new entry if it already exists or create a new one. + auto it = kernel_calls.emplace(std::string(opaque), get_kernel_call()).first; + + JAX_RETURN_IF_ERROR(it->second.status()); + return it->second->get(); } } // namespace @@ -203,7 +219,12 @@ class ModuleImage { } GPU_RETURN_IF_ERROR(gpuCtxPushCurrent(context)); - absl::Cleanup ctx_restorer = [] { gpuCtxPopCurrent(nullptr); }; + absl::Cleanup ctx_restorer = [] { + absl::Status status = JAX_AS_STATUS(gpuCtxPopCurrent(nullptr)); + if (!status.ok()) { + LOG(WARNING) << "Failed to pop GPU context: " << status; + } + }; gpuModule_t module; GPU_RETURN_IF_ERROR(gpuModuleLoadData(&module, module_image_.data())); diff --git a/jaxlib/gpu/vendor.h b/jaxlib/gpu/vendor.h index 96266ca93378..fa247b08b207 100644 --- a/jaxlib/gpu/vendor.h +++ b/jaxlib/gpu/vendor.h @@ -22,17 +22,20 @@ limitations under the License. #if defined(JAX_GPU_CUDA) -#include "third_party/gpus/cuda/extras/CUPTI/include/cupti.h" // IWYU pragma: export -#include "third_party/gpus/cuda/include/cuComplex.h" // IWYU pragma: export -#include "third_party/gpus/cuda/include/cublas_v2.h" // IWYU pragma: export -#include "third_party/gpus/cuda/include/cuda.h" // IWYU pragma: export -#include "third_party/gpus/cuda/include/cuda_fp8.h" // IWYU pragma: export -#include "third_party/gpus/cuda/include/cuda_runtime_api.h" // IWYU pragma: export -#include "third_party/gpus/cuda/include/cufft.h" // IWYU pragma: export -#include "third_party/gpus/cuda/include/cusolverDn.h" // IWYU pragma: export -#include "third_party/gpus/cuda/include/cusolver_common.h" // IWYU pragma: export -#include "third_party/gpus/cuda/include/cusparse.h" // IWYU pragma: export -#include "third_party/gpus/cudnn/cudnn.h" // IWYU pragma: export +// IWYU pragma: begin_exports +#include "third_party/gpus/cuda/extras/CUPTI/include/cupti.h" +#include "third_party/gpus/cuda/include/cooperative_groups.h" +#include "third_party/gpus/cuda/include/cuComplex.h" +#include "third_party/gpus/cuda/include/cublas_v2.h" +#include "third_party/gpus/cuda/include/cuda.h" +#include "third_party/gpus/cuda/include/cuda_fp8.h" +#include "third_party/gpus/cuda/include/cuda_runtime_api.h" +#include "third_party/gpus/cuda/include/cufft.h" +#include "third_party/gpus/cuda/include/cusolverDn.h" +#include "third_party/gpus/cuda/include/cusolver_common.h" +#include "third_party/gpus/cuda/include/cusparse.h" +#include "third_party/gpus/cudnn/cudnn.h" +// IWYU pragma: end_exports #if CUDA_VERSION < 11080 #error "JAX requires CUDA 11.8 or newer." @@ -54,6 +57,9 @@ typedef cuDoubleComplex gpublasDoubleComplex; typedef cublasFillMode_t gpusolverFillMode_t; typedef cublasStatus_t gpublasStatus_t; typedef cublasHandle_t gpublasHandle_t; +typedef cublasOperation_t gpublasOperation_t; +typedef cublasFillMode_t gpublasFillMode_t; + typedef CUcontext gpuContext_t; typedef CUstreamCaptureMode gpustreamCaptureMode_t; typedef CUstreamCaptureStatus gpustreamCaptureStatus_t; @@ -72,6 +78,8 @@ typedef cusolverStatus_t gpusolverStatus_t; typedef cusolverEigMode_t gpusolverEigMode_t; typedef syevjInfo gpuSyevjInfo; typedef syevjInfo_t gpuSyevjInfo_t; +typedef gesvdjInfo gpuGesvdjInfo; +typedef gesvdjInfo_t gpuGesvdjInfo_t; typedef cusparseIndexType_t gpusparseIndexType_t; typedef cusparseHandle_t gpusparseHandle_t; typedef cusparseOperation_t gpusparseOperation_t; @@ -98,6 +106,11 @@ typedef cusparseDnVecDescr_t gpusparseDnVecDescr_t; #define gpublasCgetrfBatched cublasCgetrfBatched #define gpublasZgetrfBatched cublasZgetrfBatched +#define gpublasSsyrk cublasSsyrk +#define gpublasDsyrk cublasDsyrk +#define gpublasCsyrk cublasCsyrk +#define gpublasZsyrk cublasZsyrk + #define GPUBLAS_STATUS_SUCCESS CUBLAS_STATUS_SUCCESS #define gpudnnCreate cudnnCreate @@ -109,6 +122,8 @@ typedef cusparseDnVecDescr_t gpusparseDnVecDescr_t; #define gpusolverDnSetStream cusolverDnSetStream #define gpusolverDnCreateSyevjInfo cusolverDnCreateSyevjInfo #define gpusolverDnDestroySyevjInfo cusolverDnDestroySyevjInfo +#define gpusolverDnCreateGesvdjInfo cusolverDnCreateGesvdjInfo +#define gpusolverDnDestroyGesvdjInfo cusolverDnDestroyGesvdjInfo #define gpusolverDnSgeqrf cusolverDnSgeqrf #define gpusolverDnDgeqrf cusolverDnDgeqrf #define gpusolverDnCgeqrf cusolverDnCgeqrf @@ -173,6 +188,22 @@ typedef cusparseDnVecDescr_t gpusparseDnVecDescr_t; cusolverDnCgesvd_bufferSize(h, m, n, lwork) #define gpusolverDnZgesvd_bufferSize(h, jobu, jobvt, m, n, lwork) \ cusolverDnZgesvd_bufferSize(h, m, n, lwork) +#define gpusolverDnSgesvdj cusolverDnSgesvdj +#define gpusolverDnDgesvdj cusolverDnDgesvdj +#define gpusolverDnCgesvdj cusolverDnCgesvdj +#define gpusolverDnZgesvdj cusolverDnZgesvdj +#define gpusolverDnSgesvdj_bufferSize cusolverDnSgesvdj_bufferSize +#define gpusolverDnDgesvdj_bufferSize cusolverDnDgesvdj_bufferSize +#define gpusolverDnCgesvdj_bufferSize cusolverDnCgesvdj_bufferSize +#define gpusolverDnZgesvdj_bufferSize cusolverDnZgesvdj_bufferSize +#define gpusolverDnSgesvdjBatched cusolverDnSgesvdjBatched +#define gpusolverDnDgesvdjBatched cusolverDnDgesvdjBatched +#define gpusolverDnCgesvdjBatched cusolverDnCgesvdjBatched +#define gpusolverDnZgesvdjBatched cusolverDnZgesvdjBatched +#define gpusolverDnSgesvdjBatched_bufferSize cusolverDnSgesvdjBatched_bufferSize +#define gpusolverDnDgesvdjBatched_bufferSize cusolverDnDgesvdjBatched_bufferSize +#define gpusolverDnCgesvdjBatched_bufferSize cusolverDnCgesvdjBatched_bufferSize +#define gpusolverDnZgesvdjBatched_bufferSize cusolverDnZgesvdjBatched_bufferSize #define gpusolverDnSsytrd_bufferSize cusolverDnSsytrd_bufferSize #define gpusolverDnDsytrd_bufferSize cusolverDnDsytrd_bufferSize #define gpusolverDnChetrd_bufferSize cusolverDnChetrd_bufferSize @@ -185,8 +216,13 @@ typedef cusparseDnVecDescr_t gpusparseDnVecDescr_t; #define GPUSOLVER_FILL_MODE_LOWER CUBLAS_FILL_MODE_LOWER #define GPUSOLVER_FILL_MODE_UPPER CUBLAS_FILL_MODE_UPPER #define GPUSOLVER_EIG_MODE_VECTOR CUSOLVER_EIG_MODE_VECTOR +#define GPUSOLVER_EIG_MODE_NOVECTOR CUSOLVER_EIG_MODE_NOVECTOR #define GPUSOLVER_STATUS_SUCCESS CUSOLVER_STATUS_SUCCESS +#define GPUBLAS_OP_N CUBLAS_OP_N +#define GPUBLAS_OP_T CUBLAS_OP_T +#define GPUBLAS_OP_C CUBLAS_OP_C + #define gpusparseCooSetStridedBatch cusparseCooSetStridedBatch #define gpusparseCreate cusparseCreate #define gpusparseCreateCoo cusparseCreateCoo @@ -292,6 +328,26 @@ typedef cusparseDnVecDescr_t gpusparseDnVecDescr_t; #define gpuStreamWaitEvent cudaStreamWaitEvent #define gpuSuccess cudaSuccess +#define gpuDeviceProp cudaDeviceProp +#define gpuGetDeviceProperties cudaGetDeviceProperties +#define gpuLaunchCooperativeKernel cudaLaunchCooperativeKernel + +#define JAX_GPU_64_BIT 1 + +#define GPU_R_32F CUDA_R_32F +#define GPU_R_64F CUDA_R_64F +#define GPU_C_32F CUDA_C_32F +#define GPU_C_64F CUDA_C_64F + +typedef cudaDataType gpuDataType; +typedef cusolverDnParams gpusolverDnParams; +typedef cusolverDnParams_t gpusolverDnParams_t; +#define gpusolverDnCreateParams cusolverDnCreateParams +#define gpusolverDnDestroyParams cusolverDnDestroyParams + +#define gpusolverDnXgesvd_bufferSize cusolverDnXgesvd_bufferSize +#define gpusolverDnXgesvd cusolverDnXgesvd + namespace jax::JAX_GPU_NAMESPACE { namespace { constexpr uint32_t kNumThreadsPerWarp = 32; @@ -300,15 +356,19 @@ constexpr uint32_t kNumThreadsPerWarp = 32; #elif defined(JAX_GPU_HIP) +// IWYU pragma: begin_exports +#include "rocm/include/hip/hip_cooperative_groups.h" #include "rocm/include/hip/hip_runtime_api.h" #include "rocm/include/hipblas/hipblas.h" #include "rocm/include/hipsolver/hipsolver.h" #include "rocm/include/hipsparse/hipsparse.h" +// IWYU pragma: end_exports #define JAX_GPU_NAMESPACE hip #define JAX_GPU_PREFIX "hip" #define JAX_GPU_HAVE_SPARSE 1 +#define JAX_GPU_64_BIT 0 #define JAX_GPU_HAVE_FP8 0 typedef hipFloatComplex gpuComplex; @@ -320,6 +380,7 @@ typedef hipsolverHandle_t gpusolverDnHandle_t; typedef hipblasFillMode_t gpublasFillMode_t; typedef hipsolverFillMode_t gpusolverFillMode_t; typedef hipblasHandle_t gpublasHandle_t; +typedef hipblasOperation_t gpublasOperation_t; typedef hipblasStatus_t gpublasStatus_t; typedef hipCtx_t gpuContext_t; typedef hipStreamCaptureMode gpustreamCaptureMode_t; @@ -362,6 +423,11 @@ typedef hipsparseDnVecDescr_t gpusparseDnVecDescr_t; #define gpublasCgetrfBatched hipblasCgetrfBatched #define gpublasZgetrfBatched hipblasZgetrfBatched +#define gpublasSsyrk hipblasSsyrk +#define gpublasDsyrk hipblasDsyrk +#define gpublasCsyrk hipblasCsyrk +#define gpublasZsyrk hipblasZsyrk + #define GPUBLAS_STATUS_SUCCESS HIPBLAS_STATUS_SUCCESS #define gpusolverDnCreate hipsolverCreate @@ -444,8 +510,13 @@ typedef hipsparseDnVecDescr_t gpusparseDnVecDescr_t; #define GPUSOLVER_FILL_MODE_LOWER HIPSOLVER_FILL_MODE_LOWER #define GPUSOLVER_FILL_MODE_UPPER HIPSOLVER_FILL_MODE_UPPER #define GPUSOLVER_EIG_MODE_VECTOR HIPSOLVER_EIG_MODE_VECTOR +#define GPUSOLVER_EIG_MODE_NOVECTOR HIPSOLVER_EIG_MODE_NOVECTOR #define GPUSOLVER_STATUS_SUCCESS HIPSOLVER_STATUS_SUCCESS +#define GPUBLAS_OP_N HIPBLAS_OP_N +#define GPUBLAS_OP_T HIPBLAS_OP_T +#define GPUBLAS_OP_C HIPBLAS_OP_C + #define gpusparseCooSetStridedBatch hipsparseCooSetStridedBatch #define gpusparseCreate hipsparseCreate #define gpusparseSetStream hipsparseSetStream @@ -541,6 +612,10 @@ typedef hipsparseDnVecDescr_t gpusparseDnVecDescr_t; HIP_FUNC_ATTRIBUTE_SHARED_SIZE_BYTES #define GPU_EVENT_DEFAULT hipEventDefault +#define gpuDeviceProp hipDeviceProp_t +#define gpuGetDeviceProperties hipGetDeviceProperties +#define gpuLaunchCooperativeKernel hipLaunchCooperativeKernel + namespace jax::JAX_GPU_NAMESPACE { namespace { constexpr uint32_t kNumThreadsPerWarp = 64; diff --git a/jaxlib/gpu_linalg.py b/jaxlib/gpu_linalg.py index af79b3ae756f..9dedc86e4355 100644 --- a/jaxlib/gpu_linalg.py +++ b/jaxlib/gpu_linalg.py @@ -12,16 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -import functools -from functools import partial import importlib -import numpy as np -import operator - -import jaxlib.mlir.ir as ir - -from .hlo_helpers import custom_call -from .gpu_common_utils import GpuLibNotLinkedError from jaxlib import xla_client @@ -37,7 +28,9 @@ if _cuda_linalg: for _name, _value in _cuda_linalg.registrations().items(): - api_version = 0 if _name == "cu_cholesky_update" else 1 + api_version = (1 + if _name.endswith("lu_pivots_to_permutation") + or _name.endswith("_ffi") else 0) xla_client.register_custom_call_target( _name, _value, platform="CUDA", api_version=api_version ) @@ -54,72 +47,9 @@ if _hip_linalg: for _name, _value in _hip_linalg.registrations().items(): + api_version = (1 + if _name.endswith("lu_pivots_to_permutation") + or _name.endswith("_ffi") else 0) xla_client.register_custom_call_target( - _name, _value, platform="ROCM", api_version=1 + _name, _value, platform="ROCM", api_version=api_version ) - -_prod = lambda xs: functools.reduce(operator.mul, xs, 1) - - -def _lu_pivots_to_permutation_hlo(platform, gpu_linalg, pivots, *, permutation_size): - """Kernel for the transformation of pivots to permutations on GPU.""" - typ = ir.RankedTensorType(pivots.type) - dims = typ.shape - i32_type = ir.IntegerType.get_signless(32) - - assert typ.element_type == i32_type, typ - - if not gpu_linalg: - raise GpuLibNotLinkedError() - - pivots_layout = tuple(range(len(dims) - 1, -1, -1)) - permutations_layout = pivots_layout - permutations_dims = list(dims) - permutations_dims[-1] = permutation_size - permutations_type = ir.RankedTensorType.get(permutations_dims, i32_type) - return custom_call( - f"{platform}_lu_pivots_to_permutation", - api_version=4, - result_types=[permutations_type], - operands=[pivots], - backend_config=dict( - permutation_size=ir.IntegerAttr.get(i32_type, permutation_size), - ), - operand_layouts=[pivots_layout], - result_layouts=[permutations_layout], - ).results - -cuda_lu_pivots_to_permutation = partial(_lu_pivots_to_permutation_hlo, "cu", - _cuda_linalg) -hip_lu_pivots_to_permutation = partial( - _lu_pivots_to_permutation_hlo, "hip", _hip_linalg) - - - -def _cholesky_update_hlo(platform, gpu_linalg, r_matrix, w_vector, dtype): - """Cholesky update.""" - del platform - r_type = ir.RankedTensorType(r_matrix.type) - dims = r_type.shape - assert dims[0] == dims[1] - n = dims[0] - - if not gpu_linalg: - raise GpuLibNotLinkedError() - - np_type = np.dtype(dtype) - opaque = gpu_linalg.build_cholesky_update_descriptor(np_type, n) - - return custom_call( - "cu_cholesky_update", - operands = [r_matrix, w_vector], - result_types=[ - ir.RankedTensorType.get((n, n), r_type.element_type), - ir.RankedTensorType.get((n,), r_type.element_type), - ], - operand_output_aliases={0: 0, 1: 1}, - backend_config=opaque, - ).results[:1] - - -cuda_cholesky_update = partial(_cholesky_update_hlo, "cu", _cuda_linalg) diff --git a/jaxlib/gpu_solver.py b/jaxlib/gpu_solver.py index 87171fdb4611..ff1e5570bb04 100644 --- a/jaxlib/gpu_solver.py +++ b/jaxlib/gpu_solver.py @@ -43,10 +43,7 @@ if _cublas: for _name, _value in _cublas.registrations().items(): - # TODO(danfm): Clean up after all legacy custom calls are ported. - api_version = 1 if _name.endswith("_ffi") else 0 - xla_client.register_custom_call_target(_name, _value, platform="CUDA", - api_version=api_version) + xla_client.register_custom_call_target(_name, _value, platform="CUDA") for cuda_module_name in [".cuda", "jax_cuda12_plugin"]: try: @@ -78,10 +75,7 @@ if _hipblas: for _name, _value in _hipblas.registrations().items(): - # TODO(danfm): Clean up after all legacy custom calls are ported. - api_version = 1 if _name.endswith("_ffi") else 0 - xla_client.register_custom_call_target(_name, _value, platform="ROCM", - api_version=api_version) + xla_client.register_custom_call_target(_name, _value, platform="ROCM") for rocm_module_name in [".rocm", "jax_rocm60_plugin"]: try: @@ -105,7 +99,8 @@ def _real_type(dtype): return np.finfo(dtype).dtype -def _getrf_hlo(platform, gpu_blas, gpu_solver, ctx, dtype, a): +# TODO(b/357034884): Remove this function after the forward compat window. +def _getrf_hlo(platform, gpu_blas, gpu_solver, dtype, a): """LU decomposition.""" a_type = ir.RankedTensorType(a.type) dims = a_type.shape @@ -115,63 +110,41 @@ def _getrf_hlo(platform, gpu_blas, gpu_solver, ctx, dtype, a): num_bd = len(batch_dims) i32_type = ir.IntegerType.get_signless(32) layout = (num_bd, num_bd + 1) + tuple(range(num_bd - 1, -1, -1)) - batch = math.prod(batch_dims) - use_batched = batch > 1 and m == n and m // batch <= 128 - - # TODO(b/357034884): Remove after 3 week forward compatibility window. - if ctx.is_forward_compat(): - if not gpu_blas: - raise GpuLibNotLinkedError() - - if use_batched: - lwork, opaque = gpu_blas.build_getrf_batched_descriptor( - np.dtype(dtype), batch, m) - workspace = ir.RankedTensorType.get([lwork], ir.IntegerType.get_signless(8)) - kernel = f"{platform}blas_getrf_batched" - else: - lwork, opaque = gpu_solver.build_getrf_descriptor( - np.dtype(dtype), batch, m, n) - workspace = ir.RankedTensorType.get([lwork], a_type.element_type) - kernel = f"{platform}solver_getrf" - out = custom_call( - kernel, - result_types=[ - a.type, - ir.RankedTensorType.get(batch_dims + (min(m, n),), i32_type), - ir.RankedTensorType.get(batch_dims, i32_type), - workspace, - ], - operands=[a], - backend_config=opaque, - operand_layouts=[layout], - result_layouts=[ - layout, - tuple(range(num_bd, -1, -1)), - tuple(range(num_bd - 1, -1, -1)), - [0], - ], - operand_output_aliases={0: 0}).results - return out[:3] + if not gpu_blas: + raise GpuLibNotLinkedError() - target = "blas_getrf_batched_ffi" if use_batched else "solver_getrf_ffi" - return custom_call( - f"{platform}{target}", + batch = math.prod(batch_dims) + if batch > 1 and m == n and m // batch <= 128: + lwork, opaque = gpu_blas.build_getrf_batched_descriptor( + np.dtype(dtype), batch, m) + workspace = ir.RankedTensorType.get([lwork], ir.IntegerType.get_signless(8)) + kernel = f"{platform}blas_getrf_batched" + else: + lwork, opaque = gpu_solver.build_getrf_descriptor( + np.dtype(dtype), batch, m, n) + workspace = ir.RankedTensorType.get([lwork], a_type.element_type) + kernel = f"{platform}solver_getrf" + + out = custom_call( + kernel, result_types=[ a.type, ir.RankedTensorType.get(batch_dims + (min(m, n),), i32_type), ir.RankedTensorType.get(batch_dims, i32_type), + workspace, ], operands=[a], + backend_config=opaque, operand_layouts=[layout], result_layouts=[ layout, tuple(range(num_bd, -1, -1)), tuple(range(num_bd - 1, -1, -1)), + [0], ], - operand_output_aliases={0: 0}, - backend_config={}, - api_version=4).results + operand_output_aliases={0: 0}).results + return out[:3] cuda_getrf = partial(_getrf_hlo, "cu", _cublas, _cusolver) diff --git a/jaxlib/gpu_triton.py b/jaxlib/gpu_triton.py index f2d37bfec03d..77f315e5b4b1 100644 --- a/jaxlib/gpu_triton.py +++ b/jaxlib/gpu_triton.py @@ -35,6 +35,7 @@ create_array_parameter = _cuda_triton.create_array_parameter create_scalar_parameter = _cuda_triton.create_scalar_parameter get_compute_capability = _cuda_triton.get_compute_capability + get_arch_details = _cuda_triton.get_arch_details get_custom_call = _cuda_triton.get_custom_call get_serialized_metadata = _cuda_triton.get_serialized_metadata @@ -58,5 +59,6 @@ create_array_parameter = _hip_triton.create_array_parameter create_scalar_parameter = _hip_triton.create_scalar_parameter get_compute_capability = _hip_triton.get_compute_capability + get_arch_details = _hip_triton.get_arch_details get_custom_call = _hip_triton.get_custom_call get_serialized_metadata = _hip_triton.get_serialized_metadata diff --git a/jaxlib/jax.bzl b/jaxlib/jax.bzl index d1decfd3a885..2e37e694b506 100644 --- a/jaxlib/jax.bzl +++ b/jaxlib/jax.bzl @@ -19,6 +19,7 @@ load("@local_config_cuda//cuda:build_defs.bzl", _cuda_library = "cuda_library", load("@local_config_rocm//rocm:build_defs.bzl", _if_rocm_is_configured = "if_rocm_is_configured", _rocm_library = "rocm_library") load("@python_version_repo//:py_version.bzl", "HERMETIC_PYTHON_VERSION") load("@rules_cc//cc:defs.bzl", _cc_proto_library = "cc_proto_library") +load("@rules_python//python:defs.bzl", "py_test") load("@tsl//tsl/platform:build_config_root.bzl", _tf_cuda_tests_tags = "tf_cuda_tests_tags", _tf_exec_properties = "tf_exec_properties") load("@xla//xla/tsl:tsl.bzl", _if_windows = "if_windows", _pybind_extension = "tsl_pybind_extension_opensource") @@ -222,7 +223,7 @@ def if_building_jaxlib( }) # buildifier: disable=function-docstring -def jax_test( +def jax_multiplatform_test( name, srcs, args = [], @@ -230,15 +231,22 @@ def jax_test( shard_count = None, deps = [], data = [], - disable_backends = None, # buildifier: disable=unused-variable + enable_backends = None, backend_variant_args = {}, # buildifier: disable=unused-variable backend_tags = {}, # buildifier: disable=unused-variable disable_configs = None, # buildifier: disable=unused-variable - enable_configs = None, # buildifier: disable=unused-variable + enable_configs = [], config_tags_overrides = None, # buildifier: disable=unused-variable tags = [], main = None, pjrt_c_api_bypass = False): # buildifier: disable=unused-variable + # enable_configs and disable_configs do not do anything in OSS, only in Google's CI. + # The order in which `enable_backends`, `enable_configs`, and `disable_configs` are applied is + # as follows: + # 1. `enable_backends` is applied first, enabling all test configs for the given backends. + # 2. `disable_configs` is applied second, disabling the named test configs. + # 3. `enable_configs` is applied last, enabling the named test configs. + if main == None: if len(srcs) == 1: main = srcs[0] @@ -255,7 +263,7 @@ def jax_test( "--jax_platform_name=" + backend, ] test_tags = list(tags) + ["jax_test_%s" % backend] + backend_tags.get(backend, []) - if disable_backends and backend in disable_backends: + if enable_backends != None and backend not in enable_backends and not any([config.startswith(backend) for config in enable_configs]): test_tags += ["manual"] if backend == "gpu": test_tags += tf_cuda_tests_tags() @@ -267,10 +275,10 @@ def jax_test( deps = [ "//jax", "//jax:test_util", - ] + deps + if_building_jaxlib(["//jaxlib/cuda:gpu_only_test_deps"]) + select({ - "//jax:enable_build_cuda_plugin_from_source": ["//jax_plugins:gpu_plugin_only_test_deps"], - "//conditions:default": [], - }), + ] + deps + if_building_jaxlib([ + "//jaxlib/cuda:gpu_only_test_deps", + "//jax_plugins:gpu_plugin_only_test_deps", + ]), data = data, shard_count = test_shards, tags = test_tags, @@ -297,3 +305,15 @@ def jax_generate_backend_suites(backends = []): ) jax_test_file_visibility = [] + +def xla_py_proto_library(*args, **kw): # buildifier: disable=unused-variable + pass + +def jax_py_test( + name, + env = {}, + **kwargs): + env = dict(env) + if "PYTHONWARNINGS" not in env: + env["PYTHONWARNINGS"] = "error" + py_test(name = name, env = env, **kwargs) diff --git a/jaxlib/lapack.py b/jaxlib/lapack.py index 11ba6803d9df..e23cc0075139 100644 --- a/jaxlib/lapack.py +++ b/jaxlib/lapack.py @@ -17,6 +17,7 @@ from collections.abc import Sequence from enum import Enum +from typing import Optional import numpy as np @@ -25,12 +26,13 @@ from jaxlib import xla_client +from .cpu import _lapack +from .cpu._lapack import eig from .hlo_helpers import ( custom_call, hlo_u8, hlo_s32, ensure_hlo_s32, hlo_add, hlo_min, DimensionSize, ShapeTypePair, mk_result_types_and_shapes, ) -from .cpu import _lapack for _name, _value in _lapack.registrations().items(): xla_client.register_custom_call_target( @@ -69,6 +71,23 @@ def _matrix_diagonal_attr(*, unit_diag: bool): return _char_attr("U" if unit_diag else "N") +def _svd_computation_attr( + *, compute_uv: bool, full_matrices: Optional[bool] = True +): + mode = "A" + if full_matrices is None: + full_matrices = True + if not compute_uv: + # We should assert that `full_matrices` is never True here. + # This should never happen because `full_matrices` can only be computed when + # `compute_uv` is True. However, at this point there are too many tests that + # rely on this behavior. + mode = "N" + elif not full_matrices: + mode = "S" + return _char_attr(mode) + + LAPACK_DTYPE_PREFIX = { np.float32: "s", np.float64: "d", @@ -77,6 +96,12 @@ def _matrix_diagonal_attr(*, unit_diag: bool): } +def prepare_lapack_call(fn_base, dtype): + """Initializes the LAPACK library and returns the LAPACK target name.""" + _lapack.initialize() + return build_lapack_fn_target(fn_base, dtype) + + def build_lapack_fn_target(fn_base: str, dtype) -> str: """Builds the target name for a LAPACK function custom call.""" try: @@ -139,15 +164,13 @@ def trsm_hlo(dtype, alpha, a, b, # # ?getrf: LU decomposition -def getrf_hlo(ctx, dtype, a: ir.Value, *, - a_shape_vals: tuple[DimensionSize, ...]): - _lapack.initialize() +def getrf_hlo(dtype, a: ir.Value, *, a_shape_vals: tuple[DimensionSize, ...]): a_type = ir.RankedTensorType(a.type) assert len(a_shape_vals) >= 2 batch_dims_vals = a_shape_vals[:-2] num_bd = len(a_shape_vals) - 2 m, n = a_shape_vals[-2:] - fn_base = build_lapack_fn_target(fn_base="getrf", dtype=dtype) + fn = prepare_lapack_call(fn_base="getrf", dtype=dtype) layout = (num_bd, num_bd + 1) + tuple(range(num_bd - 1, -1, -1)) @@ -159,49 +182,31 @@ def getrf_hlo(ctx, dtype, a: ir.Value, *, ] result_types, result_shapes = mk_result_types_and_shapes(shape_type_pairs) - if ctx.is_forward_compat(): - fn = fn_base - scalar_layout = [] - batch_size_val = hlo_s32(1) - for b_v in batch_dims_vals: - batch_size_val = hlo.multiply(batch_size_val, ensure_hlo_s32(b_v)) + scalar_layout = [] + batch_size_val = hlo_s32(1) + for b_v in batch_dims_vals: + batch_size_val = hlo.multiply(batch_size_val, ensure_hlo_s32(b_v)) - return custom_call( - fn, - result_types=result_types, - operands=[batch_size_val, ensure_hlo_s32(m), ensure_hlo_s32(n), a], - operand_layouts=[scalar_layout] * 3 + [layout], - result_layouts=[ - layout, - tuple(range(num_bd, -1, -1)), - tuple(range(num_bd - 1, -1, -1)), - ], - operand_output_aliases={3: 0}, - result_shapes=result_shapes, - ).results - else: - fn = fn_base + "_ffi" - return custom_call( - fn, - result_types=result_types, - operands=[a], - operand_layouts=[layout], - result_layouts=[ - layout, - tuple(range(num_bd, -1, -1)), - tuple(range(num_bd - 1, -1, -1)), - ], - operand_output_aliases={0: 0}, - result_shapes=result_shapes, - backend_config={}, - api_version=4, - ).results + return custom_call( + fn, + result_types=result_types, + operands=[batch_size_val, ensure_hlo_s32(m), ensure_hlo_s32(n), a], + operand_layouts=[scalar_layout] * 3 + [layout], + result_layouts=[ + layout, + tuple(range(num_bd, -1, -1)), + tuple(range(num_bd - 1, -1, -1)), + ], + operand_output_aliases={3: 0}, + result_shapes=result_shapes, + ).results # # ?geqrf: QR decomposition -def geqrf_hlo(dtype, a: ir.Value, *, - a_shape_vals: tuple[DimensionSize, ...]): - _lapack.initialize() + +def geqrf_hlo( + ctx, dtype, a: ir.Value, *, a_shape_vals: tuple[DimensionSize, ...] +): a_type = ir.RankedTensorType(a.type) assert len(a_shape_vals) >= 2 m, n = a_shape_vals[-2:] @@ -210,58 +215,77 @@ def geqrf_hlo(dtype, a: ir.Value, *, batch_dims_vals = a_shape_vals[:-2] num_bd = len(batch_dims_vals) + fn_base = prepare_lapack_call(fn_base="geqrf", dtype=dtype) - if dtype == np.float32: - fn = "lapack_sgeqrf" - lwork = _lapack.lapack_sgeqrf_workspace(m, n) - elif dtype == np.float64: - fn = "lapack_dgeqrf" - lwork = _lapack.lapack_dgeqrf_workspace(m, n) - elif dtype == np.complex64: - fn = "lapack_cgeqrf" - lwork = _lapack.lapack_cgeqrf_workspace(m, n) - elif dtype == np.complex128: - fn = "lapack_zgeqrf" - lwork = _lapack.lapack_zgeqrf_workspace(m, n) - else: - raise NotImplementedError(f"Unsupported dtype {dtype}") - - scalar_layout = [] layout = (num_bd, num_bd + 1) + tuple(range(num_bd - 1, -1, -1)) i32_type = ir.IntegerType.get_signless(32) - batch_size_val = hlo_s32(1) - for b_v in batch_dims_vals: - batch_size_val = hlo.multiply(batch_size_val, ensure_hlo_s32(b_v)) + if ctx.is_forward_compat(): + fn = fn_base + if dtype == np.float32: + lwork = _lapack.lapack_sgeqrf_workspace(m, n) + elif dtype == np.float64: + lwork = _lapack.lapack_dgeqrf_workspace(m, n) + elif dtype == np.complex64: + lwork = _lapack.lapack_cgeqrf_workspace(m, n) + elif dtype == np.complex128: + lwork = _lapack.lapack_zgeqrf_workspace(m, n) + else: + raise NotImplementedError(f"Unsupported dtype {dtype}") + + scalar_layout = [] + batch_size_val = hlo_s32(1) + for b_v in batch_dims_vals: + batch_size_val = hlo.multiply(batch_size_val, ensure_hlo_s32(b_v)) + shape_type_pairs: Sequence[ShapeTypePair] = [ + (a_shape_vals, a_type.element_type), + (batch_dims_vals + (min(m, n),), a_type.element_type), + (batch_dims_vals, i32_type), + ([lwork], a_type.element_type), + ] + result_types, result_shapes = mk_result_types_and_shapes(shape_type_pairs) + return custom_call( + fn, + result_types=result_types, + operands=[batch_size_val, hlo_s32(m), hlo_s32(n), hlo_s32(lwork), a], + operand_layouts=[scalar_layout] * 4 + [layout], + result_layouts=[ + layout, + tuple(range(num_bd, -1, -1)), + tuple(range(num_bd - 1, -1, -1)), + [0], + ], + operand_output_aliases={4: 0}, + result_shapes=result_shapes, + ).results[:3] + fn = fn_base + "_ffi" shape_type_pairs: Sequence[ShapeTypePair] = [ (a_shape_vals, a_type.element_type), (batch_dims_vals + (min(m, n),), a_type.element_type), - (batch_dims_vals, i32_type), - ([lwork], a_type.element_type), ] result_types, result_shapes = mk_result_types_and_shapes(shape_type_pairs) - out = custom_call( + return custom_call( fn, result_types=result_types, - operands=[batch_size_val, hlo_s32(m), hlo_s32(n), hlo_s32(lwork), a], - operand_layouts=[scalar_layout] * 4 + [layout], + operands=[a], + operand_layouts=[layout], result_layouts=[ - layout, - tuple(range(num_bd, -1, -1)), - tuple(range(num_bd - 1, -1, -1)), - [0], + layout, + tuple(range(num_bd, -1, -1)), ], - operand_output_aliases={4: 0}, + operand_output_aliases={0: 0}, result_shapes=result_shapes, + backend_config={}, + api_version=4, ).results - return out[:3] # # ?orgqr: product of elementary Householder reflectors: -def orgqr_hlo(dtype, a: ir.Value, tau, *, +def orgqr_hlo(ctx, dtype, a: ir.Value, tau, *, a_shape_vals: tuple[DimensionSize, ...], tau_shape_vals: tuple[DimensionSize, ...]): - _lapack.initialize() + fn_base = "un" if dtype == np.complex64 or dtype == np.complex128 else "or" + fn_base = prepare_lapack_call(fn_base=fn_base + "gqr", dtype=dtype) a_type = ir.RankedTensorType(a.type) dims = a_type.shape dims_vals = a_shape_vals @@ -271,64 +295,83 @@ def orgqr_hlo(dtype, a: ir.Value, tau, *, assert n != ir.ShapedType.get_dynamic_size() batch_dims_vals = dims_vals[:-2] num_bd = len(batch_dims_vals) - batch_size_val = hlo_s32(1) - for b_v in batch_dims_vals: - batch_size_val = hlo.multiply(batch_size_val, ensure_hlo_s32(b_v)) - k = tau_shape_vals[-1] assert type(k) is int - - if dtype == np.float32: - fn = "lapack_sorgqr" - lwork = _lapack.lapack_sorgqr_workspace(m, n, k) - elif dtype == np.float64: - fn = "lapack_dorgqr" - lwork = _lapack.lapack_dorgqr_workspace(m, n, k) - elif dtype == np.complex64: - fn = "lapack_cungqr" - lwork = _lapack.lapack_cungqr_workspace(m, n, k) - elif dtype == np.complex128: - fn = "lapack_zungqr" - lwork = _lapack.lapack_zungqr_workspace(m, n, k) - else: - raise NotImplementedError(f"Unsupported dtype {dtype}") - - scalar_layout = [] layout = (num_bd, num_bd + 1) + tuple(range(num_bd - 1, -1, -1)) i32_type = ir.IntegerType.get_signless(32) + + if ctx.is_forward_compat(): + fn = fn_base + batch_size_val = hlo_s32(1) + for b_v in batch_dims_vals: + batch_size_val = hlo.multiply(batch_size_val, ensure_hlo_s32(b_v)) + + if dtype == np.float32: + lwork = _lapack.lapack_sorgqr_workspace(m, n, k) + elif dtype == np.float64: + lwork = _lapack.lapack_dorgqr_workspace(m, n, k) + elif dtype == np.complex64: + lwork = _lapack.lapack_cungqr_workspace(m, n, k) + elif dtype == np.complex128: + lwork = _lapack.lapack_zungqr_workspace(m, n, k) + else: + raise NotImplementedError(f"Unsupported dtype {dtype}") + + scalar_layout = [] + shape_type_pairs: Sequence[ShapeTypePair] = [ + (a_shape_vals, a_type.element_type), + (batch_dims_vals, i32_type), + ([lwork], a_type.element_type), + ] + result_types, result_shapes = mk_result_types_and_shapes(shape_type_pairs) + return custom_call( + fn, + result_types=result_types, + operands=[batch_size_val, hlo_s32(m), hlo_s32(n), hlo_s32(k), + hlo_s32(lwork), a, tau], + operand_layouts=[scalar_layout] * 5 + [ + layout, + tuple(range(num_bd, -1, -1)), + ], + result_layouts=[ + layout, + tuple(range(num_bd - 1, -1, -1)), + [0], + ], + operand_output_aliases={5: 0}, + result_shapes=result_shapes, + ).results[:2] + fn = fn_base + "_ffi" shape_type_pairs: Sequence[ShapeTypePair] = [ (a_shape_vals, a_type.element_type), - (batch_dims_vals, i32_type), - ([lwork], a_type.element_type), ] result_types, result_shapes = mk_result_types_and_shapes(shape_type_pairs) - out = custom_call( + return custom_call( fn, result_types=result_types, - operands=[batch_size_val, hlo_s32(m), hlo_s32(n), hlo_s32(k), - hlo_s32(lwork), a, tau], - operand_layouts=[scalar_layout] * 5 + [ + operands=[ + a, tau + ], + operand_layouts=[ layout, tuple(range(num_bd, -1, -1)), ], result_layouts=[ layout, - tuple(range(num_bd - 1, -1, -1)), - [0], ], - operand_output_aliases={5: 0}, + operand_output_aliases={0: 0}, result_shapes=result_shapes, + backend_config={}, + api_version=4, ).results - return out[:2] # ?potrf: Cholesky decomposition def potrf_hlo(ctx, dtype, a: ir.Value, *, lower=False, a_shape_vals: tuple[DimensionSize, ...]): - _lapack.initialize() a_type = ir.RankedTensorType(a.type) - fn_base = build_lapack_fn_target(fn_base="potrf", dtype=dtype) + fn_base = prepare_lapack_call(fn_base="potrf", dtype=dtype) batch_dims_vals = a_shape_vals[:-2] num_bd = len(batch_dims_vals) layout = (num_bd, num_bd + 1) + tuple(range(num_bd - 1, -1, -1)) @@ -375,9 +418,8 @@ def potrf_hlo(ctx, dtype, a: ir.Value, *, lower=False, # # ?gesdd: Singular value decomposition -def gesdd_hlo(dtype, a: ir.Value, *, full_matrices=True, compute_uv=True, +def gesdd_hlo(ctx, dtype, a: ir.Value, *, full_matrices=True, compute_uv=True, a_shape_vals: tuple[DimensionSize, ...]): - _lapack.initialize() a_type = ir.RankedTensorType(a.type) assert len(a_shape_vals) >= 2 m, n = a_shape_vals[-2:] @@ -385,89 +427,127 @@ def gesdd_hlo(dtype, a: ir.Value, *, full_matrices=True, compute_uv=True, assert type(n) is int batch_dims_vals = a_shape_vals[:-2] num_bd = len(batch_dims_vals) - batch_size_val = hlo_s32(1) - for b_v in batch_dims_vals: - batch_size_val = hlo.multiply(batch_size_val, ensure_hlo_s32(b_v)) - + fn_base = prepare_lapack_call(fn_base="gesdd", dtype=dtype) i32_type = ir.IntegerType.get_signless(32) workspace: list[ShapeTypePair] - if dtype == np.float32: - fn = "lapack_sgesdd" - singular_vals_type = ir.F32Type.get() - lwork = _lapack.sgesdd_work_size(m, n, compute_uv, full_matrices) - workspace = [ - ([_lapack.gesdd_iwork_size(m, n)], i32_type), - ([lwork], a_type.element_type), - ] - workspace_layouts = [[0], [0]] - elif dtype == np.float64: - fn = "lapack_dgesdd" - singular_vals_type = ir.F64Type.get() - lwork = _lapack.dgesdd_work_size(m, n, compute_uv, full_matrices) - workspace = [ - ([_lapack.gesdd_iwork_size(m, n)], i32_type), - ([lwork], a_type.element_type), - ] - workspace_layouts = [[0], [0]] - elif dtype == np.complex64: - fn = "lapack_cgesdd" + + # TODO(b/344892332): Remove the old kernel after the compatibility period. + if ctx.is_forward_compat(): + fn = fn_base + batch_size_val = hlo_s32(1) + for b_v in batch_dims_vals: + batch_size_val = hlo.multiply(batch_size_val, ensure_hlo_s32(b_v)) + if dtype == np.float32: + singular_vals_type = ir.F32Type.get() + lwork = _lapack.sgesdd_work_size(m, n, compute_uv, full_matrices) + workspace = [ + ([_lapack.gesdd_iwork_size(m, n)], i32_type), + ([lwork], a_type.element_type), + ] + workspace_layouts = [[0], [0]] + elif dtype == np.float64: + singular_vals_type = ir.F64Type.get() + lwork = _lapack.dgesdd_work_size(m, n, compute_uv, full_matrices) + workspace = [ + ([_lapack.gesdd_iwork_size(m, n)], i32_type), + ([lwork], a_type.element_type), + ] + workspace_layouts = [[0], [0]] + elif dtype == np.complex64: + singular_vals_type = ir.F32Type.get() + lwork = _lapack.cgesdd_work_size(m, n, compute_uv, full_matrices) + workspace = [ + ([_lapack.gesdd_iwork_size(m, n)], i32_type), + ([_lapack.cgesdd_rwork_size(m, n, int(compute_uv))], ir.F32Type.get()), + ([lwork], a_type.element_type), + ] + workspace_layouts = [[0], [0], [0]] + elif dtype == np.complex128: + singular_vals_type = ir.F64Type.get() + lwork = _lapack.zgesdd_work_size(m, n, compute_uv, full_matrices) + workspace = [ + ([_lapack.gesdd_iwork_size(m, n)], i32_type), + ([_lapack.cgesdd_rwork_size(m, n, int(compute_uv))], ir.F64Type.get()), + ([lwork], a_type.element_type), + ] + workspace_layouts = [[0], [0], [0]] + else: + raise NotImplementedError(f"Unsupported dtype {dtype}") + + scalar_layout = [] + layout = (num_bd, num_bd + 1) + tuple(range(num_bd - 1, -1, -1)) + + shape_type_pairs: Sequence[ShapeTypePair] = [ + (a_shape_vals, a_type.element_type), + (batch_dims_vals + (min(m, n),), singular_vals_type), + (batch_dims_vals + (m, m if full_matrices else min(m, n)), a_type.element_type), + (batch_dims_vals + (n if full_matrices else min(m, n), n), a_type.element_type), + (batch_dims_vals, i32_type), + ] + workspace + result_types, result_shapes = mk_result_types_and_shapes(shape_type_pairs) + return custom_call( + fn, + result_types=result_types, + operands=[hlo_s32(int(full_matrices)), hlo_s32(int(compute_uv)), batch_size_val, + hlo_s32(m), hlo_s32(n), hlo_s32(lwork), a], + operand_layouts=[scalar_layout] * 6 + [layout], + result_layouts=[ + layout, + (num_bd,) + tuple(range(num_bd - 1, -1, -1)), + layout, + layout, + tuple(range(num_bd - 1, -1, -1)), + ] + workspace_layouts, + operand_output_aliases={6: 0}, + result_shapes=result_shapes + ).results[1:5] + fn = fn_base + "_ffi" + mode_attr = _svd_computation_attr( + compute_uv=compute_uv, full_matrices=full_matrices + ) + if dtype == np.float32 or dtype == np.complex64: singular_vals_type = ir.F32Type.get() - lwork = _lapack.cgesdd_work_size(m, n, compute_uv, full_matrices) - workspace = [ - ([_lapack.gesdd_iwork_size(m, n)], i32_type), - ([_lapack.cgesdd_rwork_size(m, n, int(compute_uv))], ir.F32Type.get()), - ([lwork], a_type.element_type), - ] - workspace_layouts = [[0], [0], [0]] - elif dtype == np.complex128: - fn = "lapack_zgesdd" + elif dtype == np.float64 or dtype == np.complex128: singular_vals_type = ir.F64Type.get() - lwork = _lapack.zgesdd_work_size(m, n, compute_uv, full_matrices) - workspace = [ - ([_lapack.gesdd_iwork_size(m, n)], i32_type), - ([_lapack.cgesdd_rwork_size(m, n, int(compute_uv))], ir.F64Type.get()), - ([lwork], a_type.element_type), - ] - workspace_layouts = [[0], [0], [0]] else: raise NotImplementedError(f"Unsupported dtype {dtype}") - scalar_layout = [] layout = (num_bd, num_bd + 1) + tuple(range(num_bd - 1, -1, -1)) - + a_elem_type = a_type.element_type shape_type_pairs: Sequence[ShapeTypePair] = [ - (a_shape_vals, a_type.element_type), - (batch_dims_vals + (min(m, n),), singular_vals_type), - (batch_dims_vals + (m, m if full_matrices else min(m, n)), a_type.element_type), - (batch_dims_vals + (n if full_matrices else min(m, n), n), a_type.element_type), - (batch_dims_vals, i32_type), - ] + workspace + (a_shape_vals, a_elem_type), + (batch_dims_vals + (min(m, n),), singular_vals_type), + (batch_dims_vals + (m, m if full_matrices else min(m, n)), a_elem_type), + (batch_dims_vals + (n if full_matrices else min(m, n), n), a_elem_type), + (batch_dims_vals, i32_type), + ] result_types, result_shapes = mk_result_types_and_shapes(shape_type_pairs) - out = custom_call( + return custom_call( fn, result_types=result_types, - operands=[hlo_s32(int(full_matrices)), hlo_s32(int(compute_uv)), batch_size_val, - hlo_s32(m), hlo_s32(n), hlo_s32(lwork), a], - operand_layouts=[scalar_layout] * 6 + [layout], + operands=[a], + operand_layouts=[layout], result_layouts=[ layout, (num_bd,) + tuple(range(num_bd - 1, -1, -1)), layout, layout, tuple(range(num_bd - 1, -1, -1)), - ] + workspace_layouts, - operand_output_aliases={6: 0}, - result_shapes=result_shapes - ).results - return out[1:5] + ], + operand_output_aliases={0: 0}, + result_shapes=result_shapes, + backend_config={ + "mode": mode_attr, + }, + api_version=4, + ).results[1:] # # syevd: Symmetric eigendecomposition -def syevd_hlo(dtype, a: ir.Value, +def syevd_hlo(ctx, dtype, a: ir.Value, a_shape_vals: tuple[DimensionSize, ...], lower=False): - _lapack.initialize() a_type = ir.RankedTensorType(a.type) assert len(a_shape_vals) >= 2 m, n = a_shape_vals[-2:] @@ -476,76 +556,110 @@ def syevd_hlo(dtype, a: ir.Value, batch_dims_vals = a_shape_vals[:-2] num_bd = len(a_shape_vals) - 2 + mode = _enum_to_char_attr(eig.ComputationMode.kComputeEigenvectors) i32_type = ir.IntegerType.get_signless(32) workspace: list[ShapeTypePair] - if dtype == np.float32: - fn = "lapack_ssyevd" - eigvals_type = ir.F32Type.get() - workspace = [ - ([_lapack.syevd_work_size(n)], a_type.element_type), - ([_lapack.syevd_iwork_size(n)], i32_type), - ] - elif dtype == np.float64: - fn = "lapack_dsyevd" - eigvals_type = ir.F64Type.get() - workspace = [ - ([_lapack.syevd_work_size(n)], a_type.element_type), - ([_lapack.syevd_iwork_size(n)], i32_type), - ] - elif dtype == np.complex64: - fn = "lapack_cheevd" + layout = (num_bd, num_bd + 1) + tuple(range(num_bd - 1, -1, -1)) + # Hermitian is for complex square matrices, symmetric otherwise. + fn_base = "he" if dtype == np.complex64 or dtype == np.complex128 else "sy" + fn_base = prepare_lapack_call(fn_base=fn_base + "evd", dtype=dtype) + if ctx.is_forward_compat(): + fn = fn_base + if dtype == np.float32: + eigvals_type = ir.F32Type.get() + workspace = [ + ([_lapack.syevd_work_size(n)], a_type.element_type), + ([_lapack.syevd_iwork_size(n)], i32_type), + ] + elif dtype == np.float64: + eigvals_type = ir.F64Type.get() + workspace = [ + ([_lapack.syevd_work_size(n)], a_type.element_type), + ([_lapack.syevd_iwork_size(n)], i32_type), + ] + elif dtype == np.complex64: + eigvals_type = ir.F32Type.get() + workspace = [ + ([_lapack.heevd_work_size(n)], a_type.element_type), + ([_lapack.heevd_rwork_size(n)], eigvals_type), + ([_lapack.syevd_iwork_size(n)], i32_type), + ] + elif dtype == np.complex128: + eigvals_type = ir.F64Type.get() + workspace = [ + ([_lapack.heevd_work_size(n)], a_type.element_type), + ([_lapack.heevd_rwork_size(n)], eigvals_type), + ([_lapack.syevd_iwork_size(n)], i32_type), + ] + else: + raise NotImplementedError(f"Unsupported dtype {dtype}") + + batch_size_val = hlo_s32(1) + for b_v in batch_dims_vals: + batch_size_val = hlo.multiply(batch_size_val, ensure_hlo_s32(b_v)) + + scalar_layout = [] + shape_layout = [0] + workspace_layouts = [shape_layout] * len(workspace) + layout = (num_bd, num_bd + 1) + tuple(range(num_bd - 1, -1, -1)) + + result_types, result_shapes = mk_result_types_and_shapes( + [(a_shape_vals, a_type.element_type), + (batch_dims_vals + (n,), eigvals_type), + (batch_dims_vals, i32_type)] + workspace + ) + + return custom_call( + fn, + result_types=result_types, + operands=[hlo_s32(1 if lower else 0), batch_size_val, ensure_hlo_s32(n), a], + operand_layouts=[scalar_layout] * 3 + [layout], + result_layouts=[ + layout, + tuple(range(num_bd, -1, -1)), + tuple(range(num_bd - 1, -1, -1)), + ] + workspace_layouts, + operand_output_aliases={3: 0}, + result_shapes=result_shapes, + ).results[:3] + fn = fn_base + "_ffi" + if dtype == np.float32 or dtype == np.complex64: eigvals_type = ir.F32Type.get() - workspace = [ - ([_lapack.heevd_work_size(n)], a_type.element_type), - ([_lapack.heevd_rwork_size(n)], eigvals_type), - ([_lapack.syevd_iwork_size(n)], i32_type), - ] - elif dtype == np.complex128: - fn = "lapack_zheevd" + elif dtype == np.float64 or dtype == np.complex128: eigvals_type = ir.F64Type.get() - workspace = [ - ([_lapack.heevd_work_size(n)], a_type.element_type), - ([_lapack.heevd_rwork_size(n)], eigvals_type), - ([_lapack.syevd_iwork_size(n)], i32_type), - ] else: raise NotImplementedError(f"Unsupported dtype {dtype}") - batch_size_val = hlo_s32(1) - for b_v in batch_dims_vals: - batch_size_val = hlo.multiply(batch_size_val, ensure_hlo_s32(b_v)) - - scalar_layout = [] - shape_layout = [0] - workspace_layouts = [shape_layout] * len(workspace) - layout = (num_bd, num_bd + 1) + tuple(range(num_bd - 1, -1, -1)) - - result_types, result_shapes = mk_result_types_and_shapes( - [(a_shape_vals, a_type.element_type), - (batch_dims_vals + (n,), eigvals_type), - (batch_dims_vals, i32_type)] + workspace - ) + result_types, result_shapes = mk_result_types_and_shapes([ + (a_shape_vals, a_type.element_type), + (batch_dims_vals + (n,), eigvals_type), + (batch_dims_vals, i32_type), + ]) - out = custom_call( + return custom_call( fn, result_types=result_types, - operands=[hlo_s32(1 if lower else 0), batch_size_val, ensure_hlo_s32(n), a], - operand_layouts=[scalar_layout] * 3 + [layout], + operands=[a], + operand_layouts=[layout], result_layouts=[ layout, tuple(range(num_bd, -1, -1)), tuple(range(num_bd - 1, -1, -1)), - ] + workspace_layouts, - operand_output_aliases={3: 0}, + ], + operand_output_aliases={0: 0}, result_shapes=result_shapes, + backend_config={ + "uplo": _matrix_uplo_attr(lower=lower), + "mode": mode, + }, + api_version=4, ).results - return out[:3] # # geev: Nonsymmetric eigendecomposition (eig) -def geev_hlo(dtype, input, *, +def geev_hlo(ctx, dtype, input, *, input_shape_vals: tuple[DimensionSize, ...], # input.shape as ir.Values jobvl=True, jobvr=True): # input_shape_vals are used for when input has dynamic shapes. @@ -558,80 +672,128 @@ def geev_hlo(dtype, input, *, layout = (num_bd, num_bd + 1) + tuple(range(num_bd - 1, -1, -1)) - jobvl_c = ord('V' if jobvl else 'N') - jobvr_c = ord('V' if jobvr else 'N') + compute_left = ( + eig.ComputationMode.kComputeEigenvectors + if jobvl + else eig.ComputationMode.kNoEigenvectors + ) + + compute_right = ( + eig.ComputationMode.kComputeEigenvectors + if jobvr + else eig.ComputationMode.kNoEigenvectors + ) + fn_base = build_lapack_fn_target(fn_base="geev", dtype=dtype) i32_type = ir.IntegerType.get_signless(32) f32_type = ir.F32Type.get() f64_type = ir.F64Type.get() c64_type = ir.ComplexType.get(ir.F32Type.get()) c128_type = ir.ComplexType.get(ir.F64Type.get()) + if ctx.is_forward_compat(): + fn = fn_base + workspaces: list[ShapeTypePair] + eigvals: list[ShapeTypePair] + if dtype == np.float32: + real = True + eigvecs_type = c64_type + workspaces = [([n, n], f32_type)] * 3 + workspace_layouts = [[0, 1]] * 3 + eigvals = [(batch_dims_vals + (n,), f32_type)] * 2 + eigvals_layouts = [tuple(range(num_bd, -1, -1))] * 2 + elif dtype == np.float64: + real = True + eigvecs_type = c128_type + workspaces = [([n, n], f64_type)] * 3 + workspace_layouts = [[0, 1]] * 3 + eigvals = [(batch_dims_vals + (n,), f64_type)] * 2 + eigvals_layouts = [tuple(range(num_bd, -1, -1))] * 2 + elif dtype == np.complex64: + real = False + eigvecs_type = c64_type + workspaces = [([n, n], c64_type), ([hlo_add(n, n)], f32_type)] + workspace_layouts = [[0, 1], [0]] + eigvals = [(batch_dims_vals + (n,), c64_type)] + eigvals_layouts = [tuple(range(num_bd, -1, -1))] + elif dtype == np.complex128: + real = False + eigvecs_type = c128_type + workspaces = [([n, n], c128_type), ([hlo_add(n, n)], f64_type)] + workspace_layouts = [[0, 1], [0]] + eigvals = [(batch_dims_vals + (n,), c128_type)] + eigvals_layouts = [tuple(range(num_bd, -1, -1))] + else: + raise NotImplementedError(f"Unsupported dtype {dtype}") - workspaces: list[ShapeTypePair] - eigvals: list[ShapeTypePair] - if dtype == np.float32: - fn = "lapack_sgeev" - real = True - eigvecs_type = c64_type - workspaces = [([n, n], f32_type)] * 3 - workspace_layouts = [[0, 1]] * 3 - eigvals = [(batch_dims_vals + (n,), f32_type)] * 2 - eigvals_layouts = [tuple(range(num_bd, -1, -1))] * 2 - elif dtype == np.float64: - fn = "lapack_dgeev" - real = True - eigvecs_type = c128_type - workspaces = [([n, n], f64_type)] * 3 - workspace_layouts = [[0, 1]] * 3 - eigvals = [(batch_dims_vals + (n,), f64_type)] * 2 - eigvals_layouts = [tuple(range(num_bd, -1, -1))] * 2 - elif dtype == np.complex64: - fn = "lapack_cgeev" - real = False - eigvecs_type = c64_type - workspaces = [([n, n], c64_type), ([hlo_add(n, n)], f32_type)] - workspace_layouts = [[0, 1], [0]] - eigvals = [(batch_dims_vals + (n,), c64_type)] - eigvals_layouts = [tuple(range(num_bd, -1, -1))] - elif dtype == np.complex128: - fn = "lapack_zgeev" - real = False - eigvecs_type = c128_type - workspaces = [([n, n], c128_type), ([hlo_add(n, n)], f64_type)] - workspace_layouts = [[0, 1], [0]] - eigvals = [(batch_dims_vals + (n,), c128_type)] - eigvals_layouts = [tuple(range(num_bd, -1, -1))] - else: - raise NotImplementedError(f"Unsupported dtype {dtype}") - - scalar_layout = [] - info_layout = tuple(range(num_bd - 1, -1, -1)) + scalar_layout = [] + info_layout = tuple(range(num_bd - 1, -1, -1)) - batch_size_val = hlo_s32(1) - for b_v in batch_dims_vals: - batch_size_val = hlo.multiply(batch_size_val, ensure_hlo_s32(b_v)) + batch_size_val = hlo_s32(1) + for b_v in batch_dims_vals: + batch_size_val = hlo.multiply(batch_size_val, ensure_hlo_s32(b_v)) - shape_type_pairs: Sequence[ShapeTypePair] = workspaces + eigvals + [ + shape_type_pairs: Sequence[ShapeTypePair] = workspaces + eigvals + [ + (input_shape_vals, eigvecs_type), + (input_shape_vals, eigvecs_type), + (batch_dims_vals, i32_type)] + result_types, result_shapes = mk_result_types_and_shapes(shape_type_pairs) + out = custom_call( + fn, + result_types=result_types, + operands=[batch_size_val, ensure_hlo_s32(n), + hlo_u8(compute_left.value), + hlo_u8(compute_right.value), + input], + operand_layouts=[scalar_layout] * 4 + [layout], + result_layouts=(workspace_layouts + eigvals_layouts + [layout] * 2 + + [info_layout]), + result_shapes=result_shapes, + ).results + if real: + return (hlo.complex(out[3], out[4]), out[5], out[6], out[7]) + else: + return out[2:6] + fn = fn_base + "_ffi" + real = dtype == np.float32 or dtype == np.float64 + eigvecs_type = ( + c64_type if dtype == np.float32 or dtype == np.complex64 else c128_type + ) + input_type = ir.RankedTensorType(input.type) + eigvals = [(batch_dims_vals + (n,), input_type.element_type)] + eigvals_layouts = [tuple(range(num_bd, -1, -1))] + if real: + eigvals = eigvals * 2 + eigvals_layouts = eigvals_layouts * 2 + info_layout = tuple(range(num_bd - 1, -1, -1)) + shape_type_pairs: Sequence[ShapeTypePair] = [ + *eigvals, (input_shape_vals, eigvecs_type), (input_shape_vals, eigvecs_type), - (batch_dims_vals, i32_type)] + (batch_dims_vals, i32_type), + ] result_types, result_shapes = mk_result_types_and_shapes(shape_type_pairs) out = custom_call( fn, result_types=result_types, - operands=[batch_size_val, ensure_hlo_s32(n), - hlo_u8(jobvl_c), - hlo_u8(jobvr_c), - input], - operand_layouts=[scalar_layout] * 4 + [layout], - result_layouts=(workspace_layouts + eigvals_layouts + [layout] * 2 + - [info_layout]), + operands=[input], + operand_layouts=[layout], + result_layouts=( + *eigvals_layouts, + layout, + layout, + info_layout, + ), result_shapes=result_shapes, + backend_config={ + "compute_left": _enum_to_char_attr(compute_left), + "compute_right": _enum_to_char_attr(compute_right), + }, + api_version=4, ).results if real: - return (hlo.complex(out[3], out[4]), out[5], out[6], out[7]) + return (hlo.complex(out[0], out[1]), out[2], out[3], out[4]) else: - return out[2:6] + return out[:4] # # gees : Schur factorization diff --git a/jaxlib/mlir/_mlir_libs/BUILD.bazel b/jaxlib/mlir/_mlir_libs/BUILD.bazel index 06fa9e760a70..db02eb8bbff1 100644 --- a/jaxlib/mlir/_mlir_libs/BUILD.bazel +++ b/jaxlib/mlir/_mlir_libs/BUILD.bazel @@ -310,7 +310,7 @@ py_extension( py_extension( name = "_chlo", srcs = [ - "@stablehlo//:stablehlo/integrations/python/ChloModule.cpp", + "@stablehlo//:chlo_py_api_files", ], copts = COPTS, linkopts = LINKOPTS, @@ -327,23 +327,18 @@ py_extension( py_extension( name = "_stablehlo", srcs = [ - "@stablehlo//:stablehlo/integrations/python/PortableApi.cpp", - "@stablehlo//:stablehlo/integrations/python/PortableApi.h", - "@stablehlo//:stablehlo/integrations/python/StablehloModule.cpp", + "@stablehlo//:stablehlo_py_api_files", ], copts = COPTS, linkopts = LINKOPTS, deps = [ ":jaxlib_mlir_capi_shared_library", + "@llvm-project//llvm:Support", "@llvm-project//mlir:CAPIIRHeaders", - "@llvm-project//mlir:IR", "@llvm-project//mlir:MLIRBindingsPythonHeaders", "@local_config_python//:headers", "@pybind11", - "@stablehlo//:reference_api", "@stablehlo//:stablehlo_capi_headers", - "@stablehlo//:stablehlo_portable_api", - "@stablehlo//:stablehlo_serialization", ], ) diff --git a/jaxlib/mlir/_mlir_libs/register_jax_dialects.cc b/jaxlib/mlir/_mlir_libs/register_jax_dialects.cc index e1958c211b33..2e10062945b5 100644 --- a/jaxlib/mlir/_mlir_libs/register_jax_dialects.cc +++ b/jaxlib/mlir/_mlir_libs/register_jax_dialects.cc @@ -1,5 +1,6 @@ // Registers MLIR dialects used by JAX. // This module is called by mlir/__init__.py during initialization. +#include #include "mlir-c/Dialect/Arith.h" #include "mlir-c/Dialect/Func.h" @@ -14,11 +15,13 @@ #include "mlir-c/Transforms.h" #include "mlir/Bindings/Python/PybindAdaptors.h" +namespace py = pybind11; + #define REGISTER_DIALECT(name) \ MlirDialectHandle name##_dialect = mlirGetDialectHandle__##name##__(); \ mlirDialectHandleInsertDialect(name##_dialect, registry) -PYBIND11_MODULE(register_jax_dialects, m) { +PYBIND11_MODULE(register_jax_dialects, m, py::mod_gil_not_used()) { m.doc() = "Registers upstream MLIR dialects used by JAX."; m.def("register_dialects", [](MlirDialectRegistry registry) { diff --git a/jaxlib/mlir/_mlir_libs/tpu_ext.cc b/jaxlib/mlir/_mlir_libs/tpu_ext.cc index 1d024a8b77a4..a50aef1ca6d4 100644 --- a/jaxlib/mlir/_mlir_libs/tpu_ext.cc +++ b/jaxlib/mlir/_mlir_libs/tpu_ext.cc @@ -314,7 +314,7 @@ MlirContext getDefaultContext() { } // namespace -PYBIND11_MODULE(_tpu_ext, m) { +PYBIND11_MODULE(_tpu_ext, m, py::mod_gil_not_used()) { mlirRegisterTPUPasses(); // Register all passes on load. py::class_(m, "ApplyVectorLayoutCtx", @@ -374,14 +374,21 @@ PYBIND11_MODULE(_tpu_ext, m) { .def(py::init([](int bitwidth, py::tuple offsets, py::tuple tiling, MlirTpuImplicitDim implicit_dim) { if (offsets.size() != 2) { - throw py::value_error("offsets should be of length 2"); + throw py::value_error("Offsets should be of length 2"); } - return mlirTpuVectorLayoutCreate( + if (tiling.size() != 2) { + throw py::value_error("Tiling should be of length 2"); + } + MlirTpuVectorLayout layout = mlirTpuVectorLayoutCreate( bitwidth, {offsetFromPyOffset(offsets[0]), offsetFromPyOffset(offsets[1])}, {tiling[0].cast(), tiling[1].cast()}, implicit_dim); + if (!mlirTpuVectorLayoutIsValid(layout, TARGET_SHAPE)) { + throw py::value_error("Layout not valid for target shape"); + } + return layout; }), py::arg("bitwidth"), py::arg("offsets"), py::arg("tiling"), py::arg("implicit_dim")) diff --git a/jaxlib/mlir/_mlir_libs/triton_ext.cc b/jaxlib/mlir/_mlir_libs/triton_ext.cc index 4b900b7c1cbf..e02e4f3d86e4 100644 --- a/jaxlib/mlir/_mlir_libs/triton_ext.cc +++ b/jaxlib/mlir/_mlir_libs/triton_ext.cc @@ -22,7 +22,7 @@ limitations under the License. namespace py = pybind11; -PYBIND11_MODULE(_triton_ext, m) { +PYBIND11_MODULE(_triton_ext, m, py::mod_gil_not_used()) { // // Dialects. // diff --git a/jaxlib/mosaic/BUILD b/jaxlib/mosaic/BUILD index 10acec815475..5452520204b8 100644 --- a/jaxlib/mosaic/BUILD +++ b/jaxlib/mosaic/BUILD @@ -12,15 +12,15 @@ # See the License for the specific language governing permissions and # limitations under the License. -load("@rules_python//python:defs.bzl", "py_library") load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library", "td_library") +load("@rules_python//python:defs.bzl", "py_library") licenses(["notice"]) package( default_applicable_licenses = [], default_visibility = [ - "//:__subpackages__", + "//jax:mosaic_users", ], ) @@ -54,6 +54,14 @@ cc_library( # compatible with libtpu deps = [ ":tpu_inc_gen", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/hash", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/types:span", "@llvm-project//llvm:Support", "@llvm-project//mlir:ArithDialect", "@llvm-project//mlir:ControlFlowDialect", @@ -71,18 +79,10 @@ cc_library( "@llvm-project//mlir:TransformUtils", "@llvm-project//mlir:VectorDialect", "@llvm-project//mlir:VectorTransforms", + "@tsl//tsl/platform:statusor", "@xla//xla:array", "@xla//xla:shape_util", "@xla//xla:util", - "@tsl//tsl/platform:statusor", - "@com_google_absl//absl/algorithm:container", - "@com_google_absl//absl/container:flat_hash_set", - "@com_google_absl//absl/hash", - "@com_google_absl//absl/log", - "@com_google_absl//absl/log:check", - "@com_google_absl//absl/status", - "@com_google_absl//absl/strings:string_view", - "@com_google_absl//absl/types:span", ], ) @@ -192,14 +192,14 @@ cc_library( deps = [ ":tpu_dialect", ":tpu_inc_gen", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", "@llvm-project//llvm:Support", "@llvm-project//mlir:CAPIIR", "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:IR", "@llvm-project//mlir:Support", "@xla//xla:array", - "@com_google_absl//absl/log", - "@com_google_absl//absl/log:check", ], ) diff --git a/jaxlib/mosaic/dialect/tpu/integrations/c/tpu_dialect.cc b/jaxlib/mosaic/dialect/tpu/integrations/c/tpu_dialect.cc index ef7d3fecfb22..3cc9b36972d6 100644 --- a/jaxlib/mosaic/dialect/tpu/integrations/c/tpu_dialect.cc +++ b/jaxlib/mosaic/dialect/tpu/integrations/c/tpu_dialect.cc @@ -312,6 +312,11 @@ void mlirTpuVectorLayoutPrint( unwrap(layout)->print(stream); } +bool mlirTpuVectorLayoutIsValid(MlirTpuVectorLayout layout, + MlirTpuI64TargetTuple target_shape) { + return unwrap(layout)->isValid(unwrap(target_shape)); +} + void mlirTpuVregDataBoundsDestroy(MlirTpuVregDataBounds data_bounds) { delete unwrap(data_bounds); } diff --git a/jaxlib/mosaic/dialect/tpu/integrations/c/tpu_dialect.h b/jaxlib/mosaic/dialect/tpu/integrations/c/tpu_dialect.h index 42c974b3a961..5b2a7009e9e6 100644 --- a/jaxlib/mosaic/dialect/tpu/integrations/c/tpu_dialect.h +++ b/jaxlib/mosaic/dialect/tpu/integrations/c/tpu_dialect.h @@ -191,6 +191,9 @@ MLIR_CAPI_EXPORTED bool mlirTpuVectorLayoutEquivalentTo( MLIR_CAPI_EXPORTED void mlirTpuVectorLayoutPrint( MlirTpuVectorLayout layout, MlirStringCallback callback, void* user_data); +MLIR_CAPI_EXPORTED bool mlirTpuVectorLayoutIsValid( + MlirTpuVectorLayout layout, MlirTpuI64TargetTuple target_shape); + MLIR_CAPI_EXPORTED void mlirTpuVregDataBoundsDestroy( MlirTpuVregDataBounds data_bounds); diff --git a/jaxlib/mosaic/dialect/tpu/tpu.td b/jaxlib/mosaic/dialect/tpu/tpu.td index 04709690e7d7..ffcc8d52cd05 100644 --- a/jaxlib/mosaic/dialect/tpu/tpu.td +++ b/jaxlib/mosaic/dialect/tpu/tpu.td @@ -233,6 +233,37 @@ def TPU_StridedStoreOp : TPU_Op<"strided_store"> { let hasVerifier = 1; } +def TPU_ShuffledLoadOp : TPU_Op<"shuffled_load"> { + let arguments = (ins + AnyMemRef:$base, + Variadic:$indices, + DenseBoolArrayAttr:$sublane_mask, + DenseI32ArrayAttr:$sublane_offsets + ); + let results = (outs TPU_Vreg:$result); + let assemblyFormat = [{ + $base `[` $indices `]` attr-dict `:` type($base) `,` type($result) + }]; + let hasVerifier = 1; + let hasCanonicalizeMethod = 1; +} + +def TPU_ShuffledStoreOp : TPU_Op<"shuffled_store"> { + let arguments = (ins + TPU_Vreg:$valueToStore, + AnyMemRef:$base, + Variadic:$indices, + DenseBoolArrayAttr:$sublane_mask, + DenseI32ArrayAttr:$sublane_offsets + ); + let results = (outs); + let assemblyFormat = [{ + $base `[` $indices `]` `,` $valueToStore attr-dict `:` type($base) `,` type($valueToStore) + }]; + let hasVerifier = 1; + let hasCanonicalizeMethod = 1; +} + // TODO(jevinjiang): deprecate to use dynamic_rotate. def TPU_RotateOp : TPU_Op<"rotate", [Pure, SameOperandsAndResultType]> { let arguments = (ins @@ -495,6 +526,16 @@ def TPU_MemRefReshapeOp : TPU_Op<"memref_reshape", [Pure]> { let hasCanonicalizeMethod = 1; } +def TPU_MemRefBitcastOp : TPU_Op<"memref_bitcast", [Pure]> { + let arguments = (ins AnyMemRef:$input); + let results = (outs AnyMemRef:$result); + let assemblyFormat = [{ + $input attr-dict `:` type($input) `->` type($result) + }]; + let hasVerifier = 1; + let hasCanonicalizeMethod = 1; +} + def TPU_ReinterpretCastOp : TPU_Op<"reinterpret_cast", [Pure]> { let arguments = (ins AnyMemRef:$input); let results = (outs AnyMemRef:$result); @@ -749,6 +790,7 @@ def ApplyVectorLayoutPass : Pass<"tpu-apply-vector-layout", "::mlir::func::FuncO Option<"mxu_contracting_size", "mxu-contracting-size", "int", /*default=*/"128", "">, Option<"mxu_noncontracting_size", "mxu-noncontracting-size", "int", /*default=*/"128", "">, Option<"max_sublanes_in_scratch", "max-sublanes-in-scratch", "int", /*default=*/"0", "">, + Option<"vmem_banks", "vmem-banks", "int", /*default=*/"-1", "">, ]; } diff --git a/jaxlib/mosaic/dialect/tpu/tpu_dialect.h b/jaxlib/mosaic/dialect/tpu/tpu_dialect.h index 510bd384d656..00bd15b57153 100644 --- a/jaxlib/mosaic/dialect/tpu/tpu_dialect.h +++ b/jaxlib/mosaic/dialect/tpu/tpu_dialect.h @@ -62,6 +62,7 @@ struct ApplyVectorLayoutContext { // mxu_shape = {contracting_size, non_contracting_size} std::array mxu_shape = {128, 128}; int64_t max_sublanes_in_scratch = 0; + int64_t vmem_banks = -1; // -1 means "unspecified". }; std::pair mightCommunicateBetweenChips(Operation* op); diff --git a/jaxlib/mosaic/dialect/tpu/tpu_ops.cc b/jaxlib/mosaic/dialect/tpu/tpu_ops.cc index 0202fbb3b7f7..d80db4e1394e 100644 --- a/jaxlib/mosaic/dialect/tpu/tpu_ops.cc +++ b/jaxlib/mosaic/dialect/tpu/tpu_ops.cc @@ -303,6 +303,101 @@ LogicalResult MemRefReshapeOp::canonicalize(MemRefReshapeOp op, return success(); } +LogicalResult MemRefBitcastOp::verify() { + auto src_ty = getMemRefType(getInput()); + auto tgt_ty = getType(); + if (tgt_ty.getMemorySpace() != nullptr && + tgt_ty.getMemorySpace() != src_ty.getMemorySpace()) { + return emitOpError("Memory spaces do not match."); + } + if (src_ty.getRank() != tgt_ty.getRank()) { + return emitOpError("Ranks do not match."); + } + if (src_ty.getRank() <= 1) { + return emitOpError("Not implemented: 1d memref bitcast."); + } + auto src_bitwidth = src_ty.getElementTypeBitWidth(); + auto tgt_bitwidth = tgt_ty.getElementTypeBitWidth(); + for (int i = 0; i < src_ty.getRank(); ++i) { + auto src_dim_size = src_ty.getDimSize(i); + auto tgt_dim_size = tgt_ty.getDimSize(i); + if (i == src_ty.getRank() - 2) { + auto src_bits = src_dim_size * src_bitwidth; + auto tgt_bits = tgt_dim_size * tgt_bitwidth; + if (src_bits != tgt_bits) { + return emitOpError( + "Expected the same number of bits on the 2nd minormost " + "dim: (") + << src_dim_size << " * " << src_bitwidth << ") vs (" + << tgt_dim_size << " * " << tgt_bitwidth << ")"; + ; + } + } else { + if (src_dim_size != tgt_dim_size) { + return emitOpError("Expected the same dim size on dim ") + << i << ": " << src_dim_size << " vs " << tgt_dim_size; + } + } + } + // Source and target attributes may be different before propagation is done by + // the canonicalizer, so we allow this when attributes are "unset" in the + // target type. + auto tgt_layout = dyn_cast(tgt_ty.getLayout()); + if (!tgt_layout) { + return success(); + } + auto src_layout = dyn_cast(src_ty.getLayout()); + if (!src_layout) { + return emitOpError("Expected a tiled layout for the input memref."); + } + // TODO(jevinjiang): verify memref tiling is valid. Here we just assume the + // source and target tilings are valid. + auto src_tile = src_layout.getTiles().front().dimensions(); + auto tgt_tile = tgt_layout.getTiles().front().dimensions(); + if (src_tile[0] * src_bitwidth != tgt_tile[0] * tgt_bitwidth) { + return emitOpError("Invalid memref bitcast."); + } + return success(); +} + +LogicalResult MemRefBitcastOp::canonicalize(MemRefBitcastOp op, + PatternRewriter &rewriter) { + auto src_ty = op.getInput().getType(); + auto dst_ty = op.getType(); + if (src_ty == dst_ty) { + rewriter.replaceOp(op, op.getInput()); + return success(); + } + auto erase_layout_op = op.getInput().getDefiningOp(); + if (!erase_layout_op) { + return failure(); + } + auto src_bitwidth = src_ty.getElementTypeBitWidth(); + auto tgt_bitwidth = dst_ty.getElementTypeBitWidth(); + auto layout_ref = erase_layout_op.getOperand(); + auto layout_ty = layout_ref.getType(); + auto layout = cast(layout_ty.getLayout()); + CHECK(!layout.getTiles().empty()); + auto tile = layout.getTiles().front().dimensions(); + if (tile[0] * src_bitwidth % tgt_bitwidth != 0) { + return failure(); + } + SmallVector new_tiles = + {xla::Tile({tile[0] * src_bitwidth / tgt_bitwidth, 128})}; + if (tgt_bitwidth < 32) { + new_tiles.push_back(xla::Tile({32 / tgt_bitwidth, 1})); + } + auto new_layout = tpu::TiledLayoutAttr::get(src_ty.getContext(), new_tiles, + layout.getTileStrides()); + auto new_result_ty = + MemRefType::get(dst_ty.getShape(), dst_ty.getElementType(), new_layout, + layout_ty.getMemorySpace()); + auto bitcast = + rewriter.create(op.getLoc(), new_result_ty, layout_ref); + rewriter.replaceOpWithNewOp(op, op.getType(), bitcast); + return success(); +} + template LogicalResult verifyStridedOp(Op op, MemRefType memref_ty, VectorType vector_ty) { @@ -406,8 +501,9 @@ class CanonicalizeAddOfMatmul : public OpRewritePattern { } return failure(); }; - return success(succeeded(try_canonicalize(op.getLhs(), op.getRhs())) || - succeeded(try_canonicalize(op.getLhs(), op.getRhs()))); + // We tried try_canonicalize(op.getRhs(), op.getLhs()) and it caused + // worrying numerical differences in some of kernels. + return try_canonicalize(op.getLhs(), op.getRhs()); } }; @@ -488,6 +584,81 @@ LogicalResult RegionOp::verify() { return success(); } +LogicalResult ShuffledLoadOp::verify() { + if (getBase().getType().getRank() != getIndices().size()) { + return emitOpError("Base memref's rank and indices size do not match: ") + << getBase().getType().getRank() << " vs " << getIndices().size(); + } + if (getSublaneMask().size() != getType().getShape()[0]) { + return emitOpError("Expected sublane mask size equals to ") + << getType().getShape()[0] << " but got " << getSublaneMask().size(); + } + if (getSublaneOffsets().size() != getType().getShape()[0]) { + return emitOpError("Expected sublane offsets size equals to ") + << getType().getShape()[0] << " but got " + << getSublaneOffsets().size(); + } + return success(); +} + +LogicalResult ShuffledLoadOp::canonicalize(ShuffledLoadOp op, + PatternRewriter &rewriter) { + bool can_convert_to_simple_load = true; + for (int i = 0; i < op.getSublaneOffsets().size(); ++i) { + if (op.getSublaneOffsets()[i] != i) { + can_convert_to_simple_load = false; + break; + }; + } + if (can_convert_to_simple_load) { + rewriter.replaceOpWithNewOp( + op, op.getType(), op.getBase(), op.getIndices(), op.getSublaneMask(), + /*sublane_stride=*/nullptr); + } + return success(); +} + +LogicalResult ShuffledStoreOp::verify() { + if (getBase().getType().getRank() != getIndices().size()) { + return emitOpError("Base memref's rank and indices size do not match: ") + << getBase().getType().getRank() << " vs " << getIndices().size(); + } + if (getValueToStore().getType().getRank() != getIndices().size()) { + return emitOpError( + "The rank of value to store and indices size do not match: ") + << getBase().getType().getRank() << " vs " << getIndices().size(); + } + if (getSublaneMask().size() != getValueToStore().getType().getShape()[0]) { + return emitOpError("Expected sublane mask size equals to ") + << getValueToStore().getType().getShape()[0] << " but got " + << getSublaneMask().size(); + } + if (getSublaneOffsets().size() != getValueToStore().getType().getShape()[0]) { + return emitOpError("Expected sublane offsets size equals to ") + << getValueToStore().getType().getShape()[0] << " but got " + << getSublaneOffsets().size(); + } + return success(); +} + +LogicalResult ShuffledStoreOp::canonicalize(ShuffledStoreOp op, + PatternRewriter &rewriter) { + bool can_convert_to_simple_store = true; + for (int i = 0; i < op.getSublaneOffsets().size(); ++i) { + if (op.getSublaneOffsets()[i] != i) { + can_convert_to_simple_store = false; + break; + }; + } + if (can_convert_to_simple_store) { + rewriter.replaceOpWithNewOp(op, op.getValueToStore(), + op.getBase(), op.getIndices(), + op.getSublaneMask(), + /*mask=*/nullptr, + /*sublane_stride=*/nullptr); + } + return success(); +} } // namespace tpu } // namespace mlir diff --git a/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc b/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc index 0e5dfa7b51f4..f6e1c7918646 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc +++ b/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc @@ -11,6 +11,7 @@ #include #include #include +#include #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVectorExtras.h" @@ -46,6 +47,7 @@ #include "mlir/Pass/Pass.h" #include "mlir/Support/LLVM.h" #include "mlir/Support/LogicalResult.h" +#include "absl/algorithm/container.h" #include "absl/log/check.h" #include "absl/log/log.h" #include "absl/status/status.h" @@ -139,18 +141,21 @@ void moveAllRegions(Operation &src, Operation &dst) { // // Returns: // A memref of the requested shape and type. -FailureOr getInternalScratch(RewriteContext &ctx, OpBuilder &builder, - Location loc, ArrayRef shape, - Type elem_ty) { +FailureOr> getInternalScratch( + RewriteContext &ctx, OpBuilder &builder, Location loc, + ArrayRef shape, Type elem_ty, int64_t sublane_tiling = 0) { if (shape.empty()) { return failure(); } if (shape.back() % ctx.target_shape[1] != 0) { return failure(); } - int sublane_count = + int packing = 32 / elem_ty.getIntOrFloatBitWidth(); + int sublane_count = llvm::divideCeil( std::accumulate(shape.begin(), shape.end(), 1, std::multiplies()) / - ctx.target_shape[1]; + ctx.target_shape[1], + packing); + if (sublane_count > ctx.max_sublanes_in_scratch) { return failure(); } @@ -159,7 +164,7 @@ FailureOr getInternalScratch(RewriteContext &ctx, OpBuilder &builder, FAILUREOR_ASSIGN_OR_RETURN( MemRefType scratch_ref_ty, inferMemref(MemRefType::get(shape, elem_ty), ctx.hardware_generation, - /*tpu_tiling_flags=*/{})); + /*tpu_tiling_flags=*/{}, sublane_tiling)); return builder.create(loc, scratch_ref_ty) .getResult(); } @@ -526,8 +531,7 @@ FailureOr appendConstant(RewriteContext &ctx, func::FuncOp func, return argument; } -// TODO(tlongeri): This function and others below never fail, remove FailureOr -FailureOr getNativeVregOrVmaskTypeImpl( +VectorType getNativeVregOrVmaskTypeImpl( Type elem_ty, const int8_t bitwidth, const std::array target_shape) { if (bitwidth == 32) { @@ -537,9 +541,8 @@ FailureOr getNativeVregOrVmaskTypeImpl( elem_ty); } -FailureOr getNativeVregOrVmaskType( - Type elem_ty, const int8_t layout_bitwidth, - const std::array target_shape) { +VectorType getNativeVregOrVmaskType(Type elem_ty, const int8_t layout_bitwidth, + const std::array target_shape) { int8_t bitwidth = elem_ty.getIntOrFloatBitWidth(); if (bitwidth == 1) { bitwidth = layout_bitwidth; @@ -549,8 +552,8 @@ FailureOr getNativeVregOrVmaskType( return getNativeVregOrVmaskTypeImpl(elem_ty, bitwidth, target_shape); } -FailureOr getNativeVregType( - Type elem_ty, const std::array target_shape) { +VectorType getNativeVregType(Type elem_ty, + const std::array target_shape) { return getNativeVregOrVmaskTypeImpl(elem_ty, elem_ty.getIntOrFloatBitWidth(), target_shape); } @@ -572,7 +575,7 @@ FailureOr maskOOB(RewriteContext &ctx, OpBuilder &builder, const VRegDataBounds &bounds, const TypedAttr neutral) { auto native_vreg_ty = - *getNativeVregType(value.getType().getElementType(), ctx.target_shape); + getNativeVregType(value.getType().getElementType(), ctx.target_shape); TPU_ASSERT_LOC(value.getLoc(), llvm::equal(value.getType().getShape(), native_vreg_ty.getShape())); if (bounds.isComplete(ctx.target_shape)) { @@ -709,10 +712,8 @@ LogicalResult elementwise_op_rule(RewriteContext &ctx, Operation &op, in_vreg_arrays.emplace_back(std::move(tile_array)); } - FAILUREOR_ASSIGN_OR_RETURN( - const VectorType out_vreg_ty, - getNativeVregOrVmaskType(out_ty.getElementType(), layout_out.bitwidth(), - ctx.target_shape)); + const VectorType out_vreg_ty = getNativeVregOrVmaskType( + out_ty.getElementType(), layout_out.bitwidth(), ctx.target_shape); NamedAttrList attributes(op.getAttrDictionary()); attributes.erase("in_layout"); @@ -763,10 +764,10 @@ using rule_type = std::function, ArrayRef)>; template -LogicalResult ext_op_rule_impl(RewriteContext &ctx, OpTy op, - const VectorLayout &layout_in, - const VectorLayout &layout_out) { - ImplicitLocOpBuilder builder(op.getLoc(), op.getOperation()); +FailureOr> ext_op_rule_impl(RewriteContext &ctx, + OpBuilder &builder, OpTy op, + const VectorLayout &layout_in, + const VectorLayout &layout_out) { const auto result_ty = cast(op.getResult().getType()); auto source = cast>(op.getIn()); const auto source_ty = source.getType(); @@ -783,9 +784,8 @@ LogicalResult ext_op_rule_impl(RewriteContext &ctx, OpTy op, output_vregs.Reshape(layout_out.tileArrayImplicitShape(result_ty.getShape(), ctx.target_shape)); } - FAILUREOR_ASSIGN_OR_RETURN( - const VectorType res_vreg_ty, - getNativeVregType(result_ty.getElementType(), ctx.target_shape)); + const VectorType res_vreg_ty = + getNativeVregType(result_ty.getElementType(), ctx.target_shape); if (layout_in.implicit_dim() != layout_out.implicit_dim()) { return op.emitOpError( "Not implemented: Change of implicit dim during the cast"); @@ -801,7 +801,7 @@ LogicalResult ext_op_rule_impl(RewriteContext &ctx, OpTy op, int64_t vreg_part = *(input_vreg_idxs.end() - 2) % packing; *(input_vreg_idxs.end() - 2) /= packing; *v = builder.create( - res_vreg_ty, input_vregs(input_vreg_idxs), vreg_part); + op.getLoc(), res_vreg_ty, input_vregs(input_vreg_idxs), vreg_part); }); } else { if (layout_in.tiling() != layout_out.tiling()) { @@ -817,17 +817,13 @@ LogicalResult ext_op_rule_impl(RewriteContext &ctx, OpTy op, input_vreg_idxs.back() /= packing; const int64_t vreg_part = idxs.back() % packing; *v = builder.create( - res_vreg_ty, input_vregs(input_vreg_idxs), vreg_part); + op.getLoc(), res_vreg_ty, input_vregs(input_vreg_idxs), vreg_part); }); } if (layout_out.implicit_dim() != VectorLayout::ImplicitDim::kNone) { output_vregs.Reshape(output_vregs_shape); } - op.replaceAllUsesWith(assemble(builder, result_ty, layout_out, - std::move(output_vregs), ctx.target_shape) - .getResult()); - op.erase(); - return success(); + return output_vregs; } LogicalResult arith_extf_rule(RewriteContext &ctx, Operation &op, @@ -842,8 +838,17 @@ LogicalResult arith_extf_rule(RewriteContext &ctx, Operation &op, return op.emitOpError( "Not implemented: Only 16-bit to 32-bit conversion supported"); } - return ext_op_rule_impl(ctx, extf_op, *layouts_in.front(), - *layouts_out.front()); + ImplicitLocOpBuilder builder(op.getLoc(), &op); + FAILUREOR_ASSIGN_OR_RETURN( + xla::Array output_vregs, + ext_op_rule_impl(ctx, builder, extf_op, *layouts_in.front(), + *layouts_out.front())); + const auto result_ty = cast(extf_op.getResult().getType()); + extf_op.replaceAllUsesWith(assemble(builder, result_ty, *layouts_out.front(), + std::move(output_vregs), ctx.target_shape) + .getResult()); + extf_op.erase(); + return success(); } LogicalResult arith_extsi_rule(RewriteContext &ctx, Operation &op, @@ -854,8 +859,69 @@ LogicalResult arith_extsi_rule(RewriteContext &ctx, Operation &op, TPU_ASSERT_EQ_OP(layouts_out.size(), 1); TPU_ASSERT_OP(layouts_out.front().has_value()); auto extsi_op = cast(op); - return ext_op_rule_impl(ctx, extsi_op, *layouts_in.front(), - *layouts_out.front()); + ImplicitLocOpBuilder builder(op.getLoc(), &op); + FAILUREOR_ASSIGN_OR_RETURN( + xla::Array output_vregs, + ext_op_rule_impl(ctx, builder, extsi_op, *layouts_in.front(), + *layouts_out.front())); + const auto result_ty = cast(extsi_op.getResult().getType()); + extsi_op.replaceAllUsesWith(assemble(builder, result_ty, *layouts_out.front(), + std::move(output_vregs), ctx.target_shape) + .getResult()); + extsi_op.erase(); + return success(); +} + +LogicalResult arith_extui_rule(RewriteContext &ctx, Operation &op, + const ArrayRef layouts_in, + const ArrayRef layouts_out) { + TPU_ASSERT_EQ_OP(layouts_in.size(), 1); + TPU_ASSERT_OP(layouts_in.front().has_value()); + TPU_ASSERT_EQ_OP(layouts_out.size(), 1); + TPU_ASSERT_OP(layouts_out.front().has_value()); + auto extui_op = cast(op); + auto in_ty = dyn_cast(extui_op.getIn().getType()); + auto out_ty = dyn_cast(extui_op.getType()); + CHECK(in_ty && out_ty); + auto in_bitwidth = in_ty ? in_ty.getElementTypeBitWidth() + : extui_op.getIn().getType().getIntOrFloatBitWidth(); + if (in_bitwidth == 1) { + return elementwise_op_rule(ctx, op, layouts_in, layouts_out); + } + ImplicitLocOpBuilder builder(op.getLoc(), &op); + FAILUREOR_ASSIGN_OR_RETURN( + xla::Array output_vregs, + ext_op_rule_impl(ctx, builder, extui_op, *layouts_in.front(), + *layouts_out.front())); + const auto source_ty = cast(extui_op.getIn().getType()); + const auto result_ty = cast(extui_op.getResult().getType()); + auto src_bitwidth = source_ty.getElementTypeBitWidth(); + auto dst_bitwidth = result_ty.getElementTypeBitWidth(); + // Generate a mask to mask out the sign extension. e.g., for u8 -> u16, + // the mask is 0x00ff00ff. + unsigned mask = (1 << src_bitwidth) - 1; + while (dst_bitwidth < 32) { + mask = (mask << dst_bitwidth) | mask; + dst_bitwidth *= 2; + } + const VectorType i32_vreg_ty = + getNativeVregType(builder.getI32Type(), ctx.target_shape); + auto mask_const = builder.create( + op.getLoc(), i32_vreg_ty, DenseIntElementsAttr::get(i32_vreg_ty, {mask})); + const VectorType res_vreg_ty = + getNativeVregType(result_ty.getElementType(), ctx.target_shape); + output_vregs.Each([&](absl::Span _, Value *v) { + Value unpacked = + builder.create(op.getLoc(), i32_vreg_ty, *v); + unpacked = builder.create(op.getLoc(), i32_vreg_ty, unpacked, + mask_const); + *v = builder.create(op.getLoc(), res_vreg_ty, unpacked); + }); + extui_op.replaceAllUsesWith(assemble(builder, result_ty, *layouts_out.front(), + std::move(output_vregs), ctx.target_shape) + .getResult()); + extui_op.erase(); + return success(); } template @@ -891,9 +957,8 @@ LogicalResult trunc_op_rule_impl(RewriteContext &ctx, OpTy op, output_vregs.Reshape(layout_out.tileArrayImplicitShape(result_ty.getShape(), ctx.target_shape)); } - FAILUREOR_ASSIGN_OR_RETURN( - VectorType res_vreg_ty, - getNativeVregType(result_ty.getElementType(), ctx.target_shape)); + VectorType res_vreg_ty = + getNativeVregType(result_ty.getElementType(), ctx.target_shape); if (layout_out.tiling() == ctx.target_shape) { const int packing = layout_out.packing(); output_vregs.Each([&](absl::Span idxs, Value *v) { @@ -1558,9 +1623,8 @@ LogicalResult strided_op_rule_impl(RewriteContext &ctx, Operation &op, } ImplicitLocOpBuilder builder(op.getLoc(), &op); - FAILUREOR_ASSIGN_OR_RETURN( - VectorType vreg_ty, - getNativeVregType(vty.getElementType(), ctx.target_shape)); + VectorType vreg_ty = + getNativeVregType(vty.getElementType(), ctx.target_shape); bool is_load_op = true; xla::Array tiles( @@ -1738,9 +1802,8 @@ LogicalResult tpu_matmul_rule(RewriteContext &ctx, Operation &op, TPU_ASSERT_EQ_OP(padded_lhs_rows, acc_vregs.dim(0) * layout_acc.tiling()[0]); TPU_ASSERT_EQ_OP(padded_rhs_rows, rhs_vregs.dim(0) * layout_rhs.tiling()[0]); - FAILUREOR_ASSIGN_OR_RETURN( - const VectorType i32_vreg, - getNativeVregType(builder.getI32Type(), ctx.target_shape)); + const VectorType i32_vreg = + getNativeVregType(builder.getI32Type(), ctx.target_shape); auto getVmaskByPaddingEnd = [&](int64_t dim, int64_t padding, VectorType vreg_ty) { CHECK(dim == 0 || dim == 1); @@ -2012,9 +2075,8 @@ LogicalResult tpu_bitcast_rule(RewriteContext &ctx, Operation &op, } } ImplicitLocOpBuilder builder(op.getLoc(), &op); - FAILUREOR_ASSIGN_OR_RETURN( - const auto native_vreg_ty, - getNativeVregType(out_ty.getElementType(), ctx.target_shape)); + const auto native_vreg_ty = + getNativeVregType(out_ty.getElementType(), ctx.target_shape); FAILUREOR_ASSIGN_OR_RETURN( const xla::Array in_tiles, disassemble(builder, in_layout, bitcast_op.getInput(), ctx.target_shape)); @@ -2064,9 +2126,8 @@ LogicalResult tpu_assume_layout_rule(RewriteContext &ctx, Operation &op, SmallVector layout_shape = layout->tileArrayShape(vty.getShape(), ctx.target_shape); const int64_t num_vectors = ShapedType::getNumElements(layout_shape); - FAILUREOR_ASSIGN_OR_RETURN( - VectorType vreg_ty, - getNativeVregType(vty.getElementType(), ctx.target_shape)); + VectorType vreg_ty = + getNativeVregType(vty.getElementType(), ctx.target_shape); // We can not use disassemble here because the val is block argument. auto unrolled_op = builder.create( val.getLoc(), SmallVector(num_vectors, vreg_ty), val); @@ -2104,16 +2165,15 @@ LogicalResult rotate_rule_impl(RewriteContext &ctx, OpTy op, Value amount, } ImplicitLocOpBuilder builder(op.getLoc(), op.getOperation()); - FAILUREOR_ASSIGN_OR_RETURN( - VectorType res_vreg_ty, - getNativeVregType(vty.getElementType(), ctx.target_shape)); + + VectorType res_vreg_ty = + getNativeVregType(vty.getElementType(), ctx.target_shape); FAILUREOR_ASSIGN_OR_RETURN( const xla::Array in_tiles, disassemble(builder, layout_in, op.getValue(), ctx.target_shape)); - FAILUREOR_ASSIGN_OR_RETURN( - const VectorType i32_vreg, - getNativeVregType(builder.getI32Type(), ctx.target_shape)); + const VectorType i32_vreg = + getNativeVregType(builder.getI32Type(), ctx.target_shape); // Some helper functions for math ops. auto mlirI32Const = [&](int d) { @@ -2466,11 +2526,11 @@ LogicalResult tpu_concatenate_rule(RewriteContext &ctx, Operation &op, "Not implemented: Only native tiling with offset (0, 0) is supported " "when concatenation along tiling dims."); } - // Check if shapes of src and res are aligned to native tiling. + // Check if the concat dim size of src and res is aligned to native tiling. auto check_aligned = [&](const VectorType &vty) { + auto i = dimension - res_ty.getRank(); return vty.getRank() >= 2 && - *(vty.getShape().end() - 2) % *(layout.tiling().end() - 2) == 0 && - *(vty.getShape().end() - 1) % *(layout.tiling().end() - 1) == 0; + *(vty.getShape().end() + i) % *(layout.tiling().end() + i) == 0; }; bool is_aligned = check_aligned(res_ty); int op_idx = 0; @@ -2518,9 +2578,9 @@ LogicalResult tpu_iota_rule(RewriteContext &ctx, Operation &op, if (!layout_out.hasNativeTiling(ctx.target_shape)) { return iota_op.emitOpError("Not implemented: Only native tiling supported"); } - FAILUREOR_ASSIGN_OR_RETURN( - const auto native_vreg_ty, - getNativeVregType(vty.getElementType(), ctx.target_shape)); + + const auto native_vreg_ty = + getNativeVregType(vty.getElementType(), ctx.target_shape); if (layout_out.implicit_dim() != VectorLayout::ImplicitDim::kNone) { return op.emitOpError("Not implemented: Only 2D layouts supported"); } @@ -2807,9 +2867,8 @@ LogicalResult vector_load_rule(RewriteContext &ctx, Operation &op, auto load_op = cast(op); const auto memref_ty = getMemRefType(load_op.getBase()); const auto vty = cast(load_op.getResult().getType()); - FAILUREOR_ASSIGN_OR_RETURN( - VectorType target_ty, - getNativeVregType(vty.getElementType(), ctx.target_shape)); + VectorType target_ty = + getNativeVregType(vty.getElementType(), ctx.target_shape); if (vty.getRank() == 0) { op.emitOpError("Not implemented: scalar loads from vmem"); } @@ -3017,9 +3076,8 @@ LogicalResult arith_constant_rule(RewriteContext &ctx, Operation &op, } const VectorLayout &layout_out = *layouts_out.front(); DenseElementsAttr value = cast(constant_op.getValue()); - FAILUREOR_ASSIGN_OR_RETURN( - const VectorType target_vty, - getNativeVregType(vty.getElementType(), ctx.target_shape)); + const VectorType target_vty = + getNativeVregType(vty.getElementType(), ctx.target_shape); if (value.isSplat()) { if (layout_out.offsets() != LayoutOffsets{std::nullopt, std::nullopt}) { return op.emitOpError( @@ -3270,9 +3328,9 @@ LogicalResult vector_broadcast_rule(RewriteContext &ctx, Operation &op, // yields the vmask. auto src_i32 = builder.create( broadcast_op.getLoc(), builder.getI32Type(), broadcast_op.getSource()); - FAILUREOR_ASSIGN_OR_RETURN( - const VectorType native_vreg_ty, - getNativeVregType(src_i32.getType(), ctx.target_shape)); + + const VectorType native_vreg_ty = + getNativeVregType(src_i32.getType(), ctx.target_shape); auto tile_i32 = builder.create(native_vreg_ty, src_i32); auto zeros = builder.create( @@ -3313,13 +3371,13 @@ LogicalResult vector_broadcast_rule(RewriteContext &ctx, Operation &op, loc, src_i32, builder.create(loc, src_i32, shift_width)); } - FAILUREOR_ASSIGN_OR_RETURN( - const VectorType i32_vreg_ty, - getNativeVregType(src_i32.getType(), ctx.target_shape)); + + const VectorType i32_vreg_ty = + getNativeVregType(src_i32.getType(), ctx.target_shape); auto tile_i32 = builder.create(i32_vreg_ty, src_i32); - FAILUREOR_ASSIGN_OR_RETURN(const VectorType native_vreg_ty, - getNativeVregType(src_ty, ctx.target_shape)); + const VectorType native_vreg_ty = + getNativeVregType(src_ty, ctx.target_shape); auto tile = builder.create(native_vreg_ty, tile_i32); const xla::Array dst_tiles(dst_tiles_shape, tile); @@ -3329,9 +3387,8 @@ LogicalResult vector_broadcast_rule(RewriteContext &ctx, Operation &op, broadcast_op.erase(); return success(); } else { - FAILUREOR_ASSIGN_OR_RETURN( - const VectorType native_vreg_ty, - getNativeVregType(broadcast_op.getSourceType(), ctx.target_shape)); + const VectorType native_vreg_ty = + getNativeVregType(broadcast_op.getSourceType(), ctx.target_shape); auto tile = builder.create(native_vreg_ty, broadcast_op.getSource()); const xla::Array dst_tiles(dst_tiles_shape, tile); @@ -3498,24 +3555,53 @@ LogicalResult vector_extract_rule(RewriteContext &ctx, Operation &op, op.erase(); return success(); } else { - for (int64_t i : extract_op.getStaticPosition()) { - if (i != 0) { - return op.emitOpError( - "Not implemented: Only 0 indices supported for scalar results"); - } - } + // TODO(b/367459476): Support non-zero offsets. if (layout_in.offsets() != LayoutOffsets{0, 0}) { return op.emitOpError("Not implemented: Unsupported layout"); } + auto [sub_tile, lane_tile] = layout_in.tiling(); FAILUREOR_ASSIGN_OR_RETURN( const xla::Array vregs, disassemble(builder, layout_in, extract_op.getVector(), ctx.target_shape)); TPU_ASSERT_GT_OP(vregs.num_elements(), 0); + + SmallVector indices(extract_op.getStaticPosition()); + auto vreg_slice = layout_in.vregSlice(ctx.target_shape); + std::array position = {0, 0}; + SmallVector vreg_index(indices); + // TODO(b/367459476): Support non-VREG-aligned tiling. + CHECK_EQ(lane_tile, ctx.target_shape[1]); + layout_in.insertImplicit(indices, static_cast(0)); + layout_in.insertImplicit(vreg_index, static_cast(0)); + int i = *(indices.end()-2); + int j = *(indices.end()-1); + *(vreg_index.end() -2) = i / vreg_slice[0]; + *(vreg_index.end() -1) = j / vreg_slice[1]; + layout_in.eraseImplicit(vreg_index); + position[0] = ((j % vreg_slice[1]) / lane_tile * sub_tile + ) + i % sub_tile; + position[1] = j % lane_tile; + + TPU_ASSERT_LT_OP(vreg_index, vregs.dimensions()); + Value extracted_vreg = vregs(vreg_index); + + // Invert the offsets to get the rotation amount. + position[0] = (ctx.target_shape[0] - position[0]) % ctx.target_shape[0]; + position[1] = (ctx.target_shape[1] - position[1]) % ctx.target_shape[1]; + auto res_vreg_ty = extracted_vreg.getType(); + Value shift = builder.create( + builder.getIntegerAttr(builder.getI32Type(), position[0])); + Value rotated_vreg = builder.create( + res_vreg_ty, extracted_vreg, shift, 0, /*stride*/nullptr, nullptr); + shift = builder.create( + builder.getIntegerAttr(builder.getI32Type(), position[1])); + rotated_vreg = builder.create( + res_vreg_ty, rotated_vreg, shift, 1, /*stride*/nullptr, nullptr); extract_op.replaceAllUsesWith( - builder - .create(op.getLoc(), *vregs.data(), - ArrayRef{0, 0}) + builder.create( + op.getLoc(), rotated_vreg, + ArrayRef{0, 0}) .getResult()); } extract_op.erase(); @@ -3578,12 +3664,7 @@ LogicalResult vector_multi_reduction_rule(RewriteContext &ctx, Operation &op, auto acc = cast>(multi_reduction_op.getAcc()); TPU_ASSERT_OP(layouts_out.front().has_value()); - const ArrayAttr dim_attrs = multi_reduction_op.getReductionDims(); - SmallVector dims; - dims.reserve(dim_attrs.size()); - for (const Attribute dim_attr : dim_attrs) { - dims.push_back(cast(dim_attr).getValue().getSExtValue()); - } + SmallVector dims(multi_reduction_op.getReductionDims()); std::sort(dims.begin(), dims.end()); // Make sure that the accumulator is a splat of the neutral value @@ -4366,6 +4447,7 @@ const llvm::StringMap &rules() { {arith::ConstantOp::getOperationName(), arith_constant_rule}, {arith::ExtFOp::getOperationName(), arith_extf_rule}, {arith::ExtSIOp::getOperationName(), arith_extsi_rule}, + {arith::ExtUIOp::getOperationName(), arith_extui_rule}, {arith::TruncFOp::getOperationName(), arith_truncf_rule}, {arith::TruncIOp::getOperationName(), arith_trunci_rule}, {func::ReturnOp::getOperationName(), func_return_rule}, @@ -4771,30 +4853,6 @@ xla::Array retileToReducedSublanes( return dst_vreg_array; } -// Returns true iff the layout changes involve reduced sublanes per tile. -// -// Arguments: -// src: The existing layout. -// dst: The new layout based on which the retiling is to be carried out. -bool isSupportedReducedSublanesRetile( - const VectorLayout &src, const VectorLayout &dst, - const std::array target_shape) { - return src.implicit_dim() == dst.implicit_dim() && - llvm::all_of(llvm::zip_equal(src.offsets(), dst.offsets()), - [](auto tup) { - auto [lhs, rhs] = tup; - return lhs.value_or(0) == rhs.value_or(0); - }) - // TODO (kumudbhandari): We have not tested any tile size where - // tile[-1] != TARGET_SHAPE.lanes. It should work but needs to be - // tested. - && src.tiling()[1] == target_shape[1] && - dst.tiling()[1] == target_shape[1] && - dst.tiling()[0] < src.tiling()[0] && - src.bitwidth() == dst.bitwidth() && - llvm::isPowerOf2_64(src.tiling()[0]) && - llvm::isPowerOf2_64(dst.tiling()[0]); -} // Copy one sublane from a vreg to another vreg. // @@ -4821,7 +4879,7 @@ Value copy_one_sublane(OpBuilder &builder, Value src_vreg, int src_sl_idx, CHECK_EQ(bitwidth, cast(dst_vreg.getType()).getElementTypeBitWidth()); const VectorType vmask_ty = - *getNativeVregOrVmaskType(builder.getI1Type(), bitwidth, target_shape); + getNativeVregOrVmaskType(builder.getI1Type(), bitwidth, target_shape); auto sublanes_mask = builder.create( src_vreg.getLoc(), vmask_ty, ValueRange{boundIdxConst(dst_sl_idx), boundIdxConst(0)}, @@ -4869,9 +4927,8 @@ FailureOr> tpu_rotate_with_overflow( SmallVector dst_tiles_shape = layout_out.tileArrayImplicitShape(vty.getShape(), target_shape); - FAILUREOR_ASSIGN_OR_RETURN( - const VectorType res_vreg_ty, - getNativeVregType(vty.getElementType(), target_shape)); + const VectorType res_vreg_ty = + getNativeVregType(vty.getElementType(), target_shape); xla::Array out_tiles(dst_tiles_shape); @@ -5008,13 +5065,291 @@ FailureOr> tpu_rotate_with_overflow( return out_tiles; } +void rotateVregs(OpBuilder &builder, xla::Array &vregs, + const int64_t amount, const int dimension) { + if (amount != 0) { + vregs.Each([&](absl::Span idx, Value *vreg) { + CHECK(vreg); + *vreg = builder + .create(vreg->getLoc(), *vreg, + /*amount=*/amount, + /*dimension=*/dimension, + /*stride=*/nullptr, + /*stride_dimension=*/nullptr) + .getResult(); + }); + } +}; + +void rotateSublanes(OpBuilder &builder, xla::Array &vregs, + const int64_t amount) { + rotateVregs(builder, vregs, amount, 0); +} + +void rotateLanes(OpBuilder &builder, xla::Array &vregs, + const int64_t amount) { + rotateVregs(builder, vregs, amount, 1); +} + +// Relayout src_vregs from layout src to layout dst, where dst is the same as +// src except that the column offset is dst_col_offset. +FailureOr> doColumnShiftRelayout( + OpBuilder &builder, const ArrayRef shape, + xla::Array src_vregs, const VectorLayout &src, + const int64_t dst_col_offset, const std::array target_shape) { + CHECK(src.offsets()[1]); + const std::array tiled_ishape = + src.getImplicitTiledDims(shape, 1); + const Location loc = src_vregs.begin()->getLoc(); + const std::array tiling = src.tiling(); + const std::array vreg_slice = src.vregSlice(target_shape); + const int bitwidth = src.bitwidth(); + const int packing = src.packing(); + const VectorLayout dst(bitwidth, {src.offsets()[0], dst_col_offset}, tiling, + src.implicit_dim()); + const int64_t col_diff = dst_col_offset - *src.offsets()[1]; + if (tiling[0] % packing != 0 || tiling[1] != target_shape[1]) { + return emitError(loc, + "Not implemented: Unsupported tiling for column shift"); + } + // When shifting columns with multiple tiles per vreg, the overflowing + // columns of a tile move to the next tile, and they have to be shifted + // down. For example, for a 32-bit layout with (2, 128 tiling), when shifting + // a vreg right by 138 (128 + 10): + // + // +---------------+---------+ +---------+---------------+ + // | 0:118 | 118:128 | |-138:-128| -128:-10 | + // +---------------+---------+ +---------+---------------+ + // | 128:246 | 246:256 | | -10:0 | 0:118 | + // +---------------+---------+ -> +---------+---------------+ + // | 256:382 | 382:392 | | 118:128 | 128:246 | + // +---------------+---------+ +---------+---------------+ + // | 392:502 | 502:512 | | 246:256 | 256:382 | + // +---------------+---------+ +---------+---------------+ + // + // The negative numbers above are used for column intervals coming from the + // previous vreg (if there is one). + // + // We can break the result vreg down into four parts: + // + // +---------+---------------+ + // | UL | UR | + // + +---------------+ + // | | LR | + // +---------+ + + // | LL | | + // + + + + // | | | + // +---------+---------------+ + // + // Our example shifts right, which causes the upper parts to come from the + // previous (along the minor dim) vreg of the array (if it exists) and the + // lower parts to come from the original "current" vreg. + // + // - LR (Lower Right) comes from the current vreg lane-rotated by 10, and + // sublane-rotated down by 2 (1 tile). + // - LL (Lower Left) comes from the current vreg lane-rotated by 10, and + // sublane-rotated down by 4 (2 tiles). + // - UR (Upper Right) comes from the previous vreg lane-shifted by 10, and + // sublane-rotated down by 2 (1 tile). + // - UL (Upper Left) comes from the previous vreg lane-shifted by 10, and + // sublane-rotated down by 4 (2 tiles). + // + // This partitioning also works similarly for left shifts, except that the + // upper parts come from the current vreg, and the lower parts come from the + // next vreg. + // + // In general, for any tiling and shift amount, we will partition the result + // vreg into four like we did here. However, for some tilings and shift + // amounts, some of the partitions may be empty. There are some notable cases: + // + // - Tile-aligned shifts result in empty left parts. + // - Native tiling (a single tile per vreg) results in empty upper right and + // lower left parts. + // - Shifts right by less than 1 tile result in empty upper right parts, and + // shifts left by less than 1 tile result in empty lower left parts. + + const int64_t sublanes_per_tile = src.sublanesPerTile(target_shape); + const int64_t tiles_per_vreg = src.tilesPerVreg(target_shape); + + int64_t split_offset = col_diff; + int64_t upper_idx_delta = -1; + int64_t lower_idx_delta = 0; + if (col_diff < 0) { + split_offset += vreg_slice[1]; + ++upper_idx_delta; + ++lower_idx_delta; + } + const int64_t left_tile_split = llvm::divideCeil(split_offset, tiling[1]); + const int64_t right_tile_split = split_offset / tiling[1]; + const int64_t left_right_split = split_offset % tiling[1]; + + rotateLanes(builder, src_vregs, left_right_split); + // TODO(tlongeri): Clean up. Some of these rotations may end up unused: + // - The left part of the first vreg and the right part of the last vreg + // may be entirely padding. + // - The entire left part may be unused if the shift is tile-aligned. + // They will be removed as dead code anyway, but it would be nicer to not + // generate them in the first place. + // Also, sometimes the rotation amount is 0, so we don't need to allocate + // another array (and we should steal the allocation for src_tiles, too). + xla::Array left_part = src_vregs; + xla::Array right_part = src_vregs; + rotateSublanes(builder, left_part, + left_tile_split * sublanes_per_tile % target_shape[0]); + rotateSublanes(builder, right_part, + right_tile_split * sublanes_per_tile % target_shape[0]); + // We assemble left and right, and then put them together. + // TODO(tlongeri): Lower and upper first is probably better, it can be + // reused for consecutive vregs. We can assemble lower_left+lower_right + // for one vreg and upper_left+upper_right for the next one in the same + // vselect. But the mask for assembling upper+lower is not as simple, so + // it might be a bit more expensive to generate. Worth it for large vreg + // arrays, I'm not sure about small ones (especially in older TPU gens). + const auto mask_vreg_ty = VectorType::get( + packing == 1 + ? target_shape + : ArrayRef{target_shape[0], target_shape[1], packing}, + builder.getI1Type()); + Value left_mask = nullptr; + Value right_mask = nullptr; + Value left_right_mask = nullptr; + auto get_left_mask = [&]() { + if (left_mask == nullptr) { + left_mask = builder.create( + loc, mask_vreg_ty, + ArrayRef{IdxConst(0, builder, loc), IdxConst(0, builder, loc)}, + ArrayRef{ + IdxConst(left_tile_split * sublanes_per_tile, builder, loc), + IdxConst(target_shape[1], builder, loc)}); + } + return left_mask; + }; + auto get_right_mask = [&]() { + if (right_mask == nullptr) { + right_mask = builder.create( + loc, mask_vreg_ty, + ArrayRef{IdxConst(0, builder, loc), IdxConst(0, builder, loc)}, + ArrayRef{ + IdxConst(right_tile_split * sublanes_per_tile, builder, loc), + IdxConst(target_shape[1], builder, loc)}); + } + return right_mask; + }; + auto get_left_right_mask = [&]() { + if (left_right_mask == nullptr) { + left_right_mask = builder.create( + loc, mask_vreg_ty, + ArrayRef{IdxConst(0, builder, loc), IdxConst(0, builder, loc)}, + ArrayRef{IdxConst(target_shape[0], builder, loc), + IdxConst(left_right_split, builder, loc)}); + } + return left_right_mask; + }; + xla::Array dst_vregs(VectorLayout(bitwidth, + {src.offsets()[0], dst_col_offset}, + tiling, src.implicit_dim()) + .tileArrayImplicitShape(shape, target_shape)); + dst_vregs.Each([&](absl::Span dst_idx, Value *dst_vreg) { + SmallVector dst_idx_local(toArrayRef(dst_idx)); + Value lower_left = nullptr; + Value lower_right = nullptr; + Value upper_left = nullptr; + Value upper_right = nullptr; + // Set parts if their size is non-empty and the source vreg exists. + *(dst_idx_local.end() - 1) += lower_idx_delta; + if (*(dst_idx_local.end() - 1) < *(src_vregs.dimensions().end() - 1)) { + if (left_tile_split < tiles_per_vreg && 0 < left_right_split) { + lower_left = left_part(dst_idx_local); + } + if (right_tile_split < tiles_per_vreg) { + lower_right = right_part(dst_idx_local); + } + } + *(dst_idx_local.end() - 1) -= lower_idx_delta; + *(dst_idx_local.end() - 1) += upper_idx_delta; + if (*(dst_idx_local.end() - 1) >= 0) { + if (0 < left_tile_split && 0 < left_right_split) { + upper_left = left_part(dst_idx_local); + } + if (0 < right_tile_split) { + upper_right = right_part(dst_idx_local); + } + } + *(dst_idx_local.end() - 1) -= upper_idx_delta; + + // For the first and last vregs, some parts may be all padding, so + // unset them if this is the case. Note that the first and last vreg + // are the same when there is only one. + if (*(dst_idx_local.end() - 1) == 0) { + // We check the final offset (note that this is different from the rotate + // amount) against the thresholds of the last columns of vreg parts. + if (right_tile_split * tiling[1] <= dst_col_offset) { + // Note: When shifting right, UR is always all-padding. + upper_right = nullptr; + } + if (split_offset <= dst_col_offset) { + // Note: When shifting right, UL is always all-padding. When shifting + // left, UL is never all-padding (unless this is also the last vreg, + // possibly). + upper_left = nullptr; + } + if (vreg_slice[1] - tiling[1] + left_right_split <= dst_col_offset) { + // Note: When shifting right, LL is only all-padding if the source + // offset is in the last tile. When shifting left, LL is never + // all-padding (unless this is also the last vreg, possibly). + lower_left = nullptr; + } + } + if (*(dst_idx_local.end() - 1) == *(dst_vregs.dimensions().end() - 1) - 1) { + // We check the final end offset against the thresholds of the first + // columns of vreg parts. + const uint64_t end_offset = + (dst_col_offset + tiled_ishape[1] - 1) % vreg_slice[1] + 1; + if (end_offset <= left_tile_split * tiling[1]) { + // Note: When shifting left, LL is always all-padding. + lower_left = nullptr; + } + if (end_offset <= split_offset) { + // Note: When shifting left, LR is always all-padding. When shifting + // right, LR is never all-padding (unless this is also the first vreg, + // possibly). + lower_right = nullptr; + } + if (end_offset <= left_right_split) { + // Note: When shifting left, UR is only all-padding if the original + // end offset is in the first tile. When shifting right, UR is never + // all-padding (unless this is also the last vreg, possibly). + upper_right = nullptr; + } + } + // Combine parts into the final vreg (see comment in mask definitions). + auto combine_parts = [&builder](Value part1, Value part2, + auto get_mask_fn) -> Value { + if (part1 && part2) { + return builder.create(part1.getLoc(), get_mask_fn(), + part1, part2); + } else if (part1) { + return part1; + } else { + return part2; + } + }; + Value left = combine_parts(upper_left, lower_left, get_left_mask); + Value right = combine_parts(upper_right, lower_right, get_right_mask); + *dst_vreg = combine_parts(left, right, get_left_right_mask); + CHECK(*dst_vreg); + }); + return dst_vregs; +} + FailureOr>> changeOffsets( - OpBuilder &builder, const std::array target_shape, - const Location loc, const VectorType vty, const VectorLayout src, - xla::Array vregs, const LayoutOffsets dst_offsets) { + RewriteContext &ctx, OpBuilder &builder, const Location loc, + const VectorType vty, const VectorLayout src, xla::Array vregs, + const LayoutOffsets dst_offsets) { + const auto &target_shape = ctx.target_shape; const VectorLayout dst(src.bitwidth(), dst_offsets, src.tiling(), src.implicit_dim()); - const auto &tiling = src.tiling(); const int packing = src.packing(); const int8_t bitwidth = src.bitwidth(); @@ -5061,15 +5396,7 @@ FailureOr>> changeOffsets( if (sublane_diff < 0) { sublane_diff += target_shape[0]; } - vregs.Each([&](absl::Span idx, Value *tile) { - *tile = - builder - .create(loc, *tile, - /*amount=*/sublane_diff, - /*dimension=*/0, /*stride=*/nullptr, - /*stride_dimension=*/nullptr) - .getResult(); - }); + rotateSublanes(builder, vregs, sublane_diff); } const int src_subelem = *src.offsets()[0] % packing; const int dst_subelem = *dst.offsets()[0] % packing; @@ -5108,79 +5435,363 @@ FailureOr>> changeOffsets( SmallVector dst_tiles_shape = dst.tileArrayImplicitShape(vty.getShape(), target_shape); CHECK_EQ(*(dst_tiles_shape.end() - 2), *(vregs.dimensions().end() - 2)); - if (dst_tiles_shape.back() != vregs.dimensions().back()) { - return emitError(loc, - "Not implemented: Offsets changing the vreg array shape"); - } + // TODO(tlongeri): Clean up col_diff and pass the dst offset directly. if (col_diff != 0) { - if (bitwidth != 32 || tiling != target_shape) { + FAILUREOR_ASSIGN_OR_RETURN( + vregs, doColumnShiftRelayout(builder, vty.getShape(), std::move(vregs), + src, *dst.offsets()[1], target_shape)); + } + return std::make_pair(dst, std::move(vregs)); +} + +LogicalResult retileToLargeTileWithScratch( + RewriteContext &ctx, OpBuilder &builder, const Location loc, + xla::Array &dst_tiles, const std::array &dst_tile, + const xla::Array &src_tiles, const std::array &src_tile, + TypedValue scratch_ref) { + if (dst_tile[0] % src_tile[0] != 0) { + return failure(); + } + // Number of src vregs needed to assemble one dst vreg. + int vregs_per_group = dst_tile[0] / src_tile[0]; + // Number of sublanes needed per src vreg to assemble one dst vreg. + int sl_per_vreg = ctx.target_shape[0] / vregs_per_group; + int stride = vregs_per_group; + + xla::Array sublane_offsets( + {ctx.target_shape[0] / dst_tile[0], src_tile[0], vregs_per_group}, 0); + absl::c_iota(sublane_offsets, 0); + // The older hardware has limited support for shuffles so even if we have bank + // conflicts, we just accept them and will have the lowering unroll the + // loads/stores. + bool should_handle_bank_confict = + ctx.hardware_generation >= 4 && ctx.vmem_banks > 0 && + ctx.vmem_banks < stride * ctx.target_shape[0]; + // Add one extra sublane to stride to avoid bank conflict. + if (should_handle_bank_confict) { + // Adjust sublane offsets to match the stride. + for (int i = 0; i < sublane_offsets.num_elements(); i += 1) { + *(sublane_offsets.begin() + i) += i / stride; + } + stride += 1; + } + sublane_offsets.TransposeDimensions({0, 2, 1}); + + auto mlirIndexConst = [&](int d) { + return builder.create( + src_tiles.begin()->getLoc(), + builder.getIntegerAttr(builder.getIndexType(), d)); + }; + auto cst_0 = mlirIndexConst(0); + // Each group has exact number of src vregs needed to assemble one dst vreg. + // We can not use circular buffer here because we need to have enough space to + // strided load/store. + int64_t sublanes_per_group = stride * sl_per_vreg * vregs_per_group; + int64_t max_groups_in_scratch = + ctx.max_sublanes_in_scratch / sublanes_per_group; + if (max_groups_in_scratch < 1) { + return emitError(loc, + "scratch space is not enough for retiling to large tile"); + } + int64_t stored_group_cnt = 0; + auto dst_vreg_ty = src_tiles.begin()->getType(); + // Create a new vreg type that can be stored in scratch memref. + auto temp_vreg_ty = + VectorType::get(ctx.target_shape, scratch_ref.getType().getElementType()); + SmallVector sublane_mask(ctx.target_shape[0], true); + // (dst_vreg, load_offset) + std::vector> delayed_loads; + delayed_loads.reserve(max_groups_in_scratch * vregs_per_group); + // We only emit the loads when we run out of scratch space or we are at the + // last vreg of the batch to help bundle scheduling. + auto emit_all_delayed_loads = [&]() { + for (auto [dst_vreg, load_offset] : delayed_loads) { + Value load_op = builder.create( + loc, temp_vreg_ty, scratch_ref, ArrayRef({load_offset, cst_0}), + ArrayRef(sublane_mask), + ArrayRef(sublane_offsets.begin(), sublane_offsets.end())); + *dst_vreg = builder.create(loc, dst_vreg_ty, load_op); + } + delayed_loads.clear(); + }; + + int rank = src_tiles.dimensions().size(); + if (rank != dst_tiles.dimensions().size()) { + return emitError(loc, "src and dst tiles have different ranks"); + } + for (int i = 0; i < rank - 2; ++i) { + if (src_tiles.dim(i) != dst_tiles.dim(i)) { return emitError(loc, - "Not implemented: Only 32-bit column shifts for " - "native layouts supported"); - } - TPU_ASSERT_GE_LOC(loc, vregs.num_dimensions(), 1); - std::optional maybe_create_mask; - if (*(vregs.dimensions().end() - 1) > 1) { - int64_t lane_start, lane_end; - if (col_diff > 0) { - lane_start = 0; - lane_end = col_diff; - } else { // col_diff < 0 - lane_start = target_shape[1] + col_diff; - lane_end = target_shape[1]; - } - auto boundIdxConst = - std::bind(IdxConst, std::placeholders::_1, builder, loc); - maybe_create_mask = builder.create( - loc, VectorType::get(target_shape, builder.getI1Type()), - ValueRange{boundIdxConst(0), boundIdxConst(lane_start)}, - ValueRange{boundIdxConst(target_shape[0]), boundIdxConst(lane_end)}); - } - auto rotated_vregs = vregs; - rotated_vregs.Each([&](absl::Span idx, Value *tile) { - *tile = builder - .create(loc, *tile, - /*amount=*/col_diff < 0 - ? target_shape[1] + col_diff - : col_diff, - /*dimension=*/1, /*stride=*/nullptr, - /*stride_dimension=*/nullptr) - .getResult(); - }); - vregs.Each([&](absl::Span idx, Value *result) { - Value rot_tile = rotated_vregs(idx); - Value prev_rot_tile; - if (col_diff > 0) { - if (*(idx.end() - 1) != 0) { - SmallVector prev_idx(idx.begin(), idx.end()); - --*(prev_idx.end() - 1); - prev_rot_tile = rotated_vregs(prev_idx); + "Expected src and dst tiles have same dimension " + "sizes on dim") + << i << ", but got " << src_tiles.dim(i) << " vs " + << dst_tiles.dim(i); + } + } + SmallVector src_idx(rank); + dst_tiles.Each([&](absl::Span dst_idx, Value *dst_vreg) { + int64_t dst_row_idx = *(dst_idx.end() - 2); + int64_t dst_col_idx = *(dst_idx.end() - 1); + int64_t vreg_idx_in_group = dst_col_idx % vregs_per_group; + int64_t load_offset = sublanes_per_group * stored_group_cnt + + vreg_idx_in_group * sl_per_vreg * stride; + delayed_loads.push_back( + std::make_pair(dst_vreg, mlirIndexConst(load_offset))); + // When dst vreg is at the last vreg of the group or the current dst + // vregs' row, this indicates we have scheduled delayed loads for all + // the vregs from current group and now we need to store corresponding + // group of src vregs before actually emitting the loads. + if (vreg_idx_in_group == vregs_per_group - 1 || + dst_col_idx == dst_tiles.dimensions().back() - 1) { + auto src_row_idx = dst_row_idx * vregs_per_group; + auto src_col_idx = dst_col_idx / vregs_per_group; + std::copy(dst_idx.begin(), dst_idx.end(), src_idx.begin()); + for (int vi = 0; vi < vregs_per_group; ++vi) { + if (src_row_idx + vi >= src_tiles.dim(rank - 2) || + src_col_idx >= src_tiles.dim(rank - 1)) { + break; } - } else { // col_diff < 0 - if (*(idx.end() - 1) != *(rotated_vregs.dimensions().end() - 1) - 1) { - SmallVector prev_idx(idx.begin(), idx.end()); - ++*(prev_idx.end() - 1); - prev_rot_tile = rotated_vregs(prev_idx); + *(src_idx.end() - 2) = src_row_idx + vi; + *(src_idx.end() - 1) = src_col_idx; + Value src_vreg = src_tiles(src_idx); + src_vreg = + builder.create(loc, temp_vreg_ty, src_vreg); + Value store_offset = + mlirIndexConst(sublanes_per_group * stored_group_cnt + vi); + builder.create( + loc, src_vreg, scratch_ref, ArrayRef({store_offset, cst_0}), + ArrayRef(sublane_mask), + /*mask=*/nullptr, builder.getI32IntegerAttr(stride)); + } + stored_group_cnt = (stored_group_cnt + 1) % max_groups_in_scratch; + // We emit loads when we run out of scratch space or we are at the + // last vreg of the batch. + if (stored_group_cnt == 0 || + (*(dst_idx.end() - 2) == dst_tiles.dim(rank - 2) - 1 && + *(dst_idx.end() - 1) == dst_tiles.dim(rank - 1) - 1)) { + emit_all_delayed_loads(); + } + } + }); + return success(); +} + +LogicalResult retileToSmallTileWithScratch( + RewriteContext &ctx, OpBuilder &builder, const Location loc, + xla::Array &dst_tiles, const std::array &dst_tile, + const xla::Array &src_tiles, const std::array &src_tile, + TypedValue scratch_ref) { + if (src_tile[0] % dst_tile[0] != 0) { + return failure(); + } + // Number of src vregs needed to assemble one dst vreg. + int vregs_per_group = src_tile[0] / dst_tile[0]; + // Number of sublanes needed per src vreg to assemble one dst vreg. + int sl_per_vreg = ctx.target_shape[0] / vregs_per_group; + int stride = vregs_per_group; + + xla::Array sublane_offsets( + {ctx.target_shape[0] / src_tile[0], dst_tile[0], vregs_per_group}, 0); + absl::c_iota(sublane_offsets, 0); + // The older hardware has limited support for shuffles so even if we have + // bank conflicts, we just accept them and will have the lowering unroll the + // loads/stores. + bool should_handle_bank_confict = + ctx.hardware_generation >= 4 && ctx.vmem_banks > 0 && + ctx.vmem_banks < stride * ctx.target_shape[0]; + bool use_shuffled_load = false; + if (ctx.hardware_generation <= 4) { + if (src_tile[0] == 8) { + // The older hardware does not support shuffled store. However, if the src + // tile is (8, 128), we can convert (shuffled store + strided load) to + // (strided store + shuffled load). + use_shuffled_load = true; + } else if (src_tile[0] == 4) { + // In this case, the trick of replacing a shuffled store with a shuffled + // load does not work. Handling bank conflicts will cause the sublane + // offsets to increase which might make emulation harder, so we avoid + // doing so. + should_handle_bank_confict = false; + } + } + + // Add one extra sublane to stride to avoid bank conflict. + if (should_handle_bank_confict) { + // Adjust sublane offsets to match the stride. + for (int i = 0; i < sublane_offsets.num_elements(); i += 1) { + *(sublane_offsets.begin() + i) += i / stride; + } + stride += 1; + } + sublane_offsets.TransposeDimensions({0, 2, 1}); + auto mlirIndexConst = [&](int d) { + return builder.create( + src_tiles.begin()->getLoc(), + builder.getIntegerAttr(builder.getIndexType(), d)); + }; + auto cst_0 = mlirIndexConst(0); + // Each group has exact number of src vregs needed to assemble one dst vreg. + // We can not use circular buffer here because we need to have enough space + // to strided load/store. + int64_t sublanes_per_group = stride * sl_per_vreg * vregs_per_group; + int64_t max_groups_in_scratch = + ctx.max_sublanes_in_scratch / sublanes_per_group; + if (max_groups_in_scratch < 1) { + return emitError(loc, + "scratch space is not enough for retiling to small tile"); + } + int64_t stored_group_cnt = 0; + auto dst_vreg_ty = src_tiles.begin()->getType(); + // Create a new vreg type that can be stored in scratch memref. + auto temp_vreg_ty = + VectorType::get(ctx.target_shape, scratch_ref.getType().getElementType()); + SmallVector sublane_mask(ctx.target_shape[0], true); + // (dst_vreg, load_offset) + std::vector> delayed_loads; + delayed_loads.reserve(max_groups_in_scratch * vregs_per_group); + // We only emit the loads when we run out of scratch space or we are at the + // last vreg of the batch to help bundle scheduling. + auto emit_all_delayed_loads = [&]() { + for (auto [dst_vreg, load_offset] : delayed_loads) { + Value load_op; + if (use_shuffled_load) { + load_op = builder.create( + loc, temp_vreg_ty, scratch_ref, + ArrayRef({load_offset, cst_0}), ArrayRef(sublane_mask), + ArrayRef(sublane_offsets.begin(), sublane_offsets.end())); + } else { + load_op = builder.create( + loc, temp_vreg_ty, scratch_ref, + ArrayRef({load_offset, cst_0}), ArrayRef(sublane_mask), + builder.getI32IntegerAttr(stride)); + } + *dst_vreg = builder.create(loc, dst_vreg_ty, load_op); + } + delayed_loads.clear(); + }; + int rank = src_tiles.dimensions().size(); + if (rank != dst_tiles.dimensions().size()) { + return emitError(loc, "src and dst tiles have different ranks"); + } + for (int i = 0; i < rank - 2; ++i) { + if (src_tiles.dim(i) != dst_tiles.dim(i)) { + return emitError(loc, + "Expected src and dst tiles have same dimension " + "sizes on dim") + << i << ", but got " << src_tiles.dim(i) << " vs " + << dst_tiles.dim(i); + } + } + SmallVector dst_idx(rank); + src_tiles.Each([&](absl::Span src_idx, Value src_vreg) { + int64_t src_row_idx = *(src_idx.end() - 2); + int64_t src_col_idx = *(src_idx.end() - 1); + int64_t vreg_idx_in_group = src_col_idx % vregs_per_group; + src_vreg = builder.create(loc, temp_vreg_ty, src_vreg); + if (use_shuffled_load) { + Value store_offset = mlirIndexConst( + sublanes_per_group * stored_group_cnt + vreg_idx_in_group); + builder.create( + loc, src_vreg, scratch_ref, ArrayRef({store_offset, cst_0}), + ArrayRef(sublane_mask), + /*mask=*/nullptr, builder.getI32IntegerAttr(stride)); + } else { + Value store_offset = + mlirIndexConst(sublanes_per_group * stored_group_cnt + + vreg_idx_in_group * sl_per_vreg * stride); + builder.create( + loc, src_vreg, scratch_ref, ArrayRef({store_offset, cst_0}), + ArrayRef(sublane_mask), + ArrayRef(sublane_offsets.begin(), sublane_offsets.end())); + } + // When src vreg is at the last vreg of the group or the current src + // vregs' row, this indicates we have stored all the vregs needed to + // assemble a new group of dst vreg. + if (vreg_idx_in_group == vregs_per_group - 1 || + src_col_idx == src_tiles.dimensions().back() - 1) { + auto dst_row_idx = src_row_idx * vregs_per_group; + auto dst_col_idx = src_col_idx / vregs_per_group; + std::copy(src_idx.begin(), src_idx.end(), dst_idx.begin()); + for (int vi = 0; vi < vregs_per_group; ++vi) { + if (dst_row_idx + vi >= dst_tiles.dim(rank - 2) || + dst_col_idx >= dst_tiles.dim(rank - 1)) { + break; } + *(dst_idx.end() - 2) = dst_row_idx + vi; + *(dst_idx.end() - 1) = dst_col_idx; + Value *dst_vreg = &dst_tiles(dst_idx); + int64_t load_offset = + use_shuffled_load ? (sublanes_per_group * stored_group_cnt + + vi * sl_per_vreg * stride) + : (sublanes_per_group * stored_group_cnt + vi); + delayed_loads.push_back( + std::make_pair(dst_vreg, mlirIndexConst(load_offset))); } - if (prev_rot_tile != nullptr) { - rot_tile = builder.create( - loc, maybe_create_mask->getResult(), prev_rot_tile, rot_tile); + stored_group_cnt = (stored_group_cnt + 1) % max_groups_in_scratch; + // We emit loads when we run out of scratch space or we are at the + // last vreg of the batch. + if (stored_group_cnt == 0 || + (*(src_idx.end() - 2) == src_tiles.dim(rank - 2) - 1 && + *(src_idx.end() - 1) == src_tiles.dim(rank - 1) - 1)) { + emit_all_delayed_loads(); } - *result = rot_tile; - }); + } + }); + return success(); +} + +// go/mosaic-retiling-in-scratch is the full internal documentation that +// includes more details about the TPU generations. +LogicalResult retileWithScratch(RewriteContext &ctx, OpBuilder &builder, + const Location loc, + xla::Array &dst_tiles, + const std::array &dst_tiling, + const xla::Array &src_tiles, + const std::array &src_tiling, + int packing) { + if (!(src_tiling[1] == ctx.target_shape[1] && + dst_tiling[1] == ctx.target_shape[1] && src_tiling[0] % packing == 0 && + dst_tiling[0] % packing == 0)) { + return failure(); } - return std::make_pair(dst, std::move(vregs)); + // Try to get i32 vector scratch space. Because we will bitcast vregs to + // i32 vregs before using scratch for retiling. Through this way we can + // handle packed types as well. + auto vi32_scratch_ref = getInternalScratch( + ctx, builder, loc, {ctx.max_sublanes_in_scratch, ctx.target_shape[1]}, + builder.getI32Type(), /*sublane_tiling=*/1); + if (failed(vi32_scratch_ref)) { + return emitError(loc, "Failed to get scratch ref for retiling"); + } + auto ref = vi32_scratch_ref.value(); + std::array vi32_dst_tiling = {dst_tiling[0] / packing, + dst_tiling[1]}; + std::array vi32_src_tiling = {src_tiling[0] / packing, + src_tiling[1]}; + if (src_tiling[0] > dst_tiling[0]) { + return retileToSmallTileWithScratch(ctx, builder, loc, dst_tiles, + vi32_dst_tiling, src_tiles, + vi32_src_tiling, ref); + } + if (src_tiling[0] < dst_tiling[0]) { + return retileToLargeTileWithScratch(ctx, builder, loc, dst_tiles, + vi32_dst_tiling, src_tiles, + vi32_src_tiling, ref); + } + dst_tiles = std::move(src_tiles); + return success(); } -// TODO(b/265133506): Generalize retiling. FailureOr>> changeTiling( - OpBuilder &builder, const std::array target_shape, - const Location loc, VectorType vty, const VectorLayout src, - xla::Array vregs, const std::array dst_tiling, - bool try_replicate_rows) { - if (src.tiling() == dst_tiling) { + RewriteContext &ctx, OpBuilder &builder, const Location loc, VectorType vty, + const VectorLayout src, xla::Array vregs, + const std::array dst_tiling, bool try_replicate_rows) { + bool has_enough_scratch = ctx.max_sublanes_in_scratch >= + ctx.target_shape[0] * (ctx.target_shape[0] + 1); + const auto &target_shape = ctx.target_shape; + const std::array src_tiling = src.tiling(); + if (src_tiling == dst_tiling) { return std::pair(src, std::move(vregs)); } const int packing = src.packing(); @@ -5190,106 +5801,62 @@ FailureOr>> changeTiling( if (!dst.isValid(target_shape)) { return emitError(loc, "Not implemented: invalid offsets in tiling target"); } - // Handle retiling from (packing, 128) to (8 * packing, 128). - if (src.offsets() == LayoutOffsets{0, 0} && - src.tiling() == std::array{packing, 128} && - dst_tiling == std::array{8 * packing, 128}) { - bool replicate_sublanes = try_replicate_rows && packing == 1 && - *(vregs.dimensions().end() - 2) == 1; - xla::Array retiled( - dst.tileArrayImplicitShape(vty.getShape(), target_shape)); + auto dst_tiles_shape = + dst.tileArrayImplicitShape(vty.getShape(), target_shape); + // Handle retiling from (1, 128) to (8, 128) for 32-bit data with replicating + // sublanes. + if (try_replicate_rows && packing == 1 && + *(vregs.dimensions().end() - 2) == 1 && + src.offsets() == LayoutOffsets{0, 0} && + src.tiling() == std::array{1, 128} && + dst_tiling == std::array{8, 128}) { + xla::Array retiled(dst_tiles_shape); retiled.Each([&](absl::Span idx, Value *tile) { SmallVector src_idx(idx.begin(), idx.end()); *(src_idx.end() - 2) *= target_shape[0]; *(src_idx.end() - 1) /= target_shape[0]; const int64_t src_sl_idx = *(idx.end() - 1) % target_shape[0]; - if (replicate_sublanes) { - CHECK_EQ(src.getImplicitTiledDims(vty.getShape(), 1)[0], 1); - *tile = - broadcastSublane(builder, vregs(src_idx), src_sl_idx, target_shape); - } else { - for (int dst_sl_idx = 0; - dst_sl_idx < target_shape[0] && - *(src_idx.end() - 2) < *(vregs.dimensions().end() - 2); - ++dst_sl_idx, ++*(src_idx.end() - 2)) { - *tile = copy_one_sublane(builder, vregs(src_idx), src_sl_idx, *tile, - dst_sl_idx, target_shape); - } - } + CHECK_EQ(src.getImplicitTiledDims(vty.getShape(), 1)[0], 1); + *tile = + broadcastSublane(builder, vregs(src_idx), src_sl_idx, target_shape); }); // We have successfully replicated sublanes. - if (replicate_sublanes) { - dst = VectorLayout(bitwidth, {std::nullopt, dst.offsets()[1]}, dst_tiling, - dst.implicit_dim()); - } - return std::pair(dst, std::move(retiled)); - } - // Handle retiling from (m, 128) to (8, 128) for 32-bit data - // where m < 8 and m is a power of 2. - // TODO(b/306692696): Handle any vregs.dimensions(). - if (bitwidth == 32 && src.offsets() == LayoutOffsets{0, 0} && - target_shape[0] % src.tiling()[0] == 0 && - src.tiling()[1] == target_shape[1] && dst.tiling() == target_shape && - *(vregs.dimensions().end() - 2) == 1) { - xla::Array retiled( - dst.tileArrayImplicitShape(vty.getShape(), target_shape)); - retiled.Each([&](const absl::Span idx, - Value *const new_src_tile) { - const int64_t tiles_per_vreg = src.tilesPerVreg(target_shape); - const int64_t dst_col = idx.back(); - const int64_t src_col = dst_col / tiles_per_vreg; - const int64_t start_slane_idx = - src.tiling()[0] * (dst_col % tiles_per_vreg); - SmallVector src_idx(toArrayRef(idx)); - src_idx.back() = src_col; - Value src_tile = vregs(src_idx); - if (start_slane_idx) { - SmallVector slane_idxs; - slane_idxs.reserve(target_shape[0]); - for (int i = 0; i < target_shape[0]; ++i) { - slane_idxs.push_back(start_slane_idx + (i % src.tiling()[0])); - } - const DenseI32ArrayAttr gather_indices = - builder.getDenseI32ArrayAttr(slane_idxs); - *new_src_tile = builder.create(loc, src_tile.getType(), - src_tile, gather_indices, - /*dimension=*/0); - } else { - *new_src_tile = src_tile; - } - }); + dst = VectorLayout(bitwidth, {std::nullopt, dst.offsets()[1]}, dst_tiling, + dst.implicit_dim()); return std::pair(dst, std::move(retiled)); } // (8,128) -> (8 * packing,128) tiling change for packed type. if (bitwidth < 32 && 32 % bitwidth == 0 && - src.tiling() == std::array{8, 128} && - dst.tiling() == std::array{8 * dst.packing(), 128}) { - xla::Array retiled( - dst.tileArrayImplicitShape(vty.getShape(), target_shape)); - int vty_packing = dst.packing(); - VectorType vreg_x32 = - vty.getElementType().isSignlessInteger() - ? VectorType::get(target_shape, builder.getI32Type()) - : VectorType::get(target_shape, builder.getF32Type()); - retiled.Each([&](absl::Span idx, Value *tile) { - const int vreg_part = idx.back() % vty_packing; - SmallVector parts; - parts.reserve(vty_packing); - SmallVector src_idx(idx.begin(), idx.end()); - src_idx[src_idx.size() - 2] *= vty_packing; - src_idx[src_idx.size() - 1] /= vty_packing; - for (int i = 0; i < vty_packing; ++i) { - parts.push_back(builder.create( - loc, vreg_x32, vregs(src_idx), vreg_part)); - if (src_idx[src_idx.size() - 2] < - vregs.dim(vregs.num_dimensions() - 2) - 1) { - ++src_idx[src_idx.size() - 2]; + src_tiling == std::array{8, 128} && + dst_tiling == std::array{8 * dst.packing(), 128}) { + // Note: for int4, retiling with scratch is always faster. + if (bitwidth != 4 || !has_enough_scratch) { + xla::Array retiled(dst_tiles_shape); + int vty_packing = dst.packing(); + VectorType vreg_x32 = + vty.getElementType().isSignlessInteger() + ? VectorType::get(target_shape, builder.getI32Type()) + : VectorType::get(target_shape, builder.getF32Type()); + retiled.Each([&](absl::Span idx, Value *tile) { + const int vreg_part = idx.back() % vty_packing; + SmallVector parts; + parts.reserve(vty_packing); + SmallVector src_idx(idx.begin(), idx.end()); + src_idx[src_idx.size() - 2] *= vty_packing; + src_idx[src_idx.size() - 1] /= vty_packing; + for (int i = 0; i < vty_packing; ++i) { + parts.push_back(builder.create( + loc, vreg_x32, vregs(src_idx), vreg_part)); + if (src_idx[src_idx.size() - 2] < + vregs.dim(vregs.num_dimensions() - 2) - 1) { + ++src_idx[src_idx.size() - 2]; + } } - } - *tile = builder.create( - loc, vregs.begin()->getType(), parts, tpu::PackFormat::kCompressed); - }); - return std::pair(dst, std::move(retiled)); + *tile = builder.create( + loc, vregs.begin()->getType(), parts, tpu::PackFormat::kCompressed); + }); + return std::pair(dst, std::move(retiled)); + } } // Handle retiling from (1, 128 * packing) to (packing, 128) for // packed data. @@ -5303,8 +5870,8 @@ FailureOr>> changeTiling( // match corresponding elements without shifting. It's just that // the tiles are not adjacent (no contiguous vreg slice). if (bitwidth < 32 && 32 % bitwidth == 0 && - src.tiling() == std::array{1, 128 * packing} && - dst.tiling() == std::array{packing, 128}) { + src_tiling == std::array{1, 128 * packing} && + dst_tiling == std::array{packing, 128}) { // To illustrate, consider a 2 x 16 16-bit shape laid out in vregs of // 4 sublanes and 2 lanes (this is convenient for to keep the example small // yet non-trivial) with (1, 4) tiling. We will relayout to (2, 2) tiling. @@ -5345,8 +5912,7 @@ FailureOr>> changeTiling( // [(a b) (A B) (c d) (C D) ...]. That is, traverse down each column before // moving to the next one. This is exactly an interleaving of the sublanes // of the vreg parts. - xla::Array retiled( - dst.tileArrayImplicitShape(vty.getShape(), target_shape)); + xla::Array retiled(dst_tiles_shape); const VectorType vreg_x32 = vty.getElementType().isSignlessInteger() ? VectorType::get(target_shape, builder.getI32Type()) @@ -5371,20 +5937,49 @@ FailureOr>> changeTiling( }); return std::pair(dst, std::move(retiled)); } - if (isSupportedReducedSublanesRetile(src, dst, target_shape)) { - return std::pair(dst, retileToReducedSublanes(builder, vty.getShape(), src, - vregs, dst, target_shape)); + if (src_tiling[1] == target_shape[1] && dst_tiling[1] == target_shape[1]) { + // TODO(b/368088671): When sublane tiling changes, we should be able to + // preserve some replications from the source layout. But we need to + // make sure they are implemented efficiently and well-tested. For now, we + // just simply use 0 for the replicated offset after retiling. + dst = VectorLayout( + bitwidth, {src.offsets()[0].value_or(0), src.offsets()[1].value_or(0)}, + dst_tiling, dst.implicit_dim()); + + // All clauses in the and expression are based on performance benchmarking. + bool use_alu = !has_enough_scratch || + (ctx.hardware_generation >= 5 && src_tiling[0] != packing && + dst_tiling[0] != packing); + + if (use_alu) { + if (src_tiling[0] > dst_tiling[0]) { + return std::pair( + dst, retileToReducedSublanes(builder, vty.getShape(), src, vregs, + dst, target_shape)); + } else if (!has_enough_scratch) { + // TODO(b/357538782): Implement retileToIncreasedSublanes with ALU ops. + return emitError( + loc, + "Not implemented: retiling to increase sublane tiling with ALU"); + } + } + xla::Array retiled(dst_tiles_shape); + if (failed(retileWithScratch(ctx, builder, loc, retiled, dst_tiling, vregs, + src_tiling, packing))) { + return failure(); + } + return std::pair(dst, std::move(retiled)); } return emitError(loc, "Not implemented: Unsupported tiling change for ") - << vty << ": from " << src << " to tiling (" << dst_tiling[0] << ", " - << dst_tiling[1] << ")"; + << vty << ": from " << src << " to " << dst; } FailureOr>> changeImplicitDim( - OpBuilder &builder, const std::array target_shape, - const Location loc, VectorType vty, const VectorLayout src, - xla::Array vregs, const VectorLayout::ImplicitDim dst_implicit_dim, + RewriteContext &ctx, OpBuilder &builder, const Location loc, VectorType vty, + const VectorLayout src, xla::Array vregs, + const VectorLayout::ImplicitDim dst_implicit_dim, const LayoutOffsets dst_offset_hints) { + const auto &target_shape = ctx.target_shape; if (src.implicit_dim() == dst_implicit_dim) { return std::make_pair(src, std::move(vregs)); } @@ -5396,33 +5991,47 @@ FailureOr>> changeImplicitDim( src_candidate.tileArrayImplicitShape(vty.getShape(), target_shape)); return std::make_pair(src_candidate, vregs); } - // Remove second minor implicit dim, for values that have (8, 128) tiling. - // TODO(apaszke): We should allow replicated dst_offset_hints[0]. + // Remove second minor implicit dim, for values that have (m, 128) tiling (for + // m that is a power of 2). if (src.implicit_dim() == VectorLayout::ImplicitDim::kSecondMinor && dst_implicit_dim == VectorLayout::ImplicitDim::kNone && - src.bitwidth() == 32 && src.tiling() == std::array{8, 128} && - dst_offset_hints[0]) { + src.bitwidth() == 32 && src.tiling()[1] == target_shape[1] && + llvm::isPowerOf2_32(src.tiling()[0])) { + // We should never see a replicated offset here. We're removing the implicit + // dim so the only case when this can happen is when its size is 1 (or else + // we can't prove replication in the logical value). But in that case, the + // equivalentTo case above triggers and we never reach this branch. + CHECK(dst_offset_hints[0].has_value()); int64_t dst_sublane_offset = *dst_offset_hints[0]; VectorLayout dst(src.bitwidth(), {dst_sublane_offset, src.offsets()[1]}, src.tiling(), dst_implicit_dim); xla::Array new_vregs( dst.tileArrayImplicitShape(vty.getShape(), target_shape)); - new_vregs.Each([&](const absl::Span idx, - Value *tile) { + new_vregs.Each([&](const absl::Span idx, Value *tile) { const int64_t dst_2nd_minor_idx = idx.size() - 2; SmallVector src_idx(idx.begin(), idx.end()); src.insertImplicit(src_idx, 0); const int dst_sl_start = idx[dst_2nd_minor_idx] == 0 ? dst_sublane_offset : 0; - src_idx[dst_2nd_minor_idx] = target_shape[0] * idx[dst_2nd_minor_idx] + + // This could be optimized further to take offsets[1] into account. + // For example, extended offsets allow us to skip copies of low sublanes + // in tiles with idx.back() == 0. + const int tiles_per_vreg = src.tilesPerVreg(target_shape); + const int sublanes_per_tile = src.sublanesPerTile(target_shape); + src_idx[dst_2nd_minor_idx] = src.tiling()[0] * idx[dst_2nd_minor_idx] + dst_sl_start - dst_sublane_offset; for (int dst_sl_idx = dst_sl_start; - dst_sl_idx < target_shape[0] && + dst_sl_idx < src.tiling()[0] && src_idx[dst_2nd_minor_idx] < vregs.dim(dst_2nd_minor_idx); ++dst_sl_idx, ++src_idx[dst_2nd_minor_idx]) { - *tile = copy_one_sublane(builder, vregs(src_idx), - src.offsets()[0].value_or(dst_sl_idx), *tile, - dst_sl_idx, target_shape); + // This could be optimized further by copying multiple sublanes at once. + for (int tile_idx = 0; tile_idx < tiles_per_vreg; ++tile_idx) { + int tile_off = tile_idx * sublanes_per_tile; + *tile = + copy_one_sublane(builder, vregs(src_idx), + tile_off + src.offsets()[0].value_or(dst_sl_idx), + *tile, tile_off + dst_sl_idx, target_shape); + } } }); return std::make_pair(dst, new_vregs); @@ -5523,21 +6132,21 @@ FailureOr> relayout(RewriteContext &ctx, FAILUREOR_ASSIGN_OR_RETURN( std::tie(src, src_tiles), - changeTiling(builder, ctx.target_shape, v.getLoc(), vty, src, - std::move(src_tiles), dst.tiling(), + changeTiling(ctx, builder, v.getLoc(), vty, src, std::move(src_tiles), + dst.tiling(), dst.offsets()[0] == std::nullopt && src.offsets()[0] != std::nullopt)); FAILUREOR_ASSIGN_OR_RETURN( std::tie(src, src_tiles), - changeImplicitDim(builder, ctx.target_shape, v.getLoc(), vty, src, + changeImplicitDim(ctx, builder, v.getLoc(), vty, src, std::move(src_tiles), dst.implicit_dim(), dst.offsets())); FAILUREOR_ASSIGN_OR_RETURN( std::tie(src, src_tiles), - changeOffsets(builder, ctx.target_shape, v.getLoc(), vty, src, - std::move(src_tiles), dst.offsets())); + changeOffsets(ctx, builder, v.getLoc(), vty, src, std::move(src_tiles), + dst.offsets())); CHECK_EQ(src, dst); // At this point we've should be done. return assemble(builder, vty, dst, std::move(src_tiles), target_shape, @@ -5598,7 +6207,8 @@ LogicalResult applyLayoutOp(RewriteContext &ctx, Operation &op) { // TODO: b/342235360 - This check is temporary while we increase and test // support for offsets outside of the first tile. When support is more broad, // any op without support should check it within their own rule. - if (!isa(op)) { + if (!isa(op)) { for (const Layout &layout : layouts_in) { if (layout && layout->offsets()[1].has_value() && layout->offsets()[1].value() >= layout->tiling()[1]) { @@ -5668,6 +6278,7 @@ struct ApplyVectorLayoutPass mxu_contracting_size = ctx.mxu_shape[0]; mxu_noncontracting_size = ctx.mxu_shape[1]; max_sublanes_in_scratch = ctx.max_sublanes_in_scratch; + vmem_banks = ctx.vmem_banks; } void runOnOperation() override { // Fail if hardware_generation has not been set from the default value. @@ -5679,7 +6290,8 @@ struct ApplyVectorLayoutPass .hardware_generation = hardware_generation, .target_shape = {sublane_count, lane_count}, .mxu_shape = {mxu_contracting_size, mxu_noncontracting_size}, - .max_sublanes_in_scratch = max_sublanes_in_scratch}; + .max_sublanes_in_scratch = max_sublanes_in_scratch, + .vmem_banks = vmem_banks}; if (failed(applyLayoutFunc(ctx, getOperation()))) { signalPassFailure(); return; diff --git a/jaxlib/mosaic/dialect/tpu/transforms/canonicalize_mosaic.cc b/jaxlib/mosaic/dialect/tpu/transforms/canonicalize_mosaic.cc index 54c0776514df..e70e01dfbce7 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/canonicalize_mosaic.cc +++ b/jaxlib/mosaic/dialect/tpu/transforms/canonicalize_mosaic.cc @@ -203,8 +203,8 @@ LogicalResult canonicalize_multi_dim_reduction(int hardware_generation, return success(); } else if (element_type.isBF16()) { bool reduces_sublanes = false; - for (Attribute dim : op.getReductionDims()) { - if (cast(dim).getInt() == source_ty.getRank() - 2) { + for (int64_t dim : op.getReductionDims()) { + if (dim == source_ty.getRank() - 2) { reduces_sublanes = true; } } @@ -230,7 +230,7 @@ LogicalResult canonicalize_multi_dim_reduction(int hardware_generation, } auto new_op = builder.create( op.getLoc(), new_acc.getType(), op.getKindAttr(), new_source, new_acc, - op.getReductionDims()); + DenseI64ArrayAttr::get(builder.getContext(), op.getReductionDims())); auto new_result = builder.create(op.getLoc(), result_ty, new_op.getResult()); op.replaceAllUsesWith(new_result.getResult()); @@ -317,6 +317,21 @@ LogicalResult canonicalize_contraction(int hardware_generation, Operation &op) { return result; } +LogicalResult canonicalize_extract(int hardware_generation, Operation &raw_op) { + auto op = dyn_cast(raw_op); + Type result_ty = op.getResult().getType(); + if (!isa(result_ty)) { + bool is_supported = result_ty.isSignlessIntOrFloat() && + result_ty.getIntOrFloatBitWidth() == 32; + if (!is_supported) { + return op.emitOpError( + "Only 32-bit scalar vector.extracts supported. Cast your input to a " + "32-bit type first."); + } + } + return success(); +} + using canonicalize_rule_type = std::function; @@ -324,6 +339,7 @@ const llvm::StringMap &rules() { static auto rules = new llvm::StringMap{ {tpu::MatmulOp::getOperationName(), canonicalize_matmul}, {vector::ContractionOp::getOperationName(), canonicalize_contraction}, + {vector::ContractionOp::getOperationName(), canonicalize_extract}, {vector::MultiDimReductionOp::getOperationName(), canonicalize_multi_dim_reduction}}; return *rules; diff --git a/jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout.cc b/jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout.cc index 8f702432d397..2894b0797e7b 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout.cc +++ b/jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout.cc @@ -196,12 +196,16 @@ class VectorLayoutInferer { auto out_ty = dyn_cast(op.getType()); TPU_CHECK_OP(static_cast(in_ty) == static_cast(out_ty), "Input and output are not both vectors?"); - if (in_ty) { - TPU_CHECK_OP(in_ty.getElementTypeBitWidth() == 1, - "Only extending i1 is supported"); - } - if (inferElementwise(&any_op, /*check_bitwidth=*/false).failed()) { - return failure(); + auto in_bitwidth = in_ty ? in_ty.getElementTypeBitWidth() + : op.getIn().getType().getIntOrFloatBitWidth(); + if (in_bitwidth == 1) { + if (inferElementwise(&any_op, /*check_bitwidth=*/false).failed()) { + return failure(); + } + } else { + if (inferExt(&any_op).failed()) { + return failure(); + } } } else if (isa(any_op) || isa(any_op)) { Operation *op = &any_op; // For TPU_CHECK_OP macros, which use the `op` @@ -1277,11 +1281,7 @@ class VectorLayoutInferer { auto src_ty = op.getSourceVectorType(); auto dst_ty = dyn_cast(op.getDestType()); TPU_CHECK_OP(dst_ty, "only reductions with vector results supported"); - SmallVector dims; - dims.reserve(op.getReductionDims().size()); - for (Attribute dim_attr : op.getReductionDims()) { - dims.push_back(cast(dim_attr).getInt()); - } + llvm::ArrayRef dims = op.getReductionDims(); int64_t src_rank = src_ty.getRank(); auto acc_layout = getLayout(op.getAcc()); TPU_CHECK_OP(is_fully_replicated(acc_layout), @@ -1358,6 +1358,9 @@ class VectorLayoutInferer { // TODO(tlongeri): Be smarter about trying implicit dims. We should probably // only add them when folding dimensions, and remove them when unfolding. + // The ordering of candidate implicit dims is important! Inserting an + // implicit second minor can make a reshape possible, but also very + // inefficient. We should always prefer to try with None first. SmallVector candidate_implicit_dims; if (res_shape.size() >= 2) { candidate_implicit_dims.push_back(ImplicitDim::kNone); @@ -1386,21 +1389,28 @@ class VectorLayoutInferer { for (const ImplicitDim implicit_dim : candidate_implicit_dims) { const std::array res_tiled_ishape = VectorLayout::getImplicitTiledDims(implicit_dim, res_shape, 1); - // Sublane (un)folding. - if (src_tiled_ishape[1] == res_tiled_ishape[1] && - src_tiled_ishape[0] % vreg_slice[0] == 0 && - res_tiled_ishape[0] % vreg_slice[0] == 0) { - // TODO(b/343808585): We shouldn't force second minor offset to 0 when - // unfolding, it's still a no-op, but we need to add - // support in apply-vector-layout. - const LayoutOffsets offsets = {0, layout.offsets()[1]}; - setLayout(op, - VectorLayout(layout.bitwidth(), offsets, layout.tiling(), - layout.implicit_dim()), - VectorLayout(layout.bitwidth(), offsets, layout.tiling(), - implicit_dim)); - return success(); - } + // Sublane (un)folding. We attempt to reduce the sublane tiling, which + // might make this reshape a no-op. We use do-while to handle the packed + // 1D tilings that use 1 in the sublane dimension. + int64_t sublane_tiling = vreg_slice[0]; + do { + if (src_tiled_ishape[1] == res_tiled_ishape[1] && + src_tiled_ishape[0] % sublane_tiling == 0 && + res_tiled_ishape[0] % sublane_tiling == 0) { + std::array tiling = {sublane_tiling, target_shape_[1]}; + // TODO(b/343808585): We shouldn't force second minor offset to 0 when + // unfolding, it's still a no-op, but we need to + // add support in apply-vector-layout. + LayoutOffsets offsets = {0, layout.offsets()[1]}; + setLayout(op, + VectorLayout(layout.bitwidth(), offsets, tiling, + layout.implicit_dim()), + VectorLayout(layout.bitwidth(), offsets, tiling, + implicit_dim)); + return success(); + } + sublane_tiling /= 2; + } while (sublane_tiling >= layout.packing()); // Lane (un)folding. if (src_tiled_ishape[1] != res_tiled_ishape[1] && src_tiled_ishape[1] % layout.tiling()[1] == 0 && @@ -1770,9 +1780,8 @@ class VectorLayoutInferer { if (auto reduce = dyn_cast(operand.getOwner())) { bool reduces_tiled_dims = false; - for (Attribute dim : reduce.getReductionDims()) { - if (cast(dim).getInt() >= - reduce.getSourceVectorType().getRank() - 2) { + for (int64_t dim : reduce.getReductionDims()) { + if (dim >= reduce.getSourceVectorType().getRank() - 2) { reduces_tiled_dims = true; break; } diff --git a/jaxlib/mosaic/dialect/tpu/transforms/linalg_vectorization.cc b/jaxlib/mosaic/dialect/tpu/transforms/linalg_vectorization.cc index 4d5e62049098..7006c1c2402a 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/linalg_vectorization.cc +++ b/jaxlib/mosaic/dialect/tpu/transforms/linalg_vectorization.cc @@ -477,6 +477,7 @@ struct LinalgVectorizationPass // contract ops will help to sustain the structure through various // transformations. vector::populateVectorReductionToContractPatterns(patterns); + vector::populateSinkVectorOpsPatterns(patterns); // Pull in patterns to canonicalize transfer ops. vector::populateVectorTransferPermutationMapLoweringPatterns(patterns); vector::TransferReadOp::getCanonicalizationPatterns(patterns, ctx); diff --git a/jaxlib/mosaic/dialect/tpu/transforms/memory_space_specialization.cc b/jaxlib/mosaic/dialect/tpu/transforms/memory_space_specialization.cc index 37112666f542..569038500067 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/memory_space_specialization.cc +++ b/jaxlib/mosaic/dialect/tpu/transforms/memory_space_specialization.cc @@ -78,6 +78,14 @@ LogicalResult specializeMemorySpace(TypedValue value, updateResultFrom(op, op.getInput().getType()); continue; } + if (auto op = dyn_cast(some_op)) { + updateResultFrom(op, op.getInput().getType()); + continue; + } + if (auto op = dyn_cast(some_op)) { + updateResultFrom(op, op.getInput().getType()); + continue; + } if (auto op = dyn_cast(some_op)) { updateResultFrom(op, op.getOperand().getType()); continue; diff --git a/jaxlib/mosaic/dialect/tpu/transforms/serde.cc b/jaxlib/mosaic/dialect/tpu/transforms/serde.cc index ac2389d6c238..3f6050f31dab 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/serde.cc +++ b/jaxlib/mosaic/dialect/tpu/transforms/serde.cc @@ -22,11 +22,13 @@ limitations under the License. #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinOps.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/IR/OperationSupport.h" #include "mlir/IR/Value.h" #include "mlir/IR/Visitors.h" #include "mlir/Pass/Pass.h" // IWYU pragma: keep #include "mlir/Support/LLVM.h" +#include "mlir/include/mlir/IR/BuiltinAttributes.h" #include "mlir/include/mlir/IR/OpDefinition.h" #include "mlir/include/mlir/IR/OperationSupport.h" #include "mlir/include/mlir/Support/LogicalResult.h" @@ -41,7 +43,7 @@ namespace { constexpr std::string_view kMangledDialect = "stable_mosaic."; constexpr StringRef kVersionAttrName = "stable_mosaic.version"; -constexpr int kVersion = 2; +constexpr int kVersion = 3; StringRef mangle(StringRef name, std::string* storage) { storage->clear(); @@ -100,10 +102,40 @@ LogicalResult semaphore_signal_rule(Operation* op, int version) { return success(); } +LogicalResult vector_multi_dim_reduce_rule(Operation* op, int version) { + // Changed reductions_dims from ArrayAttr of IntegerAttrs to DenseI64ArrayAttr + // in version 3. + if (version < 3) { + Attribute reduction_dims_attr = op->getAttr("reduction_dims"); + if (!reduction_dims_attr) { + return op->emitError("Missing reduction_dims attribute"); + } + ArrayAttr reduction_dims_array = dyn_cast(reduction_dims_attr); + if (!reduction_dims_array) { + return op->emitOpError("reduction_dims attribute is not an ArrayAttr"); + } + std::vector reduction_dims; + reduction_dims.reserve(reduction_dims_array.size()); + for (Attribute reduction_dim : reduction_dims_array) { + IntegerAttr reduction_dim_attr = dyn_cast(reduction_dim); + if (!reduction_dim_attr) { + return op->emitOpError( + "reduction_dims attribute contains a non-IntegerAttr"); + } + reduction_dims.push_back(reduction_dim_attr.getInt()); + } + op->setAttr("reduction_dims", + DenseI64ArrayAttr::get(op->getContext(), reduction_dims)); + } + return success(); +} + const llvm::StringMap& upgrade_rules() { static auto rules = new llvm::StringMap{ {EnqueueDMAOp::getOperationName(), enqueue_dma_rule}, {SemaphoreSignalOp::getOperationName(), semaphore_signal_rule}, + {vector::MultiDimReductionOp::getOperationName(), + vector_multi_dim_reduce_rule} }; return *rules; } @@ -134,8 +166,8 @@ struct MosaicSerdePass : public impl::MosaicSerdePassBase { return; } if (version_attr.getInt() > kVersion) { - module->emitError("Unsupported Mosaic version: ") - << version_attr.getInt(); + module->emitError("Unsupported Mosaic version: expected <= ") + << kVersion << " but got " << version_attr.getInt(); signalPassFailure(); return; } @@ -189,4 +221,4 @@ struct MosaicSerdePass : public impl::MosaicSerdePassBase { } // namespace -} // namespace mlir::tpu +} // namespace mlir::tpu \ No newline at end of file diff --git a/jaxlib/mosaic/gpu/BUILD b/jaxlib/mosaic/gpu/BUILD index 20fcf2b4ce74..e5eaeb347137 100644 --- a/jaxlib/mosaic/gpu/BUILD +++ b/jaxlib/mosaic/gpu/BUILD @@ -17,7 +17,7 @@ load("//jaxlib:jax.bzl", "pybind_extension") package( default_applicable_licenses = [], - default_visibility = ["//:__subpackages__"], + default_visibility = ["//jax:mosaic_gpu_users"], ) py_library( @@ -105,6 +105,12 @@ cc_library( deps = [ ":passes", "//jaxlib/cuda:cuda_vendor", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/synchronization", "@llvm-project//llvm:Support", "@llvm-project//mlir:ArithDialect", "@llvm-project//mlir:ArithToLLVM", @@ -142,12 +148,6 @@ cc_library( "@llvm-project//mlir:VectorDialect", "@xla//xla/service:custom_call_status", "@xla//xla/service:custom_call_target_registry", - "@com_google_absl//absl/base:core_headers", - "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/log:check", - "@com_google_absl//absl/status", - "@com_google_absl//absl/status:statusor", - "@com_google_absl//absl/synchronization", ], alwayslink = True, ) @@ -168,11 +168,11 @@ pybind_extension( deps = [ "//jaxlib:kernel_nanobind_helpers", "//jaxlib/cuda:cuda_vendor", - "@xla//xla/service:custom_call_status", - "@xla//xla/tsl/cuda:cudart", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/synchronization", "@nanobind", + "@xla//xla/service:custom_call_status", + "@xla//xla/tsl/cuda:cudart", ], ) @@ -192,7 +192,7 @@ cc_binary( "notap", ], deps = [ - "@xla//xla/tsl/cuda:cudart", "@local_config_cuda//cuda:cuda_headers", + "@xla//xla/tsl/cuda:cudart", ], ) diff --git a/jaxlib/mosaic/gpu/custom_call.cc b/jaxlib/mosaic/gpu/custom_call.cc index 47ad893eaa05..103f9f78c32f 100644 --- a/jaxlib/mosaic/gpu/custom_call.cc +++ b/jaxlib/mosaic/gpu/custom_call.cc @@ -18,7 +18,9 @@ limitations under the License. #include #include +#include #include +#include #include #include #include @@ -335,8 +337,10 @@ absl::StatusOr> Compile( } // Create a transformer to run all LLVM optimization passes at the // specified optimization level. + auto transformer = mlir::makeOptimizingTransformer( + /*optLevel=*/3, /*sizeLevel=*/0, /*targetMachine=*/nullptr); mlir::ExecutionEngineOptions options; - options.transformer = mlir::makeOptimizingTransformer(3, 0, nullptr); + options.transformer = transformer; options.jitCodeGenOptLevel = llvm::CodeGenOptLevel::Aggressive; options.sharedLibPaths = runtime_lib; auto maybe_execution_engine = mlir::ExecutionEngine::create(module, options); @@ -349,36 +353,65 @@ absl::StatusOr> Compile( class CompiledKernel { public: CompiledKernel(std::unique_ptr engine, void* ctx, - void* scratch_addr, MosaicHostFunc* host_launch) - : engine_(std::move(engine)), - ctx_(ctx), - scratch_addr_(scratch_addr), - host_launch_(host_launch) {} - - std::tuple GetHostLaunch() { - return std::make_tuple(ctx_, scratch_addr_, host_launch_); + MosaicHostFunc* host_launch) + : engine_(std::move(engine)), ctx_(ctx), host_launch_(host_launch) {} + + std::tuple GetHostLaunch() { + return std::make_tuple(ctx_, host_launch_); } private: std::unique_ptr engine_; void* ctx_; // TODO(apaszke): Destroy this properly - void* scratch_addr_; MosaicHostFunc* host_launch_; }; -std::pair*, absl::Mutex*> +using KernelHash = std::array; +using CacheKey = std::pair; + +std::pair*, absl::Mutex*> GetKernelCache() { static absl::Mutex mutex; static auto& context_cache = - *new absl::flat_hash_map; + *new absl::flat_hash_map; return std::make_pair(&context_cache, &mutex); } + +absl::StatusOr CompileAndInit(const char* module) { + mlir::MLIRContext context(mlir::MLIRContext::Threading::DISABLED); + InitContext(&context); + mlir::ParserConfig parse_config(&context); + auto module_op = + mlir::parseSourceString(module, parse_config); + if (!module_op) { + return absl::InternalError("Failed to parse module"); + } + auto maybe_engine = Compile(*module_op); + if (!maybe_engine.ok()) { + return maybe_engine.status(); + } + mlir::ExecutionEngine* execution_engine = maybe_engine->get(); + auto main = execution_engine->lookupPacked("_mlir_ciface_main"); + auto init = execution_engine->lookupPacked("_mlir_ciface_main_init"); + if (!init || !main) { + return absl::InternalError("Failed to retrieve kernel function"); + } + void* module_ptr = nullptr; + void* kernel_ptr = nullptr; + void** module_ptr_ptr = &module_ptr; + void** kernel_ptr_ptr = &kernel_ptr; + void*** init_args[2] = {&module_ptr_ptr, &kernel_ptr_ptr}; + reinterpret_cast(*init)(init_args); + return CompiledKernel(std::move(*maybe_engine), kernel_ptr, + reinterpret_cast(*main)); +} + // Each compiled kernel has a unique init func, and each kernel is used from // a single HLO module. So it should be safe to not include the CUDA context // in the key. -absl::StatusOr> CompileAndInit( - uint64_t kernel_id, const char* module) { +absl::StatusOr> CachedCompileAndInit( + CacheKey key, const char* module) { auto cache_and_mutex = GetKernelCache(); auto* cache = cache_and_mutex.first; auto* mutex = cache_and_mutex.second; @@ -386,66 +419,78 @@ absl::StatusOr> CompileAndInit( { // Fast path uses reader lock (as hash map look-up is relatively slow). absl::ReaderMutexLock lock(mutex); - auto it = cache->find(kernel_id); + auto it = cache->find(key); if (ABSL_PREDICT_TRUE(it != cache->end())) return it->second.GetHostLaunch(); } absl::MutexLock lock(mutex); // We released the reader lock, another thread might have initialized it. - if (cache->find(kernel_id) == cache->end()) { - mlir::MLIRContext context(mlir::MLIRContext::Threading::DISABLED); - InitContext(&context); - mlir::ParserConfig parse_config(&context); - auto module_op = - mlir::parseSourceString(module, parse_config); - if (!module_op) { - return absl::InternalError("Failed to parse module"); - } - auto maybe_engine = Compile(*module_op); - if (!maybe_engine.ok()) { - return maybe_engine.status(); + if (cache->find(key) == cache->end()) { + auto compiled = CompileAndInit(module); + if (!compiled.ok()) { + return compiled.status(); } - mlir::ExecutionEngine* execution_engine = maybe_engine->get(); - auto main = execution_engine->lookupPacked("_mlir_ciface_main"); - auto init = execution_engine->lookupPacked("_mlir_ciface_main_init"); - if (!init || !main) { - return absl::InternalError("Failed to retrieve kernel function"); - } - void* module_ptr = nullptr; - void* kernel_ptr = nullptr; - void** module_ptr_ptr = &module_ptr; - void** kernel_ptr_ptr = &kernel_ptr; - void*** init_args[2] = {&module_ptr_ptr, &kernel_ptr_ptr}; - reinterpret_cast(*init)(init_args); - CUmodule module = static_cast(module_ptr); - CUdeviceptr scratch_addr; - cuModuleGetGlobal(&scratch_addr, nullptr, module, "global_scratch"); - cache->insert_or_assign( - kernel_id, - CompiledKernel(std::move(*maybe_engine), kernel_ptr, - reinterpret_cast(scratch_addr), - reinterpret_cast(*main))); + cache->insert_or_assign(key, std::move(*compiled)); } - return cache->at(kernel_id).GetHostLaunch(); + return cache->at(key).GetHostLaunch(); } void MosaicGPUCustomCall(void* stream, void** buffers, char* opaque, size_t opaque_len, XlaCustomCallStatus* status) { - uint64_t kernel_id = *reinterpret_cast(opaque); - auto ctx_and_kernel = CompileAndInit(kernel_id, opaque + sizeof(uint64_t)); + if (reinterpret_cast(opaque) % alignof(KernelHash)) { + fprintf(stderr, "Misaligned opaque pointer\n"); + abort(); + } + auto hash = *reinterpret_cast(opaque); + CUcontext ctx; + if (cuCtxGetCurrent(&ctx) != CUDA_SUCCESS) { + fprintf(stderr, "Failed to get current CUDA context\n"); + abort(); + } + CacheKey key(hash, reinterpret_cast(ctx)); + auto ctx_and_kernel = CachedCompileAndInit(key, opaque + sizeof(KernelHash)); if (!ctx_and_kernel.ok()) { XlaCustomCallStatusSetFailure(status, ctx_and_kernel.status().message().data(), ctx_and_kernel.status().message().size()); return; } - void* args[4] = {&std::get<0>(*ctx_and_kernel), &stream, &buffers, - &std::get<1>(*ctx_and_kernel)}; - std::get<2>(*ctx_and_kernel)(args); + void* args[4] = {&std::get<0>(*ctx_and_kernel), &stream, &buffers}; + std::get<1>(*ctx_and_kernel)(args); } XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("mosaic_gpu", &MosaicGPUCustomCall, "CUDA"); } // namespace + +extern "C" { + +__attribute__((visibility("default"))) +void** MosaicGpuCompile(const char* module) { + auto compiled = CompileAndInit(module); + if (!compiled.ok()) { + return nullptr; + } + auto [ctx, launch] = compiled->GetHostLaunch(); + auto tuple_ptr = std::unique_ptr(new void*[3]); + if (!tuple_ptr) { + return nullptr; + } + tuple_ptr.get()[0] = ctx; + tuple_ptr.get()[1] = reinterpret_cast(launch); + tuple_ptr.get()[2] = new CompiledKernel(std::move(*compiled)); + if (!tuple_ptr.get()[2]) { + return nullptr; + } + return tuple_ptr.release(); +} + +__attribute__((visibility("default"))) +void MosaicGpuUnload(void** tuple_ptr) { + delete reinterpret_cast(tuple_ptr[2]); + delete[] tuple_ptr; +} + +} // extern "C" diff --git a/jaxlib/mosaic/gpu/runtime.cc b/jaxlib/mosaic/gpu/runtime.cc index 4acb9c3dbf83..82659e45bef1 100644 --- a/jaxlib/mosaic/gpu/runtime.cc +++ b/jaxlib/mosaic/gpu/runtime.cc @@ -88,6 +88,13 @@ void mosaic_gpu_init_tma_desc(CUtensorMap *tma_desc, void *base_addr, tma_window_shape_i, rank - i - 1); abort(); } + if (i == 0 && (tma_window_shape_i * elem_bytewidth) % 16 != 0) { + fprintf(stderr, + "The last dimension of window shape must have a bytewidth " + "divisible by 16, but got %d*%ld at index %ld\n", + tma_window_shape_i, elem_bytewidth, rank - i - 1); + abort(); + } tma_window_shape[i] = tma_window_shape_i; } cuuint32_t element_strides[5] = {1, 1, 1, 1, 1}; diff --git a/jaxlib/mosaic/python/BUILD b/jaxlib/mosaic/python/BUILD index 639e61a89062..48268bfcf30a 100644 --- a/jaxlib/mosaic/python/BUILD +++ b/jaxlib/mosaic/python/BUILD @@ -14,8 +14,8 @@ # Mosaic Python bindings -load("@rules_python//python:defs.bzl", "py_library") load("@llvm-project//mlir:tblgen.bzl", "gentbl_filegroup") +load("@rules_python//python:defs.bzl", "py_library") gentbl_filegroup( name = "tpu_python_gen_raw", diff --git a/jaxlib/rocm/BUILD.bazel b/jaxlib/rocm/BUILD similarity index 92% rename from jaxlib/rocm/BUILD.bazel rename to jaxlib/rocm/BUILD index 1ec36fd30c8e..5987415224c7 100644 --- a/jaxlib/rocm/BUILD.bazel +++ b/jaxlib/rocm/BUILD @@ -14,6 +14,7 @@ # AMD HIP kernels +load("@rules_python//python:defs.bzl", "py_library") load( "//jaxlib:jax.bzl", "if_rocm_is_configured", @@ -23,7 +24,10 @@ load( licenses(["notice"]) -package(default_visibility = ["//:__subpackages__"]) +package( + default_applicable_licenses = [], + default_visibility = ["//:__subpackages__"], +) cc_library( name = "hip_vendor", @@ -58,6 +62,16 @@ cc_library( ]), ) +rocm_library( + name = "hip_make_batch_pointers", + srcs = ["//jaxlib/gpu:make_batch_pointers.cu.cc"], + hdrs = ["//jaxlib/gpu:make_batch_pointers.h"], + deps = [ + ":hip_vendor", + "@local_config_rocm//rocm:rocm_headers", + ], +) + cc_library( name = "hip_blas_handle_pool", srcs = ["//jaxlib/gpu:blas_handle_pool.cc"], @@ -80,6 +94,7 @@ cc_library( deps = [ ":hip_blas_handle_pool", ":hip_gpu_kernel_helpers", + ":hip_make_batch_pointers", ":hip_vendor", "//jaxlib:kernel_helpers", "@com_google_absl//absl/algorithm:container", @@ -98,23 +113,6 @@ cc_library( ], ) -cc_library( - name = "hipblas_kernels_ffi", - srcs = ["//jaxlib/gpu:blas_kernels_ffi.cc"], - hdrs = ["//jaxlib/gpu:blas_kernels_ffi.h"], - deps = [ - ":hip_blas_handle_pool", - ":hip_gpu_kernel_helpers", - ":hip_vendor", - "//jaxlib:ffi_helpers", - "@com_google_absl//absl/status", - "@local_config_rocm//rocm:hipblas", - "@local_config_rocm//rocm:rocm_headers", - "@xla//xla/ffi/api:ffi", - "@xla//xla/service:custom_call_status", - ], -) - pybind_extension( name = "_blas", srcs = ["//jaxlib/gpu:blas.cc"], @@ -127,7 +125,6 @@ pybind_extension( deps = [ ":hip_vendor", ":hipblas_kernels", - ":hipblas_kernels_ffi", "//jaxlib:kernel_nanobind_helpers", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/strings:str_format", @@ -171,17 +168,37 @@ cc_library( ], ) +cc_library( + name = "hipsolver_interface", + srcs = ["//jaxlib/gpu:solver_interface.cc"], + hdrs = ["//jaxlib/gpu:solver_interface.h"], + deps = [ + ":hip_gpu_kernel_helpers", + ":hip_vendor", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:str_format", + "@local_config_rocm//rocm:hipblas", + "@local_config_rocm//rocm:hipsolver", + ], +) + cc_library( name = "hipsolver_kernels_ffi", srcs = ["//jaxlib/gpu:solver_kernels_ffi.cc"], hdrs = ["//jaxlib/gpu:solver_kernels_ffi.h"], deps = [ + ":hip_blas_handle_pool", ":hip_gpu_kernel_helpers", + ":hip_make_batch_pointers", ":hip_solver_handle_pool", ":hip_vendor", + ":hipsolver_interface", "//jaxlib:ffi_helpers", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:str_format", + "@local_config_rocm//rocm:hipblas", "@local_config_rocm//rocm:hipsolver", "@local_config_rocm//rocm:rocm_headers", "@xla//xla/ffi/api:ffi", @@ -199,14 +216,16 @@ pybind_extension( features = ["-use_header_modules"], module_name = "_solver", deps = [ - ":hip_solver_handle_pool", ":hip_gpu_kernel_helpers", + ":hip_solver_handle_pool", ":hip_vendor", ":hipsolver_kernels", ":hipsolver_kernels_ffi", "//jaxlib:kernel_nanobind_helpers", "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings:str_format", + "@local_config_rocm//rocm:hipblas", "@local_config_rocm//rocm:hipsolver", "@local_config_rocm//rocm:rocm_headers", "@nanobind", @@ -225,6 +244,7 @@ cc_library( "//jaxlib:kernel_helpers", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", "@com_google_absl//absl/synchronization", "@local_config_rocm//rocm:hipsparse", "@local_config_rocm//rocm:rocm_headers", @@ -245,6 +265,7 @@ pybind_extension( ":hip_gpu_kernel_helpers", ":hip_vendor", ":hipsparse_kernels", + "//jaxlib:absl_status_casters", "//jaxlib:kernel_nanobind_helpers", "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/base", @@ -273,13 +294,13 @@ cc_library( ":hip_vendor", "//jaxlib:ffi_helpers", "//jaxlib:kernel_helpers", - "@com_google_absl//absl/algorithm:container", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings:str_format", "@local_config_rocm//rocm:rocm_headers", "@xla//xla/ffi/api:c_api", "@xla//xla/ffi/api:ffi", + "@xla//xla/service:custom_call_status", ], ) @@ -290,7 +311,6 @@ rocm_library( deps = [ ":hip_gpu_kernel_helpers", ":hip_vendor", - "@local_config_rocm//rocm:rocm_headers", "@xla//xla/ffi/api:ffi", "@xla//xla/service:custom_call_status", ], @@ -380,14 +400,15 @@ cc_library( "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/cleanup", "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/log", "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/synchronization", + "@tsl//tsl/platform:env", "@xla//xla/service:custom_call_status", - "@xla//xla/stream_executor/gpu:asm_compiler", "@xla//xla/tsl/util:env_var", ], ) diff --git a/jaxlib/rocm_plugin_extension.cc b/jaxlib/rocm_plugin_extension.cc index dde4e57a97cf..0100b37b22e9 100644 --- a/jaxlib/rocm_plugin_extension.cc +++ b/jaxlib/rocm_plugin_extension.cc @@ -35,8 +35,8 @@ namespace nb = nanobind; namespace xla { namespace { absl::Status RegisterCustomCallTarget(const PJRT_Api* c_api, nb::str fn_name, - nb::capsule fn, int api_version, - XLA_FFI_Handler_Traits traits) { + nb::object fn, int api_version, + XLA_FFI_Handler_Traits traits) { if (c_api->extension_start == nullptr) { return Unimplemented("The plugin does not have extension."); } @@ -50,6 +50,8 @@ absl::Status RegisterCustomCallTarget(const PJRT_Api* c_api, nb::str fn_name, if (next == nullptr) { return Unimplemented("The plugin does not have a custom call extension."); } + PJRT_Gpu_Register_Custom_Call* register_custom_call = + reinterpret_cast(next)->custom_call; if (traits != 0) { return Unimplemented("The plugin does not support custom call traits."); @@ -59,14 +61,73 @@ absl::Status RegisterCustomCallTarget(const PJRT_Api* c_api, nb::str fn_name, args.struct_size = PJRT_Gpu_Register_Custom_Call_Args_STRUCT_SIZE; args.function_name = fn_name.c_str(); args.function_name_size = nb::len(fn_name); + #if PJRT_API_GPU_EXTENSION_VERSION >= 1 args.api_version = api_version; #endif - args.custom_call_function = static_cast(fn.data()); - RETURN_STATUS_IF_PJRT_ERROR( - reinterpret_cast(next)->custom_call(&args), - c_api); + + auto as_capsule = [](nb::object obj) -> absl::StatusOr { + nb::capsule capsule; + if (!nb::try_cast(obj, capsule)) { + return absl::InvalidArgumentError( + "Custom call target registration requires handlers as PyCapsules"); + } + return capsule; + }; + +#if PJRT_API_GPU_EXTENSION_VERSION <= 1 + TF_ASSIGN_OR_RETURN(nb::capsule fn_execute, as_capsule(fn)); + args.custom_call_function = fn_execute.data(); + RETURN_STATUS_IF_PJRT_ERROR(register_custom_call(&args), c_api); return absl::OkStatus(); +#else + args.handler_instantiate = nullptr; + args.handler_prepare = nullptr; + args.handler_initialize = nullptr; + args.handler_execute = nullptr; + + // Register legacy custom call target (untyped void* API). + if (api_version == 0) { + TF_ASSIGN_OR_RETURN(nb::capsule capsule_execute, as_capsule(fn)); + args.handler_execute = capsule_execute.data(); + RETURN_STATUS_IF_PJRT_ERROR(register_custom_call(&args), c_api); + return absl::OkStatus(); + } + + // Register XLA FFI handler (typed API with explicit function signatures). + if (api_version == 1) { + auto capsule_execute = as_capsule(fn); + if (capsule_execute.ok()) { + args.handler_execute = capsule_execute->data(); + RETURN_STATUS_IF_PJRT_ERROR(register_custom_call(&args), c_api); + return absl::OkStatus(); + } + + nb::dict bundle; + if (nb::try_cast(fn, bundle)) { + auto handler = [&](const char* name) -> absl::StatusOr { + if (!bundle.contains(name)) return nullptr; + TF_ASSIGN_OR_RETURN(nb::capsule capsule, as_capsule(bundle[name])); + return capsule.data(); + }; + + TF_ASSIGN_OR_RETURN(args.handler_instantiate, handler("instantiate")); + TF_ASSIGN_OR_RETURN(args.handler_prepare, handler("prepare")); + TF_ASSIGN_OR_RETURN(args.handler_initialize, handler("initialize")); + TF_ASSIGN_OR_RETURN(args.handler_execute, handler("execute")); + RETURN_STATUS_IF_PJRT_ERROR(register_custom_call(&args), c_api); + return absl::OkStatus(); + } + + return absl::InvalidArgumentError( + "Unsupported custom call target type for api_version=1"); + } + + return absl::UnimplementedError(absl::StrFormat( + "API version %d is not supported by RegisterCustomCallTarget. " + "Supported versions are 0 and 1.", + api_version)); +#endif } nb::dict Registrations() { @@ -118,7 +179,7 @@ NB_MODULE(rocm_plugin_extension, m) { tsl::ImportNumpy(); m.def( "register_custom_call_target", - [](nb::capsule c_api, nb::str fn_name, nb::capsule fn, + [](nb::capsule c_api, nb::str fn_name, nb::object fn, nb::str xla_platform_name, int api_version, XLA_FFI_Handler_Traits traits) { xla::ThrowIfError(RegisterCustomCallTarget( @@ -139,11 +200,11 @@ NB_MODULE(rocm_plugin_extension, m) { void* data_ptr = reinterpret_cast(data_value); hipError_t result = hipPointerGetAttribute(static_cast(&device_ordinal), - HIP_POINTER_ATTRIBUTE_DEVICE_ORDINAL, - reinterpret_cast(data_ptr)); + HIP_POINTER_ATTRIBUTE_DEVICE_ORDINAL, + reinterpret_cast(data_ptr)); if (result != hipSuccess) { - LOG(FATAL) << "Not able to get the device_ordinal for ptr: " << data_ptr - << ". Error: " << ToString(result); + LOG(FATAL) << "Not able to get the device_ordinal for ptr: " + << data_ptr << ". Error: " << ToString(result); } return device_ordinal; }, diff --git a/jaxlib/setup.py b/jaxlib/setup.py index 215313f9bb3a..dea9503c7c00 100644 --- a/jaxlib/setup.py +++ b/jaxlib/setup.py @@ -66,7 +66,7 @@ def has_ext_modules(self): 'numpy>=1.24', 'ml_dtypes>=0.2.0', ], - url='https://github.com/google/jax', + url='https://github.com/jax-ml/jax', license='Apache-2.0', classifiers=[ "Programming Language :: Python :: 3.10", diff --git a/jaxlib/tools/BUILD.bazel b/jaxlib/tools/BUILD.bazel index 8463cba08c5f..4553dc1e3ea8 100644 --- a/jaxlib/tools/BUILD.bazel +++ b/jaxlib/tools/BUILD.bazel @@ -16,7 +16,7 @@ load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda") load("@local_config_rocm//rocm:build_defs.bzl", "if_rocm") -load("//jaxlib:jax.bzl", "if_windows") +load("//jaxlib:jax.bzl", "if_windows", "jax_py_test") licenses(["notice"]) # Apache 2 @@ -52,7 +52,7 @@ py_binary( ], ) -py_test( +jax_py_test( name = "build_wheel_test", srcs = ["build_wheel_test.py"], data = [":build_wheel"], @@ -64,11 +64,12 @@ py_test( cc_binary( name = "pjrt_c_api_gpu_plugin.so", linkopts = [ - "-Wl,--version-script,$(location @xla//xla/pjrt/c:pjrt_c_api_gpu_version_script.lds)", + "-Wl,--version-script,$(location :gpu_version_script.lds)", "-Wl,--no-undefined", ], linkshared = True, deps = [ + ":gpu_version_script.lds", "@xla//xla/pjrt/c:pjrt_c_api_gpu", "@xla//xla/pjrt/c:pjrt_c_api_gpu_version_script.lds", "@xla//xla/service:gpu_plugin", diff --git a/jaxlib/tools/build_gpu_kernels_wheel.py b/jaxlib/tools/build_gpu_kernels_wheel.py index 28d2806a7da9..ced0b76c344c 100644 --- a/jaxlib/tools/build_gpu_kernels_wheel.py +++ b/jaxlib/tools/build_gpu_kernels_wheel.py @@ -74,7 +74,7 @@ def write_setup_cfg(sources_path, cpu): license_files = LICENSE.txt [bdist_wheel] -plat-name={tag} +plat_name={tag} """) diff --git a/jaxlib/tools/build_gpu_plugin_wheel.py b/jaxlib/tools/build_gpu_plugin_wheel.py index 73cb8a9e020d..0e2bba0c74d0 100644 --- a/jaxlib/tools/build_gpu_plugin_wheel.py +++ b/jaxlib/tools/build_gpu_plugin_wheel.py @@ -80,7 +80,7 @@ def write_setup_cfg(sources_path, cpu): license_files = LICENSE.txt [bdist_wheel] -plat-name={tag} +plat_name={tag} python-tag=py3 """ ) diff --git a/jaxlib/tools/build_wheel.py b/jaxlib/tools/build_wheel.py index 48aab847f3fb..3c40c2d11fb5 100644 --- a/jaxlib/tools/build_wheel.py +++ b/jaxlib/tools/build_wheel.py @@ -135,7 +135,7 @@ def verify_mac_libraries_dont_reference_chkstack(): We don't entirely know why this happens, but in some build environments we seem to target the wrong Mac OS version. - https://github.com/google/jax/issues/3867 + https://github.com/jax-ml/jax/issues/3867 This check makes sure we don't release wheels that have this dependency. """ @@ -164,7 +164,7 @@ def write_setup_cfg(sources_path, cpu): license_files = LICENSE.txt [bdist_wheel] -plat-name={tag} +plat_name={tag} """ ) @@ -351,6 +351,7 @@ def prepare_wheel(sources_path: pathlib.Path, *, cpu, skip_gpu_kernels): f"__main__/jaxlib/mlir/_mlir_libs/_mlirDialectsSparseTensor.{pyext}", f"__main__/jaxlib/mlir/_mlir_libs/_mlirSparseTensorPasses.{pyext}", f"__main__/jaxlib/mlir/_mlir_libs/_tpu_ext.{pyext}", + f"__main__/jaxlib/mlir/_mlir_libs/_sdy.{pyext}", f"__main__/jaxlib/mlir/_mlir_libs/_stablehlo.{pyext}", f"__main__/jaxlib/mlir/_mlir_libs/register_jax_dialects.{pyext}", f"__main__/jaxlib/mlir/_mlir_libs/_mlirDialectsGPU.{pyext}", diff --git a/jaxlib/tools/gpu_version_script.lds b/jaxlib/tools/gpu_version_script.lds new file mode 100644 index 000000000000..8e46b2c590b2 --- /dev/null +++ b/jaxlib/tools/gpu_version_script.lds @@ -0,0 +1,11 @@ +VERS_1.0 { + global: + extern "C" { + GetPjrtApi; + MosaicGpuCompile; + MosaicGpuUnload; + }; + + local: + *; +}; diff --git a/jaxlib/triton/BUILD b/jaxlib/triton/BUILD index 1d994209ffcc..99cddd9e6381 100644 --- a/jaxlib/triton/BUILD +++ b/jaxlib/triton/BUILD @@ -12,14 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. -load("//jaxlib:jax.bzl", "if_windows", "pytype_strict_library") load("@llvm-project//mlir:tblgen.bzl", "gentbl_filegroup") +load("//jaxlib:jax.bzl", "if_windows", "pytype_strict_library") licenses(["notice"]) package( default_applicable_licenses = [], - default_visibility = ["//:__subpackages__"], + default_visibility = ["//jax:internal"], ) pytype_strict_library( @@ -56,8 +56,8 @@ genrule( out=$(RULEDIR)/$${base//_raw/} echo '# pytype: skip-file' > $${out} && \ cat $${src} | - sed -e 's/^from \\.\\./from jaxlib.mlir\\./g' | - sed -e 's/^from \\./from jaxlib.mlir\\.dialects\\./g' >> $${out} + sed -e 's/^from \\.\\./from jaxlib\\.mlir\\./g' | + sed -e 's/^from \\./from jaxlib\\.mlir\\.dialects\\./g' >> $${out} done """, ) @@ -116,7 +116,7 @@ cc_library( hdrs = ["triton_dialect_capi.h"], deps = [ "@llvm-project//llvm:Support", - "@llvm-project//mlir:CAPIIR", + "@llvm-project//mlir:CAPIIRObjects", "@llvm-project//mlir:IR", "@triton//:TritonDialects", ], diff --git a/jaxlib/triton/dialect.py b/jaxlib/triton/dialect.py index 1bbb565b69b2..0e3fb4d982cb 100644 --- a/jaxlib/triton/dialect.py +++ b/jaxlib/triton/dialect.py @@ -21,9 +21,9 @@ from collections.abc import Sequence from jaxlib.mlir._mlir_libs._triton_ext import ( - PointerType, - infer_reduce_op_encoding, - register_dialect, + PointerType as PointerType, + register_dialect as register_dialect, + infer_reduce_op_encoding as _infer_reduce_op_encoding, ) from jaxlib.mlir import ir @@ -86,7 +86,7 @@ def _infer_reduce_op_return_types( if not shape: return_types.append(op_type.element_type) elif op_encoding := op_type.encoding: - encoding = infer_reduce_op_encoding(op_encoding, axis) + encoding = _infer_reduce_op_encoding(op_encoding, axis) if encoding is not None: raise RuntimeError("Failed to infer return type encoding for ReduceOp") return_types.append( diff --git a/pyproject.toml b/pyproject.toml index bc424a13e14b..3423783e2407 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -43,6 +43,7 @@ module = [ "tensorstore.*", "web_pdb.*", "zstandard.*", + "kubernetes.*" ] ignore_missing_imports = true @@ -53,25 +54,23 @@ markers = [ ] filterwarnings = [ "error", - "default:Error (reading|writing) persistent compilation cache entry for 'jit_equal'", - "default:Error (reading|writing) persistent compilation cache entry for 'jit__lambda_'", - "default:jax.extend.mlir.dialects.mhlo is deprecated.*:DeprecationWarning", # TODO(jakevdp): remove when array_api_tests stabilize "default:.*not machine-readable.*:UserWarning", "default:Special cases found for .* but none were parsed.*:UserWarning", - "default:.*is not JSON-serializable. Using the repr instead.", + "default:.*is not JSON-serializable. Using the repr instead.*:UserWarning", + "default:The .* method is good for exploring strategies.*", - # These are transitive warnings coming from TensorFlow dependencies. - # TODO(slebedev): Remove once we bump the minimum TensorFlow version. - "default:The key path API is deprecated .*", - "default:jax.xla_computation is deprecated.*:DeprecationWarning", + # NOTE: this is probably not where you want to add code to suppress a + # warning. Only pytest tests look at this list, whereas Bazel tests also + # check for warnings and do not check this list. Most likely, you should + # add a @jtu.ignore_warning decorator to your test instead. ] doctest_optionflags = [ "NUMBER", "NORMALIZE_WHITESPACE" ] -addopts = "--doctest-glob='*.rst'" +addopts = "--doctest-glob='*.rst' --ignore='examples/ffi'" [tool.pylint.master] extension-pkg-whitelist = "numpy" @@ -114,6 +113,8 @@ ignore = [ "C408", # Unnecessary map usage "C417", + # Unnecessary dict comprehension for iterable + "C420", # Object names too complex "C901", # Local variable is assigned to but never used @@ -140,7 +141,16 @@ max-complexity = 18 [tool.ruff.lint.per-file-ignores] # F811: Redefinition of unused name. +# F821: Undefined name. "docs/autodidax.py" = ["F811"] +"docs/pallas/tpu/matmul.ipynb" = ["F811"] +"docs/pallas/tpu/distributed.ipynb" = ["F811"] +"docs/pallas/quickstart.ipynb" = ["F811"] +"docs/notebooks/autodiff_cookbook.ipynb" = ["F811", "F821"] +"docs/notebooks/autodiff_remat.ipynb" = ["F811", "F821"] +"docs/notebooks/Custom_derivative_rules_for_Python_code.ipynb" = ["F811"] +"docs/jep/9407-type-promotion.ipynb" = ["F811"] +"docs/autodidax.ipynb" = ["F811"] # Note: we don't use jax/*.py because this matches contents of jax/_src "__init__.py" = ["F401"] "jax/abstract_arrays.py" = ["F401"] diff --git a/setup.py b/setup.py index 08ce8dbcb4ed..762b5ad7a281 100644 --- a/setup.py +++ b/setup.py @@ -19,10 +19,10 @@ project_name = 'jax' -_current_jaxlib_version = '0.4.31' +_current_jaxlib_version = '0.4.33' # The following should be updated after each new jaxlib release. -_latest_jaxlib_version_on_pypi = '0.4.31' -_libtpu_version = '0.1.dev20240729' +_latest_jaxlib_version_on_pypi = '0.4.33' +_libtpu_version = '0.1.dev20240916' def load_version_module(pkg_path): spec = importlib.util.spec_from_file_location( @@ -103,8 +103,13 @@ def load_version_module(pkg_path): f"jaxlib=={_current_jaxlib_version}", f"jax-cuda12-plugin=={_current_jaxlib_version}", ], + + # For automatic bootstrapping distributed jobs in Kubernetes + 'k8s': [ + 'kubernetes', + ], }, - url='https://github.com/google/jax', + url='https://github.com/jax-ml/jax', license='Apache-2.0', classifiers=[ "Programming Language :: Python :: 3.10", diff --git a/tests/BUILD b/tests/BUILD index adea49cac293..df9a28236e6a 100644 --- a/tests/BUILD +++ b/tests/BUILD @@ -12,11 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. -load("@rules_python//python:defs.bzl", "py_test") load( "//jaxlib:jax.bzl", "jax_generate_backend_suites", - "jax_test", + "jax_multiplatform_test", + "jax_py_test", "jax_test_file_visibility", "py_deps", "pytype_test", @@ -31,30 +31,29 @@ package( jax_generate_backend_suites() -jax_test( +jax_multiplatform_test( name = "api_test", srcs = ["api_test.py"], shard_count = 10, - tags = ["test_cpu_thunks"], ) -jax_test( +jax_multiplatform_test( name = "device_test", srcs = ["device_test.py"], ) -jax_test( +jax_multiplatform_test( name = "dynamic_api_test", srcs = ["dynamic_api_test.py"], shard_count = 2, ) -jax_test( +jax_multiplatform_test( name = "api_util_test", srcs = ["api_util_test.py"], ) -py_test( +jax_py_test( name = "array_api_test", srcs = ["array_api_test.py"], deps = [ @@ -64,15 +63,18 @@ py_test( ] + py_deps("absl/testing"), ) -jax_test( +jax_multiplatform_test( name = "array_interoperability_test", srcs = ["array_interoperability_test.py"], - disable_backends = ["tpu"], + enable_backends = [ + "cpu", + "gpu", + ], tags = ["multiaccelerator"], deps = py_deps("tensorflow_core"), ) -jax_test( +jax_multiplatform_test( name = "batching_test", srcs = ["batching_test.py"], shard_count = { @@ -80,12 +82,12 @@ jax_test( }, ) -jax_test( +jax_multiplatform_test( name = "config_test", srcs = ["config_test.py"], ) -jax_test( +jax_multiplatform_test( name = "core_test", srcs = ["core_test.py"], shard_count = { @@ -94,17 +96,17 @@ jax_test( }, ) -jax_test( +jax_multiplatform_test( name = "custom_object_test", srcs = ["custom_object_test.py"], ) -jax_test( +jax_multiplatform_test( name = "debug_nans_test", srcs = ["debug_nans_test.py"], ) -py_test( +jax_py_test( name = "multiprocess_gpu_test", srcs = ["multiprocess_gpu_test.py"], args = [ @@ -117,12 +119,12 @@ py_test( ] + py_deps("portpicker"), ) -jax_test( +jax_multiplatform_test( name = "dtypes_test", srcs = ["dtypes_test.py"], ) -jax_test( +jax_multiplatform_test( name = "errors_test", srcs = ["errors_test.py"], # No need to test all other configs. @@ -131,13 +133,13 @@ jax_test( ], ) -jax_test( +jax_multiplatform_test( name = "extend_test", srcs = ["extend_test.py"], deps = ["//jax:extend"], ) -jax_test( +jax_multiplatform_test( name = "fft_test", srcs = ["fft_test.py"], backend_tags = { @@ -153,37 +155,31 @@ jax_test( }, ) -jax_test( +jax_multiplatform_test( name = "generated_fun_test", srcs = ["generated_fun_test.py"], ) -jax_test( +jax_multiplatform_test( name = "gpu_memory_flags_test_no_preallocation", srcs = ["gpu_memory_flags_test.py"], - disable_backends = [ - "cpu", - "tpu", - ], + enable_backends = ["gpu"], env = { "XLA_PYTHON_CLIENT_PREALLOCATE": "0", }, main = "gpu_memory_flags_test.py", ) -jax_test( +jax_multiplatform_test( name = "gpu_memory_flags_test", srcs = ["gpu_memory_flags_test.py"], - disable_backends = [ - "cpu", - "tpu", - ], + enable_backends = ["gpu"], env = { "XLA_PYTHON_CLIENT_PREALLOCATE": "1", }, ) -jax_test( +jax_multiplatform_test( name = "lobpcg_test", srcs = ["lobpcg_test.py"], env = {"LOBPCG_EMIT_DEBUG_PLOTS": "1"}, @@ -197,7 +193,7 @@ jax_test( ] + py_deps("matplotlib"), ) -jax_test( +jax_multiplatform_test( name = "svd_test", srcs = ["svd_test.py"], shard_count = { @@ -207,7 +203,7 @@ jax_test( }, ) -py_test( +jax_py_test( name = "xla_interpreter_test", srcs = ["xla_interpreter_test.py"], deps = [ @@ -216,7 +212,7 @@ py_test( ], ) -jax_test( +jax_multiplatform_test( name = "memories_test", srcs = ["memories_test.py"], shard_count = { @@ -227,13 +223,19 @@ jax_test( ], ) -jax_test( +jax_multiplatform_test( name = "pjit_test", srcs = ["pjit_test.py"], backend_tags = { "tpu": ["notsan"], # Times out under tsan. "gpu": ["noasan"], # Memory leaks in NCCL, see https://github.com/NVIDIA/nccl/pull/1143 }, + enable_configs = [ + "cpu_shardy", + "gpu_2gpu_shardy", + "tpu_v3_2x2_shardy", + "tpu_v4_2x2_shardy", + ], shard_count = { "cpu": 5, "gpu": 5, @@ -245,7 +247,7 @@ jax_test( ], ) -jax_test( +jax_multiplatform_test( name = "layout_test", srcs = ["layout_test.py"], backend_tags = { @@ -254,7 +256,7 @@ jax_test( tags = ["multiaccelerator"], ) -jax_test( +jax_multiplatform_test( name = "shard_alike_test", srcs = ["shard_alike_test.py"], deps = [ @@ -262,16 +264,13 @@ jax_test( ], ) -jax_test( +jax_multiplatform_test( name = "pgle_test", srcs = ["pgle_test.py"], backend_tags = { "gpu": ["noasan"], # Memory leaks in NCCL, see https://github.com/NVIDIA/nccl/pull/1143 }, - disable_backends = [ - "cpu", - "tpu", - ], + enable_backends = ["gpu"], env = {"XLA_FLAGS": "--xla_dump_to=sponge --xla_gpu_enable_latency_hiding_scheduler=true"}, tags = [ "config-cuda-only", @@ -282,13 +281,10 @@ jax_test( ], ) -jax_test( +jax_multiplatform_test( name = "mock_gpu_test", srcs = ["mock_gpu_test.py"], - disable_backends = [ - "cpu", - "tpu", - ], + enable_backends = ["gpu"], tags = [ "config-cuda-only", ], @@ -297,7 +293,7 @@ jax_test( ], ) -jax_test( +jax_multiplatform_test( name = "array_test", srcs = ["array_test.py"], backend_tags = { @@ -310,7 +306,7 @@ jax_test( ], ) -jax_test( +jax_multiplatform_test( name = "aot_test", srcs = ["aot_test.py"], tags = ["multiaccelerator"], @@ -319,7 +315,7 @@ jax_test( ] + py_deps("numpy"), ) -jax_test( +jax_multiplatform_test( name = "image_test", srcs = ["image_test.py"], shard_count = { @@ -331,23 +327,21 @@ jax_test( deps = py_deps("pil") + py_deps("tensorflow_core"), ) -jax_test( +jax_multiplatform_test( name = "infeed_test", srcs = ["infeed_test.py"], - tags = ["test_cpu_thunks"], deps = [ "//jax:experimental_host_callback", ], ) -jax_test( +jax_multiplatform_test( name = "jax_jit_test", srcs = ["jax_jit_test.py"], main = "jax_jit_test.py", - tags = ["test_cpu_thunks"], ) -py_test( +jax_py_test( name = "jax_to_ir_test", srcs = ["jax_to_ir_test.py"], deps = [ @@ -357,7 +351,7 @@ py_test( ] + py_deps("tensorflow_core"), ) -py_test( +jax_py_test( name = "jaxpr_util_test", srcs = ["jaxpr_util_test.py"], deps = [ @@ -367,7 +361,7 @@ py_test( ], ) -jax_test( +jax_multiplatform_test( name = "jet_test", srcs = ["jet_test.py"], shard_count = { @@ -380,7 +374,7 @@ jax_test( ], ) -jax_test( +jax_multiplatform_test( name = "lax_control_flow_test", srcs = ["lax_control_flow_test.py"], shard_count = { @@ -390,7 +384,7 @@ jax_test( }, ) -jax_test( +jax_multiplatform_test( name = "custom_root_test", srcs = ["custom_root_test.py"], shard_count = { @@ -400,7 +394,7 @@ jax_test( }, ) -jax_test( +jax_multiplatform_test( name = "custom_linear_solve_test", srcs = ["custom_linear_solve_test.py"], shard_count = { @@ -410,7 +404,7 @@ jax_test( }, ) -jax_test( +jax_multiplatform_test( name = "lax_numpy_test", srcs = ["lax_numpy_test.py"], backend_tags = { @@ -427,7 +421,7 @@ jax_test( ], ) -jax_test( +jax_multiplatform_test( name = "lax_numpy_operators_test", srcs = ["lax_numpy_operators_test.py"], shard_count = { @@ -435,10 +429,9 @@ jax_test( "gpu": 30, "tpu": 40, }, - tags = ["test_cpu_thunks"], ) -jax_test( +jax_multiplatform_test( name = "lax_numpy_reducers_test", srcs = ["lax_numpy_reducers_test.py"], shard_count = { @@ -446,10 +439,9 @@ jax_test( "gpu": 20, "tpu": 20, }, - tags = ["test_cpu_thunks"], ) -jax_test( +jax_multiplatform_test( name = "lax_numpy_indexing_test", srcs = ["lax_numpy_indexing_test.py"], shard_count = { @@ -459,7 +451,7 @@ jax_test( }, ) -jax_test( +jax_multiplatform_test( name = "lax_numpy_einsum_test", srcs = ["lax_numpy_einsum_test.py"], shard_count = { @@ -467,10 +459,9 @@ jax_test( "gpu": 10, "tpu": 10, }, - tags = ["test_cpu_thunks"], ) -jax_test( +jax_multiplatform_test( name = "lax_numpy_ufuncs_test", srcs = ["lax_numpy_ufuncs_test.py"], shard_count = { @@ -478,16 +469,14 @@ jax_test( "gpu": 10, "tpu": 10, }, - tags = ["test_cpu_thunks"], ) -jax_test( +jax_multiplatform_test( name = "lax_numpy_vectorize_test", srcs = ["lax_numpy_vectorize_test.py"], - tags = ["test_cpu_thunks"], ) -jax_test( +jax_multiplatform_test( name = "lax_scipy_test", srcs = ["lax_scipy_test.py"], shard_count = { @@ -498,7 +487,7 @@ jax_test( deps = py_deps("numpy") + py_deps("scipy") + py_deps("absl/testing"), ) -jax_test( +jax_multiplatform_test( name = "lax_scipy_sparse_test", srcs = ["lax_scipy_sparse_test.py"], backend_tags = { @@ -511,7 +500,7 @@ jax_test( }, ) -jax_test( +jax_multiplatform_test( name = "lax_scipy_special_functions_test", srcs = ["lax_scipy_special_functions_test.py"], backend_tags = { @@ -525,7 +514,7 @@ jax_test( deps = py_deps("numpy") + py_deps("scipy") + py_deps("absl/testing"), ) -jax_test( +jax_multiplatform_test( name = "lax_scipy_spectral_dac_test", srcs = ["lax_scipy_spectral_dac_test.py"], shard_count = { @@ -538,7 +527,7 @@ jax_test( ] + py_deps("numpy") + py_deps("scipy") + py_deps("absl/testing"), ) -jax_test( +jax_multiplatform_test( name = "lax_test", srcs = ["lax_test.py"], backend_tags = { @@ -549,21 +538,16 @@ jax_test( "gpu": 40, "tpu": 40, }, - tags = ["test_cpu_thunks"], deps = [ "//jax:internal_test_util", "//jax:lax_reference", ] + py_deps("numpy"), ) -jax_test( +jax_multiplatform_test( name = "lax_metal_test", srcs = ["lax_metal_test.py"], - disable_backends = [ - "cpu", - "gpu", - "tpu", - ], + enable_backends = ["metal"], tags = ["notap"], deps = [ "//jax:internal_test_util", @@ -571,7 +555,7 @@ jax_test( ] + py_deps("numpy"), ) -jax_test( +jax_multiplatform_test( name = "lax_autodiff_test", srcs = ["lax_autodiff_test.py"], shard_count = { @@ -579,10 +563,9 @@ jax_test( "gpu": 40, "tpu": 20, }, - tags = ["test_cpu_thunks"], ) -jax_test( +jax_multiplatform_test( name = "lax_vmap_test", srcs = ["lax_vmap_test.py"], shard_count = { @@ -590,11 +573,10 @@ jax_test( "gpu": 40, "tpu": 40, }, - tags = ["test_cpu_thunks"], deps = ["//jax:internal_test_util"] + py_deps("numpy") + py_deps("absl/testing"), ) -jax_test( +jax_multiplatform_test( name = "lax_vmap_op_test", srcs = ["lax_vmap_op_test.py"], shard_count = { @@ -602,11 +584,10 @@ jax_test( "gpu": 40, "tpu": 40, }, - tags = ["test_cpu_thunks"], deps = ["//jax:internal_test_util"] + py_deps("numpy") + py_deps("absl/testing"), ) -py_test( +jax_py_test( name = "lazy_loader_test", srcs = [ "lazy_loader_test.py", @@ -617,7 +598,7 @@ py_test( ], ) -py_test( +jax_py_test( name = "deprecation_test", srcs = [ "deprecation_test.py", @@ -628,7 +609,7 @@ py_test( ], ) -jax_test( +jax_multiplatform_test( name = "linalg_test", srcs = ["linalg_test.py"], backend_tags = { @@ -645,24 +626,20 @@ jax_test( "gpu": 40, "tpu": 40, }, - tags = ["test_cpu_thunks"], ) -jax_test( +jax_multiplatform_test( name = "cholesky_update_test", srcs = ["cholesky_update_test.py"], ) -jax_test( +jax_multiplatform_test( name = "metadata_test", srcs = ["metadata_test.py"], - disable_backends = [ - "gpu", - "tpu", - ], + enable_backends = ["cpu"], ) -py_test( +jax_py_test( name = "monitoring_test", srcs = ["monitoring_test.py"], deps = [ @@ -671,23 +648,28 @@ py_test( ], ) -jax_test( +jax_multiplatform_test( name = "multibackend_test", srcs = ["multibackend_test.py"], ) -jax_test( +jax_multiplatform_test( name = "multi_device_test", srcs = ["multi_device_test.py"], - disable_backends = [ - "gpu", - "tpu", - ], + enable_backends = ["cpu"], ) -jax_test( +jax_multiplatform_test( name = "nn_test", srcs = ["nn_test.py"], + backend_tags = { + "gpu": [ + "noasan", # Times out under asan. + ], + "tpu": [ + "noasan", # Times out under asan. + ], + }, shard_count = { "cpu": 10, "tpu": 10, @@ -695,13 +677,13 @@ jax_test( }, ) -jax_test( +jax_multiplatform_test( name = "optimizers_test", srcs = ["optimizers_test.py"], deps = ["//jax:optimizers"], ) -jax_test( +jax_multiplatform_test( name = "pickle_test", srcs = ["pickle_test.py"], deps = [ @@ -709,7 +691,7 @@ jax_test( ] + py_deps("cloudpickle") + py_deps("numpy"), ) -jax_test( +jax_multiplatform_test( name = "pmap_test", srcs = ["pmap_test.py"], backend_tags = { @@ -729,14 +711,11 @@ jax_test( ], ) -jax_test( +jax_multiplatform_test( name = "polynomial_test", srcs = ["polynomial_test.py"], # No implementation of nonsymmetric Eigendecomposition. - disable_backends = [ - "gpu", - "tpu", - ], + enable_backends = ["cpu"], shard_count = { "cpu": 10, }, @@ -749,28 +728,21 @@ jax_test( tags = ["nomsan"], ) -jax_test( +jax_multiplatform_test( name = "heap_profiler_test", srcs = ["heap_profiler_test.py"], - disable_backends = [ - "gpu", - "tpu", - ], + enable_backends = ["cpu"], ) -jax_test( +jax_multiplatform_test( name = "profiler_test", srcs = ["profiler_test.py"], - disable_backends = [ - "gpu", - "tpu", - ], + enable_backends = ["cpu"], ) -jax_test( +jax_multiplatform_test( name = "pytorch_interoperability_test", srcs = ["pytorch_interoperability_test.py"], - disable_backends = ["tpu"], # The following cases are disabled because they time out in Google's CI, mostly because the # CUDA kernels in Torch take a very long time to compile. disable_configs = [ @@ -778,6 +750,10 @@ jax_test( "gpu_a100", # Pytorch A100 build times out in Google's CI. "gpu_h100", # Pytorch H100 build times out in Google's CI. ], + enable_backends = [ + "cpu", + "gpu", + ], tags = [ "not_build:arm", # TODO(b/355237462): Re-enable once MSAN issue is addressed. @@ -786,7 +762,7 @@ jax_test( deps = py_deps("torch"), ) -jax_test( +jax_multiplatform_test( name = "qdwh_test", srcs = ["qdwh_test.py"], backend_tags = { @@ -799,7 +775,7 @@ jax_test( shard_count = 10, ) -jax_test( +jax_multiplatform_test( name = "random_test", srcs = ["random_test.py"], backend_tags = { @@ -821,7 +797,7 @@ jax_test( tags = ["noasan"], # Times out ) -jax_test( +jax_multiplatform_test( name = "random_lax_test", srcs = ["random_lax_test.py"], backend_tags = { @@ -847,7 +823,7 @@ jax_test( ) # TODO(b/199564969): remove once we always enable_custom_prng -jax_test( +jax_multiplatform_test( name = "random_test_with_custom_prng", srcs = ["random_test.py"], args = ["--jax_enable_custom_prng=true"], @@ -872,7 +848,7 @@ jax_test( }, ) -jax_test( +jax_multiplatform_test( name = "scipy_fft_test", srcs = ["scipy_fft_test.py"], backend_tags = { @@ -885,22 +861,22 @@ jax_test( shard_count = 4, ) -jax_test( +jax_multiplatform_test( name = "scipy_interpolate_test", srcs = ["scipy_interpolate_test.py"], ) -jax_test( +jax_multiplatform_test( name = "scipy_ndimage_test", srcs = ["scipy_ndimage_test.py"], ) -jax_test( +jax_multiplatform_test( name = "scipy_optimize_test", srcs = ["scipy_optimize_test.py"], ) -jax_test( +jax_multiplatform_test( name = "scipy_signal_test", srcs = ["scipy_signal_test.py"], backend_tags = { @@ -925,13 +901,13 @@ jax_test( }, ) -jax_test( +jax_multiplatform_test( name = "scipy_spatial_test", srcs = ["scipy_spatial_test.py"], deps = py_deps("scipy"), ) -jax_test( +jax_multiplatform_test( name = "scipy_stats_test", srcs = ["scipy_stats_test.py"], backend_tags = { @@ -948,7 +924,7 @@ jax_test( ], # Times out ) -jax_test( +jax_multiplatform_test( name = "sparse_test", srcs = ["sparse_test.py"], args = ["--jax_bcoo_cusparse_lowering=true"], @@ -981,7 +957,7 @@ jax_test( ] + py_deps("scipy"), ) -jax_test( +jax_multiplatform_test( name = "sparse_bcoo_bcsr_test", srcs = ["sparse_bcoo_bcsr_test.py"], args = ["--jax_bcoo_cusparse_lowering=true"], @@ -997,6 +973,7 @@ jax_test( "cpu": ["--jax_num_generated_cases=40"], "cpu_x32": ["--jax_num_generated_cases=40"], "gpu": ["--jax_num_generated_cases=40"], + "tpu": ["--jax_num_generated_cases=40"], }, shard_count = { "cpu": 50, @@ -1014,19 +991,10 @@ jax_test( ] + py_deps("scipy"), ) -jax_test( +jax_multiplatform_test( name = "sparse_nm_test", srcs = ["sparse_nm_test.py"], - config_tags_overrides = { - "gpu_a100": { - "ondemand": False, # Include in presubmit. - }, - }, - disable_backends = [ - "cpu", - "gpu", - "tpu", - ], + enable_backends = [], enable_configs = [ "gpu_a100", "gpu_h100", @@ -1037,7 +1005,7 @@ jax_test( ], ) -jax_test( +jax_multiplatform_test( name = "sparsify_test", srcs = ["sparsify_test.py"], args = ["--jax_bcoo_cusparse_lowering=true"], @@ -1061,21 +1029,21 @@ jax_test( ], ) -jax_test( +jax_multiplatform_test( name = "stack_test", srcs = ["stack_test.py"], ) -jax_test( +jax_multiplatform_test( name = "checkify_test", srcs = ["checkify_test.py"], shard_count = { "gpu": 2, - "tpu": 2, + "tpu": 4, }, ) -jax_test( +jax_multiplatform_test( name = "stax_test", srcs = ["stax_test.py"], shard_count = { @@ -1085,18 +1053,18 @@ jax_test( deps = ["//jax:stax"], ) -jax_test( +jax_multiplatform_test( name = "linear_search_test", srcs = ["third_party/scipy/line_search_test.py"], main = "third_party/scipy/line_search_test.py", ) -jax_test( +jax_multiplatform_test( name = "blocked_sampler_test", srcs = ["blocked_sampler_test.py"], ) -py_test( +jax_py_test( name = "tree_util_test", srcs = ["tree_util_test.py"], deps = [ @@ -1114,7 +1082,7 @@ pytype_test( ], ) -py_test( +jax_py_test( name = "util_test", srcs = ["util_test.py"], deps = [ @@ -1123,7 +1091,7 @@ py_test( ], ) -py_test( +jax_py_test( name = "version_test", srcs = ["version_test.py"], deps = [ @@ -1132,7 +1100,7 @@ py_test( ], ) -py_test( +jax_py_test( name = "xla_bridge_test", srcs = ["xla_bridge_test.py"], data = ["testdata/example_pjrt_plugin_config.json"], @@ -1143,7 +1111,7 @@ py_test( ] + py_deps("absl/logging"), ) -py_test( +jax_py_test( name = "lru_cache_test", srcs = ["lru_cache_test.py"], deps = [ @@ -1153,17 +1121,16 @@ py_test( ] + py_deps("filelock"), ) -jax_test( +jax_multiplatform_test( name = "compilation_cache_test", srcs = ["compilation_cache_test.py"], - tags = ["test_cpu_thunks"], deps = [ "//jax:compilation_cache_internal", "//jax:compiler", ], ) -jax_test( +jax_multiplatform_test( name = "cache_key_test", srcs = ["cache_key_test.py"], deps = [ @@ -1172,7 +1139,7 @@ jax_test( ], ) -jax_test( +jax_multiplatform_test( name = "ode_test", srcs = ["ode_test.py"], shard_count = { @@ -1181,13 +1148,16 @@ jax_test( deps = ["//jax:ode"], ) -jax_test( +jax_multiplatform_test( name = "host_callback_outfeed_test", srcs = ["host_callback_test.py"], args = ["--jax_host_callback_outfeed=true"], shard_count = { "tpu": 5, }, + tags = [ + "noasan", # Times out. + ], deps = [ "//jax:experimental", "//jax:experimental_host_callback", @@ -1195,7 +1165,7 @@ jax_test( ], ) -jax_test( +jax_multiplatform_test( name = "host_callback_test", srcs = ["host_callback_test.py"], args = ["--jax_host_callback_outfeed=false"], @@ -1203,6 +1173,7 @@ jax_test( shard_count = { "gpu": 5, }, + tags = ["noasan"], # Times out deps = [ "//jax:experimental", "//jax:experimental_host_callback", @@ -1210,7 +1181,7 @@ jax_test( ], ) -jax_test( +jax_multiplatform_test( name = "host_callback_to_tf_test", srcs = ["host_callback_to_tf_test.py"], tags = ["noasan"], # Linking TF causes a linker OOM. @@ -1220,12 +1191,12 @@ jax_test( ] + py_deps("tensorflow_core"), ) -jax_test( +jax_multiplatform_test( name = "key_reuse_test", srcs = ["key_reuse_test.py"], ) -jax_test( +jax_multiplatform_test( name = "x64_context_test", srcs = ["x64_context_test.py"], deps = [ @@ -1233,13 +1204,13 @@ jax_test( ], ) -jax_test( +jax_multiplatform_test( name = "ann_test", srcs = ["ann_test.py"], shard_count = 10, ) -py_test( +jax_py_test( name = "mesh_utils_test", srcs = ["mesh_utils_test.py"], deps = [ @@ -1249,39 +1220,39 @@ py_test( ], ) -jax_test( +jax_multiplatform_test( name = "transfer_guard_test", srcs = ["transfer_guard_test.py"], ) -jax_test( +jax_multiplatform_test( name = "name_stack_test", srcs = ["name_stack_test.py"], ) -jax_test( +jax_multiplatform_test( name = "jaxpr_effects_test", srcs = ["jaxpr_effects_test.py"], backend_tags = { "gpu": ["noasan"], # Memory leaks in NCCL, see https://github.com/NVIDIA/nccl/pull/1143 }, enable_configs = [ - "gpu", + "gpu_h100", "cpu", ], tags = ["multiaccelerator"], ) -jax_test( +jax_multiplatform_test( name = "debugging_primitives_test", srcs = ["debugging_primitives_test.py"], enable_configs = [ - "gpu", + "gpu_h100", "cpu", ], ) -jax_test( +jax_multiplatform_test( name = "python_callback_test", srcs = ["python_callback_test.py"], backend_tags = { @@ -1293,16 +1264,16 @@ jax_test( ], ) -jax_test( +jax_multiplatform_test( name = "debugger_test", srcs = ["debugger_test.py"], enable_configs = [ - "gpu", + "gpu_h100", "cpu", ], ) -jax_test( +jax_multiplatform_test( name = "state_test", srcs = ["state_test.py"], # Use fewer cases to prevent timeouts. @@ -1313,7 +1284,7 @@ jax_test( "tpu_pjrt_c_api": ["--jax_num_generated_cases=1"], }, enable_configs = [ - "gpu", + "gpu_h100", "cpu", ], shard_count = { @@ -1324,12 +1295,12 @@ jax_test( deps = py_deps("hypothesis"), ) -jax_test( +jax_multiplatform_test( name = "mutable_array_test", srcs = ["mutable_array_test.py"], ) -jax_test( +jax_multiplatform_test( name = "for_loop_test", srcs = ["for_loop_test.py"], shard_count = { @@ -1339,9 +1310,15 @@ jax_test( }, ) -jax_test( +jax_multiplatform_test( name = "shard_map_test", srcs = ["shard_map_test.py"], + enable_configs = [ + "cpu_shardy", + "gpu_2gpu_shardy", + "tpu_v3_2x2_shardy", + "tpu_v4_2x2_shardy", + ], shard_count = { "cpu": 50, "gpu": 10, @@ -1359,12 +1336,12 @@ jax_test( ], ) -jax_test( +jax_multiplatform_test( name = "clear_backends_test", srcs = ["clear_backends_test.py"], ) -jax_test( +jax_multiplatform_test( name = "attrs_test", srcs = ["attrs_test.py"], deps = [ @@ -1372,23 +1349,20 @@ jax_test( ], ) -jax_test( +jax_multiplatform_test( name = "experimental_rnn_test", srcs = ["experimental_rnn_test.py"], - disable_backends = [ - "tpu", - "cpu", - ], disable_configs = [ "gpu_a100", # Numerical precision problems. ], + enable_backends = ["gpu"], shard_count = 15, deps = [ "//jax:rnn", ], ) -py_test( +jax_py_test( name = "mosaic_test", srcs = ["mosaic_test.py"], deps = [ @@ -1398,7 +1372,7 @@ py_test( ], ) -py_test( +jax_py_test( name = "source_info_test", srcs = ["source_info_test.py"], deps = [ @@ -1407,7 +1381,7 @@ py_test( ], ) -py_test( +jax_py_test( name = "package_structure_test", srcs = ["package_structure_test.py"], deps = [ @@ -1416,16 +1390,16 @@ py_test( ], ) -jax_test( +jax_multiplatform_test( name = "logging_test", srcs = ["logging_test.py"], ) -jax_test( +jax_multiplatform_test( name = "export_test", srcs = ["export_test.py"], enable_configs = [ - "tpu_df_2x2", + "tpu_v3_2x2", ], tags = [], deps = [ @@ -1433,7 +1407,7 @@ jax_test( ], ) -jax_test( +jax_multiplatform_test( name = "shape_poly_test", srcs = ["shape_poly_test.py"], disable_configs = [ @@ -1458,7 +1432,7 @@ jax_test( ], ) -jax_test( +jax_multiplatform_test( name = "export_harnesses_multi_platform_test", srcs = ["export_harnesses_multi_platform_test.py"], disable_configs = [ @@ -1481,7 +1455,7 @@ jax_test( ], ) -jax_test( +jax_multiplatform_test( name = "export_back_compat_test", srcs = ["export_back_compat_test.py"], tags = [], @@ -1491,20 +1465,23 @@ jax_test( ], ) -jax_test( +jax_multiplatform_test( name = "fused_attention_stablehlo_test", srcs = ["fused_attention_stablehlo_test.py"], - disable_backends = [ - "tpu", - "cpu", - ], + enable_backends = ["gpu"], shard_count = { "gpu": 4, }, tags = ["multiaccelerator"], ) -py_test( +jax_multiplatform_test( + name = "xla_metadata_test", + srcs = ["xla_metadata_test.py"], + deps = ["//jax:experimental"], +) + +jax_py_test( name = "pretty_printer_test", srcs = ["pretty_printer_test.py"], deps = [ @@ -1513,7 +1490,7 @@ py_test( ], ) -py_test( +jax_py_test( name = "sourcemap_test", srcs = ["sourcemap_test.py"], deps = [ @@ -1522,6 +1499,20 @@ py_test( ], ) +jax_multiplatform_test( + name = "cudnn_fusion_test", + srcs = ["cudnn_fusion_test.py"], + enable_backends = ["gpu"], + enable_configs = [ + "gpu_a100", + "gpu_h100", + ], + tags = [ + "multiaccelerator", + "notap", # TODO(phawkins): this test fails in our internal CI. + ], +) + exports_files( [ "api_test.py", @@ -1552,6 +1543,6 @@ filegroup( exclude = [], ) + ["BUILD"], visibility = [ - "//:__subpackages__", + "//jax:internal", ], ) diff --git a/tests/api_test.py b/tests/api_test.py index cb0d7c0d40c7..c73d5960f123 100644 --- a/tests/api_test.py +++ b/tests/api_test.py @@ -50,7 +50,6 @@ from jax._src import config from jax._src import core from jax._src import custom_derivatives -from jax._src import deprecations from jax._src import linear_util as lu from jax._src import test_util as jtu from jax._src import xla_bridge @@ -60,7 +59,6 @@ from jax._src.interpreters import mlir from jax._src.interpreters import partial_eval as pe from jax._src.compilation_cache import is_persistent_cache_enabled -from jax._src.lib import xla_client from jax._src.lib import xla_extension import jax._src.util as jax_util from jax.ad_checkpoint import checkpoint_name, checkpoint as new_checkpoint @@ -469,7 +467,7 @@ def test_jit_donate_weak_type(self, argnum_type, argnum_val): ("argnames", "donate_argnames", ('array',)), ) def test_jnp_array_copy(self, argnum_type, argnum_val): - # https://github.com/google/jax/issues/3412 + # https://github.com/jax-ml/jax/issues/3412 @partial(jit, **{argnum_type: argnum_val}) def _test(array): @@ -735,13 +733,12 @@ def test_jit_bad_input(self): def f(x): return x - with self.assertRaisesRegex( - TypeError, r".* 'foo' of type <.*'str'> is not a valid JAX type"): + err_str = ("Error interpreting argument to .* as an abstract array. The problematic " + "value is of type .* and was passed to the function at path x.") + with self.assertRaisesRegex(TypeError, err_str): jit(f)("foo") # Jax type objects aren't valid data arguments. - err_str = "JAX scalar type .*int32.* cannot be interpreted as a JAX array." - with self.assertRaisesRegex(TypeError, err_str): jit(f)(jnp.int32) @@ -907,7 +904,7 @@ def f(x): @jax.legacy_prng_key('allow') def test_omnistaging(self): - # See https://github.com/google/jax/issues/5206 + # See https://github.com/jax-ml/jax/issues/5206 # TODO(frostig): remove `wrap` once we always enable_custom_prng def wrap(arr): @@ -1411,7 +1408,7 @@ def f(d) -> float: f({E.A: 1.0, E.B: 2.0}) def test_jit_static_argnums_requires_type_equality(self): - # See: https://github.com/google/jax/pull/9311 + # See: https://github.com/jax-ml/jax/pull/9311 @partial(jit, static_argnums=(0,)) def f(k): assert python_should_be_executing @@ -1426,7 +1423,7 @@ def f(k): self.assertEqual(x, f(x)) def test_caches_depend_on_axis_env(self): - # https://github.com/google/jax/issues/9187 + # https://github.com/jax-ml/jax/issues/9187 f = lambda: lax.psum(1, "i") g = jax.jit(f) expected = jax.vmap(f, axis_name="i", axis_size=2, out_axes=None)() @@ -1439,16 +1436,16 @@ def test_caches_depend_on_axis_env(self): self.assertEqual(ans, expected) def test_caches_dont_depend_on_unnamed_axis_env(self): - # https://github.com/google/jax/issues/9187 + # https://github.com/jax-ml/jax/issues/9187 f = jax.jit(lambda: jnp.sin(1)) expected = f() - with jtu.count_jit_and_pmap_compiles() as count: # noqa: F841 + with jtu.count_jit_and_pmap_lowerings() as count: # noqa: F841 ans = jax.vmap(f, axis_size=2, out_axes=None)() self.assertEqual(count[0], 0) # no compiles self.assertArraysAllClose(ans, expected, check_dtypes=True) def test_cache_key_defaults(self): - # https://github.com/google/jax/discussions/11875 + # https://github.com/jax-ml/jax/discussions/11875 f = jit(lambda x: (x ** 2).sum()) self.assertEqual(f._cache_size(), 0) x = jnp.arange(5.0) @@ -1457,12 +1454,26 @@ def test_cache_key_defaults(self): self.assertEqual(f._cache_size(), 1) def test_jit_nan_times_zero(self): - # https://github.com/google/jax/issues/4780 + # https://github.com/jax-ml/jax/issues/4780 def f(x): return 1 + x * 0 self.assertAllClose(f(np.nan), np.nan) self.assertAllClose(jit(f)(np.nan), np.nan) + def test_no_tracing(self): + @jax.jit + def f(x): + return x + + x = jnp.arange(3) + y = jnp.arange(4) + + _ = f(x) # no crash + + with self.assertRaisesRegex(RuntimeError, 'no_tracing'): + with jax.no_tracing(): + _ = f(y) # crash! + class APITest(jtu.JaxTestCase): @@ -1564,13 +1575,14 @@ def test_bad_input(self): def f(x): return x - self.assertRaisesRegex( - TypeError, ".* 'foo' of type <.*'str'> is not a valid JAX type", - lambda: grad(f)("foo")) + with self.assertRaisesRegex(TypeError, ".* 'foo' of type <.*'str'> is not a valid JAX type"): + grad(f)("foo") - self.assertRaisesRegex( - TypeError, ".* 'foo' of type <.*'str'> is not a valid JAX type", - lambda: jit(f)("foo")) + + err_str = ("Error interpreting argument to .* as an abstract array. The problematic " + "value is of type .* and was passed to the function at path x.") + with self.assertRaisesRegex(TypeError, err_str): + jit(f)("foo") def test_grad_tuple_output(self): jtu.check_raises(lambda: grad(lambda x: (x,x))(1.0), TypeError, @@ -2151,7 +2163,7 @@ def test_grad_and_aux_constant(self): self.assertEqual(aux, [4.**2, 4.]) def test_grad_and_aux_no_tracers(self): - # see https://github.com/google/jax/issues/1950 + # see https://github.com/jax-ml/jax/issues/1950 def f(x): aux = dict(identity=x, p1=x+1) return x ** 2, aux @@ -2310,7 +2322,7 @@ def test_linear_transpose_integer(self): self.assertEqual(actual, expected) def test_linear_transpose_dce(self): - # https://github.com/google/jax/issues/15660 + # https://github.com/jax-ml/jax/issues/15660 f = jit(lambda x: (2 * x, x > 0)) g = lambda x: f(x)[0] api.linear_transpose(g, 1.)(1.) @@ -2377,7 +2389,7 @@ def test_complex_output_jacrev_raises_error(self): self.assertRaises(TypeError, lambda: jacrev(lambda x: jnp.sin(x))(1 + 2j)) def test_nonholomorphic_jacrev(self): - # code based on https://github.com/google/jax/issues/603 + # code based on https://github.com/jax-ml/jax/issues/603 zs = 0.5j * np.arange(5) + np.arange(5) def f(z): @@ -2389,8 +2401,8 @@ def f(z): @jax.numpy_dtype_promotion('standard') # Test explicitly exercises implicit dtype promotion. def test_heterogeneous_jacfwd(self): - # See https://github.com/google/jax/issues/7157 - # See https://github.com/google/jax/issues/7780 + # See https://github.com/jax-ml/jax/issues/7157 + # See https://github.com/jax-ml/jax/issues/7780 x = np.array([2.0], dtype=np.float16) y = np.array([3.0], dtype=np.float32) a = (x, y) @@ -2409,8 +2421,8 @@ def f(tup): @jax.numpy_dtype_promotion('standard') # Test explicitly exercises implicit dtype promotion. def test_heterogeneous_jacrev(self): - # See https://github.com/google/jax/issues/7157 - # See https://github.com/google/jax/issues/7780 + # See https://github.com/jax-ml/jax/issues/7157 + # See https://github.com/jax-ml/jax/issues/7780 x = np.array([2.0], dtype=np.float16) y = np.array([3.0], dtype=np.float32) a = (x, y) @@ -2428,7 +2440,7 @@ def f(tup): jtu.check_eq(actual, desired) def test_heterogeneous_grad(self): - # See https://github.com/google/jax/issues/7157 + # See https://github.com/jax-ml/jax/issues/7157 x = np.array(1.0+1j) y = np.array(2.0) a = (x, y) @@ -2500,7 +2512,7 @@ def test_devicearray_weakref_friendly(self): self.assertIsNone(y()) def test_namedtuple_transparency(self): - # See https://github.com/google/jax/issues/446 + # See https://github.com/jax-ml/jax/issues/446 Point = collections.namedtuple("Point", ["x", "y"]) def f(pt): @@ -2516,7 +2528,7 @@ def f(pt): self.assertAllClose(f(pt), f_jit(pt), check_dtypes=False) def test_namedtuple_subclass_transparency(self): - # See https://github.com/google/jax/issues/806 + # See https://github.com/jax-ml/jax/issues/806 Point = collections.namedtuple("Point", ["x", "y"]) class ZeroPoint(Point): @@ -2693,7 +2705,7 @@ def __init__(self, shape, dtype): self.assertEqual(out_shape.shape, (3, 5)) def test_eval_shape_duck_typing2(self): - # https://github.com/google/jax/issues/5683 + # https://github.com/jax-ml/jax/issues/5683 class EasyDict(dict): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) @@ -2703,26 +2715,6 @@ def __init__(self, *args, **kwargs): out_shape = api.eval_shape(lambda x: x, x) # doesn't crash self.assertEqual(out_shape.shape, (3,)) - def test_eval_shape_names(self): - raise unittest.SkipTest("named shape are deprecated") - - def fun(x, y): - return lax.psum(x, 'i') + y - - class MyArgArray: - def __init__(self, shape, dtype, named_shape): - self.shape = shape - self.dtype = jnp.dtype(dtype) - self.named_shape = named_shape - - x = MyArgArray((3, 2), jnp.float32, {'i': 10}) - y = MyArgArray((3, 2), jnp.float32, {'j': 5}) - with core.extend_axis_env('i', 10, None): - with core.extend_axis_env('j', 5, None): - out_shape = api.eval_shape(fun, x, y) - - self.assertEqual(out_shape.named_shape, {'j': 5}) - def test_issue_871(self): T = jnp.array([[1., 2.], [3., 4.], [5., 6.]]) x = jnp.array([1, 2, 3]) @@ -2890,74 +2882,6 @@ def test_jacfwd_of_complex_errors(self): r"sub-dtype of np.floating\), but got complex.*"), lambda: dfn(3. + 1j)) - def test_xla_computation(self): - # these tests basically check the examples in the xla_computation docstring - - def e(x): - return jnp.sin(jnp.cos(x)) - c = api.xla_computation(e)(2.) - self.assertIn('cosine', c.as_hlo_text()) - self.assertIn('sine', c.as_hlo_text()) - - def f(x): - return x - lax.psum(x, 'i') - axis_env = [('i', 4)] - c = api.xla_computation(f, axis_env=axis_env)(2) - self.assertIn('all-reduce', c.as_hlo_text()) - self.assertIn('replica_groups={{0,1,2,3}}', c.as_hlo_text()) - - def g(x): - rowsum = lax.psum(x, 'i') - colsum = lax.psum(x, 'j') - allsum = lax.psum(x, ('i', 'j')) - return rowsum, colsum, allsum - axis_env = [('i', 4), ('j', 2)] - c = api.xla_computation(g, axis_env=axis_env)(5.) - self.assertIn('all-reduce', c.as_hlo_text()) - self.assertIn('replica_groups={{0,2,4,6},{1,3,5,7}}', c.as_hlo_text()) - self.assertIn('replica_groups={{0,1},{2,3},{4,5},{6,7}}', c.as_hlo_text()) - self.assertIn('replica_groups={{0,1,2,3,4,5,6,7}}', c.as_hlo_text()) - - def h(x): - rowsum = lax.psum(x, 'i', axis_index_groups=[[0, 1], [2, 3]]) - colsum = lax.psum(x, 'j') - return rowsum, colsum - axis_env = [('i', 4), ('j', 2)] - c = api.xla_computation(h, axis_env=axis_env)(5.) - self.assertIn('all-reduce', c.as_hlo_text()) - self.assertIn('replica_groups={{0,2},{4,6},{1,3},{5,7}}', c.as_hlo_text()) - self.assertIn('replica_groups={{0,1},{2,3},{4,5},{6,7}}', c.as_hlo_text()) - - def test_xla_computation_args(self): - def foo(x, y, z): - return x + y + z - - c = api.xla_computation(foo)(1., 2., 3.) - self.assertEqual(len(c.program_shape().parameter_shapes()), 3) - - c = api.xla_computation(foo, tuple_args=True)(1., 2., 3.) - param_shapes = c.program_shape().parameter_shapes() - self.assertEqual(len(param_shapes), 1) - self.assertEqual(param_shapes[0].xla_element_type(), - xla_client.PrimitiveType.TUPLE) - - def test_xla_computation_duck_typing(self): - def foo(x, y, z): - return x + y + z - - x = jax.ShapeDtypeStruct((), np.float32) - y = jax.ShapeDtypeStruct((), np.float32) - z = jax.ShapeDtypeStruct((), np.float32) - - c = api.xla_computation(foo)(x, y, z) - self.assertEqual(len(c.program_shape().parameter_shapes()), 3) - - c = api.xla_computation(foo, tuple_args=True)(1., 2., 3.) - param_shapes = c.program_shape().parameter_shapes() - self.assertEqual(len(param_shapes), 1) - self.assertEqual(param_shapes[0].xla_element_type(), - xla_client.PrimitiveType.TUPLE) - def test_compiler_ir(self): # TODO(phawkins): merge these tests with the `xla_computation` tests. def e(x): @@ -2969,72 +2893,6 @@ def e(x): self.assertIn("stablehlo.cosine", stablehlo) self.assertIn("stablehlo.sine", stablehlo) - def test_staging_out_multi_replica(self): - def f(x): - return api.pmap(jnp.mean)(x) - xla_comp = api.xla_computation(f) - xla_comp(jnp.arange(8)).as_hlo_text() # doesn't crash - - def test_xla_computation_instantiate_constant_outputs(self): - def f(): - return jnp.zeros((3, 4)) - - xla_comp = api.xla_computation(f)() - out_shape, = xla_comp.program_shape().result_shape().tuple_shapes() - self.assertEqual(out_shape.dimensions(), (3, 4)) - - def test_xla_computation_static_argnums(self): - def f(x, y): - return x + y - - xla_comp = api.xla_computation(f, static_argnums=(1,))(2, 3) - hlo_text = xla_comp.as_hlo_text() - self.assertIn("constant(3)", hlo_text) - # The static arguments should be removed from the function being compiled, - # thus the function should have only a single argument. - self.assertIn("parameter(0)", hlo_text) - self.assertNotIn("parameter(1)", hlo_text) - - def test_xla_computation_return_shape(self): - _, shape_tree = api.xla_computation(lambda x: (x + 1, jnp.zeros(2, jnp.float32)), - return_shape=True)(np.int32(1)) - expected = (api.ShapeDtypeStruct(shape=(), dtype=jnp.int32), - api.ShapeDtypeStruct(shape=(2,), dtype=jnp.float32)) - self.assertEqual(shape_tree, expected) - - def test_xla_computation_psum_constant(self): - f = lambda: jax.lax.psum(1, "i") - api.xla_computation(f, axis_env=[("i", 2)])() # doesn't crash - - @jtu.ignore_warning(message="Some donated buffers were not usable") - def test_xla_computation_donate_argnums(self): - api.xla_computation(lambda x: None, donate_argnums=(0,))(3) # doesn't crash - - def test_xla_computation_lower_fun_axis_env(self): - axis_name = 'i' - def fn(x): - y = lax.all_gather( - x, axis_name=axis_name) - return y * lax.axis_index(axis_name).astype(jnp.float32) - - input_x = jnp.ones((5,6,4), dtype=jnp.float32) - axis_env = [(axis_name, jax.local_device_count())] - _ = api.xla_computation(fn, axis_env=axis_env, backend='cpu')(input_x) - - @jtu.ignore_warning(category=DeprecationWarning, message='jax.xla_computation is deprecated') - def test_xla_computation_axis_env(self): - is_accelerated = deprecations.is_accelerated_attribute(jax, 'xla_computation') - xla_computation = api.xla_computation if is_accelerated else jax.xla_computation - - def fn(x): - z = x * jax.lax.axis_index('i').astype(jnp.float32) - def inner_fn(carry, a): - return carry + a, () - return jax.lax.scan(inner_fn, jnp.zeros_like(z[0]), z) - - x = jnp.ones((5, 6, 4), dtype=jnp.float32) - _ = xla_computation(fn, axis_env=(('i', 8),), backend='cpu')(x) - def test_concurrent_device_get_and_put(self): def f(x): for _ in range(100): @@ -3101,8 +2959,10 @@ def check_warning(warn, nowarn): lambda: jnp.arange(1.0).astype(int)) def test_error_for_invalid_dtype(self): + err_str = ("Error interpreting argument to .* as an abstract array. The problematic " + r"value is of type .* and was passed to the function at path args\[1\].") with jax.enable_checks(False): - with self.assertRaisesRegex(TypeError, ".*not a valid JAX array type.*"): + with self.assertRaisesRegex(TypeError, err_str): lax.add(jnp.array(7), np.array("hello")) with jax.enable_checks(True): with self.assertRaises(AssertionError): @@ -3122,7 +2982,7 @@ def superfun(a): ])) def test_vmap_in_axes_list(self): - # https://github.com/google/jax/issues/2367 + # https://github.com/jax-ml/jax/issues/2367 dictionary = {'a': 5., 'b': jnp.ones(2)} x = jnp.zeros(3) y = jnp.arange(3.) @@ -3135,7 +2995,7 @@ def f(dct, x, y): self.assertAllClose(out1, out2) def test_vmap_in_axes_non_tuple_error(self): - # https://github.com/google/jax/issues/18548 + # https://github.com/jax-ml/jax/issues/18548 with self.assertRaisesRegex( TypeError, re.escape("vmap in_axes must be an int, None, or a tuple of entries corresponding " @@ -3143,7 +3003,7 @@ def test_vmap_in_axes_non_tuple_error(self): jax.vmap(lambda x: x['a'], in_axes={'a': 0}) def test_vmap_in_axes_wrong_length_tuple_error(self): - # https://github.com/google/jax/issues/18548 + # https://github.com/jax-ml/jax/issues/18548 with self.assertRaisesRegex( ValueError, re.escape("vmap in_axes must be an int, None, or a tuple of entries corresponding to the " @@ -3151,7 +3011,7 @@ def test_vmap_in_axes_wrong_length_tuple_error(self): jax.vmap(lambda x: x['a'], in_axes=(0, {'a': 0}))({'a': jnp.zeros((3, 3))}) def test_vmap_in_axes_tree_prefix_error(self): - # https://github.com/google/jax/issues/795 + # https://github.com/jax-ml/jax/issues/795 value_tree = jnp.ones(3) self.assertRaisesRegex( ValueError, @@ -3172,14 +3032,14 @@ def test_vmap_out_axes_leaf_types(self): api.vmap(lambda x: x, out_axes=(jnp.array([1., 2.]),))(jnp.array([1., 2.])) def test_vmap_unbatched_object_passthrough_issue_183(self): - # https://github.com/google/jax/issues/183 + # https://github.com/jax-ml/jax/issues/183 fun = lambda f, x: f(x) vfun = api.vmap(fun, (None, 0)) ans = vfun(lambda x: x + 1, jnp.arange(3)) self.assertAllClose(ans, np.arange(1, 4), check_dtypes=False) def test_vmap_mismatched_keyword(self): - # https://github.com/google/jax/issues/10193 + # https://github.com/jax-ml/jax/issues/10193 @jax.vmap def f(x, y): return x + y @@ -3193,7 +3053,7 @@ def f(x, y): f(jnp.array([1], 'int32'), y=jnp.array([1, 2], 'int32')) def test_vmap_mismatched_axis_sizes_error_message_issue_705(self): - # https://github.com/google/jax/issues/705 + # https://github.com/jax-ml/jax/issues/705 def h(a, b): return jnp.sum(a) + jnp.sum(b) @@ -3298,12 +3158,12 @@ def foo(tree_arg): self.assertEqual(vfoo(tree).shape, (6, 2, 5)) def test_vmap_in_axes_bool_error(self): - # https://github.com/google/jax/issues/6372 + # https://github.com/jax-ml/jax/issues/6372 with self.assertRaisesRegex(TypeError, "must be an int"): api.vmap(lambda x: x, in_axes=False)(jnp.zeros(3)) def test_pmap_in_axes_bool_error(self): - # https://github.com/google/jax/issues/6372 + # https://github.com/jax-ml/jax/issues/6372 with self.assertRaisesRegex(TypeError, "must be an int"): api.pmap(lambda x: x, in_axes=False)(jnp.zeros(1)) @@ -3365,7 +3225,7 @@ def test_device_array_hash(self): hash(rep) def test_grad_without_enough_args_error_message(self): - # https://github.com/google/jax/issues/1696 + # https://github.com/jax-ml/jax/issues/1696 def f(x, y): return x + y df = api.grad(f, argnums=0) self.assertRaisesRegex( @@ -3433,17 +3293,17 @@ def test_grad_of_jit_compilation_caching2(self): def f(x): return jnp.sin(x) - with jtu.count_jit_and_pmap_compiles() as count: # noqa: F841 + with jtu.count_jit_and_pmap_lowerings() as count: # noqa: F841 _ = jax.grad(f)(3.) self.assertEqual(count[0], 2) # one for fwd, one for bwd - with jtu.count_jit_and_pmap_compiles() as count: # noqa: F841 + with jtu.count_jit_and_pmap_lowerings() as count: # noqa: F841 _ = jax.grad(f)(3.) _ = jax.grad(f)(4.) self.assertEqual(count[0], 0) # cache hits on both fwd and bwd def test_grad_does_not_unflatten_tree_with_none(self): - # https://github.com/google/jax/issues/7546 + # https://github.com/jax-ml/jax/issues/7546 class CustomNode(list): pass @@ -3512,7 +3372,7 @@ def test_primitive_compilation_cache(self): self.assertEqual(count[0], 1) def test_arange_jit(self): - # see https://github.com/google/jax/issues/553 + # see https://github.com/jax-ml/jax/issues/553 def fun(x): r = jnp.arange(x.shape[0])[x] return r @@ -3638,7 +3498,7 @@ def test_escaped_tracer_shape_dtype(self): _ = self._saved_tracer+1 def test_pmap_static_kwarg_error_message(self): - # https://github.com/google/jax/issues/3007 + # https://github.com/jax-ml/jax/issues/3007 def f(a, b): return a + b @@ -3664,7 +3524,7 @@ def f(x): return x + y + y x = np.array([1, 2], dtype=np.float32) - hlo_lines = jax.xla_computation(f)(x).as_hlo_text().split('\n') + hlo_lines = jax.jit(f).lower(x).as_text('hlo').split('\n') hlo_lines = {s.strip() for s in hlo_lines} self.assertIn('constant.1 = f32[2]{0} constant({7, 14})', hlo_lines) self.assertNotIn('constant.2 = f32[2]{0} constant({7, 14})', hlo_lines) @@ -3791,13 +3651,8 @@ def g(x): with self.assertRaisesRegex(core.ConcretizationTypeError, msg): g(1) - def test_xla_computation_zeros_doesnt_device_put(self): - with jtu.count_device_put() as count: - api.xla_computation(lambda: jnp.zeros(3))() - self.assertEqual(count[0], 0) - def test_join_concrete_arrays_with_omnistaging(self): - # https://github.com/google/jax/issues/4622 + # https://github.com/jax-ml/jax/issues/4622 x = jnp.array([1., 2., 3.]) y = jnp.array([1., 2., 4.]) @@ -3820,7 +3675,7 @@ def fn(x): self.assertEqual(aux, True) def test_linearize_aval_error(self): - # https://github.com/google/jax/issues/4622 + # https://github.com/jax-ml/jax/issues/4622 f = lambda x: x # these should not error @@ -3838,7 +3693,7 @@ def test_linearize_aval_error(self): f_jvp(np.ones(2, np.int32)) def test_grad_of_token_consuming_primitive(self): - # https://github.com/google/jax/issues/5463 + # https://github.com/jax-ml/jax/issues/5463 tokentest_p = core.Primitive("tokentest") tokentest_p.def_impl(partial(xla.apply_primitive, tokentest_p)) tokentest_p.def_abstract_eval(lambda x, y: x) @@ -3970,7 +3825,7 @@ def g(x): f(3) def test_leak_checker_avoids_false_positive_custom_jvp(self): - # see https://github.com/google/jax/issues/5636 + # see https://github.com/jax-ml/jax/issues/5636 with jax.checking_leaks(): @jax.custom_jvp def t(y): @@ -4053,7 +3908,7 @@ def test_default_device(self): self.assertEqual(jnp.ones(1).devices(), system_default_devices) def test_dunder_jax_array(self): - # https://github.com/google/jax/pull/4725 + # https://github.com/jax-ml/jax/pull/4725 class AlexArray: def __init__(self, jax_val): @@ -4085,6 +3940,18 @@ def __jax_array__(self): a2 = jnp.array(((x, x), [x, x])) self.assertAllClose(np.array(((1, 1), (1, 1))), a2) + def test_eval_shape_weak_type(self): + # https://github.com/jax-ml/jax/issues/23302 + arr = jax.numpy.array(1) + + with jtu.count_jit_tracing_cache_miss() as count: + jax.eval_shape(jax.numpy.array, 1) + out = jax.eval_shape(jax.numpy.array, 1) + + self.assertEqual(count[0], 1) + self.assertTrue(out.weak_type) + self.assertEqual(out.weak_type, arr.weak_type) + def test_dunder_jax_array_bug(self): @jax.tree_util.register_pytree_node_class class A: @@ -4115,7 +3982,7 @@ def __jax_array__(self) -> jax.Array: f(a, a) # don't crash def test_constant_handler_mro(self): - # https://github.com/google/jax/issues/6129 + # https://github.com/jax-ml/jax/issues/6129 class Foo(enum.IntEnum): bar = 1 @@ -4132,7 +3999,7 @@ def f(_): {"testcase_name": f"{dtype.__name__}", "dtype": dtype} for dtype in jtu.dtypes.all]) def test_constant_handlers(self, dtype): - # https://github.com/google/jax/issues/9380 + # https://github.com/jax-ml/jax/issues/9380 @jax.jit def f(): return jnp.exp(dtype(0)) @@ -4270,7 +4137,7 @@ def f(x): jaxpr = api.make_jaxpr(f)(3) self.assertNotIn('pjit', str(jaxpr)) - # Repro for https://github.com/google/jax/issues/7229. + # Repro for https://github.com/jax-ml/jax/issues/7229. def test_compute_with_large_transfer(self): def f(x, delta): return x + jnp.asarray(delta, x.dtype) @@ -4328,7 +4195,7 @@ def transpose(f, x): self.assertEqual(actual, expected) def test_leaked_tracer_issue_7613(self): - # from https://github.com/google/jax/issues/7613 + # from https://github.com/jax-ml/jax/issues/7613 import numpy.random as npr def sigmoid(x): @@ -4346,13 +4213,13 @@ def loss(A, x): _ = jax.grad(loss)(A, x) # doesn't crash def test_vmap_caching(self): - # https://github.com/google/jax/issues/7621 + # https://github.com/jax-ml/jax/issues/7621 f = lambda x: jnp.square(x).mean() jf = jax.jit(f) x = jax.random.uniform(jax.random.key(0), shape=(8, 4)) - with jtu.count_jit_and_pmap_compiles() as count: # noqa: F841 + with jtu.count_jit_and_pmap_lowerings() as count: # noqa: F841 for _ in range(5): jax.hessian(jf)(x).block_until_ready() @@ -4434,7 +4301,7 @@ def g(x, y): self.assertEqual(2 * i, g(2, i), msg=i) def test_fastpath_cache_confusion(self): - # https://github.com/google/jax/issues/12542 + # https://github.com/jax-ml/jax/issues/12542 @jax.jit def a(x): return () @@ -4479,7 +4346,7 @@ def h(x): b(8) # don't crash def test_vjp_multiple_arguments_error_message(self): - # https://github.com/google/jax/issues/13099 + # https://github.com/jax-ml/jax/issues/13099 def foo(x): return (x, x) _, f_vjp = jax.vjp(foo, 1.0) @@ -4511,7 +4378,7 @@ def foo(x, y, z): self.assertEqual(jfoo.__module__, "jax") def test_inner_jit_function_retracing(self): - # https://github.com/google/jax/issues/7155 + # https://github.com/jax-ml/jax/issues/7155 inner_count = outer_count = 0 @jax.jit @@ -4538,7 +4405,7 @@ def outer_fn(x): self.assertEqual(outer_count, 1) def test_grad_conj_symbolic_zeros(self): - # https://github.com/google/jax/issues/15400 + # https://github.com/jax-ml/jax/issues/15400 f = lambda x: jax.jit(lambda x, y: (x, y))(x, jax.lax.conj(x))[0] out = jax.grad(f)(3.0) # doesn't crash self.assertAllClose(out, 1., check_dtypes=False) @@ -4690,7 +4557,7 @@ def test_jit_custom_floats(self, dtype): self._CompileAndCheck(f, args_maker) def test_jvp_asarray_returns_array(self): - # https://github.com/google/jax/issues/15676 + # https://github.com/jax-ml/jax/issues/15676 p, t = jax.jvp(jax.numpy.asarray, (1.,), (2.,)) _check_instance(self, p) _check_instance(self, t) @@ -4850,6 +4717,19 @@ def g(): with self.assertRaisesRegex(TracerBoolConversionError, "Attempted boolean"): f() + def test_inline_return_twice(self): + # https://github.com/jax-ml/jax/issues/22944 + @jax.jit + def add_one(x: int) -> int: + return x + 1 + + def add_one_and_dupe(x: int) -> tuple[int, int]: + y = add_one(x) + return (y, y) + + jit_add_one_dupe = jax.jit(add_one_and_dupe, inline=True) + jax.eval_shape(jit_add_one_dupe, 0) # don't crash + class RematTest(jtu.JaxTestCase): @@ -5196,7 +5076,7 @@ def f_yesremat(x): ('_new', partial(new_checkpoint, policy=lambda *_, **__: False)), ]) def test_remat_no_redundant_flops(self, remat): - # see https://github.com/google/jax/pull/1749#issuecomment-558267584 + # see https://github.com/jax-ml/jax/pull/1749#issuecomment-558267584 @api.jit def g(x): @@ -5246,7 +5126,7 @@ def binom_checkpoint(funs): ('_new', new_checkpoint), ]) def test_remat_symbolic_zeros(self, remat): - # code from https://github.com/google/jax/issues/1907 + # code from https://github.com/jax-ml/jax/issues/1907 key = jax.random.key(0) key, split = jax.random.split(key) @@ -5299,7 +5179,7 @@ def g(): ('_new', new_checkpoint), ]) def test_remat_nontrivial_env(self, remat): - # simplified from https://github.com/google/jax/issues/2030 + # simplified from https://github.com/jax-ml/jax/issues/2030 @remat def foo(state, dt=0.5, c=1): @@ -5333,7 +5213,7 @@ def loss(u0, target, steps, dt=1/jnp.sqrt(2), c=1): ('_new', partial(new_checkpoint, policy=lambda *_, **__: False)), ]) def test_remat_jit3(self, remat): - # https://github.com/google/jax/issues/2180 + # https://github.com/jax-ml/jax/issues/2180 def f(w, x): a = jnp.dot(x, w) b = jnp.einsum("btd,bTd->btT", a, a) @@ -5366,7 +5246,7 @@ def f(w, x): ('_new', new_checkpoint), ]) def test_remat_scan2(self, remat): - # https://github.com/google/jax/issues/1963 + # https://github.com/jax-ml/jax/issues/1963 def scan_bug(x0): f = lambda x, _: (x + 1, None) @@ -5378,7 +5258,7 @@ def scanned_f(x, _): jax.grad(scan_bug)(1.0) # doesn't crash def test_remat_jit_static_argnum_omnistaging(self): - # https://github.com/google/jax/issues/2833 + # https://github.com/jax-ml/jax/issues/2833 # NOTE(mattjj): after #3370, this test doesn't actually call remat... def named_call(f): def named_f(*args): @@ -5403,7 +5283,7 @@ def f(a_bool, y): ('_new', partial(new_checkpoint, policy=lambda *_, **__: False)), ]) def test_remat_eval_counter(self, remat): - # https://github.com/google/jax/issues/2737 + # https://github.com/jax-ml/jax/issues/2737 add_one_p = core.Primitive('add_one') add_one = add_one_p.bind @@ -5493,13 +5373,12 @@ def f(x): x, _ = g(x) return x - c = api.xla_computation(f)(2.) - self.assertNotIn('while', c.as_hlo_text()) - self.assertNotIn('conditional', c.as_hlo_text()) - self.assertNotIn('opt-barrier', c.as_hlo_text()) + text = jax.jit(f).lower(2.).as_text('hlo') + self.assertNotIn('while', text) + self.assertNotIn('conditional', text) + self.assertNotIn('opt-barrier', text) - c = api.xla_computation(grad(f))(2.) - text = c.as_hlo_text() + text = jax.jit(grad(f)).lower(2.).as_text('hlo') self.assertTrue('while' in text or 'conditional' in text or 'opt-barrier' in text) @@ -5518,13 +5397,13 @@ def f(x): x, _ = g(x) return x - c = api.xla_computation(f)(2.) - self.assertNotIn('while', c.as_hlo_text()) - self.assertNotIn('conditional', c.as_hlo_text()) + text = jax.jit(f).lower(2.).as_text('hlo') + self.assertNotIn('while', text) + self.assertNotIn('conditional', text) - c = api.xla_computation(grad(f))(2.) - self.assertNotIn('while', c.as_hlo_text()) - self.assertNotIn('conditional', c.as_hlo_text()) + text = jax.jit(grad(f)).lower(2.).as_text('hlo') + self.assertNotIn('while', text) + self.assertNotIn('conditional', text) @parameterized.named_parameters( {"testcase_name": f"_{policy_name}_{remat_name}", "remat": remat, @@ -5788,7 +5667,7 @@ def test_constants_not_hoisted(self): # The old implementation of remat worked by data dependence, and so # (potentially large) constants would not be rematerialized and could be # wastefully instantiated. This test checks that the newer remat - # implementation avoids that. See https://github.com/google/jax/pull/8191. + # implementation avoids that. See https://github.com/jax-ml/jax/pull/8191. # no residuals from constants created inside jnp.einsum @partial(new_checkpoint, policy=lambda *_, **__: False) @@ -5889,6 +5768,27 @@ def f(x, y): self.assertStartsWith(res[4][1], "named 'z'") self.assertEqual(res[5][0].shape, ()) + def test_saved_residuals_utility_jit(self): + @jax.jit + def f(x, y): + x1, x2 = x + z = checkpoint_name(jnp.sin(3.), 'z') + return z * ((x1 * x2) * y) * np.array([3.]) + + res = saved_residuals(f, (2., 3.), y=4.) + self.assertLen(res, 6) + self.assertEqual(res[0][0].shape, ()) + self.assertEqual(res[0][1], "from the argument x[0]") + self.assertEqual(res[1][0].shape, ()) + self.assertEqual(res[1][1], "from the argument x[1]") + self.assertEqual(res[2][0].shape, ()) + self.assertEqual(res[2][1], "from the argument y") + self.assertEqual(res[3][0].shape, ()) + self.assertStartsWith(res[3][1], "output of jitted function 'f'") + self.assertEqual(res[4][0].shape, ()) + self.assertEqual(res[5][0].shape, (1,)) + self.assertStartsWith(res[5][1], "output of jitted function 'f'") + @parameterized.named_parameters( {"testcase_name": f"{suffix}", "remat": remat} for suffix, remat in [ @@ -5913,16 +5813,16 @@ def f(x): _ = jax.grad(f)(3.) # doesn't crash def test_linearize_caching(self): - # https://github.com/google/jax/issues/9661 + # https://github.com/jax-ml/jax/issues/9661 identity = jax.checkpoint(jax.jit(lambda x: 2 * x)) _, f_lin = jax.linearize(identity, 1.) - with jtu.count_jit_and_pmap_compiles() as count: # noqa: F841 + with jtu.count_jit_and_pmap_lowerings() as count: # noqa: F841 for _ in range(20): f_lin(1.).block_until_ready() self.assertEqual(count[0], 1) # cached after first execution def test_vjp_caching(self): - # https://github.com/google/jax/issues/9661 + # https://github.com/jax-ml/jax/issues/9661 identity = jax.checkpoint(jax.jit(lambda x: 2 * x)) _, f_vjp = jax.vjp(identity, 1.) with jtu.count_pjit_cpp_cache_miss() as count: # noqa: F841 @@ -5934,7 +5834,7 @@ def test_vjp_caching_static_argnums(self): identity = jax.remat(lambda x, y: jax.jit(lambda x: 2 * x if y else x)(x), static_argnums=(1,)) _, f_vjp = jax.vjp(lambda x: identity(x, True), 1.) - with jtu.count_jit_and_pmap_compiles() as count: # noqa: F841 + with jtu.count_jit_and_pmap_lowerings() as count: # noqa: F841 for _ in range(20): f_vjp(1.)[0].block_until_ready() self.assertEqual(count[0], 2) # fwd execute_trivial, backward_pass on bwd @@ -5942,7 +5842,7 @@ def test_vjp_caching_static_argnums(self): def test_fwd_caching(self): # see above test also identity = jax.checkpoint(jax.jit(lambda x: 2 * x)) - with jtu.count_jit_and_pmap_compiles() as count: # noqa: F841 + with jtu.count_jit_and_pmap_lowerings() as count: # noqa: F841 for _ in range(20): y, _ = jax.vjp(identity, 1.) y.block_until_ready() @@ -5951,7 +5851,7 @@ def test_fwd_caching(self): def test_fwd_caching_static_argnums(self): # see above test also identity = jax.checkpoint(jax.jit(lambda x: 2 * x), static_argnums=(0,)) - with jtu.count_jit_and_pmap_compiles() as count: # noqa: F841 + with jtu.count_jit_and_pmap_lowerings() as count: # noqa: F841 for _ in range(20): y = identity(1.) y.block_until_ready() @@ -6569,49 +6469,6 @@ def f(x): jaxpr = api.make_jaxpr(f, axis_env=[('i', 4)])(2) self.assertIn('psum', str(jaxpr)) - def test_make_jaxpr_named(self): - raise unittest.SkipTest("named shape are deprecated") - def f(x): - return x - lax.psum(x, 'i') - - x = api.ShapeDtypeStruct( - shape=(2, 3), dtype=jnp.dtype(jnp.float32), named_shape={'i': 10}) - jaxpr = api.make_jaxpr(f, axis_env=[('i', 10)])(x) - named_shapes = [v.aval.named_shape for v in jaxpr.jaxpr.eqns[1].invars] - self.assertEqual(named_shapes, [{'i': 10}, {}]) - - @parameterized.parameters(True, False) - def test_vjp_reduce_axes_jaxpr(self, gy_batched): - raise unittest.SkipTest("reduce_axes autodiff is removed") - def f(w, x): - return jnp.sin(jnp.dot(x, w)) - - w = api.ShapeDtypeStruct( - shape=(3, 4), dtype=jnp.float32, named_shape={}) - x = api.ShapeDtypeStruct( - shape=(3,), dtype=jnp.float32, named_shape={'batch': 2}) - gy = api.ShapeDtypeStruct( - shape=(4,), dtype=jnp.float32, - named_shape={'batch': 2} if gy_batched else {}) - - # per-example - jaxpr, shapes = api.make_jaxpr( - lambda w, x, gy: api.vjp(f, w, x)[1](gy), axis_env=[('batch', 2)], - return_shape=True)(w, x, gy) - expected = (api.ShapeDtypeStruct( - shape=(3, 4), dtype=jnp.float32, named_shape={'batch': 2}), x) - self.assertEqual(shapes, expected) - self.assertNotIn('psum', str(jaxpr)) - - # reduced - jaxpr, shapes = api.make_jaxpr( - lambda w, x, gy: api.vjp(f, w, x, reduce_axes=('batch',))[1](gy), - axis_env=[('batch', 2)], - return_shape=True)(w, x, gy) - expected = (w, x) - self.assertEqual(shapes, expected) - self.assertIn('psum', str(jaxpr)) - def test_weak_type_jit_invariance(self): y = jnp.broadcast_to(3., (3,)) self.assertTrue(y.aval.weak_type) @@ -6640,7 +6497,7 @@ def test_elide_trivial_broadcasts(self): self.assertLen(jaxpr.jaxpr.eqns, 0) def test_convert_element_type_literal_constant_folding(self): - # this convert_elemnt_type is nontrivial, but because it's on a scalar we + # this convert_element_type is nontrivial, but because it's on a scalar we # constant-fold it cet = partial(lax.convert_element_type, new_dtype='float16') jaxpr = api.make_jaxpr(lambda: cet(3.))() @@ -7094,7 +6951,7 @@ def f_jvp(primals, tangents): check_dtypes=False) def test_kwargs(self): - # from https://github.com/google/jax/issues/1938 + # from https://github.com/jax-ml/jax/issues/1938 @jax.custom_jvp def my_fun(x, y, c=1.): return c * (x + y) @@ -7369,12 +7226,13 @@ def foo_jvp(primals, tangents): TypeError, re.escape( "Custom JVP rule must produce primal and tangent outputs " - "with equal shapes and dtypes, but got float32[] and float32[1] " - "respectively."), + "with corresponding shapes and dtypes. " + "Expected float32[] (tangent type of float32[]) but got float32[1]."), lambda: api.jvp(f, (jnp.float32(2.),), (jnp.float32(1.),))) + def test_jvp_rule_doesnt_return_pair_error_message(self): - # https://github.com/google/jax/issues/2516 + # https://github.com/jax-ml/jax/issues/2516 @jax.custom_jvp def f(x): @@ -7539,7 +7397,7 @@ def _expit_jvp(primals, tangents): api.eval_shape(api.grad(lambda x: expit(x).sum()), jnp.ones((2, 3))) def test_jaxpr_zeros(self): - # from https://github.com/google/jax/issues/2657 + # from https://github.com/jax-ml/jax/issues/2657 @jax.custom_jvp def f(A, b): return A @ b @@ -7585,7 +7443,7 @@ def foo(x): self.assertAllClose(ans, expected, check_dtypes=False) def test_custom_jvps_first_rule_is_none(self): - # https://github.com/google/jax/issues/3389 + # https://github.com/jax-ml/jax/issues/3389 @jax.custom_jvp def f(x, y): return x ** 2 * y @@ -7596,7 +7454,7 @@ def f(x, y): self.assertAllClose(ans, expected, check_dtypes=False) def test_concurrent_initial_style(self): - # https://github.com/google/jax/issues/3843 + # https://github.com/jax-ml/jax/issues/3843 def unroll(param, sequence): def scan_f(prev_state, inputs): return prev_state, jax.nn.sigmoid(param * inputs) @@ -7618,7 +7476,7 @@ def run(): self.assertAllClose(ans, expected) def test_nondiff_argnums_vmap_tracer(self): - # https://github.com/google/jax/issues/3964 + # https://github.com/jax-ml/jax/issues/3964 @partial(jax.custom_jvp, nondiff_argnums=(0, 2)) def sample(shape, param, seed): return jax.random.uniform(key=seed, shape=shape, minval=param) @@ -7660,7 +7518,7 @@ def baz(w): api.vmap(fun_with_nested_calls_2)(jnp.arange(3.)) def test_closure_with_vmap(self): - # https://github.com/google/jax/issues/3822 + # https://github.com/jax-ml/jax/issues/3822 alpha = np.float32(2.) def sample(seed): @@ -7680,7 +7538,7 @@ def f_jvp(primal, tangent): api.vmap(sample)(jax.random.split(jax.random.key(1), 3)) # don't crash def test_closure_with_vmap2(self): - # https://github.com/google/jax/issues/8783 + # https://github.com/jax-ml/jax/issues/8783 def h(z): def f(x): @jax.custom_jvp @@ -7702,12 +7560,13 @@ def g_jvp(primals, tangents): self.assertAllClose(tangents, 2 * jnp.arange(3., dtype='float32')) def test_float0(self): + scalar_float0 = jnp.zeros((), dtype=float0) @jax.custom_jvp def f(x, y): return x, y def f_jvp(primals, _): - # we need a defined (non-float0) tangent to trigger the rule - return primals, (2., 1) + x, y = primals + return (x, y), (2., custom_derivatives_public.zero_from_primal(y)) f.defjvp(f_jvp) primals = (2., 3) @@ -7717,12 +7576,13 @@ def f_jvp(primals, _): (primals, expected_tangents)) def test_float0_initial_style(self): + scalar_float0 = jnp.zeros((), dtype=float0) @jax.custom_jvp def f(x, y): return x, y def f_jvp(primals, _): x, y = primals - return (x, y), (2., 1) + return (x, y), (2., custom_derivatives_public.zero_from_primal(y)) f.defjvp(f_jvp) def foo(x, y): @@ -7730,8 +7590,9 @@ def foo(x, y): return out primals = (2., 3) - tangents = (np.ones(()), np.zeros((), float0),) - expected_tangents = (2., np.zeros((), float0)) + tangents = (np.ones(()), scalar_float0) + expected_tangents = (2., scalar_float0) + self.assertAllClose(api.jvp(foo, primals, tangents), (primals, expected_tangents)) @@ -7822,7 +7683,7 @@ def foo(x): self.assertAllClose(ans, expected, check_dtypes=False) def test_custom_jvp_vmap_broadcasting_interaction(self): - # https://github.com/google/jax/issues/6452 + # https://github.com/jax-ml/jax/issues/6452 def f2(y, z): v1 = z v2 = jnp.sum(y) + z @@ -7840,7 +7701,7 @@ def f1(y, z): self.assertEqual(g.shape, ()) def test_custom_jvp_vmap_broadcasting_interaction_2(self): - # https://github.com/google/jax/issues/5849 + # https://github.com/jax-ml/jax/issues/5849 @jax.custom_jvp def transform(box, R): if jnp.isscalar(box) or box.size == 1: @@ -7878,7 +7739,7 @@ def energy_fn(box): self.assertEqual(grad(energy_fn)(scalar_box).shape, ()) def test_custom_jvp_implicit_broadcasting(self): - # https://github.com/google/jax/issues/6357 + # https://github.com/jax-ml/jax/issues/6357 if config.enable_x64.value: raise unittest.SkipTest("test only applies when x64 is disabled") @@ -7936,7 +7797,7 @@ def fun(X): self.assertAllClose(dir_deriv, dir_deriv_num, atol=1e-3) def test_vmap_inside_defjvp(self): - # https://github.com/google/jax/issues/3201 + # https://github.com/jax-ml/jax/issues/3201 seed = 47 key = jax.random.key(seed) mat = jax.random.normal(key, (2, 3)) @@ -7985,7 +7846,7 @@ def operate(mx, val): jax.grad(lambda mat, aux: jnp.sum(f(mat, aux)))(mat, 0.5) # doesn't crash def test_custom_jvp_unbroadcasting(self): - # https://github.com/google/jax/issues/3056 + # https://github.com/jax-ml/jax/issues/3056 a = jnp.array([1., 1.]) @jax.custom_jvp @@ -8003,8 +7864,8 @@ def f_jvp(primals, tangents): def test_maybe_perturbed_internal_helper_function(self): # This is a unit test for an internal API. We include it so as not to - # regress https://github.com/google/jax/issues/9567. For an explanation of - # this helper function, see https://github.com/google/jax/issues/6415. + # regress https://github.com/jax-ml/jax/issues/9567. For an explanation of + # this helper function, see https://github.com/jax-ml/jax/issues/6415. def f(x): def g(y, _): z = y * x @@ -8016,7 +7877,7 @@ def g(y, _): jax.jvp(f, (1.0,), (1.0,)) # assertions inside f def test_maybe_perturbed_int_regression(self): - # see https://github.com/google/jax/discussions/9951 + # see https://github.com/jax-ml/jax/discussions/9951 @jax.jit def f(): @@ -8026,7 +7887,7 @@ def f(): f() def test_sinc_constant_function_batching(self): - # https://github.com/google/jax/pull/10756 + # https://github.com/jax-ml/jax/pull/10756 batch_data = jnp.arange(15.).reshape(5, 3) @jax.vmap @@ -8143,7 +8004,7 @@ def f_jvp(primals, tangents): _ = jax.linearize(lambda x: f_(x, 3.), 2.) # don't crash! def test_symbolic_zeros_under_jit(self): - # https://github.com/google/jax/issues/14833 + # https://github.com/jax-ml/jax/issues/14833 Zero = jax.custom_derivatives.SymbolicZero @jax.custom_jvp @@ -8177,7 +8038,7 @@ def jvp_fn(primals, tangents): self.assertEqual((1.0, 0.1), jax.grad(lambda args: fn(*args))((1.0, 2.0))) def test_run_rules_more_than_once(self): - # https://github.com/google/jax/issues/16614 + # https://github.com/jax-ml/jax/issues/16614 @jax.custom_jvp def f(x, y): @@ -8368,7 +8229,7 @@ def f_rev(cos_x, g): lambda: api.jvp(jit(f), (3.,), (1.,))) def test_kwargs(self): - # from https://github.com/google/jax/issues/1938 + # from https://github.com/jax-ml/jax/issues/1938 @jax.custom_vjp def my_fun(x, y, c=1.): return c * (x + y) @@ -8664,7 +8525,7 @@ def test_issue2511(self): api.jit(foo)(arr) # doesn't crash def test_lowering_out_of_traces(self): - # https://github.com/google/jax/issues/2578 + # https://github.com/jax-ml/jax/issues/2578 class F(collections.namedtuple("F", ["a"])): def __call__(self, x): @@ -8677,7 +8538,7 @@ def g(f, x): jax.grad(g, argnums=(1,))(F(2.0), 0.) # doesn't crash def test_clip_gradient(self): - # https://github.com/google/jax/issues/2784 + # https://github.com/jax-ml/jax/issues/2784 @jax.custom_vjp def _clip_gradient(lo, hi, x): return x # identity function when not differentiating @@ -8700,7 +8561,7 @@ def clip_gradient(x): self.assertAllClose(g, jnp.array(0.2)) def test_nestable_vjp(self): - # Verify that https://github.com/google/jax/issues/3667 is resolved. + # Verify that https://github.com/jax-ml/jax/issues/3667 is resolved. def f(x): return x ** 2 @@ -8733,7 +8594,7 @@ def z(x): self.assertAllClose(y, jnp.array(6.0)) def test_initial_style_vmap_2(self): - # https://github.com/google/jax/issues/4173 + # https://github.com/jax-ml/jax/issues/4173 x = jnp.ones((10, 3)) # Create the custom function @@ -8896,7 +8757,7 @@ def f(x): def f_fwd(x): return x, (2., x) def f_rev(*_): - return ((2., 1),) + return ((2., jnp.zeros(shape=(), dtype=float0)),) f.defvjp(f_fwd, f_rev) def foo(x, y): @@ -8999,7 +8860,7 @@ def f_rev(cos, g): self.assertAllClose(ans, expected, check_dtypes=False) def test_custom_vjp_closure_4521(self): - # https://github.com/google/jax/issues/4521 + # https://github.com/jax-ml/jax/issues/4521 @jax.custom_vjp def g(x, y): return None @@ -9116,7 +8977,7 @@ def closure(x): def test_closure_convert_mixed_consts(self): # Like test_closure_convert, but close over values that # participate in AD as well as values that do not. - # See https://github.com/google/jax/issues/6415 + # See https://github.com/jax-ml/jax/issues/6415 def cos_after(fn, x): converted_fn, aux_args = jax.closure_convert(fn, x) @@ -9154,6 +9015,19 @@ def closure(x): self.assertAllClose(g_c, 42. * c, check_dtypes=False) self.assertAllClose(g_x, 17. * x, check_dtypes=False) + def test_closure_convert_pytree_mismatch(self): + # See https://github.com/jax-ml/jax/issues/23588 + def f(x, z): + return z * x + + x, z = 2.0, 3.0 + _, vjp = api.vjp(f, x, z) + vjp_pure, vjp_aux_args = jax.closure_convert(vjp, x) + vjp_pure(x, *vjp_aux_args) + with self.assertRaisesRegex( + TypeError, "The inputs to the closure produced by closure_convert"): + vjp_pure(x, vjp_aux_args) + def test_float0_cotangents_automatically_handled(self): @jax.custom_vjp def f(x, y): @@ -9170,7 +9044,7 @@ def f_bwd(_, zbar): jax.jit(lambda x: jax.vjp(f, 0., x)[1](1.))(1) # doesn't crash def test_custom_vjp_scan_batching_edge_case(self): - # https://github.com/google/jax/issues/5832 + # https://github.com/jax-ml/jax/issues/5832 @jax.custom_vjp def mul(x, coeff): return x * coeff def mul_fwd(x, coeff): return mul(x, coeff), (x, coeff) @@ -9201,7 +9075,7 @@ def f_(x, t): modes=['rev']) def test_closure_with_vmap2(self): - # https://github.com/google/jax/issues/8783 + # https://github.com/jax-ml/jax/issues/8783 def h(z): def f(x): @jax.custom_vjp @@ -9243,7 +9117,7 @@ def f_bwd(_, g): jax.grad(f)(A([1.])) # doesn't crash def test_vmap_vjp_called_twice(self): - # https://github.com/google/jax/pull/14728 + # https://github.com/jax-ml/jax/pull/14728 @jax.custom_vjp def f(x): return x @@ -9539,7 +9413,7 @@ def f_bwd(_, z_bar): _ = jax.linearize(lambda x: f_(x, 3.), 2.) # don't crash! def test_run_rules_more_than_once(self): - # https://github.com/google/jax/issues/16614 + # https://github.com/jax-ml/jax/issues/16614 @jax.custom_vjp def f(x, y): @@ -9569,7 +9443,7 @@ def g(x): g(1.) # doesn't crash def test_nones_representing_zeros_in_subtrees_returned_by_bwd(self): - # https://github.com/google/jax/issues/8356 + # https://github.com/jax-ml/jax/issues/8356 @jax.custom_vjp def f(x): return x[0] @@ -9749,6 +9623,48 @@ def f_bwd(res, g): x, y = 3.2, 1.0 self.assertAllClose(jax.grad(f)(x, y), jax.grad(f_)(x, y)) + def test_optimize_remat_kwargs(self): + @jax.custom_vjp + def f(x, y): + return jnp.sin(x) * y + + def f_fwd(x, y, *, keyword=False): + del keyword + return f(x, y), (jnp.cos(x), jnp.sin(x), y) + + def f_bwd(res, g): + cos_x, sin_x, y = res + return (cos_x * g * y, sin_x * g) + + f.defvjp(f_fwd, f_bwd, optimize_remat=True) + x, y = 3.2, 1.0 + jax.grad(f)(x, y) # Doesn't error + + def test_optimize_remat_custom_vmap(self): + # See https://github.com/jax-ml/jax/pull/23000 + @jax.custom_vjp + def f(x, y): + return jnp.sin(x) * y + + @jax.custom_batching.custom_vmap + def f_fwd(x, y): + return f(x, y), (jnp.cos(x), jnp.sin(x), y) + + @f_fwd.def_vmap + def f_fwd_vmap(_, in_batched, x, y): + # Insert a new const here to test the optimize_remat batching rule. + out = np.array([2.0])*f(x, y) + out_batched = (True, (True, True, True)) + return (out, (jnp.cos(x), jnp.sin(x), y)), out_batched + + def f_bwd(res, g): + cos_x, sin_x, y = res + return (cos_x * g * y, sin_x * g) + + f.defvjp(f_fwd, f_bwd, optimize_remat=True) + x, y = jnp.linspace(0.0, 1.0, 5), jnp.linspace(2.0, 5.0, 5) + jax.jit(jax.vmap(jax.grad(f)))(x, y) # Doesn't error + def transpose_unary(f, x_example): def transposed(y): @@ -9781,12 +9697,12 @@ def __call__(self, *args): # an option of inferring output types. def custom_transpose(example_out): if isinstance(example_out, Callable): - out_type = core.get_aval(0.).at_least_vspace() + out_type = core.get_aval(0.).to_tangent_aval() return _custom_transpose(out_type, example_out) return partial( _custom_transpose, jax.tree.map( - lambda x: core.get_aval(x).at_least_vspace(), example_out)) + lambda x: core.get_aval(x).to_tangent_aval(), example_out)) class CustomTransposeTest(jtu.JaxTestCase): @@ -10780,7 +10696,6 @@ def test_batch_map_pytrees(self, batch_size: int): ) self.assertAllClose(outputs['b'], expected) - def test_batch_divides_axis(self): def f(t): x, a = t @@ -10798,6 +10713,32 @@ def g(x, a): self.assertAllClose(y, (x + a)**2) + def test_undefined_rule(self): + @jax.custom_batching.custom_vmap + def f(x): return jnp.sin(x) + + with self.assertRaisesRegex( + AttributeError, "No batching rule defined for custom_vmap function f"): + f(0.5) + + def test_kwargs(self): + @jax.custom_batching.custom_vmap + def f(x): return jnp.sin(x) + + @f.def_vmap + def rule(axis_size, in_batched, xs): + xs_batched, = in_batched + self.assertEqual(xs_batched, True) + self.assertEqual(axis_size, xs.shape[0]) + return jnp.cos(xs), xs_batched + + x, xs = jnp.array(1.), jnp.arange(3) + y = f(x=x) + self.assertAllClose(y, jnp.sin(x)) + ys = api.vmap(f)(x=xs) + self.assertAllClose(ys, jnp.cos(xs)) + + class CustomApiTest(jtu.JaxTestCase): """Test interactions among the custom_{vmap,jvp,vjp,transpose,*} APIs""" @@ -10860,25 +10801,6 @@ def test_pmap_nested_donate_ignored(self): class NamedCallTest(jtu.JaxTestCase): - @jtu.ignore_warning(category=DeprecationWarning, message='jax.xla_computation is deprecated') - def test_default_name(self): - is_accelerated = deprecations.is_accelerated_attribute(jax, 'xla_computation') - xla_computation = api.xla_computation if is_accelerated else jax.xla_computation - - @api.named_call - def my_test_function(x): - return x**2 - - @jax.jit - def f(x): - return my_test_function(x) - - c = xla_computation(f)(2) - print_opts = xla_client._xla.HloPrintOptions.short_parsable() - print_opts.print_metadata = True - hlo_text = c.as_hlo_module().to_string(print_opts) - self.assertIn("my_test_function", hlo_text) - def test_non_jaxtype_arg(self): # For the test to fail without the invalid JaxType filter we need to pass # in a valid JaxType that forces the invalid Jaxtype to be raised to an @@ -11009,7 +10931,7 @@ def test_autodidax_smoketest(self): class GarbageCollectionTest(jtu.JaxTestCase): def test_xla_gc_callback(self): - # https://github.com/google/jax/issues/14882 + # https://github.com/jax-ml/jax/issues/14882 x_np = np.arange(10, dtype='int32') x_jax = jax.device_put(x_np) x_np_weakref = weakref.ref(x_np) diff --git a/tests/api_util_test.py b/tests/api_util_test.py index 46bed8c86b8a..e34611c6e785 100644 --- a/tests/api_util_test.py +++ b/tests/api_util_test.py @@ -69,5 +69,21 @@ def test_rebase_donate_argnums(self, donate, static, expected): self.assertEqual(expected, api_util.rebase_donate_argnums(donate, static)) + def test_resolve_kwargs(self): + def fun(x, y, z=3): + return x, y, z + assert api_util.resolve_kwargs(fun, (1,), {"y": 2}) == (1, 2, 3) + assert api_util.resolve_kwargs(fun, (1, 2), {"z": 3}) == (1, 2, 3) + assert api_util.resolve_kwargs( + fun, (), {"x": 1, "y": 2, "z": 3}) == (1, 2, 3) + + def test_resolve_kwargs_with_keyword(self): + def fun(x, y, z, *, kw=True): + del kw + return x, y, z + assert api_util.resolve_kwargs(fun, (1, 2), {"z": 3}) == (1, 2, 3) + with self.assertRaisesRegex(TypeError, "keyword arguments"): + api_util.resolve_kwargs(fun, (1, 2), {"z": 3, "kw": False}) + if __name__ == "__main__": absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/tests/array_api_skips.txt b/tests/array_api_skips.txt index f7d80d94f96f..2ac2edcdfd99 100644 --- a/tests/array_api_skips.txt +++ b/tests/array_api_skips.txt @@ -4,36 +4,11 @@ array_api_tests/test_data_type_functions.py::test_finfo[float32] # Test suite attempts in-place mutation: -array_api_tests/test_special_cases.py::test_iop -array_api_tests/test_special_cases.py::test_nan_propagation array_api_tests/test_array_object.py::test_setitem +array_api_tests/test_array_object.py::test_setitem_masking -# Raises NonInteractiveExampleWarning -array_api_tests/test_special_cases.py::test_binary -array_api_tests/test_special_cases.py::test_unary - -# Pending implementation update for proper dtype promotion behavior, -# see https://github.com/data-apis/array-api-tests/issues/234 -array_api_tests/test_statistical_functions.py::test_sum -array_api_tests/test_statistical_functions.py::test_prod - -# Pending bugfix, see https://github.com/data-apis/array-api-tests/issues/256 -array_api_tests/test_signatures.py::test_func_signature[logical_and] -array_api_tests/test_signatures.py::test_func_signature[logical_or] -array_api_tests/test_signatures.py::test_func_signature[logical_xor] +# Returns wrong zero sign +array_api_tests/test_special_cases.py::test_unary[sign((x_i is -0 or x_i == +0)) -> 0] # Returns int32 when int64 is expected array_api_tests/test_searching_functions.py::test_searchsorted - -# Various info functions not yet defined -# Pending bugfix, see https://github.com/data-apis/array-api-tests/pull/262 -array_api_tests/test_has_names.py::test_has_names[info-capabilities] -array_api_tests/test_has_names.py::test_has_names[info-default_device] -array_api_tests/test_has_names.py::test_has_names[info-default_dtypes] -array_api_tests/test_has_names.py::test_has_names[info-devices] -array_api_tests/test_has_names.py::test_has_names[info-dtypes] -array_api_tests/test_signatures.py::test_func_signature[capabilities] -array_api_tests/test_signatures.py::test_func_signature[default_device] -array_api_tests/test_signatures.py::test_func_signature[default_dtypes] -array_api_tests/test_signatures.py::test_func_signature[devices] -array_api_tests/test_signatures.py::test_func_signature[dtypes] diff --git a/tests/array_interoperability_test.py b/tests/array_interoperability_test.py index c2cd4c0f968d..02f5ad527c61 100644 --- a/tests/array_interoperability_test.py +++ b/tests/array_interoperability_test.py @@ -75,6 +75,8 @@ def setUp(self): use_stream=[False, True], ) @jtu.run_on_devices("gpu") + @jtu.ignore_warning(message="Calling from_dlpack with a DLPack tensor", + category=DeprecationWarning) def testJaxRoundTrip(self, shape, dtype, copy, use_stream): rng = jtu.rand_default(self.rng()) np = rng(shape, dtype) @@ -142,6 +144,8 @@ def testJaxArrayRoundTrip(self, shape, dtype, gpu): dtype=dlpack_dtypes, ) @unittest.skipIf(not tf, "Test requires TensorFlow") + @jtu.ignore_warning(message="Calling from_dlpack with a DLPack tensor", + category=DeprecationWarning) def testTensorFlowToJax(self, shape, dtype): if (not config.enable_x64.value and dtype in [jnp.int64, jnp.uint64, jnp.float64]): @@ -184,8 +188,10 @@ def testJaxToTensorFlow(self, shape, dtype): self.assertAllClose(np, y.numpy()) @unittest.skipIf(not tf, "Test requires TensorFlow") + @jtu.ignore_warning(message="Calling from_dlpack with a DLPack tensor", + category=DeprecationWarning) def testTensorFlowToJaxInt64(self): - # See https://github.com/google/jax/issues/11895 + # See https://github.com/jax-ml/jax/issues/11895 x = jax.dlpack.from_dlpack( tf.experimental.dlpack.to_dlpack(tf.ones((2, 3), tf.int64))) dtype_expected = jnp.int64 if config.enable_x64.value else jnp.int32 @@ -221,6 +227,8 @@ def testJaxToNumpy(self, shape, dtype): x_np = np.from_dlpack(x_jax) self.assertAllClose(x_np, x_jax) + @jtu.ignore_warning(message="Calling from_dlpack.*", + category=DeprecationWarning) def testNondefaultLayout(self): # Generate numpy array with nonstandard layout a = np.arange(4).reshape(2, 2) diff --git a/tests/array_test.py b/tests/array_test.py index f13ecbb51adb..080356d3490a 100644 --- a/tests/array_test.py +++ b/tests/array_test.py @@ -81,7 +81,7 @@ def test_array_impl_name(self): ("mesh_fully_replicated", P()), ) def test_jax_array_value(self, mesh_axes): - global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y')) + global_mesh = jtu.create_mesh((4, 2), ('x', 'y')) input_shape = (8, 2) arr, global_data = create_array( input_shape, jax.sharding.NamedSharding(global_mesh, mesh_axes)) @@ -121,7 +121,7 @@ def test_jax_array_value(self, mesh_axes): ) def test_array_2d_shard(self, mesh_axes, expected_index, expected_shard_shape, expected_replica_ids, expected_is_fully_replicated): - global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y')) + global_mesh = jtu.create_mesh((4, 2), ('x', 'y'), iota_order=True) global_input_shape = (8, 2) s = jax.sharding.NamedSharding(global_mesh, mesh_axes) arr, global_input_data = create_array(global_input_shape, s) @@ -148,7 +148,7 @@ def test_array_2d_shard(self, mesh_axes, expected_index, expected_shard_shape, self.assertArraysEqual(g.data, l.data) def test_addressable_data(self): - global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y')) + global_mesh = jtu.create_mesh((4, 2), ('x', 'y')) shape = (8, 2) s = jax.sharding.NamedSharding(global_mesh, P(None)) arr, inp_data = create_array(shape, s) @@ -156,7 +156,7 @@ def test_addressable_data(self): self.assertArraysEqual(inp_data, arr.addressable_data(i)) def test_array_delete(self): - global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y')) + global_mesh = jtu.create_mesh((4, 2), ('x', 'y')) input_shape = (8, 2) arr, _ = create_array( input_shape, jax.sharding.NamedSharding(global_mesh, P('x', 'y'))) @@ -174,7 +174,7 @@ def test_single_device_array_usage_after_delete(self): _ = x + 1 def test_multi_device_array_usage_after_delete(self): - global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y')) + global_mesh = jtu.create_mesh((4, 2), ('x', 'y')) shape = (8, 2) arr = jax.device_put(np.arange(math.prod(shape), dtype=np.int32), jax.sharding.NamedSharding(global_mesh, P('x'))) @@ -205,14 +205,14 @@ def test_device_put_array_delete(self): self.assertIsNone(arr._arrays) def test_array_device_get(self): - global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y')) + global_mesh = jtu.create_mesh((4, 2), ('x', 'y')) input_shape = (8, 2) arr, input_data = create_array( input_shape, jax.sharding.NamedSharding(global_mesh, P('x', 'y'))) self.assertArraysEqual(jax.device_get(arr), input_data) def test_repr(self): - global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y')) + global_mesh = jtu.create_mesh((4, 2), ('x', 'y')) input_shape = (8, 2) arr, _ = create_array( input_shape, jax.sharding.NamedSharding(global_mesh, P('x', 'y'))) @@ -254,7 +254,7 @@ def test_jnp_array_normal_add(self): self.assertIsInstance(arr.sharding, jax.sharding.SingleDeviceSharding) def test_array_sharded_astype(self): - global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y')) + global_mesh = jtu.create_mesh((4, 2), ('x', 'y')) input_shape = (8, 2) arr, input_data = create_array( input_shape, jax.sharding.NamedSharding(global_mesh, P('x', 'y'))) @@ -272,7 +272,7 @@ def test_jnp_array_astype(self): self.assertArraysEqual(arr_float32, arr.astype(np.float32)) def test_array_delete_idempotent(self): - mesh = jtu.create_global_mesh((2,), ('x',)) + mesh = jtu.create_mesh((2,), ('x',)) arr = jax.device_put(np.arange(8), jax.sharding.NamedSharding(mesh, P('x'))) arr.delete() @@ -282,7 +282,7 @@ def test_array_delete_idempotent(self): self.assertTrue(arr.is_deleted()) def test_sharded_add(self): - global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y')) + global_mesh = jtu.create_mesh((4, 2), ('x', 'y')) input_shape = (8, 2) a, input_data = create_array( input_shape, jax.sharding.NamedSharding(global_mesh, P('x', 'y'))) @@ -296,7 +296,7 @@ def test_sharded_add(self): self.assertArraysEqual(i.data, expected[i.index]) def test_sharded_zeros_like(self): - global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y')) + global_mesh = jtu.create_mesh((4, 2), ('x', 'y')) input_shape = (8, 2) a, input_data = create_array( input_shape, jax.sharding.NamedSharding(global_mesh, P('x', 'y'))) @@ -318,7 +318,7 @@ def test_wrong_num_arrays(self): if jax.device_count() < 4: self.skipTest('Requires more than 4 devices') shape = (8, 2) - mesh = jtu.create_global_mesh((1, 2), ('x', 'y')) + mesh = jtu.create_mesh((1, 2), ('x', 'y')) devices = jax.local_devices()[:2] # Taking up to 2 devices s = jax.sharding.NamedSharding(mesh, P('x', 'y')) inp_data = np.arange(math.prod(shape), dtype=np.float32).reshape(shape) @@ -342,7 +342,7 @@ def test_arrays_not_in_device_assignment(self): if jax.device_count() < 4: self.skipTest('Requires more than 4 devices') shape = (8, 2) - mesh = jtu.create_global_mesh((1, 2), ('x', 'y')) + mesh = jtu.create_mesh((1, 2), ('x', 'y')) # sharding device ids = {0, 1} s = jax.sharding.NamedSharding(mesh, P('x')) inp_data = np.arange(math.prod(shape), dtype=np.float32).reshape(shape) @@ -378,7 +378,7 @@ def test_duplicated_devices_in_arrays(self): if xc._version <= 274: self.skipTest('Test requires jaxlib version 275') shape = (8, 2) - mesh = jtu.create_global_mesh((1, 2), ('x', 'y')) + mesh = jtu.create_mesh((1, 2), ('x', 'y')) # Sharding device ids = {0, 1} s = jax.sharding.NamedSharding(mesh, P('x')) inp_data = np.arange(math.prod(shape), dtype=np.float32).reshape(shape) @@ -401,7 +401,7 @@ def test_duplicated_devices_in_arrays(self): ) def test_shard_shape_mismatch_with_buffer_shape(self, pspec, expected_shard_shape): shape = (8, 4) - mesh = jtu.create_global_mesh((4, 2), ('x', 'y')) + mesh = jtu.create_mesh((4, 2), ('x', 'y')) mps = jax.sharding.NamedSharding(mesh, pspec) inp_data = np.arange(5) @@ -415,7 +415,7 @@ def test_shard_shape_mismatch_with_buffer_shape(self, pspec, expected_shard_shap def test_mismatch_dtype(self): shape = (8, 2) - mesh = jtu.create_global_mesh((1, 2), ('x', 'y')) + mesh = jtu.create_mesh((1, 2), ('x', 'y')) s = jax.sharding.NamedSharding(mesh, P('x', 'y')) inp_data = np.arange(math.prod(shape), dtype=np.int32).reshape(shape) indices = s.devices_indices_map(shape) @@ -452,7 +452,7 @@ def test_array_iter_pmap_sharding_last_dim_sharded(self): self.assertArraysAllClose(i, j) def test_array_iter_mesh_pspec_sharding_multi_device(self): - global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y')) + global_mesh = jtu.create_mesh((4, 2), ('x', 'y')) input_shape = (8, 2) arr, input_data = create_array( input_shape, jax.sharding.NamedSharding(global_mesh, P('x', 'y'))) @@ -462,7 +462,7 @@ def test_array_iter_mesh_pspec_sharding_multi_device(self): self.assertArraysEqual(i, j) def test_array_iter_replicated_multi_device(self): - global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y')) + global_mesh = jtu.create_mesh((4, 2), ('x', 'y')) input_shape = (8, 2) arr, input_data = create_array( input_shape, jax.sharding.NamedSharding(global_mesh, P(None))) @@ -477,7 +477,7 @@ def test_array_iter_replicated_multi_device(self): i.sharding._to_xla_hlo_sharding(i.ndim))) def test_array_getitem_mesh_pspec_sharding_multi_device(self): - global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y')) + global_mesh = jtu.create_mesh((4, 2), ('x', 'y')) input_shape = (8, 2) arr, input_data = create_array( input_shape, jax.sharding.NamedSharding(global_mesh, P('x', 'y'))) @@ -496,7 +496,7 @@ def _check(out, inp, shard_shape): self.assertEqual(out.sharding.shard_shape(out.shape), shard_shape) self.assertNotIsInstance(out.sharding, jax.sharding.SingleDeviceSharding) - global_mesh = jtu.create_global_mesh((2, 2, 2), ('x', 'y', 'z')) + global_mesh = jtu.create_mesh((2, 2, 2), ('x', 'y', 'z')) input_shape = (4, 4, 2) arr, np_inp = create_array( input_shape, jax.sharding.NamedSharding(global_mesh, P('x', 'y', 'z'))) @@ -523,7 +523,7 @@ def _check(out, inp, shard_shape): _check(arr[1], np_inp[1], (2, 1)) def test_array_getitem_replicated_multi_device(self): - global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y')) + global_mesh = jtu.create_mesh((4, 2), ('x', 'y')) input_shape = (8, 2) arr, input_data = create_array( input_shape, jax.sharding.NamedSharding(global_mesh, P(None))) @@ -575,7 +575,7 @@ def test_array_shards_committed(self): self.assertTrue(s.data._committed) def test_array_jnp_array_copy_multi_device(self): - global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y')) + global_mesh = jtu.create_mesh((4, 2), ('x', 'y')) input_shape = (8, 2) arr, _ = create_array( input_shape, jax.sharding.NamedSharding(global_mesh, P('x', 'y'))) @@ -592,7 +592,7 @@ def test_array_jnp_array_copy_multi_device(self): c.data.unsafe_buffer_pointer()) def test_array_addressable_shards(self): - global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y')) + global_mesh = jtu.create_mesh((4, 2), ('x', 'y')) input_shape = (8, 2) arr, _ = create_array( input_shape, jax.sharding.NamedSharding(global_mesh, P('x', 'y'))) @@ -620,7 +620,7 @@ def check_tracer_hash(x): check_tracer_hash(x) def test_shape_dtype_struct_sharding_jit(self): - mesh = jtu.create_global_mesh((8,), ('x')) + mesh = jtu.create_mesh((8,), ('x')) s = jax.sharding.NamedSharding(mesh, P('x')) x_dummy = jax.ShapeDtypeStruct( @@ -647,7 +647,7 @@ def f(x): s._to_xla_hlo_sharding(x_dummy.ndim))) def test_shape_dtype_struct_sharding_pjit(self): - mesh = jtu.create_global_mesh((4, 2), ('x', 'y')) + mesh = jtu.create_mesh((4, 2), ('x', 'y')) s = jax.sharding.NamedSharding(mesh, P('x', 'y')) def f(x): @@ -677,7 +677,7 @@ def test_defragment(self): self.skipTest("Manual defragment not exposed via PJRT C API") # Create a few arrays - global_mesh = jtu.create_global_mesh((jax.local_device_count(),), ('x',)) + global_mesh = jtu.create_mesh((jax.local_device_count(),), ('x',)) shape = (8, 2) mpsharding = jax.sharding.NamedSharding(global_mesh, P('x',)) arr1, data = create_array(shape, mpsharding) @@ -700,7 +700,7 @@ def test_defragment(self): # OOM, and exposing allocator stats in Python. def test_on_device_size_in_bytes(self): - global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y')) + global_mesh = jtu.create_mesh((4, 2), ('x', 'y')) a, _ = create_array( (8, 2), jax.sharding.NamedSharding(global_mesh, P('x', 'y'))) shard_size = a.addressable_shards[0].data.on_device_size_in_bytes() @@ -756,7 +756,7 @@ def test_buffer_protocol_deletion(self): self.assertEqual(x_bytes, y_bytes) def test_array_copy_to_host_async(self): - global_mesh = jtu.create_global_mesh((2, 2), ('x', 'y')) + global_mesh = jtu.create_mesh((2, 2), ('x', 'y')) x = pjit(lambda: jnp.arange(8.), out_shardings=jax.sharding.NamedSharding(global_mesh, P(None)))() self.assertLen(x.sharding.device_set, 4) @@ -765,7 +765,7 @@ def test_array_copy_to_host_async(self): def test_array_fully_replicated_shard(self): - global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y')) + global_mesh = jtu.create_mesh((4, 2), ('x', 'y')) inp_shape = (8, 2) arr, inp_data = create_array( inp_shape, jax.sharding.NamedSharding(global_mesh, P())) @@ -776,7 +776,7 @@ def test_array_fully_replicated_shard(self): self.assertArraysEqual(arr.addressable_data(0), inp_data) def test_shard_array_to_fully_replicated(self): - global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y')) + global_mesh = jtu.create_mesh((4, 2), ('x', 'y')) sharding = jax.sharding.NamedSharding(global_mesh, P()) arr = jnp.arange(16) self.assertFalse(arr._committed) @@ -786,7 +786,7 @@ def test_shard_array_to_fully_replicated(self): self.assertArraysEqual(out, arr * 2) def test_fully_replicated_donated_array_is_deleted(self): - global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y')) + global_mesh = jtu.create_mesh((4, 2), ('x', 'y')) sharding = jax.sharding.NamedSharding(global_mesh, P()) arr = jnp.arange(16) arr_copy = arr.copy() @@ -804,7 +804,7 @@ def test_shards_have_correct_dtype(self, dtype): self.assertEqual(shard.data.dtype, dtype) def test_make_array_from_callback_global_array(self): - mesh = jtu.create_global_mesh((4, 2), ('x', 'y')) + mesh = jtu.create_mesh((4, 2), ('x', 'y')) sharding = jax.sharding.NamedSharding(mesh, P()) np_inp = np.arange(16).reshape(8, 2) arr = jax.device_put(np_inp, sharding) @@ -822,23 +822,17 @@ def test_make_array_from_callback_global_array(self): self.assertEqual(out2.sharding, sharding2) def test_make_array_from_process_data_single_host_data_sharding(self): - data = np.ones((1, 512)) - mesh = jtu.create_global_mesh((1, 1), ('x', 'unused')) - sharding_spec = jax.sharding.NamedSharding( - mesh, jax.sharding.PartitionSpec('x') - ) - global_shape = data.shape - result = jax.make_array_from_process_local_data( - sharding_spec, data, global_shape - ) - self.assertIsInstance(result, jax.Array) - self.assertEqual(result.shape, data.shape) - self.assertEqual(result.sharding, sharding_spec) + mesh = jtu.create_mesh((2, 1), ('x', 'y')) + data = np.ones((256, 512)) + s = jax.NamedSharding(mesh, P('x')) + result = jax.make_array_from_process_local_data(s, data) + self.assertArraysEqual(result, data) + self.assertEqual(result.sharding, s) class ShardingTest(jtu.JaxTestCase): def test_mesh_pspec_sharding_interface(self): - mesh = jtu.create_global_mesh((4, 2), ('x', 'y')) + mesh = jtu.create_mesh((4, 2), ('x', 'y')) pspec = P('y', 'x') global_shape = (8, 4) mp_sharding = jax.sharding.NamedSharding(mesh, pspec) @@ -855,7 +849,7 @@ def test_mesh_pspec_sharding_interface(self): [0, 2, 4, 6, 1, 3, 5, 7]) def test_util_clear_cache(self): - mesh = jtu.create_global_mesh((1,), ('x',)) + mesh = jtu.create_mesh((1,), ('x',)) s = NamedSharding(mesh, P()) s.devices_indices_map((8,)) jax.clear_caches() @@ -874,7 +868,7 @@ def test_util_clear_cache(self): ) def test_op_sharding_indices(self, pspec): shape = (8, 4) - mesh = jtu.create_global_mesh((4, 2), ('x', 'y')) + mesh = jtu.create_mesh((4, 2), ('x', 'y')) mps = jax.sharding.NamedSharding(mesh, pspec) ops = jax.sharding.GSPMDSharding( list(mesh.devices.flat), mps._to_xla_hlo_sharding(len(shape))) @@ -892,12 +886,12 @@ def test_op_sharding_indices(self, pspec): ) def test_shard_shape(self, pspec, expected_shard_shape): shape = (8, 4) - mesh = jtu.create_global_mesh((4, 2), ('x', 'y')) + mesh = jtu.create_mesh((4, 2), ('x', 'y')) mps = jax.sharding.NamedSharding(mesh, pspec) self.assertEqual(mps.shard_shape(shape), expected_shard_shape) def test_uneven_shard_error(self): - mesh = jtu.create_global_mesh((4, 2), ('x', 'y')) + mesh = jtu.create_mesh((4, 2), ('x', 'y')) mps = jax.sharding.NamedSharding(mesh, P('x', 'y')) with self.assertRaisesRegex( ValueError, @@ -930,7 +924,7 @@ def test_pmap_sharding_hash_eq(self): def test_is_compatible_error(self): shape = (8, 2) - mesh = jtu.create_global_mesh((1, 1, 2), ('replica', 'data', 'mdl')) + mesh = jtu.create_mesh((1, 1, 2), ('replica', 'data', 'mdl')) mps = jax.sharding.NamedSharding(mesh, P(None, ('mdl',), None, None)) new_mps = jax.sharding.NamedSharding._from_parsed_pspec( mps.mesh, mps._parsed_pspec) @@ -982,7 +976,7 @@ def test_positional_sharding_op_sharding_lowering( self, pspec, shape, axes, transpose): value_shape = (8, 4) - mesh = jtu.create_global_mesh((4, 2), ('x', 'y')) + mesh = jtu.create_mesh((4, 2), ('x', 'y')) mps = jax.sharding.NamedSharding(mesh, pspec) devices = jax.local_devices()[:8] # Taking up to 8 devices @@ -1038,7 +1032,7 @@ def test_positional_sharding_aval_compatible(self): ) def test_positional_sharding_from_op_sharding(self, mesh_shape, pspec): ndim = len(mesh_shape) - mesh = jtu.create_global_mesh( + mesh = jtu.create_mesh( mesh_shape, ('x', 'y') if ndim == 2 else ('x', 'y', 'z')) mps = jax.sharding.NamedSharding(mesh, pspec) original_op_sharding = mps._to_xla_hlo_sharding(ndim) @@ -1071,7 +1065,7 @@ def test_is_fully_replicated_named_sharding(self, mesh_shape, pspec): axis_names = ('x', 'y', 'z') else: axis_names = ('x',) - mesh = jtu.create_global_mesh(mesh_shape, axis_names) + mesh = jtu.create_mesh(mesh_shape, axis_names) mps = jax.sharding.NamedSharding(mesh, pspec) shape = (8, 2, 4) mps_op_sharding = mps._to_xla_hlo_sharding(len(shape)) @@ -1086,7 +1080,7 @@ def test_is_fully_replicated_named_sharding(self, mesh_shape, pspec): def test_devices_sharding_respects_init_mesh_shape(self): value_shape = (8, 4) - mesh = jtu.create_global_mesh((4, 2), ('x', 'y')) + mesh = jtu.create_mesh((4, 2), ('x', 'y')) mps = jax.sharding.NamedSharding(mesh, P('x', 'y')) devices_sharding = jax.sharding.PositionalSharding(mesh.devices) @@ -1140,14 +1134,14 @@ def test_default_pmap_sharding_with_devices(self): self.assertEqual(ps._device_assignment, new_order) def test_mesh_repr(self): - mesh = jtu.create_global_mesh((1, 1), ('x', 'y')) + mesh = jtu.create_mesh((1, 1), ('x', 'y')) mesh_repr = repr(mesh) self.assertIn('device_ids', mesh_repr) self.assertIn('axis_names', mesh_repr) def test_are_shardings_equivalent(self): - mesh = jtu.create_global_mesh((1,), ('x')) - mesh2 = jtu.create_global_mesh((2, 1), ('x', 'y')) + mesh = jtu.create_mesh((1,), ('x')) + mesh2 = jtu.create_mesh((2, 1), ('x', 'y')) s1 = jax.sharding.NamedSharding(mesh, P('x')) s2 = jax.sharding.SingleDeviceSharding(jax.devices()[0]) @@ -1196,7 +1190,7 @@ def test_are_shardings_equivalent(self): def test_devices_indices_map_good_error_message(self): shape = (1, 2) - mesh = jtu.create_global_mesh((2, 2), ('x', 'y')) + mesh = jtu.create_mesh((2, 2), ('x', 'y')) s = jax.sharding.NamedSharding(mesh, P('x', 'y')) with self.assertRaisesRegex( ValueError, @@ -1205,7 +1199,7 @@ def test_devices_indices_map_good_error_message(self): s.devices_indices_map(shape) def test_scalar_input_wrong_pspec(self): - mesh = jtu.create_global_mesh((1, ), ('x')) + mesh = jtu.create_mesh((1, ), ('x')) shape = () s = jax.sharding.NamedSharding(mesh, P('x')) with self.assertRaisesRegex( @@ -1222,13 +1216,13 @@ def test_mesh_caching_during_construction(self): self.assertIs(mesh1, mesh2) def test_mesh_str(self): - mesh = jtu.create_global_mesh((2, 2, 2), ('x', 'y', 'z')) + mesh = jtu.create_mesh((2, 2, 2), ('x', 'y', 'z')) self.assertEqual(str(mesh), "Mesh('x': 2, 'y': 2, 'z': 2)") def test_make_array_from_callback_error(self): mesh_shape = (2, 3) global_shape = tuple(np.square(mesh_shape)) - mesh = jtu.create_global_mesh(mesh_shape, ('x', 'y')) + mesh = jtu.create_mesh(mesh_shape, ('x', 'y'), iota_order=True) pspec = P('x', 'y') sharding = jax.sharding.NamedSharding(mesh, pspec) n = math.prod(global_shape) @@ -1257,7 +1251,7 @@ def f(x): def test_make_array_from_single_device_arrays_bad_inputs(self): x = jnp.arange(10) - mesh = jtu.create_global_mesh((2,), ('x',)) + mesh = jtu.create_mesh((2,), ('x',)) s = jax.sharding.NamedSharding(mesh, P('x')) x = jax.device_put(x, s) @@ -1268,7 +1262,7 @@ def test_make_array_from_single_device_arrays_bad_inputs(self): def test_gspmd_sharding_hash_eq(self): - mesh = jtu.create_global_mesh((1, 1, 1), ('x', 'y', 'z')) + mesh = jtu.create_mesh((1, 1, 1), ('x', 'y', 'z')) ns = NamedSharding(mesh, P('x', 'y', 'z')) x1 = GSPMDSharding(mesh._flat_devices_tuple, ns._to_xla_hlo_sharding(3)) @@ -1283,14 +1277,14 @@ def test_device_attr(self): self.assertEqual(x.device, list(x.devices())[0]) # For sharded arrays, x.device returns the sharding - mesh = jtu.create_global_mesh((2,), ('x',)) + mesh = jtu.create_mesh((2,), ('x',)) sharding = jax.sharding.NamedSharding(mesh, P('x')) x = jax.device_put(x, sharding) self.assertEqual(x.device, sharding) def test_to_device(self): device = jax.devices()[-1] - mesh = jtu.create_global_mesh((2,), ('x',)) + mesh = jtu.create_mesh((2,), ('x',)) sharding = jax.sharding.NamedSharding(mesh, P('x')) x = jnp.ones((2, 10)) @@ -1306,7 +1300,7 @@ def test_to_device(self): class ShardyShardingTest(jtu.JaxTestCase): def test_long_axis_names(self): - mesh = jtu.create_global_mesh((2, 2, 2), ('sequence', 'data', 'model')) + mesh = jtu.create_mesh((2, 2, 2), ('sequence', 'data', 'model')) s = jax.sharding.NamedSharding(mesh, P(('sequence', 'data'), 'model')) sdy_sharding = s._to_sdy_sharding(3) self.assertEqual( @@ -1323,7 +1317,7 @@ def test_long_axis_names(self): '#sdy.sharding<@mesh, [{"sequence", "data"}, {"model"}, {}]>') def test_unconstrained(self): - mesh = jtu.create_global_mesh((8,), ('x',)) + mesh = jtu.create_mesh((8,), ('x',)) s = jax.sharding.NamedSharding(mesh, P(None, P.UNCONSTRAINED, 'x')) sdy_sharding = s._to_sdy_sharding(3) self.assertEqual( @@ -1351,7 +1345,7 @@ def f(x): 32, x.shape) return bits + x - mesh = jtu.create_global_mesh((num_devices,), ('x',)) + mesh = jtu.create_mesh((num_devices,), ('x',), iota_order=True) s = jax.sharding.NamedSharding(mesh, P('x')) n = num_devices ** 2 @@ -1387,7 +1381,7 @@ def f(x): global_shape = tuple(np.square(mesh_shape)) - mesh = jtu.create_global_mesh(mesh_shape, ('x', 'y')) + mesh = jtu.create_mesh(mesh_shape, ('x', 'y'), iota_order=True) s = jax.sharding.NamedSharding(mesh, pspec) n = math.prod(global_shape) @@ -1409,6 +1403,19 @@ def f(x): y_ref1 = f(jax.device_put(x, jax.devices()[0])) self.assertArraysEqual(y, y_ref1) + def test_empty_mesh_creation(self): + mesh = jax.sharding.Mesh(devices=np.empty((), dtype=object), axis_names=[]) + self.assertTrue(mesh.empty) + self.assertEqual(mesh.size, 0) + + abstract_mesh = mesh.abstract_mesh + self.assertTrue(abstract_mesh.empty) + self.assertEqual(abstract_mesh.size, 0) + + abstract_mesh2 = jax.sharding.AbstractMesh(()) + self.assertTrue(abstract_mesh2.empty) + self.assertEqual(abstract_mesh2.size, 0) + if __name__ == '__main__': absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/tests/attrs_test.py b/tests/attrs_test.py index 5c834f314270..4eb354a8d50f 100644 --- a/tests/attrs_test.py +++ b/tests/attrs_test.py @@ -323,7 +323,7 @@ def f(obj, x): self.assertEqual(count, 1) def test_tracer_lifetime_bug(self): - # regression test for https://github.com/google/jax/issues/20082 + # regression test for https://github.com/jax-ml/jax/issues/20082 class StatefulRNG: key: jax.Array @@ -344,6 +344,21 @@ def jitted(): jax.jit(jitted)() # don't crash + def test_scan_carry(self): + class A: + ... + + a = A() + + jax_setattr(a, 'x', jnp.zeros(3)) + + def body(i, _): + x = jax_getattr(a, 'x') + x = x.at[i].set(x[i] + 1) + jax_setattr(a, 'x', x) + return i + 1, None + _, _ = jax.lax.scan(body, 0, None, length=3) # don't crash + class AttrsJVPTest(jtu.JaxTestCase): diff --git a/tests/batching_test.py b/tests/batching_test.py index 6cd8c7bc20ac..2b0b0d63a6f5 100644 --- a/tests/batching_test.py +++ b/tests/batching_test.py @@ -335,7 +335,7 @@ def testConcatenate(self): self.assertAllClose(ans, expected_ans, check_dtypes=False) def testJacobianIssue54(self): - # test modeling the code in https://github.com/google/jax/issues/54 + # test modeling the code in https://github.com/jax-ml/jax/issues/54 def func(xs): return jnp.array(list(xs)) @@ -345,7 +345,7 @@ def func(xs): jacfwd(func)(xs) # don't crash def testAny(self): - # test modeling the code in https://github.com/google/jax/issues/108 + # test modeling the code in https://github.com/jax-ml/jax/issues/108 ans = vmap(jnp.any)(jnp.array([[True, False], [False, False]])) expected = jnp.array([True, False]) @@ -368,7 +368,7 @@ def fun(x, t): def testDynamicSlice(self): # test dynamic_slice via numpy indexing syntax - # see https://github.com/google/jax/issues/1613 for an explanation of why we + # see https://github.com/jax-ml/jax/issues/1613 for an explanation of why we # need to use np rather than np to create x and idx x = jnp.arange(30).reshape((10, 3)) @@ -933,7 +933,7 @@ def f(scale): rtol=jtu.default_gradient_tolerance) def testIssue387(self): - # https://github.com/google/jax/issues/387 + # https://github.com/jax-ml/jax/issues/387 R = self.rng().rand(100, 2) def dist_sq(R): @@ -951,7 +951,7 @@ def f(R): @jax.legacy_prng_key('allow') def testIssue489(self): - # https://github.com/google/jax/issues/489 + # https://github.com/jax-ml/jax/issues/489 def f(key): def body_fn(uk): key = uk[1] @@ -1131,7 +1131,7 @@ def testAxisIndex(self): x - np.arange(x.shape[0], dtype='int32')) def testVmapKwargs(self): - # https://github.com/google/jax/issues/912 + # https://github.com/jax-ml/jax/issues/912 def f(a, b): return (2*a, 3*b) @@ -1242,7 +1242,7 @@ def f(x): self.assertEqual(jax.vmap(f)(jnp.ones((2, 3))).shape, (2, 3)) def testPpermuteBatcherTrivial(self): - # https://github.com/google/jax/issues/8688 + # https://github.com/jax-ml/jax/issues/8688 def ppermute(input): return jax.lax.ppermute(input, axis_name="i", perm=[[0, 1], [1, 0]]) @@ -1255,7 +1255,7 @@ def ppermute(input): self.assertAllClose(ans, jnp.ones(2), check_dtypes=False) def testBatchingPreservesWeakType(self): - # Regression test for https://github.com/google/jax/issues/10025 + # Regression test for https://github.com/jax-ml/jax/issues/10025 x = jnp.ravel(1) self.assertTrue(dtypes.is_weakly_typed(x)) @vmap diff --git a/tests/cache_key_test.py b/tests/cache_key_test.py index 508dbacc2a98..00925c5f7dfc 100644 --- a/tests/cache_key_test.py +++ b/tests/cache_key_test.py @@ -14,8 +14,10 @@ import hashlib import os +import re import sys import unittest +from typing import cast as type_cast import numpy as np @@ -29,6 +31,11 @@ from jax._src import test_util as jtu from jax._src import xla_bridge from jax._src.lib import xla_client +from jax._src.lib.mlir import ir +from jax._src.mesh import Mesh +from jax._src.partition_spec import PartitionSpec as P +from jax._src.sharding_impls import NamedSharding +from jax._src.custom_partitioning import custom_partitioning config.parse_flags_with_absl() @@ -155,6 +162,49 @@ def test_different_computations(self): cache_key.get(computation2, devices, compile_options, backend), ) + def test_custom_partitioning_ptr_removal(self): + def _partition(mesh, arg_shapes, result_shape): + arg_shardings = jax.tree.map(lambda x: x.sharding, arg_shapes) + result_shardings = NamedSharding(mesh, arg_shapes[0].sharding.spec) + return mesh, jax.numpy.add, result_shardings, arg_shardings + + def _infer_sharding_from_operands(mesh, arg_shapes, result_shape): + return NamedSharding(mesh, arg_shapes[0].sharding.spec) + + @custom_partitioning + def _cp_add(x, y): + return jax.numpy.add(x, y) + + _cp_add.def_partition( + infer_sharding_from_operands=_infer_sharding_from_operands, + partition=_partition) + + devices = np.asarray(jax.devices()) + with Mesh(devices, ('x',)) as m: + computation = jax.jit( + _cp_add, + in_shardings=(NamedSharding(m, P('x')), + NamedSharding(m, P('x'))), + out_shardings=NamedSharding(m, P('x')) + ).lower( + jax.ShapeDtypeStruct([1024], dtype=jax.numpy.float32), + jax.ShapeDtypeStruct([1024], dtype=jax.numpy.float32), + ).compiler_ir() + pattern = ( + r'stablehlo\.custom_call @CustomSPMDPartitioning\(' + r'(.*?)\) \{' + r'(.*?backend_config\s*=\s*"([^"]*)".*?)' + r'\}' + ) + with config.remove_custom_partitioning_ptr_from_cache_key(True): + with computation.context: + updated_module = cache_key._remove_custom_partitioning_ptr( + type_cast(ir.Module, computation.operation.clone())) + bcs = [match[2] for + match in re.findall(pattern, str(updated_module), re.DOTALL)] + for bc in bcs: + self.assertEqual(bc, "REMOVED") + def test_different_device_assignment(self): computation = jax.jit(lambda x, y: x + y).lower(1, 1).compiler_ir() devices = np.array([[jax.local_devices()[0]]]) diff --git a/tests/checkify_test.py b/tests/checkify_test.py index 730f14ddcdd1..24387a767659 100644 --- a/tests/checkify_test.py +++ b/tests/checkify_test.py @@ -23,6 +23,7 @@ from jax import lax from jax.experimental import checkify from jax.experimental import pjit +from jax.experimental import shard_map from jax.sharding import NamedSharding from jax._src import array from jax._src import config @@ -539,6 +540,46 @@ def g(x, y): self.assertIsNotNone(b_err.get()) self.assertStartsWith(b_err.get(), "division by zero") + @parameterized.parameters(True, False) + def test_shard_map(self, check_rep): + def f(x): + # unary func + return jax.lax.axis_index("dev") * x / x + + def g(x, y): + # binary func + return jax.lax.axis_index("dev") * x / y + + devices = jax.local_devices()[:8] # Taking up to 8 devices + mesh = jax.sharding.Mesh(np.array(devices), ["dev"]) + pspec = jax.sharding.PartitionSpec("dev") + ps = NamedSharding(mesh, pspec) + inp = np.tile(np.arange(4, dtype=np.int32), 2) + x = array.make_array_from_callback(inp.shape, ps, lambda idx: inp[idx]) + + f = shard_map.shard_map( + f, mesh, in_specs=pspec, out_specs=pspec, check_rep=check_rep + ) + f = jax.jit(f, in_shardings=ps, out_shardings=ps) + f = checkify.checkify(f, errors=checkify.float_checks) + g = shard_map.shard_map( + g, mesh, in_specs=(pspec, pspec), out_specs=pspec, check_rep=check_rep + ) + g = jax.jit(g, in_shardings=(ps, ps), out_shardings=ps) + g = checkify.checkify(g, errors=checkify.float_checks) + u_err, _ = f(x) + b_err, _ = g(x, x) + + divbyzero = "division by zero" + expected_err = f"at mapped index 0: {divbyzero}" + if (next_device_with_zero := len(devices) // 2) != 0: + expected_err += f"\nat mapped index {next_device_with_zero}: {divbyzero}" + + self.assertIsNotNone(u_err.get()) + self.assertEqual(u_err.get(), expected_err) + self.assertIsNotNone(b_err.get()) + self.assertEqual(b_err.get(), expected_err) + def test_empty_enabled_errors(self): def multi_errors(x): x = x/0 # DIV @@ -815,7 +856,7 @@ def g(x): def test_retracing(self): f = checkify.checkify(jax.jit(lambda x: jnp.sin(x) ** 2)) _ = f(3.) - with jtu.count_jit_and_pmap_compiles() as count: + with jtu.count_jit_and_pmap_lowerings() as count: _ = f(3.) self.assertEqual(count[0], 0) diff --git a/tests/compilation_cache_test.py b/tests/compilation_cache_test.py index fd02f59826cc..75c52822a223 100644 --- a/tests/compilation_cache_test.py +++ b/tests/compilation_cache_test.py @@ -243,6 +243,7 @@ def test_cache_write_warning(self): mock.patch.object(cc._get_cache(backend).__class__, "put") as mock_put, warnings.catch_warnings(record=True) as w, ): + warnings.simplefilter("always") mock_put.side_effect = RuntimeError("test error") self.assertEqual(f(2).item(), 4) if len(w) != 1: @@ -265,6 +266,7 @@ def test_cache_read_warning(self): mock.patch.object(cc._get_cache(backend).__class__, "get") as mock_get, warnings.catch_warnings(record=True) as w, ): + warnings.simplefilter("always") mock_get.side_effect = RuntimeError("test error") # Calling assertEqual with the jitted f will generate two PJIT # executables: Equal and the lambda function itself. diff --git a/tests/core_test.py b/tests/core_test.py index 0838702c4be6..94b7010907a9 100644 --- a/tests/core_test.py +++ b/tests/core_test.py @@ -349,7 +349,7 @@ def g_vmap(x): g_vmap(jnp.ones((1, ))) def test_concrete_array_string_representation(self): - # https://github.com/google/jax/issues/5364 + # https://github.com/jax-ml/jax/issues/5364 self.assertEqual( str(core.ConcreteArray(np.dtype(np.int32), np.array([1], dtype=np.int32))), @@ -369,7 +369,7 @@ def body(c, _): self.assertEqual(dropvar.aval, aval) def test_input_residual_forwarding(self): - # https://github.com/google/jax/pull/11151 + # https://github.com/jax-ml/jax/pull/11151 x = jnp.arange(3 * 4.).reshape(3, 4) y = jnp.arange(4 * 3.).reshape(4, 3) diff --git a/tests/cudnn_fusion_test.py b/tests/cudnn_fusion_test.py new file mode 100644 index 000000000000..e70ba12361a2 --- /dev/null +++ b/tests/cudnn_fusion_test.py @@ -0,0 +1,69 @@ +# Copyright 2024 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from absl.testing import absltest, parameterized +from unittest import SkipTest +from jax._src import test_util as jtu +import jax +import jax.numpy as jnp +from jax._src.cudnn import cudnn_fusion + + +jax.config.parse_flags_with_absl() + + +class CudnnFusionTest(jtu.JaxTestCase): + def setUp(self): + if (not jtu.test_device_matches(["cuda"]) or + not jtu.is_cuda_compute_capability_at_least("8.0")): + self.skipTest("Only works on >= sm80 GPUs") + super().setUp() + + @parameterized.parameters(["", "pmap"]) + @jtu.run_on_devices("cuda") + def test_cudnn_fusion(self, mode): + batch_size = 2 + if mode == "pmap" and jax.device_count() < batch_size: + raise SkipTest("pmap test requires 2 GPUs") + + @cudnn_fusion + def comp1(x, y, z): + return jnp.float32(jax.lax.batch_matmul(jnp.bfloat16(x), y)) + z + + k = jax.random.key(0) + s = batch_size, 16, 16 + x = jnp.int8(jax.random.normal(k, shape=s)) + y = jnp.bfloat16(jax.random.normal(k, shape=s)) + z = jnp.float32(jax.random.normal(k, shape=s)) + + fn = jax.pmap(comp1) if mode == "pmap" else comp1 + jitted = jax.jit(comp1) + lowered = jitted.lower(x, y, z) + stablehlo = lowered.as_text("stablehlo") + self.assertIn("func.func private @comp1", stablehlo) + self.assertIn("__cudnn$fusion", stablehlo) + + hlo = lowered.as_text("hlo") + self.assertIn('custom_call_target="__cudnn$fusion"', hlo) + self.assertIn("called_computations=", hlo) + + hlo_after_opt = lowered.compile().as_text() + self.assertIn("kind=kCustom", hlo_after_opt) + self.assertIn("plan_id", hlo_after_opt) + + self.assertAllClose(jitted(x, y, z), fn(x, y, z)) + + +if __name__ == '__main__': + absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/tests/custom_linear_solve_test.py b/tests/custom_linear_solve_test.py index 830526826059..857dc34d430e 100644 --- a/tests/custom_linear_solve_test.py +++ b/tests/custom_linear_solve_test.py @@ -291,7 +291,7 @@ def transpose_solve(vecmat, x): jtu.check_grads(linear_solve, (a, b), order=2, rtol=2e-3) - # regression test for https://github.com/google/jax/issues/1536 + # regression test for https://github.com/jax-ml/jax/issues/1536 jtu.check_grads(jax.jit(linear_solve), (a, b), order=2, rtol={np.float32: 2e-3}) @@ -396,7 +396,7 @@ def custom_unrolled_lower_tri_solve(mat, b): def test_custom_linear_solve_pytree_with_aux(self): # Check that lax.custom_linear_solve handles # pytree inputs + has_aux=True - # https://github.com/google/jax/pull/13093 + # https://github.com/jax-ml/jax/pull/13093 aux_orig = {'a': 1, 'b': 2} b = {'c': jnp.ones(2), 'd': jnp.ones(3)} diff --git a/tests/custom_object_test.py b/tests/custom_object_test.py index 75ff39630705..4b1182e16b5a 100644 --- a/tests/custom_object_test.py +++ b/tests/custom_object_test.py @@ -68,20 +68,17 @@ def __repr__(self): class AbstractSparseArray(core.ShapedArray): __slots__ = ['index_dtype', 'nnz', 'data_aval', 'indices_aval'] - def __init__(self, shape, dtype, index_dtype, nnz, weak_type=False, - named_shape=None): + def __init__(self, shape, dtype, index_dtype, nnz, weak_type=False): super().__init__(shape, dtypes.canonicalize_dtype(dtype)) - named_shape = {} if named_shape is None else named_shape self.index_dtype = index_dtype self.nnz = nnz - self.data_aval = core.ShapedArray((nnz,), dtypes.canonicalize_dtype(dtype), - weak_type, named_shape) + self.data_aval = core.ShapedArray( + (nnz,), dtypes.canonicalize_dtype(dtype), weak_type) self.indices_aval = core.ShapedArray( - (nnz, len(shape)), dtypes.canonicalize_dtype(index_dtype), - named_shape=named_shape) + (nnz, len(shape)), dtypes.canonicalize_dtype(index_dtype)) def update(self, shape=None, dtype=None, index_dtype=None, nnz=None, - weak_type=None, named_shape=None): + weak_type=None): if shape is None: shape = self.shape if dtype is None: @@ -92,10 +89,7 @@ def update(self, shape=None, dtype=None, index_dtype=None, nnz=None, nnz = self.nnz if weak_type is None: weak_type = self.weak_type - if named_shape is None: - named_shape = self.named_shape - return AbstractSparseArray( - shape, dtype, index_dtype, nnz, weak_type, named_shape) + return AbstractSparseArray(shape, dtype, index_dtype, nnz, weak_type) def strip_weak_type(self): return self diff --git a/tests/debug_nans_test.py b/tests/debug_nans_test.py index 19e2a5893835..020c9f744833 100644 --- a/tests/debug_nans_test.py +++ b/tests/debug_nans_test.py @@ -127,7 +127,7 @@ def testPjit(self): ans.block_until_ready() def testDebugNansJitWithDonation(self): - # https://github.com/google/jax/issues/12514 + # https://github.com/jax-ml/jax/issues/12514 a = jnp.array(0.) with self.assertRaises(FloatingPointError): ans = jax.jit(lambda x: 0. / x, donate_argnums=(0,))(a) @@ -214,7 +214,7 @@ def f(x): f(1) def testDebugNansDoesntCorruptCaches(self): - # https://github.com/google/jax/issues/6614 + # https://github.com/jax-ml/jax/issues/6614 @jax.jit def f(x): return jnp.divide(x, x) diff --git a/tests/debugging_primitives_test.py b/tests/debugging_primitives_test.py index a508373b61a7..5532fdf0303f 100644 --- a/tests/debugging_primitives_test.py +++ b/tests/debugging_primitives_test.py @@ -80,6 +80,18 @@ def f(x): jax.effects_barrier() self.assertEqual(output(), "x: 2\n") + def test_static_args(self): + @jax.jit + def f(arr): + jax.debug.print("arr {array}, dtype: {dtype}, arr {array2}", + array=arr, dtype=arr.dtype, array2=arr) + arr = jnp.array([1, 2, 3], dtype=jnp.float32) + with jtu.capture_stdout() as output: + f(arr) + jax.effects_barrier() + self.assertEqual( + output(), "arr [1. 2. 3.], dtype: float32, arr [1. 2. 3.]\n") + def test_debug_print_works_with_named_format_strings(self): def f(x): debug_print('x: {x}', x=x) @@ -106,6 +118,16 @@ def f(x): jax.effects_barrier() self.assertEqual(output(), "x: 2\n") + def test_can_stage_out_debug_print_with_formatting(self): + @jax.jit + def f(x): + debug_print('x: {x:.2f}', x=x) + + with jtu.capture_stdout() as output: + f(2) + jax.effects_barrier() + self.assertEqual(output(), "x: 2.00\n") + @jtu.device_supports_buffer_donation() def test_can_stage_out_debug_print_with_donate_argnums(self): def f(x, y): @@ -1098,7 +1120,7 @@ def f_(x): return jnp.square(x) f = jax.jit(f_) - mesh = jtu.create_global_mesh((2,), ('x')) + mesh = jtu.create_mesh((2,), ('x')) s = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec('x')) arr = jax.device_put(np.arange(8).reshape(2, 2, 2), s) @@ -1114,7 +1136,7 @@ def f_(x): return jnp.square(x) f = pjit.pjit(f_) - mesh = jtu.create_global_mesh((2,), ('x')) + mesh = jtu.create_mesh((2,), ('x')) s = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec('x')) arr = jax.device_put(np.arange(8).reshape(2, 2, 2), s) diff --git a/tests/dtypes_test.py b/tests/dtypes_test.py index c6b12e2d8a16..89d70871a8f9 100644 --- a/tests/dtypes_test.py +++ b/tests/dtypes_test.py @@ -19,6 +19,7 @@ from functools import partial import itertools import operator +import types from absl.testing import absltest from absl.testing import parameterized @@ -300,16 +301,6 @@ def testIsSubdtype(self): self.assertEqual(dtypes.issubdtype(t, category), np.issubdtype(np.dtype(t).type, category)) - def testIsSubdtypeExtended(self): - self.assertTrue(dtypes.issubdtype(dtypes.extended, dtypes.extended)) - self.assertTrue(dtypes.issubdtype(dtypes.extended, np.generic)) - self.assertFalse(dtypes.issubdtype(dtypes.extended, np.number)) - - self.assertTrue(jnp.issubdtype(dtypes.prng_key, dtypes.prng_key)) - self.assertTrue(jnp.issubdtype(dtypes.prng_key, dtypes.extended)) - self.assertTrue(jnp.issubdtype(dtypes.prng_key, np.generic)) - self.assertFalse(dtypes.issubdtype(dtypes.prng_key, np.number)) - @parameterized.product(dtype=custom_float_dtypes) def testIsSubdtypeCustomFloats(self, dtype): for dt in [dtype, np.dtype(dtype), str(np.dtype(dtype))]: @@ -408,6 +399,34 @@ def testDefaultDtypes(self): self.assertEqual(dtypes.float_, np.float32 if precision == '32' else np.float64) self.assertEqual(dtypes.complex_, np.complex64 if precision == '32' else np.complex128) + def test_check_dtype_non_hashable(self): + # regression test for issue with checking non-hashable custom dtype + class MyDtype: + __hash__ = None + dtype = np.dtype('float32') + dtypes.check_user_dtype_supported(MyDtype()) + + def test_check_dtype_array(self): + x = jnp.arange(4) + msg = "Passing an array as a dtype argument is deprecated" + with self.assertWarnsRegex(DeprecationWarning, msg): + dtypes.check_user_dtype_supported(x) + with self.assertWarnsRegex(DeprecationWarning, msg): + jax.jit(dtypes.check_user_dtype_supported)(x) + + +class ExtendedDTypeTest(jtu.JaxTestCase): + + def testIsSubdtypeExtended(self): + self.assertTrue(dtypes.issubdtype(dtypes.extended, dtypes.extended)) + self.assertTrue(dtypes.issubdtype(dtypes.extended, np.generic)) + self.assertFalse(dtypes.issubdtype(dtypes.extended, np.number)) + + self.assertTrue(jnp.issubdtype(dtypes.prng_key, dtypes.prng_key)) + self.assertTrue(jnp.issubdtype(dtypes.prng_key, dtypes.extended)) + self.assertTrue(jnp.issubdtype(dtypes.prng_key, np.generic)) + self.assertFalse(dtypes.issubdtype(dtypes.prng_key, np.number)) + def test_custom_tangent_dtype(self): from jax._src import core @@ -415,6 +434,8 @@ class scale(dtypes.extended): pass class ScalesTyRules: + allow_conversion: bool = True + @staticmethod def physical_element_aval(dtype) -> core.ShapedArray: return core.ShapedArray((), dtype.float_dtype) @@ -435,14 +456,6 @@ def zero(dt): else dtypes.finfo(dt.float_dtype).min, dt.float_dtype) return jax.lax.convert_element_type(neginf, dt) - @staticmethod - def convert_from(dtype, other_dtype) -> bool: - return dtype.float_dtype == other_dtype - - @staticmethod - def convert_to(other_dtype, dtype) -> bool: - return dtype.float_dtype == other_dtype - @dataclasses.dataclass(frozen=True) class ScaleTy(dtypes.ExtendedDType): float_dtype: dtypes.DType @@ -485,19 +498,13 @@ def test_custom_tangent_dtype_with_scan(self): from jax._src import core class ScalesTyRules: - # tell JAX how to lower this dtype to an HLO dtype + # tell JAX how to lower this dtype to an HLO representation dtype @staticmethod def physical_element_aval(dtype) -> core.ShapedArray: return core.ShapedArray((), dtype.float_dtype) - # allow conversions to and from the corresponding float type - @staticmethod - def convert_from(scale_dtype, other_dtype) -> bool: - return scale_dtype.float_dtype == other_dtype - - @staticmethod - def convert_to(other_dtype, scale_dtype) -> bool: - return scale_dtype.float_dtype == other_dtype + # allow conversions to and from the corresponding representation type + allow_conversion: bool = True # define how autodiff should accumulate these values @staticmethod @@ -563,20 +570,150 @@ def inner_bwd(prev_scale, grads): _, new_scale = jax.jit(jax.grad(outer, (0, 1)))(jnp.float32(3.14), scale) self.assertAllClose(new_scale, jnp.float32(1.0)) - def test_check_dtype_non_hashable(self): - # regression test for issue with checking non-hashable custom dtype - class MyDtype: - __hash__ = None - dtype = np.dtype('float32') - dtypes.check_user_dtype_supported(MyDtype()) + @parameterized.parameters([True]) # TODO(mattjj): make jit=False work + def test_primal_tangent_dtype(self, jit): + dt = dtypes.primal_tangent_dtype(jnp.int8, jnp.bfloat16) - def test_check_dtype_array(self): - x = jnp.arange(4) - msg = "Passing an array as a dtype argument is deprecated" - with self.assertWarnsRegex(DeprecationWarning, msg): - dtypes.check_user_dtype_supported(x) - with self.assertWarnsRegex(DeprecationWarning, msg): - jax.jit(dtypes.check_user_dtype_supported)(x) + x = jax.random.uniform(jax.random.key(0), (3,), minval=0, maxval=10 + ).astype(jnp.int8) + g = jax.random.uniform(jax.random.key(0), (3,), minval=0, maxval=10 + ).astype(jnp.bfloat16) + + @jax.custom_gradient + def f(x): + def bwd(g): + return 2 * g, + return jnp.int8(x).astype(g.dtype) * 2 + 1, bwd + + def h(): + result, bwd = jax.vjp(f, x.astype(dt)) + bwd_result, = bwd(g) + return result, bwd_result + + if jit: + h = jax.jit(h) + + result, bwd_result = h() + self.assertEqual(result.dtype, jnp.bfloat16) + self.assertEqual(bwd_result.dtype, jnp.bfloat16) + self.assertAllClose(bwd_result, 2 * g) + self.assertEqual(repr(dt), 'PrimalTangentDType{i8/bf16}') + + @parameterized.parameters(itertools.product([(), (2,), (3, 4)], repeat=2)) + def test_edtype_conversion(self, shape_prefix, shape_suffix): + class scalar(dtypes.extended): ... + + @dataclasses.dataclass(frozen=True) + class DType(dtypes.ExtendedDType): + name = 'dt' + type = scalar + _rules = types.SimpleNamespace( + physical_element_aval= + lambda _: types.SimpleNamespace(shape=shape_suffix, dtype='int32'), + allow_conversion=True) + dtype = DType() + + @jax.jit + def f(x): + self.assertEqual(x.shape, shape_prefix + shape_suffix) + self.assertEqual(x.dtype, jnp.dtype('int32')) + x = jax.lax.convert_element_type(x, dtype) + self.assertEqual(x.shape, shape_prefix) + self.assertEqual(x.dtype, dtype) + x = jax.lax.convert_element_type(x, 'int32') + self.assertEqual(x.shape, shape_prefix + shape_suffix) + self.assertEqual(x.dtype, jnp.dtype('int32')) + f(jnp.zeros(shape_prefix + shape_suffix, dtype='int32')) + + def test_edtype_conversion_errors(self): + class scalar(dtypes.extended): ... + + @dataclasses.dataclass(frozen=True) + class DType(dtypes.ExtendedDType): + name = 'dt' + type = scalar + _rules = types.SimpleNamespace( + physical_element_aval= + lambda _: types.SimpleNamespace(shape=(3,), dtype='int32'), + allow_conversion=True) + dtype = DType() + + class scalar2(dtypes.extended): ... + + @dataclasses.dataclass(frozen=True) + class DType2(dtypes.ExtendedDType): + name = 'dt2' + type = scalar2 + _rules = types.SimpleNamespace( + physical_element_aval= + lambda _: types.SimpleNamespace(shape=(3,), dtype='int32'), + allow_conversion=True) + dtype2 = DType2() + + @jax.jit + def f(x): + y = jax.lax.convert_element_type(x, dtype) + with self.assertRaisesRegex(ValueError, "cannot directly"): + jax.lax.convert_element_type(y, dtype2) + with self.assertRaisesRegex(ValueError, "can only convert"): + jax.lax.convert_element_type(x.astype('float32'), dtype) + with self.assertRaisesRegex(ValueError, "can only convert"): + jax.lax.convert_element_type(x[:, :2], dtype) + with self.assertRaisesRegex(ValueError, "can only convert"): + jax.lax.convert_element_type(x[:, 0], dtype) + with self.assertRaisesRegex(ValueError, "can only convert"): + jax.lax.convert_element_type(y, 'float32') + f(jnp.zeros((5, 3), dtype='int32')) + + def test_edtype_conversion_autodiff(self): + + class scalar(dtypes.extended): ... + + @dataclasses.dataclass(frozen=True) + class DType(dtypes.ExtendedDType): + name = 'dt' + type = scalar + _rules = types.SimpleNamespace( + physical_element_aval= + lambda _: types.SimpleNamespace(shape=(), dtype='float32'), + tangent_dtype=lambda dtype: jnp.dtype('bfloat16'), + allow_conversion=True) + dtype = DType() + + @jax.jit + @jax.grad + def f(x): + x = jax.lax.convert_element_type(x, dtype) + + @jax.custom_jvp + def g(x): return x + @g.defjvp + def g_jvp(primals, tangents): + (x,), (x_dot,) = primals, tangents + self.assertEqual(x.shape, (5,)) + self.assertEqual(x.dtype, dtype) + self.assertEqual(x_dot.shape, (5,)) + self.assertEqual(x_dot.dtype, jnp.dtype('bfloat16')) + return x, x_dot + x = g(x) + + x = jax.lax.convert_element_type(x, 'float32') + + @jax.custom_jvp + def h(x): return x + @h.defjvp + def h_jvp(primals, tangents): + (x,), (x_dot,) = primals, tangents + self.assertEqual(x.shape, (5,)) + self.assertEqual(x.dtype, jnp.dtype('float32')) + self.assertEqual(x_dot.shape, (5,)) + self.assertEqual(x_dot.dtype, jnp.dtype('float32')) + return x, x_dot + x = h(x) + + return 0. + + f(jnp.zeros(5, dtype='float32')) # test assertions in the function class EArrayTest(jtu.JaxTestCase): @@ -590,10 +727,7 @@ def test_extended_dtypes_at_rest(self, jit): class foo(dtypes.extended): pass class FooTyRules: - - @staticmethod - def convert_to(foo_dtype, target_dtype): - return True + allow_conversion: bool = True @staticmethod def physical_element_aval(foo_dtype): @@ -667,7 +801,7 @@ def testJaxTypeWeak(self, dtype): {"testcase_name": f"_{typ}", "typ": typ} for typ in [bool, int, float, complex]) def testScalarWeakTypes(self, typ): - # Regression test for https://github.com/google/jax/issues/11377 + # Regression test for https://github.com/jax-ml/jax/issues/11377 val = typ(0) result1 = jnp.array(val) @@ -806,7 +940,7 @@ def testBinaryPromotionJitInvariance(self, xtype, ytype, xfun, yfun): for weak_type in [True, False] ) def testUnaryPromotion(self, dtype, weak_type): - # Regression test for https://github.com/google/jax/issues/6051 + # Regression test for https://github.com/jax-ml/jax/issues/6051 if dtype in intn_dtypes: self.skipTest("XLA support for int2 and int4 is incomplete.") x = lax_internal._convert_element_type(0, dtype, weak_type=weak_type) @@ -852,7 +986,7 @@ def testBinaryNonPromotion(self, dtype, weak_type, promotion): self.skipTest("XLA support for float8 is incomplete.") if dtype in intn_dtypes: self.skipTest("XLA support for int2 and int4 is incomplete.") - # Regression test for https://github.com/google/jax/issues/6051 + # Regression test for https://github.com/jax-ml/jax/issues/6051 x = lax_internal._convert_element_type(0, dtype, weak_type=weak_type) with jax.numpy_dtype_promotion(promotion): y = (x + x) diff --git a/tests/dynamic_api_test.py b/tests/dynamic_api_test.py index 101dddccb7c1..f6625e86ca14 100644 --- a/tests/dynamic_api_test.py +++ b/tests/dynamic_api_test.py @@ -621,7 +621,7 @@ def test_flattening_basic(self): self.assertLessEqual(len(jaxpr.jaxpr.eqns), 3) def test_shape_validation(self): - # Regression test for https://github.com/google/jax/issues/18937 + # Regression test for https://github.com/jax-ml/jax/issues/18937 msg = r"Shapes must be 1D sequences of integer scalars, got .+" with self.assertRaisesRegex(TypeError, msg): jax.make_jaxpr(jnp.ones)(5.0) @@ -1486,6 +1486,9 @@ def f(i): jax_traceback_filtering='off') class JumbleTest(jtu.JaxTestCase): + def setUp(self): + if jax.config.x64_enabled: raise unittest.SkipTest() + @parameterized.parameters((True,), (False,)) def test_internal_jumble(self, disable_jit): with jax.disable_jit(disable_jit): diff --git a/tests/errors_test.py b/tests/errors_test.py index fa2dec95f0fa..7dfc4e51a6de 100644 --- a/tests/errors_test.py +++ b/tests/errors_test.py @@ -394,11 +394,19 @@ def test_grad_norm(self): class CustomErrorsTest(jtu.JaxTestCase): + @jtu.sample_product( - errorclass=[ - errorclass for errorclass in dir(jax.errors) - if errorclass.endswith('Error') and errorclass not in ['JaxIndexError', 'JAXTypeError'] - ], + errorclass=[ + errorclass + for errorclass in dir(jax.errors) + if errorclass.endswith('Error') + and errorclass + not in [ + 'JaxIndexError', + 'JAXTypeError', + 'JaxRuntimeError', + ] + ], ) def testErrorsURL(self, errorclass): class FakeTracer(core.Tracer): diff --git a/tests/export_back_compat_test.py b/tests/export_back_compat_test.py index 045f2e233465..103357ac18ac 100644 --- a/tests/export_back_compat_test.py +++ b/tests/export_back_compat_test.py @@ -44,6 +44,8 @@ from jax._src.internal_test_util.export_back_compat_test_data import cpu_svd_lapack_gesdd from jax._src.internal_test_util.export_back_compat_test_data import cpu_triangular_solve_blas_trsm from jax._src.internal_test_util.export_back_compat_test_data import cuda_threefry2x32 +from jax._src.internal_test_util.export_back_compat_test_data import cuda_lu_pivots_to_permutation +from jax._src.internal_test_util.export_back_compat_test_data import cuda_lu_cusolver_getrf from jax._src.internal_test_util.export_back_compat_test_data import tpu_Eigh from jax._src.internal_test_util.export_back_compat_test_data import tpu_Lu from jax._src.internal_test_util.export_back_compat_test_data import tpu_ApproxTopK @@ -64,7 +66,6 @@ from jax._src import config from jax._src import test_util as jtu from jax._src.lib import cuda_versions -from jax._src.lib import version as jaxlib_version config.parse_flags_with_absl() @@ -112,7 +113,11 @@ def test_custom_call_coverage(self): targets_to_cover = set(_export._CUSTOM_CALL_TARGETS_GUARANTEED_STABLE) cpu_ffi_testdatas = [ cpu_cholesky_lapack_potrf.data_2024_05_31, + cpu_qr_lapack_geqrf.data_2024_08_22, + cpu_eig_lapack_geev.data_2024_08_19, + cpu_eigh_lapack_syev.data_2024_08_19, cpu_lu_lapack_getrf.data_2024_05_31, + cpu_svd_lapack_gesdd.data_2024_08_13, ] # Add here all the testdatas that should cover the targets guaranteed # stable @@ -124,7 +129,10 @@ def test_custom_call_coverage(self): cpu_qr_lapack_geqrf.data_2023_03_17, cuda_threefry2x32.data_2023_03_15, cuda_threefry2x32.data_2024_07_30, cpu_lu_lapack_getrf.data_2023_06_14, - cuda_qr_cusolver_geqrf.data_2023_03_18, cuda_eigh_cusolver_syev.data_2023_03_17, + cuda_lu_pivots_to_permutation.data_2024_08_08, + cuda_lu_cusolver_getrf.data_2024_08_19, + cuda_qr_cusolver_geqrf.data_2023_03_18, + cuda_eigh_cusolver_syev.data_2023_03_17, rocm_qr_hipsolver_geqrf.data_2024_08_05, rocm_eigh_hipsolver_syev.data_2024_08_05, cpu_schur_lapack_gees.data_2023_07_16, @@ -181,14 +189,11 @@ def test_cpu_cholesky_lapack_potrf(self, dtype_name="f32"): atol = dict(f32=1e-4, f64=1e-12, c64=1e-4, c128=1e-12)[dtype_name] data = self.load_testdata(cpu_cholesky_lapack_potrf.data_2023_06_19[dtype_name]) - # TODO(b/344892332): Remove the check after the compatibility period. - has_xla_ffi_support = jaxlib_version >= (0, 4, 31) self.run_one_test(func, data, rtol=rtol, atol=atol) - if has_xla_ffi_support: - with config.export_ignore_forward_compatibility(True): - # FFI Kernel test - data = self.load_testdata(cpu_cholesky_lapack_potrf.data_2024_05_31[dtype_name]) - self.run_one_test(func, data, rtol=rtol, atol=atol) + with config.export_ignore_forward_compatibility(True): + # FFI Kernel test + data = self.load_testdata(cpu_cholesky_lapack_potrf.data_2024_05_31[dtype_name]) + self.run_one_test(func, data, rtol=rtol, atol=atol) @parameterized.named_parameters( dict(testcase_name=f"_dtype={dtype_name}", dtype_name=dtype_name) @@ -249,6 +254,11 @@ def check_eigenvalue_is_in_array(eigenvalue, eigenvalues_array): self.run_one_test(func, data, rtol=rtol, atol=atol, check_results=check_eig_results) + with config.export_ignore_forward_compatibility(True): + # FFI Kernel test + data = self.load_testdata(cpu_eig_lapack_geev.data_2024_08_19[dtype_name]) + self.run_one_test(func, data, rtol=rtol, atol=atol, + check_results=check_eig_results) @staticmethod def eigh_input(shape, dtype): @@ -299,6 +309,11 @@ def test_cpu_eigh_lapack_syevd(self, dtype_name="f32"): atol = dict(f32=1e-4, f64=1e-12, c64=1e-4, c128=1e-12)[dtype_name] self.run_one_test(func, data, rtol=rtol, atol=atol, check_results=partial(self.check_eigh_results, operand)) + # FFI Kernel test + with config.export_ignore_forward_compatibility(True): + data = self.load_testdata(cpu_eigh_lapack_syev.data_2024_08_19[dtype_name]) + self.run_one_test(func, data, rtol=rtol, atol=atol, + check_results=partial(self.check_eigh_results, operand)) @parameterized.named_parameters( dict(testcase_name=f"_dtype={dtype_name}_{variant}", @@ -342,6 +357,33 @@ def test_tpu_Eigh(self): self.run_one_test(func, data, rtol=1e-3, check_results=partial(self.check_eigh_results, operand)) + @staticmethod + def lu_pivots_to_permutation_harness(shape): + operand = jnp.reshape(jnp.arange(math.prod(shape), dtype=np.int32), shape) + return lax.linalg.lu_pivots_to_permutation(operand, permutation_size=8) + + def test_cuda_lu_pivots_to_permutation(self): + shape = (2, 3, 4) + func = lambda: CompatTest.lu_pivots_to_permutation_harness(shape) + data = self.load_testdata(cuda_lu_pivots_to_permutation.data_2024_08_08) + self.run_one_test(func, data) + + @parameterized.named_parameters( + dict(testcase_name=f"_dtype={dtype_name}", + dtype_name=dtype_name) + for dtype_name in ("f32", "f64", "c64", "c128")) + def test_cuda_lu_lapack_getrf(self, dtype_name:str): + if not config.enable_x64.value and dtype_name in ["f64", "c128"]: + self.skipTest("Test disabled for x32 mode") + dtype = dict(f32=np.float32, f64=np.float64, + c64=np.complex64, c128=np.complex128)[dtype_name] + shape = (3, 4) + func = lambda: CompatTest.lu_harness(shape, dtype) + # TODO(b/360788062): Clean up after the compatibility period. + with config.export_ignore_forward_compatibility(True): + data = self.load_testdata(cuda_lu_cusolver_getrf.data_2024_08_19[dtype_name]) + self.run_one_test(func, data) + @staticmethod def qr_harness(shape, dtype): # In order to keep inputs small, we construct the input programmatically @@ -362,6 +404,12 @@ def test_cpu_qr_lapack_geqrf(self, dtype_name="f32"): data = self.load_testdata(cpu_qr_lapack_geqrf.data_2023_03_17[dtype_name]) rtol = dict(f32=1e-3, f64=1e-5, c64=1e-3, c128=1e-5)[dtype_name] self.run_one_test(func, data, rtol=rtol) + with config.export_ignore_forward_compatibility(True): + # FFI Kernel test + data = self.load_testdata( + cpu_qr_lapack_geqrf.data_2024_08_22[dtype_name] + ) + self.run_one_test(func, data, rtol=rtol) @parameterized.named_parameters( dict(testcase_name=f"_dtype={dtype_name}_{batched}", @@ -439,14 +487,11 @@ def test_cpu_lu_lapack_getrf(self, dtype_name:str): self.run_one_test(func, data, rtol=rtol, atol=atol, check_results=partial(self.check_lu_results, operand, dtype=dtype)) - # TODO(b/344892332): Remove the check after the compatibility period. - has_xla_ffi_support = jaxlib_version >= (0, 4, 32) - if has_xla_ffi_support: - with config.export_ignore_forward_compatibility(True): - # FFI Kernel test - data = self.load_testdata(cpu_lu_lapack_getrf.data_2024_05_31[dtype_name]) - self.run_one_test(func, data, rtol=rtol, atol=atol, - check_results=partial(self.check_lu_results, operand, + with config.export_ignore_forward_compatibility(True): + # FFI Kernel test + data = self.load_testdata(cpu_lu_lapack_getrf.data_2024_05_31[dtype_name]) + self.run_one_test(func, data, rtol=rtol, atol=atol, + check_results=partial(self.check_lu_results, operand, dtype=dtype)) def check_svd_results(self, input, res_run, res_exp, @@ -566,6 +611,13 @@ def func(input): self.run_one_test(func, data, rtol=rtol, atol=atol, check_results=partial(self.check_svd_results, input)) + with config.export_ignore_forward_compatibility(True): + # FFI Kernel test + data = self.load_testdata( + cpu_svd_lapack_gesdd.data_2024_08_13[dtype_name] + ) + self.run_one_test(func, data, rtol=rtol, atol=atol, + check_results=partial(self.check_svd_results, input)) @jtu.parameterized_filterable( kwargs=[ diff --git a/tests/export_harnesses_multi_platform_test.py b/tests/export_harnesses_multi_platform_test.py index 21ad29c7a4c9..0f0c20fd78e3 100644 --- a/tests/export_harnesses_multi_platform_test.py +++ b/tests/export_harnesses_multi_platform_test.py @@ -100,6 +100,10 @@ def setUpClass(cls): ) @jtu.skip_on_flag("jax_skip_slow_tests", True) def test_prim(self, harness: test_harnesses.Harness): + if "eigh_" in harness.fullname: + self.skipTest("Eigenvalues are sorted and it is not correct to compare " + "decompositions for equality.") + if (jtu.device_under_test() == "gpu" and _known_failures_gpu.search(harness.fullname)): self.skipTest("failure to be investigated") diff --git a/tests/export_test.py b/tests/export_test.py index b269aef28d79..0d946d84d22b 100644 --- a/tests/export_test.py +++ b/tests/export_test.py @@ -473,7 +473,8 @@ def f(xi, xf): # Native JAX 1st order vjp (f_outi, f_outf), f_vjp = jax.vjp(f, xi, xf) - f_outi_ct = np.ones(f_outi.shape, dtype=f_outi.dtype) + f_outi_ct = np.ones(f_outi.shape, + dtype=core.primal_dtype_to_tangent_dtype(f_outi.dtype)) f_outf_ct = np.ones(f_outf.shape, dtype=f_outf.dtype) xi_ct, xf_ct = f_vjp((f_outi_ct, f_outf_ct)) @@ -1333,7 +1334,7 @@ def f_jax(x): # x: f32[10,20] -> f32[20,10] def test_grad_sharding_different_mesh(self): # Export and serialize with two similar meshes, the only difference being # the order of the devices. grad and serialization should not fail. - # https://github.com/google/jax/issues/21314 + # https://github.com/jax-ml/jax/issues/21314 def f(x): return jnp.sum(x * 2.) diff --git a/tests/extend_test.py b/tests/extend_test.py index f34d40cd3556..fff3314a7656 100644 --- a/tests/extend_test.py +++ b/tests/extend_test.py @@ -15,7 +15,8 @@ import os import numpy as np -from absl.testing import absltest, parameterized +from absl.testing import absltest +from absl.testing import parameterized import jax from jax import lax @@ -30,8 +31,8 @@ from jax._src import test_util as jtu from jax._src import xla_bridge from jax._src.interpreters import mlir -from jax._src.lib.mlir import ir -from jax._src.extend import ffi +from jax._src.layout import DeviceLocalLayout +from jax._src.lib.mlir.dialects import hlo jax.config.parse_flags_with_absl() @@ -97,30 +98,85 @@ def no_rule(*args, **kwargs): class FfiTest(jtu.JaxTestCase): + def find_custom_call_in_module(self, module): + for func in module.body.operations: + for block in func.body.blocks: + for op in block.operations: + if op.OPERATION_NAME == "stablehlo.custom_call": + return op + self.fail("No custom_call found in the lowered IR") + def testHeadersExist(self): base_dir = os.path.join(jex.ffi.include_dir(), "xla", "ffi", "api") for header in ["c_api.h", "api.h", "ffi.h"]: self.assertTrue(os.path.exists(os.path.join(base_dir, header))) - @parameterized.parameters( - [True, int(1), float(5.0), - np.int32(-5), np.float32(0.5)]) - def testIrAttribute(self, value): - with mlir.make_ir_context(), ir.Location.unknown(): - const = mlir.ir_constant(value) - attr = ffi._ir_attribute(value) - assert const.type.element_type == attr.type - - @parameterized.parameters([True, 1, 5.0, "param", np.float32(0.5)]) - def testParams(self, param): + @parameterized.parameters([ + (tuple(range(3)), tuple(range(3))), + (None, tuple(reversed(range(3)))), + (DeviceLocalLayout(tuple(range(3))), tuple(reversed(range(3)))), + ]) + def testLoweringLayouts(self, layout_spec, expected_layout): + # Regression test to ensure that the lowering rule properly captures + # layouts. + def lowering_rule(ctx, x): + aval, = ctx.avals_in + ndim = len(aval.shape) + return jex.ffi.ffi_lowering("test_ffi", operand_layouts=[layout_spec], + result_layouts=[layout_spec])(ctx, x) prim = core.Primitive("test_ffi") - prim.def_abstract_eval(lambda *args, **kwargs: args[0]) - mlir.register_lowering(prim, jex.ffi.ffi_lowering("test_ffi")) - - # TODO(dfm): Currently testing that lowering works with different types of - # parameters, but we should probably actually check the emitted HLO. - func = jax.jit(lambda *args: prim.bind(*args, param=param)) - func.lower(jnp.linspace(0, 5, 10)) + prim.def_impl(lambda x: x) + prim.def_abstract_eval(lambda x: x) + mlir.register_lowering(prim, lowering_rule) + + x = jnp.ones((3,) * len(expected_layout)) + lowered = jax.jit(prim.bind).lower(x) + module = lowered.compiler_ir("stablehlo") + op = self.find_custom_call_in_module(module) + self.assertIn("operand_layouts", op.attributes) + self.assertIn("result_layouts", op.attributes) + + text = lowered.as_text() + expected = ", ".join(map(str, expected_layout)) + pattern = rf"operand_layouts = \[dense<\[{expected}\]>" + self.assertRegex(text, pattern) + pattern = rf"result_layouts = \[dense<\[{expected}\]>" + self.assertRegex(text, pattern) + + @parameterized.parameters([ + (True, mlir.ir.BoolAttr.get), + (1, mlir.i64_attr), + (5.0, lambda x: mlir.ir.FloatAttr.get(mlir.ir.F64Type.get(), x)), + ("param", mlir.ir.StringAttr.get), + (np.float32(0.5), + lambda x: mlir.ir.FloatAttr.get(mlir.ir.F32Type.get(), x)), + ]) + def testParams(self, param, expected_builder): + def fun(x): + return jex.ffi.ffi_call("test_ffi", x, x, param=param) + + # Here we inspect the lowered IR to test that the parameter has been + # serialized with the appropriate type. + module = jax.jit(fun).lower(0.5).compiler_ir("stablehlo") + op = self.find_custom_call_in_module(module) + config = op.attributes["mhlo.backend_config"] + self.assertIsInstance(config, mlir.ir.DictAttr) + self.assertIn("param", config) + with mlir.make_ir_context(), mlir.ir.Location.unknown(): + expected = expected_builder(param) + self.assertEqual(type(config["param"]), type(expected)) + self.assertTrue(expected.type.isinstance(config["param"].type)) + + def testToken(self): + def fun(): + token = lax.create_token() + return jex.ffi.ffi_call("test_ffi", core.abstract_token, token) + + # Ensure that token inputs and outputs are translated to the correct type + module = jax.jit(fun).lower().compiler_ir("stablehlo") + op = self.find_custom_call_in_module(module) + self.assertTrue(hlo.TokenType.isinstance(op.operands[0].type)) + self.assertTrue(hlo.TokenType.isinstance(op.results[0].type)) @jtu.sample_product( shape=[(1,), (4,), (5,)], @@ -166,6 +222,7 @@ def ffi_call_lu_pivots_to_permutation(pivots, permutation_size, vectorized=True) dtype=pivots.dtype, ), pivots, + # TODO(b/358275922): Remove this after jaxlib v0.4.32 is released. permutation_size=np.int32(permutation_size), vectorized=vectorized, ) diff --git a/tests/fft_test.py b/tests/fft_test.py index 05fa96a93fae..a87b7b66e150 100644 --- a/tests/fft_test.py +++ b/tests/fft_test.py @@ -175,7 +175,7 @@ def testFftn(self, inverse, real, shape, dtype, axes, s, norm): self.assertEqual(dtype, expected_dtype) def testIrfftTranspose(self): - # regression test for https://github.com/google/jax/issues/6223 + # regression test for https://github.com/jax-ml/jax/issues/6223 def build_matrix(linear_func, size): return jax.vmap(linear_func)(jnp.eye(size, size)) diff --git a/tests/filecheck/math.filecheck.py b/tests/filecheck/math.filecheck.py index e75e8e7d735f..f34b8211eb33 100644 --- a/tests/filecheck/math.filecheck.py +++ b/tests/filecheck/math.filecheck.py @@ -419,7 +419,7 @@ def integer_pow(x): return lax.integer_pow(x, 3) print_ir(jnp.bfloat16(0))(lax.sqrt) # CHECK-LABEL: TEST: tan float16[] - # CHECK: chlo.tan + # CHECK: hlo.tan # CHECK-SAME: tensor print_ir(np.float16(0))(lax.tan) diff --git a/tests/filecheck/subcomputations.filecheck.py b/tests/filecheck/subcomputations.filecheck.py index 1f8e9d32e5b1..b3c3191ca416 100644 --- a/tests/filecheck/subcomputations.filecheck.py +++ b/tests/filecheck/subcomputations.filecheck.py @@ -19,7 +19,6 @@ from absl import app import jax -from jax import numpy as jnp from jax.interpreters import mlir from jax._src.lib.mlir import ir import numpy as np @@ -39,7 +38,7 @@ def main(_): # CHECK-NOT: func private @cumsum @print_ir(np.empty([2, 7], np.int32), np.empty([2, 7], np.int32)) def cumsum_only_once(x, y): - return jnp.cumsum(x) + jnp.cumsum(y) + return jax.lax.cumsum(x) + jax.lax.cumsum(y) # Test merging modules # CHECK-LABEL: TEST: merge_modules diff --git a/tests/for_loop_test.py b/tests/for_loop_test.py index b79c233e6f2e..438ba55203a9 100644 --- a/tests/for_loop_test.py +++ b/tests/for_loop_test.py @@ -319,8 +319,10 @@ def f(a, b): _, f_lin = jax.linearize(f, a, b) expected_tangents = f_lin(a, b) _, actual_tangents = jax.jvp(f, (a, b), (a, b)) - np.testing.assert_allclose(actual_tangents[0], expected_tangents[0]) - np.testing.assert_allclose(actual_tangents[1], expected_tangents[1]) + np.testing.assert_allclose(actual_tangents[0], expected_tangents[0], + rtol=1e-6, atol=1e-6) + np.testing.assert_allclose(actual_tangents[1], expected_tangents[1], + rtol=1e-6, atol=1e-6) def body2(_, refs): # Here we use `i_ref` as a loop counter @@ -343,7 +345,8 @@ def g(a, b): expected_tangents = g_lin(a, b) _, actual_tangents = jax.jvp(g, (a, b), (a, b)) np.testing.assert_allclose(actual_tangents[0], expected_tangents[0]) - np.testing.assert_allclose(actual_tangents[1], expected_tangents[1]) + np.testing.assert_allclose(actual_tangents[1], expected_tangents[1], + rtol=1e-6) @jtu.sample_product( [dict(for_body_name=for_body_name, f=for_body, ref=ref, diff --git a/tests/fused_attention_stablehlo_test.py b/tests/fused_attention_stablehlo_test.py index bc05c4b2e85c..2cfcfa7c5ec6 100644 --- a/tests/fused_attention_stablehlo_test.py +++ b/tests/fused_attention_stablehlo_test.py @@ -47,7 +47,8 @@ def sdpa_train(query: Array, scale: float = 0.5, mask_type: MaskType = MaskType.NO_MASK, is_bnth: bool = False, - dropout_rate: float = 0.1) -> Array: + dropout_rate: float = 0.1, + sliding_window_length: int | None = None) -> Array: if mask_type == MaskType.PADDING: if is_bnth: B, _, S, _ = query.shape @@ -59,7 +60,8 @@ def sdpa_train(query: Array, out, sdpa_vjp = jax.vjp( partial(dot_product_attention, scale=scale, mask_type=mask_type, dropout_rate=dropout_rate, - qkv_layout="BNTH" if is_bnth else "BTNH"), + qkv_layout="BNTH" if is_bnth else "BTNH", + sliding_window_length=sliding_window_length), query, key, value, bias, mask, q_seqlen, kv_seqlen) query_grad, key_grad, value_grad, bias_grad, _, _, _ = sdpa_vjp(grad) if bias is not None and len(bias.shape) == 3: @@ -74,7 +76,8 @@ def sdpa_ref(query: Array, mask: Array | None = None, scale: float = 0.5, mask_type: MaskType = MaskType.NO_MASK, - dropout_rate: float = 0.1) -> Array: + dropout_rate: float = 0.1, + sliding_window_length: int | None = None) -> Array: def get_causal_mask(logits): large_negative_number = get_large_negative_number(logits.dtype) @@ -99,6 +102,16 @@ def get_encoded_padding_mask(encoded): return jax.lax.broadcast_in_dim( encoded_padding, encoded.shape, broadcast_dimensions=[1]) + def get_sliding_window_mask(logits, window_length): + large_negative_number = get_large_negative_number(logits.dtype) + T = logits.shape[-2] + col_idx = jax.lax.broadcasted_iota(np.int32, (T, T), 1) + row_idx = jax.lax.broadcasted_iota(np.int32, (T, T), 0) + mask = jnp.logical_or( + row_idx < col_idx, + col_idx <= row_idx - window_length).astype(logits.dtype) * large_negative_number + return mask[(*([jnp.newaxis]*(len(logits.shape) - 2)), ...)] + B, T, qN, H = query.shape _, _, kN, _ = key.shape logits = jnp.einsum("bqhd,bkhd->bhqk", query, key) @@ -108,6 +121,11 @@ def get_encoded_padding_mask(encoded): bias = get_causal_mask(logits) elif mask_type == MaskType.PADDING: bias = get_padding_mask(logits) + elif sliding_window_length is not None: + if sliding_window_length <= 0: + raise ValueError( + f"Expect sliding_window_length > 0, got {sliding_window_length}.") + bias = get_sliding_window_mask(logits, sliding_window_length) if mask is not None: large_negative_number = get_large_negative_number(logits.dtype) mask = jnp.where(mask, jnp.asarray(0, query.dtype), large_negative_number) @@ -141,10 +159,12 @@ def sdpa_train_ref(query: Array, mask: Array | None = None, scale: float = 0.5, mask_type: MaskType = MaskType.NO_MASK, - dropout_rate: float = 0.1) -> Array: + dropout_rate: float = 0.1, + sliding_window_length: int | None = None) -> Array: out_ref, sdpa_vjp_ref = jax.vjp( partial( - sdpa_ref, scale=scale, mask_type=mask_type, dropout_rate=dropout_rate), + sdpa_ref, scale=scale, mask_type=mask_type, dropout_rate=dropout_rate, + sliding_window_length=sliding_window_length), query, key, value, bias, mask) query_grad_ref, key_grad_ref, value_grad_ref, bias_grad_ref, _ = sdpa_vjp_ref(grad) if bias is not None and len(bias.shape) == 3: @@ -399,6 +419,78 @@ def test_sdpa_broadcast_bias_and_dbias(self): self.assertArraysAllClose(value_grad_ref, value_grad, rtol=1e-5, atol=1e-5) self.assertArraysAllClose(bias_grad_ref, bias_grad, rtol=1e-5, atol=1e-5) + @jtu.sample_product( + batch_size=[1, 16], + ) + @jtu.run_on_devices("cuda") + def test_sdpa_dbias(self, batch_size: int): + # cuDNN only supports dbias when batch size is 1. If the batch size is + # greater, dbias is silently set to all zeros. This test verifies this + # behavior for both vmap and regular use cases. + # TODO: Remove this test once cuDNN adds broader dbias support. + dtype = jnp.bfloat16 + x_shape = (batch_size, 512, 16, 48) + bias_shape = (batch_size, 16, 512, 512) + mask_shape = (1, 1, 512) + + keys = jax.random.split(jax.random.key(0), 2) + x = jax.random.normal(keys[0], x_shape, dtype=dtype) + bias = jax.random.normal(keys[1], bias_shape, dtype=dtype) + mask = jnp.ones(mask_shape, dtype=jnp.bool_) + + def attn(x, bias, mask): + return dot_product_attention(x, x, x, bias, mask) + + def attn_vjp(x, bias, mask, target_fn): + _, f_vjp = jax.vjp(target_fn, x, bias, mask) + return f_vjp(x) + + attn_vmap = jax.vmap(attn, in_axes=(0, 0, None)) + attn_ref = jax.jit(partial(attn_vjp, target_fn=attn)) + attn_ans = jax.jit(partial(attn_vjp, target_fn=attn_vmap)) + + _, dbias_ref, _ = attn_ref(x, bias, mask) + x = jnp.expand_dims(x, axis=1) + bias = jnp.expand_dims(bias, axis=1) + _, dbias_ans, _ = attn_ans(x, bias, mask) + dbias_ans = jnp.squeeze(dbias_ans, axis=1) + self.assertArraysAllClose(dbias_ans, dbias_ref) + if batch_size != 1: + self.assertTrue(not jnp.any(dbias_ans)) + + @jtu.run_on_devices("cuda") + def test_sdpa_sliding_window_length(self): + k1, k2, k3, k4 = jax.random.split(jax.random.key(0), 4) + query = jax.random.normal( + k1, (4, 1024, 4, 64), dtype=jnp.bfloat16) + key = jax.random.normal( + k2, (4, 1024, 4, 64), dtype=jnp.bfloat16) + value = jax.random.normal( + k3, (4, 1024, 4, 64), dtype=jnp.bfloat16) + grad = jax.random.normal( + k4, (4, 1024, 4, 64), dtype=jnp.bfloat16) + jitted_sdpa_train = jax.jit( + partial( + sdpa_train, scale=1.0, mask_type=MaskType.CAUSAL, dropout_rate=0, + sliding_window_length=64), + ) + # for reference implementation + # sliding_window_length option itself will setup correct mask + jitted_sdpa_train_ref = jax.jit( + partial( + sdpa_train_ref, scale=1.0, mask_type=MaskType.NO_MASK, dropout_rate=0, + sliding_window_length=64), + ) + + out, (query_grad, key_grad, value_grad) = \ + jitted_sdpa_train(query, key, value, grad, None, None) + out_ref, (query_grad_ref, key_grad_ref, value_grad_ref) = \ + jitted_sdpa_train_ref(query, key, value, grad, None, None) + self.assertArraysAllClose(out_ref, out, rtol=1e-5, atol=1e-5) + self.assertArraysAllClose(query_grad_ref, query_grad, rtol=1e-2, atol=1e-2) + self.assertArraysAllClose(key_grad_ref, key_grad, rtol=1e-5, atol=1e-5) + self.assertArraysAllClose(value_grad_ref, value_grad, rtol=1e-5, atol=1e-5) + @jtu.run_on_devices("cuda") def test_layouts(self): dtype = "bfloat16" @@ -429,18 +521,27 @@ def _cvt_back(x): def test_sdpa_utils(self): test_cases = [ - (1, 257, 64, 8905, False, True), - (1, 1024, 64, 8905, False, False), - (1024, 1024, 64, 8905, False, False), - (1024, 1024, 128, 8905, False, False), + (1, 257, 64, 8905, False, True, True), + (1, 1024, 64, 8905, False, False, True), + (1024, 1024, 64, 8905, False, False, True), + (1024, 1024, 128, 8905, False, False, True), + (1024, 1024, 127, 8905, False, False, False), ] for k in test_cases: - sql_q, sql_v, head_dim, cudnn_version, has_bias, is_training = k + sql_q, sql_v, head_dim, cudnn_version, has_bias, is_training, \ + expected_pass = k query = jnp.empty((4, sql_q, 4, head_dim)) key = jnp.empty((4, sql_v, 4, head_dim)) - check_is_flash_attention( - query, key, AttentionLayout.BNTH, cudnn_version, has_bias, is_training) + if expected_pass: + check_is_flash_attention( + query, key, AttentionLayout.BNTH.value, cudnn_version, has_bias, + is_training) + else: + with self.assertRaises(NotImplementedError): + check_is_flash_attention( + query, key, AttentionLayout.BNTH.value, cudnn_version, has_bias, + is_training) if __name__ == "__main__": absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/tests/host_callback_test.py b/tests/host_callback_test.py index 5988b2774408..837d205fbbed 100644 --- a/tests/host_callback_test.py +++ b/tests/host_callback_test.py @@ -1035,7 +1035,7 @@ def func(x, yint): ( 5.00 2 )""", testing_stream.output) def test_tap_grad_float0_result(self): - # https://github.com/google/jax/issues/7340 + # https://github.com/jax-ml/jax/issues/7340 # x is a Tuple[f32[2], s32[3]] x = (np.array([.7, .8], dtype=np.float32), np.array([11, 12, 13], dtype=np.int32)) @@ -1058,7 +1058,7 @@ def f_jax_vjp(x): ( [0.70 0.80] [11 12 13] )""", testing_stream.output) def test_tap_higher_order_grad_float0_result(self): - # https://github.com/google/jax/issues/7340 + # https://github.com/jax-ml/jax/issues/7340 # x is a Tuple[f32[2], s32[3]] x = (np.array([.7, .8], dtype=np.float32), np.array([11, 12, 13], dtype=np.int32)) @@ -1935,7 +1935,7 @@ def func(x, transforms, y): hcb.id_tap(func, 1, y=2) def test_tap_id_tap_random_key(self): - # See https://github.com/google/jax/issues/13949 + # See https://github.com/jax-ml/jax/issues/13949 with jax.enable_custom_prng(): @jax.jit def f(x): @@ -2240,7 +2240,7 @@ def f_outside(arg): def test_call_cond(self): def f_outside(args): x, y = args - return x * y + return x * y.astype(np.float32) def loop(x, use_outside=True): def body(i, acc): @@ -2253,8 +2253,8 @@ def body(i, acc): return lax.fori_loop(0, 18, body, x) - res_inside = loop(1.2, use_outside=False) - self.assertAllClose(res_inside, jax.jit(loop)(1.2)) + res_inside = loop(np.float32(1.2), use_outside=False) + self.assertAllClose(res_inside, jax.jit(loop)(np.float32(1.2))) def test_call_jit_scan_call(self): def f_outside(x): diff --git a/tests/host_callback_to_tf_test.py b/tests/host_callback_to_tf_test.py index fe80c90ace68..3a36ce1296a6 100644 --- a/tests/host_callback_to_tf_test.py +++ b/tests/host_callback_to_tf_test.py @@ -176,6 +176,8 @@ def supported_only_in_legacy_mode(self): testcase_name=f"_{ad=}", ad=ad) for ad in CALL_TF_IMPLEMENTATIONS.keys()) + @jtu.ignore_warning(message="The host_callback APIs are deprecated", + category=DeprecationWarning) def test_impl(self, ad="simple"): self.supported_only_in_legacy_mode() call_tf = CALL_TF_IMPLEMENTATIONS[ad] @@ -197,6 +199,8 @@ def f_outside(x): ad=ad) for ad in CALL_TF_IMPLEMENTATIONS.keys() if ad != "none") + @jtu.ignore_warning(message="The host_callback APIs are deprecated", + category=DeprecationWarning) def test_grad(self, ad="simple"): self.supported_only_in_legacy_mode() call_tf = CALL_TF_IMPLEMENTATIONS[ad] @@ -217,6 +221,8 @@ def f_outside(x): self.assertAllClose(jax.grad(f_jax)(x), grad_f, check_dtypes=False) + @jtu.ignore_warning(message="The host_callback APIs are deprecated", + category=DeprecationWarning) def test_grad_pytree(self): self.supported_only_in_legacy_mode() call_tf = call_tf_full_ad @@ -246,6 +252,8 @@ def f_outside(xy): testcase_name=f"_degree=_{degree}", degree=degree) for degree in [1, 2, 3, 4]) + @jtu.ignore_warning(message="The host_callback APIs are deprecated", + category=DeprecationWarning) def test_higher_order_grad(self, degree=4): self.supported_only_in_legacy_mode() call_tf = call_tf_full_ad diff --git a/tests/image_test.py b/tests/image_test.py index f3cd56ed7622..0f6341086d19 100644 --- a/tests/image_test.py +++ b/tests/image_test.py @@ -180,7 +180,7 @@ def testResizeGradients(self, dtype, image_shape, target_shape, method, antialias=[False, True], ) def testResizeEmpty(self, dtype, image_shape, target_shape, method, antialias): - # Regression test for https://github.com/google/jax/issues/7586 + # Regression test for https://github.com/jax-ml/jax/issues/7586 image = np.ones(image_shape, dtype) out = jax.image.resize(image, shape=target_shape, method=method, antialias=antialias) self.assertArraysEqual(out, jnp.zeros(target_shape, dtype)) diff --git a/tests/jet_test.py b/tests/jet_test.py index b1e2ef3f8380..4e437c044426 100644 --- a/tests/jet_test.py +++ b/tests/jet_test.py @@ -404,7 +404,7 @@ def g(x): self.assertArraysEqual(g_out_series, f_out_series) def test_add_any(self): - # https://github.com/google/jax/issues/5217 + # https://github.com/jax-ml/jax/issues/5217 f = lambda x, eps: x * eps + eps + x def g(eps): x = jnp.array(1.) @@ -412,7 +412,7 @@ def g(eps): jet(g, (1.,), ([1.],)) # doesn't crash def test_scatter_add(self): - # very basic test from https://github.com/google/jax/issues/5365 + # very basic test from https://github.com/jax-ml/jax/issues/5365 def f(x): x0 = x[0] x1 = x[1] diff --git a/tests/lax_autodiff_test.py b/tests/lax_autodiff_test.py index ab3a183177f6..78d90cb8a072 100644 --- a/tests/lax_autodiff_test.py +++ b/tests/lax_autodiff_test.py @@ -424,7 +424,7 @@ def testDotGeneralContractAndBatchGrads(self, lhs_shape, rhs_shape, dtype, assert "Precision.HIGHEST" in s def testDotPreferredElementType(self): - # https://github.com/google/jax/issues/10818 + # https://github.com/jax-ml/jax/issues/10818 x = jax.numpy.ones((), jax.numpy.float16) def f(x): return jax.lax.dot_general(x, x, (((), ()), ((), ())), @@ -513,7 +513,7 @@ def testReverseGrad(self): rtol={np.float32: 3e-3}) def testPowSecondDerivative(self): - # https://github.com/google/jax/issues/12033 + # https://github.com/jax-ml/jax/issues/12033 x, y = 4.0, 0.0 expected = ((0.0, 1/x), (1/x, np.log(x) ** 2)) @@ -528,18 +528,18 @@ def testPowSecondDerivative(self): with self.subTest("zero to the zero"): result = jax.grad(lax.pow)(0.0, 0.0) # TODO(jakevdp) special-case zero in a way that doesn't break other cases - # See https://github.com/google/jax/pull/12041#issuecomment-1222766191 + # See https://github.com/jax-ml/jax/pull/12041#issuecomment-1222766191 # self.assertEqual(result, 0.0) self.assertAllClose(result, np.nan) def testPowIntPowerAtZero(self): - # https://github.com/google/jax/issues/14397 + # https://github.com/jax-ml/jax/issues/14397 ans = jax.grad(jax.jit(lambda x, n: x ** n))(0., 0) self.assertAllClose(ans, 0., check_dtypes=False) @jax.numpy_dtype_promotion('standard') # This test explicitly exercises mixed type promotion def testPowIntPowerAtZero2(self): - # https://github.com/google/jax/issues/17995 + # https://github.com/jax-ml/jax/issues/17995 a = lambda z: jax.numpy.sum(z**jax.numpy.arange(0, 2, dtype=int)) b = lambda z: jax.numpy.sum(z**jax.numpy.arange(0, 2, dtype=float)) c = lambda z: 1 + z @@ -634,7 +634,7 @@ def testDynamicUpdateSliceGrad(self, shape, dtype, start_indices, check_grads(dus, (update,), 2, ["fwd", "rev"], eps=1.) def testDynamicSliceValueAndGrad(self): - # Regression test for https://github.com/google/jax/issues/10984 + # Regression test for https://github.com/jax-ml/jax/issues/10984 # Issue arose due to an out-of-range negative index. rng = jtu.rand_default(self.rng()) shape = (5, 5) @@ -649,7 +649,7 @@ def f(x): self.assertAllClose(result1, result2) def testDynamicUpdateSliceValueAndGrad(self): - # Regression test for https://github.com/google/jax/issues/10984 + # Regression test for https://github.com/jax-ml/jax/issues/10984 # Issue arose due to an out-of-range negative index. rng = jtu.rand_default(self.rng()) shape = (5, 5) @@ -1004,7 +1004,7 @@ def testScatterGrad(self, arg_shape, dtype, idxs, update_shape, dnums, check_grads(scatter, (x, y), 2, ["fwd", "rev"], 1e-2, 1e-2, 1.) def testScatterGradSymbolicZeroUpdate(self): - # https://github.com/google/jax/issues/1901 + # https://github.com/jax-ml/jax/issues/1901 def f(x): n = x.shape[0] y = np.arange(n, dtype=x.dtype) @@ -1111,7 +1111,7 @@ def gen_y(rng, size): check_grads(lax.rem, (x, y), 2, ["fwd", "rev"]) def testHigherOrderGradientOfReciprocal(self): - # Regression test for https://github.com/google/jax/issues/3136 + # Regression test for https://github.com/jax-ml/jax/issues/3136 def inv(x): # N.B.: intentionally written as 1/x, not x ** -1 or reciprocal(x) return 1 / x @@ -1150,7 +1150,7 @@ def f(x): jax.jacrev(f)(x) def testPowShapeMismatch(self): - # Regression test for https://github.com/google/jax/issues/17294 + # Regression test for https://github.com/jax-ml/jax/issues/17294 x = lax.iota('float32', 4) y = 2 actual = jax.jacrev(jax.jit(jax.lax.pow))(x, y) # no error diff --git a/tests/lax_control_flow_test.py b/tests/lax_control_flow_test.py index d52862ec42ac..7fb118d47256 100644 --- a/tests/lax_control_flow_test.py +++ b/tests/lax_control_flow_test.py @@ -733,7 +733,7 @@ def false_fun(x): self.assertEqual(fun(4), (8, 16)) def testCondPredIsNone(self): - # see https://github.com/google/jax/issues/11574 + # see https://github.com/jax-ml/jax/issues/11574 def f(pred, x): return lax.cond(pred, lambda x: x + 1, lambda x: x + 2, x) @@ -743,7 +743,7 @@ def f(pred, x): lambda: jax.jit(f)(None, 1.)) def testCondTwoOperands(self): - # see https://github.com/google/jax/issues/8469 + # see https://github.com/jax-ml/jax/issues/8469 add, mul = lax.add, lax.mul def fun(x): @@ -775,7 +775,7 @@ def cfun(x): self.assertEqual(fun(1), cfun(1)) def testCondCallableOperands(self): - # see https://github.com/google/jax/issues/16413 + # see https://github.com/jax-ml/jax/issues/16413 @tree_util.register_pytree_node_class class Foo: @@ -1560,7 +1560,7 @@ def f(x): {"testcase_name": f"_{name}", "cond": cond} for cond, name in COND_IMPLS) def testCondVmapGrad(self, cond): - # https://github.com/google/jax/issues/2264 + # https://github.com/jax-ml/jax/issues/2264 def f_1(x): return x ** 2 def f_2(x): return x ** 3 @@ -1680,7 +1680,8 @@ def f(c, a): tol = {np.float64: 1e-12, np.float32: 1e-4} self.assertAllClose(ans, expected, check_dtypes=False, rtol=tol, atol=tol) - jtu.check_grads(partial(scan, f), (c, as_), order=2, modes=["fwd"]) + jtu.check_grads(partial(scan, f), (c, as_), order=2, modes=["fwd"], + rtol={jnp.float32: 2e-1}) @parameterized.named_parameters( {"testcase_name": f"_{jit_scan=}_{jit_f=}_impl={scan_name}", @@ -1838,7 +1839,7 @@ def loss(params, inputs, targets): def testIssue711(self, scan): # Tests reverse-mode differentiation through a scan for which the scanned # function also involves reverse-mode differentiation. - # See https://github.com/google/jax/issues/711 + # See https://github.com/jax-ml/jax/issues/711 def harmonic_bond(conf, params): return jnp.sum(conf * params) @@ -2077,7 +2078,7 @@ def scan_body(c, x): self.assertAllClose(carry_out[0], jnp.array([2., 2., 2.]), check_dtypes = False) def testIssue757(self): - # code from https://github.com/google/jax/issues/757 + # code from https://github.com/jax-ml/jax/issues/757 def fn(a): return jnp.cos(a) @@ -2106,7 +2107,7 @@ def testMap(self): self.assertAllClose(actual, expected) def testMapEmpty(self): - # https://github.com/google/jax/issues/2412 + # https://github.com/jax-ml/jax/issues/2412 ans = lax.map(lambda x: x * x, jnp.array([])) expected = jnp.array([]) self.assertAllClose(ans, expected) @@ -2163,7 +2164,7 @@ def body(x): lax.while_loop(cond, body, 0) def test_caches_depend_on_axis_env(self): - # https://github.com/google/jax/issues/9187 + # https://github.com/jax-ml/jax/issues/9187 scanned_f = lambda _, __: (lax.psum(1, 'i'), None) f = lambda: lax.scan(scanned_f, 0, None, length=1)[0] ans = jax.vmap(f, axis_name='i', axis_size=2, out_axes=None)() @@ -2442,7 +2443,7 @@ def f(h, _): self.assertEqual(h, length) def test_disable_jit_cond_with_vmap(self): - # https://github.com/google/jax/issues/3093 + # https://github.com/jax-ml/jax/issues/3093 def fn(t): return lax.cond(t > 0, 0, lambda x: 0, 0, lambda x: 1) fn = jax.vmap(fn) @@ -2451,14 +2452,14 @@ def fn(t): _ = fn(jnp.array([1])) # doesn't crash def test_disable_jit_while_loop_with_vmap(self): - # https://github.com/google/jax/issues/2823 + # https://github.com/jax-ml/jax/issues/2823 def trivial_while(y): return lax.while_loop(lambda x: x < 10.0, lambda x: x + 1.0, y) with jax.disable_jit(): jax.vmap(trivial_while)(jnp.array([3.0,4.0])) # doesn't crash def test_vmaps_of_while_loop(self): - # https://github.com/google/jax/issues/3164 + # https://github.com/jax-ml/jax/issues/3164 def f(x, n): return lax.fori_loop(0, n, lambda _, x: x + 1, x) x, n = jnp.arange(3), jnp.arange(4) jax.vmap(jax.vmap(f, (None, 0)), (0, None))(x, n) # doesn't crash @@ -2566,7 +2567,7 @@ def new_jaxpr(): lambda: core.check_jaxpr(jaxpr)) def test_cond_transformation_rule_with_consts(self): - # https://github.com/google/jax/pull/9731 + # https://github.com/jax-ml/jax/pull/9731 @jax.custom_jvp def f(x): @@ -2583,14 +2584,14 @@ def f_jvp(primals, tangents): jax.jvp(g, (x,), (x,)) # doesn't crash def test_cond_excessive_compilation(self): - # Regression test for https://github.com/google/jax/issues/14058 + # Regression test for https://github.com/jax-ml/jax/issues/14058 def f(x): return x + 1 def g(x): return x + 2 - with jtu.count_jit_and_pmap_compiles() as count: + with jtu.count_jit_and_pmap_lowerings() as count: for x in range(10): lax.cond(x, f, g, x) # Should observe a maximum of 4 compiles: convert_element_type, f, g, cond @@ -2631,7 +2632,7 @@ def body_fun(val): ('new_remat', new_checkpoint), ]) def test_scan_vjp_forwards_extensive_residuals(self, remat): - # https://github.com/google/jax/issues/4510 + # https://github.com/jax-ml/jax/issues/4510 def cumprod(x): s = jnp.ones((2, 32), jnp.float32) return lax.scan(lambda s, x: (x*s, s), s, x) @@ -2670,7 +2671,7 @@ def scan(state, xs): (jnp.array([1.]), jnp.array([[0., 1., 2., 3., 4.]])), check_dtypes=False) def test_xla_cpu_gpu_loop_cond_bug(self): - # https://github.com/google/jax/issues/5900 + # https://github.com/jax-ml/jax/issues/5900 def deriv(f): return lambda x, *args: jax.linearize(lambda x: f(x, *args), x)[1](1.0) @@ -2749,7 +2750,7 @@ def body(c, _): jax.grad(f)(1.) # doesn't crash def test_custom_jvp_tangent_cond_transpose(self): - # https://github.com/google/jax/issues/14026 + # https://github.com/jax-ml/jax/issues/14026 def mask_fun(arr, choice): out = (1 - choice) * arr.sum() + choice * (1 - arr.sum()) return out @@ -2950,6 +2951,25 @@ def body(carry, x): hlo_text = fn.lower(init).as_text('hlo') self.assertNotIn('4,1,2,2', hlo_text) + def test_scan_length_concrete_error(self): + f = jax.jit(lambda n, x: jax.lax.scan(lambda c, z: (c, z), x, (), n)) + + with self.assertRaisesRegex( + core.ConcretizationTypeError, + "The `length` argument to `scan` expects a concrete `int` value.*"): + f(3, 1.) + + def test_scan_unroll_concrete_error(self): + f = jax.jit(lambda n, x: jax.lax.scan( + lambda c, z: (c, z), x, (), 10, unroll=n)) + + msg = ("The `unroll` argument to `scan` expects a concrete `int` or " + "`bool` value.*") + with self.assertRaisesRegex(core.ConcretizationTypeError, msg): + f(3, 1.) + with self.assertRaisesRegex(core.ConcretizationTypeError, msg): + f(True, 1.) + def test_cond_vmap_forwarding_doesnt_promote(self): def f(x, y): x, y = jax.lax.cond( @@ -2977,7 +2997,7 @@ def test_cond_casting(self): self.assertIsInstance(y, jax.Array) def test_cond_memory_leak(self): - # https://github.com/google/jax/issues/12719 + # https://github.com/jax-ml/jax/issues/12719 def leak(): data = jax.device_put(np.zeros((1024), dtype=np.float32) + 1) diff --git a/tests/lax_metal_test.py b/tests/lax_metal_test.py index 02fecb7b3f1a..d3dada0d750a 100644 --- a/tests/lax_metal_test.py +++ b/tests/lax_metal_test.py @@ -914,7 +914,7 @@ def testOperatorRound(self, jit): check_dtypes=False) def testRoundMethod(self): - # https://github.com/google/jax/issues/15190 + # https://github.com/jax-ml/jax/issues/15190 (jnp.arange(3.) / 5.).round() # doesn't crash @jtu.sample_product(shape=[(5,), (5, 2)]) @@ -1425,7 +1425,7 @@ def testIntegerPower(self, ptype): y=[0, 32, 64, 128], ) def testIntegerPowerOverflow(self, x, y): - # Regression test for https://github.com/google/jax/issues/5987 + # Regression test for https://github.com/jax-ml/jax/issues/5987 args_maker = lambda: [x, y] self._CheckAgainstNumpy(np.power, jnp.power, args_maker) self._CompileAndCheck(jnp.power, args_maker) @@ -1536,7 +1536,7 @@ def testConcatenateArray(self, shape, dtype, axis): self._CompileAndCheck(jnp_fun, args_maker) def testConcatenateAxisNone(self): - # https://github.com/google/jax/issues/3419 + # https://github.com/jax-ml/jax/issues/3419 a = jnp.array([[1, 2], [3, 4]]) b = jnp.array([[5]]) jnp.concatenate((a, b), axis=None) @@ -2768,7 +2768,7 @@ def np_fun(x, n=n, axis=axis, prepend=prepend, append=append): self._CompileAndCheck(jnp_fun, args_maker) def testDiffPrepoendScalar(self): - # Regression test for https://github.com/google/jax/issues/19362 + # Regression test for https://github.com/jax-ml/jax/issues/19362 x = jnp.arange(10) result_jax = jnp.diff(x, prepend=x[0], append=x[-1]) @@ -3359,7 +3359,7 @@ def _check(obj, out_dtype, weak_type): _check([jnp.complex128(1)], np.complex128, False) # Mixed inputs use JAX-style promotion. - # (regression test for https://github.com/google/jax/issues/8945) + # (regression test for https://github.com/jax-ml/jax/issues/8945) _check([0, np.int16(1)], np.int16, False) _check([0.0, np.float16(1)], np.float16, False) @@ -3932,17 +3932,17 @@ def testPathologicalFloats(self): # TODO(mattjj): test other ndarray-like method overrides def testNpMean(self): - # from https://github.com/google/jax/issues/125 + # from https://github.com/jax-ml/jax/issues/125 x = jnp.eye(3, dtype=float) + 0. ans = np.mean(x) self.assertAllClose(ans, np.array(1./3), check_dtypes=False) def testArangeOnFloats(self): np_arange = jtu.with_jax_dtype_defaults(np.arange) - # from https://github.com/google/jax/issues/145 + # from https://github.com/jax-ml/jax/issues/145 self.assertAllClose(np_arange(0.0, 1.0, 0.1), jnp.arange(0.0, 1.0, 0.1)) - # from https://github.com/google/jax/issues/3450 + # from https://github.com/jax-ml/jax/issues/3450 self.assertAllClose(np_arange(2.5), jnp.arange(2.5)) self.assertAllClose(np_arange(0., 2.5), @@ -4303,7 +4303,7 @@ def args_maker(): self._CompileAndCheck(jnp_op, args_maker) def testTakeAlongAxisWithUint8IndicesDoesNotOverflow(self): - # https://github.com/google/jax/issues/5088 + # https://github.com/jax-ml/jax/issues/5088 h = jtu.rand_default(self.rng())((256, 256, 100), np.float32) g = jtu.rand_int(self.rng(), 0, 100)((256, 256, 1), np.uint8) q0 = jnp.take_along_axis(h, g, axis=-1) @@ -4513,9 +4513,9 @@ def testSymmetrizeDtypePromotion(self): # NOTE(mattjj): I disabled this test when removing lax._safe_mul because # introducing the convention 0 * inf = 0 leads to silently wrong results in # some cases. See this comment for details: - # https://github.com/google/jax/issues/1052#issuecomment-514083352 + # https://github.com/jax-ml/jax/issues/1052#issuecomment-514083352 # def testIssue347(self): - # # https://github.com/google/jax/issues/347 + # # https://github.com/jax-ml/jax/issues/347 # def test_fail(x): # x = jnp.sqrt(jnp.sum(x ** 2, axis=1)) # ones = jnp.ones_like(x) @@ -4526,7 +4526,7 @@ def testSymmetrizeDtypePromotion(self): # assert not np.any(np.isnan(result)) def testIssue453(self): - # https://github.com/google/jax/issues/453 + # https://github.com/jax-ml/jax/issues/453 a = np.arange(6) + 1 ans = jnp.reshape(a, (3, 2), order='F') expected = np.reshape(a, (3, 2), order='F') @@ -4538,7 +4538,7 @@ def testIssue453(self): op=["atleast_1d", "atleast_2d", "atleast_3d"], ) def testAtLeastNdLiterals(self, dtype, op): - # Fixes: https://github.com/google/jax/issues/634 + # Fixes: https://github.com/jax-ml/jax/issues/634 np_fun = lambda arg: getattr(np, op)(arg).astype(dtypes.python_scalar_dtypes[dtype]) jnp_fun = lambda arg: getattr(jnp, op)(arg) args_maker = lambda: [dtype(2)] @@ -5147,7 +5147,7 @@ def testDisableNumpyRankPromotionBroadcastingDecorator(self): jnp.ones(2) + 3 # don't want to warn for scalars def testStackArrayArgument(self): - # tests https://github.com/google/jax/issues/1271 + # tests https://github.com/jax-ml/jax/issues/1271 @jax.jit def foo(x): return jnp.stack(x) @@ -5316,7 +5316,7 @@ def testGradient(self, shape, varargs, axis, dtype): self._CompileAndCheck(jnp_fun, args_maker) def testZerosShapeErrors(self): - # see https://github.com/google/jax/issues/1822 + # see https://github.com/jax-ml/jax/issues/1822 self.assertRaisesRegex( TypeError, "Shapes must be 1D sequences of concrete values of integer type.*", @@ -5334,7 +5334,7 @@ def testTraceMethod(self): self.assertAllClose(x.trace(), jax.jit(lambda y: y.trace())(x)) def testIntegerPowersArePrecise(self): - # See https://github.com/google/jax/pull/3036 + # See https://github.com/jax-ml/jax/pull/3036 # Checks if the squares of float32 integers have no numerical errors. # It should be satisfied with all integers less than sqrt(2**24). x = jnp.arange(-2**12, 2**12, dtype=jnp.int32) @@ -5405,7 +5405,7 @@ def testArange64Bit(self, dtype): self._CompileAndCheck(jnp_fun, args_maker) def testIssue2347(self): - # https://github.com/google/jax/issues/2347 + # https://github.com/jax-ml/jax/issues/2347 object_list = list[tuple[jnp.array, float, float, jnp.array, bool]] self.assertRaises(TypeError, jnp.array, object_list) @@ -5617,7 +5617,7 @@ def jax_metal_supported(target_ver): return False - #https://github.com/google/jax/issues/16420 + #https://github.com/jax-ml/jax/issues/16420 def test_broadcast_dim(self): x = jnp.arange(2) f = lambda x : jax.lax.broadcast_in_dim(x, (2, 2), (0,)) @@ -5640,7 +5640,7 @@ def test_triu(self): res = jnp.triu(x) jtu.check_eq(res, np.triu(x)) - #https://github.com/google/jax/issues/16471 + #https://github.com/jax-ml/jax/issues/16471 def test_matmul_1d(self): x = np.array(np.random.rand(3, 3)) y = np.array(np.random.rand(3)) @@ -5650,7 +5650,7 @@ def test_matmul_1d(self): res = jnp.dot(x, y) self.assertArraysAllClose(res, np.dot(x,y)) - #https://github.com/google/jax/issues/17175 + #https://github.com/jax-ml/jax/issues/17175 def test_indexing(self): x = jnp.array([[1,2,3],[4,5,6],[7,8,9],[10,11,12]], dtype=jnp.float32) @jax.vmap @@ -5661,7 +5661,7 @@ def f(i): res = f(idx) jtu.check_eq(res, np.array([[4., 5., 6.], [4., 5., 6.], [7., 8., 9.], [7., 8., 9.], [1., 2., 3.]])) - #https://github.com/google/jax/issues/17344 + #https://github.com/jax-ml/jax/issues/17344 def test_take_along_axis(self): @jax.jit def f(): @@ -5672,7 +5672,7 @@ def f(): return jnp.take_along_axis(x, idx, axis=1) jtu.check_eq(f(), self.dispatchOn([], f)) - #https://github.com/google/jax/issues/17590 + #https://github.com/jax-ml/jax/issues/17590 def test_in1d(self): a = np.array([123,2,4]) b = np.array([123,1]) @@ -5688,7 +5688,7 @@ def f(x): res = f(x) jtu.check_eq(res, np.array([[1., 2., 3.], [1., 5., 6.,], [1., 8., 9.], [1., 11., 12.]])) - #https://github.com/google/jax/issues/16326 + #https://github.com/jax-ml/jax/issues/16326 def test_indexing_update2(self): @jax.jit def f(x, r): @@ -5722,7 +5722,7 @@ def test_gather_ir(self): print(res) jtu.check_eq(res, res_ref) - #https://github.com/google/jax/issues/16366 + #https://github.com/jax-ml/jax/issues/16366 def test_pad_interior_1(self): if not ReportedIssuesTests.jax_metal_supported('0.0.6'): raise unittest.SkipTest("jax-metal version doesn't support it.") diff --git a/tests/lax_numpy_einsum_test.py b/tests/lax_numpy_einsum_test.py index 7397cf3e4ee8..ea7bff1d09fc 100644 --- a/tests/lax_numpy_einsum_test.py +++ b/tests/lax_numpy_einsum_test.py @@ -89,7 +89,7 @@ def test_two_operands_5(self): self._check(s, x, y) def test_two_operands_6(self): - # based on https://github.com/google/jax/issues/37#issuecomment-448572187 + # based on https://github.com/jax-ml/jax/issues/37#issuecomment-448572187 r = self.rng() x = r.randn(2, 1) y = r.randn(2, 3, 4) diff --git a/tests/lax_numpy_indexing_test.py b/tests/lax_numpy_indexing_test.py index cbb8e92ed603..d58a5c2c3866 100644 --- a/tests/lax_numpy_indexing_test.py +++ b/tests/lax_numpy_indexing_test.py @@ -496,7 +496,7 @@ def jnp_op(x, idx): self._CompileAndCheck(jnp_op_idx, args_maker) def testIndexApplyBatchingBug(self): - # https://github.com/google/jax/issues/16655 + # https://github.com/jax-ml/jax/issues/16655 arr = jnp.array([[1, 2, 3, 4, 5, 6]]) ind = jnp.array([3]) func = lambda a, i: a.at[i].apply(lambda x: x - 1) @@ -505,7 +505,7 @@ def testIndexApplyBatchingBug(self): self.assertArraysEqual(out, expected) def testIndexUpdateScalarBug(self): - # https://github.com/google/jax/issues/14923 + # https://github.com/jax-ml/jax/issues/14923 a = jnp.arange(10.) out = a.at[0].apply(jnp.cos) self.assertArraysEqual(out, a.at[0].set(1)) @@ -835,7 +835,7 @@ def testBooleanIndexingArray2D(self): self.assertAllClose(ans, expected, check_dtypes=False) def testBoolean1DIndexingWithEllipsis(self): - # Regression test for https://github.com/google/jax/issues/8412 + # Regression test for https://github.com/jax-ml/jax/issues/8412 x = np.arange(24).reshape(4, 3, 2) idx = (..., np.array([True, False])) ans = jnp.array(x)[idx] @@ -843,7 +843,7 @@ def testBoolean1DIndexingWithEllipsis(self): self.assertAllClose(ans, expected, check_dtypes=False) def testBoolean1DIndexingWithEllipsis2(self): - # Regression test for https://github.com/google/jax/issues/9050 + # Regression test for https://github.com/jax-ml/jax/issues/9050 x = np.arange(3) idx = (..., np.array([True, False, True])) ans = jnp.array(x)[idx] @@ -936,7 +936,7 @@ def testSimpleIndexingUsesSlice(self): self.assertEqual(jaxpr.jaxpr.eqns[-1].primitive, lax.squeeze_p) def testTrivialGatherIsntGenerated(self): - # https://github.com/google/jax/issues/1621 + # https://github.com/jax-ml/jax/issues/1621 jaxpr = jax.make_jaxpr(lambda x: x[:, None])(np.arange(4)) self.assertEqual(len(jaxpr.jaxpr.eqns), 1) self.assertNotIn('gather', str(jaxpr)) @@ -988,14 +988,14 @@ def testBooleanIndexingWithEmptyResult(self): self.assertAllClose(ans, expected, check_dtypes=False) def testBooleanIndexingShapeMismatch(self): - # Regression test for https://github.com/google/jax/issues/7329 + # Regression test for https://github.com/jax-ml/jax/issues/7329 x = jnp.arange(4) idx = jnp.array([True, False]) with self.assertRaisesRegex(IndexError, "boolean index did not match shape.*"): x[idx] def testBooleanIndexingWithNone(self): - # Regression test for https://github.com/google/jax/issues/18542 + # Regression test for https://github.com/jax-ml/jax/issues/18542 x = jnp.arange(6).reshape(2, 3) idx = (None, jnp.array([True, False])) ans = x[idx] @@ -1003,7 +1003,7 @@ def testBooleanIndexingWithNone(self): self.assertAllClose(ans, expected) def testBooleanIndexingWithNoneAndEllipsis(self): - # Regression test for https://github.com/google/jax/issues/18542 + # Regression test for https://github.com/jax-ml/jax/issues/18542 x = jnp.arange(6).reshape(2, 3) mask = jnp.array([True, False, False]) ans = x[None, ..., mask] @@ -1011,7 +1011,7 @@ def testBooleanIndexingWithNoneAndEllipsis(self): self.assertAllClose(ans, expected) def testBooleanIndexingWithEllipsisAndNone(self): - # Regression test for https://github.com/google/jax/issues/18542 + # Regression test for https://github.com/jax-ml/jax/issues/18542 x = jnp.arange(6).reshape(2, 3) mask = jnp.array([True, False, False]) ans = x[..., None, mask] @@ -1030,6 +1030,23 @@ def testNontrivialBooleanIndexing(self): self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) self._CompileAndCheck(jnp_fun, args_maker) + @parameterized.parameters( + [(3,), (0,)], + [(3, 4), (0,)], + [(3, 4), (0, 4)], + [(3, 4), (3, 0)], + [(3, 4, 5), (3, 0)], + ) + def testEmptyBooleanIndexing(self, x_shape, m_shape): + # Regression test for https://github.com/jax-ml/jax/issues/22886 + rng = jtu.rand_default(self.rng()) + args_maker = lambda: [rng(x_shape, np.int32), np.empty(m_shape, dtype=bool)] + + np_fun = lambda x, m: np.asarray(x)[np.asarray(m)] + jnp_fun = lambda x, m: jnp.asarray(x)[jnp.asarray(m)] + + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) + @jtu.sample_product( shape=[(2, 3, 4, 5)], idx=[ @@ -1103,7 +1120,7 @@ def testStrIndexingError(self): with self.assertRaisesRegex(TypeError, msg): jnp.zeros((2, 3))[:, 'abc'] - def testIndexOutOfBounds(self): # https://github.com/google/jax/issues/2245 + def testIndexOutOfBounds(self): # https://github.com/jax-ml/jax/issues/2245 x = jnp.arange(5, dtype=jnp.int32) + 1 self.assertAllClose(x, x[:10]) @@ -1596,7 +1613,7 @@ def np_fun(data, segment_ids): self._CompileAndCheck(jnp_fun, args_maker) def testIndexDtypeError(self): - # https://github.com/google/jax/issues/2795 + # https://github.com/jax-ml/jax/issues/2795 jnp.array(1) # get rid of startup warning with self.assertNoWarnings(): jnp.zeros(5).at[::2].set(1) @@ -1630,13 +1647,13 @@ def testIndexSequenceDeprecation(self, idx, idx_type): x.at[normalize(idx)].set(0) def testIndexedUpdateAliasingBug(self): - # https://github.com/google/jax/issues/7461 + # https://github.com/jax-ml/jax/issues/7461 fn = lambda x: x.at[1:].set(1 + x[:-1]) y = jnp.zeros(8) self.assertArraysEqual(fn(y), jax.jit(fn)(y)) def testScatterValuesCastToTargetDType(self): - # https://github.com/google/jax/issues/15505 + # https://github.com/jax-ml/jax/issues/15505 a = jnp.zeros(1, dtype=jnp.uint32) val = 2**32 - 1 # too large for int32 diff --git a/tests/lax_numpy_operators_test.py b/tests/lax_numpy_operators_test.py index 4c31684e145f..45a780c9f721 100644 --- a/tests/lax_numpy_operators_test.py +++ b/tests/lax_numpy_operators_test.py @@ -654,7 +654,13 @@ def testShiftOpAgainstNumpy(self, op, dtypes, shapes): shift_rng = jtu.rand_int(self.rng(), high=max(info.bits, shift_info.bits)) args_maker = lambda: (x_rng(shapes[0], dtype), shift_rng(shapes[1], shift_dtype)) - np_op = getattr(np, op.__name__) + if jtu.numpy_version() < (2, 0, 0) and op.__name__ in ("bitwise_left_shift", "bitwise_right_shift"): + # numpy < 2.0.0 does not have bitwise shift functions. + op_name = op.__name__.removeprefix("bitwise_") + else: + op_name = op.__name__ + + np_op = getattr(np, op_name) with jtu.strict_promotion_if_dtypes_match(dtypes): self._CompileAndCheck(op, args_maker) @@ -691,7 +697,7 @@ def __rmul__(self, other): self.assertIsInstance(jax.jit(operator.mul)(b, a), MyArray) def testI0Grad(self): - # Regression test for https://github.com/google/jax/issues/11479 + # Regression test for https://github.com/jax-ml/jax/issues/11479 dx = jax.grad(jax.numpy.i0)(0.0) self.assertArraysEqual(dx, 0.0) diff --git a/tests/lax_numpy_reducers_test.py b/tests/lax_numpy_reducers_test.py index 4767e48c3f5e..623c11a51998 100644 --- a/tests/lax_numpy_reducers_test.py +++ b/tests/lax_numpy_reducers_test.py @@ -27,6 +27,7 @@ from jax import numpy as jnp from jax._src import config +from jax._src import deprecations from jax._src import dtypes from jax._src import test_util as jtu from jax._src.util import NumpyComplexWarning @@ -424,7 +425,7 @@ def testReducerWhere(self, name, rng_factory, shape, dtype, axis, if (shape in [()] + scalar_shapes and dtype in [jnp.int16, jnp.uint16] and jnp_op in [jnp.min, jnp.max]): - self.skipTest("Known XLA failure; see https://github.com/google/jax/issues/4971.") + self.skipTest("Known XLA failure; see https://github.com/jax-ml/jax/issues/4971.") rng = rng_factory(self.rng()) is_bf16_nan_test = dtype == jnp.bfloat16 and rng_factory.__name__ == 'rand_some_nan' # Do not pass where via args_maker as that is incompatible with _promote_like_jnp. @@ -581,7 +582,7 @@ def np_fun(x): size=[0, 1, 2] ) def testStdOrVarLargeDdofReturnsNan(self, jnp_fn, size): - # test for https://github.com/google/jax/issues/21330 + # test for https://github.com/jax-ml/jax/issues/21330 x = jnp.arange(size) self.assertTrue(np.isnan(jnp_fn(x, ddof=size))) self.assertTrue(np.isnan(jnp_fn(x, ddof=size + 1))) @@ -621,7 +622,7 @@ def np_fun(x): atol=tol) def testNanStdGrad(self): - # Regression test for https://github.com/google/jax/issues/8128 + # Regression test for https://github.com/jax-ml/jax/issues/8128 x = jnp.arange(5.0).at[0].set(jnp.nan) y = jax.grad(jnp.nanvar)(x) self.assertAllClose(y, jnp.array([0.0, -0.75, -0.25, 0.25, 0.75]), check_dtypes=False) @@ -715,17 +716,31 @@ def np_fun(*args): # TODO(phawkins): we currently set dtype=False because we aren't as # aggressive about promoting to float64. It's not clear we want to mimic # Numpy here. - tol_spec = {np.float16: 1E-2, np.float32: 2e-4, np.float64: 5e-6} + tol_spec = {np.float16: 4e-2, np.float32: 2e-4, np.float64: 5e-6} tol = max(jtu.tolerance(a_dtype, tol_spec), jtu.tolerance(q_dtype, tol_spec)) self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False, tol=tol) self._CompileAndCheck(jnp_fun, args_maker, rtol=tol) + @jtu.sample_product( + op=['quantile', 'nanquantile', 'percentile', 'nanpercentile'] + ) + def testQuantileDeprecatedArgs(self, op): + func = getattr(jnp, op) + msg = f"The interpolation= argument to '{op}' is deprecated. " + def assert_warns_or_errors(msg=msg): + if deprecations.is_accelerated("jax-numpy-quantile-interpolation"): + return self.assertRaisesRegex(ValueError, msg) + else: + return self.assertWarnsRegex(DeprecationWarning, msg) + with assert_warns_or_errors(msg): + func(jnp.arange(4), 0.5, interpolation='linear') + @unittest.skipIf(not config.enable_x64.value, "test requires X64") @jtu.run_on_devices("cpu") # test is for CPU float64 precision def testPercentilePrecision(self): - # Regression test for https://github.com/google/jax/issues/8513 + # Regression test for https://github.com/jax-ml/jax/issues/8513 x = jnp.float64([1, 2, 3, 4, 7, 10]) self.assertEqual(jnp.percentile(x, 50), 3.5) @@ -763,14 +778,14 @@ def np_fun(*args): self._CompileAndCheck(jnp_fun, args_maker, rtol=tol) def testMeanLargeArray(self): - # https://github.com/google/jax/issues/15068 + # https://github.com/jax-ml/jax/issues/15068 raise unittest.SkipTest("test is slow, but it passes!") x = jnp.ones((16, 32, 1280, 4096), dtype='int8') self.assertEqual(1.0, jnp.mean(x)) self.assertEqual(1.0, jnp.mean(x, where=True)) def testStdLargeArray(self): - # https://github.com/google/jax/issues/15068 + # https://github.com/jax-ml/jax/issues/15068 raise unittest.SkipTest("test is slow, but it passes!") x = jnp.ones((16, 32, 1280, 4096), dtype='int8') self.assertEqual(0.0, jnp.std(x)) @@ -846,6 +861,10 @@ def testCumulativeSumErrors(self, shape, dtype, include_initial): with self.assertRaisesRegex(ValueError, msg): jnp.cumulative_sum(x, include_initial=include_initial) + def testCumulativeSumBool(self): + out = jnp.cumulative_sum(jnp.array([[0.1], [0.1], [0.0]]), axis=-1, + dtype=jnp.bool_) + np.testing.assert_array_equal(np.array([[True], [True], [False]]), out) if __name__ == "__main__": absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/tests/lax_numpy_test.py b/tests/lax_numpy_test.py index a4a7fa896ae4..6f8167df9c29 100644 --- a/tests/lax_numpy_test.py +++ b/tests/lax_numpy_test.py @@ -18,7 +18,7 @@ import collections from collections.abc import Iterator import copy -from functools import partial +from functools import partial, wraps import inspect import io import itertools @@ -161,6 +161,38 @@ def _shapes_are_equal_length(shapes): return all(len(shape) == len(shapes[0]) for shape in shapes[1:]) +def arrays_with_overlapping_values(rng, shapes, dtypes, unique=False, overlap=0.5) -> list[jax.Array]: + """Generate multiple arrays with some overlapping values. + + This is useful for tests of set-like operations. + """ + assert 0 <= overlap <= 1 + sizes = [math.prod(jtu._dims_of_shape(shape)) for shape in shapes] + total_size = int(sum(sizes) * (1 - overlap)) + max(sizes) # non-strict upper-bound. + if unique: + vals = jtu.rand_unique_int(rng)((total_size,), 'int32') + else: + vals = jtu.rand_default(rng)((total_size,), 'int32') + offsets = [int(sum(sizes[:i]) * (1 - overlap)) for i in range(len(sizes))] + return [rng.permutation(vals[offset: offset + size]).reshape(shape).astype(dtype) + for (offset, size, shape, dtype) in zip(offsets, sizes, shapes, dtypes)] + + +def with_size_argument(fun): + @wraps(fun) + def wrapped(*args, size=None, fill_value=None, **kwargs): + result = fun(*args, **kwargs) + if size is None or size == len(result): + return result + elif size < len(result): + return result[:size] + else: + if fill_value is None: + fill_value = result.min() if result.size else 0 + return np.pad(result, (0, size - len(result)), constant_values=fill_value) + return wrapped + + class LaxBackedNumpyTests(jtu.JaxTestCase): """Tests for LAX-backed Numpy implementation.""" @@ -671,11 +703,12 @@ def testTensordotErrors(self): test_shape=all_shapes, dtype=default_dtypes, invert=[False, True], + method=['auto', 'compare_all', 'binary_search', 'sort'] ) - def testIsin(self, element_shape, test_shape, dtype, invert): + def testIsin(self, element_shape, test_shape, dtype, invert, method): rng = jtu.rand_default(self.rng()) args_maker = lambda: [rng(element_shape, dtype), rng(test_shape, dtype)] - jnp_fun = lambda e, t: jnp.isin(e, t, invert=invert) + jnp_fun = lambda e, t: jnp.isin(e, t, invert=invert, method=method) np_fun = lambda e, t: np.isin(e, t, invert=invert) self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) self._CompileAndCheck(jnp_fun, args_maker) @@ -685,11 +718,12 @@ def testIsin(self, element_shape, test_shape, dtype, invert): dtype2=[s for s in default_dtypes if s != jnp.bfloat16], shape1=all_shapes, shape2=all_shapes, + overlap=[0.1, 0.5, 0.9], ) - def testSetdiff1d(self, shape1, shape2, dtype1, dtype2): - rng = jtu.rand_default(self.rng()) - args_maker = lambda: [rng(shape1, dtype1), rng(shape2, dtype2)] - + def testSetdiff1d(self, shape1, shape2, dtype1, dtype2, overlap): + args_maker = partial(arrays_with_overlapping_values, self.rng(), + shapes=[shape1, shape2], dtypes=[dtype1, dtype2], + overlap=overlap) with jtu.strict_promotion_if_dtypes_match([dtype1, dtype2]): self._CheckAgainstNumpy(np.setdiff1d, jnp.setdiff1d, args_maker) @@ -700,10 +734,12 @@ def testSetdiff1d(self, shape1, shape2, dtype1, dtype2): shape2=all_shapes, size=[1, 5, 10], fill_value=[None, -1], + overlap=[0.1, 0.5, 0.9], ) - def testSetdiff1dSize(self, shape1, shape2, dtype1, dtype2, size, fill_value): - rng = jtu.rand_default(self.rng()) - args_maker = lambda: [rng(shape1, dtype1), rng(shape2, dtype2)] + def testSetdiff1dSize(self, shape1, shape2, dtype1, dtype2, size, fill_value, overlap): + args_maker = partial(arrays_with_overlapping_values, self.rng(), + shapes=[shape1, shape2], dtypes=[dtype1, dtype2], + overlap=overlap) def np_fun(arg1, arg2): result = np.setdiff1d(arg1, arg2) if size <= len(result): @@ -719,12 +755,14 @@ def jnp_fun(arg1, arg2): @jtu.sample_product( dtype1=[s for s in default_dtypes if s != jnp.bfloat16], dtype2=[s for s in default_dtypes if s != jnp.bfloat16], - shape1=nonempty_nonscalar_array_shapes, - shape2=nonempty_nonscalar_array_shapes, + shape1=all_shapes, + shape2=all_shapes, + overlap=[0.1, 0.5, 0.9], ) - def testUnion1d(self, shape1, shape2, dtype1, dtype2): - rng = jtu.rand_default(self.rng()) - args_maker = lambda: [rng(shape1, dtype1), rng(shape2, dtype2)] + def testUnion1d(self, shape1, shape2, dtype1, dtype2, overlap): + args_maker = partial(arrays_with_overlapping_values, self.rng(), + shapes=[shape1, shape2], dtypes=[dtype1, dtype2], + overlap=overlap) def np_fun(arg1, arg2): dtype = jnp.promote_types(arg1.dtype, arg2.dtype) return np.union1d(arg1, arg2).astype(dtype) @@ -734,14 +772,16 @@ def np_fun(arg1, arg2): @jtu.sample_product( dtype1=[s for s in default_dtypes if s != jnp.bfloat16], dtype2=[s for s in default_dtypes if s != jnp.bfloat16], - shape1=nonempty_nonscalar_array_shapes, - shape2=nonempty_nonscalar_array_shapes, + shape1=nonempty_shapes, + shape2=nonempty_shapes, size=[1, 5, 10], fill_value=[None, -1], + overlap=[0.1, 0.5, 0.9], ) - def testUnion1dSize(self, shape1, shape2, dtype1, dtype2, size, fill_value): - rng = jtu.rand_default(self.rng()) - args_maker = lambda: [rng(shape1, dtype1), rng(shape2, dtype2)] + def testUnion1dSize(self, shape1, shape2, dtype1, dtype2, size, fill_value, overlap): + args_maker = partial(arrays_with_overlapping_values, self.rng(), + shapes=[shape1, shape2], dtypes=[dtype1, dtype2], + overlap=overlap) def np_fun(arg1, arg2): dtype = jnp.promote_types(arg1.dtype, arg2.dtype) result = np.union1d(arg1, arg2).astype(dtype) @@ -762,34 +802,62 @@ def jnp_fun(arg1, arg2): shape1=all_shapes, shape2=all_shapes, assume_unique=[False, True], - ) - def testSetxor1d(self, shape1, dtype1, shape2, dtype2, assume_unique): - rng = jtu.rand_default(self.rng()) - args_maker = lambda: [rng(shape1, dtype1), rng(shape2, dtype2)] - jnp_fun = lambda ar1, ar2: jnp.setxor1d(ar1, ar2, assume_unique=assume_unique) + size=[None, 2, 5], + fill_value=[None, 99], + overlap=[0.1, 0.5, 0.9], + ) + def testSetxor1d(self, shape1, dtype1, shape2, dtype2, assume_unique, size, fill_value, overlap): + args_maker = partial(arrays_with_overlapping_values, self.rng(), + shapes=[shape1, shape2], dtypes=[dtype1, dtype2], + overlap=overlap) + jnp_fun = lambda ar1, ar2: jnp.setxor1d(ar1, ar2, assume_unique=assume_unique, + size=size, fill_value=fill_value) def np_fun(ar1, ar2): if assume_unique: - # pre-flatten the arrays to match with jax implementation + # numpy requires 1D inputs when assume_unique is True. ar1 = np.ravel(ar1) ar2 = np.ravel(ar2) - return np.setxor1d(ar1, ar2, assume_unique) + return with_size_argument(np.setxor1d)(ar1, ar2, assume_unique, size=size, fill_value=fill_value) with jtu.strict_promotion_if_dtypes_match([dtype1, dtype2]): self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False) @jtu.sample_product( dtype1=[s for s in default_dtypes if s != jnp.bfloat16], dtype2=[s for s in default_dtypes if s != jnp.bfloat16], - shape1=all_shapes, - shape2=all_shapes, + shape1=nonempty_shapes, + shape2=nonempty_shapes, assume_unique=[False, True], return_indices=[False, True], + size=[None, 3, 5], + fill_value=[None, -1], + overlap=[0.1, 0.5, 0.9], ) def testIntersect1d(self, shape1, dtype1, shape2, dtype2, assume_unique, - return_indices): - rng = jtu.rand_default(self.rng()) - args_maker = lambda: [rng(shape1, dtype1), rng(shape2, dtype2)] - jnp_fun = lambda ar1, ar2: jnp.intersect1d(ar1, ar2, assume_unique=assume_unique, return_indices=return_indices) - np_fun = lambda ar1, ar2: np.intersect1d(ar1, ar2, assume_unique=assume_unique, return_indices=return_indices) + return_indices, size, fill_value, overlap): + args_maker = partial(arrays_with_overlapping_values, self.rng(), + shapes=[shape1, shape2], dtypes=[dtype1, dtype2], + unique=assume_unique, overlap=overlap) + + def jnp_fun(ar1, ar2): + return jnp.intersect1d(ar1, ar2, assume_unique=assume_unique, return_indices=return_indices, + size=size, fill_value=fill_value) + + def np_fun(ar1, ar2): + result = np.intersect1d(ar1, ar2, assume_unique=assume_unique, return_indices=return_indices) + def correct_size(x, fill_value): + if size is None or size == len(x): + return x + elif size < len(x): + return x[:size] + else: + if fill_value is None: + fill_value = x.min() + return np.pad(x, (0, size - len(x)), constant_values=fill_value) + if return_indices: + return tuple(correct_size(r, f) for r, f in zip(result, [fill_value, ar1.size, ar2.size])) + else: + return correct_size(result, fill_value) + with jtu.strict_promotion_if_dtypes_match([dtype1, dtype2]): self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False) @@ -915,6 +983,16 @@ def testClipComplexInputError(self): with self.assertRaisesRegex(ValueError, msg): jnp.clip(x, max=jnp.array([-1+5j])) + def testClipDeprecatedArgs(self): + msg = "Passing arguments 'a', 'a_min' or 'a_max' to jax.numpy.clip is deprecated" + def assert_warns_or_errors(msg=msg): + if deprecations.is_accelerated("jax-numpy-clip-args"): + return self.assertRaisesRegex(ValueError, msg) + else: + return self.assertWarnsRegex(DeprecationWarning, msg) + with assert_warns_or_errors(msg): + jnp.clip(jnp.arange(4), a_min=2, a_max=3) + def testHypotComplexInputError(self): rng = jtu.rand_default(self.rng()) x = rng((5,), dtype=jnp.complex64) @@ -965,7 +1043,7 @@ def testOperatorRound(self, jit): check_dtypes=False) def testRoundMethod(self): - # https://github.com/google/jax/issues/15190 + # https://github.com/jax-ml/jax/issues/15190 (jnp.arange(3.) / 5.).round() # doesn't crash @jtu.sample_product(shape=[(5,), (5, 2)]) @@ -1400,6 +1478,12 @@ def testTrimZeros(self, a_shape, dtype, trim): jnp_fun = lambda arg1: jnp.trim_zeros(arg1, trim) self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=True) + def testTrimZerosNotOneDArray(self): + # TODO: make this an error after the deprecation period. + with self.assertWarnsRegex(DeprecationWarning, + r"Passing arrays with ndim != 1 to jnp.trim_zeros\(\)"): + jnp.trim_zeros(jnp.array([[0.0, 1.0, 0.0],[2.0, 4.5, 0.0]])) + @jtu.sample_product( rank=(1, 2), dtype=default_dtypes, @@ -1487,7 +1571,7 @@ def testIntegerPower(self, ptype): y=[0, 32, 64, 128], ) def testIntegerPowerOverflow(self, x, y): - # Regression test for https://github.com/google/jax/issues/5987 + # Regression test for https://github.com/jax-ml/jax/issues/5987 args_maker = lambda: [x, y] self._CheckAgainstNumpy(np.power, jnp.power, args_maker) self._CompileAndCheck(jnp.power, args_maker) @@ -1629,7 +1713,7 @@ def testConcatenateArray(self, shape, dtype, axis): self._CompileAndCheck(jnp_fun, args_maker) def testConcatenateAxisNone(self): - # https://github.com/google/jax/issues/3419 + # https://github.com/jax-ml/jax/issues/3419 a = jnp.array([[1, 2], [3, 4]]) b = jnp.array([[5]]) jnp.concatenate((a, b), axis=None) @@ -2050,6 +2134,9 @@ def np_fun(x): self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) @jtu.sample_product(dtype=inexact_dtypes, equal_nan=[True, False]) + @jtu.ignore_warning( + category=RuntimeWarning, message='invalid value encountered in cast' + ) def testUniqueEqualNan(self, dtype, equal_nan): shape = (20,) rng = jtu.rand_some_nan(self.rng()) @@ -2594,6 +2681,7 @@ def np_fun(x1, x2): shape=all_shapes, dtype=default_dtypes, ) + @jtu.ignore_warning(category=RuntimeWarning, message="overflow") def testFrexp(self, shape, dtype, rng_factory): # integer types are converted to float64 in numpy's implementation if (dtype not in [jnp.bfloat16, np.float16, np.float32] @@ -2637,6 +2725,11 @@ def np_fun(arg): self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) self._CompileAndCheck(jnp_fun, args_maker) + def testTraceSameAxesError(self): + a = jnp.arange(1, 13).reshape(2, 3, 2) + with self.assertRaisesRegex(ValueError, r"axis1 and axis2 can not be same"): + jnp.trace(a, axis1=1, axis2=-2) + @jtu.sample_product( ashape=[(15,), (16,), (17,)], vshape=[(), (5,), (5, 5)], @@ -2722,6 +2815,23 @@ def testDigitize(self, xshape, binshape, right, reverse, dtype): self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) self._CompileAndCheck(jnp_fun, args_maker) + @jtu.sample_product( + xshape=[(20,), (5, 4)], + binshape=[(0,), (1,), (5,)], + right=[True, False], + method=['scan', 'scan_unrolled', 'sort', 'compare_all'], + reverse=[True, False], + dtype=default_dtypes, + ) + def testDigitizeMethod(self, xshape, binshape, right, method, reverse, dtype): + order = jnp.index_exp[::-1] if reverse else jnp.index_exp[:] + rng = jtu.rand_default(self.rng()) + args_maker = lambda: [rng(xshape, dtype), jnp.sort(rng(binshape, dtype))[order]] + np_fun = lambda x, bins: np.digitize(x, bins, right=right).astype('int32') + jnp_fun = lambda x, bins: jnp.digitize(x, bins, right=right, method=method) + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) + self._CompileAndCheck(jnp_fun, args_maker) + @jtu.sample_product( dtypes=[ [np.float32], @@ -2873,7 +2983,7 @@ def np_fun(x, n=n, axis=axis, prepend=prepend, append=append): self._CompileAndCheck(jnp_fun, args_maker) def testDiffPrepoendScalar(self): - # Regression test for https://github.com/google/jax/issues/19362 + # Regression test for https://github.com/jax-ml/jax/issues/19362 x = jnp.arange(10) result_jax = jnp.diff(x, prepend=x[0], append=x[-1]) @@ -3298,6 +3408,16 @@ def testReshape(self, arg_shape, out_shape, dtype, order): self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker) self._CompileAndCheck(jnp_fun, args_maker) + def testReshapeDeprecatedArgs(self): + msg = "The newshape argument of jax.numpy.reshape is deprecated." + def assert_warns_or_errors(msg=msg): + if deprecations.is_accelerated("jax-numpy-reshape-newshape"): + return self.assertRaisesRegex(ValueError, msg) + else: + return self.assertWarnsRegex(DeprecationWarning, msg) + with assert_warns_or_errors(msg): + jnp.reshape(jnp.arange(4), newshape=(2, 2)) + @jtu.sample_product( [dict(arg_shape=arg_shape, out_shape=out_shape) for arg_shape, out_shape in [ @@ -3497,7 +3617,7 @@ def _check(obj, out_dtype, weak_type): _check([jnp.complex128(1)], np.complex128, False) # Mixed inputs use JAX-style promotion. - # (regression test for https://github.com/google/jax/issues/8945) + # (regression test for https://github.com/jax-ml/jax/issues/8945) _check([0, np.int16(1)], np.int16, False) _check([0.0, np.float16(1)], np.float16, False) @@ -4115,17 +4235,17 @@ def testPathologicalFloats(self): # TODO(mattjj): test other ndarray-like method overrides def testNpMean(self): - # from https://github.com/google/jax/issues/125 + # from https://github.com/jax-ml/jax/issues/125 x = jnp.eye(3, dtype=float) + 0. ans = np.mean(x) self.assertAllClose(ans, np.array(1./3), check_dtypes=False) def testArangeOnFloats(self): np_arange = jtu.with_jax_dtype_defaults(np.arange) - # from https://github.com/google/jax/issues/145 + # from https://github.com/jax-ml/jax/issues/145 self.assertAllClose(np_arange(0.0, 1.0, 0.1), jnp.arange(0.0, 1.0, 0.1)) - # from https://github.com/google/jax/issues/3450 + # from https://github.com/jax-ml/jax/issues/3450 self.assertAllClose(np_arange(2.5), jnp.arange(2.5)) self.assertAllClose(np_arange(0., 2.5), @@ -4184,14 +4304,8 @@ def testSortStableDescending(self): self.assertArraysEqual(jnp.argsort(x), argsorted_stable) self.assertArraysEqual(jnp.argsort(x, descending=True), argsorted_rev_stable) - @jtu.sample_product( - [dict(shape=shape, axis=axis) - for shape in one_dim_array_shapes - for axis in [None] - ], - dtype=all_dtypes, - ) - def testSortComplex(self, dtype, shape, axis): + @jtu.sample_product(shape=nonzerodim_shapes, dtype=all_dtypes) + def testSortComplex(self, shape, dtype): rng = jtu.rand_some_equal(self.rng()) args_maker = lambda: [rng(shape, dtype)] self._CheckAgainstNumpy(np.sort_complex, jnp.sort_complex, args_maker, @@ -4292,7 +4406,7 @@ def testPartition(self, shape, dtype, axis, kth): dtype=unsigned_dtypes, ) def testPartitionUnsignedWithZeros(self, kth, dtype): - # https://github.com/google/jax/issues/22137 + # https://github.com/jax-ml/jax/issues/22137 max_val = np.iinfo(dtype).max arg = jnp.array([[6, max_val, 0, 4, 3, 1, 0, 7, 5, 2]], dtype=dtype) axis = -1 @@ -4333,7 +4447,7 @@ def testArgpartition(self, shape, dtype, axis, kth): dtype=unsigned_dtypes, ) def testArgpartitionUnsignedWithZeros(self, kth, dtype): - # https://github.com/google/jax/issues/22137 + # https://github.com/jax-ml/jax/issues/22137 max_val = np.iinfo(dtype).max arg = jnp.array([[6, max_val, 0, 4, 3, 1, 0, 7, 5, 2, 3]], dtype=dtype) axis = -1 @@ -4508,7 +4622,7 @@ def args_maker(): self._CompileAndCheck(jnp_op, args_maker) def testTakeAlongAxisWithUint8IndicesDoesNotOverflow(self): - # https://github.com/google/jax/issues/5088 + # https://github.com/jax-ml/jax/issues/5088 h = jtu.rand_default(self.rng())((256, 256, 100), np.float32) g = jtu.rand_int(self.rng(), 0, 100)((256, 256, 1), np.uint8) q0 = jnp.take_along_axis(h, g, axis=-1) @@ -4701,7 +4815,7 @@ def np_fun(condlist, choicelist, default): else x.astype(np.float32) for x in choicelist] dtype = jnp.result_type(default, *choicelist) return np.select(condlist, - [np.asarray(x, dtype=dtype) for x in choicelist], + [np.asarray(x).astype(dtype) for x in choicelist], np.asarray(default, dtype=dtype)) with jtu.strict_promotion_if_dtypes_match(dtypes): self._CheckAgainstNumpy(np_fun, jnp.select, args_maker, @@ -4729,9 +4843,9 @@ def testSymmetrizeDtypePromotion(self): # NOTE(mattjj): I disabled this test when removing lax._safe_mul because # introducing the convention 0 * inf = 0 leads to silently wrong results in # some cases. See this comment for details: - # https://github.com/google/jax/issues/1052#issuecomment-514083352 + # https://github.com/jax-ml/jax/issues/1052#issuecomment-514083352 # def testIssue347(self): - # # https://github.com/google/jax/issues/347 + # # https://github.com/jax-ml/jax/issues/347 # def test_fail(x): # x = jnp.sqrt(jnp.sum(x ** 2, axis=1)) # ones = jnp.ones_like(x) @@ -4742,7 +4856,7 @@ def testSymmetrizeDtypePromotion(self): # assert not np.any(np.isnan(result)) def testIssue453(self): - # https://github.com/google/jax/issues/453 + # https://github.com/jax-ml/jax/issues/453 a = np.arange(6) + 1 ans = jnp.reshape(a, (3, 2), order='F') expected = np.reshape(a, (3, 2), order='F') @@ -4753,7 +4867,7 @@ def testIssue453(self): op=["atleast_1d", "atleast_2d", "atleast_3d"], ) def testAtLeastNdLiterals(self, dtype, op): - # Fixes: https://github.com/google/jax/issues/634 + # Fixes: https://github.com/jax-ml/jax/issues/634 np_fun = lambda arg: getattr(np, op)(arg).astype(dtypes.python_scalar_dtypes[dtype]) jnp_fun = lambda arg: getattr(jnp, op)(arg) args_maker = lambda: [dtype(2)] @@ -5381,7 +5495,7 @@ def testDisableNumpyRankPromotionBroadcastingDecorator(self): jnp.ones(2) + 3 # don't want to warn for scalars def testStackArrayArgument(self): - # tests https://github.com/google/jax/issues/1271 + # tests https://github.com/jax-ml/jax/issues/1271 @jax.jit def foo(x): return jnp.stack(x) @@ -5428,7 +5542,7 @@ def testBroadcastTo(self, from_shape, to_shape): self._CompileAndCheck(jnp_op, args_maker) def testBroadcastToInvalidShape(self): - # Regression test for https://github.com/google/jax/issues/20533 + # Regression test for https://github.com/jax-ml/jax/issues/20533 x = jnp.zeros((3, 4, 5)) with self.assertRaisesRegex( ValueError, "Cannot broadcast to shape with fewer dimensions"): @@ -5580,7 +5694,7 @@ def testGradientNonConstant(self, shape, dtype): self._CompileAndCheck(jnp.gradient, args_maker) def testZerosShapeErrors(self): - # see https://github.com/google/jax/issues/1822 + # see https://github.com/jax-ml/jax/issues/1822 self.assertRaisesRegex( TypeError, "Shapes must be 1D sequences of concrete values of integer type.*", @@ -5597,8 +5711,9 @@ def testTraceMethod(self): self.assertAllClose(x.trace(), jnp.array(x).trace()) self.assertAllClose(x.trace(), jax.jit(lambda y: y.trace())(x)) + @jtu.ignore_warning(category=RuntimeWarning, message="divide by zero") def testIntegerPowersArePrecise(self): - # See https://github.com/google/jax/pull/3036 + # See https://github.com/jax-ml/jax/pull/3036 # Checks if the squares of float32 integers have no numerical errors. # It should be satisfied with all integers less than sqrt(2**24). x = jnp.arange(-2**12, 2**12, dtype=jnp.int32) @@ -5669,7 +5784,7 @@ def testArange64Bit(self, dtype): self._CompileAndCheck(jnp_fun, args_maker) def testIssue2347(self): - # https://github.com/google/jax/issues/2347 + # https://github.com/jax-ml/jax/issues/2347 object_list = list[tuple[jnp.array, float, float, jnp.array, bool]] self.assertRaises(TypeError, jnp.array, object_list) @@ -5988,7 +6103,7 @@ def testSincGradArrayInput(self): jax.grad(lambda x: jnp.sinc(x).sum())(jnp.arange(10.)) # doesn't crash def testTakeAlongAxisIssue1521(self): - # https://github.com/google/jax/issues/1521 + # https://github.com/jax-ml/jax/issues/1521 idx = jnp.repeat(jnp.arange(3), 10).reshape((30, 1)) def f(x): @@ -6099,7 +6214,7 @@ def testWrappedSignaturesMatch(self): if name == "clip": # JAX's support of the Array API spec for clip, and the way it handles # backwards compatibility was introduced in - # https://github.com/google/jax/pull/20550 with a different signature + # https://github.com/jax-ml/jax/pull/20550 with a different signature # from the one in numpy, introduced in # https://github.com/numpy/numpy/pull/26724 # TODO(dfm): After our deprecation period for the clip arguments ends @@ -6162,7 +6277,8 @@ def _dtypes_for_ufunc(name: str) -> Iterator[tuple[str, ...]]: for arg_dtypes in itertools.product(_available_numpy_dtypes, repeat=func.nin): args = (np.ones(1, dtype=dtype) for dtype in arg_dtypes) try: - with jtu.ignore_warning(category=RuntimeWarning, message="divide by zero"): + with jtu.ignore_warning( + category=RuntimeWarning, message="(divide by zero|invalid value)"): _ = func(*args) except TypeError: pass @@ -6184,7 +6300,7 @@ def testUfuncInputTypes(self, name, arg_dtypes): jnp_op = getattr(jnp, name) np_op = getattr(np, name) np_op = jtu.ignore_warning(category=RuntimeWarning, - message="divide by zero.*")(np_op) + message="(divide by zero|invalid value)")(np_op) args_maker = lambda: tuple(np.ones(1, dtype=dtype) for dtype in arg_dtypes) with jtu.strict_promotion_if_dtypes_match(arg_dtypes): @@ -6200,7 +6316,9 @@ def test_lax_numpy_docstrings(self): unimplemented = ['fromfile', 'fromiter'] aliases = ['abs', 'acos', 'acosh', 'asin', 'asinh', 'atan', 'atanh', 'atan2', - 'amax', 'amin'] + 'amax', 'amin', 'around', 'bitwise_right_shift', 'conj', 'degrees', + 'divide', 'mod', 'pow', 'radians', 'round_'] + skip_args_check = ['vsplit', 'hsplit', 'dsplit', 'array_split'] for name in dir(jnp): if name.startswith('_') or name in unimplemented: @@ -6225,12 +6343,14 @@ def test_lax_numpy_docstrings(self): raise Exception(f"jnp.{name} does not have a wrapped docstring.") elif name in aliases: assert "Alias of" in obj.__doc__ - else: + elif name not in skip_args_check: # Other functions should have nontrivial docs including "Args" and "Returns". doc = obj.__doc__ self.assertNotEmpty(doc) self.assertIn("Args:", doc, msg=f"'Args:' not found in docstring of jnp.{name}") self.assertIn("Returns:", doc, msg=f"'Returns:' not found in docstring of jnp.{name}") + if name not in ["frompyfunc", "isdtype", "promote_types"]: + self.assertIn("Examples:", doc, msg=f"'Examples:' not found in docstring of jnp.{name}") @parameterized.named_parameters( {"testcase_name": "_jit" if jit else "", "jit": jit} for jit in [True, False]) diff --git a/tests/lax_numpy_ufuncs_test.py b/tests/lax_numpy_ufuncs_test.py index 16eb9321c822..537146a215c7 100644 --- a/tests/lax_numpy_ufuncs_test.py +++ b/tests/lax_numpy_ufuncs_test.py @@ -14,6 +14,7 @@ """Tests for jax.numpy.ufunc and its methods.""" +import itertools from functools import partial from absl.testing import absltest @@ -22,7 +23,6 @@ import jax import jax.numpy as jnp from jax._src import test_util as jtu -from jax._src.numpy.ufunc_api import get_if_single_primitive jax.config.parse_flags_with_absl() @@ -54,19 +54,39 @@ def scalar_sub(x, y): {'func': scalar_sub, 'nin': 2, 'nout': 1, 'identity': None}, ] -FASTPATH_FUNCS = [ - {'func': jnp.add, 'nin': 2, 'nout': 1, 'identity': 0, - 'reducer': jax.lax.reduce_sum_p, 'accumulator': jax.lax.cumsum_p}, - {'func': jnp.multiply, 'nin': 2, 'nout': 1, 'identity': 1, - 'reducer': jax.lax.reduce_prod_p, 'accumulator': jax.lax.cumprod_p}, +def _jnp_ufunc_props(name): + jnp_func = getattr(jnp, name) + assert isinstance(jnp_func, jnp.ufunc) + np_func = getattr(np, name) + dtypes = [np.dtype(c) for c in "Ffi?" if f"{c}{c}->{c}" in np_func.types or f"{c}->{c}" in np_func.types] + return [dict(name=name, dtype=dtype) for dtype in dtypes] + + +JAX_NUMPY_UFUNCS = [ + name for name in dir(jnp) if isinstance(getattr(jnp, name), jnp.ufunc) +] + +BINARY_UFUNCS = [ + name for name in JAX_NUMPY_UFUNCS if getattr(jnp, name).nin == 2 ] -NON_FASTPATH_FUNCS = [ - {'func': lambda a, b: jnp.add(a, a), 'nin': 2, 'nout': 1, 'identity': 0}, - {'func': lambda a, b: jnp.multiply(b, a), 'nin': 2, 'nout': 1, 'identity': 1}, - {'func': jax.jit(lambda a, b: jax.jit(jnp.multiply)(b, a)), 'nin': 2, 'nout': 1, 'identity': 1}, +UNARY_UFUNCS = [ + name for name in JAX_NUMPY_UFUNCS if getattr(jnp, name).nin == 1 ] +JAX_NUMPY_UFUNCS_WITH_DTYPES = list(itertools.chain.from_iterable( + _jnp_ufunc_props(name) for name in JAX_NUMPY_UFUNCS +)) + +BINARY_UFUNCS_WITH_DTYPES = list(itertools.chain.from_iterable( + _jnp_ufunc_props(name) for name in BINARY_UFUNCS +)) + +UNARY_UFUNCS_WITH_DTYPES = list(itertools.chain.from_iterable( + _jnp_ufunc_props(name) for name in UNARY_UFUNCS +)) + + broadcast_compatible_shapes = [(), (1,), (3,), (1, 3), (4, 1), (4, 3)] nonscalar_shapes = [(3,), (4,), (4, 3)] @@ -80,23 +100,40 @@ def wrapped(*args, **kwargs): class LaxNumpyUfuncTests(jtu.JaxTestCase): @jtu.sample_product(SCALAR_FUNCS) - def test_ufunc_properties(self, func, nin, nout, identity): + def test_frompyfunc_properties(self, func, nin, nout, identity): jnp_fun = jnp.frompyfunc(func, nin=nin, nout=nout, identity=identity) self.assertEqual(jnp_fun.identity, identity) self.assertEqual(jnp_fun.nin, nin) self.assertEqual(jnp_fun.nout, nout) self.assertEqual(jnp_fun.nargs, nin) + @jtu.sample_product(name=JAX_NUMPY_UFUNCS) + def test_ufunc_properties(self, name): + jnp_fun = getattr(jnp, name) + np_fun = getattr(np, name) + self.assertEqual(jnp_fun.identity, np_fun.identity) + self.assertEqual(jnp_fun.nin, np_fun.nin) + self.assertEqual(jnp_fun.nout, np_fun.nout) + self.assertEqual(jnp_fun.nargs, np_fun.nargs - 1) # -1 because NumPy accepts `out` + @jtu.sample_product(SCALAR_FUNCS) - def test_ufunc_properties_readonly(self, func, nin, nout, identity): + def test_frompyfunc_properties_readonly(self, func, nin, nout, identity): jnp_fun = jnp.frompyfunc(func, nin=nin, nout=nout, identity=identity) - for attr in ['nargs', 'nin', 'nout', 'identity', '_func', '_call']: + for attr in ['nargs', 'nin', 'nout', 'identity', '_func']: + getattr(jnp_fun, attr) # no error on attribute access. + with self.assertRaises(AttributeError): + setattr(jnp_fun, attr, None) # error when trying to mutate. + + @jtu.sample_product(name=JAX_NUMPY_UFUNCS) + def test_ufunc_properties_readonly(self, name): + jnp_fun = getattr(jnp, name) + for attr in ['nargs', 'nin', 'nout', 'identity', '_func']: getattr(jnp_fun, attr) # no error on attribute access. with self.assertRaises(AttributeError): setattr(jnp_fun, attr, None) # error when trying to mutate. @jtu.sample_product(SCALAR_FUNCS) - def test_ufunc_hash(self, func, nin, nout, identity): + def test_frompyfunc_hash(self, func, nin, nout, identity): jnp_fun = jnp.frompyfunc(func, nin=nin, nout=nout, identity=identity) jnp_fun_2 = jnp.frompyfunc(func, nin=nin, nout=nout, identity=identity) self.assertEqual(jnp_fun, jnp_fun_2) @@ -113,7 +150,7 @@ def test_ufunc_hash(self, func, nin, nout, identity): dtype=jtu.dtypes.floating, ) @jax.numpy_rank_promotion('allow') # This test explicitly exercises implicit rank promotion. - def test_call(self, func, nin, nout, identity, lhs_shape, rhs_shape, dtype): + def test_frompyfunc_call(self, func, nin, nout, identity, lhs_shape, rhs_shape, dtype): jnp_fun = jnp.frompyfunc(func, nin=nin, nout=nout, identity=identity) np_fun = cast_outputs(np.frompyfunc(func, nin=nin, nout=nout, identity=identity)) @@ -123,13 +160,41 @@ def test_call(self, func, nin, nout, identity, lhs_shape, rhs_shape, dtype): self._CheckAgainstNumpy(jnp_fun, np_fun, args_maker) self._CompileAndCheck(jnp_fun, args_maker) + @jtu.sample_product( + UNARY_UFUNCS_WITH_DTYPES, + shape=broadcast_compatible_shapes, + ) + def test_unary_ufunc_call(self, name, dtype, shape): + jnp_fun = getattr(jnp, name) + np_fun = getattr(np, name) + rng = jtu.rand_default(self.rng()) + args_maker = lambda: [rng(shape, dtype)] + + self._CheckAgainstNumpy(jnp_fun, np_fun, args_maker) + self._CompileAndCheck(jnp_fun, args_maker) + + @jtu.sample_product( + BINARY_UFUNCS_WITH_DTYPES, + lhs_shape=broadcast_compatible_shapes, + rhs_shape=broadcast_compatible_shapes, + ) + @jax.numpy_rank_promotion('allow') # This test explicitly exercises implicit rank promotion. + def test_bimary_ufunc_call(self, name, dtype, lhs_shape, rhs_shape): + jnp_fun = getattr(jnp, name) + np_fun = getattr(np, name) + rng = jtu.rand_default(self.rng()) + args_maker = lambda: [rng(lhs_shape, dtype), rng(rhs_shape, dtype)] + + self._CheckAgainstNumpy(jnp_fun, np_fun, args_maker) + self._CompileAndCheck(jnp_fun, args_maker) + @jtu.sample_product( SCALAR_FUNCS, lhs_shape=broadcast_compatible_shapes, rhs_shape=broadcast_compatible_shapes, dtype=jtu.dtypes.floating, ) - def test_outer(self, func, nin, nout, identity, lhs_shape, rhs_shape, dtype): + def test_frompyfunc_outer(self, func, nin, nout, identity, lhs_shape, rhs_shape, dtype): if (nin, nout) != (2, 1): self.skipTest(f"outer requires (nin, nout)=(2, 1); got {(nin, nout)=}") jnp_fun = jnp.frompyfunc(func, nin=nin, nout=nout, identity=identity).outer @@ -141,6 +206,21 @@ def test_outer(self, func, nin, nout, identity, lhs_shape, rhs_shape, dtype): self._CheckAgainstNumpy(jnp_fun, np_fun, args_maker) self._CompileAndCheck(jnp_fun, args_maker) + @jtu.sample_product( + BINARY_UFUNCS_WITH_DTYPES, + lhs_shape=broadcast_compatible_shapes, + rhs_shape=broadcast_compatible_shapes, + ) + def test_binary_ufunc_outer(self, name, lhs_shape, rhs_shape, dtype): + jnp_fun = getattr(jnp, name) + np_fun = getattr(np, name) + + rng = jtu.rand_default(self.rng()) + args_maker = lambda: [rng(lhs_shape, dtype), rng(rhs_shape, dtype)] + + self._CheckAgainstNumpy(jnp_fun.outer, np_fun.outer, args_maker) + self._CompileAndCheck(jnp_fun.outer, args_maker) + @jtu.sample_product( SCALAR_FUNCS, [{'shape': shape, 'axis': axis} @@ -148,7 +228,7 @@ def test_outer(self, func, nin, nout, identity, lhs_shape, rhs_shape, dtype): for axis in [None, *range(-len(shape), len(shape))]], dtype=jtu.dtypes.floating, ) - def test_reduce(self, func, nin, nout, identity, shape, axis, dtype): + def test_frompyfunc_reduce(self, func, nin, nout, identity, shape, axis, dtype): if (nin, nout) != (2, 1): self.skipTest(f"reduce requires (nin, nout)=(2, 1); got {(nin, nout)=}") jnp_fun = partial(jnp.frompyfunc(func, nin=nin, nout=nout, identity=identity).reduce, axis=axis) @@ -160,6 +240,25 @@ def test_reduce(self, func, nin, nout, identity, shape, axis, dtype): self._CheckAgainstNumpy(jnp_fun, np_fun, args_maker) self._CompileAndCheck(jnp_fun, args_maker) + @jtu.sample_product( + BINARY_UFUNCS_WITH_DTYPES, + [{'shape': shape, 'axis': axis} + for shape in nonscalar_shapes + for axis in [None, *range(-len(shape), len(shape))]], + ) + def test_binary_ufunc_reduce(self, name, shape, axis, dtype): + jnp_fun = getattr(jnp, name) + np_fun = getattr(np, name) + + jnp_fun_reduce = partial(jnp_fun.reduce, axis=axis) + np_fun_reduce = partial(np_fun.reduce, axis=axis) + + rng = jtu.rand_default(self.rng()) + args_maker = lambda: [rng(shape, dtype)] + + self._CheckAgainstNumpy(jnp_fun_reduce, np_fun_reduce, args_maker) + self._CompileAndCheck(jnp_fun_reduce, args_maker) + @jtu.sample_product( SCALAR_FUNCS, [{'shape': shape, 'axis': axis} @@ -167,7 +266,7 @@ def test_reduce(self, func, nin, nout, identity, shape, axis, dtype): for axis in [None, *range(-len(shape), len(shape))]], dtype=jtu.dtypes.floating, ) - def test_reduce_where(self, func, nin, nout, identity, shape, axis, dtype): + def test_frompyfunc_reduce_where(self, func, nin, nout, identity, shape, axis, dtype): if (nin, nout) != (2, 1): self.skipTest(f"reduce requires (nin, nout)=(2, 1); got {(nin, nout)=}") @@ -194,42 +293,27 @@ def np_fun(arr, where): self._CompileAndCheck(jnp_fun, args_maker) @jtu.sample_product( - FASTPATH_FUNCS, - [{'shape': shape, 'axis': axis} - for shape in nonscalar_shapes - for axis in range(-len(shape), len(shape))], - dtype=jtu.dtypes.floating, - ) - def test_reduce_fastpath(self, func, nin, nout, identity, shape, axis, dtype, reducer, accumulator): - del accumulator # unused - if (nin, nout) != (2, 1): - self.skipTest(f"reduce requires (nin, nout)=(2, 1); got {(nin, nout)=}") - rng = jtu.rand_default(self.rng()) - args = (rng(shape, dtype),) - jnp_fun = partial(jnp.frompyfunc(func, nin=nin, nout=nout, identity=identity).reduce, axis=axis) - self.assertEqual(get_if_single_primitive(jnp_fun, *args), reducer) - - @jtu.sample_product( - NON_FASTPATH_FUNCS, + BINARY_UFUNCS_WITH_DTYPES, [{'shape': shape, 'axis': axis} for shape in nonscalar_shapes - for axis in range(-len(shape), len(shape))], - dtype=jtu.dtypes.floating, + for axis in [None, *range(-len(shape), len(shape))]], ) - def test_non_fastpath(self, func, nin, nout, identity, shape, axis, dtype): - if (nin, nout) != (2, 1): - self.skipTest(f"reduce requires (nin, nout)=(2, 1); got {(nin, nout)=}") - rng = jtu.rand_default(self.rng()) - args = (rng(shape, dtype),) + def test_binary_ufunc_reduce_where(self, name, shape, axis, dtype): + jnp_fun = getattr(jnp, name) + np_fun = getattr(np, name) - _ = func(0, 0) # function should not error. + if jnp_fun.identity is None: + self.skipTest("reduce with where requires identity") - reduce_fun = partial(jnp.frompyfunc(func, nin=nin, nout=nout, identity=identity).reduce, axis=axis) - self.assertIsNone(get_if_single_primitive(reduce_fun, *args)) + jnp_fun_reduce = lambda a, where: jnp_fun.reduce(a, axis=axis, where=where) + np_fun_reduce = lambda a, where: np_fun.reduce(a, axis=axis, where=where) - accum_fun = partial(jnp.frompyfunc(func, nin=nin, nout=nout, identity=identity).accumulate, axis=axis) - self.assertIsNone(get_if_single_primitive(accum_fun, *args)) + rng = jtu.rand_default(self.rng()) + rng_where = jtu.rand_bool(self.rng()) + args_maker = lambda: [rng(shape, dtype), rng_where(shape, bool)] + self._CheckAgainstNumpy(jnp_fun_reduce, np_fun_reduce, args_maker) + self._CompileAndCheck(jnp_fun_reduce, args_maker) @jtu.sample_product( SCALAR_FUNCS, @@ -238,7 +322,7 @@ def test_non_fastpath(self, func, nin, nout, identity, shape, axis, dtype): for axis in range(-len(shape), len(shape))], dtype=jtu.dtypes.floating, ) - def test_accumulate(self, func, nin, nout, identity, shape, axis, dtype): + def test_frompyfunc_accumulate(self, func, nin, nout, identity, shape, axis, dtype): if (nin, nout) != (2, 1): self.skipTest(f"accumulate requires (nin, nout)=(2, 1); got {(nin, nout)=}") jnp_fun = partial(jnp.frompyfunc(func, nin=nin, nout=nout, identity=identity).accumulate, axis=axis) @@ -251,20 +335,26 @@ def test_accumulate(self, func, nin, nout, identity, shape, axis, dtype): self._CompileAndCheck(jnp_fun, args_maker) @jtu.sample_product( - FASTPATH_FUNCS, + BINARY_UFUNCS_WITH_DTYPES, [{'shape': shape, 'axis': axis} for shape in nonscalar_shapes - for axis in range(-len(shape), len(shape))], - dtype=jtu.dtypes.floating, + for axis in range(-len(shape), len(shape))] ) - def test_accumulate_fastpath(self, func, nin, nout, identity, shape, axis, dtype, reducer, accumulator): - del reducer # unused - if (nin, nout) != (2, 1): - self.skipTest(f"reduce requires (nin, nout)=(2, 1); got {(nin, nout)=}") + def test_binary_ufunc_accumulate(self, name, shape, axis, dtype): + jnp_fun = getattr(jnp, name) + np_fun = getattr(np, name) + rng = jtu.rand_default(self.rng()) - args = (rng(shape, dtype),) - jnp_fun = partial(jnp.frompyfunc(func, nin=nin, nout=nout, identity=identity).accumulate, axis=axis) - self.assertEqual(get_if_single_primitive(jnp_fun, *args), accumulator) + args_maker = lambda: [rng(shape, dtype)] + + jnp_fun_accumulate = partial(jnp_fun.accumulate, axis=axis) + def np_fun_accumulate(x): + # numpy accumulate has different dtype casting behavior. + result = np_fun.accumulate(x, axis=axis) + return result if x.dtype == bool else result.astype(x.dtype) + + self._CheckAgainstNumpy(jnp_fun_accumulate, np_fun_accumulate, args_maker) + self._CompileAndCheck(jnp_fun_accumulate, args_maker) @jtu.sample_product( SCALAR_FUNCS, @@ -272,7 +362,7 @@ def test_accumulate_fastpath(self, func, nin, nout, identity, shape, axis, dtype idx_shape=[(), (2,)], dtype=jtu.dtypes.floating, ) - def test_at(self, func, nin, nout, identity, shape, idx_shape, dtype): + def test_frompyfunc_at(self, func, nin, nout, identity, shape, idx_shape, dtype): if (nin, nout) != (2, 1): self.skipTest(f"accumulate requires (nin, nout)=(2, 1); got {(nin, nout)=}") jnp_fun = partial(jnp.frompyfunc(func, nin=nin, nout=nout, identity=identity).at, inplace=False) @@ -288,8 +378,52 @@ def np_fun(x, idx, y): self._CheckAgainstNumpy(jnp_fun, np_fun, args_maker) self._CompileAndCheck(jnp_fun, args_maker) - def test_at_broadcasting(self): - # Regression test for https://github.com/google/jax/issues/18004 + @jtu.sample_product( + UNARY_UFUNCS_WITH_DTYPES, + shape=nonscalar_shapes, + idx_shape=[(), (2,)], + ) + def test_unary_ufunc_at(self, name, shape, idx_shape, dtype): + jnp_fun = getattr(jnp, name) + np_fun = getattr(np, name) + + rng = jtu.rand_default(self.rng()) + idx_rng = jtu.rand_int(self.rng(), low=-shape[0], high=shape[0]) + args_maker = lambda: [rng(shape, dtype), idx_rng(idx_shape, 'int32')] + + jnp_fun_at = partial(jnp_fun.at, inplace=False) + def np_fun_at(x, idx): + x_copy = x.copy() + np_fun.at(x_copy, idx) + return x_copy + + self._CheckAgainstNumpy(jnp_fun_at, np_fun_at, args_maker) + self._CompileAndCheck(jnp_fun_at, args_maker) + + @jtu.sample_product( + BINARY_UFUNCS_WITH_DTYPES, + shape=nonscalar_shapes, + idx_shape=[(), (2,)], + ) + def test_binary_ufunc_at(self, name, shape, idx_shape, dtype): + jnp_fun = getattr(jnp, name) + np_fun = getattr(np, name) + + rng = jtu.rand_default(self.rng()) + idx_rng = jtu.rand_int(self.rng(), low=-shape[0], high=shape[0]) + args_maker = lambda: [rng(shape, dtype), idx_rng(idx_shape, 'int32'), rng(idx_shape[1:], dtype)] + + jnp_fun_at = partial(jnp_fun.at, inplace=False) + def np_fun_at(x, idx, y): + x_copy = x.copy() + np_fun.at(x_copy, idx, y) + return x_copy + + self._CheckAgainstNumpy(jnp_fun_at, np_fun_at, args_maker) + self._CompileAndCheck(jnp_fun_at, args_maker) + + def test_frompyfunc_at_broadcasting(self): + # Regression test for https://github.com/jax-ml/jax/issues/18004 args_maker = lambda: [np.ones((5, 3)), np.array([0, 4, 2]), np.arange(9.0).reshape(3, 3)] def np_fun(x, idx, y): @@ -309,7 +443,7 @@ def np_fun(x, idx, y): idx_shape=[(0,), (3,), (5,)], dtype=jtu.dtypes.floating, ) - def test_reduceat(self, func, nin, nout, identity, shape, axis, idx_shape, dtype): + def test_frompyfunc_reduceat(self, func, nin, nout, identity, shape, axis, idx_shape, dtype): if (nin, nout) != (2, 1): self.skipTest(f"accumulate requires (nin, nout)=(2, 1); got {(nin, nout)=}") jnp_fun = partial(jnp.frompyfunc(func, nin=nin, nout=nout, identity=identity).reduceat, axis=axis) @@ -322,6 +456,33 @@ def test_reduceat(self, func, nin, nout, identity, shape, axis, idx_shape, dtype self._CheckAgainstNumpy(jnp_fun, np_fun, args_maker) self._CompileAndCheck(jnp_fun, args_maker) + @jtu.sample_product( + BINARY_UFUNCS_WITH_DTYPES, + [{'shape': shape, 'axis': axis} + for shape in nonscalar_shapes + for axis in [*range(-len(shape), len(shape))]], + idx_shape=[(0,), (3,), (5,)], + ) + def test_binary_ufunc_reduceat(self, name, shape, axis, idx_shape, dtype): + jnp_fun = getattr(jnp, name) + np_fun = getattr(np, name) + if (jnp_fun.nin, jnp_fun.nout) != (2, 1): + self.skipTest(f"accumulate requires (nin, nout)=(2, 1); got {(jnp_fun.nin, jnp_fun.nout)=}") + if name in ['add', 'multiply'] and dtype == bool: + # TODO(jakevdp): figure out how to fix thest cases. + self.skipTest(f"known failure for {name}.reduceat with {dtype=}") + + rng = jtu.rand_default(self.rng()) + idx_rng = jtu.rand_int(self.rng(), low=0, high=shape[axis]) + args_maker = lambda: [rng(shape, dtype), idx_rng(idx_shape, 'int32')] + + def np_fun_reduceat(x, i): + # Numpy has different casting behavior. + return np_fun.reduceat(x, i).astype(x.dtype) + + self._CheckAgainstNumpy(jnp_fun.reduceat, np_fun_reduceat, args_maker) + self._CompileAndCheck(jnp_fun.reduceat, args_maker) + if __name__ == "__main__": absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/tests/lax_numpy_vectorize_test.py b/tests/lax_numpy_vectorize_test.py index 56fd0f7817e3..985dba484845 100644 --- a/tests/lax_numpy_vectorize_test.py +++ b/tests/lax_numpy_vectorize_test.py @@ -258,7 +258,7 @@ def test_none_arg_bad_signature(self): f(*args) def test_rank_promotion_error(self): - # Regression test for https://github.com/google/jax/issues/22305 + # Regression test for https://github.com/jax-ml/jax/issues/22305 f = jnp.vectorize(jnp.add, signature="(),()->()") rank2 = jnp.zeros((10, 10)) rank1 = jnp.zeros(10) diff --git a/tests/lax_scipy_sparse_test.py b/tests/lax_scipy_sparse_test.py index d2e64833b964..303c67c5860d 100644 --- a/tests/lax_scipy_sparse_test.py +++ b/tests/lax_scipy_sparse_test.py @@ -469,7 +469,7 @@ def test_gmres_weak_types(self): self.assertTrue(dtypes.is_weakly_typed(x)) def test_linear_solve_batching_via_jacrev(self): - # See https://github.com/google/jax/issues/14249 + # See https://github.com/jax-ml/jax/issues/14249 rng = np.random.RandomState(0) M = rng.randn(5, 5) A = np.dot(M, M.T) diff --git a/tests/lax_scipy_special_functions_test.py b/tests/lax_scipy_special_functions_test.py index 38607cae883b..bd3bca5385b7 100644 --- a/tests/lax_scipy_special_functions_test.py +++ b/tests/lax_scipy_special_functions_test.py @@ -95,6 +95,10 @@ def op_record(name, nargs, dtypes, rng_factory, test_grad, nondiff_argnums=(), t op_record( "factorial", 1, float_dtypes, jtu.rand_default, True ), + op_record( + "fresnel", 1, float_dtypes, + functools.partial(jtu.rand_default, scale=30), True + ), op_record( "i0", 1, float_dtypes, jtu.rand_default, True ), diff --git a/tests/lax_scipy_test.py b/tests/lax_scipy_test.py index 1ed410cbaed8..4840972e9483 100644 --- a/tests/lax_scipy_test.py +++ b/tests/lax_scipy_test.py @@ -165,7 +165,7 @@ def testLogSumExpComplexSign(self): self.assertAllClose(sign * np.exp(logsumexp).astype(x.dtype), expected_sumexp, rtol=tol) def testLogSumExpZeros(self): - # Regression test for https://github.com/google/jax/issues/5370 + # Regression test for https://github.com/jax-ml/jax/issues/5370 scipy_fun = lambda a, b: osp_special.logsumexp(a, b=b) lax_fun = lambda a, b: lsp_special.logsumexp(a, b=b) args_maker = lambda: [np.array([-1000, -2]), np.array([1, 0])] @@ -173,14 +173,14 @@ def testLogSumExpZeros(self): self._CompileAndCheck(lax_fun, args_maker) def testLogSumExpOnes(self): - # Regression test for https://github.com/google/jax/issues/7390 + # Regression test for https://github.com/jax-ml/jax/issues/7390 args_maker = lambda: [np.ones(4, dtype='float32')] with jax.debug_infs(True): self._CheckAgainstNumpy(osp_special.logsumexp, lsp_special.logsumexp, args_maker) self._CompileAndCheck(lsp_special.logsumexp, args_maker) def testLogSumExpNans(self): - # Regression test for https://github.com/google/jax/issues/7634 + # Regression test for https://github.com/jax-ml/jax/issues/7634 with jax.debug_nans(True): with jax.disable_jit(): result = lsp_special.logsumexp(1.0) @@ -218,7 +218,8 @@ def lax_fun(a): rng = jtu.rand_positive(self.rng()) args_maker = lambda: [rng(shape, dtype) + (d - 1) / 2.] self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, - tol={np.float32: 1e-3, np.float64: 1e-14}) + tol={np.float32: 1e-3, np.float64: 1e-14}, + check_dtypes=False) self._CompileAndCheck( lax_fun, args_maker, rtol={ np.float32: 5e-5 if jtu.test_device_matches(["tpu"]) else 1e-05, @@ -245,7 +246,7 @@ def testXlogyShouldReturnZero(self): self.assertAllClose(lsp_special.xlogy(0., 0.), 0., check_dtypes=False) def testGradOfXlogyAtZero(self): - # https://github.com/google/jax/issues/15598 + # https://github.com/jax-ml/jax/issues/15598 x0, y0 = 0.0, 3.0 d_xlog1py_dx = jax.grad(lsp_special.xlogy, argnums=0)(x0, y0) self.assertAllClose(d_xlog1py_dx, lax.log(y0)) @@ -259,7 +260,7 @@ def testXlog1pyShouldReturnZero(self): self.assertAllClose(lsp_special.xlog1py(0., -1.), 0., check_dtypes=False) def testGradOfXlog1pyAtZero(self): - # https://github.com/google/jax/issues/15598 + # https://github.com/jax-ml/jax/issues/15598 x0, y0 = 0.0, 3.0 d_xlog1py_dx = jax.grad(lsp_special.xlog1py, argnums=0)(x0, y0) self.assertAllClose(d_xlog1py_dx, lax.log1p(y0)) @@ -283,7 +284,7 @@ def testXLogX(self): rtol=.1, eps=1e-3) def testGradOfEntrAtZero(self): - # https://github.com/google/jax/issues/15709 + # https://github.com/jax-ml/jax/issues/15709 self.assertEqual(jax.jacfwd(lsp_special.entr)(0.0), jnp.inf) self.assertEqual(jax.jacrev(lsp_special.entr)(0.0), jnp.inf) @@ -332,6 +333,8 @@ def scipy_fun(z): dtype=float_dtypes, ) def testLpmn(self, l_max, shape, dtype): + if jtu.is_device_tpu(6, "e"): + self.skipTest("TODO(b/364258243): fails on TPU v6e") rng = jtu.rand_uniform(self.rng(), low=-0.2, high=0.9) args_maker = lambda: [rng(shape, dtype)] @@ -441,6 +444,8 @@ def testSphHarmOrderOneDegreeOne(self): @jax.numpy_dtype_promotion('standard') # This test explicitly exercises dtype promotion def testSphHarmForJitAndAgainstNumpy(self, l_max, num_z, dtype): """Tests against JIT compatibility and Numpy.""" + if jtu.is_device_tpu(6, "e"): + self.skipTest("TODO(b/364258243): fails on TPU v6e") n_max = l_max shape = (num_z,) rng = jtu.rand_int(self.rng(), -l_max, l_max + 1) @@ -518,8 +523,6 @@ def testPolar( tol = 650 * float(jnp.finfo(matrix.dtype).eps) eye_mat = np.eye(should_be_eye.shape[0], dtype=should_be_eye.dtype) with self.subTest('Test unitarity.'): - if jtu.test_device_matches(["cpu"]): - tol = max(tol, 1e-8) self.assertAllClose( eye_mat, should_be_eye, atol=tol * 1000 * min(shape)) diff --git a/tests/lax_test.py b/tests/lax_test.py index 73b21d12923e..c8f3ca797903 100644 --- a/tests/lax_test.py +++ b/tests/lax_test.py @@ -41,10 +41,12 @@ from jax._src import dtypes from jax._src import lax_reference from jax._src import test_util as jtu +from jax._src import xla_bridge from jax._src.interpreters import mlir from jax._src.interpreters import pxla from jax._src.internal_test_util import lax_test_util from jax._src.lax import lax as lax_internal +from jax._src.lib import version as jaxlib_version from jax._src.util import NumpyComplexWarning, safe_zip from jax._src.tree_util import tree_map @@ -110,6 +112,7 @@ def testOp(self, op_name, rng_factory, shapes, dtype): for shape_group in lax_test_util.compatible_shapes), dtype=rec.dtypes) for rec in lax_test_util.lax_ops())) + @jtu.ignore_warning(message="invalid value", category=RuntimeWarning) def testOpAgainstNumpy(self, op_name, rng_factory, shapes, dtype, tol): if (not config.enable_x64.value and op_name == "nextafter" and dtype == np.float64): @@ -1002,7 +1005,7 @@ def fun_via_grad(lhs, rhs): self._CheckAgainstNumpy(fun_via_grad, fun, args_maker) def testConvTransposePaddingList(self): - # Regression test for https://github.com/google/jax/discussions/8695 + # Regression test for https://github.com/jax-ml/jax/discussions/8695 a = jnp.ones((28,28)) b = jnp.ones((3,3)) c = lax.conv_general_dilated(a[None, None], b[None, None], (1,1), [(0,0),(0,0)], (1,1)) @@ -1040,6 +1043,178 @@ def testDot(self, lhs_shape, rhs_shape, lhs_dtype, rhs_dtype, precision): args_maker = lambda: [rng(lhs_shape, lhs_dtype), rng(rhs_shape, rhs_dtype)] self._CompileAndCheck(partial(lax.dot, precision=precision), args_maker) + @parameterized.parameters([ + (algorithm, dtype) + for algorithm, test_dtypes in [ + (lax.DotAlgorithm( + lhs_precision_type=np.float32, + rhs_precision_type=np.float32, + accumulation_type=np.float32, + lhs_component_count=1, + rhs_component_count=1, + num_primitive_operations=1, + allow_imprecise_accumulation=False, + ), [np.float32]), + (lax.DotAlgorithm( + lhs_precision_type=np.float16, + rhs_precision_type=np.float16, + accumulation_type=np.float32, + ), [np.float16]), + ("F16_F16_F32", [np.float16]), + (lax.DotAlgorithm.Preset.DEFAULT, lax_test_util.float_dtypes), + (lax.DotAlgorithm.Preset.ANY_F8_ANY_F8_F32, dtypes._float8_dtypes), + (lax.DotAlgorithm.Preset.ANY_F8_ANY_F8_F32_FAST_ACCUM, dtypes._float8_dtypes), + (lax.DotAlgorithm.Preset.F16_F16_F16, [np.float16]), + (lax.DotAlgorithm.Preset.F16_F16_F32, [np.float16]), + (lax.DotAlgorithm.Preset.BF16_BF16_BF16, [dtypes.bfloat16]), + (lax.DotAlgorithm.Preset.BF16_BF16_F32, [dtypes.bfloat16]), + (lax.DotAlgorithm.Preset.BF16_BF16_F32_X3, [np.float32]), + (lax.DotAlgorithm.Preset.BF16_BF16_F32_X6, [np.float32]), + (lax.DotAlgorithm.Preset.TF32_TF32_F32, [np.float32]), + (lax.DotAlgorithm.Preset.TF32_TF32_F32_X3, [np.float32]), + (lax.DotAlgorithm.Preset.F32_F32_F32, [np.float32]), + (lax.DotAlgorithm.Preset.F64_F64_F64, [np.float64]), + ] for dtype in test_dtypes + if jtu.dtypes.supported([dtype]) + ]) + def testDotAlgorithm(self, algorithm, dtype): + if xla_bridge.using_pjrt_c_api(): + raise SkipTest( + "The dot algorithm attribute is not supported by PJRT C API.") + if jaxlib_version <= (0, 4, 33): + raise SkipTest( + "The dot algorithm attribute is only supported for jaxlib >0.4.33.") + if jtu.test_device_matches(["gpu"]): + # GPU algorithm support is a little spotty. It is checked in + # xla/service/algorithm_util.cc and the logic is copied here. + if algorithm in { + lax.DotAlgorithm.Preset.F16_F16_F32, + lax.DotAlgorithm.Preset.TF32_TF32_F32, + lax.DotAlgorithm.Preset.BF16_BF16_F32, + lax.DotAlgorithm.Preset.BF16_BF16_F32_X3, # Must have f32 input + lax.DotAlgorithm.Preset.BF16_BF16_F32_X6, # Must have f32 input + }: + if not jtu.is_cuda_compute_capability_at_least("8.0"): + raise SkipTest( + f"The dot algorithm '{algorithm}' requires CUDA compute " + "capability >= 8.0.") + elif algorithm not in { + lax.DotAlgorithm.Preset.DEFAULT, + lax.DotAlgorithm.Preset.ANY_F8_ANY_F8_F32, + lax.DotAlgorithm.Preset.ANY_F8_ANY_F8_F32_FAST_ACCUM, + lax.DotAlgorithm.Preset.F32_F32_F32, + lax.DotAlgorithm.Preset.F64_F64_F64, + }: + raise SkipTest( + f"The dot algorithm '{algorithm}' is not supported on GPU.") + lhs_shape = (3, 4) + rhs_shape = (4, 3) + rng = jtu.rand_default(self.rng()) + args_maker = lambda: [rng(lhs_shape, dtype), rng(rhs_shape, dtype)] + self._CompileAndCheck(partial(lax.dot, algorithm=algorithm), args_maker) + # Check that accumulation type sets the output type + output = lax.dot(*args_maker(), algorithm=algorithm) + algorithm = lax_internal.canonicalize_dot_algorithm(algorithm) + expected_dtype = dtype if algorithm is None else algorithm.accumulation_type + self.assertEqual(output.dtype, expected_dtype) + + def testDotAlgorithmInvalidFloat8Type(self): + if xla_bridge.using_pjrt_c_api(): + raise SkipTest( + "The dot algorithm attribute is not supported by PJRT C API.") + if jaxlib_version <= (0, 4, 33): + raise SkipTest( + "The dot algorithm attribute is only supported for jaxlib >0.4.33.") + lhs_shape = (3, 4) + rhs_shape = (4, 3) + rng = jtu.rand_default(self.rng()) + lhs, rhs = rng(lhs_shape, np.float32), rng(rhs_shape, dtypes.float8_e4m3fn) + with self.assertRaisesRegex(ValueError, "The dot algorithm"): + lax.dot(lhs, rhs, algorithm="ANY_F8_ANY_F8_F32") + + @parameterized.parameters([ + ({"precision": lax.Precision.HIGHEST}, "The dot_general precision must be None or DEFAULT"), + ({"preferred_element_type": np.float32}, "The preferred_element_type and algorithm arguments"), + ]) + def testDotAlgorithmInvalidParameters(self, kwargs, pattern): + if xla_bridge.using_pjrt_c_api(): + raise SkipTest( + "The dot algorithm attribute is not supported by PJRT C API.") + if jaxlib_version <= (0, 4, 33): + raise SkipTest( + "The dot algorithm attribute is only supported for jaxlib >0.4.33.") + lhs_shape = (3, 4) + rhs_shape = (4, 3) + rng = jtu.rand_default(self.rng()) + lhs, rhs = rng(lhs_shape, np.float32), rng(rhs_shape, np.float32) + with self.assertRaisesRegex(ValueError, pattern): + lax.dot(lhs, rhs, algorithm="F32_F32_F32", **kwargs) + + def testDotAlgorithmTransposeRequired(self): + if xla_bridge.using_pjrt_c_api(): + raise SkipTest( + "The dot algorithm attribute is not supported by PJRT C API.") + if jaxlib_version <= (0, 4, 33): + raise SkipTest( + "The dot algorithm attribute is only supported for jaxlib >0.4.33.") + lhs_shape = (3, 4) + rhs_shape = (4, 3) + rng = jtu.rand_default(self.rng()) + lhs, rhs = rng(lhs_shape, np.float32), rng(rhs_shape, np.float32) + fun = partial(lax.dot, algorithm="F32_F32_F32") + out = fun(lhs, rhs) + _, vjp_fun = jax.vjp(fun, lhs, rhs) + with self.assertRaisesRegex( + ValueError, "When a dot_general algorithm is specified"): + vjp_fun(out) + + @parameterized.parameters([ + ("F32_F32_F32", "F16_F16_F32"), + ("F32_F32_F32", ("F16_F16_F32", "F64_F64_F64")), + ]) + def testDotAlgorithmTranspose(self, algorithm, transpose_algorithm): + if xla_bridge.using_pjrt_c_api(): + raise SkipTest( + "The dot algorithm attribute is not supported by PJRT C API.") + if jaxlib_version <= (0, 4, 33): + raise SkipTest( + "The dot algorithm attribute is only supported for jaxlib >0.4.33.") + def fun(x, y): + return lax.dot(x, y, algorithm=algorithm, + transpose_algorithm=transpose_algorithm) + + algorithm_ = lax_internal.canonicalize_dot_algorithm(algorithm) + lhs_alg, rhs_alg = lax_internal.canonicalize_dot_transpose_algorithm( + transpose_algorithm) + + lhs_shape = (3, 4) + rhs_shape = (4, 3) + rng = jtu.rand_default(self.rng()) + lhs, rhs = rng(lhs_shape, np.float32), rng(rhs_shape, np.float32) + out = fun(lhs, rhs) + + def check_transpose_algorithm(f, arg, alg, trans_alg, trans_trans_alg): + fun_trans = jax.linear_transpose(f, arg) + jaxpr = jax.make_jaxpr(fun_trans)(out) + eqn = next(filter(lambda eqn: eqn.primitive == lax.dot_general_p, jaxpr.eqns)) + self.assertEqual(eqn.params["algorithm"], alg) + self.assertEqual(eqn.params["transpose_algorithm"], trans_alg) + + fun_ = jax.linear_transpose(lambda x: fun_trans(x)[0], out) + jaxpr_ = jax.make_jaxpr(fun_)(arg) + eqn = next(filter(lambda eqn: eqn.primitive == lax.dot_general_p, jaxpr_.eqns)) + self.assertEqual(eqn.params["algorithm"], algorithm_) + + # Note that transposing the RHS of a dot_general introduce extra + # transposes on the input and output, so we don't actually end up with + # the same `transpose_algorithm` parameter after 2 transposes. + self.assertEqual(eqn.params["transpose_algorithm"], trans_trans_alg) + + check_transpose_algorithm(partial(fun, y=rhs), lhs, lhs_alg, + (algorithm_, rhs_alg), (lhs_alg, rhs_alg)) + check_transpose_algorithm(partial(fun, lhs), rhs, rhs_alg, + (algorithm_, lhs_alg), (rhs_alg, lhs_alg)) + @jtu.sample_product( [dict(lhs_shape=lhs_shape, rhs_shape=rhs_shape) for lhs_shape in [(3,), (4, 3)] for rhs_shape in [(3,), (3, 6)]], @@ -1280,7 +1455,7 @@ def testBroadcastInDim(self, inshape, dtype, outshape, dimensions): self._CompileAndCheck(op, args_maker) def testBroadcastInDimOperandShapeTranspose(self): - # Regression test for https://github.com/google/jax/issues/5276 + # Regression test for https://github.com/jax-ml/jax/issues/5276 def f(x): return lax.broadcast_in_dim(x, (2, 3, 4), broadcast_dimensions=(0, 1, 2)).sum() def g(x): @@ -1681,7 +1856,7 @@ def args_maker(): lax.dynamic_update_slice, args_maker) def testDynamicUpdateSliceBatched(self): - # Regression test for https://github.com/google/jax/issues/9083 + # Regression test for https://github.com/jax-ml/jax/issues/9083 x = jnp.arange(5) y = jnp.arange(6, 9) ind = jnp.arange(6) @@ -2236,7 +2411,7 @@ def testReduceWindowShapeDilation(self, shape, window_dimensions, self.assertEqual(shape, result.shape) def testReduceWindowWithEmptyOutput(self): - # https://github.com/google/jax/issues/10315 + # https://github.com/jax-ml/jax/issues/10315 shape = (5, 3, 2) operand, padding, strides = np.ones(shape), 'VALID', (1,) * len(shape) out = jax.eval_shape(lambda x: lax.reduce_window(x, 0., lax.add, padding=padding, @@ -2844,9 +3019,10 @@ def testDynamicUpdateSliceTypeErrors(self): (np.int32(1), np.int16(2)))) def test_primitive_jaxtype_error(self): + err_str = ("Error interpreting argument to .* as an abstract array. The problematic " + r"value is of type .* and was passed to the function at path args\[1\].") with jax.enable_checks(False): - with self.assertRaisesRegex( - TypeError, "Argument .* of type .* is not a valid JAX type"): + with self.assertRaisesRegex(TypeError, err_str): lax.add(1, 'hi') def test_reduction_with_repeated_axes_error(self): @@ -2859,13 +3035,13 @@ def test_ops_do_not_accept_complex_dtypes(self, op): op(2+3j, 4+5j) def test_population_count_booleans_not_supported(self): - # https://github.com/google/jax/issues/3886 + # https://github.com/jax-ml/jax/issues/3886 msg = "population_count does not accept dtype bool" with self.assertRaisesRegex(TypeError, msg): lax.population_count(True) def test_conv_general_dilated_different_input_ranks_error(self): - # https://github.com/google/jax/issues/4316 + # https://github.com/jax-ml/jax/issues/4316 msg = ("conv_general_dilated lhs and rhs must have the same number of " "dimensions") dimension_numbers = lax.ConvDimensionNumbers(lhs_spec=(0, 1, 2), @@ -2885,7 +3061,7 @@ def test_conv_general_dilated_different_input_ranks_error(self): lax.conv_general_dilated(lhs, rhs, **kwargs) def test_window_strides_dimension_shape_rule(self): - # https://github.com/google/jax/issues/5087 + # https://github.com/jax-ml/jax/issues/5087 msg = ("conv_general_dilated window and window_strides must have " "the same number of dimensions") lhs = jax.numpy.zeros((1, 1, 3, 3)) @@ -2894,7 +3070,7 @@ def test_window_strides_dimension_shape_rule(self): jax.lax.conv(lhs, rhs, [1], 'SAME') def test_reduce_window_scalar_init_value_shape_rule(self): - # https://github.com/google/jax/issues/4574 + # https://github.com/jax-ml/jax/issues/4574 args = { "operand": np.ones((4, 4), dtype=np.int32) , "init_value": np.zeros((1,), dtype=np.int32) , "computation": lax.max @@ -3045,7 +3221,7 @@ def testDynamicSliceU8Index(self): np.array(lax.dynamic_slice(x, np.uint8([128]), (1,))), [128]) def test_dot_general_batching_python_builtin_arg(self): - # https://github.com/google/jax/issues/16805 + # https://github.com/jax-ml/jax/issues/16805 @jax.remat def f(x): return jax.lax.dot_general(x, x, (([], []), ([], []))) @@ -3053,7 +3229,7 @@ def f(x): jax.hessian(f)(1.0) # don't crash def test_constant_folding_complex_to_real_scan_regression(self): - # regression test for github.com/google/jax/issues/19059 + # regression test for github.com/jax-ml/jax/issues/19059 def g(hiddens): hiddens_aug = jnp.vstack((hiddens[0], hiddens)) new_hiddens = hiddens_aug.copy() @@ -3088,11 +3264,15 @@ def testAsarray(self, typ): jaxpr = jax.make_jaxpr(asarray_closure)() self.assertLen(jaxpr.eqns, 0) - # Regression test for https://github.com/google/jax/issues/19334 + # Regression test for https://github.com/jax-ml/jax/issues/19334 # lax.asarray as a closure should not trigger transfer guard. with jax.transfer_guard('disallow'): jax.jit(asarray_closure)() + def testOptimizationBarrier(self): + x = lax.optimization_barrier((2, 3)) + self.assertEqual((2, 3), x) + class LazyConstantTest(jtu.JaxTestCase): def _Check(self, make_const, expected): @@ -3250,7 +3430,7 @@ def testArgMaxOfNanChoosesNaN(self): def testUnaryWeakTypes(self, op_name, rec_dtypes): """Test that all lax unary ops propagate weak_type information appropriately.""" if op_name == "bitwise_not": - raise unittest.SkipTest("https://github.com/google/jax/issues/12066") + raise unittest.SkipTest("https://github.com/jax-ml/jax/issues/12066") # Find a valid dtype for the function. for dtype in [float, int, complex, bool]: dtype = dtypes.canonicalize_dtype(dtype) @@ -3373,7 +3553,7 @@ def __repr__(self) -> str: size = property(lambda self: self.data.size // 2) ndim = property(lambda self: self.data.ndim - 1) -def shard_foo_array_handler(xs, shardings): +def shard_foo_array_handler(xs, shardings, layouts): results = [] for x, sharding in safe_zip(xs, shardings): device, = sharding._addressable_device_assignment @@ -3644,7 +3824,7 @@ def test_gather(self): self.assertEqual(ys.shape, (3, 2, 1)) def test_gather_batched_index_dtype(self): - # Regression test for https://github.com/google/jax/issues/16557 + # Regression test for https://github.com/jax-ml/jax/issues/16557 dtype = jnp.int8 size = jnp.iinfo(dtype).max + 10 indices = jnp.zeros(size, dtype=dtype) @@ -3794,7 +3974,8 @@ def _testOnComplexPlaneWorker(self, name, dtype, kind): size_im = 11 atol = None - if name in {"arccos", "arcsin", "arcsinh", "arccosh"}: + if (name in {"arccos", "arcsin", "arcsinh", "arccosh"} + or name in {"arctan", "arctanh"} and jax._src.lib.version > (0, 4, 31)): # TODO(pearu): eliminate this if-block when a fix to mpmath#787 # becomes available extra_prec_multiplier = 20 @@ -3950,21 +4131,21 @@ def regions_with_inaccuracies_keep(*to_keep): elif name == 'arccos': regions_with_inaccuracies_keep('q4.imag', 'ninf', 'pinf', 'ninfj', 'pinfj.real') - elif name == 'arctan': + elif name == 'arctan' and jax._src.lib.version <= (0, 4, 31): if dtype == np.complex64: regions_with_inaccuracies_keep('q1', 'q2', 'q3', 'q4', 'neg', 'pos', 'negj', 'posj', 'ninf', 'pinf', 'ninfj', 'pinfj', - 'mq1.imag', 'mq2.imag', 'mq3.imag', 'mq4.imag', 'mnegj.imag', 'mposj.imag') + 'mq1.imag', 'mq2.imag', 'mq3.imag', 'mq4.imag', 'mnegj.real', 'mnegj.imag', 'mposj.imag') if dtype == np.complex128: - regions_with_inaccuracies_keep('q1', 'q2', 'q3', 'q4', 'neg', 'pos', 'negj', 'posj', 'ninf', 'pinf', 'ninfj', 'pinfj') + regions_with_inaccuracies_keep('q1', 'q2', 'q3', 'q4', 'neg', 'pos', 'negj', 'posj', 'ninf', 'pinf', 'ninfj', 'pinfj', 'mnegj.real') - elif name == 'arctanh': + elif name == 'arctanh' and jax._src.lib.version <= (0, 4, 31): regions_with_inaccuracies_keep('pos.imag', 'ninf', 'pinf', 'ninfj', 'pinfj', 'mpos.imag') elif name in {'cos', 'sin'}: regions_with_inaccuracies_keep('ninf.imag', 'pinf.imag') elif name in {'positive', 'negative', 'conjugate', 'sin', 'cos', 'sqrt', 'expm1', 'log1p', 'tan', - 'arcsinh', 'arcsin', 'arccosh'}: + 'arcsinh', 'arcsin', 'arccosh', 'arctan', 'arctanh'}: regions_with_inaccuracies.clear() else: assert 0 # unreachable diff --git a/tests/lax_vmap_test.py b/tests/lax_vmap_test.py index 37d51c04f8de..37a0011e7bd0 100644 --- a/tests/lax_vmap_test.py +++ b/tests/lax_vmap_test.py @@ -693,8 +693,8 @@ def testSort(self, shape, dimension, arity, bdims, is_stable): # TODO(b/183233858): variadic reduce-window is not implemented on XLA:GPU @jtu.skip_on_devices("gpu") def test_variadic_reduce_window(self): - # https://github.com/google/jax/discussions/9818 and - # https://github.com/google/jax/issues/9837 + # https://github.com/jax-ml/jax/discussions/9818 and + # https://github.com/jax-ml/jax/issues/9837 def normpool(x): norms = jnp.linalg.norm(x, axis=-1) idxs = jnp.arange(x.shape[0]) diff --git a/tests/layout_test.py b/tests/layout_test.py index c72082d0a16c..1d18179ccfee 100644 --- a/tests/layout_test.py +++ b/tests/layout_test.py @@ -25,6 +25,7 @@ from jax._src.layout import Layout, DeviceLocalLayout as DLL from jax._src import test_util as jtu from jax._src.util import safe_zip +from jax._src.lib import xla_extension_version config.parse_flags_with_absl() @@ -45,9 +46,10 @@ def setUp(self): super().setUp() def test_auto_layout(self): - if jtu.test_device_matches(["gpu"]): - self.skipTest("This test does not work on GPU backend.") - mesh = jtu.create_global_mesh((2, 2), ('x', 'y')) + # Remove this condition when xla_extension_version >= 285 + if jtu.test_device_matches(["gpu"]) and xla_extension_version < 285: + self.skipTest("Requires xla_extension_version >= 285 for GPU backend.") + mesh = jtu.create_mesh((2, 2), ('x', 'y')) shape1 = (128, 128) shape2 = (128, 128) s1 = NamedSharding(mesh, P('x', 'y')) @@ -68,18 +70,18 @@ def init(x, y): out_shardings=Layout(DLL.AUTO)).lower(sds1, sds2) compiled_apply = lowered_apply.compile() - arg_layouts, kw_layouts = compiled_apply.input_layouts() + arg_layouts, kw_layouts = compiled_apply.input_layouts self.assertEmpty(kw_layouts) - for i, o in zip(arg_layouts, compiled_apply.output_layouts()): + for i, o in zip(arg_layouts, compiled_apply.output_layouts): self.assertEqual(i.device_local_layout.major_to_minor, o.device_local_layout.major_to_minor[::-1]) init_compiled = jax.jit( init, out_shardings=arg_layouts).lower(sds1, sds2).compile() - for i, o in zip(init_compiled.input_layouts()[0], - init_compiled.output_layouts()): + for i, o in zip(init_compiled.input_layouts[0], + init_compiled.output_layouts): self.assertEqual(i, o) arr1 = jax.device_put(np_inp1, s1) @@ -90,16 +92,16 @@ def init(x, y): init_compiled(arr1, arr2) self.assertEqual(init_count[0], 1) - self.assertEqual(init_out[0].layout, init_compiled.output_layouts()[0]) - self.assertEqual(init_out[1].layout, init_compiled.output_layouts()[1]) + self.assertEqual(init_out[0].layout, init_compiled.output_layouts[0]) + self.assertEqual(init_out[1].layout, init_compiled.output_layouts[1]) with jtu.count_aot_jit_cpp_cache_miss() as apply_count: apply_out = compiled_apply(*init_out) compiled_apply(*init_out) self.assertEqual(apply_count[0], 1) - self.assertEqual(apply_out[0].layout, compiled_apply.output_layouts()[0]) - self.assertEqual(apply_out[1].layout, compiled_apply.output_layouts()[1]) + self.assertEqual(apply_out[0].layout, compiled_apply.output_layouts[0]) + self.assertEqual(apply_out[1].layout, compiled_apply.output_layouts[1]) self.assertTupleEqual(apply_out[0].layout.device_local_layout.major_to_minor, init_out[0].layout.device_local_layout.major_to_minor[::-1]) @@ -112,9 +114,10 @@ def init(x, y): self.assertArraysEqual(apply_out[1], (np_inp2 * 2).T) def test_default_layout(self): - if jtu.test_device_matches(["gpu"]): - self.skipTest("This test does not work on GPU backend.") - mesh = jtu.create_global_mesh((2, 2), ('x', 'y')) + # Remove this condition when xla_extension_version >= 285 + if jtu.test_device_matches(["gpu"]) and xla_extension_version < 285: + self.skipTest("Requires xla_extension_version >= 285 for GPU backend.") + mesh = jtu.create_mesh((2, 2), ('x', 'y')) shape = (4, 4, 2) np_inp = np.arange(math.prod(shape)).reshape(shape) s = NamedSharding(mesh, P('x', 'y')) @@ -130,10 +133,10 @@ def f(x): out = compiled(arr) self.assertTupleEqual( - compiled.input_layouts()[0][0].device_local_layout.major_to_minor[::-1], + compiled.input_layouts[0][0].device_local_layout.major_to_minor[::-1], (2, 1, 0)) self.assertTupleEqual( - compiled.output_layouts().device_local_layout.major_to_minor[::-1], + compiled.output_layouts.device_local_layout.major_to_minor[::-1], (2, 1, 0)) self.assertArraysEqual(out, np_inp.T) self.assertEqual(out.sharding, NamedSharding(mesh, P(None, 'y', 'x'))) @@ -141,10 +144,10 @@ def f(x): compiled_auto = jax.jit(f, in_shardings=Layout(DLL.AUTO), out_shardings=Layout(DLL.AUTO)).lower(sds).compile() self.assertTupleEqual( - compiled_auto.input_layouts()[0][0].device_local_layout.major_to_minor[::-1], + compiled_auto.input_layouts[0][0].device_local_layout.major_to_minor[::-1], (2, 1, 0)) self.assertTupleEqual( - compiled_auto.output_layouts().device_local_layout.major_to_minor[::-1], + compiled_auto.output_layouts.device_local_layout.major_to_minor[::-1], (0, 1, 2)) with self.assertRaisesRegex( @@ -153,9 +156,10 @@ def f(x): out_shardings=DLL.AUTO).lower(sds).compile() def test_in_layouts_out_layouts(self): - if jtu.test_device_matches(["gpu"]): - self.skipTest("This test does not work on GPU backend.") - mesh = jtu.create_global_mesh((2, 2), ('x', 'y')) + # Remove this condition when xla_extension_version >= 285 + if jtu.test_device_matches(["gpu"]) and xla_extension_version < 285: + self.skipTest("Requires xla_extension_version >= 285 for GPU backend.") + mesh = jtu.create_mesh((2, 2), ('x', 'y')) shape = (8, 8) np_inp = np.arange(math.prod(shape)).reshape(shape) s = NamedSharding(mesh, P('x', 'y')) @@ -167,21 +171,22 @@ def f(x): compiled = jax.jit(f, in_shardings=Layout(), out_shardings=Layout(DLL.AUTO)).lower(arr).compile() self.assertTupleEqual( - compiled.input_layouts()[0][0].device_local_layout.major_to_minor[::-1], + compiled.input_layouts[0][0].device_local_layout.major_to_minor[::-1], (1, 0)) self.assertTupleEqual( - compiled.output_layouts().device_local_layout.major_to_minor[::-1], + compiled.output_layouts.device_local_layout.major_to_minor[::-1], (0, 1)) out = compiled(arr) self.assertArraysEqual(out, np_inp.T) - self.assertEqual(out.layout, compiled.output_layouts()) + self.assertEqual(out.layout, compiled.output_layouts) self.assertEqual(out.sharding, NamedSharding(mesh, P('y', 'x'))) def test_sharding_and_layouts(self): - if jtu.test_device_matches(["gpu"]): - self.skipTest("This test does not work on GPU backend.") - mesh = jtu.create_global_mesh((2, 2), ('x', 'y')) + # Remove this condition when xla_extension_version >= 285 + if jtu.test_device_matches(["gpu"]) and xla_extension_version < 285: + self.skipTest("Requires xla_extension_version >= 285 for GPU backend.") + mesh = jtu.create_mesh((2, 2), ('x', 'y')) shape = (4, 8) np_inp = np.arange(math.prod(shape)).reshape(shape) s = NamedSharding(mesh, P('x', 'y')) @@ -190,10 +195,10 @@ def test_sharding_and_layouts(self): out_shardings=Layout(DLL.AUTO, s)).lower(np_inp).compile() out = compiled(np_inp) self.assertTupleEqual( - compiled.input_layouts()[0][0].device_local_layout.major_to_minor[::-1], + compiled.input_layouts[0][0].device_local_layout.major_to_minor[::-1], (1, 0)) self.assertTupleEqual( - compiled.output_layouts().device_local_layout.major_to_minor[::-1], + compiled.output_layouts.device_local_layout.major_to_minor[::-1], (0, 1)) self.assertArraysEqual(out, np_inp.T) self.assertEqual(out.sharding, s) @@ -206,13 +211,13 @@ def f(x, y, z, a, b, c): inps = [np.arange(math.prod(shape)).reshape(shape)] * 6 compiled = jax.jit(f, in_shardings=Layout(DLL.AUTO), out_shardings=Layout(DLL.AUTO)).lower(*inps).compile() - arg_layouts, _ = compiled.input_layouts() + arg_layouts, _ = compiled.input_layouts out1, out2 = compiled(*inps) compiled2 = jax.jit(f, in_shardings=arg_layouts).lower(*inps).compile() out3, out4 = compiled2(*inps) - for l1, l2 in safe_zip(arg_layouts, compiled2.input_layouts()[0]): + for l1, l2 in safe_zip(arg_layouts, compiled2.input_layouts[0]): self.assertEqual(l1, l2) self.assertArraysEqual(out1, out3) @@ -224,7 +229,7 @@ def f(x, y, z, a, b, c): self.assertArraysEqual(out2, out6) def test_no_error_dced_args(self): - mesh = jtu.create_global_mesh((2, 1), ('x', 'y')) + mesh = jtu.create_mesh((2, 1), ('x', 'y')) shape = (8, 2) s = NamedSharding(mesh, P('x', 'y')) np_inp = np.arange(math.prod(shape)).reshape(shape) @@ -238,14 +243,18 @@ def f(x, y): jf = jax.jit(f, in_shardings=Layout(DLL.AUTO, s), out_shardings=Layout(DLL.AUTO, s)) compiled = jf.lower(np_inp, np_inp).compile() - arg_layouts, _ = compiled.input_layouts() + arg_layouts, _ = compiled.input_layouts arrs = [jax.device_put(i, l) for i, l in zip(arrs, arg_layouts)] compiled(*arrs) def test_aot_layout_mismatch(self): if jtu.test_device_matches(["gpu"]): + # The test fails on GPU because the compilation with both input and + # output set to auto layout is underspecified. The GPU compiler chooses + # the default layout as the input layout and that choice does not + # raise an exception. self.skipTest("This test does not work on GPU backend.") - mesh = jtu.create_global_mesh((2, 2), ('x', 'y')) + mesh = jtu.create_mesh((2, 2), ('x', 'y')) shape = (256, 4, 2) np_inp = np.arange(math.prod(shape)).reshape(shape) s = NamedSharding(mesh, P('x')) @@ -281,7 +290,7 @@ def test_cpu_default_backend_layout(self): out_cpu, out_cpu).compile() # doesn't crash def test_device_put_concrete_layout(self): - mesh = jtu.create_global_mesh((2, 2), ('x', 'y')) + mesh = jtu.create_mesh((2, 2), ('x', 'y')) shape = (8, 128) np_inp = np.arange(math.prod(shape)).reshape(shape) s = NamedSharding(mesh, P('x', 'y')) @@ -289,7 +298,7 @@ def test_device_put_concrete_layout(self): compiled = jax.jit( lambda x: x * 2, out_shardings=Layout(DLL.AUTO)).lower(arr).compile() - col = compiled.output_layouts() + col = compiled.output_layouts out = jax.device_put(np_inp, col) self.assertEqual(out.layout, col) @@ -321,19 +330,19 @@ def invalid_layout_spec(self): compiled = jax.jit(lambda x: x).lower(x).compile() with self.assertRaisesRegex( ValueError, 'Sharding has to be concrete when layout.*'): - Layout(compiled.output_layouts()[0], None) + Layout(compiled.output_layouts[0], None) def test_layout_on_sds(self): - mesh = jtu.create_global_mesh((2, 1), ('x', 'y')) + mesh = jtu.create_mesh((2, 1), ('x', 'y')) s = NamedSharding(mesh, P('x', 'y')) np_inp = np.arange(16).reshape(8, 2) arr = jax.device_put(np_inp, s) out_layout = jax.jit(jnp.sin, out_shardings=Layout(DLL.AUTO)).lower( - arr).compile().output_layouts() + arr).compile().output_layouts sds = jax.ShapeDtypeStruct(arr.shape, arr.dtype, sharding=out_layout) - arg_layout, _ = jax.jit(lambda x: x * 2).lower(sds).compile().input_layouts() + arg_layout, _ = jax.jit(lambda x: x * 2).lower(sds).compile().input_layouts self.assertEqual(arg_layout[0], out_layout) with self.assertRaisesRegex( @@ -343,12 +352,12 @@ def test_layout_on_sds(self): jax.ShapeDtypeStruct(arr.shape, arr.dtype, sharding=Layout(DLL.AUTO)) def test_make_array_from_callback(self): - mesh = jtu.create_global_mesh((2, 1), ('x', 'y')) + mesh = jtu.create_mesh((2, 1), ('x', 'y')) s = NamedSharding(mesh, P('x', 'y')) np_inp = np.arange(16).reshape(8, 2) sds = jax.ShapeDtypeStruct(np_inp.shape, np_inp.dtype, sharding=s) - layout = jax.jit(lambda x: x * 2).lower(sds).compile().output_layouts() + layout = jax.jit(lambda x: x * 2).lower(sds).compile().output_layouts out = jax.make_array_from_callback(np_inp.shape, layout, lambda idx: np_inp[idx]) @@ -368,7 +377,7 @@ def test_make_array_from_callback(self): np_inp.shape, Layout(None, None), lambda idx: np_inp[idx]) def test_wsc_concrete_layout(self): - mesh = jtu.create_global_mesh((2, 2), ('x', 'y')) + mesh = jtu.create_mesh((2, 2), ('x', 'y')) shape = (16, 128) s = NamedSharding(mesh, P('x')) np_inp = np.arange(math.prod(shape)).reshape(shape) @@ -391,7 +400,7 @@ def f(x): self.assertArraysEqual(out, np_inp.T) def test_wsc_bfloat16_concrete_layout(self): - mesh = jtu.create_global_mesh((2, 2), ('x', 'y')) + mesh = jtu.create_mesh((2, 2), ('x', 'y')) shape = (16, 128) s = NamedSharding(mesh, P('x')) inp = jnp.arange(math.prod(shape), dtype=jnp.bfloat16).reshape(shape) @@ -414,8 +423,6 @@ def f(x): self.assertArraysEqual(out, inp.T) def test_device_put_user_concrete_layout(self): - if jtu.test_device_matches(["gpu"]): - self.skipTest("This test does not work on GPU backend.") shape = (8, 128) np_inp = np.arange(math.prod(shape)).reshape(shape) @@ -428,7 +435,7 @@ def test_device_put_user_concrete_layout(self): self.assertArraysEqual(out, np_inp) def test_concrete_layout_jit(self): - mesh = jtu.create_global_mesh((2, 2), ('x', 'y')) + mesh = jtu.create_mesh((2, 2), ('x', 'y')) shape = (16, 128) s = NamedSharding(mesh, P('x')) np_inp = np.arange(math.prod(shape)).reshape(shape) @@ -470,9 +477,10 @@ def test_incompatible_aval_error_device_put(self): jax.device_put(inp, l) def test_concrete_layout_in_shardings(self): - if jtu.test_device_matches(["gpu"]): - self.skipTest("This test does not work on GPU backend.") - mesh = jtu.create_global_mesh((2, 2), ('x', 'y')) + # Remove this condition when xla_extension_version >= 285 + if jtu.test_device_matches(["gpu"]) and xla_extension_version < 285: + self.skipTest("Requires xla_extension_version >= 285 for GPU backend.") + mesh = jtu.create_mesh((2, 2), ('x', 'y')) s = NamedSharding(mesh, P('x', 'y')) shape = (16, 128) np_inp = np.arange(math.prod(shape)).reshape(shape) @@ -480,7 +488,9 @@ def test_concrete_layout_in_shardings(self): custom_dll = DLL(major_to_minor=(0, 1)) - @partial(jax.jit, in_shardings=Layout(custom_dll, s)) + @partial(jax.jit, + in_shardings=Layout(custom_dll, s), + out_shardings=Layout(DLL.AUTO)) def f(x): return x.T @@ -500,6 +510,96 @@ def g(x): 'Layout passed to jit does not match the layout on the respective arg'): g(arr) + def test_in_layouts_jit_jnp_input(self): + major_last_layout = DLL(major_to_minor=(1, 0)) + sharding = jax.sharding.SingleDeviceSharding(jax.devices()[0]) + + f = jax.jit(lambda x: x + 1, + in_shardings=Layout(major_last_layout, sharding)) + + arr = jnp.arange(8 * 128).reshape(8, 128) + out = f(arr) + self.assertArraysEqual(out, arr + 1) + + # cpp dispatch should call into shard_args from cpp. + out2 = f(arr) + self.assertArraysEqual(out2, arr + 1) + + np_inp = np.arange(8 * 128).reshape(8, 128) + out3 = f(np_inp) + self.assertArraysEqual(out3, np_inp + 1) + + # cpp dispatch should call into shard_args from cpp. + out4 = f(np_inp) + self.assertArraysEqual(out4, np_inp + 1) + + def test_layout_donation(self): + mesh = jtu.create_mesh((2, 2), ('x', 'y')) + s = NamedSharding(mesh, P('x', 'y')) + shape = (16, 128) + np_inp = np.arange(math.prod(shape)).reshape(shape) + + custom_dll = DLL(major_to_minor=(0, 1)) + arr = jax.device_put(np_inp, Layout(custom_dll, s)) + + @partial(jax.jit, in_shardings=Layout(custom_dll, s), donate_argnums=0) + def f(x): + return x + + f(arr) + self.assertTrue(arr.is_deleted()) + + def test_layout_donation_auto(self): + mesh = jtu.create_mesh((2, 2), ('x', 'y')) + s = NamedSharding(mesh, P('x', 'y')) + shape = (128, 16) + np_inp = np.arange(math.prod(shape)).reshape(shape) + + arr = jax.device_put(np_inp, s) + + @partial(jax.jit, out_shardings=Layout(DLL.AUTO), donate_argnums=0) + def f(x): + return x * x + + f(arr) + self.assertTrue(arr.is_deleted()) + + def test_layout_donation_matching_in_and_out(self): + mesh = jtu.create_mesh((2, 2), ('x', 'y')) + s = NamedSharding(mesh, P('x', 'y')) + shape = (128, 16) + np_inp = np.arange(math.prod(shape)).reshape(shape) + + custom_dll = DLL(major_to_minor=(0, 1)) + l = Layout(custom_dll, s) + arr = jax.device_put(np_inp, l) + + @partial(jax.jit, in_shardings=l, out_shardings=l, donate_argnums=0) + def f(x): + return x * x + + f(arr) + self.assertTrue(arr.is_deleted()) + + @jtu.skip_on_devices('cpu', 'gpu') + def test_layout_donation_mismatching_in_and_out_fails(self): + mesh = jtu.create_mesh((2, 2), ('x', 'y')) + s = NamedSharding(mesh, P('x', 'y')) + shape = (16*2, 32016*2) + np_inp = np.arange(math.prod(shape), dtype=jnp.bfloat16).reshape(shape) + + custom_dll1 = DLL(major_to_minor=(1, 0), _tiling=((8,128), (2,1))) + l1 = Layout(custom_dll1, s) + arr = jax.device_put(np_inp, s) + + @partial(jax.jit, out_shardings=l1, donate_argnums=0) + def f(x): + return x * x + + sds = jax.ShapeDtypeStruct(np_inp.shape, np_inp.dtype, sharding=s) + f.lower(sds).compile()(arr) + self.assertFalse(arr.is_deleted()) + if __name__ == '__main__': absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/tests/linalg_test.py b/tests/linalg_test.py index 1f4488fd5014..e52582eb7526 100644 --- a/tests/linalg_test.py +++ b/tests/linalg_test.py @@ -30,6 +30,7 @@ from jax import numpy as jnp from jax import scipy as jsp from jax._src import config +from jax._src import deprecations from jax._src.lax import linalg as lax_linalg from jax._src import test_util as jtu from jax._src import xla_bridge @@ -110,7 +111,7 @@ def testDetOfSingularMatrix(self): self.assertAllClose(np.float32(0), jsp.linalg.det(x)) @jtu.sample_product( - shape=[(1, 1), (3, 3), (2, 4, 4)], + shape=[(1, 1), (2, 2), (3, 3), (2, 2, 2), (2, 3, 3), (2, 4, 4), (5, 7, 7)], dtype=float_types, ) @jtu.skip_on_flag("jax_skip_slow_tests", True) @@ -268,10 +269,7 @@ def check_left_eigenvectors(a, w, vl): if compute_right_eigenvectors: check_right_eigenvectors(a, w, results[1 + compute_left_eigenvectors]) - # TODO(phawkins): we are seeing nondeterminism in LAPACK routines with - # avx enabled, because for Eigen BLAS nrm2 has an alignment dependence. - # self._CompileAndCheck(partial(jnp.linalg.eig), args_maker, - # rtol=1e-3) + self._CompileAndCheck(partial(jnp.linalg.eig), args_maker, rtol=1e-3) @jtu.sample_product( shape=[(4, 4), (5, 5), (50, 50), (2, 6, 6)], @@ -324,11 +322,11 @@ def testEigvals(self, shape, dtype): a, = args_maker() w1, _ = jnp.linalg.eig(a) w2 = jnp.linalg.eigvals(a) - self.assertAllClose(w1, w2, rtol={np.complex64: 1e-5, np.complex128: 1e-14}) + self.assertAllClose(w1, w2, rtol={np.complex64: 1e-5, np.complex128: 2e-14}) @jtu.run_on_devices("cpu") def testEigvalsInf(self): - # https://github.com/google/jax/issues/2661 + # https://github.com/jax-ml/jax/issues/2661 x = jnp.array([[jnp.inf]]) self.assertTrue(jnp.all(jnp.isnan(jnp.linalg.eigvals(x)))) @@ -488,7 +486,7 @@ def testEighRankDeficient(self, rank): with jax.numpy_rank_promotion("allow"): self.assertLessEqual( np.linalg.norm(np.matmul(a, v) - w * v), - 81 * eps * np.linalg.norm(a), + 85 * eps * np.linalg.norm(a), ) @jtu.sample_product( @@ -1003,7 +1001,7 @@ def qr_and_mul(a): @jtu.skip_on_devices("tpu") def testQrInvalidDtypeCPU(self, shape=(5, 6), dtype=np.float16): - # Regression test for https://github.com/google/jax/issues/10530 + # Regression test for https://github.com/jax-ml/jax/issues/10530 rng = jtu.rand_default(self.rng()) arr = rng(shape, dtype) if jtu.test_device_matches(['cpu']): @@ -1150,11 +1148,22 @@ def np_fn(a): a = (a + T(a.conj())) / 2 return np.linalg.pinv(a, hermitian=hermitian) self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker, tol=1e-4) - self._CompileAndCheck(jnp_fn, args_maker) + self._CompileAndCheck(jnp_fn, args_maker, atol=1e-5) # TODO(phawkins): 6e-2 seems like a very loose tolerance. jtu.check_grads(jnp_fn, args_maker(), 1, rtol=6e-2, atol=1e-3) + def testPinvDeprecatedArgs(self): + msg = "The rcond argument for linalg.pinv is deprecated." + def assert_warns_or_errors(msg=msg): + if deprecations.is_accelerated("jax-numpy-linalg-pinv-rcond"): + return self.assertRaisesRegex(ValueError, msg) + else: + return self.assertWarnsRegex(DeprecationWarning, msg) + x = jnp.ones((3, 3)) + with assert_warns_or_errors(msg): + jnp.linalg.pinv(x, rcond=1E-2) + def testPinvGradIssue2792(self): def f(p): a = jnp.array([[0., 0.],[-p, 1.]], jnp.float32) * 1 / (1 + p**2) @@ -1197,6 +1206,17 @@ def testMatrixRank(self, shape, dtype): self._CompileAndCheck(jnp.linalg.matrix_rank, args_maker, check_dtypes=False, rtol=1e-3) + def testMatrixRankDeprecatedArgs(self): + msg = "The tol argument for linalg.matrix_rank is deprecated." + def assert_warns_or_errors(msg=msg): + if deprecations.is_accelerated("jax-numpy-linalg-matrix_rank-tol"): + return self.assertRaisesRegex(ValueError, msg) + else: + return self.assertWarnsRegex(DeprecationWarning, msg) + x = jnp.ones((3, 3)) + with assert_warns_or_errors(msg): + jnp.linalg.matrix_rank(x, tol=1E-2) + @jtu.sample_product( shapes=[ [(3, ), (3, 1)], # quick-out codepath @@ -1376,7 +1396,6 @@ def testBlockDiag(self, args): args_maker, check_dtypes=False) self._CompileAndCheck(jsp.linalg.block_diag, args_maker) - @jtu.sample_product( shape=[(1, 1), (4, 5), (10, 5), (50, 50)], dtype=float_types + complex_types, @@ -1399,7 +1418,7 @@ def testLuOfSingularMatrix(self): @parameterized.parameters(lax_linalg.lu, lax_linalg._lu_python) def testLuOnZeroMatrix(self, lu): - # Regression test for https://github.com/google/jax/issues/19076 + # Regression test for https://github.com/jax-ml/jax/issues/19076 x = jnp.zeros((2, 2), dtype=np.float32) x_lu, _, _ = lu(x) self.assertArraysEqual(x_lu, x) @@ -1759,7 +1778,6 @@ def sp_func(a): self._CheckAgainstNumpy(sp_func, jax_func, args_maker, rtol=1e-4, atol=1e-4, check_dtypes=False) - @jtu.sample_product( n=[1, 4, 5, 20, 50, 100], dtype=float_types + complex_types, @@ -1795,7 +1813,6 @@ def args_maker(): self._CheckAgainstNumpy(osp.linalg.cho_solve, jsp.linalg.cho_solve, args_maker, tol=1e-3) - @jtu.sample_product( n=[1, 4, 5, 20, 50, 100], dtype=float_types + complex_types, @@ -1821,13 +1838,13 @@ def args_maker(): e = rng((n, n), dtype) return [a, e, ] - #compute_expm is True + # compute_expm is True osp_fun = lambda a,e: osp.linalg.expm_frechet(a,e,compute_expm=True) jsp_fun = lambda a,e: jsp.linalg.expm_frechet(a,e,compute_expm=True) self._CheckAgainstNumpy(osp_fun, jsp_fun, args_maker, check_dtypes=False, tol=tol) self._CompileAndCheck(jsp_fun, args_maker, check_dtypes=False) - #compute_expm is False + # compute_expm is False osp_fun = lambda a,e: osp.linalg.expm_frechet(a,e,compute_expm=False) jsp_fun = lambda a,e: jsp.linalg.expm_frechet(a,e,compute_expm=False) self._CheckAgainstNumpy(osp_fun, jsp_fun, args_maker, @@ -1863,18 +1880,31 @@ def expm(x): jtu.check_grads(expm, (a,), modes=["fwd", "rev"], order=1, atol=tol, rtol=tol) + @jtu.sample_product( + shape=[(4, 4), (15, 15), (50, 50), (100, 100)], + dtype=float_types + complex_types, + ) + @jtu.run_on_devices("cpu") + def testSchur(self, shape, dtype): + rng = jtu.rand_default(self.rng()) + args_maker = lambda: [rng(shape, dtype)] + + self._CheckAgainstNumpy(osp.linalg.schur, jsp.linalg.schur, args_maker) + self._CompileAndCheck(jsp.linalg.schur, args_maker) + @jtu.sample_product( shape=[(1, 1), (4, 4), (15, 15), (50, 50), (100, 100)], dtype=float_types + complex_types, ) @jtu.run_on_devices("cpu") def testRsf2csf(self, shape, dtype): - rng = jtu.rand_default(self.rng()) - args_maker = lambda: [rng(shape, dtype), rng(shape, dtype)] - tol = 3e-5 - self._CheckAgainstNumpy(osp.linalg.rsf2csf, jsp.linalg.rsf2csf, - args_maker, tol=tol) - self._CompileAndCheck(jsp.linalg.rsf2csf, args_maker) + rng = jtu.rand_default(self.rng()) + args_maker = lambda: [rng(shape, dtype), rng(shape, dtype)] + tol = 3e-5 + self._CheckAgainstNumpy( + osp.linalg.rsf2csf, jsp.linalg.rsf2csf, args_maker, tol=tol + ) + self._CompileAndCheck(jsp.linalg.rsf2csf, args_maker) @jtu.sample_product( shape=[(1, 1), (5, 5), (20, 20), (50, 50)], @@ -1885,17 +1915,22 @@ def testRsf2csf(self, shape, dtype): # backend only, so tests on GPU and TPU backends are skipped here @jtu.run_on_devices("cpu") def testFunm(self, shape, dtype, disp): - def func(x): - return x**-2.718 - rng = jtu.rand_default(self.rng()) - args_maker = lambda: [rng(shape, dtype)] - jnp_fun = lambda arr: jsp.linalg.funm(arr, func, disp=disp) - scp_fun = lambda arr: osp.linalg.funm(arr, func, disp=disp) - self._CheckAgainstNumpy( - jnp_fun, scp_fun, args_maker, check_dtypes=False, - tol={np.float32: 2e-3,np.complex64: 2e-3, np.complex128: 1e-6}) - # TODO(phawkins): nondeterminism due to alignment. - # self._CompileAndCheck(jnp_fun, args_maker, atol=2e-5) + + def func(x): + return x**-2.718 + + rng = jtu.rand_default(self.rng()) + args_maker = lambda: [rng(shape, dtype)] + jnp_fun = lambda arr: jsp.linalg.funm(arr, func, disp=disp) + scp_fun = lambda arr: osp.linalg.funm(arr, func, disp=disp) + self._CheckAgainstNumpy( + jnp_fun, + scp_fun, + args_maker, + check_dtypes=False, + tol={np.complex64: 1e-5, np.complex128: 1e-6}, + ) + self._CompileAndCheck(jnp_fun, args_maker, atol=2e-5) @jtu.sample_product( shape=[(4, 4), (15, 15), (50, 50), (100, 100)], @@ -1910,9 +1945,9 @@ def testSqrtmPSDMatrix(self, shape, dtype): mat = arg @ arg.T args_maker = lambda : [mat] if dtype == np.float32 or dtype == np.complex64: - tol = 1e-4 + tol = 1e-4 else: - tol = 1e-8 + tol = 1e-8 self._CheckAgainstNumpy(osp.linalg.sqrtm, jsp.linalg.sqrtm, args_maker, @@ -2121,8 +2156,10 @@ def testSchur(self, shape, dtype): args = rng(shape, dtype) Ts, Ss = lax.linalg.schur(args) eps = np.finfo(dtype).eps - self.assertAllClose(args, Ss @ Ts @ jnp.conj(Ss.T), atol=eps * 600) - self.assertAllClose(np.eye(*shape, dtype=dtype), Ss @ jnp.conj(Ss.T), atol=eps * 100) + self.assertAllClose(args, Ss @ Ts @ jnp.conj(Ss.T), atol=600 * eps) + self.assertAllClose( + np.eye(*shape, dtype=dtype), Ss @ jnp.conj(Ss.T), atol=100 * eps + ) @jtu.sample_product( shape=[(2, 2), (4, 4), (15, 15), (50, 50), (100, 100)], @@ -2164,6 +2201,38 @@ def testHilbert(self, n): self._CheckAgainstNumpy(osp_fun, jsp_fun, args_maker) self._CompileAndCheck(jsp_fun, args_maker) + @jtu.sample_product( + shape=[(5, 1), (10, 4), (128, 12)], + dtype=float_types, + symmetrize_output=[True, False], + ) + @jtu.skip_on_devices("tpu") + def testSymmetricProduct(self, shape, dtype, symmetrize_output): + rng = jtu.rand_default(self.rng()) + batch_size = 10 + atol = 1e-6 if dtype == jnp.float64 else 1e-3 + + a_matrix = rng((batch_size,) + shape, dtype) + c_shape = a_matrix.shape[:-1] + (a_matrix.shape[-2],) + c_matrix = jnp.zeros(c_shape, dtype) + + old_product = jnp.einsum("...ij,...kj->...ik", a_matrix, a_matrix, + precision=lax.Precision.HIGHEST) + new_product = lax_linalg.symmetric_product( + a_matrix, c_matrix, symmetrize_output=symmetrize_output) + new_product_with_batching = jax.vmap( + lambda a, c: lax_linalg.symmetric_product( + a, c, symmetrize_output=symmetrize_output), + in_axes=(0, 0))(a_matrix, c_matrix) + + if not symmetrize_output: + old_product = jnp.tril(old_product) + new_product = jnp.tril(new_product) + new_product_with_batching = jnp.tril(new_product_with_batching) + self.assertAllClose(new_product, old_product, atol=atol) + self.assertAllClose( + new_product_with_batching, old_product, atol=atol) + if __name__ == "__main__": absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/tests/logging_test.py b/tests/logging_test.py index 5a495d47d31b..70f619de5ee6 100644 --- a/tests/logging_test.py +++ b/tests/logging_test.py @@ -19,6 +19,7 @@ import platform import subprocess import sys +import tempfile import textwrap import unittest @@ -74,37 +75,45 @@ def test_no_log_spam(self): if sys.executable is None: raise self.skipTest("test requires access to python binary") - program = textwrap.dedent(""" - import jax - jax.device_count() - f = jax.jit(lambda x: x + 1) - f(1) - f(2) - jax.numpy.add(1, 1) - """) - python = sys.executable - assert "python" in python - env_variables = {"TF_CPP_MIN_LOG_LEVEL": "1"} - if os.getenv("PYTHONPATH"): - env_variables["PYTHONPATH"] = os.getenv("PYTHONPATH") - if os.getenv("LD_LIBRARY_PATH"): - env_variables["LD_LIBRARY_PATH"] = os.getenv("LD_LIBRARY_PATH") - # Make sure C++ logging is at default level for the test process. - proc = subprocess.run( - [python, "-c", program], - capture_output=True, - env=env_variables, - ) - - lines = proc.stdout.split(b"\n") - lines.extend(proc.stderr.split(b"\n")) - allowlist = [ - b"", - b"An NVIDIA GPU may be present on this machine, but a CUDA-enabled " - b"jaxlib is not installed. Falling back to cpu.", - ] - lines = [l for l in lines if l not in allowlist] - self.assertEmpty(lines) + # Save script in file to fix the problem with + # `tsl::Env::Default()->GetExecutablePath()` not working properly with + # command flag. + with tempfile.NamedTemporaryFile( + mode="w+", encoding="utf-8", suffix=".py" + ) as f: + f.write(textwrap.dedent(""" + import jax + jax.device_count() + f = jax.jit(lambda x: x + 1) + f(1) + f(2) + jax.numpy.add(1, 1) + """)) + python = sys.executable + assert "python" in python + env_variables = {"TF_CPP_MIN_LOG_LEVEL": "1"} + if os.getenv("PYTHONPATH"): + env_variables["PYTHONPATH"] = os.getenv("PYTHONPATH") + if os.getenv("LD_LIBRARY_PATH"): + env_variables["LD_LIBRARY_PATH"] = os.getenv("LD_LIBRARY_PATH") + # Make sure C++ logging is at default level for the test process. + proc = subprocess.run( + [python, f.name], + capture_output=True, + env=env_variables, + ) + + lines = proc.stdout.split(b"\n") + lines.extend(proc.stderr.split(b"\n")) + allowlist = [ + b"", + ( + b"An NVIDIA GPU may be present on this machine, but a" + b" CUDA-enabled jaxlib is not installed. Falling back to cpu." + ), + ] + lines = [l for l in lines if l not in allowlist] + self.assertEmpty(lines) def test_debug_logging(self): # Warmup so we don't get "No GPU/TPU" warning later. diff --git a/tests/memories_test.py b/tests/memories_test.py index 816fbeee3e3d..6959aa7535b8 100644 --- a/tests/memories_test.py +++ b/tests/memories_test.py @@ -35,6 +35,7 @@ TransferToMemoryKind, PartitionSpec as P) from jax.experimental.compute_on import compute_on from jax.experimental.shard_map import shard_map +from jax._src.lib import xla_extension_version import numpy as np config.parse_flags_with_absl() @@ -47,14 +48,13 @@ def get_memory_kinds_from_executable(f, args): def _create_inputs(shape, pspec, mem_kind=None): - mesh = jtu.create_global_mesh((2, 2), ("x", "y")) + mesh = jtu.create_mesh((2, 2), ("x", "y")) np_inp = np.arange(math.prod(shape)).reshape(shape) s = NamedSharding(mesh, pspec, memory_kind=mem_kind) inp = jax.device_put(np_inp, s) return mesh, s, np_inp, inp -@jtu.with_config(jax_enable_memories=True) class ShardingMemoriesTest(jtu.JaxTestCase): def setUp(self): @@ -72,7 +72,7 @@ def setUp(self): ) def test_canonicalize_memory_kind(self, name): if name == "named_sharding": - mesh = jtu.create_global_mesh((1,), "x") + mesh = jtu.create_mesh((1,), "x") ns = NamedSharding(mesh, P("x")) self.assertEqual(ns.memory_kind, self._default_memory_kind) elif name == "positional_sharding": @@ -97,7 +97,7 @@ def test_wrong_memory_kind(self, name): with self.assertRaisesRegex( ValueError, "Could not find memory addressable by device.*" ): - mesh = jtu.create_global_mesh((1,), ("x",)) + mesh = jtu.create_mesh((1,), ("x",)) NamedSharding(mesh, P("x"), memory_kind="hbm") elif name == "positional_sharding": with self.assertRaisesRegex( @@ -129,7 +129,7 @@ def test_correct_tpu_memory_kind(self, name): self.skipTest("TPU memory kind test.") if name == "named_sharding": - mesh = jtu.create_global_mesh((1,), ("x",)) + mesh = jtu.create_mesh((1,), ("x",)) NamedSharding(mesh, P("x"), memory_kind=self._default_memory_kind) elif name == "positional_sharding": PositionalSharding(jax.devices(), memory_kind=self._default_memory_kind) @@ -147,7 +147,7 @@ def test_correct_tpu_memory_kind(self, name): ) def test_sharding_eq(self, name): if name == "named_sharding": - mesh = jtu.create_global_mesh((1,), ("x",)) + mesh = jtu.create_mesh((1,), ("x",)) s1 = NamedSharding(mesh, P("x")) s2 = NamedSharding(mesh, P("x"), memory_kind=self._default_memory_kind) self.assertEqual(s1, s2) @@ -165,7 +165,7 @@ def test_sharding_eq(self, name): self.assertEqual(s1, s2) def test_sharding_equivalent(self): - mesh = jtu.create_global_mesh((1,), ("x",)) + mesh = jtu.create_mesh((1,), ("x",)) ndim = 2 ns1 = NamedSharding(mesh, P("x")) gs1 = GSPMDSharding( @@ -186,7 +186,6 @@ def test_default_memory_kind(self): self.assertEqual(dev.default_memory().kind, self._default_memory_kind) -@jtu.with_config(jax_enable_memories=True) class DevicePutTest(jtu.JaxTestCase): def setUp(self): @@ -217,7 +216,7 @@ def test_error_transfer_to_memory_kind_outside_jit(self): def test_device_put_host_to_hbm(self, host_memory_kind: str): if jtu.test_device_matches(["gpu"]) and host_memory_kind == "unpinned_host": self.skipTest("unpinned_host does not work on GPU backend.") - mesh = jtu.create_global_mesh((2, 2), ("x", "y")) + mesh = jtu.create_mesh((2, 2), ("x", "y")) s_host = NamedSharding(mesh, P("y"), memory_kind=host_memory_kind) np_inp = np.arange(16).reshape(8, 2) @@ -233,7 +232,7 @@ def test_device_put_host_to_hbm(self, host_memory_kind: str): def test_device_put_hbm_to_host(self, host_memory_kind: str): if jtu.test_device_matches(["gpu"]) and host_memory_kind == "unpinned_host": self.skipTest("unpinned_host does not work on GPU backend.") - mesh = jtu.create_global_mesh((2, 2), ("x", "y")) + mesh = jtu.create_mesh((2, 2), ("x", "y")) s_host = NamedSharding(mesh, P("y"), memory_kind=host_memory_kind) inp = jnp.arange(16).reshape(8, 2) @@ -316,7 +315,7 @@ def test_device_put_on_different_device_with_the_same_memory_kind( # TODO(yashkatariya): Enable this once we can compute on host. # def test_device_put_resharding(self): - # mesh = jtu.create_global_mesh((2, 2), ("x", "y")) + # mesh = jtu.create_mesh((2, 2), ("x", "y")) # s_host = NamedSharding(mesh, P("x", "y"), memory_kind="unpinned_host") # s_hbm = s_host.with_memory_kind("device") # np_inp = np.arange(16).reshape(8, 2) @@ -343,7 +342,7 @@ def test_device_put_on_different_device_with_the_same_memory_kind( def test_device_put_numpy_array(self, host_memory_kind: str): if jtu.test_device_matches(["gpu"]) and host_memory_kind == "unpinned_host": self.skipTest("unpinned_host does not work on GPU backend.") - mesh = jtu.create_global_mesh((2, 2), ("x", "y")) + mesh = jtu.create_mesh((2, 2), ("x", "y")) np_inp = np.arange(16).reshape(8, 2) s_hbm = NamedSharding(mesh, P(("x", "y")), memory_kind="device") s_host = s_hbm.with_memory_kind(host_memory_kind) @@ -464,7 +463,7 @@ def f(a): def test_parameter_streaming_with_scalar_and_constant(self): if jtu.test_device_matches(["gpu"]): self.skipTest("This test does not work on GPU backend.") - mesh = jtu.create_global_mesh((2, 2), ("x", "y")) + mesh = jtu.create_mesh((2, 2), ("x", "y")) scalar_inp = 1 s_host = NamedSharding(mesh, P(), memory_kind="pinned_host") @@ -490,7 +489,7 @@ def f(scalar_input): def test_parameter_and_output_streaming_with_array(self): if xb.backend_xla_version() is not None and xb.backend_xla_version() < 2: self.skipTest("This test requires an xla_version >= 2.") - mesh = jtu.create_global_mesh((2, 2), ("x", "y")) + mesh = jtu.create_mesh((2, 2), ("x", "y")) np_inp = np.arange(16).reshape(8, 2) s_host = NamedSharding(mesh, P("x", "y"), memory_kind="pinned_host") inp_host = jax.device_put(np_inp, s_host) @@ -542,7 +541,7 @@ def f(x): ) def test_identity_jit_host_to_device_and_vice_versa(self): - mesh = jtu.create_global_mesh((2, 2), ("x", "y")) + mesh = jtu.create_mesh((2, 2), ("x", "y")) np_inp = np.arange(16).reshape(8, 2) s_host = NamedSharding(mesh, P('x', 'y'), memory_kind='pinned_host') s_dev = s_host.with_memory_kind('device') @@ -562,7 +561,7 @@ def test_identity_jit_host_to_device_and_vice_versa(self): self.assertEqual(out_host.sharding, s_host) def test_parameter_streaming_inside_scan(self): - mesh = jtu.create_global_mesh((1, 1, 2), ("x", "y", "z")) + mesh = jtu.create_mesh((1, 1, 2), ("x", "y", "z")) np_inp = np.arange(4096.0).reshape(16, 16, 16) s_host = NamedSharding(mesh, P("x", "y", "z"), memory_kind="pinned_host") arr_host = jax.device_put(np_inp, s_host) @@ -584,7 +583,7 @@ def body(carry, x): def test_output_streaming(self): if jtu.test_device_matches(["gpu"]): self.skipTest("This test is flaky on GPU backend.") - mesh = jtu.create_global_mesh((1, 1), ("x", "y")) + mesh = jtu.create_mesh((1, 1), ("x", "y")) np_inp = np.arange(16.0).reshape(8, 2) s_hbm = NamedSharding(mesh, P("x", "y"), memory_kind="device") s_host = NamedSharding(mesh, P("x", "y"), memory_kind="pinned_host") @@ -621,7 +620,7 @@ def test_output_streaming_inside_scan(self): self.skipTest("This test does not work on GPU backend.") if xb.backend_xla_version() is not None and xb.backend_xla_version() < 2: self.skipTest("This test requires an xla_version >= 2.") - mesh = jtu.create_global_mesh((1, 1, 2), ("x", "y", "z")) + mesh = jtu.create_mesh((1, 1, 2), ("x", "y", "z")) np_inp = np.arange(4096).reshape(16, 16, 16) s_hbm = NamedSharding(mesh, P(None, "y", "z"), memory_kind="device") arr_hbm = jax.device_put(np_inp, s_hbm) @@ -668,7 +667,6 @@ def f(): self._check_device_put_addressable_shards(out, np_inp * 2, s_dev, 'device') -@jtu.with_config(jax_enable_memories=True) class ComputeOffload(jtu.BufferDonationTestCase): def setUp(self): @@ -683,7 +681,7 @@ def _check_mem_kind(self, executable_kind, out_sharding, expected_kind): self.assertEqual(executable_kind, expected_kind) def test_compute_no_inputs(self): - mesh = jtu.create_global_mesh((4,), ('data')) + mesh = jtu.create_mesh((4,), ('data')) tpu_sharding = NamedSharding(mesh, P('data')) cpu_sharding = NamedSharding(mesh, P('data'), memory_kind='pinned_host') @@ -701,7 +699,7 @@ def init(): def test_compute_no_inputs_host_replicated(self): if xb.backend_xla_version() is not None and xb.backend_xla_version() < 3: self.skipTest("This test requires an xla_version >= 3.") - mesh = jtu.create_global_mesh((4,), ('data')) + mesh = jtu.create_mesh((4,), ('data')) tpu_sharding = NamedSharding(mesh, P('data')) cpu_sharding = NamedSharding(mesh, P(), memory_kind='pinned_host') @@ -745,6 +743,29 @@ def h(x): self.assertArraysEqual(out2, inp * 6) self.assertEqual(out2.sharding.memory_kind, 'pinned_host') + def test_compute_on_basic_inline(self): + @compute_on('device_host') + @jax.jit + def g(x): + return x * 2 + + @functools.partial(jax.jit, inline=True) + def h(x): + y = g(x) + return y * 3 + + @jax.jit + def f(x): + return h(x) + + inp = jnp.arange(8) + out = f(inp) + self.assertArraysEqual(out, inp * 6) + + lowered_text = f.lower(jnp.arange(8)).as_text('hlo') + self.assertRegex(lowered_text, + 'to_apply=g.*frontend_attributes={_xla_compute_type="host"}') + def test_compute_on_reduction(self): out_s = SingleDeviceSharding(jax.devices()[0], memory_kind='pinned_host') @@ -846,7 +867,7 @@ def f(x): self.assertLen(out, 2) def test_nested_no_op_compute(self): - mesh = jtu.create_global_mesh((2, 2), ('x', 'y')) + mesh = jtu.create_mesh((2, 2), ('x', 'y')) s = NamedSharding(mesh, P('x', 'y')) np_inp = np.arange(16).reshape(8, 2) arr = jax.device_put(np_inp, s) @@ -871,7 +892,7 @@ def f2(x): self.assertEqual(out.sharding, s) def test_sharded_compute_on_host(self): - mesh = jtu.create_global_mesh((2, 2), ('x', 'y')) + mesh = jtu.create_mesh((2, 2), ('x', 'y')) s = NamedSharding(mesh, P('x', 'y')) np_inp = np.arange(16).reshape(8, 2) arr = jax.device_put(np_inp, s) @@ -924,7 +945,7 @@ def f_bwd(res, tx): def test_host_offload_in_custom_vjp_sharded(self): if xb.backend_xla_version() is not None and xb.backend_xla_version() < 2: self.skipTest("This test requires an xla_version >= 2.") - mesh = jtu.create_global_mesh((2, 2), ("x", "y")) + mesh = jtu.create_mesh((2, 2), ("x", "y")) s = NamedSharding(mesh, P('x')) @jax.custom_vjp @@ -1008,7 +1029,7 @@ def f(x): def test_pure_host_data_and_compute(self): if xb.backend_xla_version() is not None and xb.backend_xla_version() < 2: self.skipTest("This test requires an xla_version >= 2.") - mesh = jtu.create_global_mesh((2, 2), ('x', 'y')) + mesh = jtu.create_mesh((2, 2), ('x', 'y')) s = NamedSharding(mesh, P('x', 'y'), memory_kind='pinned_host') np_inp = np.arange(16).reshape(8, 2) arr_host = jax.device_put(np_inp, s) @@ -1035,7 +1056,7 @@ def test_eager_compute(self): self.assertArraysAllClose(out, jnp.sin(inp * 2)) def test_compute_per_annotation(self): - mesh = jtu.create_global_mesh((2, 2), ("x", "y")) + mesh = jtu.create_mesh((2, 2), ("x", "y")) s = NamedSharding(mesh, P("x", "y")) np_inp = np.arange(16.).reshape(8, 2) arr = jax.device_put(np_inp, s) @@ -1178,6 +1199,23 @@ def test_jit_cpp_cache_hit(self): self.assertArraysEqual(out, np_inp @ np_inp.T) self.assertArraysEqual(out2, np_inp @ np_inp.T) + def test_jit_compilation_cache_hit(self): + mesh, s, np_inp, inp = _create_inputs((8, 2), P("x", "y")) + inp2 = jax.device_put( + np_inp, GSPMDSharding(tuple(mesh.devices.flat), + s._to_xla_hlo_sharding(inp.ndim), + memory_kind="device") + ) + + f = jax.jit(lambda x: x @ x.T) + + with (jtu.count_pjit_cpp_cache_miss() as cpp_count, + jtu.count_jit_and_pmap_lowerings() as compile_count): + f(inp) + f(inp2) + self.assertEqual(cpp_count[0], 2) + self.assertEqual(compile_count[0], 1) + def test_jit_cpp_cache_output_hit(self): _, _, _, inp = _create_inputs((8, 2), P("x"), mem_kind="device") @@ -1200,7 +1238,7 @@ def mul(x): f = jax.jit(mul, in_shardings=s) g = jax.jit(mul, in_shardings=s2) - with jtu.count_jit_and_pmap_compiles() as count: + with jtu.count_jit_and_pmap_lowerings() as count: out = f(np_inp) out2 = g(np_inp2) self.assertEqual(count[0], 1) @@ -1209,7 +1247,7 @@ def mul(x): self.assertArraysEqual(out2, np_inp2 @ np_inp2.T) def test_sharding_devices_indices_map_cache_hit(self): - mesh = jtu.create_global_mesh((2, 2), ("x", "y")) + mesh = jtu.create_mesh((2, 2), ("x", "y")) shape = (8, 2) s1 = NamedSharding(mesh, P("x", "y")) s2 = NamedSharding(mesh, P("x", "y"), memory_kind="device") @@ -1224,7 +1262,7 @@ def test_sharding_devices_indices_map_cache_hit(self): def test_no_donation_across_memory_kinds(self): if xb.using_pjrt_c_api(): raise unittest.SkipTest("GetOutputShardings not supported in PJRT C API") - mesh = jtu.create_global_mesh((2, 1), ("x", "y")) + mesh = jtu.create_mesh((2, 1), ("x", "y")) np_inp = np.arange(16).reshape(8, 2) s_hbm = NamedSharding(mesh, P("x")) s_host = s_hbm.with_memory_kind("pinned_host") @@ -1243,7 +1281,7 @@ def f(x): self.assertNotDeleted(inp) def test_single_mem_kind_donation_default_mem_kind(self): - mesh = jtu.create_global_mesh((2,), "x") + mesh = jtu.create_mesh((2,), "x") s = NamedSharding(mesh, P()) @functools.partial(jax.jit, out_shardings=s, donate_argnums=0) @@ -1259,7 +1297,7 @@ def f(inp1): self.assertDeleted(x) def test_compute_offload_inside_shmap(self): - mesh = jtu.create_global_mesh((2, 2), ('x', 'y')) + mesh = jtu.create_mesh((2, 2), ('x', 'y')) s = NamedSharding(mesh, P('x', 'y')) np_inp = np.arange(16).reshape(8, 2) arr = jax.device_put(np_inp, s) @@ -1313,7 +1351,7 @@ def h(x): self.assertArraysAllClose(out, expected_out, rtol=1e-3) def test_mem_kind_donation_pinned_host(self): - mesh = jtu.create_global_mesh((2,), "x") + mesh = jtu.create_mesh((2,), "x") s = NamedSharding(mesh, P(), memory_kind='pinned_host') s_dev = s.with_memory_kind('device') @@ -1335,7 +1373,7 @@ def f(inp1, inp2): @parameterized.parameters("pinned_host", "device") def test_identity_mem_kind_donation(self, mem_kind): - mesh = jtu.create_global_mesh((2,), "x") + mesh = jtu.create_mesh((2,), "x") s = NamedSharding(mesh, P(), memory_kind=mem_kind) @functools.partial(jax.jit, out_shardings=s, donate_argnums=0) @@ -1353,7 +1391,7 @@ def f(inp): @jtu.run_on_devices('tpu') def test_aot_device_implicit_transfer(self): - mesh = jtu.create_global_mesh((1,), 'x') + mesh = jtu.create_mesh((1,), 'x') np_inp = np.arange(8) arr = jax.device_put(np_inp, NamedSharding(mesh, P())) @@ -1374,8 +1412,55 @@ def f(x): self.assertEqual(out.sharding, NamedSharding(mesh, P())) self.assertEqual(out.sharding.memory_kind, 'device') + def test_compute_offload_with_donation(self): + sharding = jax.sharding.SingleDeviceSharding(jax.devices()[0]) + p_sharding = jax.sharding.SingleDeviceSharding( + jax.devices()[0], memory_kind="pinned_host" + ) + + @compute_on("device_host") + @jax.jit + def host_fn(x_in, y_in): + return x_in * x_in, y_in + y_in + + def test_fn(x_in, y_in): + x_out, y_out = host_fn(x_in, y_in) + return x_out, y_out + + x = jnp.arange(0, 1024, dtype=jnp.float32) + y = jnp.arange(0, 1024, dtype=jnp.float32) + y = jax.device_put(y, p_sharding) + + x1 = jnp.arange(0, 1024, dtype=jnp.float32) + y1 = jnp.arange(0, 1024, dtype=jnp.float32) + + jit_fn = jax.jit( + test_fn, + in_shardings=(sharding, p_sharding), + out_shardings=(sharding, p_sharding), + donate_argnums=(0, 1), + ) + x_out, y_out = jit_fn(x, y) + self.assertArraysEqual(x_out, x1 * x1) + self.assertArraysEqual(y_out, y1 + y1) + + def test_compute_on_cache_miss(self): + @jax.jit + def f(x): + return x * 2 + + inp = jnp.arange(10) + with jtu.count_jit_tracing_cache_miss() as count: + with compute_on('device_host'): + f(inp) + + with compute_on('device'): + f(inp) + + # 2 for `f` and `2` for `mul` (compute type changes for `mul`) + self.assertEqual(count[0], 4) + -@jtu.with_config(jax_enable_memories=True) class ActivationOffloadingTest(jtu.JaxTestCase): def setUp(self): @@ -1384,7 +1469,7 @@ def setUp(self): super().setUp() def test_remat_jaxpr_offloadable(self): - mesh = jtu.create_global_mesh((2,), ("x",)) + mesh = jtu.create_mesh((2,), ("x",)) inp = jax.device_put(np.arange(16.), NamedSharding(mesh, P("x"))) def policy(prim, *avals, **params): @@ -1427,7 +1512,7 @@ def f(x): self.assertGreater(compiled_stats.host_temp_size_in_bytes, 0) def test_remat_scan_jaxpr_offloadable(self): - mesh = jtu.create_global_mesh((2,), ("x",)) + mesh = jtu.create_mesh((2,), ("x",)) shape = (256, 128) np_inp = np.arange(math.prod(shape), dtype=np.float32).reshape(shape) s = NamedSharding(mesh, P("x")) @@ -1480,13 +1565,12 @@ def g(ys, _): compiled_stats = compiled_f.memory_analysis() if compiled_stats is not None: - if jtu.pjrt_c_api_version_at_least(0, 43): - self.assertGreater(compiled_stats.host_temp_size_in_bytes, 0) + self.assertGreater(compiled_stats.host_temp_size_in_bytes, 0) def test_remat_scan_layout_change_offloadable(self): - if not jtu.test_device_matches(["tpu"]): - self.skipTest("Remat scan does not work on GPU backend.") - mesh = jtu.create_global_mesh((2,), ("x",)) + if jtu.test_device_matches(["gpu"]) and xla_extension_version < 289: + self.skipTest("Requires xla_extension_version >= 289") + mesh = jtu.create_mesh((2,), ("x",)) shape = (256, 128) np_inp = np.arange(math.prod(shape), dtype=np.float32).reshape(shape) s = NamedSharding(mesh, P("x")) @@ -1519,11 +1603,14 @@ def g(ys, _): self.assertIn('S(5)', compiled_text) self.assertNotRegex(compiled_text, r"copy-start.*S\(5\)") self.assertNotRegex(compiled_text, r"copy-done.*S\(5\)") + self.assertRegex(compiled_text, r"dynamic-update-slice-start.*S\(5\)") + self.assertRegex(compiled_text, r"dynamic-update-slice-done.*S\(5\)") + self.assertRegex(compiled_text, r"dynamic-slice-start.*S\(5\)") + self.assertRegex(compiled_text, r"dynamic-slice-done.*S\(5\)") compiled_stats = compiled_f.memory_analysis() if compiled_stats is not None: - if jtu.pjrt_c_api_version_at_least(0, 43): - self.assertGreater(compiled_stats.host_temp_size_in_bytes, 0) + self.assertGreater(compiled_stats.host_temp_size_in_bytes, 0) def test_remat_checkpoint_dots_with_no_batch_dims(self): policy = jax.checkpoint_policies.offload_dot_with_no_batch_dims( @@ -1554,8 +1641,7 @@ def f(x): compiled_stats = compiled_f.memory_analysis() if compiled_stats is not None: - if jtu.pjrt_c_api_version_at_least(0, 43): - self.assertGreater(compiled_stats.host_temp_size_in_bytes, 0) + self.assertGreater(compiled_stats.host_temp_size_in_bytes, 0) if __name__ == "__main__": absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/tests/mesh_utils_test.py b/tests/mesh_utils_test.py index 6916c2c37e56..b182caf2dcd2 100644 --- a/tests/mesh_utils_test.py +++ b/tests/mesh_utils_test.py @@ -24,7 +24,7 @@ from jax._src import mesh as mesh_lib from jax._src import test_util from jax._src.sharding_impls import NamedSharding, PartitionSpec, local_to_global_shape -from jax.experimental import mesh_utils +from jax._src import mesh_utils from jax.sharding import Mesh # pylint: disable=g-importing-member import numpy as np diff --git a/tests/mosaic/BUILD b/tests/mosaic/BUILD index d182c99be7b1..3d1348371f07 100644 --- a/tests/mosaic/BUILD +++ b/tests/mosaic/BUILD @@ -15,7 +15,7 @@ load( "//jaxlib:jax.bzl", "jax_generate_backend_suites", - "jax_test", + "jax_multiplatform_test", "py_deps", ) @@ -28,60 +28,50 @@ package( jax_generate_backend_suites() -DISABLED_BACKENDS = [ - "cpu", - "tpu", -] - -DISABLED_CONFIGS = [ - "gpu", - "gpu_a100", - "gpu_p100", - "gpu_p100_x32", - "gpu_x32", - "gpu_pjrt_c_api", -] - -jax_test( +jax_multiplatform_test( name = "gpu_test", srcs = ["gpu_test.py"], - disable_backends = DISABLED_BACKENDS, - disable_configs = DISABLED_CONFIGS, + enable_backends = [], + enable_configs = [ + "gpu_h100", + "gpu_h100_2gpu", + ], shard_count = 4, + tags = ["multiaccelerator"], deps = [ "//jax:mosaic_gpu", ] + py_deps("absl/testing") + py_deps("numpy"), ) -jax_test( +jax_multiplatform_test( name = "matmul_test", srcs = ["matmul_test.py"], - disable_backends = DISABLED_BACKENDS, - disable_configs = DISABLED_CONFIGS, - shard_count = 16, + enable_backends = [], + enable_configs = ["gpu_h100"], + shard_count = 5, deps = [ "//jax:mosaic_gpu", "//jax/experimental/mosaic/gpu/examples:matmul", - ] + py_deps("absl/testing") + py_deps("numpy"), + ] + py_deps("absl/testing") + py_deps("numpy") + py_deps("hypothesis"), ) -jax_test( +jax_multiplatform_test( name = "flash_attention", - srcs = ["//third_party/py/jax/experimental/mosaic/gpu/examples:flash_attention.py"], - disable_backends = DISABLED_BACKENDS, - disable_configs = DISABLED_CONFIGS, - main = "//third_party/py/jax/experimental/mosaic/gpu/examples:flash_attention.py", + srcs = ["//jax/experimental/mosaic/gpu/examples:flash_attention.py"], + enable_backends = [], + enable_configs = ["gpu_h100"], + main = "//jax/experimental/mosaic/gpu/examples:flash_attention.py", tags = ["notap"], deps = [ "//jax:mosaic_gpu", ] + py_deps("numpy"), ) -jax_test( +jax_multiplatform_test( name = "flash_attention_test", srcs = ["flash_attention_test.py"], - disable_backends = DISABLED_BACKENDS, - disable_configs = DISABLED_CONFIGS, + enable_backends = [], + enable_configs = ["gpu_h100"], deps = [ "//jax:mosaic_gpu", "//jax/experimental/mosaic/gpu/examples:flash_attention", diff --git a/tests/mosaic/gpu_test.py b/tests/mosaic/gpu_test.py index f4fb6761ce41..f949b63c7844 100644 --- a/tests/mosaic/gpu_test.py +++ b/tests/mosaic/gpu_test.py @@ -15,10 +15,10 @@ """Tests for Mosaic GPU DSL functions and utilities.""" import enum -from functools import partial import itertools import math import operator +import unittest from absl.testing import absltest, parameterized import jax @@ -42,8 +42,8 @@ class Dimension(enum.IntEnum): # Just to make parameterized tests expand ok y = 1 z = 2 else: - from jax.experimental.mosaic import gpu as mosaic_gpu - from jax.experimental.mosaic.gpu import dsl as mgpu + import jax.experimental.mosaic.gpu as mgpu + from jax.experimental.mosaic.gpu import utils as utils from jax.experimental.mosaic.gpu import profiler from jax.experimental.mosaic.gpu.utils import * # noqa: F403 from jax._src.lib.mlir.dialects import gpu @@ -120,7 +120,7 @@ def body(*idx): nvvm.fence_proxy(nvvm.ProxyKind.async_) -def iota_tensor(m, n, mlir_dtype): +def iota_tensor(m, n, dtype: jax.typing.DTypeLike): assert m % 64 == 0 assert n % 8 == 0 def c(i): @@ -144,8 +144,12 @@ def c(i): value = arith.index_cast(i32, value) vec = vector.insertelement(value, vec, position=c(col_offset)) registers[row_tile, col_tile, row_subtile, 0] = vec - t = mgpu.FragmentedArray(_registers=registers, _layout=mgpu.WGMMA_LAYOUT) - return t.astype(mlir_dtype) + t = mgpu.FragmentedArray( + _registers=registers, _layout=mgpu.WGMMA_LAYOUT, _is_signed=True + ) + return t.astype( + utils.dtype_to_ir_type(dtype), is_signed=utils.is_signed(dtype) + ) class TestCase(parameterized.TestCase): @@ -169,14 +173,14 @@ def test_copy_basic(self): def kernel(ctx, src, dst, _): copy(src, dst) x = jnp.arange(2 * 3 * 5).reshape(2, 5, 3) - y = mosaic_gpu.as_gpu_kernel(kernel, (1, 1, 1), (128, 1, 1), x, x, ())(x) + y = mgpu.as_gpu_kernel(kernel, (1, 1, 1), (128, 1, 1), x, x, ())(x) np.testing.assert_array_equal(y, x) def test_copy_swizzle(self): def kernel(ctx, src, dst, _): copy(src, dst, swizzle=128) x = jnp.arange(8 * 32, dtype=jnp.float32).reshape(8, 32) - y = mosaic_gpu.as_gpu_kernel(kernel, (1, 1, 1), (128, 1, 1), x, x, ())(x) + y = mgpu.as_gpu_kernel(kernel, (1, 1, 1), (128, 1, 1), x, x, ())(x) expected = np.zeros_like(y) for i in range(8): for j in range(8): @@ -190,7 +194,7 @@ def kernel(ctx, src, dst, smem): copy(src, smem, swizzle=128) copy(smem, dst, swizzle=128) x = jnp.arange(8 * 32, dtype=jnp.float32).reshape(8, 32) - y = mosaic_gpu.as_gpu_kernel(kernel, (1, 1, 1), (128, 1, 1), x, x, x)(x) + y = mgpu.as_gpu_kernel(kernel, (1, 1, 1), (128, 1, 1), x, x, x)(x) np.testing.assert_array_equal(y, x) def test_iota_tensor(self): @@ -198,7 +202,7 @@ def test_iota_tensor(self): def kernel(ctx, dst, _): f32 = ir.F32Type.get() index = ir.IndexType.get() - registers = iota_tensor(m, n, f32).registers + registers = iota_tensor(m, n, jnp.float32).registers assert registers.size == 16, registers.size for i, vec_reg in enumerate(registers.flat): for j in range(2): @@ -207,7 +211,7 @@ def kernel(ctx, dst, _): reg, dst, [gpu.thread_id(gpu.Dimension.x), c(2 * i + j, index)] ) out_shape = jax.ShapeDtypeStruct((128, 32), jnp.float32) - regs = mosaic_gpu.as_gpu_kernel( + regs = mgpu.as_gpu_kernel( kernel, (1, 1, 1), (128, 1, 1), (), out_shape, () )() thread_ids = np.arange(128) @@ -246,7 +250,7 @@ def kernel(ctx, inp, out, _): out_shape = list(x.shape) out_shape.insert(dim, 1) out_ty = jax.ShapeDtypeStruct(out_shape, jnp.float32) - y = mosaic_gpu.as_gpu_kernel( + y = mgpu.as_gpu_kernel( kernel, (1, 1, 1), (128, 1, 1), x, out_ty, () )(x) np.testing.assert_array_equal(y, x.reshape(out_shape)) @@ -274,7 +278,7 @@ def kernel(ctx, inp, out, _): out_shape = list(in_shape) out_shape[dim:dim + 1] = [2, 2, out_shape[dim] // 4] out_ty = jax.ShapeDtypeStruct(out_shape, jnp.float32) - y = mosaic_gpu.as_gpu_kernel( + y = mgpu.as_gpu_kernel( kernel, (1, 1, 1), (128, 1, 1), x, out_ty, () )(x) np.testing.assert_array_equal(y, x.reshape(out_ty.shape)) @@ -288,7 +292,7 @@ def kernel(ctx, inp, out, _): x = np.arange(8 * 2 * 8, dtype=jnp.float32).reshape(8, 2, 8) out_ty = jax.ShapeDtypeStruct((16, 8) if dim == 0 else (8, 16), jnp.float32) - y = mosaic_gpu.as_gpu_kernel( + y = mgpu.as_gpu_kernel( kernel, (1, 1, 1), (128, 1, 1), x, out_ty, () )(x) np.testing.assert_array_equal(y, x.reshape(out_ty.shape)) @@ -307,7 +311,7 @@ def test_fold_strided( expanded_shape = get_packed_shape(strides, shape) total_size = np.prod(expanded_shape) np_inp = np.arange(total_size, dtype=jnp.float32).reshape(expanded_shape) - index = tuple([slice(0, s) for s in shape]) + index = tuple(slice(0, s) for s in shape) # Reference implementation def np_fold(inp, dim, fold_rank): @@ -327,7 +331,7 @@ def kernel(ctx, inp, out, _): copy(memref_fold(memref_slice(inp, index), dim, fold_rank), out) out = np_fold(np_inp[index], dim, fold_rank) - y = mosaic_gpu.as_gpu_kernel( + y = mgpu.as_gpu_kernel( kernel, (1, 1, 1), (128, 1, 1), np_inp, out, () )(np_inp) assert ( @@ -360,32 +364,43 @@ def get_packed_shape(strides, shape): class WGMMATest(TestCase): + @parameterized.named_parameters(("f32", jnp.float32), ("f16", jnp.float16)) + def test_store_untiled(self, dtype): + def kernel(ctx, out, _): + del ctx + iota_tensor(64, 64, dtype).store_untiled(out) + expected = np.arange(64 * 64, dtype=dtype).reshape(64, 64) + iota = mgpu.as_gpu_kernel( + kernel, (1, 1, 1), (128, 1, 1), (), expected, () + )() + np.testing.assert_array_equal(iota, expected) + @parameterized.named_parameters( - ("f32", ir.F32Type, jnp.float32), ("f16", ir.F16Type, jnp.float16) + ("f32", jnp.float32, 256), + ("f16", jnp.float16, 256), + ("f16_small", jnp.float16, 128), ) - def test_store_untiled(self, mlir_dtype_cls, jax_dtype): - mlir_dtype = mlir_dtype_cls.get() + def test_store_untiled_splat(self, jax_dtype, size): + mlir_dtype = utils.dtype_to_ir_type(jax_dtype) def kernel(ctx, out, _): del ctx - iota_tensor(64, 64, mlir_dtype).store_untiled(out) - expected = np.arange(64 * 64, dtype=jax_dtype).reshape(64, 64) - iota = mosaic_gpu.as_gpu_kernel( + arr = mgpu.FragmentedArray.splat( + c(1.0, mlir_dtype), (size,), is_signed=utils.is_signed(jax_dtype) + ) + arr.store_untiled(out) + expected = np.ones((size,), jax_dtype) + mosaic_ones = mgpu.as_gpu_kernel( kernel, (1, 1, 1), (128, 1, 1), (), expected, () )() - np.testing.assert_array_equal(iota, expected) + np.testing.assert_array_equal(mosaic_ones, expected) @parameterized.product( - dtypes=( - (ir.F32Type.get, jnp.float32), - (ir.F16Type.get, jnp.float16), - (partial(ir.IntegerType.get_signless, 8), jnp.int8), - ), + dtype=[jnp.float32, jnp.float16, jnp.int8], swizzle=(32, 64, 128), num_col_tiles=(1, 2, 3), ) - def test_store_tiled(self, dtypes, swizzle, num_col_tiles): - mlir_dtype_cls, jax_dtype = dtypes - mlir_dtype = mlir_dtype_cls() + def test_store_tiled(self, dtype, swizzle, num_col_tiles): + mlir_dtype = utils.dtype_to_ir_type(dtype) if bytewidth(mlir_dtype) > 2 and swizzle == 32: self.skipTest("Not implemented") col_tiling = swizzle // bytewidth(mlir_dtype) @@ -394,42 +409,63 @@ def test_store_tiled(self, dtypes, swizzle, num_col_tiles): tiling = (64, col_tiling) def kernel(ctx, out, smem): del ctx - iota_tensor(m, n, mlir_dtype).store_tiled(smem, swizzle=swizzle) + iota_tensor(m, n, dtype).store_tiled(smem, swizzle=swizzle) copy(smem, out, swizzle=swizzle) expected = ( - np.arange(m * n, dtype=jax_dtype) + np.arange(m * n, dtype=dtype) .reshape(m // tiling[0], tiling[0], n // tiling[1], tiling[1]) .transpose(0, 2, 1, 3) ) - iota = mosaic_gpu.as_gpu_kernel( + iota = mgpu.as_gpu_kernel( kernel, (1, 1, 1), (128, 1, 1), (), expected, expected )() np.testing.assert_array_equal(iota, expected) + @parameterized.product( + dtype=[jnp.float16, jnp.int8], + swizzle=(32, 64, 128), + ) + def test_store_tiled_short_n(self, dtype, swizzle): + mlir_dtype = utils.dtype_to_ir_type(dtype) + col_tiling = swizzle // bytewidth(mlir_dtype) + m = 128 + n = 16 // bytewidth(mlir_dtype) + tiling = (64, col_tiling) + def kernel(ctx, out, smem): + iota_tensor(m, n, dtype).store_tiled(smem, swizzle=swizzle) + ctx.async_copy( + src_ref=smem, + dst_ref=out, + swizzle=swizzle, + gmem_slice=(ds(0, m), ds(0, col_tiling)), + gmem_transform=mgpu.TileTransform(tiling), + ) + ctx.await_async_copy(0) + smem_shape = jax.ShapeDtypeStruct((m // tiling[0], 1, *tiling), dtype) + expected = np.arange(m * n, dtype=dtype).reshape(m, n) + iota = mgpu.as_gpu_kernel( + kernel, (1, 1, 1), (128, 1, 1), (), expected, smem_shape + )() + np.testing.assert_array_equal(iota, expected) + @parameterized.named_parameters( - ("bf16_i8", - ir.BF16Type.get, jnp.bfloat16, - lambda: ir.IntegerType.get_signless(8), jnp.int8), - ("i8_bf16", - lambda: ir.IntegerType.get_signless(8), jnp.int8, - ir.BF16Type.get, jnp.bfloat16), - ("i8_i8", - lambda: ir.IntegerType.get_signless(8), jnp.int8, - lambda: ir.IntegerType.get_signless(8), jnp.int8), + ("bf16_i8", jnp.bfloat16, jnp.int8), + ("i8_bf16", jnp.int8, jnp.bfloat16), + ("i8_i8", jnp.int8, jnp.int8), ) - def test_convert_tiled(self, - mlir_dtype_cls_from, jax_dtype_from, - mlir_dtype_cls_to, jax_dtype_to): - mlir_dtype_from = mlir_dtype_cls_from() - mlir_dtype_to = mlir_dtype_cls_to() + def test_convert_tiled(self, jax_dtype_from, jax_dtype_to): + mlir_dtype_from = utils.dtype_to_ir_type(jax_dtype_from) + mlir_dtype_to = utils.dtype_to_ir_type(jax_dtype_to) m = 128 n = 256 // bytewidth(mlir_dtype_from) def kernel(ctx, inp, out, smem): del ctx smem_from, smem_to = smem copy(inp, smem_from, swizzle=128) - t = mgpu.FragmentedArray.load_tiled(smem_from, swizzle=128) - t = t.astype(mlir_dtype_to) + t = mgpu.FragmentedArray.load_tiled( + smem_from, swizzle=128, is_signed=utils.is_signed(jax_dtype_from) + ) + t = t.astype(mlir_dtype_to, is_signed=utils.is_signed(jax_dtype_to)) t.store_tiled(smem_to, swizzle=128) copy(smem_to, out, swizzle=128) @@ -444,7 +480,7 @@ def kernel(ctx, inp, out, smem): expected_from = expected(jax_dtype_from, from_tiling) expected_to = expected(jax_dtype_to, to_tiling) - res = mosaic_gpu.as_gpu_kernel( + res = mgpu.as_gpu_kernel( kernel, (1, 1, 1), (128, 1, 1), @@ -455,12 +491,12 @@ def kernel(ctx, inp, out, smem): np.testing.assert_array_equal(res, expected_to) @parameterized.named_parameters( - ("f32", ir.F32Type.get, jnp.float32), - ("f16", ir.F16Type.get, jnp.float16), - ("i8", partial(ir.IntegerType.get_signless, 8), jnp.int8), + ("f32", jnp.float32), + ("f16", jnp.float16), + ("i8", jnp.int8), ) - def test_load_tiled(self, mlir_dtype_cls, jax_dtype): - mlir_dtype = mlir_dtype_cls() + def test_load_tiled(self, jax_dtype): + mlir_dtype = utils.dtype_to_ir_type(jax_dtype) m = 128 n = 256 // bytewidth(mlir_dtype) tiling = (64, 128 // bytewidth(mlir_dtype)) @@ -468,7 +504,9 @@ def kernel(ctx, in_, out, smem): del ctx smem1, smem2 = smem copy(in_, smem1, swizzle=128) - t = mgpu.FragmentedArray.load_tiled(smem1, swizzle=128) + t = mgpu.FragmentedArray.load_tiled( + smem1, swizzle=128, is_signed=utils.is_signed(jax_dtype) + ) t.store_tiled(smem2, swizzle=128) copy(smem2, out, swizzle=128) expected = ( @@ -476,7 +514,7 @@ def kernel(ctx, in_, out, smem): .reshape(m // tiling[0], tiling[0], n // tiling[1], tiling[1]) .transpose(0, 2, 1, 3) ) - iota = mosaic_gpu.as_gpu_kernel( + iota = mgpu.as_gpu_kernel( kernel, (1, 1, 1), (128, 1, 1), expected, expected, (expected,) * 2 )(expected) np.testing.assert_array_equal(iota, expected) @@ -512,7 +550,7 @@ def test_wgmma_basic( raise self.skipTest("Copy with non-128B swizzles not implemented") in_mlir_dtype = in_mlir_dtype_cls.get() - out_mlir_dtype = mlir.dtype_to_ir_type(jnp.dtype(jax_out_dtype)) + out_mlir_dtype = utils.dtype_to_ir_type(jax_out_dtype) if ir.F32Type.isinstance(in_mlir_dtype): # We actually use tf32 instead in_jax_dtype = jnp.float32 if lhs_transpose or not rhs_transpose: @@ -544,13 +582,13 @@ def test_wgmma_basic( def kernel(ctx, lhs, rhs, out, scratch): lhs_smem, rhs_smem, barriers = scratch if tma_inputs: - lhs_transform = (mosaic_gpu.TileTransform((64, nk_tile)),) + lhs_transform = (mgpu.TileTransform((64, nk_tile)),) if lhs_transpose: assert nk_tile == 64 # Make sure we didn't have to transpose tiling. - lhs_transform += (mosaic_gpu.TransposeTransform((1, 0, 2, 3)),) - rhs_transform = (mosaic_gpu.TileTransform((nk_tile, nk_tile)),) + lhs_transform += (mgpu.TransposeTransform((1, 0, 2, 3)),) + rhs_transform = (mgpu.TileTransform((nk_tile, nk_tile)),) if rhs_transpose: - rhs_transform += (mosaic_gpu.TransposeTransform((1, 0, 2, 3)),) + rhs_transform += (mgpu.TransposeTransform((1, 0, 2, 3)),) ctx.async_copy( src_ref=lhs, dst_ref=lhs_smem, @@ -617,7 +655,7 @@ def quantize(x): ), mgpu.TMABarrier(2), ] - z = mosaic_gpu.as_gpu_kernel( + z = mgpu.as_gpu_kernel( kernel, (1, 1, 1), (128, 1, 1), (x, y), out_shape, scratch_shape )(x, y) x32, y32 = x.astype(np.float32), y.astype(np.float32) @@ -632,11 +670,9 @@ def quantize(x): k_steps=(1, 2), rhs_transpose=(False, True), swizzle=(32, 64, 128), - mlir_dtype_cls=(ir.F16Type, ir.BF16Type), + dtype=[jnp.float16, jnp.bfloat16], ) - def test_wgmma_reg_lhs( - self, m, n, k_steps, rhs_transpose, swizzle, mlir_dtype_cls - ): + def test_wgmma_reg_lhs(self, m, n, k_steps, rhs_transpose, swizzle, dtype): index = ir.IndexType.get() row_major = mgpu.WGMMALayout.ROW_MAJOR @@ -662,29 +698,84 @@ def kernel(ctx, rhs, out, rhs_smem): swizzle=swizzle, ) init_acc = mgpu.WGMMAAccumulator.zero(m=m, n=n) - lhs_regs = iota_tensor(m, k, mlir_dtype_cls.get()) + lhs_regs = iota_tensor(m, k, dtype) acc = mgpu.wgmma(init_acc, lhs_regs, rhs_smem, b_order=rhs_order, swizzle=swizzle) nvvm.wgmma_commit_group_sync_aligned() nvvm.wgmma_wait_group_sync_aligned(0) acc.value.store_untiled(out) - jax_dtype = jnp.float16 if mlir_dtype_cls == ir.F16Type else jnp.bfloat16 y_shape = (n, k) if rhs_transpose else (k, n) - y = self.prng.uniform(-1, 1, y_shape).astype(jax_dtype) + y = self.prng.uniform(-1, 1, y_shape).astype(dtype) out_shape = jax.ShapeDtypeStruct((m, n), jnp.float32) scratch_shape = jax.ShapeDtypeStruct( - (k_steps, n // nk_tile, nk_tile, nk_tile), jax_dtype + (k_steps, n // nk_tile, nk_tile, nk_tile), dtype ) - z = mosaic_gpu.as_gpu_kernel( + z = mgpu.as_gpu_kernel( kernel, (1, 1, 1), (128, 1, 1), y, out_shape, scratch_shape )(y) - x = np.arange(m * k, dtype=jax_dtype).reshape(m, k) + x = np.arange(m * k, dtype=dtype).reshape(m, k) ref = jax.lax.dot( x, (y.T if rhs_transpose else y), preferred_element_type=jnp.float32 ) rtol = 5e-4 np.testing.assert_allclose(z, ref, rtol=rtol, atol=0) + @parameterized.product( + rhs_transpose=(False, True), + swizzle=(32, 64, 128), + ) + def test_narrow_n(self, rhs_transpose, swizzle): + m, n, k_steps = 64, 8, 2 + + row_major = mgpu.WGMMALayout.ROW_MAJOR + col_major = mgpu.WGMMALayout.COL_MAJOR + rhs_order = col_major if rhs_transpose else row_major + bytewidth = 2 + nk_tile = swizzle // bytewidth + k = nk_tile * k_steps + + def kernel(ctx, rhs, out, smem): + rhs_smem, barrier = smem + gmem_slice = (ds(0, k), ds(0, nk_tile)) + smem_slice = (slice(None), slice(None), slice(None), ds(0, n)) + transform = (mgpu.TileTransform((nk_tile, nk_tile)),) + if rhs_transpose: + gmem_slice = gmem_slice[::-1] + smem_slice = (slice(None), slice(None), ds(0, n), slice(None)) + transform += (mgpu.TransposeTransform((1, 0, 2, 3)),) + ctx.async_copy( + src_ref=rhs, + dst_ref=rhs_smem, + swizzle=swizzle, + gmem_slice=gmem_slice, + gmem_transform=transform, + barrier=barrier, + ) + barrier.wait() + init_acc = mgpu.WGMMAAccumulator.zero(m=m, n=n) + lhs_regs = iota_tensor(m, k, jnp.float16) + rhs_smem = memref_slice(rhs_smem, smem_slice) + acc = mgpu.wgmma(init_acc, lhs_regs, rhs_smem, b_order=rhs_order, swizzle=swizzle) + nvvm.wgmma_commit_group_sync_aligned() + nvvm.wgmma_wait_group_sync_aligned(0) + acc.value.store_untiled(out) + + jax_dtype = jnp.float16 + y_shape = (n, k) if rhs_transpose else (k, n) + y = self.prng.uniform(-1, 1, y_shape).astype(jax_dtype) + out_shape = jax.ShapeDtypeStruct((m, n), jnp.float32) + rhs_scratch_shape = jax.ShapeDtypeStruct( + (k_steps, 1, nk_tile, nk_tile), jax_dtype + ) + z = mgpu.as_gpu_kernel( + kernel, (1, 1, 1), (128, 1, 1), y, out_shape, (rhs_scratch_shape, mgpu.TMABarrier()), + )(y) + x = np.arange(m * k, dtype=jax_dtype).reshape(m, k) + ref = jax.lax.dot( + x, (y.T if rhs_transpose else y), preferred_element_type=jnp.float32 + ) + np.testing.assert_allclose(z, ref, rtol=5e-4, atol=0) + class BarrierTest(TestCase): @@ -700,17 +791,22 @@ def kernel(ctx, dst, scratch): arith.addi(wg_idx, c(1, i32)), (128,), mgpu.WGStridedFragLayout((128,), 1), + is_signed=False, ) with ir.InsertionPoint(scf.IfOp(is_first_wg).then_block): arr.store_untiled(tmp) barriers[0].arrive() # Signal that tmp is ready. barriers[1].wait() # Wait for the other warp to produce tmp. - final_arr = arr + mgpu.FragmentedArray.load_strided(tmp) + final_arr = arr + mgpu.FragmentedArray.load_strided( + tmp, is_signed=False + ) final_arr.store_untiled(memref_slice(dst, 0)) scf.yield_([]) with ir.InsertionPoint(scf.IfOp(is_second_wg).then_block): barriers[0].wait() - final_arr = arr + mgpu.FragmentedArray.load_strided(tmp) + final_arr = arr + mgpu.FragmentedArray.load_strided( + tmp, is_signed=False + ) barriers[2].arrive() barriers[2].wait() # Synchronize this warpgroup before we overwrite tmp. arr.store_untiled(tmp) @@ -718,7 +814,7 @@ def kernel(ctx, dst, scratch): final_arr.store_untiled(memref_slice(dst, 1)) scf.yield_([]) out_shape = jax.ShapeDtypeStruct((2, 128), jnp.int32) - y = mosaic_gpu.as_gpu_kernel( + y = mgpu.as_gpu_kernel( kernel, (1, 1, 1), (2 * 128, 1, 1), @@ -733,10 +829,11 @@ def kernel(ctx, dst, scratch): @parameterized.named_parameters( ( - f"_{''.join(map(str, collective_dims))}={collective_size}{'_' + ''.join(map(str, noncollective_dims)) if noncollective_dims else ''}", + f"_{''.join(map(str, collective_dims))}={collective_size}{'_' + ''.join(map(str, noncollective_dims)) if noncollective_dims else ''}{'_group' if group_dims else ''}", collective_dims, noncollective_dims, collective_size, + group_dims, ) for collective_dims in itertools.chain.from_iterable( itertools.combinations(Dimension, n) for n in range(1, 4) @@ -745,9 +842,10 @@ def kernel(ctx, dst, scratch): itertools.combinations(Dimension, n) for n in range(3) ) for collective_size in (1, 2, 4) + for group_dims in (False,) + ((True,) if len(collective_dims) > 1 else ()) if all(d not in noncollective_dims for d in collective_dims) ) - def test_collective_arrive(self, collective_dims, noncollective_dims, collective_size): + def test_collective_arrive(self, collective_dims, noncollective_dims, collective_size, group_dims): i32 = ir.IntegerType.get_signless(32) index = ir.IndexType.get() cluster = [1, 1, 1] @@ -757,9 +855,21 @@ def test_collective_arrive(self, collective_dims, noncollective_dims, collective cluster[d] = 2 if math.prod(cluster) > 16: self.skipTest("Cluster too big") - def kernel(ctx, dst, collective_barrier): + is_trivial = math.prod(cluster[d] for d in collective_dims) == 1 + def kernel(ctx, dst, mask, collective_barrier): + memref.store(arith.constant(i32, 1 << 17), mask, [c(0, index)]) + gpu.barrier() collective_barrier.arrive() collective_barrier.wait() + if not is_trivial: + llvm.atomicrmw( + llvm.AtomicBinOp.min, + utils.memref_ptr(mask), + collective_barrier.cluster_mask, + llvm.AtomicOrdering.monotonic, + ) + else: + assert collective_barrier.cluster_mask is None tid = thread_idx() linear_idx = arith.index_cast(index, tid) stride = c(128, index) @@ -768,13 +878,30 @@ def kernel(ctx, dst, collective_barrier): stride = arith.muli(stride, gpu.grid_dim(d)) memref.store(arith.index_cast(i32, linear_idx), dst, [linear_idx]) out_shape = jax.ShapeDtypeStruct((math.prod(cluster) * 128,), jnp.int32) - scratch = mgpu.ClusterBarrier(collective_dims) - y = mosaic_gpu.as_gpu_kernel( - kernel, cluster, (128, 1, 1), (), out_shape, scratch, cluster=cluster, + mask_shape = jax.ShapeDtypeStruct((1,), jnp.int32) + barrier_dims = collective_dims + if group_dims: + barrier_dims = (collective_dims[:2], *collective_dims[2:]) + scratch = mgpu.ClusterBarrier(barrier_dims) + y, mask = mgpu.as_gpu_kernel( + kernel, cluster, (128, 1, 1), (), (out_shape, mask_shape), scratch, cluster=cluster, )() np.testing.assert_array_equal( y, np.arange(math.prod(cluster) * 128, dtype=np.int32) ) + if not is_trivial: + # Verify that the mask is correct. Blocks are column-major, hence the transpose. + block_bits = 1 << np.arange(math.prod(cluster), dtype=np.int32).reshape(cluster[::-1]).T + expected_mask = 0 + for bd in barrier_dims: + if isinstance(bd, gpu.Dimension): + bd = (bd,) + least_significant_slice = tuple( + slice(None) if d in bd else 0 for d in gpu.Dimension + ) + mask_bits = block_bits[least_significant_slice] + expected_mask |= np.bitwise_or.reduce(mask_bits, axis=None) + self.assertEqual(mask, expected_mask) class TMATest(TestCase): @@ -795,35 +922,41 @@ def kernel(ctx, src, dst, smem): copy(tmp, dst, swizzle=swizzle) x = np.arange(np.prod(shape), dtype=dtype).reshape(shape) smem = (x, mgpu.TMABarrier()) - y = mosaic_gpu.as_gpu_kernel(kernel, (1, 1, 1), (128, 1, 1), x, x, smem)(x) + y = mgpu.as_gpu_kernel(kernel, (1, 1, 1), (128, 1, 1), x, x, smem)(x) np.testing.assert_array_equal(y, x) @parameterized.named_parameters( ( - f"_{collective_dim}={collective_size}{'_' + ''.join(map(str, noncollective_dims)) if noncollective_dims else ''}", - collective_dim, + f"_{''.join(map(str, collective_dims))}={collective_size}{'_' + ''.join(map(str, noncollective_dims)) if noncollective_dims else ''}", + collective_dims, noncollective_dims, collective_size, ) - for collective_dim in Dimension + for collective_dims in itertools.chain.from_iterable( + itertools.combinations(Dimension, n) for n in range(1, 4) + ) for noncollective_dims in itertools.chain.from_iterable( itertools.combinations(Dimension, n) for n in range(3) ) for collective_size in (1, 2, 4) - if collective_dim not in noncollective_dims + if all(d not in noncollective_dims for d in collective_dims) ) - def test_tma_load_multicast(self, collective_dim, noncollective_dims, collective_size): + def test_tma_load_multicast(self, collective_dims, noncollective_dims, collective_dim_size): index = ir.IndexType.get() swizzle = 128 dtype = jnp.float16 cluster = [1, 1, 1] - cluster[collective_dim] = collective_size + for d in collective_dims: + cluster[d] = collective_dim_size for d in noncollective_dims: cluster[d] = 2 - noncollective_size = math.prod(cluster) // cluster[collective_dim] + if math.prod(cluster) > 16: + self.skipTest("Cluster too big") + collective_size = math.prod(cluster[d] for d in collective_dims) + noncollective_size = math.prod(cluster) // collective_size # We use the 2 dimension to exercise splitting the collective over # multiple dimensions when the cluster is large. - shape = (noncollective_size, 2, 16 * cluster[collective_dim], 64) + shape = (noncollective_size, 2, 16 * collective_size, 64) minor_size = 64 if swizzle is None else swizzle // jnp.dtype(dtype).itemsize shape = (*shape[:-1], minor_size) # Note that this kernel does not use the non-collective dimensions in any @@ -845,11 +978,20 @@ def kernel(ctx, src, dst, scratch): gmem_slice=(noncollective_idx,), swizzle=swizzle, barrier=barrier, - collective=collective_dim, + collective=collective_dims, ) barrier.wait() + # This is _not_ the real cluster block idx, because it does not consider + # the column-major ordering of the grid dimensions. + idx = c(0, index) + stride = 1 + for d in collective_dims: + idx = arith.addi( + idx, arith.muli(gpu.cluster_block_id(d), c(stride, index)) + ) + stride *= cluster[d] slc = ds( - arith.muli(gpu.cluster_block_id(collective_dim), c(16, index)), 16 + arith.muli(idx, c(16, index)), 16 ) copy( memref_slice(tmp, (slice(None), slc)), @@ -858,7 +1000,7 @@ def kernel(ctx, src, dst, scratch): ) x = np.arange(np.prod(shape), dtype=dtype).reshape(shape) smem_shape = (jax.ShapeDtypeStruct(shape[1:], dtype), mgpu.TMABarrier()) - y = mosaic_gpu.as_gpu_kernel( + y = mgpu.as_gpu_kernel( kernel, cluster, (128, 1, 1), x, x, smem_shape, cluster=cluster )(x) np.testing.assert_array_equal(y, x) @@ -882,7 +1024,7 @@ def kernel(ctx, src, dst, scratch): dst_ref=tmp, swizzle=swizzle, barrier=barrier, - gmem_transform=mosaic_gpu.TileTransform(tiling), + gmem_transform=mgpu.TileTransform(tiling), ) barrier.wait_parity(c(0, i1)) for idxs in np.ndindex(tiled_shape): @@ -897,7 +1039,7 @@ def kernel(ctx, src, dst, scratch): jax.ShapeDtypeStruct(tile_shape(shape, tiling), dtype), mgpu.TMABarrier(), ) - f = mosaic_gpu.as_gpu_kernel(kernel, (1, 1, 1), (128, 1, 1), x, x, smem) + f = mgpu.as_gpu_kernel(kernel, (1, 1, 1), (128, 1, 1), x, x, smem) y = f(x) np.testing.assert_array_equal(y, x) @@ -924,7 +1066,7 @@ def kernel(ctx, src, dst, smem): copy(tmp, dst, swizzle=swizzle) x = np.arange(np.prod(shape), dtype=dtype).reshape(shape) smem = (x, mgpu.TMABarrier()) - y = mosaic_gpu.as_gpu_kernel(kernel, (1, 1, 1), (128, 1, 1), x, x, smem)(x) + y = mgpu.as_gpu_kernel(kernel, (1, 1, 1), (128, 1, 1), x, x, smem)(x) np.testing.assert_array_equal(y, x) def test_parity_tracking(self): @@ -940,7 +1082,7 @@ def kernel(ctx, src, dst, smem): barrier.wait() copy(tmp, memref_slice(dst, s)) x = np.arange(np.prod(shape), dtype=jnp.float16).reshape(shape) - y = mosaic_gpu.as_gpu_kernel( + y = mgpu.as_gpu_kernel( kernel, (1, 1, 1), (128, 1, 1), x, x, (x[0:1], mgpu.TMABarrier()) )(x) np.testing.assert_array_equal(y, x) @@ -958,9 +1100,90 @@ def kernel(ctx, src, dst, tmp): ctx.async_copy(src_ref=tmp, dst_ref=dst, swizzle=swizzle) ctx.await_async_copy(0) x = np.arange(np.prod(shape), dtype=dtype).reshape(shape) - y = mosaic_gpu.as_gpu_kernel(kernel, (1, 1, 1), (128, 1, 1), x, x, x)(x) + y = mgpu.as_gpu_kernel(kernel, (1, 1, 1), (128, 1, 1), x, x, x)(x) np.testing.assert_array_equal(y, x) + @parameterized.parameters(0, 1) + def test_tma_small_tile_load(self, small_dim): + if small_dim == 0: + shape = (4, 128) + elif small_dim == 1: + shape = (128, 8) + else: + raise ValueError("small_dim must be 0 or 1") + tiled_shape = ((shape[0] + 63) // 64, (shape[1] + 63) // 64, 64, 64) + padded_shape = (math.prod(tiled_shape[0::2]), math.prod(tiled_shape[1::2])) + def kernel(ctx, src, dst, smem): + tmp, barrier = smem + ctx.async_copy( + src_ref=src, + dst_ref=tmp, + swizzle=128, + gmem_transform=mgpu.TileTransform((64, 64)), + gmem_slice=(ds(0, padded_shape[0]), ds(0, padded_shape[1])), + barrier=barrier, + ) + barrier.wait() + copy(tmp, dst, swizzle=128) + x = np.arange(np.prod(shape), dtype=jnp.float16).reshape(shape) + tiled = jax.ShapeDtypeStruct(tiled_shape, jnp.float16) + y_tiled = mgpu.as_gpu_kernel( + kernel, (1, 1, 1), (128, 1, 1), x, tiled, (tiled, mgpu.TMABarrier()), + )(x) + y = y_tiled.swapaxes(1, 2).reshape(padded_shape) + # y should contain x and zero everywhere else. + np.testing.assert_array_equal(y[:shape[0], :shape[1]], x) + y_mut = np.asarray(y).copy() + y_mut[:shape[0], :shape[1]] = 0 + np.testing.assert_array_equal(y_mut, np.zeros_like(y_mut)) + + @parameterized.parameters(0, 1) + def test_tma_small_tile_store(self, small_dim): + if small_dim == 0: + shape = (4, 128) + elif small_dim == 1: + shape = (128, 8) + else: + raise ValueError("small_dim must be 0 or 1") + tiled_shape = ((shape[0] + 63) // 64, (shape[1] + 63) // 64, 64, 64) + m, n = (math.prod(tiled_shape[0::2]), math.prod(tiled_shape[1::2])) + def kernel(ctx, dst, tmp): + vals = iota_tensor(m, n, jnp.float16) + vals.store_tiled(tmp, swizzle=128) + ctx.async_copy( + src_ref=tmp, + dst_ref=dst, + swizzle=128, + gmem_transform=mgpu.TileTransform((64, 64)), + gmem_slice=(ds(0, m), ds(0, n)), + ) + ctx.await_async_copy(0) + tiled = jax.ShapeDtypeStruct(tiled_shape, jnp.float16) + out = jax.ShapeDtypeStruct(shape, jnp.float16) + y = mgpu.as_gpu_kernel( + kernel, (1, 1, 1), (128, 1, 1), (), out, tiled, + )() + iota = np.arange(m * n, dtype=jnp.float16).reshape([m, n]) + np.testing.assert_array_equal(y, iota[:shape[0], :shape[1]]) + + def test_tma_invalid(self): + def kernel(ctx, src, dst, tmp): + copy(src, tmp) + ctx.async_copy(src_ref=tmp, dst_ref=dst) + ctx.await_async_copy(0) + + def run_kernel(shape): + x = np.arange(np.prod(shape)).reshape(shape) + _ = mgpu.as_gpu_kernel(kernel, (1, 1, 1), (128, 1, 1), x, x, x)(x) + + with self.assertRaisesRegex(ValueError, "only support striding up to 5"): + run_kernel([1] * 6) + + with self.assertRaisesRegex( + ValueError, "last dimension to be divisible by 16" + ): + run_kernel([23]) + class FragmentedArrayTest(TestCase): @@ -970,63 +1193,154 @@ class FragmentedArrayTest(TestCase): operator.mul, operator.sub, operator.truediv, + operator.mod, (lambda x, y: mgpu.FragmentedArray.max(x, y), np.maximum), ), + dtype=[jnp.float32, jnp.int32, jnp.uint32], m=(64, 128), n=(8, 16, 32, 64, 80, 128, 256), ) - def test_binary(self, op, m=64, n=32): + @jtu.ignore_warning(message="(invalid value|divide by zero)", + category=RuntimeWarning) + def test_binary(self, op, dtype, m=64, n=32): if isinstance(op, tuple): op, np_op = op else: np_op = op + if jnp.issubdtype(dtype, jnp.integer) and op is operator.truediv: + self.skipTest("Unsupported for integer types") + if jnp.issubdtype(dtype, jnp.floating) and op is operator.mod: + self.skipTest("Unsupported for floating types") + for scalar_rhs in [None, 2]: def kernel(ctx, dst, _): - f32 = ir.F32Type.get() - iota = iota_tensor(m=m, n=n, mlir_dtype=f32) - rhs = iota if scalar_rhs is None else c(scalar_rhs, iota.mlir_dtype) + mlir_dtype = utils.dtype_to_ir_type(dtype) + iota = iota_tensor(m, n, dtype) + rhs = iota if scalar_rhs is None else c(scalar_rhs, mlir_dtype) op(iota, rhs).store_untiled(dst) - out_shape = jax.ShapeDtypeStruct((m, n), jnp.float32) - result = mosaic_gpu.as_gpu_kernel( + out_shape = jax.ShapeDtypeStruct((m, n), dtype) + result = mgpu.as_gpu_kernel( kernel, (1, 1, 1), (128, 1, 1), (), out_shape, () )() - ref_x = np.arange(m * n, dtype=jnp.float32).reshape(m, n) + ref_x = np.arange(m * n, dtype=dtype).reshape(m, n) ref_rhs = scalar_rhs or ref_x - if op == operator.truediv: + if op is operator.truediv: np.testing.assert_allclose(result, np_op(ref_x, ref_rhs), atol=2e-7) else: np.testing.assert_array_equal(result, np_op(ref_x, ref_rhs)) + @parameterized.product( + op=[ + operator.lt, + operator.le, + operator.gt, + operator.ge, + operator.eq, + operator.ne, + ], + dtype=[jnp.float32, jnp.int32, jnp.uint32], + ) + def test_comparison(self, op, dtype, m=64, n=32): + def kernel(ctx, dst, _): + iota = iota_tensor(m, n, dtype) + op(iota, iota + 1).store_untiled(dst) + + out_shape = jax.ShapeDtypeStruct((m, n), jnp.bool) + result = mgpu.as_gpu_kernel( + kernel, (1, 1, 1), (128, 1, 1), (), out_shape, () + )() + iota = np.arange(m * n, dtype=dtype).reshape(m, n) + np.testing.assert_array_equal(result, op(iota, iota + 1)) + @parameterized.product( ops=( - (lambda x: mgpu.FragmentedArray.exp(x), np.exp, False), - (lambda x: mgpu.FragmentedArray.exp(x, approx=True), np.exp, True), - (lambda x: mgpu.FragmentedArray.sin(x), np.sin, False), - (lambda x: mgpu.FragmentedArray.sin(x, approx=True), np.sin, True), - (lambda x: mgpu.FragmentedArray.cos(x), np.cos, False), - (lambda x: mgpu.FragmentedArray.cos(x, approx=True), np.cos, True), - (lambda x: mgpu.FragmentedArray.rsqrt(x), jax.lax.rsqrt, False), - (lambda x: mgpu.FragmentedArray.rsqrt(x, approx=True), jax.lax.rsqrt, True), + (lambda x: -x, jax.lax.neg), + (lambda x: x + 42, lambda x: x + 42), ), - m=(64, 128), - n=(8, 16, 32, 64, 80, 128, 256), + dtype=[jnp.float32, jnp.int32, jnp.uint32], ) - def test_unary(self, ops, m=64, n=32): - op, np_op, is_approx = ops + def test_unary(self, ops, dtype, m=64, n=32): + op, np_op = ops + def kernel(ctx, dst, _): - f32 = ir.F32Type.get() - iota = iota_tensor(m=m, n=n, mlir_dtype=f32) + iota = iota_tensor(m, n, dtype) + op(iota).store_untiled(dst) + + out_shape = jax.ShapeDtypeStruct((m, n), dtype) + result = mgpu.as_gpu_kernel( + kernel, (1, 1, 1), (128, 1, 1), (), out_shape, () + )() + x = np.arange(m * n, dtype=dtype).reshape(m, n) + np.testing.assert_allclose(result, np_op(x), atol=2e-7, rtol=2e-7) + + def test_select(self, m=64, n=32): + + def kernel(ctx, dst, _): + iota = iota_tensor(m, n, jnp.int32) + (iota < 16).select(iota * 2, iota * 3).store_untiled(dst) + + out_shape = jax.ShapeDtypeStruct((m, n), jnp.int32) + result = mgpu.as_gpu_kernel( + kernel, (1, 1, 1), (128, 1, 1), (), out_shape, () + )() + x = np.arange(m * n, dtype=jnp.int32).reshape(m, n) + np.testing.assert_array_equal(result, np.where(x < 16, x * 2, x * 3)) + + @parameterized.product( + ops=[ + (lambda x: mgpu.FragmentedArray.exp(x), np.exp), + (lambda x: mgpu.FragmentedArray.sin(x), np.sin), + (lambda x: mgpu.FragmentedArray.cos(x), np.cos), + (lambda x: mgpu.FragmentedArray.rsqrt(x), jax.lax.rsqrt), + ], + approx=[False, True], + ) + @jtu.ignore_warning(message="overflow encountered", category=RuntimeWarning) + def test_math(self, ops, approx, m=64, n=32): + op, np_op = ops + def kernel(ctx, dst, _): + iota = iota_tensor(m, n, jnp.float32) op(iota).store_untiled(dst) out_shape = jax.ShapeDtypeStruct((m, n), jnp.float32) - result = mosaic_gpu.as_gpu_kernel( + result = mgpu.as_gpu_kernel( kernel, (1, 1, 1), (128, 1, 1), (), out_shape, () )() x = np.arange(m * n, dtype=jnp.float32).reshape(m, n) - atol = 5e-3 if is_approx else 2e-7 - rtol = 4e-6 if is_approx else 2e-7 + atol = 5e-3 if approx else 2e-7 + rtol = 4e-6 if approx else 2e-7 np.testing.assert_allclose(result, np_op(x), atol=atol, rtol=rtol) + @parameterized.product( + dtype=[jnp.float32, jnp.int32], + m=[128], + n=[32, 64], + ) + def test_reduce_sum(self, dtype, m, n): + def kernel(ctx, src, dst, scratch): + src = mgpu.FragmentedArray.load_strided( + src, is_signed=utils.is_signed(dtype) + ) + acc = mgpu.FragmentedArray.splat( + src.reduce_sum(scratch), + (m,), + is_signed=src.is_signed + ) + acc.store_untiled(dst) + + in_shape = jax.ShapeDtypeStruct((m, n), dtype) + out_shape = jax.ShapeDtypeStruct((m,), dtype) + kernel_fn = mgpu.as_gpu_kernel( + kernel, + (1, 1, 1), + (128, 1, 1), + in_shape, + out_shape, + smem_scratch_shape=jax.ShapeDtypeStruct((4,), dtype), + ) + x = np.arange(m * n, dtype=dtype).reshape(m, n) + np.testing.assert_array_equal(kernel_fn(x), jnp.full((m,), x.sum())) + @parameterized.product( op=(arith.addf, arith.maximumf), m=(64, 128), @@ -1034,11 +1348,10 @@ def kernel(ctx, dst, _): ) def test_reduce(self, op, m=64, n=32): def kernel(ctx, dst, _): - f32 = ir.F32Type.get() - iota = iota_tensor(m=m, n=n, mlir_dtype=f32) + iota = iota_tensor(m, n, jnp.float32) iota.reduce(op, axis=1).broadcast_minor(n).store_untiled(dst) out_shape = jax.ShapeDtypeStruct((m, n), jnp.float32) - result = mosaic_gpu.as_gpu_kernel( + result = mgpu.as_gpu_kernel( kernel, (1, 1, 1), (128, 1, 1), (), out_shape, () )() x = np.arange(m * n, dtype=jnp.float32).reshape(m, n) @@ -1053,14 +1366,13 @@ def kernel(ctx, dst, _): def test_splat_layout(self): m, n = 64, 8 def kernel(ctx, dst, _): - f32 = ir.F32Type.get() - iota = iota_tensor(m=m, n=n, mlir_dtype=f32) + iota = iota_tensor(m, n, jnp.float32) cte = c(1, iota.mlir_dtype) cte_arr = mgpu.FragmentedArray.splat(cte, ()) cte_arr = cte_arr.reshape((1, 1)).broadcast((m, n)) (iota + cte_arr).store_untiled(dst) out_shape = jax.ShapeDtypeStruct((m, n), jnp.float32) - result = mosaic_gpu.as_gpu_kernel( + result = mgpu.as_gpu_kernel( kernel, (1, 1, 1), (128, 1, 1), (), out_shape, () )() expected = np.arange(m * n, dtype=jnp.float32).reshape(m, n) + 1 @@ -1070,10 +1382,12 @@ def test_splat(self): def kernel(ctx, dst, _): f32 = ir.F32Type.get() v = arith.constant(f32, ir.FloatAttr.get(f32, 3.14)) - t = mgpu.FragmentedArray.splat(v, (128,), mgpu.WGMMA_ROW_LAYOUT) + t = mgpu.FragmentedArray.splat( + v, (128,), mgpu.WGMMA_ROW_LAYOUT + ) t.broadcast_minor(32).store_untiled(dst) out_shape = jax.ShapeDtypeStruct((128, 32), jnp.float32) - result = mosaic_gpu.as_gpu_kernel( + result = mgpu.as_gpu_kernel( kernel, (1, 1, 1), (128, 1, 1), (), out_shape, () )() np.testing.assert_array_equal(result, np.full((128, 32), 3.14, np.float32)) @@ -1088,7 +1402,7 @@ def kernel(ctx, *args): copy(smem_output, gmem_output) inp = out = self.prng.uniform(-1, 1, in_shape).astype(jnp.float32) - result = mosaic_gpu.as_gpu_kernel( + result = mgpu.as_gpu_kernel( kernel, (1, 1, 1), (128, 1, 1), (inp,), out, [inp, out], )(inp) np.testing.assert_array_equal(inp, result) @@ -1103,7 +1417,7 @@ def kernel(ctx, out, *_): memref.store(grp, out, [tid]) x = np.arange(128, dtype=jnp.int32) - result = mosaic_gpu.as_gpu_kernel( + result = mgpu.as_gpu_kernel( kernel, (1, 1, 1), (128, 1, 1), (), x, [], )() for i in range(0, 128, 4): @@ -1111,6 +1425,26 @@ def kernel(ctx, out, *_): np.testing.assert_array_equal(result, x) + @parameterized.named_parameters( + ("_bf16", jnp.bfloat16) + ) + def test_fast_i8_convert(self, jax_dtype_to): + jax_dtype_to = jnp.dtype(jax_dtype_to) + jax_dtype_from = jnp.dtype(jnp.int8) + mlir_dtype_to = utils.dtype_to_ir_type(jax_dtype_to) + def kernel(ctx, inp, out, smem): + del ctx, smem + arr = mgpu.FragmentedArray.load_strided(inp, is_signed=True) + arr.astype(mlir_dtype_to).store_untiled(out) + + x = jnp.arange(-128, 128, dtype=jax_dtype_from) + reference = x.astype(jax_dtype_to) + + result = mgpu.as_gpu_kernel( + kernel, (1, 1, 1), (128, 1, 1), x, reference, None, + )(x) + np.testing.assert_array_equal(result, reference) + class ProfilerTest(TestCase): @@ -1118,6 +1452,42 @@ def test_measure(self): x = jnp.arange(1024 * 1024) profiler.measure(lambda x, y: x + y, x, x) # This is just a smoke test + def test_multigpu(self): + if len(jax.devices()) < 2: + self.skipTest("Need at least 2 devices") + def kernel(ctx, src, dst, _): + mgpu.FragmentedArray.load_strided(src).store_untiled(dst) + x = np.arange(64 * 64, dtype=jnp.float32).reshape(64, 64) + f = jax.jit(mgpu.as_gpu_kernel( + kernel, (1, 1, 1), (128, 1, 1), x, x, () + )) + # Make sure we can invoke the same program on different devices. + for xd in (jax.device_put(x, d) for d in jax.devices()[:2]): + jax.block_until_ready(f(xd)) + + +class TorchTest(TestCase): + + @classmethod + def setUpClass(cls): + try: + import torch + except ImportError: + raise unittest.SkipTest("Test requires PyTorch") + cls.torch = torch + + def test_basic(self): + def kernel(ctx, i_gmem, o_gmem, _): + x = mgpu.FragmentedArray.load_strided(i_gmem) + (x + x).store_untiled(o_gmem) + + ty = jax.ShapeDtypeStruct((128, 128), jnp.float32) + x = self.torch.randn((128, 128), dtype=self.torch.float, device='cuda') + f = mgpu.as_torch_gpu_kernel(kernel, (1, 1, 1), (128, 1, 1), ty, ty, ()) + y = f(x) + np.testing.assert_allclose(y.cpu(), x.cpu() * 2) + del y # Make sure the destructor runs successfully. + if __name__ == "__main__": absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/tests/mosaic/matmul_test.py b/tests/mosaic/matmul_test.py index fe29615ced0d..27ce4e3f02d7 100644 --- a/tests/mosaic/matmul_test.py +++ b/tests/mosaic/matmul_test.py @@ -15,6 +15,7 @@ """Test different parameterizations of a matmul.""" import os +import unittest from absl.testing import absltest, parameterized from jax._src import config @@ -27,13 +28,25 @@ matmul = None else: from jax.experimental.mosaic.gpu.examples import matmul +try: + import hypothesis as hp + import hypothesis.strategies as hps +except (ModuleNotFoundError, ImportError): + raise unittest.SkipTest("these tests require hypothesis") config.parse_flags_with_absl() +jtu.setup_hypothesis() os.environ["XLA_FLAGS"] = ( os.environ.get("XLA_FLAGS", "") + " --xla_gpu_autotune_level=0") +def seed_hypothesis(f): + def wrapper(self, seed): + return hp.seed(seed)(f)(self) + return wrapper + + @jtu.with_config(jax_traceback_filtering="off") class MatmulTestCase(jtu.JaxTestCase): @@ -45,109 +58,69 @@ def setUp(self): not jtu.is_cuda_compute_capability_at_least("9.0")): self.skipTest("Only works on GPU with capability >= sm90") - @parameterized.product( - m=(128, 256, 512, 2048), - n=(128, 256, 512, 2048), - k=(128, 256, 512, 2048), - stages=(2, 4), - tile_m=(64, 128, 256), - tile_n=(64, 128, 256), - in_dtype=(jnp.float16, jnp.bfloat16), # f32 tested separately - rhs_transpose=(False, True), + @parameterized.named_parameters( + (f"_shard{i}", i) for i in range(5) ) - def test_matmul(self, m, k, n, stages, tile_m, tile_n, in_dtype, rhs_transpose): - if stages * (128 // jnp.dtype(in_dtype).itemsize) > k: - self.skipTest("Too many stages.") - - if m < tile_m: - self.skipTest(f"No use in running a test with {m=} < {tile_m=}.") - - if n < tile_n: - self.skipTest(f"No use in running a test with {n=} < {tile_n=}.") - - try: - matmul.verify( - m, - k, - n, - stages, - tile_m=tile_m, - tile_n=tile_n, - lhs_dtype=in_dtype, - rhs_dtype=in_dtype, - rhs_transpose=rhs_transpose, + @seed_hypothesis + @hp.settings(max_examples=100) # Add verbosity=hp.Verbosity.verbose to debug + @hp.given(hps.data()) + def test_matmul(self, data): + in_dtype = data.draw( + hps.sampled_from([jnp.float16, jnp.bfloat16, jnp.float32]), + label="in_dtype", + ) + out_dtype = jnp.float32 + if in_dtype != jnp.float32: + out_dtype = data.draw( + hps.sampled_from([in_dtype, jnp.float32]), + label="out_dtype", ) - except ValueError as e: - if "Mosaic GPU kernel exceeds available shared memory" in str(e): - self.skipTest("Not enough shared memory for test, skipping.") - raise e - - @parameterized.product( - m=(128, 256, 512, 2048), - n=(128, 256, 512, 2048), - k=(128, 256, 512, 2048), - stages=(2, 4), - tile_m=(64, 128, 256), - tile_n=(64, 128, 256), - ) - def test_matmul_f32(self, m, k, n, stages, tile_m, tile_n): - if stages * (128 // jnp.dtype(jnp.float32).itemsize) > k: - self.skipTest("Too many stages.") - - if m < tile_m: - self.skipTest(f"No use in running a test with {m=} < {tile_m=}.") - - if n < tile_n: - self.skipTest(f"No use in running a test with {n=} < {tile_n=}.") + bytewidth = jnp.dtype(in_dtype).itemsize + m, n, k = ( + data.draw(hps.sampled_from([128, 256, 512, 2048]), label=d) + for d in "mnk" + ) + stages = data.draw(hps.integers(2, 5), label="stages") + swizzle = data.draw(hps.sampled_from([32, 64, 128]), label="swizzle") + tile_m = data.draw( + hps.sampled_from([t for t in [64, 128, 256] if t <= m]), label="tile_m" + ) + tile_n = data.draw( + hps.sampled_from([t for t in [64, 128, 256] if t <= n]), label="tile_n" + ) + grid_m, grid_n = m // tile_m, n // tile_n + grid_tile_n = data.draw(hps.sampled_from([1, 2, 4, 8, 16]), label="grid_tile_n") + hp.assume(grid_n % grid_tile_n == 0) + cluster_m = data.draw(hps.sampled_from([1, 2, 4]), label="cluster_m") + hp.assume(grid_m % cluster_m == 0) + cluster_n = data.draw(hps.sampled_from([1, 2, 4]), label="cluster_n") + hp.assume(grid_n % cluster_n == 0) + # TODO(apaszke): Non-portable clusters (16 blocks) sometimes deadlock. + hp.assume(cluster_m * cluster_n <= 8) + if bytewidth == 4: + rhs_transpose = True + else: + rhs_transpose = data.draw(hps.booleans(), label="rhs_transpose") try: matmul.verify( m, k, n, - stages, - tile_m=tile_m, - tile_n=tile_n, - lhs_dtype=jnp.float32, - rhs_dtype=jnp.float32, - rhs_transpose=True, - ) - except ValueError as e: - if "Mosaic GPU kernel exceeds available shared memory" in str(e): - self.skipTest("Not enough shared memory for test, skipping.") - raise e - - @parameterized.product( - m=(512, 2048), - n=(512, 2048), - k=(512, 2048), - stages=(2, 4), - tile_m=(64, 128), - tile_n=(64, 128), - cluster_m=(1, 2, 4), - cluster_n=(1, 2, 4), - ) - def test_matmul_clusters(self, m, k, n, stages, tile_m, tile_n, cluster_m, cluster_n): - if cluster_m * cluster_n > 8: - # TODO(apaszke): Investigate - self.skipTest("Tests sometimes fail with non-portable cluster sizes.") - try: - matmul.verify( - m, - k, - n, - stages, + stages=stages, tile_m=tile_m, tile_n=tile_n, + in_dtype=in_dtype, + out_dtype=out_dtype, cluster_m=cluster_m, cluster_n=cluster_n, - lhs_dtype=jnp.float32, - rhs_dtype=jnp.float32, - rhs_transpose=True, + grid_tile_n=grid_tile_n, + swizzle=swizzle, + rhs_transpose=rhs_transpose, ) except ValueError as e: if "Mosaic GPU kernel exceeds available shared memory" in str(e): - self.skipTest("Not enough shared memory for test, skipping.") + hp.assume(False) raise e diff --git a/tests/multi_device_test.py b/tests/multi_device_test.py index 5daa0e0e5b84..a3b6b1efaa76 100644 --- a/tests/multi_device_test.py +++ b/tests/multi_device_test.py @@ -174,7 +174,7 @@ def test_device_put(self): def test_closed_over_values_device_placement(self): - # see https://github.com/google/jax/issues/1431 + # see https://github.com/jax-ml/jax/issues/1431 devices = self.get_devices() def f(): return lax.add(3., 4.) diff --git a/tests/multibackend_test.py b/tests/multibackend_test.py index 4f2e36c64f4b..4697ba8b2858 100644 --- a/tests/multibackend_test.py +++ b/tests/multibackend_test.py @@ -148,7 +148,7 @@ def get_arr(scale): @jtu.ignore_warning(category=DeprecationWarning, message="backend and device argument") def test_closed_over_values_device_placement(self): - # see https://github.com/google/jax/issues/1431 + # see https://github.com/jax-ml/jax/issues/1431 def f(): return jnp.add(3., 4.) self.assertNotEqual(jax.jit(f)().devices(), {jax.devices('cpu')[0]}) @@ -186,7 +186,7 @@ def my_sin(x): return jnp.sin(x) @jtu.skip_on_devices("cpu") # test only makes sense on non-cpu backends def test_indexing(self): - # https://github.com/google/jax/issues/2905 + # https://github.com/jax-ml/jax/issues/2905 cpus = jax.devices("cpu") x = jax.device_put(np.ones(2), cpus[0]) @@ -195,7 +195,7 @@ def test_indexing(self): @jtu.skip_on_devices("cpu") # test only makes sense on non-cpu backends def test_sum(self): - # https://github.com/google/jax/issues/2905 + # https://github.com/jax-ml/jax/issues/2905 cpus = jax.devices("cpu") x = jax.device_put(np.ones(2), cpus[0]) diff --git a/tests/multiprocess_gpu_test.py b/tests/multiprocess_gpu_test.py index 0235ba89293f..5c84f8c69b62 100644 --- a/tests/multiprocess_gpu_test.py +++ b/tests/multiprocess_gpu_test.py @@ -46,7 +46,7 @@ @unittest.skipIf(not portpicker, "Test requires portpicker") class DistributedTest(jtu.JaxTestCase): - # TODO(phawkins): Enable after https://github.com/google/jax/issues/11222 + # TODO(phawkins): Enable after https://github.com/jax-ml/jax/issues/11222 # is fixed. @unittest.SkipTest def testInitializeAndShutdown(self): @@ -354,7 +354,7 @@ def test_gpu_multi_node_transparent_initialize_and_psum(self): def test_pjit_gda_multi_input_multi_output(self): jax.distributed.initialize() - global_mesh = jtu.create_global_mesh((8, 2), ("x", "y")) + global_mesh = jtu.create_mesh((8, 2), ("x", "y")) global_input_shape = (16, 2) global_input_data = np.arange( util.prod(global_input_shape)).reshape(global_input_shape) @@ -558,7 +558,7 @@ def test_pjit_gda_non_contiguous_mesh_2d_aot(self): def test_pjit_gda_eval_shape(self): jax.distributed.initialize() - with jtu.create_global_mesh((16,), ("x")): + with jtu.create_mesh((16,), ("x")): @functools.partial(pjit.pjit, in_shardings=jax.sharding.PartitionSpec(None), diff --git a/tests/nn_test.py b/tests/nn_test.py index 455f04e5fd12..d6153d32c63e 100644 --- a/tests/nn_test.py +++ b/tests/nn_test.py @@ -38,112 +38,130 @@ config.parse_flags_with_absl() -def _is_required_cudnn_version_satisfied(): +def _is_required_cudnn_version_satisfied(min_cudnn_version): return ( jtu.is_cuda_compute_capability_at_least("8.0") and cuda_versions is not None and - cuda_versions.cudnn_get_version() >= 8904 + cuda_versions.cudnn_get_version() >= min_cudnn_version ) -def _get_causal_mask(T, S): - causal_mask = jnp.tril(jnp.ones((T, S), dtype=jnp.bool_)) - return causal_mask[jnp.newaxis, jnp.newaxis, :, :] +def _check_cudnn_backend(fn, *args, **kwargs): + lowered = jax.jit(fn).lower(*args, **kwargs) + hlo = mlir.module_to_string(lowered.compiler_ir('stablehlo')) + return '__cudnn$fmha' in hlo @jtu.with_config(jax_legacy_prng_key="allow", jax_numpy_dtype_promotion="standard") class NNFunctionsTest(jtu.JaxTestCase): @parameterized.product( - dtype=[jnp.float32, jnp.bfloat16, jnp.float16], - use_bias=[False, True], - causal_mode=[None, 'is_causal', 'is_mask'], + dtype=[jnp.bfloat16, jnp.float16], group_num=[1, 2, 4], - impl=['xla', 'cudnn'], + use_vmap=[False, True], + impl=['cudnn', 'xla'], ) - def testDotProductAttentionInfer(self, dtype, use_bias, causal_mode, - group_num, impl): - if impl == 'cudnn' and not _is_required_cudnn_version_satisfied(): + def testDotProductAttention(self, dtype, group_num, use_vmap, impl): + if impl == 'cudnn' and not _is_required_cudnn_version_satisfied(8904): raise unittest.SkipTest("CUDA or cuDNN versions are not compatible.") if impl == 'cudnn' and dtype == jnp.float32: raise unittest.SkipTest("cuDNN only supports fp16 or bf16.") - sdpa = nn.dot_product_attention B, S, T, N, H, G = 2, 128, 128, 4, 32, group_num - keys = random.split(random.PRNGKey(0), 4) + keys = random.split(random.PRNGKey(0), 5) Q = random.normal(keys[0], (B, T, N, H), dtype) K = random.normal(keys[1], (B, S, N // G, H), dtype) V = random.normal(keys[2], (B, S, N // G, H), dtype) - if use_bias: - bias = random.normal(keys[3], (1, N, T, S), dtype) - else: - bias = None - - is_causal = causal_mode == 'is_causal' - causal_mask = _get_causal_mask(T, S) if causal_mode == 'is_mask' else None + grad = random.normal(keys[3], (B, T, N, H), dtype) + bias, mask = None, None - sdpa_ref = partial(sdpa, is_causal=is_causal, implementation=None) - sdpa_ans = partial(sdpa, is_causal=is_causal, implementation=impl) + sdpa = nn.dot_product_attention + sdpa_ref = partial(sdpa, implementation=None) + sdpa_ans = partial(sdpa, implementation=impl) + if use_vmap: + sdpa_ans = jax.vmap(sdpa_ans, in_axes=(0, 0, 0, None, None), out_axes=0) + + # For testing purposes, we call the non-GQA version without vmap in the + # reference code + K_ref = jnp.repeat(K, G, axis=2) + V_ref = jnp.repeat(V, G, axis=2) + out_ref, sdpa_vjp_ref = jax.vjp(sdpa_ref, Q, K_ref, V_ref, bias, mask) + out_ans, sdpa_vjp_ans = jax.vjp(sdpa_ans, Q, K, V, bias, mask) + + dQ_ref, dK_ref, dV_ref = sdpa_vjp_ref(grad)[:3] + dQ_ans, dK_ans, dV_ans = sdpa_vjp_ans(grad)[:3] + dK_ref = dK_ref.reshape(B, S, N // G, G, H).sum(axis=3) + dV_ref = dV_ref.reshape(B, S, N // G, G, H).sum(axis=3) if impl == 'cudnn': - lowered = jax.jit(sdpa_ans).lower(Q, K, V, bias=bias, mask=causal_mask) - hlo = mlir.module_to_string(lowered.compiler_ir('stablehlo')) - self.assertIn('__cudnn$fmha', hlo) - - K_ref = jnp.repeat(K, G, axis=2) if G != 1 else K - V_ref = jnp.repeat(V, G, axis=2) if G != 1 else V - out_ref = sdpa_ref(Q, K_ref, V_ref, bias=bias, mask=causal_mask) + self.assertTrue(_check_cudnn_backend(sdpa_ans, Q, K, V, bias, mask)) + self.assertTrue(_check_cudnn_backend(sdpa_vjp_ans, grad)) - out_ans = sdpa_ans(Q, K, V, bias=bias, mask=causal_mask) self.assertAllClose(out_ref, out_ans, atol=.01, rtol=.01) + self.assertAllClose(dQ_ref, dQ_ans, rtol=.01, atol=.01) + self.assertAllClose(dK_ref, dK_ans, rtol=.02, atol=.02) + self.assertAllClose(dV_ref, dV_ans, rtol=.02, atol=.02) @parameterized.product( - dtype=[jnp.float32, jnp.bfloat16, jnp.float16], - use_bias=[False, True], - causal_mode=[None, 'is_causal', 'is_mask'], - group_num=[1, 2, 4], - impl=['xla', 'cudnn'], + mask_mode=['bias', 'causal', 'padding', 'custom', ('causal', 'padding'), + ('custom', 'padding'), ('bias', 'causal'), + ('causal', 'sliding_window')], ) - def testDotProductAttentionTrain(self, dtype, use_bias, causal_mode, - group_num, impl): - if impl == 'cudnn' and not _is_required_cudnn_version_satisfied(): + def testDotProductAttentionMask(self, mask_mode): + if isinstance(mask_mode, str): + mask_mode = (mask_mode,) + min_cudnn_version = 90200 if 'sliding_window' in mask_mode else 8904 + if not _is_required_cudnn_version_satisfied(min_cudnn_version): raise unittest.SkipTest("CUDA or cuDNN versions are not compatible.") - if impl == 'cudnn' and dtype == jnp.float32: - raise unittest.SkipTest("cuDNN only supports fp16 or bf16.") - sdpa = nn.dot_product_attention - B, S, T, N, H, G = 2, 128, 128, 4, 32, group_num - keys = random.split(random.PRNGKey(0), 5) + dtype = jnp.bfloat16 + B, S, T, N, H = 2, 128, 128, 4, 32 + keys = random.split(random.PRNGKey(0), 4) Q = random.normal(keys[0], (B, T, N, H), dtype) - K = random.normal(keys[1], (B, S, N // G, H), dtype) - V = random.normal(keys[2], (B, S, N // G, H), dtype) + K = random.normal(keys[1], (B, S, N, H), dtype) + V = random.normal(keys[2], (B, S, N, H), dtype) grad = random.normal(keys[3], (B, T, N, H), dtype) - if use_bias: + bias, mask = None, None + q_seqlen, kv_seqlen = None, None + window_size = None + + is_causal = 'causal' in mask_mode + if 'padding' in mask_mode: + q_seqlen = jnp.array([T // 2, T // 4], dtype=jnp.int32) + kv_seqlen = jnp.array([S // 4, S // 2], dtype=jnp.int32) + if 'custom' in mask_mode: + # Use a generated causal mask as the custom mask. + custom_mask = jnp.tril(jnp.ones((T, S), dtype=jnp.bool_)) + mask = custom_mask[None, None, :, :] + if 'bias' in mask_mode: bias = random.normal(keys[4], (1, N, T, S), dtype) - else: - bias = None - - is_causal = causal_mode == 'is_causal' - causal_mask = _get_causal_mask(T, S) if causal_mode == 'is_mask' else None + if 'sliding_window' in mask_mode: + window_size = (3, 2) if is_causal else (3, 0) - K_ref = jnp.repeat(K, G, axis=2) if G != 1 else K - V_ref = jnp.repeat(V, G, axis=2) if G != 1 else V + sdpa = nn.dot_product_attention sdpa_ref = partial(sdpa, is_causal=is_causal, implementation=None) - fn_ref = lambda q, k, v, b, m: sdpa_ref(q, k, v, bias=b, mask=m) - _, sdpa_vjp_ref = jax.vjp(fn_ref, Q, K_ref, V_ref, bias, causal_mask) - dQ_ref, dK_ref, dV_ref, dbias_ref, _ = sdpa_vjp_ref(grad) - if G != 1: - dK_ref = dK_ref.reshape(B, S, N // G, G, H).sum(axis=3) - dV_ref = dV_ref.reshape(B, S, N // G, G, H).sum(axis=3) - - sdpa_ans = partial(sdpa, is_causal=is_causal, implementation=impl) - fn_ans = lambda q, k, v, b, m: sdpa_ans(q, k, v, bias=b, mask=m) - _, sdpa_vjp_ans = jax.vjp(fn_ans, Q, K, V, bias, causal_mask) - dQ_ans, dK_ans, dV_ans, dbias_ans, _ = sdpa_vjp_ans(grad) + sdpa_ans = partial(sdpa, is_causal=is_causal, implementation='cudnn') - if impl == 'cudnn': - lowered = jax.jit(sdpa_vjp_ans).lower(grad) - hlo = mlir.module_to_string(lowered.compiler_ir('stablehlo')) - self.assertRegex(hlo, r'__cudnn\$fmha.*Backward\(') + args = (Q, K, V, bias, mask) + kwargs = {'query_seq_lengths': q_seqlen, 'key_value_seq_lengths': kv_seqlen} + + # Convert the kargs to positional args for the jax.vjp. + fn_ref = lambda q, k, v, b, m, qs, kvs: sdpa_ref( + q, k, v, b, m, query_seq_lengths=qs, key_value_seq_lengths=kvs, + local_window_size=window_size, + ) + fn_ans = lambda q, k, v, b, m, qs, kvs: sdpa_ans( + q, k, v, b, m, query_seq_lengths=qs, key_value_seq_lengths=kvs, + local_window_size=window_size, + ) + out_ref, sdpa_vjp_ref = jax.vjp(fn_ref, *args, q_seqlen, kv_seqlen) + out_ans, sdpa_vjp_ans = jax.vjp(fn_ans, *args, q_seqlen, kv_seqlen) + dQ_ref, dK_ref, dV_ref, dbias_ref = sdpa_vjp_ref(grad)[:4] + dQ_ans, dK_ans, dV_ans, dbias_ans = sdpa_vjp_ans(grad)[:4] + # Check if cudnn backend is called. + self.assertTrue(_check_cudnn_backend(sdpa_ans, *args, **kwargs)) + self.assertTrue(_check_cudnn_backend(sdpa_vjp_ans, grad)) + + self.assertAllClose(out_ref, out_ans, atol=.01, rtol=.01) self.assertAllClose(dQ_ref, dQ_ans, rtol=.01, atol=.01) self.assertAllClose(dK_ref, dK_ans, rtol=.02, atol=.02) self.assertAllClose(dV_ref, dV_ans, rtol=.02, atol=.02) @@ -290,11 +308,18 @@ def testGeluIntType(self, approximate): def testGelu(self, approximate): def gelu_reference(x): return x * scipy.stats.norm.cdf(x) - rng = jtu.rand_default(self.rng()) - args_maker = lambda: [rng((4, 5, 6), jnp.float32)] + args_maker = lambda: [jnp.linspace(-12, 5, 10000, dtype=jnp.float32)] + rtol = 2e-5 + atol = 1e-3 if approximate else 0 self._CheckAgainstNumpy( - gelu_reference, partial(nn.gelu, approximate=approximate), args_maker, - check_dtypes=False, tol=1e-3 if approximate else None) + gelu_reference, + partial(nn.gelu, approximate=approximate), + args_maker, + check_dtypes=False, + tol=0, + rtol=rtol, + atol=atol, + ) @parameterized.parameters(*itertools.product( (jnp.float32, jnp.bfloat16, jnp.float16), @@ -307,12 +332,12 @@ def testDtypeMatchesInput(self, dtype, fn): self.assertEqual(out.dtype, dtype) def testEluMemory(self): - # see https://github.com/google/jax/pull/1640 + # see https://github.com/jax-ml/jax/pull/1640 with jax.enable_checks(False): # With checks we materialize the array jax.make_jaxpr(lambda: nn.elu(jnp.ones((10 ** 12,)))) # don't oom def testHardTanhMemory(self): - # see https://github.com/google/jax/pull/1640 + # see https://github.com/jax-ml/jax/pull/1640 with jax.enable_checks(False): # With checks we materialize the array jax.make_jaxpr(lambda: nn.hard_tanh(jnp.ones((10 ** 12,)))) # don't oom @@ -349,7 +374,7 @@ def testSoftmaxWhereMask(self, fn): @parameterized.parameters([nn.softmax, nn.log_softmax]) def testSoftmaxWhereGrad(self, fn): - # regression test for https://github.com/google/jax/issues/19490 + # regression test for https://github.com/jax-ml/jax/issues/19490 x = jnp.array([36., 10000.]) mask = x < 1000 @@ -425,7 +450,7 @@ def testOneHotCustomDtype(self): self.assertAllClose(actual, expected) def testOneHotConcretizationError(self): - # https://github.com/google/jax/issues/3654 + # https://github.com/jax-ml/jax/issues/3654 msg = r"in jax.nn.one_hot argument `num_classes`" with self.assertRaisesRegex(core.ConcretizationTypeError, msg): jax.jit(nn.one_hot)(3, 5) @@ -445,7 +470,7 @@ def testTanhExists(self): nn.tanh # doesn't crash def testCustomJVPLeak(self): - # https://github.com/google/jax/issues/8171 + # https://github.com/jax-ml/jax/issues/8171 @jax.jit def fwd(): a = jnp.array(1.) @@ -461,7 +486,7 @@ def f(hx, _): fwd() # doesn't crash def testCustomJVPLeak2(self): - # https://github.com/google/jax/issues/8171 + # https://github.com/jax-ml/jax/issues/8171 # The above test uses jax.nn.sigmoid, as in the original #8171, but that # function no longer actually has a custom_jvp! So we inline the old def. diff --git a/tests/notebooks/colab_cpu.ipynb b/tests/notebooks/colab_cpu.ipynb index 1540b3d20892..f5dcff837838 100644 --- a/tests/notebooks/colab_cpu.ipynb +++ b/tests/notebooks/colab_cpu.ipynb @@ -20,7 +20,7 @@ "colab_type": "text" }, "source": [ - "\"Open" + "\"Open" ] }, { @@ -88,15 +88,6 @@ "height": 68 } }, - "source": [ - "from jaxlib import xla_extension\n", - "import jax\n", - "key = jax.random.PRNGKey(1701)\n", - "arr = jax.random.normal(key, (1000,))\n", - "device = arr.device()\n", - "print(f\"JAX device type: {device}\")\n", - "assert device.platform == \"cpu\", f\"unexpected JAX device type: {device.platform}\"" - ], "execution_count": 2, "outputs": [ { diff --git a/tests/notebooks/colab_gpu.ipynb b/tests/notebooks/colab_gpu.ipynb index 8352bdaf71bc..2335455e6cf2 100644 --- a/tests/notebooks/colab_gpu.ipynb +++ b/tests/notebooks/colab_gpu.ipynb @@ -7,7 +7,7 @@ "id": "view-in-github" }, "source": [ - "\"Open" + "\"Open" ] }, { diff --git a/tests/ode_test.py b/tests/ode_test.py index 834745e1cf1c..acdfa1fc6cef 100644 --- a/tests/ode_test.py +++ b/tests/ode_test.py @@ -139,7 +139,7 @@ def swoop(_np, y, t, arg1, arg2): @jtu.skip_on_devices("tpu", "gpu") def test_odeint_vmap_grad(self): - # https://github.com/google/jax/issues/2531 + # https://github.com/jax-ml/jax/issues/2531 def dx_dt(x, *args): return 0.1 * x @@ -169,7 +169,7 @@ def g(x): @jtu.skip_on_devices("tpu", "gpu") def test_disable_jit_odeint_with_vmap(self): - # https://github.com/google/jax/issues/2598 + # https://github.com/jax-ml/jax/issues/2598 with jax.disable_jit(): t = jnp.array([0.0, 1.0]) x0_eval = jnp.zeros((5, 2)) @@ -178,7 +178,7 @@ def test_disable_jit_odeint_with_vmap(self): @jtu.skip_on_devices("tpu", "gpu") def test_grad_closure(self): - # simplification of https://github.com/google/jax/issues/2718 + # simplification of https://github.com/jax-ml/jax/issues/2718 def experiment(x): def model(y, t): return -x * y @@ -188,7 +188,7 @@ def model(y, t): @jtu.skip_on_devices("tpu", "gpu") def test_grad_closure_with_vmap(self): - # https://github.com/google/jax/issues/2718 + # https://github.com/jax-ml/jax/issues/2718 @jax.jit def experiment(x): def model(y, t): @@ -209,7 +209,7 @@ def model(y, t): @jtu.skip_on_devices("tpu", "gpu") def test_forward_mode_error(self): - # https://github.com/google/jax/issues/3558 + # https://github.com/jax-ml/jax/issues/3558 def f(k): return odeint(lambda x, t: k*x, 1., jnp.linspace(0, 1., 50)).sum() @@ -219,7 +219,7 @@ def f(k): @jtu.skip_on_devices("tpu", "gpu") def test_closure_nondiff(self): - # https://github.com/google/jax/issues/3584 + # https://github.com/jax-ml/jax/issues/3584 def dz_dt(z, t): return jnp.stack([z[0], z[1]]) @@ -232,8 +232,8 @@ def f(z): @jtu.skip_on_devices("tpu", "gpu") def test_complex_odeint(self): - # https://github.com/google/jax/issues/3986 - # https://github.com/google/jax/issues/8757 + # https://github.com/jax-ml/jax/issues/3986 + # https://github.com/jax-ml/jax/issues/8757 def dy_dt(y, t, alpha): return alpha * y * jnp.exp(-t).astype(y.dtype) diff --git a/tests/optimizers_test.py b/tests/optimizers_test.py index b7710d9b94c2..c4eca070798c 100644 --- a/tests/optimizers_test.py +++ b/tests/optimizers_test.py @@ -260,7 +260,7 @@ def testUtilityClipGrads(self): self.assertAllClose(ans, expected, check_dtypes=False) def testIssue758(self): - # code from https://github.com/google/jax/issues/758 + # code from https://github.com/jax-ml/jax/issues/758 # this is more of a scan + jacfwd/jacrev test, but it lives here to use the # optimizers.py code diff --git a/tests/package_structure_test.py b/tests/package_structure_test.py index e9944ec084af..71d48c2b121c 100644 --- a/tests/package_structure_test.py +++ b/tests/package_structure_test.py @@ -31,7 +31,7 @@ class PackageStructureTest(jtu.JaxTestCase): @parameterized.parameters([ # TODO(jakevdp): expand test to other public modules. - _mod("jax.errors"), + _mod("jax.errors", exclude=["JaxRuntimeError"]), _mod("jax.nn.initializers"), _mod( "jax.tree_util", diff --git a/tests/pallas/BUILD b/tests/pallas/BUILD index 3e1fd863a3ab..044f82067510 100644 --- a/tests/pallas/BUILD +++ b/tests/pallas/BUILD @@ -15,7 +15,7 @@ load( "//jaxlib:jax.bzl", "jax_generate_backend_suites", - "jax_test", + "jax_multiplatform_test", "py_deps", ) @@ -28,7 +28,7 @@ package( jax_generate_backend_suites() -jax_test( +jax_multiplatform_test( name = "pallas_test", srcs = [ "pallas_test.py", @@ -38,11 +38,9 @@ jax_test( "ondemand": False, # Include in presubmit. }, }, - disable_configs = [ - "gpu", - "gpu_x32", - "gpu_p100", - "gpu_p100_x32", + enable_backends = [ + "cpu", + "tpu", ], enable_configs = [ "gpu_a100_x32", @@ -62,7 +60,27 @@ jax_test( ] + py_deps("absl/testing") + py_deps("numpy"), ) -jax_test( +jax_multiplatform_test( + name = "pallas_jumble_test", + srcs = [ + "pallas_jumble_test.py", + ], + disable_configs = [ + "gpu_v100", + "gpu_x32", + "gpu_a100", + "gpu_p100", + "gpu_p100_x32", + "gpu_h100", + ], + deps = [ + "//jax:pallas", + "//jax:pallas_tpu", + "//jax:pallas_tpu_ops", + ] + py_deps("absl/testing") + py_deps("numpy"), +) + +jax_multiplatform_test( name = "ops_test", srcs = [ "ops_test.py", @@ -73,40 +91,56 @@ jax_test( }, }, disable_configs = [ - "gpu", + "gpu_v100", "gpu_x32", - "gpu_a100", "gpu_p100", "gpu_p100_x32", - "gpu_h100", ], enable_configs = [ + "gpu_a100", "gpu_a100_x32", + "gpu_h100", "gpu_h100_x32", ], + shard_count = { + "cpu": 8, + "gpu": 8, + "tpu": 8, + }, + tags = [ + "noasan", # Times out. + "nomsan", # Times out. + "notsan", # Times out. + ], deps = [ "//jax:pallas", "//jax:pallas_gpu", # build_cleaner: keep "//jax:pallas_tpu", "//jax:pallas_tpu_ops", - ] + py_deps("absl/testing") + py_deps("numpy"), + ] + py_deps("absl/testing") + py_deps("hypothesis") + py_deps("numpy"), ) -jax_test( +jax_multiplatform_test( name = "indexing_test", srcs = [ "indexing_test.py", ], - disable_backends = [ - "gpu", + enable_backends = [ + "cpu", "tpu", ], + tags = [ + "noasan", # Times out. + "nomsan", # Times out. + "notsan", # Times out. + ], deps = [ "//jax:pallas", + "//jax:pallas_tpu", ] + py_deps("absl/testing") + py_deps("hypothesis") + py_deps("numpy"), ) -jax_test( +jax_multiplatform_test( name = "pallas_vmap_test", srcs = [ "pallas_vmap_test.py", @@ -116,14 +150,7 @@ jax_test( "ondemand": False, # Include in presubmit. }, }, - disable_configs = [ - "gpu", - "gpu_x32", - "gpu_a100", - "gpu_h100", - "gpu_p100", - "gpu_p100_x32", - ], + enable_backends = ["cpu"], enable_configs = [ "gpu_a100_x32", "gpu_h100_x32", @@ -138,37 +165,23 @@ jax_test( ] + py_deps("absl/testing") + py_deps("numpy"), ) -jax_test( +jax_multiplatform_test( name = "mosaic_gpu_test", srcs = [ "mosaic_gpu_test.py", ], config_tags_overrides = { - # TODO(slebedev): Switch to False once Mosaic GPU is unconditionally enabled. "gpu_h100_x32": { - "ondemand": True, # Include in presubmit. + "ondemand": False, # Include in presubmit. }, }, - disable_backends = [ - "cpu", - "tpu", - ], - disable_configs = [ - "gpu", - "gpu_x32", - "gpu_a100", - "gpu_a100_x32", - "gpu_p100", - "gpu_p100_x32", - "gpu_h100", - ], + enable_backends = [], enable_configs = [ "gpu_h100_x32", ], env = { "JAX_PALLAS_USE_MOSAIC_GPU": "1", }, - tags = ["notap"], deps = [ "//jax:pallas", "//jax:pallas_gpu", # build_cleaner: keep @@ -176,7 +189,7 @@ jax_test( ] + py_deps("absl/testing") + py_deps("numpy"), ) -jax_test( +jax_multiplatform_test( name = "export_back_compat_pallas_test", srcs = ["export_back_compat_pallas_test.py"], config_tags_overrides = { @@ -184,15 +197,7 @@ jax_test( "ondemand": False, # Include in presubmit. }, }, - disable_configs = [ - "gpu", - "gpu_x32", - "gpu_a100", - "gpu_h100", - "gpu_p100", - "gpu_p100_x32", - "gpu_pjrt_c_api", - ], + enable_backends = ["cpu"], enable_configs = [ "gpu_a100_x32", "gpu_h100_x32", @@ -203,10 +208,11 @@ jax_test( "//jax:internal_export_back_compat_test_util", "//jax:pallas", "//jax:pallas_gpu", # build_cleaner: keep + "//jax:pallas_tpu_ops", # build_cleaner: keep ], ) -jax_test( +jax_multiplatform_test( name = "export_pallas_test", srcs = ["export_pallas_test.py"], config_tags_overrides = { @@ -214,15 +220,7 @@ jax_test( "ondemand": False, # Include in presubmit. }, }, - disable_configs = [ - "gpu", - "gpu_x32", - "gpu_a100", - "gpu_h100", - "gpu_p100", - "gpu_p100_x32", - "gpu_pjrt_c_api", - ], + enable_backends = ["cpu"], enable_configs = [ "gpu_a100_x32", ], @@ -234,7 +232,7 @@ jax_test( ], ) -jax_test( +jax_multiplatform_test( name = "pallas_shape_poly_test", srcs = ["pallas_shape_poly_test.py"], config_tags_overrides = { @@ -261,29 +259,38 @@ jax_test( ], ) -jax_test( +jax_multiplatform_test( + name = "pallas_error_handling_test", + srcs = [ + "pallas_error_handling_test.py", + ], + enable_backends = ["tpu"], + deps = [ + "//jax:pallas", + "//jax:pallas_tpu", + "//jax/_src/pallas/mosaic:random", + "//third_party/py/absl/testing:absltest", + "//third_party/py/absl/testing:parameterized", + ] + py_deps("numpy"), +) + +jax_multiplatform_test( name = "tpu_all_gather_test", srcs = [ "tpu_all_gather_test.py", ], - disable_backends = [ - "cpu", - "gpu", - ], + enable_backends = ["tpu"], deps = [ "//jax:pallas_tpu_ops", ] + py_deps("absl/testing") + py_deps("numpy") + py_deps("hypothesis"), ) -jax_test( +jax_multiplatform_test( name = "tpu_gmm_test", srcs = [ "tpu_gmm_test.py", ], - disable_backends = [ - "cpu", - "gpu", - ], + enable_backends = ["tpu"], shard_count = 50, tags = [ "noasan", # Times out. @@ -300,15 +307,12 @@ jax_test( ]), ) -jax_test( +jax_multiplatform_test( name = "tpu_pallas_test", srcs = ["tpu_pallas_test.py"], # The flag is necessary for ``pl.debug_print`` tests to work on TPU. args = ["--logtostderr"], - disable_backends = [ - "cpu", - "gpu", - ], + enable_backends = ["tpu"], deps = [ "//jax:extend", "//jax:pallas_tpu", @@ -316,13 +320,27 @@ jax_test( ], ) -jax_test( - name = "tpu_pallas_distributed_test", - srcs = ["tpu_pallas_distributed_test.py"], - disable_backends = [ +jax_multiplatform_test( + name = "tpu_ops_test", + srcs = [ + "tpu_ops_test.py", + ], + enable_backends = [ "cpu", - "gpu", + "tpu", ], + deps = [ + "//jax:pallas", + "//jax:pallas_gpu", # build_cleaner: keep + "//jax:pallas_tpu", + "//jax:pallas_tpu_ops", + ] + py_deps("absl/testing") + py_deps("hypothesis") + py_deps("numpy"), +) + +jax_multiplatform_test( + name = "tpu_pallas_distributed_test", + srcs = ["tpu_pallas_distributed_test.py"], + enable_backends = ["tpu"], deps = [ "//jax:extend", "//jax:pallas_tpu", @@ -330,13 +348,10 @@ jax_test( ], ) -jax_test( +jax_multiplatform_test( name = "tpu_pallas_pipeline_test", srcs = ["tpu_pallas_pipeline_test.py"], - disable_backends = [ - "cpu", - "gpu", - ], + enable_backends = ["tpu"], shard_count = 5, tags = [ "noasan", # Times out. @@ -350,13 +365,21 @@ jax_test( ] + py_deps("hypothesis"), ) -jax_test( +jax_multiplatform_test( + name = "tpu_pallas_async_test", + srcs = ["tpu_pallas_async_test.py"], + enable_backends = ["tpu"], + tags = [ + ], + deps = [ + "//jax:pallas_tpu", + ], +) + +jax_multiplatform_test( name = "tpu_pallas_mesh_test", srcs = ["tpu_pallas_mesh_test.py"], - disable_backends = [ - "cpu", - "gpu", - ], + enable_backends = ["tpu"], tags = [ "noasan", "nomsan", @@ -368,15 +391,12 @@ jax_test( ], ) -jax_test( +jax_multiplatform_test( name = "tpu_pallas_random_test", srcs = [ "tpu_pallas_random_test.py", ], - disable_backends = [ - "cpu", - "gpu", - ], + enable_backends = ["tpu"], deps = [ "//jax:pallas", "//jax:pallas_tpu", @@ -386,13 +406,10 @@ jax_test( ] + py_deps("numpy"), ) -jax_test( +jax_multiplatform_test( name = "tpu_paged_attention_kernel_test", srcs = ["tpu_paged_attention_kernel_test.py"], - disable_backends = [ - "cpu", - "gpu", - ], + enable_backends = ["tpu"], shard_count = 5, tags = [ "noasan", # Times out. @@ -404,15 +421,12 @@ jax_test( ] + py_deps("absl/testing") + py_deps("numpy"), ) -jax_test( +jax_multiplatform_test( name = "tpu_splash_attention_kernel_test", srcs = [ "tpu_splash_attention_kernel_test.py", ], - disable_backends = [ - "gpu", - "cpu", - ], + enable_backends = ["tpu"], shard_count = 24, tags = [ "noasan", # Times out. @@ -424,20 +438,21 @@ jax_test( ] + py_deps("absl/testing") + py_deps("numpy") + py_deps("hypothesis"), ) -jax_test( +jax_multiplatform_test( name = "tpu_splash_attention_mask_test", srcs = [ "tpu_splash_attention_mask_test.py", ], - disable_backends = [ - "gpu", + enable_backends = [ + "cpu", + "tpu", ], deps = [ "//jax:pallas_tpu_ops", ] + py_deps("absl/testing") + py_deps("numpy") + py_deps("hypothesis"), ) -jax_test( +jax_multiplatform_test( name = "gpu_attention_test", srcs = [ "gpu_attention_test.py", @@ -447,17 +462,7 @@ jax_test( "ondemand": False, # Include in presubmit. }, }, - disable_backends = [ - "tpu", - ], - disable_configs = [ - "gpu", - "gpu_x32", - "gpu_p100", - "gpu_p100_x32", - "gpu_a100", - "gpu_h100", - ], + enable_backends = ["cpu"], enable_configs = [ "gpu_a100_x32", "gpu_h100_x32", @@ -470,7 +475,7 @@ jax_test( ] + py_deps("absl/testing") + py_deps("numpy"), ) -jax_test( +jax_multiplatform_test( name = "gpu_ops_test", srcs = [ "gpu_ops_test.py", @@ -480,17 +485,7 @@ jax_test( "ondemand": False, # Include in presubmit. }, }, - disable_backends = [ - "tpu", - ], - disable_configs = [ - "gpu", - "gpu_x32", - "gpu_a100", - "gpu_h100", - "gpu_p100", - "gpu_p100_x32", - ], + enable_backends = ["cpu"], enable_configs = [ "gpu_a100_x32", "gpu_h100_x32", diff --git a/tests/pallas/export_back_compat_pallas_test.py b/tests/pallas/export_back_compat_pallas_test.py index 8cf3f9708e38..9e9935884b3a 100644 --- a/tests/pallas/export_back_compat_pallas_test.py +++ b/tests/pallas/export_back_compat_pallas_test.py @@ -17,15 +17,21 @@ update these tests. """ -from absl.testing import absltest +import math +from absl.testing import absltest import jax -import jax.numpy as jnp from jax._src import config from jax._src import test_util as jtu from jax._src.internal_test_util import export_back_compat_test_util as bctu -from jax._src.internal_test_util.export_back_compat_test_data.pallas import cuda_add_one +from jax._src.internal_test_util.export_back_compat_test_data.pallas import mosaic_matmul +from jax._src.internal_test_util.export_back_compat_test_data.pallas import mosaic_semaphore_dma +from jax._src.internal_test_util.export_back_compat_test_data.pallas import triton_add_one from jax.experimental import pallas as pl +from jax.experimental.pallas import tpu as pltpu +from jax.experimental.pallas.ops.tpu import matmul +import jax.numpy as jnp + config.parse_flags_with_absl() @@ -36,14 +42,12 @@ class CompatTest(bctu.CompatTestBase): def setUp(self): if jax.config.x64_enabled: self.skipTest("Only works in 32-bit") - if not jtu.test_device_matches(["gpu"]): - self.skipTest("Only works on GPU") if (jtu.test_device_matches(["cuda"]) and not jtu.is_cuda_compute_capability_at_least("8.0")): self.skipTest("Only works on GPUs with capability >= sm80") super().setUp() - def test_cuda_add_one(self): + def test_triton_add_one(self): def func(x): def add_one(x_ref, o_ref): o_ref[0] = x_ref[0] + 1 @@ -52,8 +56,51 @@ def add_one(x_ref, o_ref): in_specs=[pl.BlockSpec((1,), lambda i: i)], out_specs=pl.BlockSpec((1,), lambda i: i), grid=8)(x) - data = self.load_testdata(cuda_add_one.data_2024_05_02) + data = self.load_testdata(triton_add_one.data_2024_05_02) + + self.run_one_test(func, data) + + @jax.default_matmul_precision("bfloat16") + def test_mosaic_matmul(self): + dtype = jnp.float32 + def func(): + # Build the inputs here, to reduce the size of the golden inputs. + x_shape = (1024, 512) + bias = 1.0 + scale = 1e-3 + x = bias + scale * jnp.arange( + math.prod(x_shape), dtype=dtype).reshape(x_shape) + y = x[:512, :256] + res = matmul.matmul(x, y, block_shape=(256, 256)) + # Keep only slices of the output, to reduce the size of the goldens. + return res[::16, ::16] + + data = self.load_testdata(mosaic_matmul.data_2024_09_24) + self.run_one_test(func, data, rtol=2e-7) + + def test_mosaic_semaphore_dma(self): + if not (jtu.test_device_matches(["tpu"]) and + jtu.is_device_tpu_at_least(4)): + # TODO: crashes during compilation on TPU v4 + self.skipTest("Only works on TPU v5+") + + # The signatures of TPU ops for semaphore and DMA have changed. + # This test ensures that the new signatures are backwards compatible. + def func(): + def dma_kernel(x, y): + def body(dma_sem, sem): + pltpu.async_copy(x, y, dma_sem).wait() + pltpu.semaphore_signal(sem) + pltpu.semaphore_wait(sem) + pl.run_scoped( + body, pltpu.SemaphoreType.DMA, pltpu.SemaphoreType.REGULAR + ) + x = jnp.arange(128 * 128, dtype=jnp.float32).reshape(128, 128) + y = pl.pallas_call(dma_kernel, out_shape=x)(x) + return jnp.array_equal(x, y).astype(jnp.float32) + data = self.load_testdata( + mosaic_semaphore_dma.semaphore_and_dma_2024_04_22) self.run_one_test(func, data) diff --git a/tests/pallas/gpu_attention_test.py b/tests/pallas/gpu_attention_test.py index 571e1348a7a8..ed059c235329 100644 --- a/tests/pallas/gpu_attention_test.py +++ b/tests/pallas/gpu_attention_test.py @@ -21,7 +21,10 @@ from jax import random from jax._src import config from jax._src import test_util as jtu -from jax.experimental.pallas.ops.gpu import decode_attention +if sys.platform != "win32": + from jax.experimental.pallas.ops.gpu import decode_attention +else: + decode_attention = None import jax.numpy as jnp import numpy as np @@ -48,7 +51,7 @@ def setUp(self): if (jtu.test_device_matches(["cuda"]) and not jtu.is_cuda_compute_capability_at_least("8.0")): self.skipTest("Only works on GPU with capability >= sm80") - if sys.platform == "win32" and not self.INTERPRET: + if sys.platform == "win32": self.skipTest("Only works on non-Windows platforms") super().setUp() @@ -148,7 +151,7 @@ def test_gqa( o_ref = decode_attention.gqa_reference(q, k, v) np.testing.assert_allclose(o, o_ref, atol=0.05) -class DecodeAttentionInterpreterTest(DecodeAttentionTest): +class DecodeAttentionInterpretTest(DecodeAttentionTest): INTERPRET = True diff --git a/tests/pallas/gpu_ops_test.py b/tests/pallas/gpu_ops_test.py index e7b7a4daac3d..7692294cd6df 100644 --- a/tests/pallas/gpu_ops_test.py +++ b/tests/pallas/gpu_ops_test.py @@ -28,10 +28,16 @@ from jax._src.lax.control_flow.for_loop import for_loop from jax._src.pallas.pallas_call import _trace_kernel_to_jaxpr from jax.experimental import pallas as pl -from jax.experimental.pallas.ops.gpu import attention -from jax.experimental.pallas.ops.gpu import layer_norm -from jax.experimental.pallas.ops.gpu import rms_norm -from jax.experimental.pallas.ops.gpu import softmax +if sys.platform != "win32": + from jax.experimental.pallas.ops.gpu import attention + from jax.experimental.pallas.ops.gpu import layer_norm + from jax.experimental.pallas.ops.gpu import rms_norm + from jax.experimental.pallas.ops.gpu import softmax +else: + attention = None + layer_norm = None + rms_norm = None + softmax = None import jax.numpy as jnp import numpy as np @@ -125,7 +131,7 @@ def setUp(self): if (jtu.test_device_matches(["cuda"]) and not jtu.is_cuda_compute_capability_at_least("8.0")): self.skipTest("Only works on GPU with capability >= sm80") - if sys.platform == "win32" and not self.INTERPRET: + if sys.platform == "win32": self.skipTest("Only works on non-Windows platforms") super().setUp() @@ -172,7 +178,7 @@ def setUp(self): (1, 384, 8, 64, True, True, True, {}), (1, 384, 8, 64, True, True, False, {}), (2, 384, 8, 64, True, True, True, {}), - # regression test: https://github.com/google/jax/pull/17314 + # regression test: https://github.com/jax-ml/jax/pull/17314 (1, 384, 8, 64, True, False, False, {'block_q': 128, 'block_k': 64}), ] ] @@ -252,8 +258,8 @@ def impl(q, k, v): (1, 384, 1, 32, False, False), (2, 384, 2, 32, False, True), (2, 384, 2, 32, False, False), - # TODO(b/283035396): (1, 384, 1, 32, True, True), - # TODO(b/283035396): (2, 384, 2, 32, True, True), + (1, 384, 1, 32, True, True), + (2, 384, 2, 32, True, True), ] ] ) @@ -292,7 +298,7 @@ def f_ref(q, k, v): np.testing.assert_allclose(dv, dv_ref, atol=0.05) -class FusedAttentionInterpreterTest(FusedAttentionTest): +class FusedAttentionInterpretTest(FusedAttentionTest): INTERPRET = True @@ -340,7 +346,7 @@ def f_ref(x, w, b): np.testing.assert_allclose(db, db_ref, rtol=1e-2, atol=1e-2) -class FusedLayerNormInterpreterTest(FusedLayerNormTest): +class FusedLayerNormInterpretTest(FusedLayerNormTest): INTERPRET = True @@ -388,7 +394,7 @@ def f_ref(x, w, b): np.testing.assert_allclose(db, db_ref, rtol=1e-2, atol=1e-2) -class RmsNormInterpreterTest(RmsNormTest): +class RmsNormInterpretTest(RmsNormTest): INTERPRET = True @@ -413,7 +419,7 @@ def test_softmax(self, shape, dtype): }[dtype] # We upcast to float32 because NumPy <2.0 does not handle custom dtypes - # properly. See https://github.com/google/jax/issues/11014. + # properly. See https://github.com/jax-ml/jax/issues/11014. np.testing.assert_allclose( softmax.softmax(x, axis=-1).astype(jnp.float32), jax.nn.softmax(x, axis=-1).astype(jnp.float32), @@ -422,7 +428,7 @@ def test_softmax(self, shape, dtype): ) -class SoftmaxInterpreterTest(SoftmaxTest): +class SoftmaxInterpretTest(SoftmaxTest): INTERPRET = True diff --git a/tests/pallas/indexing_test.py b/tests/pallas/indexing_test.py index 11402ed99741..d49b83fe160b 100644 --- a/tests/pallas/indexing_test.py +++ b/tests/pallas/indexing_test.py @@ -15,12 +15,13 @@ """Tests for Pallas indexing logic and abstractions.""" from __future__ import annotations - +import sys import unittest from absl.testing import absltest from absl.testing import parameterized import jax +from jax import random from jax._src import test_util as jtu from jax._src import util from jax._src.state import indexing @@ -28,6 +29,11 @@ import jax.numpy as jnp from jax.experimental import pallas as pl +if sys.platform != "win32": + from jax.experimental.pallas import tpu as pltpu +else: + pltpu = None + try: import hypothesis as hp except (ModuleNotFoundError, ImportError): @@ -46,6 +52,26 @@ ds = indexing.ds +_INDEXING_TEST_CASES = [ + ((4, 8, 128), (...,), (4, 8, 128)), + ((4, 8, 128), (0,), (8, 128)), + ((4, 8, 128), (pl.ds(1, 2),), (2, 8, 128)), + ((4, 8, 128), (pl.ds(2, 2),), (2, 8, 128)), + ((4, 8, 128), (pl.ds(2, 2), 0), (8, 128)), + ((4, 8, 128), (pl.ds(2, 2), 1), (8, 128)), + ((4, 8, 128), (slice(2, 4), 1), (8, 128)), + ((4, 8, 128), (slice(2, 4), slice(0, 1), 0), (8, 128)), + ((4, 8, 128), ((0, pl.ds(0, 8), pl.ds(0, 128)), ...), (8, 128)), + ((4, 8, 128), (..., (0, pl.ds(0, 8), pl.ds(0, 128)), ...), (8, 128)), +] + + +def _maybe_ds_to_slice(x: int | slice | indexing.Slice) -> int | slice: + if isinstance(x, indexing.Slice): + return slice(x.start, x.start + x.size) + return x + + def int_indexer_strategy(dim) -> hps.SearchStrategy[int]: return hps.integers(min_value=np.iinfo(np.int32).min, max_value=dim - 1) @@ -88,7 +114,23 @@ def nd_indexer_strategy(draw, shape) -> NDIndexer: return NDIndexer.from_indices_shape(indices, shape) +class PallasBaseTest(jtu.JaxTestCase): + INTERPRET = False + + def setUp(self): + if not self.INTERPRET: + if not jtu.test_device_matches(["tpu"]): + self.skipTest("Only interpret mode supported on non-TPU") + + super().setUp() + + @classmethod + def pallas_call(cls, *args, **kwargs): + return pl.pallas_call(*args, interpret=cls.INTERPRET, **kwargs) + + class IndexerTest(jtu.JaxTestCase): + """These are unit tests for the indexer logic, not using pallas_call.""" def test_simple_ndindexer(self): indices = (0, 0) @@ -206,8 +248,12 @@ def test_ndindexer(self, data): indexer.get_indexer_shape()) +class IndexerOpsTest(PallasBaseTest): + def test_multi_indexing_interpreter_only(self): - # Interpreter only test! YMMV actually compiling this. + if not self.INTERPRET: + self.skipTest("Only supported in interpret mode") + # Interpret only test! YMMV actually compiling this. def permute(left, right, left_out_ref, right_out_ref): left_out = jnp.zeros_like(left) left_out = left_out.at[:, 0].set(left[:, 0]) @@ -253,8 +299,27 @@ def invoke_permutes(x_ref, y_ref, x_out_ref, y_out_ref): interpret=True, )(x, y) + def test_multi_indexing_destination_ref(self): + if not self.INTERPRET: + self.skipTest("Only supported in interpret mode") + def kernel(x_ref, o_ref): + o_ref[...] = jnp.zeros_like(o_ref) + new_o_ref = o_ref.at[pl.ds(0, 8)].at[0].at[pl.ds(0, 4), pl.ds(0, 4)] + new_o_ref[...] = x_ref[...] + + x = jax.random.normal(jax.random.key(0), shape=(4, 4)) + result = pl.pallas_call( + kernel, + out_shape=jax.ShapeDtypeStruct((16, 16, 16), x.dtype), + interpret=True, + )(x) + expected = jnp.zeros((16, 16, 16)).at[0, 0:4, 0:4].set(x) + np.testing.assert_array_equal(result, expected) + def test_ellipsis_indexing_iterpret_only(self): - # Interpreter only test! YMMV actually compiling this. + if not self.INTERPRET: + self.skipTest("Only supported in interpret mode") + # Interpret only test! YMMV actually compiling this. def permute_columns_in_row_kernel(left, right, new_left, new_right): shape = left.shape k = shape[-1] @@ -296,18 +361,372 @@ def permute_columns_in_row_kernel(left, right, new_left, new_right): interpret=True, )(left, right) - import numpy as np # noqa: F811 left_np = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.float32) right_np = np.array([[7, 8, 9], [10, 11, 12]], dtype=np.float32) left_out_np = left_np.copy() right_out_np = right_np.copy() - permute_columns_in_row_kernel(left_np, right_np, left_out_np, right_out_np) np.testing.assert_array_equal(left_out_np, left_out) np.testing.assert_array_equal(right_out_np, right_out) + @hp.given(hps.data()) + def test_vmap_nd_indexing(self, data): + self.skipTest("TODO(necula): enable this test; was in jax_triton.") + vmap_shape = data.draw(hnp.array_shapes(min_dims=1, max_dims=3, min_side=2), + label="vmap_shape") + el_shape = data.draw(hnp.array_shapes(min_dims=2), label="el_shape") + # TODO(sharadmv,apaszke): enable rank 0 and rank 1 Refs + # hp.assume(len(el_shape) >= 2) + nd_indexer = data.draw(nd_indexer_strategy(el_shape), label="nd_indexer") + expected_shape = jax.eval_shape(lambda x: x[nd_indexer], + jax.ShapeDtypeStruct(el_shape, jnp.float32)) + + ref = lambda x: x[nd_indexer] + def kernel(x_ref, y_ref): + x = pl.load(x_ref, nd_indexer) + pl.store(y_ref, (slice(None),) * len(y_ref.shape), x) + func = pl.pallas_call(kernel, out_shape=expected_shape) + + shape = el_shape + for vmap_dim in vmap_shape[::-1]: + index = data.draw(hps.integers(min_value=0, + max_value=max(0, len(shape) - 2)), + label="index") + # hp.assume(index <= max(0, len(shape) - 2)) + # TODO(sharadmv,apaszke): enable vmapping over batch axes in 2 minormost + # dimensions + shape = (*shape[:index], vmap_dim, *shape[index:]) + ref = jax.vmap(ref, in_axes=index, out_axes=0) + func = jax.vmap(func, in_axes=index, out_axes=0) + key = random.PRNGKey(0) + x = random.normal(key, shape, dtype=jnp.float32) + expected = ref(x) + y = func(x) + np.testing.assert_array_equal(y, expected) + + @parameterized.product( + indexer_type=["state", "pallas"], + case=_INDEXING_TEST_CASES, + ) + def test_can_load_with_ref_at(self, indexer_type, case): + if self.INTERPRET: + self.skipTest("TODO: fails in interpret mode.") + in_shape, indexers, out_shape = case + dtype = jnp.float32 + def body(x_ref, y_ref): + for indexer in indexers[:-1]: + x_ref = x_ref.at[indexer] + if indexer_type == "state": + x = x_ref[indexers[-1]] + y_ref[...] = x + elif indexer_type == "pallas": + x = pl.load(x_ref, indexers[-1]) + pl.store(y_ref, ..., x) + + x = random.normal(random.key(0), in_shape, dtype=dtype) + y = x + for indexer in indexers: + if not isinstance(indexer, tuple): + indexer = (indexer,) + indexer = tuple(map(_maybe_ds_to_slice, indexer)) + y = y[indexer] + assert y.shape == out_shape + out = self.pallas_call(body, out_shape=y)(x) + self.assertAllClose(out, y) + + @parameterized.product( + indexer_type=["state", "pallas"], + case=_INDEXING_TEST_CASES, + ) + def test_can_store_with_ref_at(self, indexer_type, case): + if self.INTERPRET: + self.skipTest("TODO: fails in interpret mode.") + in_shape, indexers, val_shape = case + dtype = jnp.float32 + def body(x_ref, y_ref): + y_ref[...] = jnp.zeros_like(y_ref) + for indexer in indexers[:-1]: + y_ref = y_ref.at[indexer] + if indexer_type == "state": + x = x_ref[...] + y_ref[indexers[-1]] = x + elif indexer_type == "pallas": + x = pl.load(x_ref, ...) + pl.store(y_ref, indexers[-1], x) + + val = random.normal(random.key(0), val_shape, dtype=dtype) + # Use NumPy arrays to do nested indexing and mutation. This is really + # annoying to do in vanilla JAX. + x = np.zeros(in_shape, dtype=dtype) + y = x + for indexer in indexers: + if not isinstance(indexer, tuple): + indexer = (indexer,) + indexer = tuple(map(_maybe_ds_to_slice, indexer)) + y = y[indexer] + assert y.shape == val_shape + y[...] = val + out = self.pallas_call(body, out_shape=x)(val) + self.assertAllClose(out, x) + + @parameterized.product( + indexer_type=["state", "pallas"], + slice_type=["slice", "ds"], + ) + @hp.given( + ref_shape=hps.sampled_from(((8, 8, 32), (7, 7, 33))), + indices=hps.tuples( + hps.integers(0, 6), hps.integers(0, 6), hps.integers(0, 31) + ), + strides=hps.tuples( + hps.integers(1, 10), hps.integers(1, 10), hps.integers(1, 10) + ), + ) + def test_strided_load_and_store( + self, indexer_type, slice_type, ref_shape, indices, strides + ): + if self.INTERPRET: + self.skipTest("TODO: fails in interpret mode.") + ref_shape = (*ref_shape, 128) + indices = (*indices, 0) + strides = (*strides, 1) + vec_shape = [ + (l - i + s - 1) // s for l, i, s in zip(ref_shape, indices, strides) + ] + dtype = jnp.float32 + + def body(x_ref, y_ref1, y_ref2): + if slice_type == "slice": + slices = tuple( + slice(i, rs, s) for i, rs, s in zip(indices, ref_shape, strides) + ) + else: + slices = tuple( + pl.ds(i, vs, s) for i, vs, s in zip(indices, vec_shape, strides) + ) + if indexer_type == "state": + y_ref1[...] = x_ref[slices] + y_ref2[slices] = y_ref1[...] + elif indexer_type == "pallas": + pl.store(y_ref1, ..., pl.load(x_ref, slices)) + pl.store(y_ref2, slices, pl.load(y_ref1, ...)) + + x = random.normal(random.key(0), ref_shape, dtype=dtype) + y1, y2 = self.pallas_call( + body, + out_shape=[ + jax.ShapeDtypeStruct(vec_shape, dtype), + jax.ShapeDtypeStruct(ref_shape, dtype), + ], + )(x) + slices = tuple( + slice(i, l, s) for l, i, s in zip(ref_shape, indices, strides) + ) + expected = x[slices] + self.assertAllClose(y1, expected, err_msg="Strided Load Error") + self.assertAllClose( + y2[slices], expected, err_msg="Strided Store Error" + ) + + def test_load_with_dynamic_2nd_minor_index(self): + if pltpu is None: + self.skipTest("No TPU module available.") + # We can take any dynamic index on the 2nd minor dimension as long as + # the minormost dimsize is vreg lane count. + m, n = 32, 128 + k = 10 + start = 2 + + def kernel(x_ref, indices, y_ref): + y_ref[...] = pl.load(x_ref, pl.ds(indices[0], k)) + + x = jnp.arange(m * n, dtype=jnp.int32).reshape((m, n)) + indices = jnp.array([start]) + + res = self.pallas_call( + kernel, + out_shape=jax.ShapeDtypeStruct((k, n), jnp.int32), + grid_spec=pltpu.PrefetchScalarGridSpec( + num_scalar_prefetch=0, + in_specs=[ + pl.BlockSpec(memory_space=pltpu.VMEM), + pl.BlockSpec(memory_space=pltpu.SMEM), + ], + ), + )(x, indices) + self.assertAllClose(res, x[start : start + k, :], atol=0., rtol=0.) + + def test_store_with_dynamic_2nd_minor_index(self): + if pltpu is None: + self.skipTest("No TPU module available.") + # We can take any dynamic index on the 2nd minor dimension as long as + # the minormost dimsize is vreg lane count. + m, n = 10, 128 + k = 32 + start = 2 + + def kernel(x_ref, indices, y_ref): + pl.store(y_ref, pl.ds(indices[0], m), x_ref[...]) + + x = jnp.arange(m * n, dtype=jnp.int32).reshape((m, n)) + indices = jnp.array([start]) + + res = self.pallas_call( + kernel, + out_shape=jax.ShapeDtypeStruct((k, n), jnp.int32), + grid_spec=pltpu.PrefetchScalarGridSpec( + num_scalar_prefetch=0, + in_specs=[ + pl.BlockSpec(memory_space=pltpu.VMEM), + pl.BlockSpec(memory_space=pltpu.SMEM), + ], + ), + )(x, indices) + self.assertAllClose(res[start : start + m, :], x, atol=0., rtol=0.) + + def test_load_one_row_with_dynamic_2nd_minor_index(self): + if pltpu is None: + self.skipTest("No TPU module available.") + # This test triggers strided load. We can take any dynamic index on the + # 2nd minor dimension as long as we load one row on the 2nd minor dim. + b, m, n = 4, 16, 256 + start = 3 + + def kernel(x_ref, indices, y_ref): + y_ref[...] = x_ref[:, pl.ds(indices[0], 1), :] + + x = jnp.arange(b * m * n, dtype=jnp.int32).reshape((b, m, n)) + indices = jnp.array([start]) + + res = self.pallas_call( + kernel, + out_shape=jax.ShapeDtypeStruct((b, 1, n), jnp.int32), + grid_spec=pltpu.PrefetchScalarGridSpec( + num_scalar_prefetch=0, + in_specs=[ + pl.BlockSpec(memory_space=pltpu.VMEM), + pl.BlockSpec(memory_space=pltpu.SMEM), + ], + ), + )(x, indices) + self.assertAllClose(res, x[:, start : start + 1, :], atol=0., rtol=0.) + + def test_store_one_row_with_dynamic_2nd_minor_index(self): + if pltpu is None: + self.skipTest("No TPU module available.") + # This test triggers strided store. We can take any dynamic index on the + # 2nd minor dimension as long as we store one row on the 2nd minor dim. + b, m, n = 4, 16, 256 + start = 3 + + def kernel(x_ref, indices, y_ref): + y_ref[:, pl.ds(indices[0], 1), :] = x_ref[...] + + x = jnp.arange(b * 1 * n, dtype=jnp.int32).reshape((b, 1, n)) + indices = jnp.array([start]) + + res = self.pallas_call( + kernel, + out_shape=jax.ShapeDtypeStruct((b, m, n), jnp.int32), + grid_spec=pltpu.PrefetchScalarGridSpec( + num_scalar_prefetch=0, + in_specs=[ + pl.BlockSpec(memory_space=pltpu.VMEM), + pl.BlockSpec(memory_space=pltpu.SMEM), + ], + ), + )(x, indices) + self.assertAllClose(res[:, start : start + 1, :], x, atol=0., rtol=0.) + + +class IndexerOpsInterpretTest(IndexerOpsTest): + INTERPRET = True + + +# TODO(ayx): Fix all test cases here +_ADVANCED_INDEXER_TEST_CASES = [ + # integer + ((3, 2), lambda arr, a, b, c, d: arr[2]), + # slice + ((12, 12), lambda arr, a, b, c, d: arr[::4, ::4]), + ((16, 16), lambda arr, a, b, c, d: arr[1:14:2, 2:13:4]), + ((8, 2), lambda arr, a, b, c, d: arr[1::3, :]), + # array + ((4, 3), lambda arr, a, b, c, d: arr[a]), + ((4, 3, 2), lambda arr, a, b, c, d: arr[c, c]), + # integer + 1-D array + ((4, 3), lambda arr, a, b, c, d: arr[2, a]), + ((4, 3), lambda arr, a, b, c, d: arr[a, 2]), + # slice + 1-D array + ((4, 3), lambda arr, a, b, c, d: arr[a, :]), + # ((4, 3), lambda arr, a, b, c, d: arr[:, a]), + ((6, 8, 3), lambda arr, a, b, c, d: arr[c, ::3]), + # ((8, 6, 3), lambda arr, a, b, c, d: arr[::3, c]), + # ((8, 8, 3), lambda arr, a, b, c, d: arr[::4, ::2, a]), + # ((8, 8, 3), lambda arr, a, b, c, d: arr[::4, a, ::2]), + ((8, 8, 3, 7), lambda arr, a, b, c, d: arr[b, ::4, a, ::2]), + ((3, 8, 8, 7), lambda arr, a, b, c, d: arr[b, a, ::4, ::2]), + # ((8, 8, 3, 7), lambda arr, a, b, c, d: arr[::4, b, a, ::2]), + ((16, 3, 6, 2), lambda arr, a, b, c, d: arr[::4, a, 1::2, b]), + ((8, 8, 3, 6), lambda arr, a, b, c, d: arr[b, ::4, a, a]), + # slice + array w/ broadcasting + ((8, 8, 3, 6), lambda arr, a, b, c, d: \ + arr[b[:, None], ::4, a[None], a[:, None]]), + # integer + slice + 1-D array + ((5, 8, 8, 3), lambda arr, a, b, c, d: arr[2, ::4, ::2, a]), + ((5, 8, 8, 3), lambda arr, a, b, c, d: arr[2, ::4, a, ::2]), + # boolean + # ((6, 2), lambda arr, a, b, c, d: arr[d]), + # ((8, 6), lambda arr, a, b, c, d: arr[::4, d]), +] + + +class AdvancedIndexerOpsTest(PallasBaseTest): + + def setUp(self): + if jtu.test_device_matches(["tpu"]): + self.skipTest("Advanced indexers are not supported on TPU") + + # 4 arrays that are used in test cases of advanced indexing + self.a = jnp.array([1, 1, 1, 1, 1], dtype=jnp.int32) + self.b = jnp.array([1, 2, 2, 2, 2], dtype=jnp.int32) + self.c = jnp.array([1, 0, 2, 2, -1, 1], dtype=jnp.int32) + self.d = jnp.array([1, 0, 0, 0, 0, 1], dtype=jnp.bool_) + + super().setUp() + + @parameterized.parameters(_ADVANCED_INDEXER_TEST_CASES) + def test_advanced_indexer(self, in_shape: tuple[int, ...], indexing_func): + a, b, c, d = self.a, self.b, self.c, self.d + + x = jnp.arange(np.prod(in_shape), dtype=jnp.float32).reshape(in_shape) + y = indexing_func(x, a, b, c, d) + + # `a_ref`, `b_ref`, `c_ref` and `d_ref` are for testing purposes. + # We have them here because we need to have a unified function signature + # for all test cases, even if the arrays are actually not used in any + # computation. + def kernel(x_ref, a_ref, b_ref, c_ref, d_ref, o_ref): + a = a_ref[...] + b = b_ref[...] + c = c_ref[...] + d = d_ref[...] + o = indexing_func(x_ref, a, b, c, d) + o_ref[...] = o + + y_ = self.pallas_call( + kernel, + out_shape=jax.ShapeDtypeStruct(y.shape, jnp.float32), + )(x, a, b, c, d) + + np.testing.assert_array_equal(y_, y) + + +class AdvancedIndexerOpsInterpretTest(AdvancedIndexerOpsTest): + INTERPRET = True + if __name__ == "__main__": absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/tests/pallas/mosaic_gpu_test.py b/tests/pallas/mosaic_gpu_test.py index 7d072678fb4c..b35658ed4845 100644 --- a/tests/pallas/mosaic_gpu_test.py +++ b/tests/pallas/mosaic_gpu_test.py @@ -13,13 +13,14 @@ # limitations under the License. import functools +import math from absl.testing import absltest from absl.testing import parameterized import jax from jax._src import config from jax._src import test_util as jtu -import jax._src.pallas.mosaic_gpu.core as plgpu +import jax._src.pallas.mosaic_gpu as plgpu from jax.experimental import pallas as pl import jax.numpy as jnp import numpy as np @@ -46,22 +47,134 @@ def test_add_one(self): pl.pallas_call, out_shape=jax.ShapeDtypeStruct([256], jnp.float32), ) - def add_one(x_ref, o_ref): + def kernel(x_ref, o_ref): o_ref[...] = x_ref[...] + 1.0 x = jnp.arange(256).astype(jnp.float32) - np.testing.assert_array_equal(add_one(x), x + 1.0) + np.testing.assert_array_equal(kernel(x), x + 1.0) - def test_add_doubled_sum(self): + def test_add_xy(self): @functools.partial( pl.pallas_call, out_shape=jax.ShapeDtypeStruct([256], jnp.float32), ) - def add_one(x_ref, o_ref): - o_ref[...] = x_ref[...] + jnp.sum(x_ref[...]) + jnp.sum(x_ref[...]) + def kernel(x_ref, y_ref, o_ref): + o_ref[...] = x_ref[...] + y_ref[...] + + x = jnp.arange(256).astype(jnp.float32) + y = x + 1 + np.testing.assert_array_equal(kernel(x, y), x + y) + + def test_add_one_grid(self): + @functools.partial( + pl.pallas_call, + in_specs=[pl.BlockSpec((128,), lambda *i: i)], + out_specs=pl.BlockSpec((128,), lambda *i: i), + out_shape=jax.ShapeDtypeStruct([128 * 2], jnp.float32), + grid=2, + ) + def kernel(x_ref, o_ref): + o_ref[...] = x_ref[...] + 1.0 + + x = jnp.arange(128 * 2).astype(jnp.float32) + np.testing.assert_array_equal(kernel(x), x + 1.0) + + def test_add_one_grid_with_scratch(self): + + @functools.partial( + pl.pallas_call, + out_shape=jax.ShapeDtypeStruct([128 * 2], jnp.float32), + in_specs=[pl.BlockSpec((128,), lambda *i: i)], + out_specs=pl.BlockSpec((128,), lambda *i: i), + scratch_shapes=[plgpu.SMEM((128,), jnp.float32)], + grid=2, + ) + def kernel(x_ref, o_ref, scratch_ref): + scratch_ref[...] = x_ref[...] + 1 + o_ref[...] = scratch_ref[...] x = jnp.arange(256).astype(jnp.float32) - np.testing.assert_array_equal(add_one(x), x + x.sum()*2) + np.testing.assert_array_equal(kernel(x), x + 1.0) + + @parameterized.product(max_concurrent_steps=[1, 2, 3, 4]) + def test_add_one_grid_pipelined(self, max_concurrent_steps): + + @functools.partial( + pl.pallas_call, + in_specs=[pl.BlockSpec((128, 16), lambda i, j: (i, j))], + out_specs=pl.BlockSpec((128, 16), lambda i, j: (i, j)), + out_shape=jax.ShapeDtypeStruct([128 * 2, 64], jnp.float32), + compiler_params=plgpu.GPUCompilerParams( + dimension_semantics=["parallel", "sequential"], + max_concurrent_steps=max_concurrent_steps, + ), + grid=(2, 4), + ) + def kernel(x_ref, o_ref): + o_ref[...] = x_ref[...] + 1.0 + + x = jnp.arange(128 * 2 * 64).reshape((128 * 2, 64)).astype(jnp.float32) + np.testing.assert_array_equal(kernel(x), x + 1.0) + + def test_add_one_with_async_copy_smem_to_gmem(self): + @functools.partial( + pl.pallas_call, + out_shape=jax.ShapeDtypeStruct([128], jnp.float32), + out_specs=pl.BlockSpec(memory_space=plgpu.GMEM), + scratch_shapes=[plgpu.SMEM((128,), jnp.float32)], + ) + def kernel(x_ref, o_ref_gmem, scratch_ref): + scratch_ref[...] = x_ref[...] + 1 + plgpu.async_copy_smem_to_gmem(scratch_ref, o_ref_gmem) + plgpu.wait_smem_to_gmem(0) + + x = jnp.arange(128).astype(jnp.float32) + np.testing.assert_array_equal(kernel(x), x + 1.0) + + def test_add_one_with_async_copy_gmem_to_smem(self): + + @functools.partial( + pl.pallas_call, + out_shape=jax.ShapeDtypeStruct([128], jnp.float32), + in_specs=(pl.BlockSpec(memory_space=plgpu.GMEM),), + scratch_shapes=[ + plgpu.SMEM((128,), jnp.float32), + plgpu.Barrier(num_arrivals=1), + ], + ) + def kernel(x_ref_gmem, o_ref, scratch_ref, barrier_ref): + plgpu.async_copy_gmem_to_smem( + x_ref_gmem, scratch_ref, barrier=barrier_ref + ) + plgpu.wait_barrier(barrier_ref) + o_ref[...] = scratch_ref[...] + 1 + + x = jnp.arange(128).astype(jnp.float32) + np.testing.assert_array_equal(kernel(x), x + 1.0) + + def test_add_doubled_sum(self): + @functools.partial( + pl.pallas_call, + out_shape=jax.ShapeDtypeStruct([128], jnp.float32), + ) + def kernel(x_ref, o_ref): + o_ref[...] = x_ref[...] + jnp.sum(x_ref[...]) + jnp.sum(x_ref[...]) + + x = jnp.arange(128).astype(jnp.float32) + np.testing.assert_array_equal(kernel(x), x + x.sum()*2) + + @parameterized.parameters(False, True) + def test_rsqrt(self, approx_math): + @functools.partial( + pl.pallas_call, + out_shape=jax.ShapeDtypeStruct([128], jnp.float32), + compiler_params=plgpu.GPUCompilerParams(approx_math=approx_math), + ) + def kernel(x_ref, o_ref): + o_ref[...] = jax.lax.rsqrt(x_ref[...]) + + x = jnp.arange(128).astype(jnp.float32) + np.testing.assert_allclose(kernel(x), jax.lax.rsqrt(x)) @parameterized.product(input_factor=[0.001, 1, 10, 100, 100]) def test_layer_norm(self, input_factor): @@ -72,7 +185,6 @@ def test_layer_norm(self, input_factor): @functools.partial( pl.pallas_call, out_shape=jax.ShapeDtypeStruct([256], jnp.float32), - compiler_params={"smem_scratch_bytes": 4 * 4}, ) def layer_norm(x_ref, o_ref): x_mean = jnp.mean(x_ref[...]) @@ -115,19 +227,52 @@ def kernel(x_ref, o_ref): self.assertEqual(output(), "It works!\n") - def test_print_with_values(self): + def test_print_scalar(self): @functools.partial( pl.pallas_call, - out_shape=jax.ShapeDtypeStruct([256], jnp.float32), + out_shape=jax.ShapeDtypeStruct([256], jnp.int32), ) def kernel(x_ref, o_ref): del o_ref - pl.debug_print("x[0] = {}", x_ref[0]) + pl.debug_print("x.sum() = {}", x_ref[...].sum()) - x = jnp.arange(256).astype(jnp.float32) - with self.assertRaises(Exception): - # TODO(slebedev): Remove assertRaises() once we support indexing. - kernel(x) + x = jnp.arange(256) + with jtu.capture_stdout() as output: + jax.block_until_ready(kernel(x)) + + self.assertIn(f"x.sum() = {x.sum()}", output()) + + def test_print_scalar_array(self): + @functools.partial( + pl.pallas_call, + out_shape=jax.ShapeDtypeStruct([256], jnp.int32), + ) + def kernel(x_ref, o_ref): + del o_ref + pl.debug_print("x.sum() = {}", x_ref[...].sum() + 1) + + x = jnp.arange(256) + with jtu.capture_stdout() as output: + jax.block_until_ready(kernel(x)) + + self.assertIn(f"x.sum() = {x.sum() + 1}", output()) + + def test_print_array(self): + in_shape = [2, 1, 64, 64] + + @functools.partial( + pl.pallas_call, + out_shape=jax.ShapeDtypeStruct(in_shape, jnp.int32), + ) + def kernel(x_ref, o_ref): + del o_ref + pl.debug_print("x: {}", x_ref[...]) + + x = jnp.arange(math.prod(in_shape)).reshape(in_shape) + with jtu.capture_stdout() as output: + jax.block_until_ready(kernel(x)) + + self.assertIn(f"x: [1, 0, 43, 23]/{in_shape}: 6871\n", output()) def test_scoped_allocation(self): def kernel(x_ref, o_ref): @@ -148,6 +293,222 @@ def body(tmp_ref): o = f(inp) np.testing.assert_array_equal(o, inp + 1.0) + def test_program_id(self): + @functools.partial( + pl.pallas_call, + in_specs=(), + out_specs=pl.BlockSpec((128,), lambda *i: i), + out_shape=jax.ShapeDtypeStruct([128 * 2], jnp.int32), + grid=2, + ) + def kernel(o_ref): + o_ref[...] = jnp.full(o_ref.shape, pl.program_id(0)) + + np.testing.assert_array_equal( + kernel(), + jnp.array([0] * 128 + [1] * 128, dtype=jnp.int32), + ) + + def test_num_programs(self): + @functools.partial( + pl.pallas_call, + in_specs=(), + out_specs=pl.BlockSpec((128,), lambda *i: i), + out_shape=jax.ShapeDtypeStruct([128 * 2], jnp.int32), + grid=2, + ) + def kernel(o_ref): + o_ref[...] = jnp.full(o_ref.shape, pl.num_programs(0)) + + np.testing.assert_array_equal( + kernel(), + jnp.full([256], 2, dtype=jnp.int32), + ) + + def test_swizzled_blockspec_shapes(self): + + @functools.partial( + pl.pallas_call, + in_specs=[ + plgpu.GPUBlockSpec( + (128, 64), + lambda *i: i, + transforms=plgpu.TilingTransform((64, 64)), + swizzle=128, + ), + ], + out_specs=pl.BlockSpec((2, 1, 64, 64), lambda i, j: (i, j, 64, 64)), + out_shape=jax.ShapeDtypeStruct((4, 2, 64, 64), jnp.float16), + grid=(2, 2), + ) + def kernel(x_ref, o_ref): + assert x_ref.shape == (2, 1, 64, 64), x_ref.shape + o_ref[...] = x_ref[...] + + x = jnp.zeros((256, 128), dtype=jnp.float16) + result = kernel(x) + self.assertEqual(result.shape, (4, 2, 64, 64)) + + def test_fori_loop_array(self): + @functools.partial( + pl.pallas_call, + out_shape=jax.ShapeDtypeStruct([256], jnp.float32), + ) + def kernel(x_ref, o_ref): + # Equivalent to x_ref[...] + 2 + 3. + o_ref[...] = jax.lax.fori_loop(2, 4, lambda i, x: x + i, x_ref[...]) + + x = jnp.arange(256).astype(jnp.float32) + np.testing.assert_array_equal(kernel(x), x + 2.0 + 3.0) + + def test_fori_loop_scalar(self): + @functools.partial( + pl.pallas_call, + out_shape=jax.ShapeDtypeStruct([256], jnp.float32), + ) + def kernel(o_ref): + # Equivalent to 2 + 3. + o_ref[...] = jax.lax.broadcast( + jax.lax.fori_loop(2, 4, lambda i, x: x + i, 0.0), o_ref.shape + ) + + np.testing.assert_array_equal( + kernel(), jnp.full([256], 5.0, dtype=jnp.float32) + ) + + def test_cond(self): + + @functools.partial( + pl.pallas_call, + out_shape=jax.ShapeDtypeStruct([256], jnp.int32), + ) + def kernel(x_ref, o_ref): + acc = x_ref[...].sum() + jax.lax.cond( + acc % 2 == 0, + lambda: pl.debug_print("acc * 2: {}", acc * 2), + lambda: pl.debug_print("acc: {}", acc), + ) + o_ref[...] = jnp.broadcast_to(acc, o_ref.shape) + + x = jnp.arange(256) + with jtu.capture_stdout() as output: + jax.block_until_ready(kernel(x)) + + self.assertIn("acc * 2:", output()) + + @parameterized.parameters(jnp.float16, jnp.float32) + def test_wgmma(self, dtype): + # TensorCores can only fuse transposes of 16-bit values, and RHS + # is expected to be column major by default. + rhs_transpose = jnp.dtype(dtype).itemsize != 2 + swizzle = 128 + elems_128b = swizzle // jnp.dtype(dtype).itemsize + def kernel(a_ref, b_ref, o_ref): + def scope(acc_ref): + plgpu.wgmma(acc_ref, a_ref, b_ref, rhs_transpose=rhs_transpose) + return acc_ref[...] + + o_ref[...] = pl.run_scoped(scope, plgpu.ACC((64, 128), jnp.float32)) + + key1, key2 = jax.random.split(jax.random.key(42), 2) + a = jax.random.uniform(key1, shape=(64, 128), dtype=dtype) + b = jax.random.uniform(key2, shape=(128, 128), dtype=dtype) + + rhs_transforms = (plgpu.TilingTransform((elems_128b, elems_128b)),) + if rhs_transpose: + rhs_transforms += (plgpu.TransposeTransform((1, 0, 2, 3)),) + res = pl.pallas_call( + kernel, + in_specs=[ + plgpu.GPUBlockSpec( + (64, 128), + lambda i, j: (i, j), + transforms=plgpu.TilingTransform((64, elems_128b)), + swizzle=128, + ), + plgpu.GPUBlockSpec( + (128, 128), + lambda *i: i, + transforms=rhs_transforms, + swizzle=128, + ), + ], + out_specs=plgpu.GPUBlockSpec((64, 128), lambda *i: i), + out_shape=jax.ShapeDtypeStruct((64, 128), jnp.float32), + grid=(1, 1), + )(a, b) + np.testing.assert_allclose( + res, a @ (b.T if rhs_transpose else b), rtol=1e-3 + ) + + def test_input_output_aliases(self): + # Note that we're writing to the input pointer, which should alias b_ptr. + def kernel(a_ref, b_ref): + del b_ref + a_ref[...] = jnp.ones_like(a_ref) + + a = np.zeros((64, 64), dtype=jnp.float32) + b = pl.pallas_call( + kernel, + in_specs=[plgpu.GPUBlockSpec(memory_space=plgpu.GPUMemorySpace.GMEM)], + out_specs=plgpu.GPUBlockSpec(memory_space=plgpu.GPUMemorySpace.GMEM), + input_output_aliases={0: 0}, + out_shape=a, + )(a) + np.testing.assert_array_equal(b, np.ones_like(a)) + + def test_realistic_matmul(self): + dtype = jnp.float16 + swizzle = 128 + elems_128b = swizzle // jnp.dtype(dtype).itemsize + grid_m, grid_k, grid_n = 132, 10, 4 + tile_m = tile_n = 128 + tile_k = elems_128b + m, k, n = grid_m * tile_m, grid_k * tile_k, grid_n * tile_n + def kernel(a_ref, b_ref, o_ref, acc_ref): + plgpu.wgmma(acc_ref, a_ref, b_ref) + plgpu.wgmma_wait(0) # TODO(apaszke): Delay the pipeline to avoid memory races + # TODO(apaszke): Only store in the last step. It doesn't work because we + # don't have partial discharge for control flow. + # is_last_step = pl.program_id(2) == grid_k - 1 + # @pl.when(is_last_step) + # def _epilogue(): + # pl.debug_print("{}", acc_ref[...]) + # TODO(apaszke): This is an untiled store! It's slow!! + o_ref[...] = acc_ref[...] + + key1, key2 = jax.random.split(jax.random.key(42), 2) + a = jax.random.uniform(key1, shape=(m, k), dtype=dtype) + b = jax.random.uniform(key2, shape=(k, n), dtype=dtype) + + res = pl.pallas_call( + kernel, + in_specs=[ + plgpu.GPUBlockSpec( + (tile_m, tile_k), + lambda m, n, k: (m, k), + transforms=plgpu.TilingTransform((64, elems_128b)), + swizzle=128, + ), + plgpu.GPUBlockSpec( + (tile_k, tile_n), + lambda m, n, k: (k, n), + transforms=plgpu.TilingTransform((elems_128b, elems_128b)), + swizzle=128, + ), + ], + out_specs=plgpu.GPUBlockSpec((tile_m, tile_n), lambda m, n, k: (m, n)), + out_shape=jax.ShapeDtypeStruct((m, n), jnp.float32), + scratch_shapes=[plgpu.ACC((tile_m, tile_n), jnp.float32)], + grid=(grid_m, grid_n, grid_k), + compiler_params=plgpu.GPUCompilerParams( + dimension_semantics=["parallel", "parallel", "sequential"], + max_concurrent_steps=2, + ), + )(a, b) + np.testing.assert_allclose(res, a @ b, rtol=1e-3) + if __name__ == "__main__": absltest.main() diff --git a/tests/pallas/ops_test.py b/tests/pallas/ops_test.py index ee0cd3531e90..7bba9f01bec9 100644 --- a/tests/pallas/ops_test.py +++ b/tests/pallas/ops_test.py @@ -14,10 +14,13 @@ """Tests for common JAX operations within pallas_call.""" +from collections.abc import Sequence import contextlib import functools import itertools import sys +from typing import Any +import unittest import numpy as np from absl.testing import absltest @@ -28,6 +31,7 @@ from jax import lax from jax import random from jax._src import config +from jax._src import dtypes from jax._src import linear_util as lu from jax._src import state from jax._src import test_util as jtu @@ -41,10 +45,23 @@ plgpu = None pltpu = None +try: + import hypothesis as hp +except (ModuleNotFoundError, ImportError): + raise unittest.SkipTest("tests depend on hypothesis library") + +import hypothesis.extra.numpy as hnp +import hypothesis.strategies as hps + # There are many inherited redefinitions of _ # ruff: noqa: F811 jax.config.parse_flags_with_absl() +jtu.setup_hypothesis(max_examples=50) + + +intx = dtypes.canonicalize_dtype(jnp.int64) +floatx = dtypes.canonicalize_dtype(jnp.float64) def smem_on_tpu(): @@ -54,15 +71,188 @@ def smem_on_tpu(): return None +def _random_value(key: jax.Array, shape_dtype: jax.ShapeDtypeStruct + ) -> jax.Array: + if jnp.issubdtype(shape_dtype.dtype, jnp.floating): + return random.normal(key, shape_dtype.shape, dtype=shape_dtype.dtype) + elif jnp.issubdtype(shape_dtype.dtype, jnp.integer): + return random.randint( + key, shape_dtype.shape, minval=-4, maxval=4, dtype=shape_dtype.dtype + ) + raise NotImplementedError(shape_dtype) + + +_DTYPES = ( + "float32", + "bfloat16", + "int32", + "int16", + "int8", + "bool", +) + + +@hps.composite +def make_shape_dtype_strategy( + draw, *, + min_rank: int, + max_rank: int, + min_size_exp: int, + max_size_exp: int, + valid_dtypes: Sequence[jnp.dtype], + max_bytes: int = 2**16, +) -> jax.ShapeDtypeStruct: + dtype = draw(hps.sampled_from(valid_dtypes)) + # To generate shapes with power-of-two sizes, we draw the exponents of the + # sizes, and then generate the sizes from the exponents. + shape_exponents = tuple( + draw(hps.lists( + hps.integers(min_value=min_size_exp, max_value=max_size_exp), + min_size=min_rank, max_size=max_rank)) + ) + shape = tuple(2**exp for exp in shape_exponents) + size = np.prod(shape) * dtype.itemsize + hp.assume(size <= max_bytes) # Make sure we don't take more than 4K VMEM + return jax.ShapeDtypeStruct(shape, dtype) + + +@hps.composite +def arrays( + draw, shape: tuple[int, ...], dtype: np.dtype, + *, elements: hps.SearchStrategy[Any] | None = None, +) -> np.ndarray: + cast_to_bf16 = False + if dtype == np.dtype(jnp.bfloat16): + dtype = np.dtype('float32') + cast_to_bf16 = True + arr = draw(hnp.arrays(shape=shape, dtype=dtype, elements=elements)) + if cast_to_bf16: + arr = arr.astype(np.dtype(jnp.bfloat16)) + return arr + + +@hps.composite +def select_n_strategy( + draw, *, max_cases: int = 4, + min_rank: int = 0, max_rank: int = 2, + min_size_exp: int = 0, max_size_exp: int = 8, +) -> tuple[np.ndarray, ...]: + n_cases = draw(hps.integers(min_value=1, max_value=max_cases)) + case_shape_dtype = draw( + make_shape_dtype_strategy( + min_rank=min_rank, max_rank=max_rank, + min_size_exp=min_size_exp, max_size_exp=max_size_exp, + valid_dtypes=[ + np.dtype("int32"), + np.dtype("float32"), + # TODO(sharadmv,apaszke): enable bf16 + # np.dtype(jnp.bfloat16), + ], + ) + ) + allowed_elements = hps.integers(min_value=0, max_value=n_cases - 1) + pred_shape = draw(hps.sampled_from([(), case_shape_dtype.shape])) + # TODO(sharadmv,apaszke): enable passing bool arrays into Pallas kernels + if n_cases == 2 and not pred_shape: + pred_dtype = draw(hps.sampled_from([np.dtype(np.bool_), + np.dtype(np.int32)])) + allowed_elements = hps.booleans() + else: + pred_dtype = np.int32 + pred = draw(arrays(shape=pred_shape, dtype=pred_dtype, + elements=allowed_elements)) + cases = ( + draw( + arrays(shape=case_shape_dtype.shape, dtype=case_shape_dtype.dtype) + ) + for _ in range(n_cases) + ) + return pred, *cases + + +UNARY_PRIMITIVES = [ + # TODO(sharadmv,apaszke): enable zero rank + # TODO(sharadmv,apaszke): enable one rank + # TODO(sharadmv,apaszke): enable zero dim sizes + # TODO(sharadmv,apaszke): enable one dim sizes + ( + lax.neg_p, + make_shape_dtype_strategy( + min_rank=2, + max_rank=3, + min_size_exp=1, + max_size_exp=6, + valid_dtypes=[jnp.dtype("float32"), jnp.dtype("int32")], + ), + ), + ( + lax.not_p, + make_shape_dtype_strategy( + min_rank=2, + max_rank=3, + min_size_exp=1, + max_size_exp=6, + valid_dtypes=[jnp.dtype("int32")], + ), + ), + *[ + ( + prim, + make_shape_dtype_strategy( + min_rank=2, + max_rank=3, + min_size_exp=1, + max_size_exp=6, + valid_dtypes=[jnp.dtype("float32")], + ), + ) + for prim in [ + lax.exp_p, + lax.tanh_p, + lax.logistic_p, + lax.rsqrt_p, + lax.log_p, + lax.exp2_p, + lax.abs_p, + lax.log1p_p, + lax.sin_p, + lax.sqrt_p, + ] + ], +] + +UNARY_FUNCTIONS = [ + (prim.name, prim.bind, strategy) for prim, strategy in UNARY_PRIMITIVES +] + [ + ( + name, + func, + make_shape_dtype_strategy( + min_rank=2, + max_rank=3, + min_size_exp=1, + max_size_exp=6, + valid_dtypes=[jnp.dtype("float32")], + ), + ) + for name, func in [ + ("relu", jax.nn.relu), + ("pow2", lambda x: jnp.power(2, x)), + ("square", jnp.square), + ("reciprocal", jnp.reciprocal), + ("round", jnp.round), + ("rint", jnp.rint), + ] +] + + class PallasBaseTest(jtu.JaxTestCase): INTERPRET = False def setUp(self): - if jax.config.x64_enabled: - self.skipTest("Only works in 32-bit") if not self.INTERPRET: if jtu.device_under_test() == "cpu": - self.skipTest("Only interpreter mode supported on CPU") + self.skipTest("Only interpret mode supported on CPU") if (jtu.test_device_matches(["cuda"]) and not jtu.is_cuda_compute_capability_at_least("8.0")): self.skipTest("Only works on GPUs with capability >= sm80") @@ -76,11 +266,6 @@ def pallas_call(cls, *args, **kwargs): class OpsTest(PallasBaseTest): - def setUp(self): - super().setUp() - if jax.config.x64_enabled: - self.skipTest("Only works in 32-bit") - @parameterized.named_parameters( (fn.__name__, fn, dtype) for fn, dtype in [ (lax.pow, jnp.float32), @@ -153,7 +338,7 @@ def kernel(x_ref, y_ref, o_ref): result = self.pallas_call( kernel, - out_shape=jax.ShapeDtypeStruct([1, 128], jnp.int32), + out_shape=jax.ShapeDtypeStruct([1, 128], intx), in_specs=[ pl.BlockSpec(memory_space=smem_on_tpu()), pl.BlockSpec(memory_space=smem_on_tpu()), @@ -248,13 +433,15 @@ def kernel(x_ref, ones_ref, o_ref): float_value = jnp.where(reduced_as_bool, 1.0, 0.0) o_ref[0, 0] = float_value[0, 0] - if input_type == 'all_true': + if input_type == "all_true": x = jnp.ones((8, 128), dtype=jnp.float32) - elif input_type == 'all_false': + elif input_type == "all_false": x = jnp.zeros((8, 128), dtype=jnp.float32) - elif input_type == 'one_false': + elif input_type == "one_false": x = jnp.ones((8, 128), dtype=jnp.float32) x = x.at[0, 0].set(0.0) + else: + raise ValueError(f"Unknown input type: {input_type}") ones = jnp.ones_like(x) result = self.pallas_call( @@ -264,7 +451,7 @@ def kernel(x_ref, ones_ref, o_ref): pl.BlockSpec((8, 128), lambda *_: (0, 0)), ], out_specs=pl.BlockSpec(block_shape=(1, 1), memory_space=smem_on_tpu()), - out_shape=jax.ShapeDtypeStruct([1, 1], jnp.float32), + out_shape=jax.ShapeDtypeStruct([1, 1], floatx), grid=(1,), )(x, ones) np.testing.assert_array_equal(result[0, 0], float(expected_result)) @@ -286,14 +473,288 @@ def kernel(x_ref, o_ref): pl.BlockSpec((8, 128), lambda *_: (0, 0)), ], out_specs=pl.BlockSpec((1, 1), memory_space=smem_on_tpu()), - out_shape=jax.ShapeDtypeStruct([1, 1], jnp.float32), + out_shape=jax.ShapeDtypeStruct([1, 1], floatx), grid=(1,), )(x) np.testing.assert_allclose(result[0, 0], reduction_op(x), atol=1e-5) + # TODO(sharadmv): test rank < 2, size < 2 + @hp.given(select_n_strategy(max_cases=2, min_rank=2, max_rank=4, + min_size_exp=1)) + def test_select_n(self, args): + if jtu.test_device_matches(["gpu"]): + self.skipTest("TODO: error on GPU, lowering bug for select_n") + pred, *cases = args + scalar_pred = not pred.shape + + def kernel(*refs): + if scalar_pred: + *case_refs, o_ref = refs + pred_ = pred + else: + pred_ref, *case_refs, o_ref = refs + pred_ = pred_ref[...] + vals = [case_ref[...] for case_ref in case_refs] + o_ref[...] = lax.select_n(pred_, *vals) + out_ref = lax.select_n(pred, *cases) + if scalar_pred: + args = cases + else: + args = [pred, *cases] + out = self.pallas_call( + kernel, out_shape=jax.ShapeDtypeStruct(out_ref.shape, out_ref.dtype), + )(*args) + if out.dtype == jnp.bfloat16: + out, out_ref = out.astype(jnp.float32), out_ref.astype(jnp.float32) + np.testing.assert_allclose(out, out_ref) + + @parameterized.named_parameters( + (name, name, func, strategy) + for name, func, strategy in UNARY_FUNCTIONS + ) + @hp.given(hps.data()) + def test_unary_primitives(self, name, func, shape_dtype_strategy, data): + if self.INTERPRET: + self.skipTest("This hypothesis test is slow, even more so in interpret mode.") + # We want exact equality here to match how JAX lowers to XLA + tol = 0. + if jtu.test_device_matches(["gpu"]): + if func == jnp.round or func == jnp.rint: + self.skipTest("TODO: not implemented on GPU") + if name == "tanh": + tol = 1e-6 + elif name == "exp2": + tol = 1e-6 + elif jtu.test_device_matches(["tpu"]): + if not jtu.is_device_tpu_at_least(version=5) and False: + self.skipTest("TODO: not implemented on TPU v{3,4}") + + def kernel(x_ref, y_ref): + y_ref[...] = func(x_ref[...]) + x_shape_dtype = data.draw(shape_dtype_strategy) + key = random.key(0) + x = _random_value(key, x_shape_dtype) + out = self.pallas_call(kernel, out_shape=x_shape_dtype)(x) + self.assertAllClose(out, func(x), atol=tol, rtol=tol) + + @parameterized.product(from_dtype=_DTYPES, to_dtype=_DTYPES) + @hp.given(hps.data()) + def test_cast(self, from_dtype, to_dtype, data): + if from_dtype == to_dtype: + self.skipTest("Unnecessary test") + if jtu.is_device_tpu(version=4): + if from_dtype in {"int16", "int8"} or to_dtype in {"int16", "int8"}: + self.skipTest( + "Not supported: TPU generation doesn't support this cast." + ) + if jtu.test_device_matches(["tpu"]) and jtu.get_tpu_version() < 4: + if from_dtype in {"int32", "float32", "bfloat16"} and to_dtype in {"int16", "int8"}: + self.skipTest( + "Not supported: TPU generation doesn't support this cast." + ) + + # TODO(sharadmv,apaszke): add support for the following casts + if from_dtype == "int16" and to_dtype == "int8": + self.skipTest("Not supported: bad canonicalization") + if from_dtype == "int8" and to_dtype == "int16": + self.skipTest("Not supported: bad canonicalization") + if from_dtype == "bool" and to_dtype in {"int16", "int8"}: + self.skipTest("Not supported: cannot extend to sub-32 bit types") + if from_dtype in {"bfloat16", "float32"} and to_dtype == "bool": + self.skipTest("Not supported: unsupported relayout") + if from_dtype == "bool" and to_dtype in {"int32", "bfloat16", "float32"}: + self.skipTest("Not supported: unsupported relayout") + if from_dtype in {"int16", "int8"} and to_dtype == "bool": + self.skipTest("Not supported: cannot truncate from sub-32 bit types") + if from_dtype in {"int16", "int8"} and to_dtype == "bool": + self.skipTest("Not supported: cannot truncate from sub-32 bit types") + if jtu.test_device_matches(["gpu"]): + if (from_dtype in {"bfloat16", "float32"} and + to_dtype in {"int8", "int16", "int32"}): + self.skipTest("TODO: wrong result on GPU") + + if from_dtype == "bfloat16": + from_dtype = jnp.bfloat16 + if to_dtype == "bfloat16": + to_dtype = jnp.bfloat16 + + if from_dtype == jnp.bfloat16: + x = jnp.asarray(data.draw(hnp.arrays(jnp.float32, (8, 128)))) + x = x.astype(jnp.bfloat16) + else: + x = data.draw(hnp.arrays(from_dtype, (8, 128))) + x = jnp.asarray(x) + if from_dtype == jnp.dtype("bool"): + x = x.astype(jnp.int32) + def kernel(x_ref, y_ref): + x = x_ref[...] + if from_dtype == jnp.dtype("bool"): + x = x.astype(jnp.dtype("bool")) + y = x.astype(to_dtype) + if to_dtype == jnp.dtype("bool"): + y = y.astype(jnp.int32) + y_ref[...] = y + if (y_dtype := to_dtype) == jnp.dtype("bool"): + y_dtype = jnp.int32 + y = self.pallas_call( + kernel, out_shape=jax.ShapeDtypeStruct(x.shape, y_dtype))(x) + if to_dtype == jnp.dtype("bool"): + y = y.astype(jnp.dtype("bool")) + y_ref = x.astype(to_dtype) + if to_dtype == jnp.bfloat16: + y, y_ref = y.astype(np.float32), y_ref.astype(np.float32) + np.testing.assert_allclose(y, y_ref, atol=0., rtol=0.) + + @parameterized.parameters( + jnp.bfloat16, + jnp.float8_e5m2, + jnp.float8_e4m3fn, + ) + @jtu.skip_on_devices("gpu") + def test_scalar_downcast_float32(self, dtype): + + def kernel(x_ref, o_ref): + o_ref[0, 0] = x_ref[:][0, 0].astype(dtype) + + x = jax.random.normal(jax.random.key(0), (8, 128), dtype=jnp.float32) + result = self.pallas_call( + kernel, + in_specs=[ + pl.BlockSpec((8, 128), lambda *_: (0, 0)), + ], + out_specs=pl.BlockSpec((1, 1), memory_space=smem_on_tpu()), + out_shape=jax.ShapeDtypeStruct([1, 1], dtype), + grid=(1,), + )(x) + + np.testing.assert_array_equal(result[0, 0], x[0, 0].astype(dtype)) + + @parameterized.product( + shape=((64,), (8, 8)), + dtype=(jnp.int32, jnp.int16, jnp.int8), + ) + def test_scalar_map(self, shape, dtype): + if pltpu is None: + self.skipTest("No TPU module available.") + if dtype != jnp.int32 and len(shape) < 2: + # TODO(b/299280718): Implement this. + self.skipTest( + "Loads and stores not implemented for 1D arrays of non-32bit types" + ) + def kernel(x_ref, y_ref): + for idx in np.ndindex(shape): + x = x_ref[idx].astype(jnp.int32) + y_ref[idx] = (x * x).astype(y_ref.dtype) + f = self.pallas_call( + kernel, + in_specs=[ + pl.BlockSpec(memory_space=pltpu.SMEM), + ], + out_specs=pl.BlockSpec(memory_space=pltpu.SMEM), + out_shape=jax.ShapeDtypeStruct(shape, dtype), + ) + x = jnp.arange(np.prod(shape), dtype=dtype).reshape(shape) + self.assertAllClose(f(x), x * x) + + @jtu.skip_on_devices("gpu") # TODO: not implemented + def test_extract_scalar(self): + if pltpu is None: + self.skipTest("No TPU module available.") + def kernel(x_ref, y_ref): + y_ref[0, 0] = x_ref[:][0, 0] + f = self.pallas_call( + kernel, + out_shape=jax.ShapeDtypeStruct((1, 1), jnp.float32), + out_specs=pl.BlockSpec(memory_space=pltpu.SMEM), + ) + x = np.arange(1024, dtype=jnp.float32).reshape(8, 128) + 10 + self.assertAllClose(f(x).item(), 10.0) + + @jtu.skip_on_devices("gpu") # TODO: not implemented + def test_concat_constant(self): + if pltpu is None: + self.skipTest("No TPU module available.") + def kernel(out): + result = [] + for i in range(16): + result.append(jnp.full((1, 128), i, jnp.float32)) + out[:] = jnp.stack(result).reshape(16, 128) + + def run(interpret=False): + return pl.pallas_call( + kernel, + out_shape=jax.ShapeDtypeStruct((16, 128), jnp.float32), + out_specs=pl.BlockSpec(memory_space=pltpu.VMEM), + interpret=interpret, + )() + expected = run(True) + if not self.INTERPRET: + actual = run(False) + self.assertAllClose(actual, expected) + + @parameterized.named_parameters( + (f"{dtype.__name__}_{value}", dtype, value) + for dtypes, values in ( + ((jnp.uint16, jnp.uint32, jnp.uint64), (0, 5)), + ((jnp.int16, jnp.int32, jnp.int64), (-3, 0, 5)), + ( + (jnp.bfloat16, jnp.float16, jnp.float32, jnp.float64), + (-3.2, -0., 0., 5.1, jnp.nan, jnp.inf, -jnp.inf), + ), + ) + for dtype in dtypes + for value in values + ) + def test_sign(self, dtype, value): + if ( + not jax.config.x64_enabled + and dtype in (jnp.uint64, jnp.int64, jnp.float64) + ): + self.skipTest("64-bit types require x64_enabled") + + if ( + jtu.test_device_matches(["tpu"]) + and dtype in (jnp.uint16, jnp.int16, jnp.bfloat16, jnp.float16) + ): + self.skipTest("16-bit types are not supported on TPU") + + @functools.partial( + self.pallas_call, + out_shape=jax.ShapeDtypeStruct((8, 128), dtype), + ) + def kernel(x_ref, o_ref): + o_ref[...] = jnp.sign(x_ref[...]) + + x = jnp.full((8, 128,), value, dtype=dtype) + out = kernel(x) + expected = jnp.sign(x) + + # `.astype(jnp.float32)` is a workaround for dtype=bfloat16 and value=nan, + # see https://github.com/jax-ml/ml_dtypes/issues/206 + np.testing.assert_array_equal( + out.astype(jnp.float32), + expected.astype(jnp.float32), + ) + + @parameterized.parameters( + -3.2, -1.0, -0.999517, -0.4, 0., 0.72, 0.999517, 1.0, 2.4, + ) + def test_erf_inv(self, value): + @functools.partial( + self.pallas_call, + out_shape=jax.ShapeDtypeStruct((8, 128), floatx), + ) + def kernel(x_ref, o_ref): + o_ref[...] = lax.erf_inv(x_ref[...]) + + x = jnp.full((8, 128), value, dtype=floatx) + out = kernel(x) + expected = lax.erf_inv(x) + np.testing.assert_array_equal(out, expected) + -class OpsInterpreterTest(OpsTest): +class OpsInterpretTest(OpsTest): INTERPRET = True def test_debug_print(self): @@ -320,7 +781,7 @@ class OpsExtraTest(PallasBaseTest): def setUp(self): super().setUp() if jtu.test_device_matches(["tpu"]) and not self.INTERPRET: - # TODO: most tests fail on TPU in non-interpreter mode + # TODO: most tests fail on TPU in non-interpret mode self.skipTest("On TPU the test works only in interpret mode") ELEMENTWISE_OPS = [ @@ -362,6 +823,17 @@ def kernel(x_ref, o_ref): x = jnp.array([0.42, 2.4]).astype(dtype) np.testing.assert_allclose(kernel(x), fn(x), rtol=1e-6) + def test_abs_weak_type(self): + # see https://github.com/jax-ml/jax/issues/23191 + @functools.partial( + self.pallas_call, out_shape=jax.ShapeDtypeStruct((4, 4), floatx), + ) + def kernel(x_ref, o_ref): + o_ref[...] = jnp.abs(x_ref[...]) + + x = jnp.broadcast_to(-3.2, (4, 4)) # sets `weak_type` to `True` + np.testing.assert_allclose(kernel(x), jnp.abs(x), rtol=1e-6) + @parameterized.parameters( ("float32", "int32"), ("float64", "int32"), @@ -470,7 +942,7 @@ def kernel(x_ref, y_ref, o_ref): @parameterized.parameters("float16", "bfloat16") def test_true_divide_unsupported(self, dtype): if self.INTERPRET: - self.skipTest("No lowering in interpreter mode") + self.skipTest("No lowering in interpret mode") @functools.partial( self.pallas_call, @@ -540,7 +1012,7 @@ def kernel(o_ref): @parameterized.parameters("float16", "bfloat16", "float32") def test_approx_tanh(self, dtype): if self.INTERPRET: - self.skipTest("approx_tanh is not supported in interpreter mode") + self.skipTest("approx_tanh is not supported in interpret mode") if (dtype == "bfloat16" and not jtu.is_cuda_compute_capability_at_least("9.0")): self.skipTest("tanh.approx.bf16 requires a GPU with capability >= sm90") @@ -553,7 +1025,7 @@ def kernel(x_ref, o_ref): x = jnp.asarray([-1, 0.42, 0.24, 1]).astype(dtype) # We upcast to float32 because NumPy <2.0 does not handle custom dtypes - # properly. See https://github.com/google/jax/issues/11014. + # properly. See https://github.com/jax-ml/jax/issues/11014. np.testing.assert_allclose( kernel(x).astype(jnp.float32), jnp.tanh(x).astype(jnp.float32), @@ -564,7 +1036,7 @@ def kernel(x_ref, o_ref): def test_elementwise_inline_asm(self): if self.INTERPRET: self.skipTest( - "elementwise_inline_asm is not supported in interpreter mode" + "elementwise_inline_asm is not supported in interpret mode" ) @functools.partial( @@ -584,6 +1056,26 @@ def kernel(x_ref, o_ref): x = jnp.arange(256).astype(jnp.float16) np.testing.assert_allclose(kernel(x), jnp.tanh(x), atol=5e-3, rtol=5e-3) + def test_debug_barrier(self): + if self.INTERPRET: + self.skipTest("debug_barrier is not supported in interpret mode") + + @functools.partial( + self.pallas_call, + out_shape=jax.ShapeDtypeStruct((2,), jnp.float32), + grid=1, + ) + def kernel(x_ref, o_ref): + o_ref[...] = x_ref[...] + plgpu.debug_barrier() + + x = jnp.array([4.2, 2.4]).astype(jnp.float32) + np.testing.assert_array_equal(kernel(x), x) + + @unittest.skipIf( + sys.platform == "win32", + "plgpu.TritonCompilerParams unavailable on Windows", + ) def test_debug_print(self): # TODO: this test flakes on gpu if jtu.test_device_matches(["gpu"]): @@ -592,7 +1084,7 @@ def test_debug_print(self): self.pallas_call, out_shape=jax.ShapeDtypeStruct((2,), jnp.float32), grid=1, - compiler_params=dict(triton=dict(num_warps=1, num_stages=1)) + compiler_params=plgpu.TritonCompilerParams(num_warps=1, num_stages=1) ) def kernel(x_ref, o_ref): pl.debug_print("It works!") @@ -604,6 +1096,10 @@ def kernel(x_ref, o_ref): self.assertIn("It works!", output()) + @unittest.skipIf( + sys.platform == "win32", + "plgpu.TritonCompilerParams unavailable on Windows", + ) def test_debug_print_with_values(self): # TODO: this test flakes on gpu if jtu.test_device_matches(["gpu"]): @@ -612,7 +1108,7 @@ def test_debug_print_with_values(self): self.pallas_call, out_shape=jax.ShapeDtypeStruct((2,), jnp.float32), grid=1, - compiler_params=dict(triton=dict(num_warps=1, num_stages=1)) + compiler_params=plgpu.TritonCompilerParams(num_warps=1, num_stages=1) ) def kernel(x_ref, o_ref): pl.debug_print("x[0] =", x_ref[0]) @@ -675,20 +1171,20 @@ def f(x_ref, o_ref): def test_num_programs(self): @functools.partial( self.pallas_call, - out_shape=jax.ShapeDtypeStruct((4,), jnp.int32), + out_shape=jax.ShapeDtypeStruct((4,), intx), grid=4, ) def kernel(o_ref): o_ref[pl.program_id(0)] = pl.num_programs(0) np.testing.assert_array_equal( - kernel(), np.asarray([4, 4, 4, 4], dtype=np.int32) + kernel(), jnp.array([4, 4, 4, 4], dtype=intx) ) def test_where_broadcasting(self): @functools.partial( self.pallas_call, - out_shape=jax.ShapeDtypeStruct((4, 2, 2), jnp.float32), + out_shape=jax.ShapeDtypeStruct((4, 2, 2), floatx), grid=1, ) def copyitem(x_ref, in_idx_ref, out_idx_ref, o_ref): @@ -755,11 +1251,12 @@ def dot(x_ref, y_ref, o_ref): def test_masked_load_store(self, size, block_size): @functools.partial( self.pallas_call, - out_shape=(jax.ShapeDtypeStruct((size,), jnp.float32)), + out_shape=(jax.ShapeDtypeStruct((size,), floatx)), grid=pl.cdiv(size, block_size), ) def kernel(x_ref, o_ref): - idx = pl.program_id(0) * block_size + jnp.arange(block_size) + idx = pl.program_id(0) * block_size + jnp.arange( + block_size, dtype=jnp.int32) mask = idx < x_ref.shape[0] x = pl.load(x_ref, (idx,), mask=mask) pl.store(o_ref, (idx,), x + 1.0, mask=mask) @@ -773,7 +1270,7 @@ def test_masked_oob_load_store_slice(self): @functools.partial( self.pallas_call, - out_shape=(jax.ShapeDtypeStruct((n,), jnp.float32)), + out_shape=(jax.ShapeDtypeStruct((n,), floatx)), grid=1, ) def masked_oob_load_store_slice(x_ref, mask_ref, start_idx_ref, o_ref): @@ -790,11 +1287,7 @@ def masked_oob_load_store_slice(x_ref, mask_ref, start_idx_ref, o_ref): np.testing.assert_array_equal(out, o_new) def test_strided_load(self): - if self.INTERPRET: - # TODO(b/329733289): Remove this once the bug is fixed. - self.skipTest("Strided load not yet supported in interpreter mode") - - # Reproducer from https://github.com/google/jax/issues/20895. + # Reproducer from https://github.com/jax-ml/jax/issues/20895. @functools.partial( self.pallas_call, out_shape=jax.ShapeDtypeStruct((4,), jnp.float32), @@ -810,7 +1303,7 @@ def test_broadcasted_load_store(self): @functools.partial( self.pallas_call, - out_shape=(jax.ShapeDtypeStruct((m, n), jnp.float32)), + out_shape=(jax.ShapeDtypeStruct((m, n), floatx)), grid=1, ) def load(x_ref, o_ref): @@ -828,7 +1321,7 @@ def load(x_ref, o_ref): ) def test_invalid_broadcasted_load(self, x_shape, mask_shape): if self.INTERPRET: - self.skipTest("No broadcasting checks in pl.load in interpreter mode") + self.skipTest("No broadcasting checks in pl.load in interpret mode") @functools.partial( self.pallas_call, out_shape=jax.ShapeDtypeStruct((), jnp.float32) @@ -853,7 +1346,7 @@ def test_swap(self): @functools.partial( self.pallas_call, - out_shape=(jax.ShapeDtypeStruct((m, n), jnp.float32),) * 2, + out_shape=(jax.ShapeDtypeStruct((m, n), floatx),) * 2, grid=1, input_output_aliases={0: 0, 1: 1}, ) @@ -873,7 +1366,7 @@ def test_masked_swap(self): @functools.partial( self.pallas_call, - out_shape=(jax.ShapeDtypeStruct((m, n), jnp.float32),) * 2, + out_shape=(jax.ShapeDtypeStruct((m, n), floatx),) * 2, grid=1, input_output_aliases={0: 0, 1: 1}, ) @@ -894,8 +1387,8 @@ def test_masked_oob_swap_slice(self): @functools.partial( self.pallas_call, - out_shape=(jax.ShapeDtypeStruct((n,), jnp.float32), - jax.ShapeDtypeStruct((m,), jnp.float32)), + out_shape=(jax.ShapeDtypeStruct((n,), floatx), + jax.ShapeDtypeStruct((m,), floatx)), grid=1, input_output_aliases={0: 0, 1: 1}, ) @@ -964,7 +1457,7 @@ def test_array_atomic_add(self, axis): grid = m else: grid = n - out_shape = jax.ShapeDtypeStruct((n if axis == 0 else m,), jnp.float32) + out_shape = jax.ShapeDtypeStruct((n if axis == 0 else m,), floatx) @functools.partial( self.pallas_call, @@ -996,10 +1489,13 @@ def reduce(x_ref, _, y_ref): (2, 1, 1), ) def test_atomic_cas(self, init_value, cmp, new_value): + if jax.config.x64_enabled and jtu.test_device_matches(["gpu"]): + self.skipTest("Not supported on GPU in 64-bit mode") + @functools.partial( self.pallas_call, out_shape=( - jax.ShapeDtypeStruct((), jnp.int32), - jax.ShapeDtypeStruct((), jnp.int32)), + jax.ShapeDtypeStruct((), intx), + jax.ShapeDtypeStruct((), intx)), input_output_aliases={0: 0}) def swap(_, lock_ref, out_ref): out_ref[()] = pl.atomic_cas(lock_ref, cmp, new_value) @@ -1012,12 +1508,15 @@ def swap(_, lock_ref, out_ref): @parameterized.parameters(1, 2, 3, 4, 8) def test_atomic_counter(self, num_threads): if self.INTERPRET: - self.skipTest("While loop not supported in interpreter mode.") + self.skipTest("While loop not supported in interpret mode.") + + if jax.config.x64_enabled and jtu.test_device_matches(["gpu"]): + self.skipTest("Not supported on GPU in 64-bit mode") @functools.partial( self.pallas_call, out_shape=( - jax.ShapeDtypeStruct((), jnp.int32), - jax.ShapeDtypeStruct((), jnp.int32)), + jax.ShapeDtypeStruct((), intx), + jax.ShapeDtypeStruct((), intx)), input_output_aliases={0: 0, 1: 1}, grid=(num_threads,)) def increment(_, __, lock_ref, counter_ref): @@ -1062,14 +1561,31 @@ def reduce(x_ref, y_ref): ("argmin", jnp.argmin), ] for axis in [0, 1, (1,), (0, 1)] - for dtype in ["float16", "float32", "int32", "uint32"] + for dtype in [ + "float16", + "float32", + "float64", + "int32", + "int64", + "uint32", + "uint64", + ] if isinstance(axis, int) or "arg" not in op_name ]) def test_array_reduce(self, op, dtype, axis): m, n = 32, 8 - out_dtype = dtype - if op in {jnp.argmin, jnp.argmax}: - out_dtype = jnp.int32 + + if not jax.config.x64_enabled and dtype in ("float64", "int64", "uint64"): + self.skipTest("64-bit types require x64_enabled") + + # Skip argmin/argmax on GPU in 64-bit mode because Pallas expects + # `index_type` to be i32 + if ( + jax.config.x64_enabled + and jtu.test_device_matches(["gpu"]) + and op in {jnp.argmin, jnp.argmax} + ): + self.skipTest("Not supported on GPU in 64-bit mode") def make_x(key): if jnp.issubdtype(dtype, jnp.integer): @@ -1079,9 +1595,10 @@ def make_x(key): else: return random.normal(key, (m, n), dtype=dtype) + # deduct `out_dtype` by executing the op on a single element + out_dtype = op(jnp.arange(1, dtype=dtype)).dtype out_shape = jax.ShapeDtypeStruct( - op(make_x(random.key(0)), axis=axis).shape, out_dtype - ) + op(make_x(random.key(0)), axis=axis).shape, out_dtype) if isinstance(axis, int): grid = tuple(a for i, a in enumerate((m, n)) if i != axis) else: @@ -1089,9 +1606,11 @@ def make_x(key): @functools.partial(self.pallas_call, out_shape=out_shape, grid=grid) def reduce(x_ref, y_ref): - x = pl.load(x_ref, (jnp.arange(m)[:, None], jnp.arange(n)[None])) + x = pl.load(x_ref, (jnp.arange(m, dtype=jnp.int32)[:, None], + jnp.arange(n, dtype=jnp.int32)[None])) y = op(x, axis=axis) - pl.store(y_ref, tuple(jnp.arange(d) for d in y.shape), y) + pl.store(y_ref, + tuple(jnp.arange(d, dtype=jnp.int32) for d in y.shape), y) for i, key in enumerate(random.split(random.key(0), 20)): x = make_x(key) @@ -1130,7 +1649,7 @@ def reduce(x_ref, y_ref): np.testing.assert_allclose(y, y_ref, atol=1e-2, rtol=1e-2, err_msg=i) -class OpsExtraInterpreterTest(OpsTest): +class OpsExtraInterpretTest(OpsExtraTest): INTERPRET = True @@ -1187,59 +1706,9 @@ def body(x_ref): self.assertIn(expected, jaxpr.pretty_print(use_color=False)) -class PallasPrimitivesInterpreterTest(PallasPrimitivesTest): +class PallasPrimitivesInterpretTest(PallasPrimitivesTest): INTERPRET = True -class TpuOpsTest(PallasBaseTest): - - def setUp(self): - if not jtu.test_device_matches(["tpu"]): - self.skipTest("Test requires TPU device.") - - super().setUp() - - @parameterized.parameters([-3.2, -1.0, -0.4, 0., 0.72, 1.0, 2.4]) - def test_erf_inv(self, x): - @jax.jit - @functools.partial( - pl.pallas_call, - # TODO(ayx): add float64 support for `erf_inv` - out_shape=jax.ShapeDtypeStruct((8, 128), jnp.float32), - ) - def kernel(x_ref, o_ref): - o_ref[...] = lax.erf_inv(x_ref[...]) - - x = jnp.full((8, 128), x) - out = kernel(x) - expected = lax.erf_inv(x) - np.testing.assert_array_equal(out, expected) - - SIGN_PARAMS = [ - (jnp.int32, (-3, 0, 5)), - (jnp.uint32, (0, 5)), - (jnp.float32, (-3.2, -0., 0., 5.1, jnp.nan, jnp.inf, -jnp.inf)), - ] - - @parameterized.named_parameters( - (f"{dtype.__name__}_{value}", dtype, value) - for dtype, values in SIGN_PARAMS - for value in values - ) - def test_sign(self, dtype, value): - @jax.jit - @functools.partial( - pl.pallas_call, - out_shape=jax.ShapeDtypeStruct((8, 128), dtype), - ) - def kernel(x_ref, o_ref): - o_ref[...] = jnp.sign(x_ref[...]) - - x = jnp.full((8, 128,), value, dtype=dtype) - out = kernel(x) - expected = jnp.sign(x) - np.testing.assert_array_equal(out, expected) - - if __name__ == "__main__": absltest.main() diff --git a/tests/pallas/pallas_error_handling_test.py b/tests/pallas/pallas_error_handling_test.py new file mode 100644 index 000000000000..34b0ff1492a4 --- /dev/null +++ b/tests/pallas/pallas_error_handling_test.py @@ -0,0 +1,142 @@ +# Copyright 2024 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for Pallas error handling.""" +import functools +import traceback + +from absl.testing import absltest +import jax +from jax import numpy as jnp +from jax._src import config +from jax._src import test_util as jtu +from jax._src.pallas.mosaic import error_handling +from jax.experimental import pallas as pl +from jax.experimental.pallas import tpu as pltpu + + +config.parse_flags_with_absl() + +LOCATION_TEST_STRING = ( + r'loc("/squeeze"' + r'(callsite("foo_fn"("third_party/foo.py":104:22) at ' + r'callsite("bar_fn"("third_party/bar.py":115:6) at ' + r'""("third_party/pallas_error_handling_test.py":181:2' + r")))))" +) + + +class PallasErrorHandlingTest(jtu.JaxTestCase): + + def setUp(self): + super().setUp() + if not jtu.test_device_matches(["tpu"]): + self.skipTest("Test only works on TPU.") + + def test_non_singular_stride(self): + input_arr = jax.random.uniform( + jax.random.key(0), (8, 128), dtype=jnp.float32) + out_shape = jax.ShapeDtypeStruct((8, 16), jnp.float32) + grid_spec = pltpu.PrefetchScalarGridSpec( + num_scalar_prefetch=0, + in_specs=[ + pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.VMEM), + ], + out_specs=pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.VMEM), + ) + + @functools.partial(pl.pallas_call, out_shape=out_shape, grid_spec=grid_spec) + def test_kernel(input_ref, output_ref): + x = input_ref[:, ::8] + output_ref[...] = x + + # Test that a Mosaic error is raised. This assert is a guard against + # underlying changes in Mosaic. + # If this is fixed in future Mosaic releases we will need to change + # the test example to force a different error. + with self.assertRaisesRegex( + error_handling.MosaicError, + "Not Implemented: Stride on last dim is not 1", + ): + test_kernel(input_arr) + + # Test that the python source is the final frame in the traceback. + tb_string = "" + try: + test_kernel(input_arr) + except error_handling.MosaicError as e: + tb_string = traceback.format_tb(e.__traceback__) + tb_string = "".join(tb_string) + self.assertEndsWith(tb_string, "x = input_ref[:, ::8]\n") + + @jax.jit + def kernel_in_jitted_fn(x): + return test_kernel(x) + + with self.subTest("inside_jitted_fn"): + tb_string = "" + try: + kernel_in_jitted_fn(input_arr) + except error_handling.MosaicError as e: + tb_string = traceback.format_tb(e.__traceback__) + tb_string = "".join(tb_string) + self.assertEndsWith(tb_string, "x = input_ref[:, ::8]\n") + + def test_invalid_smem_vmem_verification_error(self): + input_arr = jax.random.uniform(jax.random.key(0), (2, 2), dtype=jnp.float32) + out_shape = jax.ShapeDtypeStruct((1, 1), jnp.float32) + grid_spec = pltpu.PrefetchScalarGridSpec( + num_scalar_prefetch=0, + in_specs=[ + pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.VMEM), + ], + out_specs=pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.SMEM), + ) + + @functools.partial(pl.pallas_call, out_shape=out_shape, grid_spec=grid_spec) + def test_kernel(input_ref, output_ref): + output_ref[0, 0] = input_ref[0, 0] + + # Test that a verification error is raised. This assert is a guard against + # underlying changes in Pallas lowering. + # If this is fixed in future Pallas releases we will need to change + # the test example to force a different error. + with self.assertRaisesRegex( + error_handling.VerificationError, + "'memref.store' op failed to verify that type of 'value' matches " + "element type of 'memref'", + ): + test_kernel(input_arr) + + # Test that the python source is the final frame in the traceback. + tb_string = "" + try: + test_kernel(input_arr) + except error_handling.MosaicError as e: + tb_string = traceback.format_tb(e.__traceback__) + tb_string = "".join(tb_string) + self.assertEndsWith(tb_string, "output_ref[0, 0] = input_ref[0, 0]\n") + + def test_parse_location_string(self): + name, frames = error_handling.parse_location_string(LOCATION_TEST_STRING) + self.assertEqual(name, "/squeeze") + self.assertLen(frames, 3) + self.assertEqual(frames[0].func_name, "foo_fn") + self.assertEqual(frames[0].filename, "third_party/foo.py") + self.assertEqual(frames[0].lineno, 104) + self.assertEqual(frames[0].colno, 22) + + +if __name__ == "__main__": + absltest.main() diff --git a/tests/pallas/pallas_jumble_test.py b/tests/pallas/pallas_jumble_test.py new file mode 100644 index 000000000000..5ed15fe964dd --- /dev/null +++ b/tests/pallas/pallas_jumble_test.py @@ -0,0 +1,201 @@ +# Copyright 2023 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import sys + +os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "0.5" + +from absl.testing import absltest +import jax +from jax import lax +from jax._src import config +from jax._src import core +from jax._src import dtypes +from jax._src import test_util as jtu +from jax._src.interpreters import batching +from jax._src.pallas.pallas_call import _trace_kernel_to_jaxpr +from jax.experimental import pallas as pl +import jax.numpy as jnp +import numpy as np + + +# TODO(mvoz): Update signatures of pallas_call to correct inputs/outputs. +# pylint: disable=no-value-for-parameter + +config.parse_flags_with_absl() + + +intx = dtypes.canonicalize_dtype(jnp.int64) +floatx = dtypes.canonicalize_dtype(jnp.float64) + + +@jtu.with_config(jax_traceback_filtering="off") +class PallasBaseTest(jtu.JaxTestCase): + INTERPRET = False + + def setUp(self): + if jtu.test_device_matches(["cpu"]) and not self.INTERPRET: + self.skipTest("On CPU the test works only in interpret mode") + if jtu.test_device_matches( + ["cuda"] + ) and not jtu.is_cuda_compute_capability_at_least("8.0"): + self.skipTest("Only works on GPU with capability >= sm80") + if sys.platform == "win32" and not self.INTERPRET: + self.skipTest("Only works on non-Windows platforms") + + super().setUp() + _trace_kernel_to_jaxpr.cache_clear() + + def pallas_call(self, *args, **kwargs): + return pl.pallas_call(*args, **kwargs, interpret=self.INTERPRET) + + +@jtu.with_config(jax_dynamic_shapes=True, jax_numpy_dtype_promotion="standard") +class PallasCallRaggedVmapTest(PallasBaseTest): + + def test_vmap_jumble_over_sin_kernel(self): + if not jtu.test_device_matches(["tpu"]): + self.skipTest("Only tested on TPU") + + row_count = 8 + col_grid_size = 5 + ragged_shape = [3, 1, 4] + sizes = lax.convert_element_type( + jnp.array([128 * x for x in ragged_shape]), + core.bint(col_grid_size * 128), + ) + x = jax.vmap( + lambda n: jnp.ones((row_count, n)), out_axes=batching.jumble_axis + )(sizes) + + def kernel(x_ref, o_ref): + o_ref[...] = jnp.sin(x_ref[...]) + + def invoke_kernel(x): + return pl.pallas_call( + kernel, + in_specs=[pl.BlockSpec((8, 128), lambda j, k: (j, k))], + out_specs=pl.BlockSpec((8, 128), lambda j, k: (j, k)), + out_shape=jax.ShapeDtypeStruct( + (8, col_grid_size * 128), dtype=jnp.float32 + ), + grid=(1, col_grid_size), + interpret=self.INTERPRET, + # See note - on zero filling counterfactuals + debug=True, + )(x) + + res = jax.vmap( + invoke_kernel, + out_axes=batching.jumble_axis, + in_axes=batching.jumble_axis, + axis_size=3, + )(x) + + res = res.data + total = len(ragged_shape) * row_count * col_grid_size * 128 + res_total = np.prod(res.shape) + self.assertEqual(res_total, total) + ragged_total = 0 + for dim in ragged_shape: + ragged_total += row_count * dim * 128 + # See note - on zero filling counterfactuals + self.assertEqual(np.count_nonzero(res == jnp.sin(1.0)), ragged_total) + + def test_vmap_jumble_over_sin_kernel_grid_remapping(self): + if not jtu.test_device_matches(["tpu"]): + self.skipTest("Only tested on TPU") + + row_count = 8 + col_grid_size = 5 + ragged_shape = [3, 1, 4] + sizes = lax.convert_element_type( + jnp.array([128 * x for x in ragged_shape]), + core.bint(col_grid_size * 128), + ) + x = jax.vmap( + lambda n: jnp.ones((row_count, n)), out_axes=batching.jumble_axis + )(sizes) + + def kernel(x_ref, o_ref): + o_ref[...] = jnp.sin(x_ref[...]) * pl.program_id(2) + + def invoke_kernel(x): + return pl.pallas_call( + kernel, + in_specs=[pl.BlockSpec((8, 128), lambda j, k: (j, k))], + out_specs=pl.BlockSpec((8, 128), lambda j, k: (j, k)), + out_shape=jax.ShapeDtypeStruct((8, 640), dtype=jnp.float32), + grid=(1, 5), + interpret=False, + )(x) + + with self.assertRaisesRegex(ValueError, "Axis 2 is out of bounds for grid"): + jax.vmap( + invoke_kernel, + out_axes=batching.jumble_axis, + in_axes=batching.jumble_axis, + axis_size=3, + )(x) + + def test_vmap_jumble_ragged_boundary_unaligned_with_grid(self): + if not jtu.test_device_matches(["tpu"]): + self.skipTest("Only tested on TPU") + + self.skipTest("Checkify NYI") + + row_count = 8 + col_grid_size = 5 + ragged_shape = [3, 1, 4] + sizes = lax.convert_element_type( + jnp.array([(128 * x) - 1 for x in ragged_shape]), + core.bint(col_grid_size * 128), + ) + x = jax.vmap( + lambda n: jnp.ones((row_count, n)), out_axes=batching.jumble_axis + )(sizes) + + def kernel(x_ref, o_ref): + o_ref[...] = jnp.sin(x_ref[...]) + + def invoke_kernel(x): + return pl.pallas_call( + kernel, + in_specs=[pl.BlockSpec((8, 128), lambda j, k: (j, k))], + out_specs=pl.BlockSpec((8, 128), lambda j, k: (j, k)), + out_shape=jax.ShapeDtypeStruct((8, 640), dtype=jnp.float32), + grid=(1, 5), + interpret=False, + )(x) + + with self.assertRaisesRegex( + ValueError, + "Ragged input shape must be evenly divisble by the grid" # noqa: W605 + " size at the ragged dimension 2", + ): + jax.vmap( + invoke_kernel, + out_axes=batching.jumble_axis, + in_axes=batching.jumble_axis, + axis_size=3, + )(x) + + +class PallasCallNamedGridInterpretTest(PallasCallRaggedVmapTest): + INTERPRET = True + + +if __name__ == "__main__": + absltest.main() diff --git a/tests/pallas/pallas_test.py b/tests/pallas/pallas_test.py index 63779319b0b9..6df31b55f8e7 100644 --- a/tests/pallas/pallas_test.py +++ b/tests/pallas/pallas_test.py @@ -33,7 +33,6 @@ from jax._src import dtypes from jax._src import test_util as jtu from jax._src.lax.control_flow.for_loop import for_loop -from jax._src.lib import version as jaxlib_version from jax._src.pallas import core as pallas_core from jax._src.pallas.pallas_call import _trace_kernel_to_jaxpr from jax.experimental import pallas as pl @@ -371,17 +370,10 @@ def copy_kernel(x_ref, o_ref): test_context = contextlib.nullcontext() if jtu.test_device_matches(["tpu"]) and not self.INTERPRET: - if jaxlib_version < (0, 4, 32): - # TODO(b/356116061): Remove the old rank condition - if rank < 2: - test_context = self.assertRaisesRegex( - ValueError, - "TPU lowering currently supports only blocks of rank >= 2") - else: - if rank < 1: - test_context = self.assertRaisesRegex( - ValueError, - "TPU lowering currently supports only blocks of rank >= 1") + if rank < 1: + test_context = self.assertRaisesRegex( + ValueError, + "TPU lowering currently supports only blocks of rank >= 1") if rank >= 1: bs0, as0 = block_shape[-1], shape[-1] @@ -492,7 +484,7 @@ def kernel(o_ref): self.assertAllClose(pids[0:4], np.array([0] * 4, dtype=np.int32)) def test_hoisted_consts(self): - # See https://github.com/google/jax/issues/21557. + # See https://github.com/jax-ml/jax/issues/21557. # to_store will be hoisted as a constant. Choose distinct shapes from in/outs. to_store = np.arange(128, dtype=np.float32).reshape((1, 128)) x = np.arange(16 * 128, dtype=np.float32).reshape((16, 128)) @@ -687,7 +679,147 @@ def f(x): self.assertEqual(trace_count, 1) -class PallasCallInterpreterTest(PallasCallTest): +class PallasCallInterpretTest(PallasCallTest): + INTERPRET = True + + +class PallasCallUnblockedIndexingTest(PallasBaseTest): + + def test_block_spec_unblocked(self): + def show_program_ids( + *, shape, block_shape, grid, indexing_mode: pl.IndexingMode + ): + def kernel(o1_ref): + assert o1_ref.shape == block_shape + o1_ref[...] = jnp.full(o1_ref.shape, pl.program_id(0)) + + return self.pallas_call( + kernel, + jax.ShapeDtypeStruct(shape, dtype=np.int32), + grid=grid, + out_specs=pl.BlockSpec( + block_shape, lambda i: (8 * i, 0), indexing_mode=indexing_mode + ), + )() + + # No padding + pids = show_program_ids( + shape=(16, 128), + block_shape=(8, 128), + grid=(2,), + indexing_mode=pl.Unblocked(), + ) + expected_pids = np.array([[0] * 128] * 8 + [[1] * 128] * 8, dtype=np.int32) + self.assertAllClose(pids, expected_pids) + + if jtu.test_device_matches(["gpu"]) and not self.INTERPRET: + self.skipTest("TODO: padding not implemented on GPU yet") + + # Only high padding + pids = show_program_ids( + shape=(14, 128), + block_shape=(8, 128), + grid=(2,), + indexing_mode=pl.Unblocked(((0, 2), (0, 0))), + ) + expected_pids = np.array([[0] * 128] * 8 + [[1] * 128] * 6, dtype=np.int32) + self.assertAllClose(pids, expected_pids) + + # Both low and high padding + self.skipTest("TODO: low padding not supported yet") + pids = show_program_ids( + shape=(11, 128), + block_shape=(8, 128), + grid=(2,), + indexing_mode=pl.Unblocked(((3, 2), (0, 0))), + ) + expected_pids = np.array([[0] * 128] * 5 + [[1] * 128] * 6, dtype=np.int32) + self.assertAllClose(pids, expected_pids) + + @parameterized.parameters("int32", "float32") + def test_block_spec_unblocked_padding_is_nan(self, dtype_name): + if not self.INTERPRET: + self.skipTest("Only applicable for the interpret mode") + + dtype = np.dtype(dtype_name) + + def copy_kernel(x_ref, o_ref): + o_ref[...] = x_ref[...] + + res = self.pallas_call( + copy_kernel, + jax.ShapeDtypeStruct((6,), dtype=dtype), + grid=(1,), + in_specs=[ + pl.BlockSpec( + (6,), lambda i: 0, indexing_mode=pl.Unblocked(((1, 2),)) + ) + ], + )(np.full((3,), 42, dtype=dtype)) + expected_pad = {"int32": jnp.iinfo(np.int32).min, "float32": np.nan}[ + dtype_name + ] + self.assertAllClose( + res, + np.array( + [expected_pad, 42, 42, 42, expected_pad, expected_pad], dtype=dtype + ), + ) + + def test_unblocked_indexing(self): + shape = (16 * 8, 128) + result_ty = jax.ShapeDtypeStruct((15 * 8, 128), jnp.float32) + + def kernel(x_ref, o_ref): + o_ref[...] = x_ref[pl.ds(0, 8)] + x_ref[pl.ds(8, 8)] + + x = np.arange(np.prod(shape), dtype=np.float32).reshape(shape) + y = self.pallas_call( + kernel, + grid=(15,), + in_specs=( + pl.BlockSpec( + (2 * 8, 128), lambda i: (i * 8, 0), indexing_mode=pl.unblocked + ), + ), + out_specs=pl.BlockSpec((8, 128), lambda i: (i, 0)), + out_shape=result_ty, + )(x) + ref = [] + for i in range(15): + block = x[i * 8 : i * 8 + 2 * 8] + ref.append(block[0:8] + block[8:16]) + ref = np.concatenate(ref, axis=0) + np.testing.assert_array_equal(y, ref) + + def test_unblocked_indexing_with_padding(self): + if jtu.test_device_matches(["gpu"]) and not self.INTERPRET: + self.skipTest("TODO: padding not implemented on GPU yet") + + shape = (8, 128) + result_ty = jax.ShapeDtypeStruct((8, 128), jnp.float32) + + def kernel(x_ref, y_ref): + y_ref[...] = x_ref[pl.ds(0, 8)] + + x = np.arange(np.prod(shape), dtype=np.float32).reshape(shape) + y = self.pallas_call( + kernel, + grid=(1,), + in_specs=( + pl.BlockSpec( + (2 * 8, 128), + lambda i: (0, 0), + indexing_mode=pl.Unblocked(((0, 8), (0, 0))), + ), + ), + out_specs=pl.BlockSpec((8, 128), lambda i: (0, 0)), + out_shape=result_ty, + )(x) + np.testing.assert_array_equal(y, x) + + +class PallasCallUnblockedIndexingInterpretTest(PallasCallUnblockedIndexingTest): INTERPRET = True @@ -768,7 +900,7 @@ def my_index_map(): in_specs=[pl.BlockSpec((4,), my_index_map)]) with self.assertRaisesRegex( ValueError, - "Index map function my_index_map at .*/pallas_test.py:.* for " + "Index map function my_index_map at .*pallas_test.py:.* for " "x_ref must return 1 values to match .*" "Currently returning 2 values."): f(a) @@ -783,7 +915,7 @@ def my_index_map(i): in_specs=[pl.BlockSpec((4,), my_index_map)]) with self.assertRaisesRegex( ValueError, - "Index map function my_index_map at .*/pallas_test.py:.* for " + "Index map function my_index_map at .*pallas_test.py:.* for " "x_ref must return integer scalars. Output\\[0\\] has " "type .*float"): f(a) @@ -798,7 +930,7 @@ def my_index_map(i): in_specs=[pl.BlockSpec((4,), my_index_map)]) with self.assertRaisesRegex( ValueError, - "Index map function my_index_map at .*/pallas_test.py:.* for " + "Index map function my_index_map at .*pallas_test.py:.* for " "x_ref must return integer scalars. Output\\[0\\] has " "type .*int32\\[4\\]"): f(a) @@ -921,7 +1053,7 @@ def the_kernel(): return None self.assertEqual("", ns5.src_info) -class ApiErrorInterpreterTest(ApiErrorTest): +class ApiErrorInterpretTest(ApiErrorTest): INTERPRET = True @@ -957,7 +1089,7 @@ def f(x): self.assertEqual(mem_analysis.temp_size_in_bytes, 0) -class PallasCallInputOutputAliasingInterpreterTest(PallasBaseTest): +class PallasCallInputOutputAliasingInterpretTest(PallasBaseTest): INTERPRET = True @@ -966,7 +1098,7 @@ class PallasControlFlowTest(PallasBaseTest): def setUp(self): super().setUp() if self.INTERPRET: - self.skipTest("Control flow not supported in interpreter mode yet.") + self.skipTest("Control flow not supported in interpret mode yet.") def test_loop_with_float64_carry(self): if jtu.test_device_matches(["tpu"]) and not self.INTERPRET: @@ -1690,7 +1822,7 @@ def outer_body(carry): np.testing.assert_equal(sizes[0, 4], jnp.asarray(key_count - real_keys)) -class PallasControlFlowInterpreterTest(PallasControlFlowTest): +class PallasControlFlowInterpretTest(PallasControlFlowTest): INTERPRET = True AD_TEST_CASES = [ @@ -1713,7 +1845,7 @@ class PallasCallAutodifferentiationTest(PallasBaseTest): def setUp(self): super().setUp() if jtu.test_device_matches(["tpu"]): - # TODO: most tests fail on TPU in non-interpreter mode + # TODO: most tests fail on TPU in non-interpret mode self.skipTest("On TPU the test works only in interpret mode") # TODO: improve tolerance setting self.tol = 1e-5 @@ -1819,11 +1951,11 @@ def softmax_kernel(x_ref, y_ref): # jtu.check_grads(mm, (x, y), modes=["fwd"], order=1) -class PallasCallAutodifferentiationInterpreterTest(PallasCallAutodifferentiationTest): +class PallasCallAutodifferentiationInterpretTest(PallasCallAutodifferentiationTest): INTERPRET = True -class PallasOutOfBoundsInterpreterTest(PallasBaseTest): +class PallasOutOfBoundsInterpretTest(PallasBaseTest): INTERPRET = True def test_interpret_mode_out_of_bounds_access(self): @@ -1901,7 +2033,7 @@ def _(): np.testing.assert_allclose(out, expected, atol=atol) -class PallasCheckifyInterpreterTest(PallasBaseTest): +class PallasCheckifyInterpretTest(PallasBaseTest): # TODO(b/346651778): Support non-interpret mode checkify. INTERPRET = True diff --git a/tests/pallas/pallas_vmap_test.py b/tests/pallas/pallas_vmap_test.py index af8299e31689..fefccfe7eb4f 100644 --- a/tests/pallas/pallas_vmap_test.py +++ b/tests/pallas/pallas_vmap_test.py @@ -21,7 +21,6 @@ from absl.testing import absltest import jax from jax import random -from jax._src.lib import xla_extension from jax._src import config from jax._src import test_util as jtu from jax._src.pallas.pallas_call import _trace_kernel_to_jaxpr @@ -63,7 +62,7 @@ class PallasCallVmapTest(PallasBaseTest): def setUp(self): super().setUp() if jtu.test_device_matches(["tpu"]): - # TODO: most tests fail on TPU in non-interpreter mode + # TODO: most tests fail on TPU in non-interpret mode self.skipTest("On TPU the test works only in interpret mode") def test_vmap_of_simple_kernel(self): @@ -208,11 +207,9 @@ def sin(x_ref, o_ref): np.testing.assert_allclose(out, out_ref, atol=1e-3, rtol=1e-3) @jtu.skip_on_flag("jax_skip_slow_tests", True) + @jtu.skip_on_devices("cpu") # Test is very slow on CPU def test_small_large_vmap(self): - if xla_extension.is_tsan() and jtu.test_device_matches(["cpu"]): - self.skipTest("Test is very slow under TSAN") - - # Catches https://github.com/google/jax/issues/18361 + # Catches https://github.com/jax-ml/jax/issues/18361 @functools.partial( self.pallas_call, out_shape=jax.ShapeDtypeStruct((2,), jnp.int32), grid=(2,)) @@ -229,9 +226,8 @@ def add_one(x_ref, o_ref): np.testing.assert_allclose(out, out_ref) + @jtu.skip_on_devices("cpu") # Test is very slow on CPU def test_small_small_large_vmap(self): - if xla_extension.is_tsan() and jtu.test_device_matches(["cpu"]): - self.skipTest("Test is very slow under TSAN") @functools.partial( self.pallas_call, out_shape=jax.ShapeDtypeStruct((2,), jnp.int32), @@ -250,7 +246,7 @@ def add_one(x_ref, o_ref): np.testing.assert_allclose(out, out_ref) -class PallasCallVmapInterpreterTest(PallasCallVmapTest): +class PallasCallVmapInterpretTest(PallasCallVmapTest): INTERPRET = True def setUp(self): diff --git a/tests/pallas/tpu_ops_test.py b/tests/pallas/tpu_ops_test.py new file mode 100644 index 000000000000..1d57dc164294 --- /dev/null +++ b/tests/pallas/tpu_ops_test.py @@ -0,0 +1,217 @@ +# Copyright 2024 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests for TPU specific operations within pallas_call.""" + +import sys +import unittest + +from absl.testing import absltest +from absl.testing import parameterized +import jax +from jax import lax +from jax._src import test_util as jtu +from jax._src.pallas import utils as pallas_utils +from jax.experimental import pallas as pl +import jax.numpy as jnp +import numpy as np + +if sys.platform != "win32": + from jax.experimental.pallas import tpu as pltpu +else: + pltpu = None + +try: + import hypothesis as hp +except (ModuleNotFoundError, ImportError): + raise unittest.SkipTest("tests depend on hypothesis library") + +import hypothesis.strategies as hps + +jax.config.parse_flags_with_absl() +jtu.setup_hypothesis(max_examples=100) + +_JAX_DTYPES = ( + jnp.float32, + jnp.bfloat16, + jnp.int32, + jnp.int16, + jnp.int8, + jnp.bool_, +) + + +class PallasBaseTest(jtu.JaxTestCase): + INTERPRET = False + + def setUp(self): + if not jtu.test_device_matches(["tpu"]): + self.skipTest("Test only supported on TPU.") + + super().setUp() + + @classmethod + def pallas_call(cls, *args, **kwargs): + return pl.pallas_call(*args, interpret=cls.INTERPRET, **kwargs) + + +class OpsTest(PallasBaseTest): + + @parameterized.product( + from_dtype=_JAX_DTYPES, to_dtype=_JAX_DTYPES, is_ref_bitcast=[False, True] + ) + def test_bitcast(self, from_dtype, to_dtype, is_ref_bitcast): + if not jtu.is_device_tpu_at_least(version=4): + self.skipTest("Run on TPUv4+ to have expected memory layout") + if from_dtype == to_dtype: + self.skipTest("No bitcast needed") + if from_dtype == jnp.bool_ or to_dtype == jnp.bool_: + self.skipTest("Bitcasting with bool is not supported") + + def kernel(x_ref, y_ref): + if is_ref_bitcast: + y_ref[...] = x_ref.bitcast(to_dtype)[...] + else: + y_ref[...] = pltpu.bitcast(x_ref[...], to_dtype) + + m, n = 1, 256 + in_packing = 32 // pallas_utils.dtype_bitwidth(from_dtype) + out_packing = 32 // pallas_utils.dtype_bitwidth(to_dtype) + in_shape = (m * in_packing, n) + out_shape = (m * out_packing, n) + inp = np.arange(np.prod(in_shape), dtype=from_dtype).reshape(in_shape) + out = self.pallas_call( + kernel, + out_shape=jax.ShapeDtypeStruct(out_shape, to_dtype), + )(inp) + if not self.INTERPRET: + out_interpret = pl.pallas_call( + kernel, + out_shape=jax.ShapeDtypeStruct(out_shape, to_dtype), + interpret=True, + )(inp) + self.assertAllClose(out, out_interpret) + + @parameterized.product(is_dynamic=(False, True)) + @hp.given( + axis=hps.integers(0, 3), + shift=hps.integers(0, 3), + stride=hps.one_of(hps.just(None), hps.integers(0, 2)), + # Stride dimension on the minor most is not supported. + stride_axis=hps.one_of(hps.just(None), hps.integers(0, 2)), + ) + @hp.example(3, 9, 1, 2) + @hp.example(3, 9, 2, 2) + @hp.example(0, 9, 0, 1) + @hp.example(0, 9, 1, 1) + def test_roll(self, is_dynamic, axis, shift, stride, stride_axis): + if (stride is None) != (stride_axis is None): + self.skipTest( + "Roll op requires both stride and stride_axis to be either specified" + " or not specified." + ) + if (not jtu.is_device_tpu(version=5)) and stride_axis == 2: + self.skipTest( + "Roll op with stride axis on 2nd minor requires at least TPU v5" + ) + shape = (4, 4, 32, 512) + + def kernel(s_ref, x_ref, y_ref): + amt = s_ref[0] if is_dynamic else shift + y_ref[...] = pltpu.roll( + x_ref[...], amt, axis, stride=stride, stride_axis=stride_axis + ) + + def roll(x, shift, axis, stride=None, stride_axis=None): + assert (stride is None) == (stride_axis is None) + if stride is None: + return np.roll(x, shift, axis) + outputs = [ + np.roll(xs, shift + i * stride, axis) + for i, xs in enumerate(np.split(x, x.shape[stride_axis], stride_axis)) + ] + return np.concatenate(outputs, stride_axis) + + inp = np.arange(np.prod(shape), dtype=jnp.int32).reshape(shape) + ref = roll(inp, shift, axis, stride, stride_axis) + dynamic_shift = jnp.array([abs(shift)], jnp.int32) + for interpret in [False, True]: + out = pl.pallas_call( + kernel, + out_shape=jax.ShapeDtypeStruct(shape, jnp.int32), + grid_spec=pltpu.PrefetchScalarGridSpec(num_scalar_prefetch=1), + interpret=interpret, + )(dynamic_shift, inp) + np.testing.assert_array_equal(out, ref, err_msg=f"{interpret=}") + + def test_interleave_vectors(self): + if not jtu.is_device_tpu_at_least(version=4): + self.skipTest("Expect TPUv4+") + + def kernel(x_ref, y_ref, out_ref): + x = pltpu.bitcast(x_ref[...].astype(jnp.float32), jnp.int32) + y = pltpu.bitcast(y_ref[...].astype(jnp.float32), jnp.int32) + shift = jax.lax.broadcast(16, x.shape) + out_ref[...] = pltpu.bitcast( + y | jax.lax.shift_right_logical(x, shift), jnp.bfloat16 + ) + + m, n = 16, 128 + inp = np.arange(m * n * 2, dtype=jnp.bfloat16).reshape(m, n * 2) + x, y = np.split(inp, 2, axis=1) + out = self.pallas_call( + kernel, + out_shape=jax.ShapeDtypeStruct((m * 2, n), jnp.bfloat16), + )(x, y) + np.testing.assert_array_equal(out, inp.reshape(m * 2, n)) + + def test_tpu_unsigned_int(self): + def body(x_ref, o_ref): + # Test cast from uint16 -> uint32 + ux = lax.convert_element_type(x_ref[...], jnp.uint32) + res = ux + 1 + # Test cast from uint32 -> float32 + o_ref[...] = res.astype(jnp.float32) + out = jax.ShapeDtypeStruct((8, 128), jnp.float32) + x = jnp.arange(8 * 128, dtype=jnp.uint16).reshape((8, 128)) + result = self.pallas_call(body, out_shape=out)(x) + np.testing.assert_array_equal(result, x.astype(jnp.float32) + 1.0) + + def test_tpu_signed_int_upcast(self): + if not jtu.is_device_tpu_at_least(version=5): + self.skipTest("TPUv5+ needed for integer matmuls") + + def body(x_ref, o_ref): + # Test cast from int4 -> int8 + ux = lax.convert_element_type(x_ref[...], jnp.int8) + o_ref[...] = jax.lax.dot(ux, ux, preferred_element_type=jnp.int32) + + out = jax.ShapeDtypeStruct((128, 128), jnp.int32) + x = jnp.arange(128 * 128, dtype=jnp.int4).reshape((128, 128)) + result = self.pallas_call(body, out_shape=out)(x) + np.testing.assert_array_equal( + result, + jax.lax.dot( + x.astype(jnp.int8), + x.astype(jnp.int8), + preferred_element_type=jnp.int32, + ), + ) + + +class OpsInterpretTest(OpsTest): + INTERPRET = True + + +if __name__ == "__main__": + absltest.main() diff --git a/tests/pallas/tpu_pallas_async_test.py b/tests/pallas/tpu_pallas_async_test.py new file mode 100644 index 000000000000..4f9d591dbea4 --- /dev/null +++ b/tests/pallas/tpu_pallas_async_test.py @@ -0,0 +1,759 @@ +# Copyright 2024 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Test TPU-specific uses of Pallas async APIs.""" + +import functools +from typing import Any +from absl.testing import absltest +from absl.testing import parameterized +import jax +from jax._src import test_util as jtu +from jax.experimental import pallas as pl +from jax.experimental import shard_map +from jax.experimental.pallas import tpu as pltpu +import jax.numpy as jnp +import numpy as np + + +jax.config.parse_flags_with_absl() +P = jax.sharding.PartitionSpec +partial = functools.partial + +Future = Any + + +def make_async_copy(target_memory_space=None): + if target_memory_space is None: + target_memory_space = pltpu.ANY + @jax.named_call + def copy_start(x: jax.Array) -> tuple[jax.Array, Future]: + + def copy_start_kernel(x_ref, aliased_x_ref, o_ref, sem): + del aliased_x_ref + pltpu.make_async_copy(x_ref, o_ref, sem).start() + + x, out, sem = pl.pallas_call( + copy_start_kernel, + out_shape=( + jax.ShapeDtypeStruct(x.shape, x.dtype), # aliased x + target_memory_space(x.shape, x.dtype), # out + pltpu.SemaphoreType.DMA(()), + ), + in_specs=[ + pl.BlockSpec(memory_space=pltpu.ANY), + ], + out_specs=( + pl.BlockSpec(memory_space=pltpu.ANY), + pl.BlockSpec(memory_space=target_memory_space), + pl.BlockSpec(memory_space=pltpu.SEMAPHORE), + ), + input_output_aliases={0: 0}, + )(x) + return x, (out, sem) + + @jax.named_call + def copy_done(x: jax.Array, future: Future) -> jax.Array: + out, sem = future + + def copy_done_kernel(x_ref, o_ref, sem, aliased_o_ref): + del aliased_o_ref + pltpu.make_async_copy(x_ref, o_ref, sem).wait() + + out = pl.pallas_call( + copy_done_kernel, + out_shape=target_memory_space(x.shape, x.dtype), # out + in_specs=[ + pl.BlockSpec(memory_space=pltpu.ANY), + pl.BlockSpec(memory_space=target_memory_space), + pl.BlockSpec(memory_space=pltpu.SEMAPHORE), + ], + out_specs=pl.BlockSpec(memory_space=target_memory_space), + input_output_aliases={1: 0}, + )(x, out, sem) + return out + + return copy_start, copy_done + + +def make_async_slice(index: int): + + def async_slice_start_kernel(x_ref, aliased_x_ref, o_ref, sem): + del aliased_x_ref + pltpu.make_async_copy(x_ref.at[index], o_ref, sem).start() + + def async_slice_done_kernel(x_ref, o_ref, sem, aliased_o_ref): + del aliased_o_ref + pltpu.make_async_copy(x_ref.at[index], o_ref, sem).wait() + + @jax.named_call + def async_slice_start(x: jax.Array) -> tuple[jax.Array, Future]: + + x, out, sem = pl.pallas_call( + async_slice_start_kernel, + out_shape=( + jax.ShapeDtypeStruct(x.shape, x.dtype), # aliased x + jax.ShapeDtypeStruct(x.shape[1:], x.dtype), # out + pltpu.SemaphoreType.DMA(()), + ), + in_specs=[ + pl.BlockSpec(memory_space=pltpu.ANY), + ], + out_specs=( + pl.BlockSpec(memory_space=pltpu.ANY), + pl.BlockSpec(memory_space=pltpu.ANY), + pl.BlockSpec(memory_space=pltpu.SEMAPHORE), + ), + input_output_aliases={0: 0}, + )(x) + return x, (out, sem) + + @jax.named_call + def async_slice_done( + x: jax.Array, future: Future + ) -> tuple[jax.Array, Future]: + out, sem = future + out = pl.pallas_call( + async_slice_done_kernel, + out_shape=(jax.ShapeDtypeStruct(x.shape[1:], x.dtype)), # out + in_specs=[ + pl.BlockSpec(memory_space=pltpu.ANY), + pl.BlockSpec(memory_space=pltpu.ANY), + pl.BlockSpec(memory_space=pltpu.SEMAPHORE), + ], + out_specs=(pl.BlockSpec(memory_space=pltpu.ANY)), + input_output_aliases={1: 0}, + )(x, out, sem) + return out + + return async_slice_start, async_slice_done + + +def make_async_dynamic_slice(index: jax.Array): + + def async_dslice_start_kernel(index_ref, x_ref, aliased_x_ref, o_ref, sem): + del aliased_x_ref + pltpu.make_async_copy(x_ref.at[index_ref[0]], o_ref, sem).start() + + def async_dslice_done_kernel(x_ref, o_ref, sem, aliased_o_ref): + del aliased_o_ref + pltpu.make_async_copy(x_ref.at[0], o_ref, sem).wait() + + @jax.named_call + def async_dslice_start(x: jax.Array) -> tuple[jax.Array, Future]: + + x, out, sem = pl.pallas_call( + async_dslice_start_kernel, + out_shape=( + jax.ShapeDtypeStruct(x.shape, x.dtype), # aliased x + jax.ShapeDtypeStruct(x.shape[1:], x.dtype), # out + pltpu.SemaphoreType.DMA(()), + ), + grid_spec=pltpu.PrefetchScalarGridSpec( + num_scalar_prefetch=1, + in_specs=[ + pl.BlockSpec(memory_space=pltpu.ANY), + ], + out_specs=( + pl.BlockSpec(memory_space=pltpu.ANY), + pl.BlockSpec(memory_space=pltpu.ANY), + pl.BlockSpec(memory_space=pltpu.SEMAPHORE), + ), + ), + input_output_aliases={1: 0}, + )(index[None], x) + return x, (out, sem) + + @jax.named_call + def async_dslice_done( + x: jax.Array, future: Future + ) -> tuple[jax.Array, Future]: + out, sem = future + out = pl.pallas_call( + async_dslice_done_kernel, + out_shape=(jax.ShapeDtypeStruct(x.shape[1:], x.dtype)), # out + in_specs=[ + pl.BlockSpec(memory_space=pltpu.ANY), + pl.BlockSpec(memory_space=pltpu.ANY), + pl.BlockSpec(memory_space=pltpu.SEMAPHORE), + ], + out_specs=(pl.BlockSpec(memory_space=pltpu.ANY)), + input_output_aliases={1: 0}, + )(x, out, sem) + return out + + return async_dslice_start, async_dslice_done + + +class PallasCallAsyncCopyTest(parameterized.TestCase): + # TODO(b/368123537): add more tests + + def setUp(self): + super().setUp() + if not jtu.is_device_tpu_at_least(4): + self.skipTest('DMAs only guaranteed to work ou TPU v4+') + + def test_basic_async_copy(self): + @jax.jit + def f(x): + copy_start, copy_done = make_async_copy() + x, fut = copy_start(x) + y = copy_done(x, fut) + return y + + x = jax.random.normal(jax.random.key(0), (8, 128), dtype=jnp.float32) + y = f(x) + np.testing.assert_array_equal(y, x) + + def test_multiple_async_copy(self): + @jax.jit + def f(x): + copy_start, copy_done = make_async_copy() + x, fut = copy_start(x) + x2, fut2 = copy_start(x) + y = copy_done(x, fut) + y2 = copy_done(x2, fut2) + return y, y2 + + x = jax.random.normal(jax.random.key(0), (8, 128), dtype=jnp.float32) + y, y2 = f(x) + np.testing.assert_array_equal(y, x) + np.testing.assert_array_equal(y2, x) + + def test_async_slice(self): + @jax.jit + def f(x): + async_slice_start, async_slice_done = make_async_slice(2) + x, fut = async_slice_start(x) + y = async_slice_done(x, fut) + return y + + x = jax.random.normal(jax.random.key(0), (4, 8, 128), dtype=jnp.float32) + y = f(x) + np.testing.assert_array_equal(y, x[2]) + + def test_async_dynamic_slice(self): + @jax.jit + def f(x, i): + async_slice_start, async_slice_done = make_async_dynamic_slice(i) + x, fut = async_slice_start(x) + y = async_slice_done(x, fut) + return y + + x = jax.random.normal(jax.random.key(0), (4, 8, 128), dtype=jnp.float32) + y = f(x, 2) + np.testing.assert_array_equal(y, x[2]) + + def test_multi_async_dynamic_slice(self): + @jax.jit + def f(x, i, j): + async_slice_start, async_slice_done = make_async_dynamic_slice(i) + async_slice_start2, async_slice_done2 = make_async_dynamic_slice(j) + x, fut = async_slice_start(x) + x2, fut2 = async_slice_start2(x) + y = async_slice_done(x, fut) + y2 = async_slice_done2(x2, fut2) + return y, y2 + + x = jax.random.normal(jax.random.key(0), (4, 8, 128), dtype=jnp.float32) + y, y2 = f(x, 2, 3) + np.testing.assert_array_equal(y, x[2]) + np.testing.assert_array_equal(y2, x[3]) + + def test_basic_async_copy_into_vmem(self): + @jax.jit + def f(x): + copy_start, copy_done = make_async_copy(pltpu.VMEM) + x, fut = copy_start(x) + y = copy_done(x, fut) + return y + + x = jax.random.normal(jax.random.key(0), (8, 128), dtype=jnp.float32) + y = f(x) + np.testing.assert_array_equal(y, x) + + def test_multiple_async_copy_into_vmem(self): + @jax.jit + def f(x): + copy_start, copy_done = make_async_copy(pltpu.VMEM) + x1, fut = copy_start(x) + x2, fut2 = copy_start(x) + y = copy_done(x1, fut) + y2 = copy_done(x2, fut2) + return y, y2 + + x = jax.random.normal(jax.random.key(0), (8, 128), dtype=jnp.float32) + y, y2 = f(x) + np.testing.assert_array_equal(y, x) + np.testing.assert_array_equal(y2, x) + + def test_copy_in_a_loop(self): + + @jax.jit + def f(x): + def body(_, carry): + x = carry + copy_start, copy_done = make_async_copy() + x, fut = copy_start(x) + y = copy_done(x, fut) + return y + x = jax.lax.fori_loop(0, x.shape[0], body, x) + return x + + x = jax.random.normal(jax.random.key(0), (16, 8, 128), dtype=jnp.float32) + y = f(x) + np.testing.assert_array_equal(y, x) + + def test_staggered_copy_in_a_loop(self): + + @jax.jit + def f(x): + copy_start, copy_done = make_async_copy() + x, fut = copy_start(x) + def body(_, carry): + x, fut = carry + y = copy_done(x, fut) + y, fut = copy_start(y) + return y, fut + # We *must* use unroll > 2 here because of aliasing constraints. XLA will + # introduce copies of the active buffer with unroll=1. + y, fut = jax.lax.fori_loop(0, x.shape[0] - 1, body, (x, fut), unroll=2) + x = copy_done(y, fut) + return x + + x = jax.random.normal(jax.random.key(0), (16, 8, 128), dtype=jnp.float32) + y = f(x) + np.testing.assert_array_equal(y, x) + + def test_full_copy_in_a_loop(self): + + @jax.jit + def f(x): + y = jnp.zeros_like(x) + def body(i, carry): + x, ys = carry + copy_start, copy_done = make_async_dynamic_slice(i) + x, fut = copy_start(x) + y = copy_done(x, fut) + ys = ys.at[i].set(y) + return x, ys + _, y = jax.lax.fori_loop(0, x.shape[0], body, (x, y)) + return y + + x = jax.random.normal(jax.random.key(0), (16, 8, 128), dtype=jnp.float32) + y = f(x) + np.testing.assert_array_equal(y, x) + + def test_staggered_full_copy_in_a_loop(self): + + @jax.jit + def f(x): + y = jnp.zeros_like(x) + copy_start, _ = make_async_dynamic_slice(jnp.array(0)) + x, fut = copy_start(x) + def body(i, carry): + x, fut, ys = carry + _, copy_done = make_async_dynamic_slice(i) + y = copy_done(x, fut) + copy_start, _ = make_async_dynamic_slice(i + 1) + ys = ys.at[i].set(y) + x, fut = copy_start(x) + return x, fut, ys + # We can use unroll=1 here because we have the ys.at[i].set(y) in the + # middle + x, fut, ys = jax.lax.fori_loop(0, x.shape[0] - 1, body, (x, fut, y), + unroll=1) + _, copy_done = make_async_dynamic_slice(x.shape[0] - 1) + y = copy_done(x, fut) + ys = ys.at[x.shape[0] - 1].set(y) + return ys + + x = jax.random.normal(jax.random.key(0), (16, 8, 128), dtype=jnp.float32) + y = f(x) + np.testing.assert_array_equal(y, x) + + +def make_async_remote_copy(axis_name: str, direction: str = 'right', + target_memory_space=None): + if target_memory_space is None: + target_memory_space = pltpu.ANY + @jax.named_call + def copy_start(x: jax.Array) -> tuple[jax.Array, Future]: + + def copy_start_kernel(x_ref, aliased_x_ref, o_ref, send_sem, recv_sem): + del aliased_x_ref + axis_size = jax.lax.psum(1, axis_name) + left_neighbor = jax.lax.rem( + jax.lax.axis_index(axis_name) - 1 + axis_size, axis_size + ) + right_neighbor = jax.lax.rem( + jax.lax.axis_index(axis_name) + 1, axis_size + ) + if direction == 'right': + src_neighbor = left_neighbor + dst_neighbor = right_neighbor + else: + src_neighbor = right_neighbor + dst_neighbor = left_neighbor + barrier_sem = pltpu.get_barrier_semaphore() + pltpu.semaphore_signal(barrier_sem, device_id=src_neighbor, core_index=0) + pltpu.semaphore_wait(barrier_sem, 1) + pltpu.make_async_remote_copy( + x_ref, o_ref, send_sem, recv_sem, device_id=dst_neighbor, + ).start() + + x, out, send_sem, recv_sem = pl.pallas_call( + copy_start_kernel, + out_shape=( + jax.ShapeDtypeStruct(x.shape, x.dtype), # aliased x + target_memory_space(x.shape, x.dtype), # out + pltpu.SemaphoreType.DMA(()), # send_sem + pltpu.SemaphoreType.DMA(()), # recv_sem + ), + in_specs=[ + pl.BlockSpec(memory_space=pltpu.ANY), + ], + out_specs=( + pl.BlockSpec(memory_space=pltpu.ANY), + pl.BlockSpec(memory_space=target_memory_space), + pl.BlockSpec(memory_space=pltpu.SEMAPHORE), + pl.BlockSpec(memory_space=pltpu.SEMAPHORE), + ), + input_output_aliases={0: 0}, + compiler_params=pltpu.TPUCompilerParams(collective_id=0), + )(x) + return x, (out, send_sem, recv_sem) + + @jax.named_call + def send_done(x: jax.Array, future: Future) -> jax.Array: + _, send_sem, _ = future + + def send_done_kernel(x_ref, send_sem, aliased_o_ref): + del aliased_o_ref + pltpu.make_async_copy(x_ref, x_ref, send_sem).wait() + + x = pl.pallas_call( + send_done_kernel, + out_shape=jax.ShapeDtypeStruct(x.shape, x.dtype), # out + in_specs=[ + pl.BlockSpec(memory_space=pltpu.ANY), + pl.BlockSpec(memory_space=pltpu.SEMAPHORE), + ], + out_specs=pl.BlockSpec(memory_space=pltpu.ANY), + input_output_aliases={0: 0}, + )(x, send_sem) + return x + + @jax.named_call + def recv_done(x: jax.Array, future: Future) -> jax.Array: + out, _, recv_sem = future + + def send_done_kernel(x_ref, o_ref, send_sem, aliased_o_ref): + del aliased_o_ref + pltpu.make_async_copy(x_ref, o_ref, send_sem).wait() + + out = pl.pallas_call( + send_done_kernel, + out_shape=target_memory_space(x.shape, x.dtype), # out + in_specs=[ + pl.BlockSpec(memory_space=pltpu.ANY), + pl.BlockSpec(memory_space=target_memory_space), + pl.BlockSpec(memory_space=pltpu.SEMAPHORE), + ], + out_specs=pl.BlockSpec(memory_space=target_memory_space), + input_output_aliases={1: 0}, + )(x, out, recv_sem) + return out + + return copy_start, send_done, recv_done + + +def make_bidi_collective_permute(axis_name: str): + @jax.named_call + def copy_start(x: jax.Array) -> tuple[jax.Array, Future]: + + def copy_start_kernel(x_ref, aliased_x_ref, o_ref, left_sems, right_sems): + del aliased_x_ref + axis_size = jax.lax.psum(1, axis_name) + left_neighbor = jax.lax.rem( + jax.lax.axis_index(axis_name) - 1 + axis_size, axis_size + ) + right_neighbor = jax.lax.rem( + jax.lax.axis_index(axis_name) + 1, axis_size + ) + barrier_sem = pltpu.get_barrier_semaphore() + pltpu.semaphore_signal(barrier_sem, device_id=left_neighbor, core_index=0) + pltpu.semaphore_signal( + barrier_sem, device_id=right_neighbor, core_index=0 + ) + pltpu.semaphore_wait(barrier_sem, 2) + assert x.shape[0] % 2 == 0, x.shape + pltpu.make_async_remote_copy( + x_ref.at[pl.ds(0, x.shape[0] // 2)], + o_ref.at[pl.ds(0, x.shape[0] // 2)], + right_sems[0], + right_sems[1], + device_id=right_neighbor, + ).start() + pltpu.make_async_remote_copy( + x_ref.at[pl.ds(x.shape[0] // 2, x.shape[0] // 2)], + o_ref.at[pl.ds(x.shape[0] // 2, x.shape[0] // 2)], + left_sems[0], + left_sems[1], + device_id=left_neighbor, + ).start() + + x, out, left_sems, right_sems = pl.pallas_call( + copy_start_kernel, + out_shape=( + jax.ShapeDtypeStruct(x.shape, x.dtype), # aliased x + pltpu.ANY(x.shape, x.dtype), # out + (pltpu.SemaphoreType.DMA(()),) * 2, # left_sems + (pltpu.SemaphoreType.DMA(()),) * 2, # right_sems + ), + in_specs=[ + pl.BlockSpec(memory_space=pltpu.ANY), + ], + out_specs=( + pl.BlockSpec(memory_space=pltpu.ANY), + pl.BlockSpec(memory_space=pltpu.ANY), + (pl.BlockSpec(memory_space=pltpu.SEMAPHORE),) * 2, + (pl.BlockSpec(memory_space=pltpu.SEMAPHORE),) * 2, + ), + input_output_aliases={0: 0}, + compiler_params=pltpu.TPUCompilerParams(collective_id=0), + )(x) + return x, (out, left_sems, right_sems) + + @jax.named_call + def send_done(x: jax.Array, future: Future) -> jax.Array: + _, (send_left_sem, _), (send_right_sem, _) = future + + def send_done_kernel(x_ref, send_left_sem, send_right_sem, aliased_o_ref): + del aliased_o_ref + pltpu.make_async_copy( + x_ref.at[x_ref.shape[0] // 2 :], + x_ref.at[x_ref.shape[0] // 2 :], + send_left_sem, + ).wait() + pltpu.make_async_copy( + x_ref.at[x_ref.shape[0] // 2 :], + x_ref.at[x_ref.shape[0] // 2 :], + send_right_sem, + ).wait() + + x = pl.pallas_call( + send_done_kernel, + out_shape=jax.ShapeDtypeStruct(x.shape, x.dtype), # out + in_specs=[ + pl.BlockSpec(memory_space=pltpu.ANY), + pl.BlockSpec(memory_space=pltpu.SEMAPHORE), + pl.BlockSpec(memory_space=pltpu.SEMAPHORE), + ], + out_specs=pl.BlockSpec(memory_space=pltpu.ANY), + input_output_aliases={0: 0}, + )(x, send_left_sem, send_right_sem) + return x + + @jax.named_call + def recv_done(x: jax.Array, future: Future) -> jax.Array: + out, (_, recv_left_sem), (_, recv_right_sem) = future + + def recv_done_kernel(o_ref, x_ref, recv_left_sem, recv_right_sem, + aliased_o_ref): + del aliased_o_ref + pltpu.make_async_copy( + x_ref.at[o_ref.shape[0] // 2 :], + o_ref.at[o_ref.shape[0] // 2 :], + recv_left_sem, + ).wait() + pltpu.make_async_copy( + x_ref.at[o_ref.shape[0] // 2 :], + o_ref.at[o_ref.shape[0] // 2 :], + recv_right_sem, + ).wait() + + out = pl.pallas_call( + recv_done_kernel, + out_shape=jax.ShapeDtypeStruct(x.shape, x.dtype), # out + in_specs=[ + pl.BlockSpec(memory_space=pltpu.ANY), + pl.BlockSpec(memory_space=pltpu.ANY), + pl.BlockSpec(memory_space=pltpu.SEMAPHORE), + pl.BlockSpec(memory_space=pltpu.SEMAPHORE), + ], + out_specs=pl.BlockSpec(memory_space=pltpu.ANY), + input_output_aliases={0: 0}, + )(out, x, recv_left_sem, recv_right_sem) + return out + return copy_start, send_done, recv_done + + +class PallasCallRemoteAsyncCopyTest(parameterized.TestCase): + + def setUp(self): + super().setUp() + if not jtu.is_device_tpu_at_least(4): + self.skipTest('DMAs only guaranteed to work ou TPU v4+') + if jax.device_count() < 2: + self.skipTest('Test only works with >2 devices') + + def test_basic_remote_copy(self): + + mesh = jax.make_mesh((jax.device_count(),), ('x',)) + + @jax.jit + @partial( + shard_map.shard_map, mesh=mesh, in_specs=(P('x'),), out_specs=P('x'), + check_rep=False, + ) + def f(x): + copy_start, send_done, recv_done = make_async_remote_copy('x') + x, fut = copy_start(x) + x = send_done(x, fut) + y = recv_done(x, fut) + return y + + x = jax.random.normal( + jax.random.key(0), (jax.device_count(), 8, 128), dtype=jnp.float32 + ) + y = f(x) + expected = jnp.roll(x, shift=1, axis=0) + np.testing.assert_array_equal(y, expected) + + def test_multi_remote_copy(self): + + mesh = jax.make_mesh((jax.device_count(),), ('x',)) + + @jax.jit + @partial( + shard_map.shard_map, mesh=mesh, in_specs=(P('x'),), out_specs=P('x'), + check_rep=False, + ) + def f(x): + copy_start, send_done, recv_done = make_async_remote_copy( + 'x', direction='right' + ) + copy_start2, send_done2, recv_done2 = make_async_remote_copy( + 'x', direction='left' + ) + x, fut = copy_start(x) + x, fut2 = copy_start2(x) + x = send_done(x, fut) + x = send_done2(x, fut2) + y = recv_done(x, fut) + y2 = recv_done2(x, fut2) + return y, y2 + + x = jax.random.normal( + jax.random.key(0), (jax.device_count(), 8, 128), dtype=jnp.float32 + ) + y, y2 = f(x) + y_expected = jnp.roll(x, shift=1, axis=0) + y2_expected = jnp.roll(x, shift=-1, axis=0) + np.testing.assert_array_equal(y, y_expected) + np.testing.assert_array_equal(y2, y2_expected) + + def test_basic_collective_permute_loop(self): + + mesh = jax.make_mesh((jax.device_count(),), ('x',)) + + @jax.jit + @partial( + shard_map.shard_map, mesh=mesh, in_specs=(P('x'),), out_specs=P('x'), + check_rep=False, + ) + def f(x): + copy_start, send_done, recv_done = make_async_remote_copy('x') + def body(_, x): + x, fut = copy_start(x) + x = send_done(x, fut) + y = recv_done(x, fut) + return y + # Send all the way around except for one step + return jax.lax.fori_loop(0, jax.device_count() - 1, body, x) + x = jax.random.normal( + jax.random.key(0), (jax.device_count(), 8, 128), dtype=jnp.float32 + ) + y = f(x) + expected = jnp.roll(x, shift=-1, axis=0) + np.testing.assert_array_equal(y, expected) + + def test_staggered_collective_permute_loop(self): + + mesh = jax.make_mesh((jax.device_count(),), ('x',)) + + @jax.jit + @partial( + shard_map.shard_map, mesh=mesh, in_specs=(P('x'),), out_specs=P('x'), + check_rep=False, + ) + def f(x): + assert x.shape[0] == 1 + copy_start, send_done, recv_done = make_async_remote_copy('x') + x, fut = copy_start(x) + def body(_, carry): + x, fut = carry + x = send_done(x, fut) + y = recv_done(x, fut) + y, fut = copy_start(y) + return y, fut + # Send all the way around except for one step + x, fut = jax.lax.fori_loop(0, jax.device_count() - 2, body, (x, fut), + unroll=2) + x = send_done(x, fut) + y = recv_done(x, fut) + return y + + n_devices = jax.device_count() + x = jax.random.normal( + jax.random.key(0), (n_devices, 8, 128), dtype=jnp.float32 + ) + y = f(x) + expected = jnp.roll(x, shift=-1, axis=0) + np.testing.assert_array_equal(y, expected) + + def test_bidi_collective_permute_loop(self): + mesh = jax.make_mesh((jax.device_count(),), ('x',)) + + @jax.jit + @partial( + shard_map.shard_map, mesh=mesh, in_specs=(P('x'),), out_specs=P('x'), + check_rep=False, + ) + def f(x): + assert x.shape[0] == 1 + x = x[0] + copy_start, send_done, recv_done = make_bidi_collective_permute('x') + def body(_, x): + x, fut = copy_start(x) + x = send_done(x, fut) + y = recv_done(x, fut) + return y + # Send all the way around except for one step + y = jax.lax.fori_loop(0, jax.device_count() - 1, body, x) + return y[None] + x = jax.random.normal( + jax.random.key(0), (jax.device_count(), 16, 128), dtype=jnp.float32 + ) + y = f(x) + expected = jnp.concatenate([ + jnp.roll(x[:, :8], axis=0, shift=-1), + jnp.roll(x[:, 8:], axis=0, shift=1), + ], axis=1) + np.testing.assert_array_equal(y, expected) + + +if __name__ == "__main__": + absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/tests/pallas/tpu_pallas_pipeline_test.py b/tests/pallas/tpu_pallas_pipeline_test.py index e61f0dfa56b3..ca64275d3f09 100644 --- a/tests/pallas/tpu_pallas_pipeline_test.py +++ b/tests/pallas/tpu_pallas_pipeline_test.py @@ -139,6 +139,8 @@ def setUp(self): ('hbm', pltpu.TPUMemorySpace.ANY), ) def test_pipeline_matmul(self, memory_space): + # TODO(b/358121809): Re-enable this test once the bug is fixed. + self.skipTest('Broken test.') k1, k2 = jax.random.split(jax.random.key(0)) x = jax.random.uniform(k1, (512, 512)) y = jax.random.uniform(k2, (512, 512)) @@ -184,6 +186,8 @@ def matmul_kernel(x_ref, y_ref, z_ref): ('hbm', pltpu.TPUMemorySpace.ANY), ) def test_double_pipeline_matmul(self, memory_space): + # TODO(b/358121809): Re-enable this test once the bug is fixed. + self.skipTest('Broken test.') k1, k2 = jax.random.split(jax.random.key(0)) x = jax.random.uniform(k1, (512, 512)) y = jax.random.uniform(k2, (512, 512)) @@ -535,6 +539,8 @@ def reference(x, y): ) def test_pipeline_throughput_optimized_allgather_matmul( self, memory_space, out_dtype, n_tiles, m_tiles, k_tiles): + # TODO(b/358121809): Re-enable this test once the bug is fixed. + self.skipTest('Broken test.') input_dtype = out_dtype num_devices = jax.local_device_count() @@ -1065,6 +1071,8 @@ def reference(x, y): ) def test_pipeline_throughput_optimized_matmul_reducescatter( self, memory_space, out_dtype, n_tiles, m_tiles, k_tiles): + # TODO(b/358121809): Re-enable this test once the bug is fixed. + self.skipTest('Broken test.') input_dtype = jnp.float32 num_devices = jax.device_count() @@ -1325,6 +1333,8 @@ def setUp(self): super().setUp() def test_can_partition_nondivisible_grid_with_dynamic_dimensions(self): + # TODO(b/358121809): Re-enable this test once the bug is fixed. + self.skipTest('Broken test.') def mul_pipeline(x_ref, y_ref): y_ref[...] = x_ref[...] * 2 @@ -1359,6 +1369,8 @@ def mul_kernel(iters_ref, x_ref, y_ref): np.testing.assert_allclose(func(jnp.array([5]), x), x * 2) def test_megacore_mul(self): + # TODO(b/358121809): Re-enable this test once the bug is fixed. + self.skipTest('Broken test.') x = jax.random.uniform(jax.random.key(0), (512, 512)) def matmul_pipeline(x_ref, y_ref): @@ -1396,6 +1408,8 @@ def matmul_kernel(x_ref, y_ref): (768, 1024, 768, 256, 512, 256), ) def test_megacore_matmul(self, m, k, n, bm, bk, bn): + # TODO(b/358121809): Re-enable this test once the bug is fixed. + self.skipTest('Broken test.') k1, k2 = jax.random.split(jax.random.key(42)) x = jax.random.uniform(k1, (m, k)) y = jax.random.uniform(k2, (k, n)) diff --git a/tests/pallas/tpu_pallas_test.py b/tests/pallas/tpu_pallas_test.py index 481e301e2db9..9a81f3196ba2 100644 --- a/tests/pallas/tpu_pallas_test.py +++ b/tests/pallas/tpu_pallas_test.py @@ -16,6 +16,7 @@ import contextlib import functools +import gc import io import math import re @@ -30,6 +31,7 @@ from jax._src.interpreters import partial_eval as pe from jax._src.lib import xla_extension from jax._src.pallas.pallas_call import _trace_kernel_to_jaxpr +from jax._src.state import utils as state_utils from jax.experimental import mesh_utils from jax.experimental import mosaic from jax.experimental import pallas as pl @@ -413,7 +415,9 @@ def kernel(s, x): ), grid=8, ), - compiler_params=dict(mosaic=dict(allow_input_fusion=[False, True])), + compiler_params=pltpu.TPUCompilerParams( + allow_input_fusion=[False, True] + ), )(s, x) first = x[0, ...].reshape((1, 8, 8, -1))[:, s[0, ...]].reshape(x.shape[1:]) @@ -453,7 +457,7 @@ def f(x): self.assertEqual(mem_analysis.alias_size_in_bytes, expected_num_bytes) -class PallasCallScalarPrefetchInterpreterTest(PallasCallScalarPrefetchTest): +class PallasCallScalarPrefetchInterpretTest(PallasCallScalarPrefetchTest): INTERPRET: bool = True @@ -703,7 +707,7 @@ def dynamic_kernel(steps, x): np.testing.assert_array_equal(dynamic_kernel(np.int32(4), x), x[8:16]) -class PallasCallDynamicGridInterpreterTest(PallasCallDynamicGridTest): +class PallasCallDynamicGridInterpretTest(PallasCallDynamicGridTest): INTERPRET = True @@ -722,8 +726,8 @@ def kernel(x_ref, y_ref): x = jnp.ones((8, 128), dtype=jnp.float32) y = self.pallas_call( kernel, - in_specs=[pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY)], - out_specs=pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY), + in_specs=[pl.BlockSpec(memory_space=pl.ANY)], + out_specs=pl.BlockSpec(memory_space=pl.ANY), out_shape=jax.ShapeDtypeStruct((8, 128), jnp.float32), )(x) jax.block_until_ready(y) @@ -883,8 +887,12 @@ def kernel(y_ref): def body(dma_sems, sems): self.assertTupleEqual(dma_sems.shape, (4,)) self.assertTupleEqual(sems.shape, (3,)) - self.assertTrue(jnp.issubdtype(dma_sems.dtype, pltpu.dma_semaphore)) - self.assertTrue(jnp.issubdtype(sems.dtype, pltpu.semaphore)) + if self.INTERPRET: + self.assertTrue(jnp.issubdtype(dma_sems.dtype, jnp.integer)) + self.assertTrue(jnp.issubdtype(sems.dtype, jnp.integer)) + else: + self.assertTrue(jnp.issubdtype(dma_sems.dtype, pltpu.dma_semaphore)) + self.assertTrue(jnp.issubdtype(sems.dtype, pltpu.semaphore)) pl.run_scoped( body, pltpu.SemaphoreType.DMA((4,)), pltpu.SemaphoreType.REGULAR((3,)) ) @@ -898,10 +906,13 @@ def test_can_allocate_scratch_semaphore_array(self): def kernel(y_ref, dma_sems, sems): self.assertTupleEqual(dma_sems.shape, (4,)) self.assertTupleEqual(sems.shape, (3,)) - self.assertTrue(jnp.issubdtype(dma_sems.dtype, pltpu.dma_semaphore)) - self.assertTrue(jnp.issubdtype(sems.dtype, pltpu.semaphore)) + if self.INTERPRET: + self.assertTrue(jnp.issubdtype(dma_sems.dtype, jnp.integer)) + self.assertTrue(jnp.issubdtype(sems.dtype, jnp.integer)) + else: + self.assertTrue(jnp.issubdtype(dma_sems.dtype, pltpu.dma_semaphore)) + self.assertTrue(jnp.issubdtype(sems.dtype, pltpu.semaphore)) - # TODO(b/345534352): Add interpret support for REGULAR semaphore. jax.block_until_ready( self.pallas_call( kernel, @@ -961,7 +972,7 @@ def body(sems): pl.run_scoped(body, pltpu.SemaphoreType.REGULAR((3,))) # TODO(b/345534352): Add interpret support for semaphore signal/wait. - jax.block_until_ready(pl.pallas_call( + jax.block_until_ready(self.pallas_call( kernel, out_shape=jax.ShapeDtypeStruct((8, 128), jnp.float32), )()) @@ -985,15 +996,13 @@ def body(sems): pltpu.semaphore_wait(sems.at[i, 2]) pl.run_scoped(body, pltpu.SemaphoreType.REGULAR((4, 3))) - # TODO(b/345534352): Add interpret support for semaphore signal/wait. jax.block_until_ready( - pl.pallas_call( + self.pallas_call( kernel, in_specs=[], out_specs=pl.BlockSpec((8, 128), lambda i: (0, 0)), out_shape=jax.ShapeDtypeStruct((8, 128), jnp.float32), grid=4, - debug=True, )() ) @@ -1011,9 +1020,8 @@ def body(sems): pl.run_scoped(body, pltpu.SemaphoreType.REGULAR((m, n))) - # TODO(b/345534352): Add interpret support for semaphore signal/wait. y = jax.block_until_ready( - pl.pallas_call( + self.pallas_call( kernel, out_specs=pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.SMEM), out_shape=jax.ShapeDtypeStruct((m, n), jnp.int32), @@ -1024,21 +1032,20 @@ def body(sems): ) def test_can_read_dma_semaphore(self): - def kernel(x_hbm_ref, y_hbm_ref, sem_val_ref, dma_sem): sem_val_ref[0, 0] = 123 pltpu.async_copy(x_hbm_ref, y_hbm_ref, dma_sem).wait() sem_val_ref[0, 0] = pltpu.semaphore_read(dma_sem) + x = jnp.arange(8 * 128, dtype=jnp.int32).reshape((8, 128)) - # TODO(b/345534352): Add interpret support for semaphore signal/wait. y, sem_val = jax.block_until_ready( - pl.pallas_call( + self.pallas_call( kernel, grid_spec=pltpu.PrefetchScalarGridSpec( num_scalar_prefetch=0, - in_specs=[pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY)], + in_specs=[pl.BlockSpec(memory_space=pl.ANY)], out_specs=[ - pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY), + pl.BlockSpec(memory_space=pl.ANY), pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.SMEM), ], scratch_shapes=[pltpu.SemaphoreType.DMA], @@ -1062,9 +1069,9 @@ def body(sem): y = self.pallas_call( kernel, in_specs=[ - pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY), + pl.BlockSpec(memory_space=pl.ANY), ], - out_specs=pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY), + out_specs=pl.BlockSpec(memory_space=pl.ANY), out_shape=jax.ShapeDtypeStruct((8, 128), jnp.float32), )(x) np.testing.assert_array_equal(y, x) @@ -1076,15 +1083,14 @@ def body(sem): sem).wait() pl.run_scoped(body, pltpu.SemaphoreType.DMA((1,))) - # TODO(b/345534352): Add interpret support for nonscalar semaphores. with self.assertRaisesRegex(ValueError, 'Cannot signal'): x = jnp.arange(8 * 128.).reshape((8, 128)) - pl.pallas_call( + self.pallas_call( kernel, in_specs=[ - pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY), + pl.BlockSpec(memory_space=pl.ANY), ], - out_specs=pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY), + out_specs=pl.BlockSpec(memory_space=pl.ANY), out_shape=jax.ShapeDtypeStruct((8, 128), jnp.float32), )(x) @@ -1096,13 +1102,12 @@ def body(sem): pl.run_scoped(body, pltpu.SemaphoreType.DMA((1,))) x = jnp.arange(8 * 128.).reshape((8, 128)) - # TODO(b/345534352): Add interpret support for nonscalar semaphores. - y = pl.pallas_call( + y = self.pallas_call( kernel, in_specs=[ - pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY), + pl.BlockSpec(memory_space=pl.ANY), ], - out_specs=pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY), + out_specs=pl.BlockSpec(memory_space=pl.ANY), out_shape=jax.ShapeDtypeStruct((8, 128), jnp.float32), )(x) np.testing.assert_array_equal(y, x) @@ -1121,9 +1126,9 @@ def body(sem): y = self.pallas_call( kernel, in_specs=[ - pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY), + pl.BlockSpec(memory_space=pl.ANY), ], - out_specs=pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY), + out_specs=pl.BlockSpec(memory_space=pl.ANY), out_shape=jax.ShapeDtypeStruct((2, 8, 128), jnp.float32), grid=(2,), )(x) @@ -1142,7 +1147,7 @@ def body(x_ref, sem): y = self.pallas_call( kernel, in_specs=[ - pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY), + pl.BlockSpec(memory_space=pl.ANY), ], out_shape=jax.ShapeDtypeStruct((8, 128), jnp.float32), )(x) @@ -1152,14 +1157,14 @@ def test_vmem_hbm_dma(self): def kernel(x_ref, y_hbm_ref): def body(y_ref, sem): y_ref[...] = x_ref[...] - pltpu.async_copy(y_hbm_ref, y_ref, sem).wait() + pltpu.async_copy(y_ref, y_hbm_ref, sem).wait() pl.run_scoped( body, pltpu.VMEM((8, 128), jnp.float32), pltpu.SemaphoreType.DMA ) x = jnp.arange(8 * 128.).reshape((8, 128)) y = self.pallas_call( kernel, - out_specs=pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY), + out_specs=pl.BlockSpec(memory_space=pl.ANY), out_shape=jax.ShapeDtypeStruct((8, 128), jnp.float32), )(x) np.testing.assert_allclose(y, x) @@ -1179,8 +1184,8 @@ def body(x_ref, y_ref, sem): x = jnp.arange(8 * 128.).reshape((8, 128)) y = self.pallas_call( kernel, - in_specs=[pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY)], - out_specs=pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY), + in_specs=[pl.BlockSpec(memory_space=pl.ANY)], + out_specs=pl.BlockSpec(memory_space=pl.ANY), out_shape=jax.ShapeDtypeStruct((8, 128), jnp.float32), )(x) np.testing.assert_allclose(y, x) @@ -1197,7 +1202,7 @@ def body(x_ref, sem): y = self.pallas_call( kernel, in_specs=[ - pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY), + pl.BlockSpec(memory_space=pl.ANY), ], out_shape=jax.ShapeDtypeStruct((8, 128), jnp.float32), )(x) @@ -1218,7 +1223,7 @@ def body(y_ref, sem): in_specs=[ pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.SMEM), ], - out_specs=pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY), + out_specs=pl.BlockSpec(memory_space=pl.ANY), out_shape=jax.ShapeDtypeStruct((1, 2), jnp.float32), )(x) expected = jnp.zeros_like(x[0:1, 0:2]).at[0, 1].set(x[4, 4]) @@ -1256,7 +1261,7 @@ def body(sem): y = self.pallas_call( kernel, in_specs=[ - pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY), + pl.BlockSpec(memory_space=pl.ANY), ], out_specs=pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.VMEM), out_shape=jax.ShapeDtypeStruct((16, 128), jnp.float32), @@ -1279,7 +1284,7 @@ def body(sem): y = self.pallas_call( kernel, in_specs=[ - pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY), + pl.BlockSpec(memory_space=pl.ANY), ], out_specs=pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.VMEM), out_shape=jax.ShapeDtypeStruct((16, 128), jnp.float32), @@ -1287,6 +1292,9 @@ def body(sem): np.testing.assert_allclose(y, x.reshape((16, 128))) def test_hbm_vmem_dma_multiple_indexing(self): + if self.INTERPRET: + self.skipTest('Multiple indexing not supported in interpret mode.') + def kernel(x_hbm_ref, y_ref): def body(sem): for i in range(3): @@ -1305,7 +1313,7 @@ def body(sem): y = self.pallas_call( kernel, in_specs=[ - pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY), + pl.BlockSpec(memory_space=pl.ANY), ], out_specs=pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.VMEM), out_shape=jax.ShapeDtypeStruct((3, 16, 128), jnp.float32), @@ -1313,6 +1321,9 @@ def body(sem): np.testing.assert_allclose(y, x.reshape((3, 16, 128))) def test_cannot_squeeze_lane_sublane(self): + if self.INTERPRET: + self.skipTest('Only works on Mosaic TPU.') + def kernel(x_hbm_ref, y_ref): def body(sem): dma1 = pltpu.async_copy( @@ -1329,17 +1340,13 @@ def body(sem): _ = self.pallas_call( kernel, in_specs=[ - pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY), + pl.BlockSpec(memory_space=pl.ANY), ], out_specs=pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.VMEM), out_shape=jax.ShapeDtypeStruct((16, 128), jnp.float32), )(x) - @parameterized.named_parameters( - ('', False), - ('_interpret', True), - ) - def test_hoisted_scratch_space(self, interpret): + def test_hoisted_scratch_space(self): def kernel(x_ref, y_ref, scratch_ref): i = pl.program_id(0) @pl.when(i == 0) @@ -1352,7 +1359,7 @@ def _(): y_ref[...] = scratch_ref[...] x = jnp.arange(8 * 128.).reshape((8, 128)) - y = pl.pallas_call( + y = self.pallas_call( kernel, grid_spec=pltpu.PrefetchScalarGridSpec( num_scalar_prefetch=0, @@ -1363,7 +1370,6 @@ def _(): out_specs=pl.BlockSpec((8, 128), lambda i: (0, 0)), grid=(3,), ), - interpret=interpret, out_shape=jax.ShapeDtypeStruct((8, 128), jnp.float32), )(x) np.testing.assert_array_equal(y, x + 3) @@ -1398,14 +1404,13 @@ def kernel(x_bbm_ref, y_ref, sem, dma_sem): pltpu.semaphore_wait(sem) pltpu.async_copy(x_bbm_ref, y_ref, dma_sem).wait() - # TODO(b/345534352): Add interpret support for semaphore signal/wait. x = jnp.arange(8 * 128.).reshape((8, 128)) - y = pl.pallas_call( + y = self.pallas_call( kernel, grid_spec=pltpu.PrefetchScalarGridSpec( num_scalar_prefetch=0, in_specs=[ - pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY), + pl.BlockSpec(memory_space=pl.ANY), ], scratch_shapes=[pltpu.SemaphoreType.REGULAR, pltpu.SemaphoreType.DMA], @@ -1418,6 +1423,9 @@ def kernel(x_bbm_ref, y_ref, sem, dma_sem): def test_large_array_indexing(self): n = 6 dtype = jnp.bfloat16 + # This test sometimes OOMs on smaller chips. We garbage collect + # to increase the chance there is 6GB memory available. + gc.collect() x = jax.lax.broadcasted_iota(dtype, (n, 1024 * 1024, 512), 0) def kernel(index, x, y, sem): @@ -1428,9 +1436,9 @@ def kernel(index, x, y, sem): num_scalar_prefetch=1, in_specs=[ pl.BlockSpec( - memory_space=pltpu.TPUMemorySpace.ANY)], + memory_space=pl.ANY)], out_specs=pl.BlockSpec( - memory_space=pltpu.TPUMemorySpace.ANY), + memory_space=pl.ANY), scratch_shapes=[pltpu.SemaphoreType.DMA], ), out_shape=jax.ShapeDtypeStruct(x.shape[1:], dtype), @@ -1441,28 +1449,41 @@ def kernel(index, x, y, sem): np.testing.assert_array_equal(y, i) del y + +class PallasCallDMAInterpretTest(PallasCallDMATest): + INTERPRET = True + def test_interpret_local_dma(self): + # We run this test in interpret mode to test semaphore counting. + # On a physical device the values update asynchronously so we cannot + # deterministically check the values. def test_kernel(x_ref, o_ref, + sem_out_ref, copy_sem, ): o_ref[...] = jnp.zeros_like(o_ref[...]) input_to_output_copy = pltpu.make_async_copy( src_ref=x_ref.at[0:8], dst_ref=o_ref.at[0:8], - sem=copy_sem, + sem=copy_sem.at[0], ) input_to_output_copy.start() + sem_out_ref[0, :] = jnp.ones_like( + sem_out_ref[0, :]) * pltpu.semaphore_read(copy_sem.at[0]) input_to_output_copy.wait() + sem_out_ref[1, :] = jnp.ones_like( + sem_out_ref[0, :]) * pltpu.semaphore_read(copy_sem.at[0]) - out_shape = (jax.ShapeDtypeStruct((9, 128), jnp.float32)) + out_shape = (jax.ShapeDtypeStruct((16, 128), jnp.int32), + jax.ShapeDtypeStruct((2, 1), jnp.int32)) grid_spec = pltpu.PrefetchScalarGridSpec( num_scalar_prefetch=0, in_specs=[ - pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY), + pl.BlockSpec(memory_space=pl.ANY), ], scratch_shapes=( - [pltpu.SemaphoreType.DMA] + [pltpu.SemaphoreType.DMA(2,)] ) ) @@ -1470,13 +1491,61 @@ def test_kernel(x_ref, test_kernel, out_shape=out_shape, grid_spec=grid_spec, - interpret=True + interpret=True, ) - x = jax.random.normal(jax.random.key(0), shape=(16, 128)) - result = kernel(x) + x = jax.random.randint( + jax.random.key(0), shape=(16, 128), minval=0, maxval=128) + + result, semaphores = kernel(x) np.testing.assert_array_equal(result[0:8], x[0:8]) np.testing.assert_array_equal(result[8:], jnp.zeros_like(result[8:])) + # Make sure semaphores have the correct value before and after DMA wait. + result_sem_pre_wait = semaphores[0, 0] + np.testing.assert_array_equal(result_sem_pre_wait, result[0:8].size) + result_sem_post_wait = semaphores[1, 0] + np.testing.assert_array_equal(result_sem_post_wait, 0) + + def test_interpreter_semaphore_counting(self): + # We run this test in interpret mode because the kernel exits with + # non-zero values. In normal Pallas this would crash the kernel. + def test_kernel(o_ref, + sem_ref, + ): + o_ref[...] = jnp.zeros_like(o_ref) + pltpu.semaphore_signal(sem_ref.at[0], 1) + pltpu.semaphore_signal(sem_ref.at[1], 2) + pltpu.semaphore_signal(sem_ref.at[2], 3) + pltpu.semaphore_signal(sem_ref.at[3], 4) + o_ref[0, 0] = pltpu.semaphore_read(sem_ref.at[0]) + o_ref[1, 0] = pltpu.semaphore_read(sem_ref.at[1]) + o_ref[2, 0] = pltpu.semaphore_read(sem_ref.at[2]) + o_ref[3, 0] = pltpu.semaphore_read(sem_ref.at[3]) + pltpu.semaphore_wait(sem_ref.at[0], 4) + pltpu.semaphore_wait(sem_ref.at[1], 3) + pltpu.semaphore_wait(sem_ref.at[2], 2) + pltpu.semaphore_wait(sem_ref.at[3], 1) + o_ref[4, 0] = pltpu.semaphore_read(sem_ref.at[0]) + o_ref[5, 0] = pltpu.semaphore_read(sem_ref.at[1]) + o_ref[6, 0] = pltpu.semaphore_read(sem_ref.at[2]) + o_ref[7, 0] = pltpu.semaphore_read(sem_ref.at[3]) + + out_shape = jax.ShapeDtypeStruct((8, 1), jnp.int32) + grid_spec = pltpu.PrefetchScalarGridSpec( + num_scalar_prefetch=0, + scratch_shapes=( + [pltpu.SemaphoreType.DMA(4,)] + ) + ) + results = pl.pallas_call( + test_kernel, + out_shape=out_shape, + grid_spec=grid_spec, + interpret=True, + )() + expected = jnp.array([1, 2, 3, 4, -3, -1, 1, 3]).reshape(out_shape.shape) + np.testing.assert_array_equal(results, expected) + class PallasCallTest(PallasBaseTest): @@ -1527,12 +1596,12 @@ def kernel(x_ref, y_ref): self.pallas_call( kernel, out_shape=x, - compiler_params=dict(mosaic=dict(vmem_limit_bytes=256)), + compiler_params=pltpu.TPUCompilerParams(vmem_limit_bytes=256), )(x) self.pallas_call( kernel, out_shape=x, - compiler_params=dict(mosaic=dict(vmem_limit_bytes=int(2**18))), + compiler_params=pltpu.TPUCompilerParams(vmem_limit_bytes=int(2**18)), )(x) def test_allow_input_fusion(self): @@ -1549,7 +1618,7 @@ def f(x, y): in_specs=[pl.BlockSpec((1, 128, 128), lambda i: (i, 0, 0))], out_specs=pl.BlockSpec((1, 128, 128), lambda i: (i, 0, 0)), out_shape=x, - compiler_params=dict(mosaic=dict(allow_input_fusion=[True])), + compiler_params=pltpu.TPUCompilerParams(allow_input_fusion=[True]), )(z) x = jnp.arange(np.prod(shape), dtype=np.float32).reshape(shape) @@ -1577,131 +1646,12 @@ def kernel(x_ref, y_ref): self.pallas_call( kernel, out_shape=jax.ShapeDtypeStruct(shape, jnp.float32), - compiler_params=dict( - mosaic=dict(internal_scratch_in_bytes=requested_bytes) + compiler_params=pltpu.TPUCompilerParams( + internal_scratch_in_bytes=requested_bytes, ), )(x) -class PallasCallUnblockedIndexingTest(PallasBaseTest): - - def test_block_spec_unblocked(self): - def show_program_ids(*, shape, block_shape, grid, - indexing_mode: pl.IndexingMode): - def kernel(o1_ref): - assert o1_ref.shape == block_shape - o1_ref[...] = jnp.full(o1_ref.shape, pl.program_id(0)) - - return self.pallas_call(kernel, - jax.ShapeDtypeStruct(shape, dtype=np.int32), - grid=grid, - out_specs=pl.BlockSpec(block_shape, - lambda i: (8 * i, 0), - indexing_mode=indexing_mode))() - # No padding - pids = show_program_ids(shape=(16, 128), block_shape=(8, 128), - grid=(2,), - indexing_mode=pl.Unblocked()) - expected_pids = np.array( - [[0] * 128] * 8 + [[1] * 128] * 8, - dtype=np.int32) - self.assertAllClose(pids, expected_pids) - - # Only high padding - pids = show_program_ids(shape=(14, 128), block_shape=(8, 128), - grid=(2,), - indexing_mode=pl.Unblocked(((0, 2), (0, 0)))) - expected_pids = np.array( - [[0] * 128] * 8 + [[1] * 128] * 6, - dtype=np.int32) - self.assertAllClose(pids, expected_pids) - - # Both low and high padding - self.skipTest("TODO: TPU low padding not supported yet") - pids = show_program_ids(shape=(11, 128), block_shape=(8, 128), - grid=(2,), - indexing_mode=pl.Unblocked(((3, 2), (0, 0)))) - expected_pids = np.array( - [[0] * 128] * 5 + [[1] * 128] * 6, - dtype=np.int32) - self.assertAllClose(pids, expected_pids) - - @parameterized.parameters("int32", "float32") - def test_block_spec_unblocked_padding_is_nan(self, dtype_name): - if not self.INTERPRET: - self.skipTest("Only applicable for the interpret mode") - - dtype = np.dtype(dtype_name) - def copy_kernel(x_ref, o_ref): - o_ref[...] = x_ref[...] - res = self.pallas_call(copy_kernel, - jax.ShapeDtypeStruct((6,), dtype=dtype), - grid=(1,), - in_specs=[pl.BlockSpec((6,), lambda i: 0, - indexing_mode=pl.Unblocked(((1, 2),)))])( - np.full((3,), 42, dtype=dtype) - ) - expected_pad = {"int32": jnp.iinfo(np.int32).min, - "float32": np.nan}[dtype_name] - self.assertAllClose(res, np.array([expected_pad, 42, 42, 42, - expected_pad, expected_pad], dtype=dtype)) - - def test_unblocked_indexing(self): - shape = (16 * 8, 128) - result_ty = jax.ShapeDtypeStruct((15 * 8, 128), jnp.float32) - - def kernel(x_ref, o_ref): - o_ref[...] = x_ref[pl.ds(0, 8)] + x_ref[pl.ds(8, 8)] - - x = np.arange(np.prod(shape), dtype=np.float32).reshape(shape) - y = self.pallas_call( - kernel, - grid=(15,), - in_specs=( - pl.BlockSpec( - (2 * 8, 128), lambda i: (i * 8, 0), indexing_mode=pl.unblocked - ), - ), - out_specs=pl.BlockSpec((8, 128), lambda i: (i, 0)), - out_shape=result_ty, - )(x) - ref = [] - for i in range(15): - block = x[i * 8:i * 8 + 2 * 8] - ref.append(block[0:8] + block[8:16]) - ref = np.concatenate(ref, axis=0) - np.testing.assert_array_equal(y, ref) - - def test_unblocked_indexing_with_padding(self): - shape = (8, 128) - result_ty = jax.ShapeDtypeStruct((8, 128), jnp.float32) - - def kernel(x_ref, y_ref): - y_ref[...] = x_ref[pl.ds(0, 8)] - - x = np.arange(np.prod(shape), dtype=np.float32).reshape(shape) - y = self.pallas_call( - kernel, - grid=(1,), - in_specs=( - pl.BlockSpec( - (2 * 8, 128), - lambda i: (0, 0), - indexing_mode=pl.Unblocked(((0, 8), (0, 0))), - ), - ), - out_specs=pl.BlockSpec((8, 128), lambda i: (0, 0)), - out_shape=result_ty, - )(x) - np.testing.assert_array_equal(y, x) - - -class PallasCallUnblockedIndexingInterpreterTest( - PallasCallUnblockedIndexingTest -): - INTERPRET = True - - class PallasUXTest(PallasBaseTest): def test_mlir_location(self): @@ -1858,6 +1808,100 @@ def kernel(size_smem_ref, x_hbm_ref, _, o_hbm_ref, sem): np.testing.assert_array_equal(out, expected) +class PallasCallRefTransformTest(PallasBaseTest): + + @parameterized.product(slice_first=[True, False]) + def test_dma_bitcasted_ref(self, slice_first): + if not jtu.is_device_tpu_at_least(4): + self.skipTest('DMAs not supported on TPU generations <= 3') + + def kernel(x_hbm_ref, y_hbm_ref): + def body(sem): + ref = ( + x_hbm_ref.at[:8, :, :128].bitcast(jnp.int16) + if slice_first + else x_hbm_ref.bitcast(jnp.int16).at[:8, :, :128] + ) + pltpu.async_copy(ref, y_hbm_ref.at[...], sem).wait() + + pl.run_scoped(body, pltpu.SemaphoreType.DMA) + + x = jnp.arange(4 * 8 * 128, dtype=jnp.int32).reshape((16, 1, 256)) + y = self.pallas_call( + kernel, + in_specs=[ + pl.BlockSpec(memory_space=pl.ANY), + ], + out_specs=pl.BlockSpec(memory_space=pl.ANY), + out_shape=jax.ShapeDtypeStruct((8, 2, 128), jnp.int16), + )(x) + expected = ( + state_utils.bitcast(x[:8, :, :128], jnp.int16) + if slice_first + else state_utils.bitcast(x, jnp.int16)[:8, :, :128] + ) + np.testing.assert_array_equal(y, expected) + + @parameterized.product(slice_first=[True, False]) + def test_load_bitcasted_ref(self, slice_first: bool): + def kernel(x_ref, y_ref): + ref = ( + x_ref.at[:8, :128].bitcast(jnp.int16) + if slice_first + else x_ref.bitcast(jnp.int16).at[:16, :128] + ) + y_ref[...] = ref[...] + + x = jnp.arange(4 * 8 * 128, dtype=jnp.int32).reshape((16, 256)) + y = self.pallas_call( + kernel, + out_shape=jax.ShapeDtypeStruct((16, 128), jnp.int16), + )(x) + expected = ( + state_utils.bitcast(x[:8, :128], jnp.int16) + if slice_first + else state_utils.bitcast(x, jnp.int16)[:16, :128] + ) + np.testing.assert_array_equal(y, expected) + + @parameterized.product(slice_first=[True, False]) + def test_store_bitcasted_ref(self, slice_first): + def kernel(x_ref, y_ref): + ref = ( + y_ref.at[:8, :128].bitcast(jnp.bfloat16) + if slice_first + else y_ref.bitcast(jnp.bfloat16).at[:16, :128] + ) + ref[...] = x_ref[...] + + x = jnp.arange(16 * 128, dtype=jnp.bfloat16).reshape((16, 128)) + y = self.pallas_call( + kernel, + out_shape=jax.ShapeDtypeStruct((16, 256), jnp.int32), + )(x) + expected = state_utils.bitcast(x, jnp.int32) + np.testing.assert_array_equal(y[:8, :128], expected) + + def test_multiple_ref_transforms(self): + + def kernel(x_ref, y_ref): + ref = ( + x_ref.at[:8, :256] + .bitcast(jnp.int16) + .bitcast(jnp.float16) + .at[:, :128] + .bitcast(jnp.int32) + ) + y_ref[...] = ref[...] + + x = jnp.arange(4 * 8 * 128, dtype=jnp.int32).reshape((16, 256)) + y = self.pallas_call( + kernel, + out_shape=jax.ShapeDtypeStruct((8, 128), jnp.int32), + )(x) + np.testing.assert_array_equal(y, x[:8, :128]) + + class PallasCallPrintTest(PallasBaseTest): def test_debug_print(self): @@ -1992,27 +2036,37 @@ def inner_scope(scoped_ref): def test_vector_bool_load_store(self): def kernel(x_ref, o_ref): o_ref[...] = x_ref[...] - input = jnp.array([[False, True, True, False]]) - output_shape = jax.ShapeDtypeStruct((1, 4), jnp.bool_) - if self.INTERPRET: - result = self.pallas_call( - kernel, - in_specs=[pl.BlockSpec(memory_space=pltpu.VMEM)], - out_specs=pl.BlockSpec(memory_space=pltpu.VMEM), - out_shape=output_shape, - )(input) - np.testing.assert_array_equal(result, input) - else: - # TODO(justinfu): Fix vector boolean ops so that they do not trigger - # a relayout error from changing bitwidths in Mosaic. - with self.assertRaisesRegex( - Exception, 'Boolean vector loads are not supported.'): - self.pallas_call( - kernel, - in_specs=[pl.BlockSpec(memory_space=pltpu.VMEM)], - out_specs=pl.BlockSpec(memory_space=pltpu.VMEM), - out_shape=output_shape, - )(input) + input = jax.random.bernoulli(jax.random.key(0), p=0.5, shape=(8, 128)) + output_shape = jax.ShapeDtypeStruct((8, 128), jnp.bool_) + result = self.pallas_call( + kernel, + in_specs=[pl.BlockSpec(memory_space=pltpu.VMEM)], + out_specs=pl.BlockSpec(memory_space=pltpu.VMEM), + out_shape=output_shape, + )(input) + np.testing.assert_array_equal(result, input) + + def test_vector_bool_masking_with_indexing(self): + def kernel(mask_ref, true_ref, false_ref, o_ref): + o_ref[0, ...] = jnp.where( + mask_ref[0, ...], true_ref[0, ...], false_ref[0, ...]) + key = jax.random.key(0) + k1, k2, k3 = jax.random.split(key, 3) + values_1 = jax.random.normal(k1, (1, 256, 256), jnp.float32) + values_2 = jax.random.normal(k2, (1, 256, 256), jnp.float32) + mask = jax.random.bernoulli(k3, p=0.5, shape=(1, 256, 256)) + output_shape = jax.ShapeDtypeStruct((1, 256, 256), jnp.float32) + result = self.pallas_call( + kernel, + in_specs=[pl.BlockSpec(memory_space=pltpu.VMEM), + pl.BlockSpec(memory_space=pltpu.VMEM), + pl.BlockSpec(memory_space=pltpu.VMEM), + ], + out_specs=pl.BlockSpec(memory_space=pltpu.VMEM), + out_shape=output_shape, + )(mask, values_1, values_2) + expected = jnp.where(mask, values_1, values_2) + np.testing.assert_array_equal(result, expected) def test_bool_dma_not_implemented(self): if not jtu.is_device_tpu_at_least(4): @@ -2170,6 +2224,43 @@ class PallasCallTPUCheckifyInterpretTest(PallasCallTPUCheckifyTest): INTERPRET: bool = True +class PrettyPrintingTest(PallasBaseTest): + + @parameterized.parameters( + ( + lambda i: (i, pl.ds(0, 8), pl.ds(0, 128)), + 'dma_start c[d,:,:] -> e[...] f', + ), + ( + lambda i: (0, pl.ds(i, 8), pl.ds(0, 128)), + 'dma_start c[0,d:d+8,:] -> e[...] f', + ), + ( + lambda i: (i, pl.ds(2, 4), pl.ds(0, 100)), + 'dma_start c[d,2:6,:100] -> e[...] f', + ), + ( + lambda i: (i, pl.ds(2, 6), pl.ds(4, 100)), + 'dma_start c[d,2:,4:104] -> e[...] f', + ), + ) + def test_dma_custom_pretty_print(self, indexer, expected): + def body(x_hbm_ref, i): + def inner(x_ref, sem): + pltpu.async_copy(x_hbm_ref.at[indexer(i)], x_ref, sem).wait() + + pl.run_scoped( + inner, pltpu.VMEM((8, 128), jnp.float32), pltpu.SemaphoreType.DMA + ) + return [] + + jaxpr, _, _, () = pe.trace_to_jaxpr_dynamic( + lu.wrap_init(body), [state.shaped_array_ref((2, 8, 128), jnp.int32), + jax.core.ShapedArray((), jnp.int32)] + ) + self.assertIn(expected, jaxpr.pretty_print(use_color=False)) + + def only_passes_in_interpret(unless_generation: int | None = None): def decorator(f): def wrapper(self): @@ -2215,9 +2306,7 @@ def kernel(x_ref, out_ref): )(x) np.testing.assert_array_equal(out, np.reshape(x, (1, 256, 8, 128))) - @only_passes_in_interpret() def test_lane_to_chunk_broadcast_fp32(self): - """b/348033362""" x = np.arange(256 * 128, dtype=jnp.float32).reshape(1, 256, 128) def kernel(x_ref, out_ref): @@ -2256,9 +2345,7 @@ def kernel(x_ref, out_ref): )(x) np.testing.assert_array_equal(out, np.broadcast_to(x, (256, 512))) - @only_passes_in_interpret(unless_generation=4) def test_bfloat16_to_uint32_bitcast(self): - """b/347771903""" x = np.arange(16 * 2 * 256, dtype=jnp.bfloat16).reshape(16, 2, 256) def kernel(x_ref, out_ref): @@ -2267,7 +2354,7 @@ def kernel(x_ref, out_ref): out = self.pallas_call( kernel, out_shape=jax.ShapeDtypeStruct((16, 1, 256), jnp.uint32) )(x) - # FIXME: Add correctness test for result. + np.testing.assert_array_equal(out, state_utils.bitcast(x, jnp.uint32)) @only_passes_in_interpret() def test_roll_partial(self): @@ -2314,9 +2401,7 @@ def kernel(x_ref, out_ref): np.testing.assert_array_equal(out, np.reshape(x[:, 7, :], (1, 8, 128))) - @only_passes_in_interpret() def test_sublane_adding_shape_cast_f32(self): - """b/352833257""" x = np.arange(8 * 128, dtype=jnp.float32).reshape(8, 128) def kernel(x_ref, out_ref): @@ -2342,9 +2427,7 @@ def kernel(x_ref, out_ref): np.testing.assert_array_equal(out, np.reshape(x, (8, 1, 128))) - @only_passes_in_interpret() def test_mixed_strides(self): - """b/352841329""" x = np.zeros((8, 128), dtype=jnp.float32) y = np.zeros((8, 2, 128), dtype=jnp.bfloat16) @@ -2358,9 +2441,7 @@ def kernel(x_ref, y_ref, out_ref): np.testing.assert_array_equal(out, np.zeros((8, 128), dtype=jnp.float32)) - @only_passes_in_interpret() def test_sum(self): - """b/356467588""" x = np.zeros((8, 2, 8, 128), dtype=jnp.float32) def kernel(x_ref, out_ref): @@ -2389,7 +2470,7 @@ def kernel(x_ref, out_ref): ) -class MiscellaneousInterpreterTest(MiscellaneousTest): +class MiscellaneousInterpretTest(MiscellaneousTest): INTERPRET: bool = True diff --git a/tests/pgle_test.py b/tests/pgle_test.py index 7dc015c90bca..4d60c3017287 100644 --- a/tests/pgle_test.py +++ b/tests/pgle_test.py @@ -54,7 +54,7 @@ def tearDown(self): @unittest.skip("Test failing in CI") def testPGLEProfilerGetFDOProfile(self): - mesh = jtu.create_global_mesh((2,), ('x',)) + mesh = jtu.create_mesh((2,), ('x',)) @partial( jax.jit, @@ -83,7 +83,7 @@ def f(x, y): @unittest.skip("Test failing in CI") def testPGLEProfilerGetFDOProfileLarge(self): - mesh = jtu.create_global_mesh((2,), ('x',)) + mesh = jtu.create_mesh((2,), ('x',)) its = 500 @partial( @@ -112,7 +112,7 @@ def f(x): self.assertEqual(fdo_profile.count(b'custom'), its) def testAutoPgle(self): - mesh = jtu.create_global_mesh((2,), ('x',)) + mesh = jtu.create_mesh((2,), ('x',)) @partial( jax.jit, @@ -245,7 +245,7 @@ def check_if_cache_hit(event): self.assertFalse(pgle_profiler.is_fdo_consumed()) def testPassingFDOProfile(self): - mesh = jtu.create_global_mesh((2,), ('x',)) + mesh = jtu.create_mesh((2,), ('x',)) @partial( jax.jit, diff --git a/tests/pjit_test.py b/tests/pjit_test.py index 516d1fec7ff9..c20084c3c8e2 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -57,6 +57,7 @@ from jax._src import xla_bridge from jax._src.lib import xla_client as xc from jax._src.lib import xla_extension +from jax._src.lib import xla_extension_version from jax._src.util import curry, unzip2 config.parse_flags_with_absl() @@ -186,10 +187,10 @@ def f(x, y): shape = (8, 8) x = np.arange(math.prod(shape), dtype=np.float32).reshape(shape) - with jtu.create_global_mesh((2,), ('x')) as mesh: + with jtu.create_mesh((2,), ('x')) as mesh: actual = f(x, x + 1) expected = x + (x + 1) - self.assertEqual(mesh, jtu.create_global_mesh((2,), ('x'))) + self.assertEqual(mesh, jtu.create_mesh((2,), ('x'))) self.assertAllClose(actual, expected, check_dtypes=False) _check_instance(self, actual) self.assertLen(actual.addressable_shards, 2) @@ -225,15 +226,15 @@ def f(x, y): check_dtypes=False) def testDifferentNestedMesh(self): - with jtu.create_global_mesh((2, 1), ("x", "y")) as m1: - with jtu.create_global_mesh((2, 2), ("a", "b")) as m2: + with jtu.create_mesh((2, 1), ("x", "y")) as m1: + with jtu.create_mesh((2, 2), ("a", "b")) as m2: self.assertEqual(mesh_lib.thread_resources.env.physical_mesh, m2) self.assertEqual(mesh_lib.thread_resources.env.physical_mesh, m1) self.assertEqual(mesh_lib.thread_resources.env.physical_mesh, mesh_lib.EMPTY_ENV.physical_mesh) def testSameNestedMesh(self): - mesh = jtu.create_global_mesh((2, 1), ("a", "b")) + mesh = jtu.create_mesh((2, 1), ("a", "b")) thread_resources = mesh_lib.thread_resources with mesh as m1: with mesh as m2: @@ -257,7 +258,7 @@ def dec(): self.assertArraysEqual(out, x) def testMeshHashRace(self): - mesh = jtu.create_global_mesh((2, 1), ('a', 'testMeshHashRace')) + mesh = jtu.create_mesh((2, 1), ('a', 'testMeshHashRace')) self.assertFalse(hasattr(mesh, '_hash')) with concurrent.futures.ThreadPoolExecutor(max_workers=5) as pool: fs = [] @@ -310,7 +311,7 @@ def f(x, y): @jtu.run_on_devices('cpu', 'gpu', 'tpu') def testBufferDonationWithNames(self): - mesh = jtu.create_global_mesh((2,), ('x')) + mesh = jtu.create_mesh((2,), ('x')) s = NamedSharding(mesh, P('x')) @partial(pjit, out_shardings=s, donate_argnames='inp2') @@ -326,7 +327,7 @@ def f(inp1, inp2): @jtu.run_on_devices('cpu', 'gpu', 'tpu') def testBufferDonationWithKwargs(self): - mesh = jtu.create_global_mesh((2,), ('x')) + mesh = jtu.create_mesh((2,), ('x')) s = NamedSharding(mesh, P('x')) @partial(pjit, out_shardings=s, donate_argnames=('inp2', 'inp3')) @@ -345,7 +346,7 @@ def f(inp1, inp2, inp3): @jtu.run_on_devices('cpu', 'gpu', 'tpu') def testBufferDonationWithPyTreeKwargs(self): - mesh = jtu.create_global_mesh((2,), ('x')) + mesh = jtu.create_mesh((2,), ('x')) s = NamedSharding(mesh, P('x')) @partial(pjit, out_shardings=s, donate_argnames='inp2') @@ -370,7 +371,7 @@ def f(inp1, inp2, inp3): @jtu.run_on_devices('tpu', 'cpu', 'gpu') def testBufferDonationWithOutputShardingInference(self): - mesh = jtu.create_global_mesh((2,), 'x') + mesh = jtu.create_mesh((2,), 'x') s = NamedSharding(mesh, P('x')) rs = NamedSharding(mesh, P()) @@ -401,7 +402,9 @@ def f(inp1, inp2, inp3): @jtu.run_on_devices('tpu') def testBufferDonationWithOutputShardingInferenceAndTokens(self): - mesh = jtu.create_global_mesh((2,), 'x') + if config.use_shardy_partitioner.value: + self.skipTest('b/355263220: Shardy does not support callbacks yet.') + mesh = jtu.create_mesh((2,), 'x') s = NamedSharding(mesh, P('x')) def _callback(x): @@ -422,7 +425,7 @@ def f(x): @jtu.run_on_devices('tpu', 'cpu', 'gpu') def testBufferDonationNotDonated(self): - mesh = jtu.create_global_mesh((2,), 'x') + mesh = jtu.create_mesh((2,), 'x') s = NamedSharding(mesh, P('x')) @partial(pjit, donate_argnames=('x')) @@ -452,13 +455,19 @@ def f(x): check_dtypes=False) hlo = f.lower(np.ones(shape)).compiler_ir() - # Annotation from with_sharding_constraint - self.assertIn('sharding = "{devices=[2,1]<=[2]}"', str(hlo)) - # Annotation from pjit - self.assertIn('sharding = "{replicated}"', str(hlo)) + if config.use_shardy_partitioner.value: + # Annotation from with_sharding_constraint + self.assertIn('<@mesh, [{"x"}, {"y"}]>', str(hlo)) + # Annotation from pjit + self.assertIn('sharding = #sdy.sharding<@mesh, [{}, {}]>}', str(hlo)) + else: + # Annotation from with_sharding_constraint + self.assertIn('sharding = "{devices=[2,1]<=[2]}"', str(hlo)) + # Annotation from pjit + self.assertIn('sharding = "{replicated}"', str(hlo)) def testShardingConstraintWithArray(self): - mesh = jtu.create_global_mesh((2, 1), ('x', 'y')) + mesh = jtu.create_mesh((2, 1), ('x', 'y')) s = NamedSharding(mesh, P(None)) @partial(pjit, in_shardings=s, out_shardings=s) @@ -483,8 +492,10 @@ def f(x): self.assertIn("sharding={replicated}", hlo.as_hlo_text()) def testShardingConstraintWithArrayOpSharding(self): + if config.use_shardy_partitioner.value: + self.skipTest("Shardy doesn't support PositionalSharding") shape = (8, 8) - mesh = jtu.create_global_mesh((2, 1), ('x', 'y')) + mesh = jtu.create_mesh((2, 1), ('x', 'y')) s = NamedSharding(mesh, P(None)) ops = pxla.to_gspmd_sharding( NamedSharding(mesh, P('x', 'y')), len(shape)) @@ -510,7 +521,7 @@ def f(x): self.assertIn("sharding={replicated}", hlo.as_hlo_text()) def testShardingConstraintPyTreeWithArray(self): - mesh = jtu.create_global_mesh((2, 1), ('x', 'y')) + mesh = jtu.create_mesh((2, 1), ('x', 'y')) @jax.jit def f(x): @@ -533,7 +544,7 @@ def f(x): def testShardingConstraintPyTreeWithUnconstrainedDimsWithJit(self): - mesh = jtu.create_global_mesh((2, 2), ('x', 'y')) + mesh = jtu.create_mesh((2, 2), ('x', 'y')) @jax.jit def f(x): x = with_sharding_constraint( @@ -554,8 +565,12 @@ def f(x): self.assertLen(actual[0]['a'].addressable_shards, 4) mlir_str = str(f.lower(x).compiler_ir()) - self.assertIn("unspecified_dims=[0]", mlir_str) - self.assertIn("unspecified_dims=[1]", mlir_str) + if config.use_shardy_partitioner.value: + self.assertIn('<@mesh, [{?}, {"y"}, {}]>', mlir_str) + self.assertIn('<@mesh, [{"x"}, {?}, {}]>', mlir_str) + else: + self.assertIn("unspecified_dims=[0]", mlir_str) + self.assertIn("unspecified_dims=[1]", mlir_str) @jtu.with_mesh([('x', 2), ('y', 2)]) def testShardingConstraintPyTreeVmapWithUnconstrainedDims(self): @@ -574,8 +589,12 @@ def f(x): x = [{'a': v, 'b': v * 2}, v * 3] mlir_str = str(f.lower(x).compiler_ir()) - self.assertIn("unspecified_dims=[0,1]", mlir_str) - self.assertIn("unspecified_dims=[0,2]", mlir_str) + if config.use_shardy_partitioner.value: + self.assertIn('<@mesh, [{?}, {?}, {"y"}]>', mlir_str) + self.assertIn('<@mesh, [{?}, {"x"}, {?}]>', mlir_str) + else: + self.assertIn("unspecified_dims=[0,1]", mlir_str) + self.assertIn("unspecified_dims=[0,2]", mlir_str) def testCaching(self): def f(x): @@ -634,18 +653,16 @@ def testAutodiff(self, mesh, resources): @jtu.with_mesh([('x', 2), ('y', 1)]) def testAutodiffCache(self): - f = pjit( - lambda x: jnp.sin(x).sum(), in_shardings=P('x'), out_shardings=None - ) + f = pjit(lambda x: jnp.sin(x).sum(), in_shardings=P('x'), out_shardings=None) x = jnp.arange(16, dtype=jnp.float32) - jax.grad(f)(x) # Warm up the cache. - before = pjit_lib._pjit_lower_cached.cache_info() - jax.grad(f)(x) - after = pjit_lib._pjit_lower_cached.cache_info() - # One hit for the forward pass, one hit for backward. - self.assertEqual(after.hits, before.hits + 2) - self.assertEqual(after.misses, before.misses) + jax.grad(f)(x) # Warm up the cache. + with jtu.count_pjit_cpp_cache_miss() as count: + jax.grad(f)(x) + if xla_extension_version >= 286: + self.assertEqual(count[0], 0) # no cache miss i.e. cache hit + else: + self.assertEqual(count[0], 2) @jtu.with_mesh([('x', 2), ('y', 1)]) def testEvalJaxpr(self): @@ -848,6 +865,9 @@ def f_for_pjit(x): def testOutfeed(self): if xla_bridge.using_pjrt_c_api(): raise unittest.SkipTest('outfeed not implemented in PJRT C API') + if config.use_shardy_partitioner.value: + self.skipTest( + 'b/355263220: outfeed lowering not supported by Shardy') devices = np.array(jax.local_devices()) nr_devices = len(devices) @@ -1147,7 +1167,7 @@ def f(x, y): def test_local_sharded_key_array_sda(self): input_shape = (8, 4) - mesh = jtu.create_global_mesh((4, 2), ('x', 'y')) + mesh = jtu.create_mesh((4, 2), ('x', 'y')) seeds = jnp.arange( math.prod(input_shape), dtype=np.uint32).reshape(input_shape) @@ -1164,7 +1184,7 @@ def make_keys(seeds): jax.random.key_data(out) # doesn't crash def test_with_sharding_constraint_is_compatible_error(self): - mesh = jtu.create_global_mesh((1, 1, 2), ('replica', 'data', 'mdl')) + mesh = jtu.create_mesh((1, 1, 2), ('replica', 'data', 'mdl')) with mesh: def f(x): @@ -1265,7 +1285,7 @@ def f(x): ) def test_with_sharding_constraint_vmap_spmd_axis_name_error(self): - mesh = jtu.create_global_mesh((2, 2), ('x', 'y')) + mesh = jtu.create_mesh((2, 2), ('x', 'y')) def f(x): return jax.lax.with_sharding_constraint(x, NamedSharding(mesh, P('x'))) @@ -1281,6 +1301,9 @@ class CustomPartitionerTest(jtu.JaxTestCase): def skip_if_custom_partitioning_not_supported(self): if jtu.is_cloud_tpu(): raise unittest.SkipTest("Custom partitioning is not supported on libtpu.") + if config.use_shardy_partitioner.value: + self.skipTest( + 'Custom partitioning is not supported with Shardy yet.') @jtu.skip_on_devices('cpu') # Collectives don't seem to work on CPU. @jtu.with_mesh([('x', 4), ('y', 2)]) @@ -1483,6 +1506,37 @@ def infer_sharding_from_operands(mesh, arg_shapes, result_shape): pjit_f = pjit(jit_f, in_shardings=(P('x')), out_shardings=P('x')) self.assertArraysEqual(x, pjit_f(x)) + @jtu.with_mesh([('x', 4)]) + def test_custom_partitioner_with_scan(self): + self.skip_if_custom_partitioning_not_supported() + + # This is a reproducer from https://github.com/jax-ml/jax/issues/20864. + + @custom_partitioning + def f(x): + return jnp.sum(x) + + def partition(mesh, arg_shapes, result_shape): + def lower_fn(xs): + def f(carry, x): + return carry + jax.lax.psum(jnp.sum(x), axis_name='x'), None + + carry, _ = jax.lax.scan(f, 0, xs) + return carry + + result_shardings = jax.tree.map(lambda x: x.sharding, result_shape) + arg_shardings = jax.tree.map(lambda x: x.sharding, arg_shapes) + return mesh, lower_fn, result_shardings, arg_shardings + + f.def_partition( + partition, + infer_sharding_from_operands=lambda mesh, *_: NamedSharding(mesh, P()), + propagate_user_sharding=lambda _, user_shape: user_shape.sharding) + + pjit_f = pjit(f, in_shardings=P(None, 'x')) + xs = jnp.ones([32, 16]) + self.assertEqual(pjit_f(xs), xs.sum()) + def test_custom_partitioning_no_mesh_context(self): self.skip_if_custom_partitioning_not_supported() @@ -1515,43 +1569,12 @@ def infer_sharding_from_operands(mesh, arg_shapes, result_shape): partition=partition, ) - mesh = jtu.create_global_mesh((4,), ('x',)) + mesh = jtu.create_mesh((4,), ('x',)) x = np.asarray(np.random.randint(0, 20, (32,)), dtype=np.float32) s = NamedSharding(mesh, P('x')) - pjit_f = jax.jit(f, in_shardings=s, out_shardings=s) - self.assertArraysEqual(x, pjit_f(x)) - - @jtu.with_mesh([('x', 4)]) - def test_custom_partitioner_with_scan(self): - self.skip_if_custom_partitioning_not_supported() - - # This is a reproducer from https://github.com/google/jax/issues/20864. - - @custom_partitioning - def f(x): - return jnp.sum(x) - - def partition(mesh, arg_shapes, result_shape): - def lower_fn(xs): - def f(carry, x): - return carry + jax.lax.psum(jnp.sum(x), axis_name='x'), None - - carry, _ = jax.lax.scan(f, 0, xs) - return carry - - result_shardings = jax.tree.map(lambda x: x.sharding, result_shape) - arg_shardings = jax.tree.map(lambda x: x.sharding, arg_shapes) - return mesh, lower_fn, result_shardings, arg_shardings - - f.def_partition( - partition, - infer_sharding_from_operands=lambda mesh, *_: NamedSharding(mesh, P()), - propagate_user_sharding=lambda _, user_shape: user_shape.sharding) - - pjit_f = pjit(f, in_shardings=P(None, 'x')) - xs = jnp.ones([32, 16]) - self.assertEqual(pjit_f(xs), xs.sum()) + jit_f = jax.jit(f, in_shardings=s, out_shardings=s) + self.assertArraysEqual(x, jit_f(x)) @jtu.pytest_mark_if_available('multiaccelerator') @@ -1565,7 +1588,9 @@ class AutoShardingPjitTest(jtu.JaxTestCase): ) def test_pjit_arr_auto_sharding_array(self, mesh_shape, global_input_shape, mesh_axis_names): - global_mesh = jtu.create_global_mesh(mesh_shape, mesh_axis_names) + if config.use_shardy_partitioner.value: + self.skipTest('Must register auto partitioner for Shardy') + global_mesh = jtu.create_mesh(mesh_shape, mesh_axis_names) input_data = np.arange( math.prod(global_input_shape), dtype=np.float32).reshape(global_input_shape) @@ -1581,7 +1606,9 @@ def test_pjit_arr_auto_sharding_array(self, mesh_shape, global_input_shape, self.assertArraysEqual(out._value, input_data) def test_xla_arr_sharding_mismatch(self): - global_mesh = jtu.create_global_mesh((2, 2), ('x', 'y')) + if config.use_shardy_partitioner.value: + self.skipTest('Must register auto partitioner for Shardy') + global_mesh = jtu.create_mesh((2, 2), ('x', 'y')) global_input_shape = (6, 2) input_data = np.arange( math.prod(global_input_shape), dtype=np.float32).reshape(global_input_shape) @@ -1608,7 +1635,9 @@ def test_xla_arr_sharding_mismatch(self): compiled(arr) def test_gda_auto_shardings_len(self): - global_mesh = jtu.create_global_mesh((2, 2), ('x', 'y')) + if config.use_shardy_partitioner.value: + self.skipTest('Must register auto partitioner for Shardy') + global_mesh = jtu.create_mesh((2, 2), ('x', 'y')) global_input_shape = (4, 2) input_data = np.arange( math.prod(global_input_shape), dtype=np.float32).reshape(global_input_shape) @@ -1628,7 +1657,9 @@ def test_gda_auto_shardings_len(self): ) def test_jit_arr_partial_auto_sharding_array( self, mesh_shape, mesh_axis_names, pspec): - mesh = jtu.create_global_mesh(mesh_shape, mesh_axis_names) + if config.use_shardy_partitioner.value: + self.skipTest('Must register auto partitioner for Shardy') + mesh = jtu.create_mesh(mesh_shape, mesh_axis_names) global_input_shape = (8, 4) input_data = np.arange( math.prod(global_input_shape), dtype=np.float32).reshape(global_input_shape) @@ -1649,7 +1680,7 @@ def test_jit_arr_partial_auto_sharding_array( self.assertArraysEqual(o._value, input_data) def test_jit_different_mesh_in_auto(self): - mesh1 = jtu.create_global_mesh((4,), ('x',)) + mesh1 = jtu.create_mesh((4,), ('x',)) dev = jax.devices() mesh2 = jax.sharding.Mesh([dev[0], dev[3], dev[2], dev[1]], 'x') f = jax.jit(lambda x, y: (x, y), @@ -1668,8 +1699,10 @@ def test_jit_auto_sharding_partial_tuple_input_shardings( self, mesh_shape, mesh_axis_names): if not jtu.test_device_matches(["tpu"]): self.skipTest('Parameters are tupled only on TPU if >2000 parameters') + if config.use_shardy_partitioner.value: + self.skipTest('Must register auto partitioner for Shardy') - mesh = jtu.create_global_mesh(mesh_shape, mesh_axis_names) + mesh = jtu.create_mesh(mesh_shape, mesh_axis_names) global_input_shape = (8, 4) input_data = np.arange( math.prod(global_input_shape), dtype=np.float32).reshape(global_input_shape) @@ -1698,7 +1731,7 @@ def test_jit_auto_sharding_partial_tuple_input_shardings( @unittest.skip('The error is not raised yet. Enable this back once we raise ' 'the error in pjit again.') def test_pjit_array_error(self): - global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y')) + global_mesh = jtu.create_mesh((4, 2), ('x', 'y')) global_input_shape = (8, 2) input_data = np.arange( math.prod(global_input_shape), dtype=np.float32).reshape(global_input_shape) @@ -1728,7 +1761,7 @@ class ArrayPjitTest(jtu.JaxTestCase): ) def test_pjit_array_single_output(self, out_axis_resources, shard_shape): global_input_shape = (8, 2) - global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y')) + global_mesh = jtu.create_mesh((4, 2), ('x', 'y')) mesh_axes = P('x', 'y') input_array, input_data = create_array(global_input_shape, global_mesh, mesh_axes) @@ -1754,7 +1787,7 @@ def test_pjit_array_single_output(self, out_axis_resources, shard_shape): def test_pjit_array_single_output_with_mesh_context_manager( self, out_axis_resources, shard_shape): global_input_shape = (8, 2) - global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y')) + global_mesh = jtu.create_mesh((4, 2), ('x', 'y')) mesh_axes = P('x', 'y') input_array, input_data = create_array(global_input_shape, global_mesh, mesh_axes) @@ -1775,7 +1808,7 @@ def test_pjit_array_single_output_with_mesh_context_manager( def test_numpy_array_input_assume_fully_replicated(self): input_shape = (8, 2) - global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y')) + global_mesh = jtu.create_mesh((4, 2), ('x', 'y')) input_data = np.arange( math.prod(input_shape)).reshape(input_shape) @@ -1790,7 +1823,7 @@ def test_numpy_array_input_assume_fully_replicated(self): def test_numpy_array_input(self): input_shape = (8, 2) - global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y')) + global_mesh = jtu.create_mesh((4, 2), ('x', 'y')) input_data = np.arange( math.prod(input_shape), dtype=np.float32).reshape(input_shape) with global_mesh: @@ -1819,7 +1852,7 @@ def _checks(out, input_data): self.assertArraysEqual(out._value, input_data) global_input_shape = (8, 2) - global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y')) + global_mesh = jtu.create_mesh((4, 2), ('x', 'y')) mesh_axes = P('x', 'y') input_array, input_data = create_array(global_input_shape, global_mesh, mesh_axes) @@ -1839,7 +1872,12 @@ def _checks(out, input_data): ) def test_pjit_array_multi_input_multi_output(self, mesh_shape, s1_shape, s2_shape, s3_shape, s4_shape): - global_mesh = jtu.create_global_mesh(mesh_shape, ('x', 'y')) + if config.use_shardy_partitioner.value: + self.skipTest( + 'TODO(b/355263220) Shardy conflict resolution is not complete. Issue ' + 'here is that for `a1 @ a1.T` GSPMD gives dim 0 sharded on `x` while ' + 'Shardy gives it fully replicated.') + global_mesh = jtu.create_mesh(mesh_shape, ('x', 'y')) global_input_shape = (8, 2) spec1 = P('x', 'y') @@ -1883,8 +1921,8 @@ def f(tree): self.assertArraysEqual(s.data, input_data) def test_sds_full_like(self): - # https://github.com/google/jax/issues/20390 - mesh = jtu.create_global_mesh((2, 1), ('x', 'y')) + # https://github.com/jax-ml/jax/issues/20390 + mesh = jtu.create_mesh((2, 1), ('x', 'y')) s = NamedSharding(mesh, P('x', 'y')) x = jax.ShapeDtypeStruct((4, 4), jnp.float32, sharding=s) y = jnp.zeros_like(x) @@ -1896,7 +1934,7 @@ def test_sds_full_like(self): def test_in_axis_resources_mismatch_error(self): global_input_shape = (8, 2) - global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y')) + global_mesh = jtu.create_mesh((4, 2), ('x', 'y')) mesh_axes = P('x', 'y') input_array, _ = create_array(global_input_shape, global_mesh, mesh_axes) @@ -1912,7 +1950,7 @@ def test_in_axis_resources_mismatch_error(self): def test_in_axis_resources_same_as_array_sharding(self): global_input_shape = (8, 2) - global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y')) + global_mesh = jtu.create_mesh((4, 2), ('x', 'y')) mesh_axes = P('x', 'y') input_array, _ = create_array(global_input_shape, global_mesh, mesh_axes) @@ -1930,11 +1968,11 @@ def f(): def test_array_device_assignment_mismatch_with_mesh(self): global_input_shape = (8, 2) - global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y')) + global_mesh = jtu.create_mesh((4, 2), ('x', 'y')) mesh_axes = P('x', 'y') input_array, _ = create_array( - global_input_shape, jtu.create_global_mesh((2, 2), ('x', 'y')), + global_input_shape, jtu.create_mesh((2, 2), ('x', 'y')), mesh_axes) with global_mesh: @@ -1944,7 +1982,7 @@ def test_array_device_assignment_mismatch_with_mesh(self): def test_array_lower_compile(self): global_input_shape = (8, 2) - global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y')) + global_mesh = jtu.create_mesh((4, 2), ('x', 'y')) a1, input_data = create_array(global_input_shape, global_mesh, P('x', 'y')) a2, _ = create_array(global_input_shape, global_mesh, P('x')) @@ -1998,7 +2036,7 @@ def make_keys(seeds): def test_globally_sharded_key_array_8x4_multi_device_with_out_sharding(self): input_shape = (8, 4) - mesh = jtu.create_global_mesh((4, 2), ('x', 'y')) + mesh = jtu.create_mesh((4, 2), ('x', 'y')) spec = P('x', 'y') seeds, _ = create_array(input_shape, mesh, spec, dtype=np.uint32) @@ -2015,7 +2053,7 @@ def make_keys(seeds): def test_globally_sharded_key_array_8x4_multi_device(self): input_shape = (8, 4) - mesh = jtu.create_global_mesh((4, 2), ('x', 'y')) + mesh = jtu.create_mesh((4, 2), ('x', 'y')) spec = P('x', 'y') seeds, _ = create_array(input_shape, mesh, spec, dtype=np.uint32) @@ -2032,8 +2070,8 @@ def make_keys(seeds): def test_array_device_assignment_mismatch_out_shardings(self): input_shape = (8, 2) - m1 = jtu.create_global_mesh((4, 2), ('x', 'y')) - m2 = jtu.create_global_mesh((2, 2), ('x', 'y')) + m1 = jtu.create_mesh((4, 2), ('x', 'y')) + m2 = jtu.create_mesh((2, 2), ('x', 'y')) spec = P('x', 'y') a1 = jnp.arange(math.prod(input_shape)).reshape(input_shape) @@ -2047,8 +2085,8 @@ def test_array_device_assignment_mismatch_out_shardings(self): def test_array_device_assignment_mismatch_in_and_out_shardings(self): input_shape = (8, 2) - m1 = jtu.create_global_mesh((4, 2), ('x', 'y')) - m2 = jtu.create_global_mesh((2, 2), ('x', 'y')) + m1 = jtu.create_mesh((4, 2), ('x', 'y')) + m2 = jtu.create_mesh((2, 2), ('x', 'y')) spec = P('x', 'y') a1 = jnp.arange(math.prod(input_shape)).reshape(input_shape) @@ -2064,7 +2102,7 @@ def test_array_device_assignment_mismatch_in_and_out_shardings(self): def test_mixed_inputs(self): input_shape = (8, 2) - global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y')) + global_mesh = jtu.create_mesh((4, 2), ('x', 'y')) spec = P('x', 'y') a1, input_data = create_array(input_shape, global_mesh, spec) @@ -2079,7 +2117,7 @@ def test_mixed_inputs(self): f(input_data, a1) def test_pjit_array_same_sharding_aot(self): - global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y')) + global_mesh = jtu.create_mesh((4, 2), ('x', 'y')) input_shape = (8, 2) a1, _ = create_array(input_shape, global_mesh, P(None,)) with global_mesh: @@ -2225,7 +2263,7 @@ def test_array_enabled_non_empty_mesh_with_pspec(self): def test_pjit_uncommitted_array_reshard(self): arr = jnp.array([[1, 2, 3]]) - mesh = jtu.create_global_mesh((4, 2), ('x', 'y')) + mesh = jtu.create_mesh((4, 2), ('x', 'y')) with mesh: out = pjit(lambda x: x)(arr) self.assertArraysEqual(out, arr) @@ -2233,7 +2271,7 @@ def test_pjit_uncommitted_array_reshard(self): def test_pjit_uncommitted_array_in_axis_resources_reshard(self): arr = jnp.arange(16).reshape(8, 2) - mesh = jtu.create_global_mesh((4, 2), ('x', 'y')) + mesh = jtu.create_mesh((4, 2), ('x', 'y')) with mesh: out = pjit(lambda x: x, in_shardings=P('x', 'y'))(arr) self.assertArraysEqual(out, arr) @@ -2245,7 +2283,7 @@ def test_pjit_uncommitted_array_in_axis_resources_reshard(self): def test_pjit_uncommitted_array_and_committed_array(self): shape = (8, 2) uarr = jnp.arange(math.prod(shape), dtype=np.float32).reshape(shape) - mesh = jtu.create_global_mesh((4, 2), ('x', 'y')) + mesh = jtu.create_mesh((4, 2), ('x', 'y')) carr, inp_data = create_array(shape, mesh, P('x', 'y')) with mesh: out1, out2 = pjit(lambda x, y: (x, y))(uarr, carr) @@ -2258,7 +2296,7 @@ def test_pjit_uncommitted_array_and_committed_array(self): self.assertEqual(mul_out.shape, (8, 8)) self.assertLen(mul_out.addressable_shards, 8) - with jtu.create_global_mesh((2, 2), ('x', 'y')): + with jtu.create_mesh((2, 2), ('x', 'y')): with self.assertRaisesRegex( ValueError, "Received incompatible devices for pjitted computation"): @@ -2266,7 +2304,7 @@ def test_pjit_uncommitted_array_and_committed_array(self): def test_pjit_uncommitted_array_multi_devices(self): shape = (8, 2) - mesh = jtu.create_global_mesh((4, 2), ('x', 'y')) + mesh = jtu.create_mesh((4, 2), ('x', 'y')) inp = np.arange(math.prod(shape), dtype=np.int32).reshape(shape) arr = array.ArrayImpl( core.ShapedArray(shape, np.int32), NamedSharding(mesh, P(None)), @@ -2302,7 +2340,7 @@ def test_pjit_committed_array_different_devices_variadic_args(self): pjit(lambda *x: x)(a, b) def test_pjit_pytree_inp_device_assignment_mismatch(self): - mesh = jtu.create_global_mesh((2, 2), ('x', 'y')) + mesh = jtu.create_mesh((2, 2), ('x', 'y')) a = jax.device_put(np.array([1, 2, 3]), jax.devices()[0]) b = jax.device_put(np.array([4, 5, 6]), jax.devices()[1]) c = jax.device_put(np.arange(16).reshape(8, 2), @@ -2311,7 +2349,7 @@ def test_pjit_pytree_inp_device_assignment_mismatch(self): msg = ("Received incompatible devices for pjitted computation. Got " r"argument {} of.* with shape int.*\[3\] and device ids " r"\[0\].*and argument {} of.* with shape int.*\[8,2\] and " - r"device ids \[0, 1, 2, 3\].*") + r"device ids.*") with self.assertRaisesRegex( ValueError, msg.format(r'tuple_inp\[0\]', r'tuple_inp\[1\]\[0\]')): @@ -2325,7 +2363,7 @@ def test_pjit_pytree_inp_device_assignment_mismatch(self): def test_same_out_sharding_id(self): shape = (8, 2) - mesh = jtu.create_global_mesh((4, 2), ('x', 'y')) + mesh = jtu.create_mesh((4, 2), ('x', 'y')) arr, inp_data = create_array(shape, mesh, P('x', 'y')) f = pjit(lambda x: x) @@ -2347,7 +2385,7 @@ def test_same_out_sharding_id(self): def test_out_sharding_indices_id_cache_hit(self): shape = (8, 2) - mesh = jtu.create_global_mesh((4, 2), ('x', 'y')) + mesh = jtu.create_mesh((4, 2), ('x', 'y')) arr, _ = create_array(shape, mesh, P('x', 'y')) f = pjit(lambda x: x) @@ -2383,7 +2421,7 @@ def f(tree): @jax.enable_custom_prng() def test_device_put_sharding_prng(self): - mesh = jtu.create_global_mesh((8,), ('x',)) + mesh = jtu.create_mesh((8,), ('x',)) s = NamedSharding(mesh, P('x')) x = jax.random.split(jax.random.PRNGKey(0), len(jax.devices())) @@ -2401,6 +2439,10 @@ def test_device_put_sharding_prng(self): self.assertTrue(jax.dtypes.issubdtype(a.dtype, jax.dtypes.prng_key)) self.assertEqual(a.sharding, out_p.sharding) + if config.use_shardy_partitioner.value: + # OpSharding is not supported in shardy. + return + op = xc.OpSharding() op.type = xc.OpSharding.Type.OTHER op.tile_assignment_dimensions = [8] @@ -2411,7 +2453,7 @@ def test_device_put_sharding_prng(self): self.assertEqual(b.sharding, gs) def test_device_put_on_different_sharding(self): - mesh = jtu.create_global_mesh((4, 2), ('x', 'y')) + mesh = jtu.create_mesh((4, 2), ('x', 'y')) x = jnp.arange(8).reshape(4, 2) s1 = NamedSharding(mesh, P('x')) @@ -2423,7 +2465,7 @@ def test_device_put_on_different_sharding(self): self.assertEqual(b.sharding, s2) def test_with_sharding_constraint_jit(self): - mesh = jtu.create_global_mesh((2, 2), ('x', 'y')) + mesh = jtu.create_mesh((2, 2), ('x', 'y')) @partial(jax.jit, static_argnums=(0, 1)) def sharded_zeros(shape, pspec): @@ -2437,7 +2479,7 @@ def sharded_zeros(shape, pspec): out_s._to_xla_hlo_sharding(out.ndim))) def test_with_sharding_constraint_pjit(self): - mesh = jtu.create_global_mesh((2, 2), ('x', 'y')) + mesh = jtu.create_mesh((2, 2), ('x', 'y')) @partial(pjit, static_argnums=(0, 1)) def sharded_zeros(shape, pspec): @@ -2451,7 +2493,7 @@ def sharded_zeros(shape, pspec): out_s._to_xla_hlo_sharding(out.ndim))) def test_jit_with_sharding_constraint_committed_inp_error(self): - mesh = jtu.create_global_mesh((2, 2), ('x', 'y')) + mesh = jtu.create_mesh((2, 2), ('x', 'y')) s = NamedSharding(mesh, P('x', 'y')) @@ -2465,7 +2507,7 @@ def sharded_inp(inp): ValueError, "Received incompatible devices for jitted computation. Got argument " r"inp of.*sharded_inp with shape bfloat16\[8,2\] and device ids \[0\].*" - r"sharding_constraint inside jit with device ids \[0, 1, 2, 3\].*"): + r"sharding_constraint inside jit with device ids.*"): sharded_inp(committed_inp) @pjit @@ -2479,13 +2521,13 @@ def f(x, y, z): ValueError, "Received incompatible devices for pjitted computation. Got argument " r"inp1 of.*my_nested_pjit with shape bfloat16\[8,2\] and device ids \[0\].*" - r"pjit inside pjit with device ids \[0, 1, 2, 3\].*"): + r"pjit inside pjit with device ids.*"): my_nested_pjit(committed_inp, committed_inp, committed_inp) @jtu.ignore_warning(category=DeprecationWarning, message="backend and device argument") def test_jit_device_with_sharding_constraint_error(self): - mesh = jtu.create_global_mesh((2, 2), ('x', 'y')) + mesh = jtu.create_mesh((2, 2), ('x', 'y')) @partial(jax.jit, static_argnums=(0, 1), device=jax.devices()[0]) def sharded_zeros(shape, pspec): @@ -2496,11 +2538,11 @@ def sharded_zeros(shape, pspec): ValueError, "Received incompatible devices for jitted computation. Got explicit " r"output sharding with device ids \[0\].*sharding_constraint inside " - r"jit with device ids \[0, 1, 2, 3\].*"): + r"jit with device ids.*"): sharded_zeros((4096, 3072), P('x', 'y')) def test_concurrent_pjit(self): - global_mesh = jtu.create_global_mesh((1,), ('x',)) + global_mesh = jtu.create_mesh((1,), ('x',)) sharding = NamedSharding(global_mesh, P('x',)) n = 10 with global_mesh: @@ -2525,7 +2567,7 @@ def _invoke_with_mesh_twice(arg_tuple): def test_trivial_computation(self): shape = (8, 2) - mesh = jtu.create_global_mesh((2, 2), ('x', 'y')) + mesh = jtu.create_mesh((2, 2), ('x', 'y')) s = NamedSharding(mesh, P('x', 'y')) inp_data = np.arange(math.prod(shape)).reshape(shape) arr = jax.device_put(inp_data, s) @@ -2533,7 +2575,7 @@ def test_trivial_computation(self): self.assertArraysEqual(out, inp_data) def test_trivial_computation_with_sharded_const(self): - mesh = jtu.create_global_mesh((2, 1), ('x', 'y')) + mesh = jtu.create_mesh((2, 1), ('x', 'y')) const = jax.device_put(np.arange(16).reshape(8, 2), NamedSharding(mesh, P('x', 'y'))) with mesh: @@ -2542,17 +2584,17 @@ def test_trivial_computation_with_sharded_const(self): self.assertArraysEqual(out, np.arange(16).reshape(8, 2)) def test_trivial_computation_with_sharded_const_using_transposed_mesh(self): - mesh = jtu.create_global_mesh((2, 1), ('x', 'y')) + mesh = jtu.create_mesh((2, 1), ('x', 'y')) const = jax.device_put(np.arange(16).reshape(8, 2), NamedSharding(mesh, P('x', 'y'))) - mesh2 = jtu.create_global_mesh((1, 2), ('x', 'y')) + mesh2 = jtu.create_mesh((1, 2), ('x', 'y')) with mesh2: out = pjit(lambda: const)() self.assertIsInstance(out, array.ArrayImpl) self.assertArraysEqual(out, np.arange(16).reshape(8, 2)) def test_trivial_computation_with_replicated_literal(self): - mesh = jtu.create_global_mesh((2, 1), ('x', 'y')) + mesh = jtu.create_mesh((2, 1), ('x', 'y')) with mesh: out = pjit(lambda: 1)() self.assertEqual(out.sharding, NamedSharding(mesh, P())) @@ -2561,7 +2603,7 @@ def test_trivial_computation_with_replicated_literal(self): def test_multi_device_pjit_mul(self): shape = (8, 2) - mesh = jtu.create_global_mesh((4, 2), ('x', 'y')) + mesh = jtu.create_mesh((4, 2), ('x', 'y')) inp_data = np.arange(math.prod(shape)).reshape(shape) arr1 = jax.device_put(inp_data, NamedSharding(mesh, P('x', 'y'))) arr2 = jax.device_put(inp_data, NamedSharding(mesh, P(None, 'y'))) @@ -2575,7 +2617,7 @@ def test_multi_device_pjit_mul(self): def test_single_device_pjit_cpp_dispatch(self): shape = (8, 2) - mesh = jtu.create_global_mesh((1,), ('x',)) + mesh = jtu.create_mesh((1,), ('x',)) inp_data = np.arange(math.prod(shape)).reshape(shape) f = pjit(lambda x: x @ x.T, in_shardings=None, out_shardings=None) @@ -2601,7 +2643,7 @@ def test_single_device_add_single_compile(self): def test_global_array_to_host_local_array_already_host_local(self): inp_shape = (8, 2) - mesh = jtu.create_global_mesh((4, 2), ('x', 'y')) + mesh = jtu.create_mesh((4, 2), ('x', 'y')) pspec = P('x', 'y') arr, _ = create_array(inp_shape, mesh, pspec) @@ -2623,7 +2665,7 @@ def f(c, x): self.assertAllClose(exe(x), x + 1, check_dtypes=False) def test_vmap_of_jvp_pjit_no_axis_resources(self): - mesh = jtu.create_global_mesh((2, 2), ('x', 'y')) + mesh = jtu.create_mesh((2, 2), ('x', 'y')) pjit_inp1 = jax.device_put( jnp.arange(8.), jax.sharding.NamedSharding(mesh, P('x'))) pjit_inp2 = jax.device_put( @@ -2649,7 +2691,7 @@ def g_(x, n): self.assertArraysEqual(pjit_out2, jit_out2) def test_vmap_of_jvp_pjit_no_axis_resources_2d(self): - mesh = jtu.create_global_mesh((2, 2), ('x', 'y')) + mesh = jtu.create_mesh((2, 2), ('x', 'y')) f_inp = jnp.arange(8.).reshape(2, 2, 2) # g_inp is sharded with P(None, 'x') because f_inp is sharded with P('x') @@ -2886,7 +2928,7 @@ def f(x, y, z, a, b, c): # pylint: disable=unused-argument self.assertLen(compiled._executable.in_avals, 1) def test_pjit_relayout_multi_slice(self): - mesh = jtu.create_global_mesh((2, 2), ('x', 'y')) + mesh = jtu.create_mesh((2, 2), ('x', 'y')) @jax.jit def mul(x): @@ -2908,7 +2950,7 @@ def _check(out, expected_device, expected_out): self.assertLen(out.sharding.device_set, 1) self.assertArraysEqual(out, expected_out @ expected_out.T) - mesh = jtu.create_global_mesh((2, 2), ('x', 'y')) + mesh = jtu.create_mesh((2, 2), ('x', 'y')) with jtu.ignore_warning(category=DeprecationWarning, message="backend and device argument"): @@ -2943,7 +2985,7 @@ def _check(out, expected_device, expected_out): _check(out3, jax.devices()[1], y) def test_pjit_with_device_arg_input_from_another_pjit(self): - mesh = jtu.create_global_mesh((2, 2), ('x', 'y')) + mesh = jtu.create_mesh((2, 2), ('x', 'y')) inp = np.arange(8).reshape(4, 2) y = jax.device_put(inp, jax.sharding.NamedSharding(mesh, P('x', 'y'))) @@ -3023,7 +3065,7 @@ def test_pjit_device_backend_both_error(self): pjit(lambda x: x, device=jax.devices()[0], backend='cpu') def test_pjit_mesh_with_device_or_backend_error(self): - mesh = jtu.create_global_mesh((1,), ('x',)) + mesh = jtu.create_mesh((1,), ('x',)) with mesh: with self.assertRaisesRegex( ValueError, @@ -3074,7 +3116,7 @@ def test_pmap_sharding_input_to_pjit_single_device(self): self.assertLen(out.devices(), 1) def test_pmap_sharding_input_to_pjit_multi_device(self): - mesh = jtu.create_global_mesh((2, 2), ('x', 'y')) + mesh = jtu.create_mesh((2, 2), ('x', 'y')) pmap_out = jax.pmap(lambda x: x)(jnp.arange(jax.device_count())) self.assertIsInstance(pmap_out.sharding, jax.sharding.PmapSharding) @@ -3093,7 +3135,7 @@ def test_pmap_sharding_input_to_pjit_multi_device(self): out2.sharding._to_xla_hlo_sharding(inp2.ndim))) def test_pmap_sharding_input_pjit_in_axis_resources(self): - mesh = jtu.create_global_mesh((2, 2), ('x', 'y')) + mesh = jtu.create_mesh((2, 2), ('x', 'y')) pmap_out = jax.pmap(lambda x: x)(jnp.arange(jax.device_count())) self.assertIsInstance(pmap_out.sharding, jax.sharding.PmapSharding) @@ -3163,7 +3205,7 @@ def g(z): f(inp) # doesn't crash def test_pjit_sin_nested(self): - mesh = jtu.create_global_mesh((4, 2), ('x', 'y')) + mesh = jtu.create_mesh((4, 2), ('x', 'y')) @pjit def f(x): @@ -3176,7 +3218,7 @@ def f(x): self.assertLen(out.devices(), 8) def test_jit_with_mesh_context_manager(self): - mesh = jtu.create_global_mesh((1,), ('x',)) + mesh = jtu.create_mesh((1,), ('x',)) with self.assertRaisesRegex( RuntimeError, "jax.jit only supports `Sharding`s being passed to " @@ -3237,7 +3279,7 @@ def f(x): self.assertEqual(count[0], 1) def test_pjit_no_global_cache_hit_axis_resources(self): - mesh = jtu.create_global_mesh((1,), ('x',)) + mesh = jtu.create_mesh((1,), ('x',)) s = NamedSharding(mesh, P('x')) inp = jnp.arange(8.0) @@ -3268,7 +3310,7 @@ def test_pjit_no_global_cache_hit_axis_resources(self): self.assertEqual(count[0], 1) def test_with_sharding_constraint_spmd_axis_name(self): - mesh = jtu.create_global_mesh((2, 2, 2), ('replica', 'data', 'mdl')) + mesh = jtu.create_mesh((2, 2, 2), ('replica', 'data', 'mdl')) shape = (8, 4, 2, 2) x = jnp.arange(math.prod(shape)).reshape(shape) @@ -3291,7 +3333,7 @@ def apply_with_scan(x): self.assertListEqual(ns2, [2, 2, 1, 1]) def test_device_put_sharding_nondivisible_sharding_error(self): - mesh = jtu.create_global_mesh((2,), ('x',)) + mesh = jtu.create_mesh((2,), ('x',)) s = NamedSharding(mesh, P('x')) x = jnp.ones((1,)) @@ -3406,7 +3448,9 @@ def g(x): jtu.check_grads(g, (arr,), order=2) def test_pjit_out_sharding_preserved(self): - mesh = jtu.create_global_mesh((2, 1), ('x', 'y')) + if config.use_shardy_partitioner.value: + raise unittest.SkipTest("Shardy doesn't support PositionalSharding") + mesh = jtu.create_mesh((2, 1), ('x', 'y')) ns = NamedSharding(mesh, P('x')) ps = PositionalSharding(jax.devices()[:2]).reshape(2, 1) @@ -3448,11 +3492,11 @@ def mul(x): cache_info4 = pxla._cached_compilation.cache_info() self.assertIsInstance(out4.sharding, PositionalSharding) - self.assertEqual(cache_info4.hits, cache_info3.hits) - self.assertEqual(cache_info4.misses, cache_info3.misses + 1) + self.assertEqual(cache_info4.hits, cache_info3.hits + 1) + self.assertEqual(cache_info4.misses, cache_info3.misses) def test_cache_hit_pjit_lower_with_cpp_cache_miss(self): - mesh = jtu.create_global_mesh((2, 1), ('x', 'y')) + mesh = jtu.create_mesh((2, 1), ('x', 'y')) ns = NamedSharding(mesh, P('x')) np_arr = np.arange(8, dtype=np.float32).reshape(8, 1) arr = jax.device_put(np_arr, ns) @@ -3478,13 +3522,15 @@ def mul(x): self.assertEqual(cache_info2.misses, cache_info1.misses) def test_list_in_pspec(self): - mesh = jtu.create_global_mesh((2,), ('x',)) + mesh = jtu.create_mesh((2,), ('x',)) with mesh: out = with_sharding_constraint(jnp.arange(8), P(['x'])) self.assertEqual(out.sharding, NamedSharding(mesh, P('x'))) def test_sharding_preserved_trivial(self): - mesh = jtu.create_global_mesh((2, 1), ('x', 'y')) + if config.use_shardy_partitioner.value: + raise unittest.SkipTest("Shardy doesn't support PositionalSharding") + mesh = jtu.create_mesh((2, 1), ('x', 'y')) ns = NamedSharding(mesh, P('x')) ps = PositionalSharding(jax.devices()[:2]).reshape(2, 1) @@ -3501,7 +3547,7 @@ def identity(x): self.assertIsInstance(out2.sharding, PositionalSharding) def test_sharding_preserved_aot(self): - mesh = jtu.create_global_mesh((2, 1), ('x', 'y')) + mesh = jtu.create_mesh((2, 1), ('x', 'y')) ns = NamedSharding(mesh, P('x')) ps = PositionalSharding(jax.devices()[:2]).reshape(2, 1) @@ -3518,12 +3564,12 @@ def test_sharding_preserved_aot(self): self.assertIsInstance(out2.sharding, NamedSharding) def test_sharding_on_output_with_vmap(self): - mesh = jtu.create_global_mesh((2, 1), ('x', 'y')) + mesh = jtu.create_mesh((2, 1), ('x', 'y')) ns = NamedSharding(mesh, P('x')) arr = jax.device_put( np.arange(16).reshape(8, 2), NamedSharding(mesh, P(None, 'x'))) - with jtu.count_jit_and_pmap_compiles() as count: + with jtu.count_jit_and_pmap_lowerings() as count: vf = jax.vmap(pjit(lambda x: x * 2, in_shardings=ns)) out = vf(arr) self.assertIsInstance(out.sharding, NamedSharding) @@ -3536,7 +3582,9 @@ def test_sharding_on_output_with_vmap(self): self.assertEqual(count[0], 1) def test_jit_mul_sum_sharding_preserved(self): - mesh = jtu.create_global_mesh((2, 1), ('x', 'y')) + if config.use_shardy_partitioner.value: + raise unittest.SkipTest("Shardy doesn't support PositionalSharding") + mesh = jtu.create_mesh((2, 1), ('x', 'y')) ns = NamedSharding(mesh, P('x')) ps = PositionalSharding(jax.devices()[:2]).reshape(2, 1) @@ -3560,8 +3608,8 @@ def test_jit_mul_sum_sharding_preserved(self): self.assertIsInstance(out3.sharding, PositionalSharding) self.assertEqual(count[0], 1) - self.assertEqual(cache_info2.hits, cache_info1.hits) - self.assertEqual(cache_info2.misses, cache_info1.misses + 1) + self.assertEqual(cache_info2.hits, cache_info1.hits + 1) + self.assertEqual(cache_info2.misses, cache_info1.misses) self.assertEqual(pl_cache_info2.hits, pl_cache_info1.hits) self.assertEqual(pl_cache_info2.misses, pl_cache_info1.misses + 1) @@ -3594,7 +3642,7 @@ def test_single_device_sharding_preserved(self): self.assertEqual(out4.devices(), {jax.devices()[1]}) def test_none_out_sharding(self): - mesh = jtu.create_global_mesh((2, 1), ('x', 'y')) + mesh = jtu.create_mesh((2, 1), ('x', 'y')) x = jnp.arange(8) with mesh: out = pjit(lambda x: x * 2, out_shardings=None)(x) @@ -3609,7 +3657,9 @@ def test_none_out_sharding(self): self.assertEqual(out2.sharding.spec, P()) def test_sharding_preserved_apply_primitive(self): - mesh = jtu.create_global_mesh((2, 1), ('x', 'y')) + if config.use_shardy_partitioner.value: + raise unittest.SkipTest("Shardy doesn't support PositionalSharding") + mesh = jtu.create_mesh((2, 1), ('x', 'y')) ns = NamedSharding(mesh, P('x')) arr = jax.device_put(np.arange(8).reshape(8, 1), ns) @@ -3632,14 +3682,14 @@ def test_sharding_preserved_apply_primitive(self): self.assertEqual(out4.devices(), {jax.devices()[1]}) def test_same_named_sharding_pspec_on_eager_ops(self): - mesh = jtu.create_global_mesh((1, 8, 1), ('x', 'y', 'z')) + mesh = jtu.create_mesh((1, 8, 1), ('x', 'y', 'z')) sharding = jax.sharding.NamedSharding(mesh, P('x', 'y', 'z')) x = jax.device_put(jnp.arange(32).reshape(1, -1, 1), sharding) y = x + 1 self.assertEqual(x.sharding, y.sharding) def test_different_named_sharding_object_replicated(self): - mesh = jtu.create_global_mesh((1, 2), ('x', 'y')) + mesh = jtu.create_mesh((1, 2), ('x', 'y')) sharding = jax.sharding.NamedSharding(mesh, P('x')) x = jax.device_put(np.arange(16).reshape(8, 2), sharding) y = jnp.sum(x) @@ -3653,7 +3703,7 @@ def test_vmap_pjit_single_device(self): self.assertIsInstance(out.sharding, SingleDeviceSharding) def test_to_gspmd_sharding_cache_with_and_without_device(self): - mesh = jtu.create_global_mesh((2,), ('x',)) + mesh = jtu.create_mesh((2,), ('x',)) np_inp = jnp.arange(4) def identity(x): @@ -3692,7 +3742,7 @@ def top(x): self.assertEqual(count[0], 1) def test_wsc_eager(self): - mesh = jtu.create_global_mesh((2,), ('x',)) + mesh = jtu.create_mesh((2,), ('x',)) np_inp = np.arange(8) inp = jax.device_put(np_inp, NamedSharding(mesh, P())) out = with_sharding_constraint(inp, NamedSharding(mesh, P('x'))) @@ -3702,14 +3752,14 @@ def test_wsc_eager(self): self.assertArraysEqual(s.data, np_inp[s.index]) def test_wsc_eager_no_resharding(self): - mesh = jtu.create_global_mesh((2,), ('x',)) + mesh = jtu.create_mesh((2,), ('x',)) np_inp = np.arange(8) inp = jax.device_put(np_inp, NamedSharding(mesh, P('x'))) out = with_sharding_constraint(inp, NamedSharding(mesh, P('x'))) self.assertEqual(id(out), id(inp)) def test_wsc_eager_different_order_devices(self): - mesh1 = jtu.create_global_mesh((2,), ('x',)) + mesh1 = jtu.create_mesh((2,), ('x',)) mesh2 = jax.sharding.Mesh([jax.devices()[1], jax.devices()[0]], 'x') inp = jax.device_put(np.arange(8), NamedSharding(mesh1, P())) with self.assertRaisesRegex( @@ -3751,7 +3801,7 @@ def test_shape_dtype_struct_as_const_error(self): jax.jit(lambda x: (x, const))(jnp.arange(8)) def test_jit_out_shardings_none(self): - mesh = jtu.create_global_mesh((4, 2), ('x', 'y')) + mesh = jtu.create_mesh((4, 2), ('x', 'y')) np_inp = np.arange(16).reshape(8, 2) s = NamedSharding(mesh, P('x', 'y')) inp = jax.device_put(np_inp, s) @@ -3760,7 +3810,7 @@ def test_jit_out_shardings_none(self): self.assertEqual(out.sharding, s) def test_jit_in_shardings_none(self): - mesh = jtu.create_global_mesh((4, 2), ('x', 'y')) + mesh = jtu.create_mesh((4, 2), ('x', 'y')) np_inp = np.arange(16).reshape(8, 2) s = NamedSharding(mesh, P('x', 'y')) inp = jax.device_put(np_inp, s) @@ -3773,8 +3823,22 @@ def test_jit_in_shardings_none(self): self.assertArraysEqual(out2, np_inp * 2) self.assertEqual(out2.sharding, SingleDeviceSharding(jax.devices()[0])) + def test_device_put_in_jit_default_mem_kind_no_op(self): + mesh = jtu.create_mesh((2,), 'x') + np_inp = np.arange(8) + arr = jax.device_put(np_inp, NamedSharding(mesh, P('x'))) + + @jax.jit + def f(x): + y = x * 2 + return jax.device_put(y, NamedSharding(mesh, P())) + + lowered_text = f.lower(arr).as_text() + self.assertNotIn('@Sharding', lowered_text) + self.assertNotIn('@annotate_device_placement', lowered_text) + def test_jit_both_shardings_none(self): - mesh = jtu.create_global_mesh((4, 2), ('x', 'y')) + mesh = jtu.create_mesh((4, 2), ('x', 'y')) np_inp = np.arange(16).reshape(8, 2) s = NamedSharding(mesh, P('x', 'y')) inp = jax.device_put(np_inp, s) @@ -3788,7 +3852,7 @@ def test_jit_both_shardings_none(self): self.assertEqual(out2.sharding, SingleDeviceSharding(jax.devices()[0])) def test_jit_lower_shape_dtype_struct_sharding_none(self): - mesh = jtu.create_global_mesh((4, 2), ('x', 'y')) + mesh = jtu.create_mesh((4, 2), ('x', 'y')) s = NamedSharding(mesh, P('x', 'y')) lower_inp1 = jax.ShapeDtypeStruct((8, 2), np.int32, sharding=s) @@ -3824,21 +3888,8 @@ def f(inp): ' manager.*SingleDeviceSharding'): jax.jit(jax.vmap(f, spmd_axis_name='x'))(arr) - @jtu.skip_on_devices("tpu", "gpu") - def test_device_put_memory_kind_not_tpu_gpu(self): - @jax.jit - def f(x): - y = x * 2 - return jax.device_put(y, sharding_impls.TransferToMemoryKind('unpinned_host')) - - with self.assertRaisesRegex( - NotImplementedError, - 'Passing memory_kind to device_put via Shardings is not supported on' - ' platform.*'): - f(jnp.arange(8)) - def test_no_output_multiple_devices(self): - mesh = jtu.create_global_mesh((2,), ('x',)) + mesh = jtu.create_mesh((2,), ('x',)) @pjit def f(): @@ -3848,11 +3899,14 @@ def f(): f() # doesn't crash def test_lowering_cache_hit_different_devices(self): + if config.use_shardy_partitioner.value: + self.skipTest('b/358322664: different axis names results in ' + 'a cache miss with Shardy.') if jax.device_count() < 4: self.skipTest('Requires >=4 devices') mesh1 = jax.sharding.Mesh(jax.devices()[:2], 'x') - mesh2 = jax.sharding.Mesh(jax.devices()[2:4], 'x') + mesh2 = jax.sharding.Mesh(jax.devices()[2:4], 'y') @jax.jit def f(x): @@ -3863,10 +3917,10 @@ def g(a): out_a = f(a) # lowering cached # same num_devices but different devices. - b = jax.device_put(out_a, NamedSharding(mesh2, P('x'))) + b = jax.device_put(out_a, NamedSharding(mesh2, P('y'))) f(b) # lowering cache *hit* - with jtu.count_jit_and_pmap_compiles() as count: + with jtu.count_jit_and_pmap_lowerings() as count: g(np.arange(8)) self.assertEqual(count[0], 1) @@ -3889,7 +3943,7 @@ def g(a): b = jax.device_put(out_a, NamedSharding(mesh2, P())) f(b) # lowering cache *miss* - with jtu.count_jit_and_pmap_compiles() as count: + with jtu.count_jit_and_pmap_lowerings() as count: g(np.arange(8)) self.assertEqual(count[0], 2) @@ -3926,7 +3980,7 @@ def test_mpmd_device_put_fast_path(self): def test_prng_sharding_propagation(self): input_shape = (8, 2) - mesh = jtu.create_global_mesh((4, 2), ('x', 'y')) + mesh = jtu.create_mesh((4, 2), ('x', 'y')) spec = P('x', 'y') seeds, _ = create_array(input_shape, mesh, spec, dtype=np.uint32) @@ -3945,11 +3999,14 @@ def make_keys(seeds): self.assertEqual(base_array.sharding, NamedSharding(mesh, P('y', 'x', None))) lowered_text = make_keys.lower(seeds).as_text() - self.assertIn('unspecified_dims=[0,1]', lowered_text) + if config.use_shardy_partitioner.value: + self.assertIn('<@mesh, [{?}, {?}, {}]>', lowered_text) + else: + self.assertIn('unspecified_dims=[0,1]', lowered_text) def test_prng_sharding_propagation_with_nested_jit(self): input_shape = (8, 2) - mesh = jtu.create_global_mesh((4, 2), ('x', 'y')) + mesh = jtu.create_mesh((4, 2), ('x', 'y')) spec = P('x', 'y') seeds, _ = create_array(input_shape, mesh, spec, dtype=np.uint32) @@ -3971,11 +4028,14 @@ def f(): self.assertEqual(base_array.sharding, NamedSharding(mesh, P(None, 'y', None))) lowered_text = make_keys.lower(seeds).as_text() - self.assertIn('unspecified_dims=[0,1]', lowered_text) + if config.use_shardy_partitioner.value: + self.assertIn('<@mesh, [{?}, {?}, {}]>', lowered_text) + else: + self.assertIn('unspecified_dims=[0,1]', lowered_text) def test_partial_sharded_prng_key_inp(self): input_shape = (8, 2, 2) - mesh = jtu.create_global_mesh((2, 2, 2), ('x', 'y', 'z')) + mesh = jtu.create_mesh((2, 2, 2), ('x', 'y', 'z')) spec = P('x', 'y', None) seeds, _ = create_array(input_shape, mesh, spec, dtype=np.uint32) @@ -3995,11 +4055,14 @@ def make_keys(seeds): self.assertEqual(base_array.sharding, NamedSharding(mesh, P(None, 'y', 'x'))) lowered_text = make_keys.lower(seeds).as_text() - self.assertIn('unspecified_dims=[0,1,2]', lowered_text) + if config.use_shardy_partitioner.value: + self.assertIn('<@mesh, [{?}, {?}, {?}, {}]>', lowered_text) + else: + self.assertIn('unspecified_dims=[0,1,2]', lowered_text) def test_jit_partially_specified_shardings(self): - mesh = jtu.create_global_mesh((2, 2), ('x', 'y')) + mesh = jtu.create_mesh((2, 2), ('x', 'y')) np_inp = np.arange(16).reshape(8, 2) s = NamedSharding(mesh, P('x', 'y')) s2 = NamedSharding(mesh, P('x')) @@ -4019,7 +4082,7 @@ def f(x, y, z, a, b): self.assertArraysEqual(out5, np_inp.T) def test_input_shardings_aot(self): - mesh = jtu.create_global_mesh((2, 1), ('x', 'y')) + mesh = jtu.create_mesh((2, 1), ('x', 'y')) np_inp = np.arange(16).reshape(8, 2) arr = jax.device_put(np_inp, NamedSharding(mesh, P('x'))) @@ -4035,7 +4098,7 @@ def test_parameter_tupled_jit(self): if not jtu.test_device_matches(["tpu"]): self.skipTest('Parameters are tupled only on TPU if >2000 parameters') - mesh = jtu.create_global_mesh((2, 2), ('x', 'y')) + mesh = jtu.create_mesh((2, 2), ('x', 'y')) s = NamedSharding(mesh, P('x')) @jax.jit @@ -4048,7 +4111,9 @@ def f(*args): f(inps) # doesn't crash def test_spmd_preserves_input_sharding_vmap_grad(self): - # https://github.com/google/jax/issues/20710 + if config.use_shardy_partitioner.value: + self.skipTest("Shardy doesn't support PositionalSharding") + # https://github.com/jax-ml/jax/issues/20710 n_devices = jax.device_count() sharding = PositionalSharding(jax.devices()) @@ -4084,7 +4149,7 @@ def test_jit_token_input(self): self.assertIsInstance(out2, core.Token) def test_uneven_sharding_wsc(self): - mesh = jtu.create_global_mesh( + mesh = jtu.create_mesh( (2, 1, 1, 1, 1), ('data', 'expert', 'fsdp', 'seq', 'model') ) @@ -4131,7 +4196,7 @@ def get_wsc_eqn_sharding(jaxpr): for s in core.subjaxprs(jaxpr): return get_wsc_eqn_sharding(s) - mesh = jtu.create_global_mesh((2, 1), ('x', 'y')) + mesh = jtu.create_mesh((2, 1), ('x', 'y')) inp = jnp.ones((10, 10)) def a_function(x): @@ -4211,6 +4276,9 @@ def f(x): self.assertArraysEqual(out2, np.arange(8) * 2) def test_device_put_efficient_reshard_single_host(self): + if config.use_shardy_partitioner.value: + self.skipTest( + '_different_device_order_reshard is creating a GSPMDSharding') if jax.device_count() < 4: self.skipTest('Requires >= 4 devices') @@ -4235,6 +4303,9 @@ def test_device_put_efficient_reshard_single_host(self): ("8_384", (8, 384)), ) def test_device_put_efficient_reshard_complex_mesh(self, shape): + if config.use_shardy_partitioner.value: + self.skipTest( + '_different_device_order_reshard is creating a GSPMDSharding') if jax.device_count() < 8: self.skipTest('Requires >= 8 devices') @@ -4268,7 +4339,7 @@ def test_device_put_efficient_reshard_complex_mesh(self, shape): self.assertEqual(out2.sharding, s1) def test_convert_element_type_sharding(self): - mesh = jtu.create_global_mesh((2, 2), ('x', 'y')) + mesh = jtu.create_mesh((2, 2), ('x', 'y')) s = NamedSharding(mesh, P('x', 'y')) inp = np.arange(16).reshape(8, 2) @@ -4279,7 +4350,9 @@ def test_convert_element_type_sharding(self): self.assertEqual(out.sharding, s) def test_jnp_array_sharding(self): - mesh = jtu.create_global_mesh((2, 2), ('x', 'y')) + if jax.device_count() < 4: + self.skipTest('Requires >=4 devices') + mesh = jax.make_mesh((2, 2), ('x', 'y'), devices=jax.devices()[:4]) s = NamedSharding(mesh, P('x', 'y')) inp = np.arange(16).reshape(8, 2) @@ -4288,7 +4361,9 @@ def test_jnp_array_sharding(self): self.assertEqual(out.sharding, s) def test_jnp_array_inside_jit_sharding(self): - mesh = jtu.create_global_mesh((2, 2), ('x', 'y')) + if jax.device_count() < 4: + self.skipTest('Requires >=4 devices') + mesh = jax.make_mesh((2, 2), ('x', 'y')) s = NamedSharding(mesh, P('x', 'y')) inp = np.arange(16).reshape(8, 2) @@ -4328,7 +4403,7 @@ def test_jnp_array_sharded_array_no_op(self): self.assertEqual(out.unsafe_buffer_pointer(), arr.unsafe_buffer_pointer()) def test_wsc_named_sharding_nullary(self): - mesh = jtu.create_global_mesh((2,), ('x',)) + mesh = jtu.create_mesh((2,), ('x',)) s = NamedSharding(mesh, P()) @jax.jit @@ -4340,7 +4415,7 @@ def f(): @jtu.run_on_devices('tpu', 'gpu') def test_aot_device_mismatch(self): - mesh = jtu.create_global_mesh((1,), 'x') + mesh = jtu.create_mesh((1,), 'x') np_inp = np.arange(8) arr = jax.device_put(np_inp, NamedSharding(mesh, P())) @@ -4356,11 +4431,153 @@ def f(x): "Compiled object called with input sharding.*does not match"): compiled(cpu_arr) + def test_different_devices_wsc_abstract_mesh_cache_hit(self): + if jax.device_count() < 4: + self.skipTest('Requires >=4 devices') + + mesh1 = jax.sharding.Mesh(jax.devices()[:2], 'x') + mesh2 = jax.sharding.Mesh(jax.devices()[2:4], 'x') + + @jax.jit + def f(x): + x = with_sharding_constraint( + x, NamedSharding(mesh_lib.AbstractMesh(mesh1.shape_tuple), P('x'))) + return jnp.sin(x) + + with ( + jtu.count_jit_tracing_cache_miss() as tracing_count, + jtu.count_jit_and_pmap_lowerings() as lowering_count, + jtu.count_jit_compilation_cache_miss() as compilation_count, + ): + a = jax.device_put(np.arange(8.), NamedSharding(mesh1, P())) + out_a = f(a) # tracing and lowering cached + + # same num_devices but different devices. + b = jax.device_put(out_a, NamedSharding(mesh2, P())) + f(b) # tracing and lowering cache *hit* + self.assertEqual(tracing_count[0], 2) # 1 miss for `f` and 1 miss for `sin` + self.assertEqual(lowering_count[0], 1) + self.assertEqual(compilation_count[0], 2) # 2 misses since devices differ. + + def test_wsc_abstract_mesh(self): + mesh = jtu.create_mesh((2, 2), ('x', 'y')) + np_inp = np.arange(16).reshape(8, 2) + arr = jax.device_put(np_inp, NamedSharding(mesh, P('x', 'y'))) + + abstract_mesh = jax.sharding.AbstractMesh(mesh.shape_tuple) + + def f(x): + x = with_sharding_constraint(x, NamedSharding(abstract_mesh, P('x'))) + return x * 2 + + out = jax.jit(f)(arr) + self.assertArraysEqual(out, np_inp * 2) + self.assertEqual(out.sharding, NamedSharding(mesh, P('x'))) + + out_eager = f(arr) + self.assertArraysEqual(out_eager, np_inp * 2) + self.assertEqual(out_eager.sharding, NamedSharding(mesh, P('x'))) + + def test_wsc_sds_abstract_mesh(self): + mesh = jtu.create_mesh((2,), 'x') + s = NamedSharding(mesh, P()) + abstract_mesh = mesh_lib.AbstractMesh(mesh.shape_tuple) + + @jax.jit + def f(x): + x = with_sharding_constraint(x, NamedSharding(abstract_mesh, P('x'))) + return x * 2 + + sds = jax.ShapeDtypeStruct((8, 2), np.float32, sharding=s) + f.eval_shape(sds) # doesn't crash + + def test_wsc_vmap_abstract_mesh(self): + mesh = jtu.create_mesh((2, 2), ('x', 'y')) + s = NamedSharding(mesh, P('x', 'y')) + np_inp = np.arange(16).reshape(8, 2) + arr = jax.device_put(np_inp, s) + + def f(x): + x = with_sharding_constraint(x, NamedSharding(mesh.abstract_mesh, P('x'))) + return x * 2 + + out = jax.jit(jax.vmap(f))(arr) # doesn't crash + self.assertEqual(out.sharding, NamedSharding(mesh, P(None, 'x'))) + + out2 = jax.jit(jax.vmap(f, spmd_axis_name='y'))(arr) + self.assertEqual(out2.sharding, NamedSharding(mesh, P('y', 'x'))) + + def test_wsc_abstract_mesh_errors(self): + mesh = jtu.create_mesh((2,), ('x',)) + np_inp = np.arange(8) + abstract_mesh = jax.sharding.AbstractMesh(mesh.shape_tuple) + s_abs = NamedSharding(abstract_mesh, P('x')) + + with self.assertRaisesRegex( + ValueError, ".*requires the input passed should be a `jax.Array`.*"): + with_sharding_constraint(np_inp, s_abs) + + with self.assertRaisesRegex( + TypeError, "The sharding on the input must be a `NamedSharding`"): + with_sharding_constraint(jnp.arange(8), s_abs) + + arr = jax.device_put(np_inp, NamedSharding(mesh, P('x'))) + abs_mesh2 = mesh_lib.AbstractMesh( + jtu.create_mesh((2,), 'y').shape_tuple) + with self.assertRaisesRegex( + ValueError, + 'Mesh shape of the input.*does not' + ' match the mesh shape of the target sharding.*'): + with_sharding_constraint(arr, NamedSharding(abs_mesh2, P('y'))) + + @unittest.skipIf(xla_extension_version < 286, + "Requires xla_extension_version >= 286") + def test_global_jit_cpp_cache_hit_out_shardings(self): + mesh = jtu.create_mesh((2,), 'x') + s = NamedSharding(mesh, P('x')) + + def f(x): + return x * 2 + + with jtu.count_pjit_cpp_cache_miss() as count: + jax.jit(f, out_shardings=s)(np.arange(8)) + jax.jit(f, out_shardings=s)(np.arange(8)) + self.assertEqual(count[0], 1) + def spec_regex(s): return str(s).replace(r"(", r"\(").replace(r")", r"\)") +class ShardingInTypesTest(jtu.JaxTestCase): + + @config.sharding_in_types(True) + def test_basic_mul(self): + mesh = jtu.create_mesh((4, 2), ('x', 'y')) + np_inp = np.arange(16).reshape(8, 2) + s = NamedSharding(mesh, P('x', 'y')) + arr = jax.device_put(np_inp, s) + + @jax.jit + def f(x): + self.assertEqual(x.sharding.spec, s.spec) + x = x * 2 + self.assertEqual(x.sharding.spec, s.spec) + x = x * x + self.assertEqual(x.sharding.spec, s.spec) + return x + + out = f(arr) + self.assertEqual(out.sharding, s) + self.assertArraysEqual(out, (np_inp * 2) * (np_inp * 2)) + + lowered_text = f.lower(arr).as_text() + if config.use_shardy_partitioner.value: + self.assertIn('sdy.sharding_constraint', lowered_text) + else: + self.assertEqual(lowered_text.count('@Sharding'), 2) + + @jtu.pytest_mark_if_available('multiaccelerator') class PJitErrorTest(jtu.JaxTestCase): @@ -4397,7 +4614,7 @@ def testUndefinedResourcesArgs(self, mesh, resources): spec = P(resources,) with self.assertRaisesRegex( ValueError, - r"Resource axis: x of.*" + spec_regex(spec) + " is undefined"): + r"Resource axis: x of.*" + spec_regex(spec) + r" is not found in mesh: \(.*\)."): pjit(lambda x: x, in_shardings=spec, out_shardings=None)(x) @check_1d_2d_mesh(set_mesh=False) @@ -4407,7 +4624,7 @@ def testUndefinedResourcesOuts(self, mesh, resources): spec = P(resources,) with self.assertRaisesRegex( ValueError, - r"Resource axis: x of.*" + spec_regex(spec) + " is undefined"): + r"Resource axis: x of.*" + spec_regex(spec) + r" is not found in mesh: \(.*\)."): pjit(lambda x: x, in_shardings=None, out_shardings=spec)(x) @check_1d_2d_mesh(set_mesh=False) @@ -4417,7 +4634,7 @@ def testUndefinedResourcesConstraint(self, mesh, resources): spec = P(resources,) with self.assertRaisesRegex( ValueError, - r"Resource axis: x of.*" + spec_regex(spec) + " is undefined"): + r"Resource axis: x of.*" + spec_regex(spec) + r" is not found in mesh: \(.*\)."): pjit( lambda x: with_sharding_constraint(x, spec), in_shardings=None, @@ -4568,7 +4785,7 @@ def h(x): ) def test_pjit_with_deleted_input_at_first_call(self, committed): shape = (8,) - mesh = jtu.create_global_mesh((1,), ('x',)) + mesh = jtu.create_mesh((1,), ('x',)) inp_data = np.arange(math.prod(shape)).reshape(shape) if committed: s = NamedSharding(mesh, P('x',)) @@ -4586,7 +4803,7 @@ def test_pjit_with_deleted_input_at_first_call(self, committed): ) def test_pjit_with_deleted_input_at_subsequent_call(self, committed): shape = (8,) - mesh = jtu.create_global_mesh((1,), ('x',)) + mesh = jtu.create_mesh((1,), ('x',)) inp_data = np.arange(math.prod(shape)).reshape(shape) if committed: s = NamedSharding(mesh, P('x',)) @@ -4621,7 +4838,7 @@ def f(x, y): g(x, y2) def test_dce_no_array(self): - mesh = jtu.create_global_mesh((2,), ('x',)) + mesh = jtu.create_mesh((2,), ('x',)) arr = jax.device_put(np.arange(8.), NamedSharding(mesh, P('x'))) @jax.jit @@ -4966,7 +5183,7 @@ def test_hlo_sharding_manual_replicated(self): def test_op_sharding_cache_on_mesh_pspec_sharding(self): ndim = 2 - mesh = jtu.create_global_mesh((4, 2), ('x', 'y')) + mesh = jtu.create_mesh((4, 2), ('x', 'y')) mps1 = NamedSharding(mesh, P('x', 'y')) op1 = mps1._to_xla_hlo_sharding(ndim) cache_info1 = sharding_impls.named_sharding_to_xla_hlo_sharding.cache_info() @@ -4981,7 +5198,7 @@ def test_op_sharding_cache_on_mesh_pspec_sharding(self): self.assertEqual(cache_info2.currsize, cache_info1.currsize) def test_get_partition_spec(self): - mesh = jtu.create_global_mesh((4, 2), ('x', 'y')) + mesh = jtu.create_mesh((4, 2), ('x', 'y')) s = NamedSharding(mesh, P('x', 'y', None)) self.assertEqual(s._parsed_pspec.get_partition_spec(), P('x', 'y', None)) @@ -4991,11 +5208,6 @@ def test_get_partition_spec(self): self.assertEqual(recovered_parsed_pspec[0].get_partition_spec(), P('x', 'y')) - out_of_sync_parsed_pspec = sharding_impls.ParsedPartitionSpec( - P('x', 'y'), ('x', 'y'), sharding_impls.SpecSync.OUT_OF_SYNC) - self.assertEqual(out_of_sync_parsed_pspec.get_partition_spec(), - P('x', 'y')) - def test_mesh_with_list_devices(self): mesh = jax.sharding.Mesh(jax.devices(), ('x',)) self.assertIsInstance(mesh.devices, np.ndarray) @@ -5005,6 +5217,24 @@ def test_mesh_with_string_axis_names(self): mesh = jax.sharding.Mesh(jax.devices(), 'dp') self.assertTupleEqual(mesh.axis_names, ('dp',)) + def test_sharded_in_place_assignment(self): + mesh = jtu.create_mesh((8,), ('data',)) + + idx = [0, 2, 5, 7, 8, 10, 13, 15] + n = 16 + def _init(): + w = jnp.zeros((n, n)) + idx1 = jnp.array(idx) + w = w.at[idx1, jnp.arange(n//2)].set(1) + return w + + w = jax.jit(_init, out_shardings=NamedSharding(mesh, P(None, 'data')))() + + w_gt = np.zeros((n, n)) + for j, i in enumerate(idx): + w_gt[i, j] = 1 + + self.assertArraysEqual(w, w_gt) @jtu.with_config(jax_use_shardy_partitioner=True) class SdyIntegrationTest(jtu.JaxTestCase): @@ -5015,7 +5245,7 @@ def setUp(self): raise unittest.SkipTest('Shardy is not available.') def test_lowering_input_output_sharding(self): - mesh = jtu.create_global_mesh((4, 2), ('x', 'y')) + mesh = jtu.create_mesh((4, 2), ('x', 'y')) np_inp = np.arange(16).reshape(8, 2) s = jax.sharding.NamedSharding(mesh, P('x', 'y')) arr = jax.device_put(np_inp, s) @@ -5027,7 +5257,7 @@ def f(x): self.assertIn('sdy.sharding = #sdy.sharding', f.lower(arr).as_text()) def test_lowering_with_sharding_constraint(self): - mesh = jtu.create_global_mesh((2, 2), ('x', 'y')) + mesh = jtu.create_mesh((2, 2), ('x', 'y')) arr = np.arange(16).reshape(4, 2, 2) @jax.jit @@ -5039,7 +5269,7 @@ def f(x): self.assertIn('<@mesh, [{"x"}, {}, {"y"}]>', lowered_str) def test_lowering_with_sharding_constraint_unconstrained(self): - mesh = jtu.create_global_mesh((4, 2), ('x', 'y')) + mesh = jtu.create_mesh((4, 2), ('x', 'y')) arr = np.arange(16).reshape(4, 2, 2) @jax.jit @@ -5053,11 +5283,11 @@ def f(x): # TODO(bartchr): run on CPU once Shardy is added to the XLA CPU pipeline. @jtu.skip_on_devices('cpu') def test_compile_with_inferred_out_sharding(self): - mesh = jtu.create_global_mesh((2, 2), ('x', 'y')) + mesh = jtu.create_mesh((2, 2), ('x', 'y')) x = jax.device_put(np.arange(8 * 4).reshape(8, 4), - jax.sharding.NamedSharding(mesh, P('x', 'y'))) + NamedSharding(mesh, P('x', 'y'))) y = jax.device_put(np.arange(4 * 16).reshape(4, 16), - jax.sharding.NamedSharding(mesh, P('y'))) + NamedSharding(mesh, P('y'))) @jax.jit def f(x, y): @@ -5065,7 +5295,18 @@ def f(x, y): out = f(x, y) self.assertArraysEqual(out, x @ y) - self.assertEqual(out.sharding, jax.sharding.NamedSharding(mesh, P('x'))) + self.assertEqual(out.sharding, NamedSharding(mesh, P('x'))) + + def test_fully_automatic_sharding(self): + mesh = jtu.create_mesh((8,), ('x',)) + x = jax.ShapeDtypeStruct((128, 128), jnp.float32) + + @jax.jit + def f(x, y): + return x @ y + + lowered_str = jax.jit(f, in_shardings=[AUTO(mesh), AUTO(mesh)]).lower(x, x).as_text() + self.assertIn('sdy.mesh @mesh = <["x"=8]>', lowered_str) if __name__ == '__main__': diff --git a/tests/pmap_test.py b/tests/pmap_test.py index e576fae91a83..9a8d0b91272b 100644 --- a/tests/pmap_test.py +++ b/tests/pmap_test.py @@ -122,7 +122,7 @@ def pmap(self): def testDeviceBufferToArray(self): sda = self.pmap(lambda x: x)(jnp.ones((jax.device_count(), 2))) - # Changed in https://github.com/google/jax/pull/10584 not to access + # Changed in https://github.com/jax-ml/jax/pull/10584 not to access # sda.device_buffers, which isn't supported, and instead ensure fast slices # of the arrays returned by pmap are set up correctly. # buf = sda.device_buffers[-1] @@ -336,7 +336,7 @@ def test_jit_lower_compile_with_compiler_options_invalid(self): compiler_options={"xla_embed_ir_in_executable": "invalid_value"})) def test_pmap_replicated_copy(self): - # https://github.com/google/jax/issues/17690 + # https://github.com/jax-ml/jax/issues/17690 inp = jnp.arange(jax.device_count()) x = jax.pmap(lambda x: x, in_axes=0, out_axes=None)(inp) out = jnp.copy(x) @@ -605,7 +605,7 @@ def f(x): self.assertAllClose(y, ref) def testNestedPmapAxisSwap(self): - # Regression test for https://github.com/google/jax/issues/5757 + # Regression test for https://github.com/jax-ml/jax/issues/5757 if jax.device_count() < 8: raise SkipTest("test requires at least 8 devices") f = jax.pmap(jax.pmap(lambda x: x, in_axes=1, out_axes=0), in_axes=0, @@ -1180,7 +1180,7 @@ def testPShuffleWithBadPerm(self): "`perm` does not represent a permutation: \\[1.*\\]", g) def testPpermuteWithZipObject(self): - # https://github.com/google/jax/issues/1703 + # https://github.com/jax-ml/jax/issues/1703 num_devices = jax.device_count() perm = [num_devices - 1] + list(range(num_devices - 1)) f = self.pmap(lambda x: lax.ppermute(x, "i", zip(perm, range(num_devices))), "i") @@ -1285,7 +1285,7 @@ def testPmapConstant(self): device_count = jax.device_count() f = self.pmap(lambda x: 3) x = jnp.arange(device_count) - with jtu.count_jit_and_pmap_compiles() as count: # noqa: F841 + with jtu.count_jit_and_pmap_lowerings() as count: # noqa: F841 ans = f(x) # self.assertEqual(count[0], 0) # TODO(mattjj): fix this expected = np.repeat(3, device_count) @@ -1306,7 +1306,7 @@ def testPmapConstantDevices(self): shuffle(devices) f = self.pmap(lambda x: 3, devices=devices) x = jnp.arange(len(devices)) - with jtu.count_jit_and_pmap_compiles() as count: # noqa: F841 + with jtu.count_jit_and_pmap_lowerings() as count: # noqa: F841 ans = f(x) # self.assertEqual(count[0], 0) # TODO(mattjj): don't compile for constants expected = np.repeat(3, len(devices)) @@ -1342,7 +1342,7 @@ def testNestedPmapConstant(self): f = self.pmap(self.pmap(lambda x: 3)) shape = (2, jax.device_count() // 2, 3) x = jnp.arange(math.prod(shape)).reshape(shape) - with jtu.count_jit_and_pmap_compiles() as count: # noqa: F841 + with jtu.count_jit_and_pmap_lowerings() as count: # noqa: F841 ans = f(x) # self.assertEqual(count[0], 0) # TODO(mattjj): don't compile for constants expected = 3 * np.ones(shape[:2]) @@ -1368,7 +1368,7 @@ def testNestedPmapConstantDevices(self): f = self.pmap(self.pmap(lambda x: 3), devices=devices) shape = (2, len(devices) // 2, 3) x = jnp.arange(math.prod(shape)).reshape(shape) - with jtu.count_jit_and_pmap_compiles() as count: # noqa: F841 + with jtu.count_jit_and_pmap_lowerings() as count: # noqa: F841 ans = f(x) # self.assertEqual(count[0], 0) # TODO(mattjj): don't compile for constants expected = 3 * np.ones(shape[:2]) @@ -1501,7 +1501,7 @@ def s(keys): self.assertEqual(ans.shape, (13, N_DEVICES)) def testVmapOfPmap3(self): - # https://github.com/google/jax/issues/3399 + # https://github.com/jax-ml/jax/issues/3399 device_count = jax.device_count() if device_count < 2: raise SkipTest("test requires at least two devices") @@ -1661,7 +1661,7 @@ def g(z): @ignore_jit_of_pmap_warning() def testIssue1065(self): - # from https://github.com/google/jax/issues/1065 + # from https://github.com/jax-ml/jax/issues/1065 device_count = jax.device_count() def multi_step_pmap(state, count): @@ -1697,7 +1697,7 @@ def testArrayGetItem(self): # replica. @unittest.skip("need eager multi-replica support") def testPostProcessMap(self): - # test came from https://github.com/google/jax/issues/1369 + # test came from https://github.com/jax-ml/jax/issues/1369 nrep = jax.device_count() def pmvm(a, b): @@ -1730,7 +1730,7 @@ def f(args_list): @jax.default_matmul_precision("float32") def testPostProcessMap2(self): - # code from https://github.com/google/jax/issues/2787 + # code from https://github.com/jax-ml/jax/issues/2787 def vv(x, y): """Vector-vector multiply""" return jnp.dot(x, y) @@ -1758,7 +1758,7 @@ def distributed_matrix_vector(x, y): ('_new', new_checkpoint), ]) def testAxisIndexRemat(self, remat): - # https://github.com/google/jax/issues/2716 + # https://github.com/jax-ml/jax/issues/2716 n = len(jax.devices()) def f(key): @@ -1769,7 +1769,7 @@ def f(key): self.pmap(remat(f), axis_name='i')(keys) def testPmapMapVmapCombinations(self): - # https://github.com/google/jax/issues/2822 + # https://github.com/jax-ml/jax/issues/2822 def vv(x, y): """Vector-vector multiply""" return jnp.dot(x, y) @@ -1802,7 +1802,7 @@ def matrix_vector(x, y, parallel=True): self.assertAllClose(result1, result4, check_dtypes=False, atol=1e-3, rtol=1e-3) def testPmapAxisNameError(self): - # https://github.com/google/jax/issues/3120 + # https://github.com/jax-ml/jax/issues/3120 a = np.arange(4)[np.newaxis,:] def test(x): return jax.lax.psum(x, axis_name='batch') @@ -1811,7 +1811,7 @@ def test(x): self.pmap(test)(a) def testPsumOnBooleanDtype(self): - # https://github.com/google/jax/issues/3123 + # https://github.com/jax-ml/jax/issues/3123 n = jax.device_count() if n > 1: x = jnp.array([True, False]) @@ -1889,7 +1889,7 @@ def foo(x): return x + x self.assertIn("mhlo.num_partitions = 1", hlo) def testPsumZeroCotangents(self): - # https://github.com/google/jax/issues/3651 + # https://github.com/jax-ml/jax/issues/3651 def loss(params, meta_params): (net, mpo) = params return meta_params * mpo * net @@ -1914,7 +1914,7 @@ def outer(params): @ignore_jit_of_pmap_warning() def test_issue_1062(self): - # code from https://github.com/google/jax/issues/1062 @shoyer + # code from https://github.com/jax-ml/jax/issues/1062 @shoyer # this tests, among other things, whether ShardedDeviceTuple constants work device_count = jax.device_count() @@ -1938,7 +1938,7 @@ def test_replicate_backend(self): # TODO(skye): fix backend caching so we always have multiple CPUs available if jax.device_count("cpu") < 4: self.skipTest("test requires 4 CPU device") - # https://github.com/google/jax/issues/4223 + # https://github.com/jax-ml/jax/issues/4223 def fn(indices): return jnp.equal(indices, jnp.arange(3)).astype(jnp.float32) mapped_fn = self.pmap(fn, axis_name='i', backend='cpu') @@ -1982,7 +1982,7 @@ def testArgAllReduce(self, shape, dtype, axis, collective, bulk_op): for dtype in [np.float32, np.int32] ) def testPmapDtype(self, dtype): - # Regression test for https://github.com/google/jax/issues/6022 + # Regression test for https://github.com/jax-ml/jax/issues/6022 @partial(self.pmap, axis_name='i') def func(_): return jax.lax.psum(dtype(0), axis_name='i') @@ -1991,7 +1991,7 @@ def func(_): self.assertEqual(out_dtype, dtype) def test_num_replicas_with_switch(self): - # https://github.com/google/jax/issues/7411 + # https://github.com/jax-ml/jax/issues/7411 def identity(x): return x @@ -2039,7 +2039,7 @@ def f(x): _, f_bwd = jax.vjp(f, x) _ = f_bwd(x) - with jtu.count_jit_and_pmap_compiles() as count: # noqa: F841 + with jtu.count_jit_and_pmap_lowerings() as count: # noqa: F841 _, f_bwd2 = jax.vjp(f, x) _ = f_bwd(x) _ = f_bwd2(x) @@ -2154,7 +2154,7 @@ def test_axis_name_shadowing_with_vmap(self): @jtu.run_on_devices("cpu") def test_pmap_stack_size(self): - # Regression test for https://github.com/google/jax/issues/20428 + # Regression test for https://github.com/jax-ml/jax/issues/20428 # pmap isn't particularly important here, but it guarantees that the CPU # client runs the computation on a threadpool rather than inline. if jax.device_count() < 2: @@ -2164,7 +2164,7 @@ def test_pmap_stack_size(self): y.block_until_ready() # doesn't crash def test_pmap_of_prng_key(self): - # Regression test for https://github.com/google/jax/issues/20392 + # Regression test for https://github.com/jax-ml/jax/issues/20392 keys = jax.random.split(jax.random.key(0), jax.device_count()) result1 = jax.pmap(jax.random.bits)(keys) with jtu.ignore_warning( @@ -3015,7 +3015,7 @@ def testShardArgs(self, shape, spec, make_arg): x = np.arange(math.prod(shape)).reshape(shape) arg = make_arg(x) sharding = jax.sharding.PmapSharding(jax.devices()[:nshards], spec) - results = pxla.shard_args([sharding], [arg]) + results = pxla.shard_args([sharding], [None], [arg]) self.assertEqual(len(results), 1) if isinstance(results[0], array.ArrayImpl): bufs = results[0]._arrays @@ -3209,8 +3209,12 @@ def setUp(self): self.jit_disabled = config.disable_jit.value config.update('jax_disable_jit', True) config.update('jax_eager_pmap', True) + self.warning_ctx = jtu.ignore_warning( + message="Some donated buffers were not usable", category=UserWarning) + self.warning_ctx.__enter__() def tearDown(self): + self.warning_ctx.__exit__(None, None, None) config.update('jax_eager_pmap', self.eager_pmap_enabled) config.update('jax_disable_jit', self.jit_disabled) super().tearDown() diff --git a/tests/python_callback_test.py b/tests/python_callback_test.py index 76854beaea3d..d9887cf7b482 100644 --- a/tests/python_callback_test.py +++ b/tests/python_callback_test.py @@ -1245,7 +1245,7 @@ def f(shard_ids, x): np.testing.assert_array_equal(shard[0] + 1, shard[1]) def test_batching_with_side_effects(self): - # https://github.com/google/jax/issues/20628#issuecomment-2050800195 + # https://github.com/jax-ml/jax/issues/20628#issuecomment-2050800195 x_lst = [] def append_x(x): nonlocal x_lst @@ -1261,7 +1261,7 @@ def f(x): self.assertAllClose(x_lst, [0., 1., 2., 0., 2., 4.], check_dtypes=False) def test_batching_with_side_effects_while_loop(self): - # https://github.com/google/jax/issues/20628#issuecomment-2050921219 + # https://github.com/jax-ml/jax/issues/20628#issuecomment-2050921219 x_lst = [] def append_x(x): nonlocal x_lst diff --git a/tests/pytorch_interoperability_test.py b/tests/pytorch_interoperability_test.py index 8d00b5eedaf4..e41c4329b95b 100644 --- a/tests/pytorch_interoperability_test.py +++ b/tests/pytorch_interoperability_test.py @@ -108,14 +108,18 @@ def testJaxArrayToTorch(self, shape, dtype): else: self.assertAllClose(np, y.cpu().numpy()) + @jtu.ignore_warning(message="Calling from_dlpack with a DLPack tensor", + category=DeprecationWarning) def testTorchToJaxInt64(self): - # See https://github.com/google/jax/issues/11895 + # See https://github.com/jax-ml/jax/issues/11895 x = jax.dlpack.from_dlpack( torch.utils.dlpack.to_dlpack(torch.ones((2, 3), dtype=torch.int64))) dtype_expected = jnp.int64 if config.enable_x64.value else jnp.int32 self.assertEqual(x.dtype, dtype_expected) @jtu.sample_product(shape=all_shapes, dtype=torch_dtypes) + @jtu.ignore_warning(message="Calling from_dlpack with a DLPack tensor", + category=DeprecationWarning) def testTorchToJax(self, shape, dtype): if not config.enable_x64.value and dtype in [ jnp.int64, diff --git a/tests/random_lax_test.py b/tests/random_lax_test.py index f69687ddc6cd..63510b7295d6 100644 --- a/tests/random_lax_test.py +++ b/tests/random_lax_test.py @@ -178,7 +178,7 @@ def testNormal(self, dtype): def testNormalBfloat16(self): # Passing bfloat16 as dtype string. - # https://github.com/google/jax/issues/6813 + # https://github.com/jax-ml/jax/issues/6813 res_bfloat16_str = random.normal(self.make_key(0), dtype='bfloat16') res_bfloat16 = random.normal(self.make_key(0), dtype=jnp.bfloat16) self.assertAllClose(res_bfloat16, res_bfloat16_str) @@ -391,7 +391,7 @@ def testBeta(self, a, b, dtype): @jtu.skip_on_devices("tpu") # TPU precision causes issues. def testBetaSmallParameters(self, dtype=np.float32): - # Regression test for beta version of https://github.com/google/jax/issues/9896 + # Regression test for beta version of https://github.com/jax-ml/jax/issues/9896 key = self.make_key(0) a, b = 0.0001, 0.0002 samples = random.beta(key, a, b, shape=(100,), dtype=dtype) @@ -441,7 +441,7 @@ def testDirichlet(self, alpha, dtype): @jtu.skip_on_devices("tpu") # lower accuracy leads to failures. def testDirichletSmallAlpha(self, dtype=np.float32): - # Regression test for https://github.com/google/jax/issues/9896 + # Regression test for https://github.com/jax-ml/jax/issues/9896 key = self.make_key(0) alpha = 0.00001 * jnp.ones(3) samples = random.dirichlet(key, alpha, shape=(100,), dtype=dtype) @@ -530,7 +530,7 @@ def testGammaGrad(self, log_space, alpha): rtol=rtol) def testGammaGradType(self): - # Regression test for https://github.com/google/jax/issues/2130 + # Regression test for https://github.com/jax-ml/jax/issues/2130 key = self.make_key(0) a = jnp.array(1., dtype=jnp.float32) b = jnp.array(3., dtype=jnp.float32) @@ -663,7 +663,7 @@ def testGeneralizedNormal(self, p, shape, dtype): ) def testGeneralizedNormalKS(self, p, shape, dtype): self.skipTest( # test is also sometimes slow, with (300, ...)-shape draws - "sensitive to random key - https://github.com/google/jax/issues/18941") + "sensitive to random key - https://github.com/jax-ml/jax/issues/18941") key = lambda: self.make_key(2) rand = lambda key, p: random.generalized_normal(key, p, (300, *shape), dtype) crand = jax.jit(rand) @@ -700,7 +700,7 @@ def testBall(self, d, p, shape, dtype): @jtu.skip_on_devices("tpu") # TPU precision causes issues. def testBallKS(self, d, p, shape, dtype): self.skipTest( - "sensitive to random key - https://github.com/google/jax/issues/18932") + "sensitive to random key - https://github.com/jax-ml/jax/issues/18932") key = lambda: self.make_key(123) rand = lambda key, p: random.ball(key, d, p, (100, *shape), dtype) crand = jax.jit(rand) @@ -800,7 +800,7 @@ def testMultivariateNormalShapes(self, dim, mean_batch_size, cov_batch_size, assert samples.shape == shape + (dim,) def testMultivariateNormalCovariance(self): - # test code based on https://github.com/google/jax/issues/1869 + # test code based on https://github.com/jax-ml/jax/issues/1869 N = 100000 mean = jnp.zeros(4) cov = jnp.array([[ 0.19, 0.00, -0.13, 0.00], @@ -827,7 +827,7 @@ def testMultivariateNormalCovariance(self): @jtu.sample_product(method=['cholesky', 'eigh', 'svd']) @jtu.skip_on_devices('gpu', 'tpu') # Some NaNs on accelerators. def testMultivariateNormalSingularCovariance(self, method): - # Singular covariance matrix https://github.com/google/jax/discussions/13293 + # Singular covariance matrix https://github.com/jax-ml/jax/discussions/13293 mu = jnp.zeros((2,)) sigma = jnp.ones((2, 2)) key = self.make_key(0) @@ -889,7 +889,7 @@ def testDtypeErrorMessage(self): def testRandomBroadcast(self): """Issue 4033""" - # test for broadcast issue in https://github.com/google/jax/issues/4033 + # test for broadcast issue in https://github.com/jax-ml/jax/issues/4033 key = lambda: self.make_key(0) shape = (10, 2) with jax.numpy_rank_promotion('allow'): @@ -1071,7 +1071,7 @@ def test_randint_out_of_range(self): self.assertGreater((r == 255).sum(), 0) def test_large_prng(self): - # https://github.com/google/jax/issues/11010 + # https://github.com/jax-ml/jax/issues/11010 def f(): return random.uniform( self.make_key(3), (308000000, 128), dtype=jnp.bfloat16) @@ -1086,7 +1086,7 @@ def f(): logits_shape_base=[(3, 4), (3, 1), (1, 4)], axis=[-3, -2, -1, 0, 1, 2]) def test_categorical_shape_argument(self, shape, logits_shape_base, axis): - # https://github.com/google/jax/issues/13124 + # https://github.com/jax-ml/jax/issues/13124 logits_shape = list(logits_shape_base) logits_shape.insert(axis % (len(logits_shape_base) + 1), 10) assert logits_shape[axis] == 10 diff --git a/tests/random_test.py b/tests/random_test.py index 80e8ea76f82c..da182dbccae9 100644 --- a/tests/random_test.py +++ b/tests/random_test.py @@ -436,7 +436,7 @@ def test_threefry_split_fold_in_symmetry(self, make_key): @skipIf(not config.threefry_partitionable.value, 'enable after upgrade') @parameterized.parameters([{'make_key': ctor} for ctor in KEY_CTORS]) def test_threefry_split_vmapped_fold_in_symmetry(self, make_key): - # See https://github.com/google/jax/issues/7708 + # See https://github.com/jax-ml/jax/issues/7708 with jax.default_prng_impl('threefry2x32'): key = make_key(72) f1, f2, f3 = vmap(lambda k, _: random.fold_in(k, lax.axis_index('batch')), @@ -450,7 +450,7 @@ def test_threefry_split_vmapped_fold_in_symmetry(self, make_key): @skipIf(config.threefry_partitionable.value, 'changed random bit values') def test_loggamma_nan_corner_case(self): - # regression test for https://github.com/google/jax/issues/17922 + # regression test for https://github.com/jax-ml/jax/issues/17922 # This particular key previously led to NaN output. # If the underlying implementation ever changes, this test will no longer # exercise this corner case, so we compare to a particular output value @@ -545,7 +545,7 @@ def test_isinstance(self, make_key): @parameterized.parameters([{'make_key': ctor} for ctor in KEY_CTORS]) def test_key_output_vjp(self, make_key): - # See https://github.com/google/jax/issues/14856 + # See https://github.com/jax-ml/jax/issues/14856 def f(seed): return make_key(seed) jax.vjp(f, 1) # doesn't crash @@ -578,7 +578,7 @@ class ThreefryPrngTest(jtu.JaxTestCase): partial(random.PRNGKey, impl='threefry2x32'), partial(random.key, impl='threefry2x32')]]) def test_seed_no_implicit_transfers(self, make_key): - # See https://github.com/google/jax/issues/15613 + # See https://github.com/jax-ml/jax/issues/15613 with jax.transfer_guard('disallow'): make_key(jax.device_put(42)) # doesn't crash @@ -922,14 +922,14 @@ def test_select(self): self.assertEqual(ys.shape, (3, 2)) def test_select_scalar_cond(self): - # regression test for https://github.com/google/jax/issues/16422 + # regression test for https://github.com/jax-ml/jax/issues/16422 ks = self.make_keys(3) ys = lax.select(True, ks, ks) self.assertIsInstance(ys, prng_internal.PRNGKeyArray) self.assertEqual(ys.shape, (3,)) def test_vmap_of_cond(self): - # See https://github.com/google/jax/issues/15869 + # See https://github.com/jax-ml/jax/issues/15869 def f(x): keys = self.make_keys(*x.shape) return lax.select(x, keys, keys) @@ -957,7 +957,7 @@ def test_device_put_replicated(self): def test_make_array_from_callback(self): devices = jax.devices() shape = (len(devices),) - mesh = jtu.create_global_mesh((len(devices),), ('x',)) + mesh = jtu.create_mesh((len(devices),), ('x',)) sharding = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec('x')) def callback(index): i = jnp.arange(len(devices))[index[0]] @@ -969,7 +969,7 @@ def callback(index): def test_make_array_from_single_device_arrays(self): devices = jax.devices() shape = (len(devices),) - mesh = jtu.create_global_mesh((len(devices),), ('x',)) + mesh = jtu.create_mesh((len(devices),), ('x',)) sharding = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec('x')) keys = random.split(random.key(0), len(devices)) arrays = [jax.device_put(keys[i:i + 1], device) for i, device in enumerate(devices)] @@ -1119,8 +1119,14 @@ class A: pass with self.assertRaisesRegex(TypeError, 'unrecognized type .* PRNG'): jax.random.key(42, impl=A()) + @jtu.sample_product(name=[name for name, _ in PRNG_IMPLS]) + def test_key_spec_repr(self, name): + key = jax.random.key(42, impl=name) + spec = jax.random.key_impl(key) + self.assertEqual(repr(spec), f"PRNGSpec({name!r})") + def test_keyarray_custom_vjp(self): - # Regression test for https://github.com/google/jax/issues/18442 + # Regression test for https://github.com/jax-ml/jax/issues/18442 @jax.custom_vjp def f(_, state): return state diff --git a/tests/scipy_fft_test.py b/tests/scipy_fft_test.py index 6c549f5ed10c..a6fdd1b79f58 100644 --- a/tests/scipy_fft_test.py +++ b/tests/scipy_fft_test.py @@ -13,9 +13,12 @@ # limitations under the License. import itertools +import numpy as np + from absl.testing import absltest import jax +from jax._src import config from jax._src import test_util as jtu import jax.scipy.fft as jsp_fft import scipy.fft as osp_fft @@ -117,5 +120,15 @@ def testiDctn(self, shape, dtype, s, axes, norm): tol=1e-4) self._CompileAndCheck(jnp_fn, args_maker, atol=1e-4) + def testIdctNormalizationPrecision(self): + # reported in https://github.com/jax-ml/jax/issues/23895 + if not config.enable_x64.value: + raise self.skipTest("requires jax_enable_x64=true") + x = np.ones(3, dtype="float64") + n = 10 + expected = osp_fft.idct(x, n=n, type=2) + actual = jsp_fft.idct(x, n=n, type=2) + self.assertArraysAllClose(actual, expected, atol=1e-14) + if __name__ == "__main__": absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/tests/scipy_ndimage_test.py b/tests/scipy_ndimage_test.py index 701b7c570937..dd34a99a73b8 100644 --- a/tests/scipy_ndimage_test.py +++ b/tests/scipy_ndimage_test.py @@ -129,7 +129,7 @@ def args_maker(): self._CheckAgainstNumpy(osp_op, lsp_op, args_maker) def testContinuousGradients(self): - # regression test for https://github.com/google/jax/issues/3024 + # regression test for https://github.com/jax-ml/jax/issues/3024 def loss(delta): x = np.arange(100.0) diff --git a/tests/scipy_optimize_test.py b/tests/scipy_optimize_test.py index 70a00e14c468..ffa576850538 100644 --- a/tests/scipy_optimize_test.py +++ b/tests/scipy_optimize_test.py @@ -117,6 +117,7 @@ def zakharov_fn(x): jax_res = jax.scipy.optimize.minimize(fun=eval_func, x0=x0, method='BFGS') self.assertLess(jax_res.fun, 1e-6) + @jtu.ignore_warning(category=RuntimeWarning, message='divide by zero') def test_minimize_bad_initial_values(self): # This test runs deliberately "bad" initial values to test that handling # of failed line search, etc. is the same across implementations diff --git a/tests/scipy_spatial_test.py b/tests/scipy_spatial_test.py index c02653dd171b..540136b33870 100644 --- a/tests/scipy_spatial_test.py +++ b/tests/scipy_spatial_test.py @@ -132,6 +132,20 @@ def testRotationAsQuatCanonical(self, shape, dtype): self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker, check_dtypes=True, tol=1e-4) self._CompileAndCheck(jnp_fn, args_maker, atol=1e-4) + @jtu.sample_product( + dtype=float_dtypes, + shape=[(4,), (num_samples, 4)], + ) + def testRotationAsQuatScalarFirst(self, shape, dtype): + if scipy_version < (1, 14, 0): + self.skipTest("Scipy 1.14.0 added the `scalar_first` arg.") + rng = jtu.rand_default(self.rng()) + args_maker = lambda: (rng(shape, dtype),) + jnp_fn = lambda q: jsp_Rotation.from_quat(q).as_quat(scalar_first=True) + np_fn = lambda q: osp_Rotation.from_quat(q).as_quat(scalar_first=True).astype(dtype) + self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker, check_dtypes=True, tol=1e-4) + self._CompileAndCheck(jnp_fn, args_maker, atol=1e-4) + @jtu.sample_product( dtype=float_dtypes, shape=[(num_samples, 4)], diff --git a/tests/scipy_stats_test.py b/tests/scipy_stats_test.py index ad988bba62d3..f02ed0fc04bb 100644 --- a/tests/scipy_stats_test.py +++ b/tests/scipy_stats_test.py @@ -314,7 +314,7 @@ def args_maker(): rtol={np.float32: 2e-3, np.float64: 1e-4}) def testBetaLogPdfZero(self): - # Regression test for https://github.com/google/jax/issues/7645 + # Regression test for https://github.com/jax-ml/jax/issues/7645 a = b = 1. x = np.array([0., 1.]) self.assertAllClose( @@ -539,7 +539,7 @@ def args_maker(): self._CompileAndCheck(lax_fun, args_maker) def testGammaLogPdfZero(self): - # Regression test for https://github.com/google/jax/issues/7256 + # Regression test for https://github.com/jax-ml/jax/issues/7256 self.assertAllClose( osp_stats.gamma.pdf(0.0, 1.0), lsp_stats.gamma.pdf(0.0, 1.0), atol=1E-6) @@ -710,7 +710,7 @@ def args_maker(): self._CompileAndCheck(lax_fun, args_maker) def testLogisticLogpdfOverflow(self): - # Regression test for https://github.com/google/jax/issues/10219 + # Regression test for https://github.com/jax-ml/jax/issues/10219 self.assertAllClose( np.array([-100, -100], np.float32), lsp_stats.logistic.logpdf(np.array([-100, 100], np.float32)), @@ -855,7 +855,7 @@ def args_maker(): self._CompileAndCheck(lax_fun, args_maker) def testNormSfNearZero(self): - # Regression test for https://github.com/google/jax/issues/17199 + # Regression test for https://github.com/jax-ml/jax/issues/17199 value = np.array(10, np.float32) self.assertAllClose(osp_stats.norm.sf(value).astype('float32'), lsp_stats.norm.sf(value), @@ -1208,7 +1208,7 @@ def args_maker(): self._CompileAndCheck(lax_fun, args_maker, rtol=tol, atol=tol) def testBinomPmfOutOfRange(self): - # Regression test for https://github.com/google/jax/issues/19150 + # Regression test for https://github.com/jax-ml/jax/issues/19150 self.assertEqual(lsp_stats.binom.pmf(k=6.5, n=5, p=0.8), 0.0) def testBinomLogPmfZerokZeron(self): @@ -1568,6 +1568,14 @@ def evaluate_kde(kde, x): contains_nans=[True, False], keepdims=[True, False] ) + @jtu.ignore_warning( + category=RuntimeWarning, + message="One or more sample arguments is too small; all returned values will be NaN" + ) + @jtu.ignore_warning( + category=RuntimeWarning, + message="All axis-slices of one or more sample arguments are too small", + ) def testMode(self, shape, dtype, axis, contains_nans, keepdims): if scipy_version < (1, 9, 0) and keepdims != True: self.skipTest("scipy < 1.9.0 only support keepdims == True") diff --git a/tests/shape_poly_test.py b/tests/shape_poly_test.py index 404c9ce0c4fd..77e5273d172a 100644 --- a/tests/shape_poly_test.py +++ b/tests/shape_poly_test.py @@ -49,6 +49,7 @@ from jax._src.lax import lax as lax_internal from jax._src.lax import control_flow as lax_control_flow from jax._src.lib import xla_client +from jax._src.lib import version as jaxlib_version import numpy as np config.parse_flags_with_absl() @@ -994,6 +995,16 @@ def test_constraints_ge_override(self): self.assertEqual(_bounds(a), (10, np.inf)) self.assertEqual(_bounds(b), (1, 10)) + def test_constraint_eq_0(self): + a, b, c, d = shape_poly.symbolic_shape( + "a, b, c, d", + constraints=("b == a", "c == a + b", "d == 5")) + # Check that we have already applied the normalizaton rules + self.assertEqual(a._to_var(), "a") + self.assertEqual(b._to_var(), "a") + self.assertEqual(c._to_single_term(), (0, 2, a._to_term())) + self.assertIs(d, 5) + def test_constraints_eq_1(self): # Some constaints override other a, b, c = shape_poly.symbolic_shape("a, b, c", @@ -1072,6 +1083,20 @@ def test_constraints_eq_7(self): self.assertEqual(128 * (t1_ceil // 128), t1_ceil) self.assertEqual(128 * b1 * (t1_ceil // 128), b1 * t1_ceil) + def test_constraints_eq_bug_23456(self): + b, = jax.export.symbolic_shape('b', constraints=['b==5']) + jax.eval_shape(lambda k: jnp.tile(k, 3), jax.ShapeDtypeStruct((b,), jnp.float32)) + + def test_constraints_eq_bug_23437(self): + def f1(x, y): + return x + y + + x = jnp.ones((4,), dtype=jnp.int32) + y = jnp.ones((4,), dtype=jnp.int32) + args_specs = jax.export.symbolic_args_specs((x, y), ("a*2", "b*2"), constraints=("a==b",)) + exp = jax.export.export(jax.jit(f1))(*args_specs) + self.assertEqual(exp.in_avals[0], exp.in_avals[1]) + def test_constraints_eq_threefry(self): # Test equalities that arise out of the threefree lowering # x : i32[a] # a may be even or odd @@ -1105,12 +1130,9 @@ def test_constraints_a_minus_4d_eq(self): assumptions1 = ["m1 >= 0", "m1 <= 3", "a1 == 4*d1 + m1"] scope1 = shape_poly.SymbolicScope(assumptions1) a1, d1, m1 = shape_poly.symbolic_shape("a1, d1, m1", scope=scope1) - # TODO: The incompleteness is due to the way we combine external constraints self.assertEqual(_bounds(a1 - 4*d1), (1, 3)) # a - 4d = m >= 1 self.assertEqual(_bounds(a1 - 2*d1), (3, np.inf)) # a - 2d = m + 2d >= 3 - # TODO: The incompleteness is due to the way we combine external constraints - self.assertEqual(_bounds(a1), - _expect(best=(5, np.inf), current=(-np.inf, np.inf))) # a >= 4d + m >= 5 + self.assertEqual(_bounds(a1), (5, np.inf)) # a >= 4d + m >= 5 def test_constraints_error_msg(self): a, b = shape_poly.symbolic_shape("a, b", @@ -1641,8 +1663,7 @@ def f(x): # x: i32[a, b] _ = export.export(jax.jit(f))( jax.ShapeDtypeStruct(export.symbolic_shape("a, b"), np.int32)) - - def test_constraints_compile_time_check(self): + def test_constraints_ge_compile_time_check(self): def f(x): # x: i32[a] a = x.shape[0] assert _bounds(a) == (2, 4) @@ -1668,9 +1689,45 @@ def f(x): # x: i32[a] with self.assertRaisesRegex( ValueError, - re.escape("Expected '- a + 4' to be greater or equal to 0, but found -1")): + re.escape("Expected '4 - a' to be greater or equal to 0, but found -1")): exp.call(np.arange(5, dtype=np.int32)) + def test_constraints_eq_0_compile_time_check(self): + def f(x): # x: i32[a, b] + return x + + x_spec = jax.ShapeDtypeStruct( + export.symbolic_shape("a, b", + constraints=["max(a, b) == b"]), np.int32) + exp = export.export(jax.jit(f))(x_spec) + with self.assertRaisesRegex( + ValueError, + re.escape("Expected 'max(a, b) - b' to be equal to 0, but found 1")): + exp.call(np.ones((3, 2), dtype=np.int32)) + + def test_constraints_eq_1_compile_time_check(self): + def f(x): # x: i32[a, b] + return x + + x_spec = jax.ShapeDtypeStruct( + export.symbolic_shape("a, b", + constraints=["a == b"]), np.int32) + exp = export.export(jax.jit(f))(x_spec) + exp.call(np.ones((3, 3), dtype=np.int32)) + + def test_constraints_eq_2_compile_time_check(self): + def f(x): # x: i32[a, b] + return x + + x_spec = jax.ShapeDtypeStruct( + export.symbolic_shape("a, b", + constraints=["max(a, b) == 4", "a == b"]), np.int32) + exp = export.export(jax.jit(f))(x_spec) + with self.assertRaisesRegex( + ValueError, + re.escape("Expected 'max(a, b) - 4' to be equal to 0, but found -1")): + exp.call(np.ones((3, 3), dtype=np.int32)) + def test_caching_with_scopes(self): f_tracing_count = 0 expected_a_bounds = (1, np.inf) @@ -2731,6 +2788,62 @@ def test_vmap_error(self): ((2, 3, 8, 4), "b1, b2, ...", True), ] ], + [ + PolyHarness( + "lu_pivots_to_permutation", + f"shape={jtu.format_shape_dtype_string(shape, np.int32)}_poly={poly}_{permutation_size=}", + lax.linalg.lu_pivots_to_permutation, + arg_descriptors=[RandArg(shape, np.int32), StaticArg(permutation_size)], + polymorphic_shapes=[poly], + symbolic_constraints=constraints, + ) + for shape, poly, permutation_size, constraints in [ + ((4,), None, 8, ()), + ((2, 3, 4), "b1, b2, ...", 8, ()), + ((4,), "b", 8, ["b <= 8"]), + ((2, 3, 4), "b1, b2, b3", 8, ["b3 <= 8"]), + ] + ], + [ + # Tracing errors are only thrown when the trailing dimension of pivots + # is static. Otherwise, the error is thrown at runtime. + PolyHarness( + "lu_pivots_to_permutation_error", + f"shape={jtu.format_shape_dtype_string(shape, np.int32)}_poly={poly}_{permutation_size=}", + lax.linalg.lu_pivots_to_permutation, + arg_descriptors=[RandArg(shape, np.int32), StaticArg(permutation_size)], + polymorphic_shapes=[poly], + symbolic_constraints=constraints, + expect_error=(ValueError, "Output permutation size"), + ) + for shape, poly, permutation_size, constraints in [ + ((4,), None, 3, ()), + ((2, 3, 4), "b1, b2, ...", 3, ()), + ((4,), "b", 8, ["b >= 9"]), + ((2, 3, 4), "b1, b2, b3", 8, ["b3 >= 9"]), + ] + ], + [ + PolyHarness( # pylint: disable=g-complex-comprehension + "lu", f"shape={jtu.format_shape_dtype_string(shape, dtype)}_poly={poly}", + lax.linalg.lu, + arg_descriptors=[RandArg(shape, dtype)], + polymorphic_shapes=[poly], + # TODO(b/360788062): Remove once the forward compatibility window is + # closed. + override_jax_config_flags={ + "jax_export_ignore_forward_compatibility": True}) + for dtype in {np.float32, np.float64, np.complex64, np.complex128} & jtu.supported_dtypes() + for shape, poly in [ + ((5, 4), "m, n"), + ((2, 0, 4), "b, ..."), + ((2, 4, 0), "b, ..."), + ((2, 3, 4, 4), "b1, b2, ..."), + ((2, 3, 4, 5), "b1, b2, ..."), + ((2, 3, 8, 4), "b1, b2, ..."), + ((2, 3, 4, 5), "b1, b2, m, n"), + ] + ], [ # The random primitive tests, with threefry (both partitionable and # non-partitionable), and unsafe_rbg. @@ -2860,7 +2973,7 @@ def test_vmap_error(self): (2, x.shape[0]), (1, 1), "VALID"), arg_descriptors=[RandArg((3, 8), _f32)], polymorphic_shapes=["b, ..."]), - # https://github.com/google/jax/issues/11804 + # https://github.com/jax-ml/jax/issues/11804 # Use the reshape trick to simulate a polymorphic dimension of 16*b. # (See test "conv_general_dilated.1d_1" above for more details.) PolyHarness("reduce_window", "add_monoid_strides_window_size=static", @@ -3354,15 +3467,28 @@ def test_harness(self, harness: PolyHarness): custom_call_harnesses = { "householder_product:gpu", "vmap_geqrf:gpu", # used for linalg.qr - "vmap_lu:gpu", - # custom_linear_solve works as long as lu works. - "vmap_custom_linear_solve:gpu", "vmap_qr:gpu", "qr:gpu", "vmap_svd:gpu", } - if f"{harness.group_name}:{jtu.device_under_test()}" in custom_call_harnesses: + name_device_key = f"{harness.group_name}:{jtu.device_under_test()}" + if name_device_key in custom_call_harnesses: raise unittest.SkipTest("native serialization with shape polymorphism not implemented for custom calls; b/261671778") + # This list keeps track of the minimum jaxlib version that supports shape + # polymorphism for some new primitives as we add them. This check is + # required so that we can still run the test suite with older versions of + # jaxlib. + version_gated = { + # TODO(danfm): remove these checks when jaxlib 0.4.32 is released. + "lu_pivots_to_permutation:gpu": (0, 4, 32), + "lu_pivots_to_permutation_error:gpu": (0, 4, 32), + "lu:gpu": (0, 4, 32), + "vmap_lu:gpu": (0, 4, 32), + "vmap_custom_linear_solve:gpu": (0, 4, 32), + } + if version_gated.get(name_device_key, jaxlib_version) > jaxlib_version: + raise unittest.SkipTest(f"shape polymorphism not supported by jaxlib version {jaxlib_version}") + if harness.group_name == "schur" and not jtu.test_device_matches(["cpu"]): raise unittest.SkipTest("schur decomposition is only implemented on CPU.") @@ -3377,6 +3503,11 @@ def test_harness(self, harness: PolyHarness): if 0 < shape[-1] <= 32: harness.check_result = False + if harness.group_name == "vmap_eigh": + raise unittest.SkipTest( + "Should not compare eigendecompositions for equality directly" + "because eigenvalues are sorted.") + if harness.group_name == "vmap_tan": # Tan (b/274462307) require support for custom call stablehlo.tan. raise unittest.SkipTest( @@ -3413,6 +3544,12 @@ def test_harness(self, harness: PolyHarness): if "cholesky" in harness.group_name and jtu.test_device_matches(["tpu"]): harness.tol = 5e-5 + # TODO(b/360788062): Clean up after the compatibility period. + if harness.group_name in [ + "lu", "vmap_lu", "custom_linear_solve", "vmap_custom_linear_solve" + ] and jtu.test_device_matches(["gpu"]): + config_flags = {**config_flags, "jax_export_ignore_forward_compatibility": True} + with jtu.global_config_context(**config_flags): harness.run_test(self) diff --git a/tests/shard_alike_test.py b/tests/shard_alike_test.py index 11305e937a08..10267ff5eb98 100644 --- a/tests/shard_alike_test.py +++ b/tests/shard_alike_test.py @@ -39,7 +39,7 @@ class ShardAlikeDownstreamTest(jtu.JaxTestCase): def test_full_like(self): x = jnp.arange(16, dtype='float32').reshape(8, 2) - mesh = jtu.create_global_mesh((8,), ("i",)) + mesh = jtu.create_mesh((8,), ("i",)) x = jax.device_put(x, NamedSharding(mesh, P('i', None))) y = jnp.full_like(x, 1) self.assertEqual(x.sharding, y.sharding) @@ -51,7 +51,7 @@ def setUp(self): super().setUp() def test_basic(self): - mesh = jtu.create_global_mesh((2, 2), ('x', 'y')) + mesh = jtu.create_mesh((2, 2), ('x', 'y')) np_inp = np.arange(16).reshape(8, 2) s = NamedSharding(mesh, P('x', 'y')) inp = jax.device_put(np_inp, s) @@ -68,7 +68,7 @@ def f(x): self.assertArraysEqual(out, np_inp * np_inp * 4) def test_output_sharded_alike_input(self): - mesh = jtu.create_global_mesh((2, 1), ('x', 'y')) + mesh = jtu.create_mesh((2, 1), ('x', 'y')) np_inp = np.arange(16).reshape(8, 2) s = NamedSharding(mesh, P('x', 'y')) inp = jax.device_put(np_inp, s) @@ -83,7 +83,7 @@ def f(x): self.assertArraysEqual(out, np_inp * 2) def test_arange_shard_alike_jit(self): - mesh = jtu.create_global_mesh((2, 1), ('x', 'y')) + mesh = jtu.create_mesh((2, 1), ('x', 'y')) np_inp = np.arange(16).reshape(8, 2) s = NamedSharding(mesh, P('x', 'y')) inp = jax.device_put(np_inp, s) @@ -98,7 +98,7 @@ def f(x): self.assertArraysEqual(out, np_inp) def test_different_shapes(self): - mesh = jtu.create_global_mesh((2, 1), ('x', 'y')) + mesh = jtu.create_mesh((2, 1), ('x', 'y')) np_inp = np.arange(16).reshape(8, 2) s = NamedSharding(mesh, P('x',)) inp = jax.device_put(np_inp, s) @@ -113,7 +113,7 @@ def f(x): f(inp) def test_double_shard_alike(self): - mesh = jtu.create_global_mesh((2, 2), ('x', 'y')) + mesh = jtu.create_mesh((2, 2), ('x', 'y')) np_inp = np.arange(16).reshape(8, 2) s = NamedSharding(mesh, P('x', 'y')) inp = jax.device_put(np_inp, s) @@ -131,7 +131,7 @@ def f(x): self.assertEqual(out2.sharding, NamedSharding(mesh, P('x'))) def test_shard_like_eager(self): - mesh = jtu.create_global_mesh((4, 1), ('x', 'y')) + mesh = jtu.create_mesh((4, 1), ('x', 'y')) np_inp = np.arange(16).reshape(8, 2) s = NamedSharding(mesh, P('x', 'y')) inp = jax.device_put(np_inp, s) @@ -145,7 +145,7 @@ def f(x): self.assertArraysEqual(out, np_inp) def test_shard_map(self): - mesh = jtu.create_global_mesh((4, 2), ('x', 'y')) + mesh = jtu.create_mesh((4, 2), ('x', 'y')) np_inp = np.arange(16).reshape(8, 2) s = NamedSharding(mesh, P('x', 'y')) inp = jax.device_put(np_inp, s) @@ -167,7 +167,7 @@ def f(x): self.assertEqual(out2.sharding, s) def test_grad(self): - mesh = jtu.create_global_mesh((4,), ('x',)) + mesh = jtu.create_mesh((4,), ('x',)) np_inp = np.arange(8.) s = NamedSharding(mesh, P('x')) inp = jax.device_put(np_inp, s) @@ -188,7 +188,7 @@ def f(x): jax.grad(jax.jit(f))(inp) # doesn't crash def test_shard_input_as_output(self): - mesh = jtu.create_global_mesh((4,), ('x',)) + mesh = jtu.create_mesh((4,), ('x',)) np_inp = np.arange(8.) s = NamedSharding(mesh, P('x')) @@ -218,7 +218,7 @@ def g(x): self.assertEqual(out4.sharding, s) def test_shard_alike_inputs(self): - mesh = jtu.create_global_mesh((2,), ('x',)) + mesh = jtu.create_mesh((2,), ('x',)) np_inp = np.arange(8.) s = NamedSharding(mesh, P('x')) rep_s = NamedSharding(mesh, P()) @@ -237,7 +237,7 @@ def f(x, y): self.assertEqual(out2.sharding, s) def test_vmap_one_mapped(self): - mesh = jtu.create_global_mesh((2, 2), ('x', 'y')) + mesh = jtu.create_mesh((2, 2), ('x', 'y')) np_inp = np.arange(2) s = NamedSharding(mesh, P('y')) inp = jax.device_put(np_inp, s) @@ -256,7 +256,7 @@ def _shard_slice_like_arg(s): self.assertArraysEqual(out, np.tile(np_inp, [8, 1])) def test_vmap_both_mapped(self): - mesh = jtu.create_global_mesh((2, 2), ('x', 'y')) + mesh = jtu.create_mesh((2, 2), ('x', 'y')) np_inp = np.arange(16).reshape(8, 2) s = NamedSharding(mesh, P('x', 'y')) inp1 = jax.device_put(np_inp, s) diff --git a/tests/shard_map_test.py b/tests/shard_map_test.py index bb1763fbcd3e..0c1155ddf1ab 100644 --- a/tests/shard_map_test.py +++ b/tests/shard_map_test.py @@ -35,9 +35,11 @@ from jax.sharding import PartitionSpec as P from jax._src import config from jax._src import core +from jax._src import prng from jax._src import test_util as jtu from jax._src.util import safe_zip, safe_map, partition_list, merge_lists from jax._src.ad_checkpoint import saved_residuals +from jax._src.mesh import AbstractMesh from jax._src.interpreters import mlir from jax._src.interpreters import partial_eval as pe from jax._src import linear_util as lu @@ -54,9 +56,7 @@ # Helper for some tests. def create_inputs(a_sharding, b_sharding): - x, y, z = 2, 2, 2 # pylint: disable=invalid-name - devices = np.array(jax.devices()[:x * y * z]).reshape((x, y, z)) - mesh = Mesh(devices, axis_names=('x', 'y', 'z')) + mesh = jtu.create_mesh((2, 2, 2), ('x', 'y', 'z')) b, e, f = 8, 8, 8 # pylint: disable=invalid-name m1 = jax.device_put( jnp.arange(b * e).reshape((b, e)), @@ -72,8 +72,6 @@ def create_inputs(a_sharding, b_sharding): def setUpModule(): _exit_stack.enter_context(jtu.set_host_platform_device_count(8)) - if len(jax.devices()) < 8: - raise unittest.SkipTest("tests require 8 devices") def tearDownModule(): _exit_stack.close() @@ -91,7 +89,7 @@ def identity(x): @jax.jit def fwd(a): c = shard_map( - lambda x: x, + identity, mesh, in_specs=(P('z', ('x', 'y')),), out_specs=P('z', ('x', 'y')))(a) @@ -217,8 +215,7 @@ def fwd(a): self.assertAllClose(np.squeeze(c.addressable_data(2 * i + 1), -1), sums) def test_collective_permute(self): - devices = np.array(jax.devices()[:8]) # Take up to 8 devices - mesh = Mesh(devices, axis_names=('x')) + mesh = jtu.create_mesh((8,), 'x') a = jax.device_put( jnp.arange(8 * 8).reshape((8, 8)), jax.sharding.NamedSharding(mesh, P('x', None))) @@ -236,10 +233,7 @@ def fwd(a): self.assertAllClose(c[1, :], a[0, :]) def test_collective_permute_with_multiple_axis_names(self): - mesh = Mesh( - np.array(jax.devices()[:8]).reshape((2, 2, 2)), - axis_names=('x', 'y', 'z'), - ) + mesh = jtu.create_mesh((2, 2, 2), ('x', 'y', 'z')) a = jax.device_put( jnp.arange(8 * 8).reshape((4, 16)), jax.sharding.NamedSharding(mesh, P('x', ('y', 'z'))), @@ -282,11 +276,7 @@ def fwd(a): ), ) def test_all_to_all(self, axis_name, mesh_axes): - devices = np.array(jax.devices()[: np.prod(tuple(mesh_axes.values()))]) - mesh = Mesh( - devices.reshape(tuple(mesh_axes.values())), - axis_names=tuple(mesh_axes.keys()), - ) + mesh = jtu.create_mesh(tuple(mesh_axes.values()), tuple(mesh_axes.keys())) a = jax.device_put( jnp.arange(8 * 8).reshape((8, 8)), jax.sharding.NamedSharding(mesh, P(axis_name, None)), @@ -308,12 +298,7 @@ def fwd(a): assert (c == jnp.reshape(a.T, (1, 64))).all() def test_all_to_all_with_axis_index_groups(self): - mesh_axes = dict(x=4) - devices = np.array(jax.devices()[: np.prod(tuple(mesh_axes.values()))]) - mesh = Mesh( - devices.reshape(tuple(mesh_axes.values())), - axis_names=tuple(mesh_axes.keys()), - ) + mesh = jtu.create_mesh((4,), ('x',)) a = jax.device_put( jnp.arange(4 * 4).reshape((4, 4)), jax.sharding.NamedSharding(mesh, P('x', None)), @@ -346,12 +331,7 @@ def fwd(a): self.assertAllClose(block, c.addressable_data(2 * i + j)) def test_all_to_all_grad(self): - mesh_axes = dict(x=4) - devices = np.array(jax.devices()[: np.prod(tuple(mesh_axes.values()))]) - mesh = Mesh( - devices.reshape(tuple(mesh_axes.values())), - axis_names=tuple(mesh_axes.keys()), - ) + mesh = jtu.create_mesh((4,), 'x') a = jax.device_put( jnp.arange(8 * 8, dtype=jnp.float32).reshape((8, 8)), jax.sharding.NamedSharding(mesh, P('x', None)), @@ -380,7 +360,7 @@ def loss_and_grad(x): self.assertAllClose(grad, 2 * np.ones_like(a)) def test_eager_repr(self): - mesh = Mesh(np.array(jax.devices()[:4]).reshape(2, 2), ('x', 'y')) + mesh = jtu.create_mesh((2, 2), ('x', 'y')) s = None @partial(shard_map, mesh=mesh, in_specs=P('x', 'y'), out_specs=P('x', 'y')) @@ -394,7 +374,7 @@ def f(x): self.assertIn('at mesh coordinates', s) def test_jvp_basic(self): - mesh = Mesh(np.array(jax.devices()[:4]).reshape(2, 2), ('x', 'y')) + mesh = jtu.create_mesh((2, 2), ('x', 'y')) g = shard_map(lambda x: jnp.sin(jnp.cos(x)), mesh, in_specs=(P('x', 'y'),), out_specs=P('x', 'y')) args = np.arange(4 * 4.).reshape(4, 4), @@ -402,7 +382,7 @@ def test_jvp_basic(self): jtu.check_grads(jax.jit(g), args, 2, ['fwd']) def test_linearize_basic(self): - mesh = Mesh(np.array(jax.devices()[:4]).reshape(2, 2), ('x', 'y')) + mesh = jtu.create_mesh((2, 2), ('x', 'y')) g = shard_map(lambda x: jnp.sin(jnp.cos(x)), mesh, in_specs=(P('x', 'y'),), out_specs=P('x', 'y')) x = np.arange(4 * 4.).reshape(4, 4) @@ -416,7 +396,7 @@ def test_linearize_basic(self): self.assertAllClose(y_dot, y_dot_, check_dtypes=False) def test_linearize_basic_repres(self): - mesh = Mesh(np.array(jax.devices()[:4]).reshape(2, 2), ('x', 'y')) + mesh = jtu.create_mesh((2, 2), ('x', 'y')) g = shard_map(lambda x: jax.lax.sin(jax.lax.cos(x)), mesh, in_specs=(P('x',),), out_specs=P('x',)) x = np.arange(4.) @@ -430,7 +410,7 @@ def test_linearize_basic_repres(self): self.assertAllClose(y_dot, y_dot_, check_dtypes=False) def test_linearize_basic_repres_jit(self): - mesh = Mesh(np.array(jax.devices()[:4]).reshape(2, 2), ('x', 'y')) + mesh = jtu.create_mesh((2, 2), ('x', 'y')) g = shard_map(lambda x: jnp.sin(jnp.cos(x)), mesh, in_specs=(P('x',),), out_specs=P('x',)) x = np.arange(4.) @@ -444,7 +424,7 @@ def test_linearize_basic_repres_jit(self): self.assertAllClose(y_dot, y_dot_, check_dtypes=False) def test_replication_checker_eager(self): - mesh = Mesh(np.array(jax.devices()[:4]).reshape(2, 2), ('x', 'y')) + mesh = jtu.create_mesh((2, 2), ('x', 'y')) x = np.arange(8 * 8.).reshape(8, 8) def f(x): @@ -462,7 +442,7 @@ def g2(x): _ = g2(x) # doesn't crash def test_replication_checker_jit(self): - mesh = Mesh(np.array(jax.devices()[:4]).reshape(2, 2), ('x', 'y')) + mesh = jtu.create_mesh((2, 2), ('x', 'y')) x = np.arange(8 * 8.).reshape(8, 8) def f(x): @@ -492,7 +472,7 @@ def g(x): jtu.check_grads(g, (x,), modes=['fwd'], order=2) def test_eager_control_flow(self): - mesh = Mesh(np.array(jax.devices()[:4]).reshape(2, 2), ('x', 'y')) + mesh = jtu.create_mesh((2, 2), ('x', 'y')) x = jnp.arange(2 * 2.).reshape(2, 2) def f(x): @@ -508,12 +488,12 @@ def g(x): self.assertAllClose(y, -x, check_dtypes=False) def test_outer_jit_detects_shard_map_mesh(self): - mesh = Mesh(np.array(jax.devices()[:4]).reshape(2, 2), ('x', 'y')) + mesh = jtu.create_mesh((2, 2), ('x', 'y')) f = shard_map(lambda x: x.reshape(1, *x.shape), mesh, P(), P('x')) _ = jax.jit(f)(jnp.array(2.0)) # doesn't crash def test_vmap_basic(self): - mesh = Mesh(np.array(jax.devices()[:4]).reshape(2, 2), ('x', 'y')) + mesh = jtu.create_mesh((2, 2), ('x', 'y')) x = jnp.arange(8 * 8.).reshape(8, 8) def g(x): @@ -523,7 +503,7 @@ def g(x): self.assertAllClose(y, 2 * x, check_dtypes=False) def test_vmap_basic_axis_name(self): - mesh = Mesh(np.array(jax.devices()[:4]).reshape(2, 2), ('x', 'y')) + mesh = jtu.create_mesh((2, 2), ('x', 'y')) x = jnp.arange(8 * 8.).reshape(8, 8) def g(x): @@ -533,7 +513,7 @@ def g(x): self.assertAllClose(y, 2 * x, check_dtypes=False) def test_vmap_basic_axis_name_reuse_mesh_name(self): - mesh = Mesh(np.array(jax.devices()[:4]).reshape(2, 2), ('x', 'y')) + mesh = jtu.create_mesh((2, 2), ('x', 'y')) x = jnp.arange(8 * 8.).reshape(8, 8) def g(x): @@ -543,7 +523,7 @@ def g(x): self.assertAllClose(y, 2 * x, check_dtypes=False) def test_tree_prefix_error(self): - mesh = Mesh(np.array(jax.devices()[:4]).reshape(2, 2), ('x', 'y')) + mesh = jtu.create_mesh((2, 2), ('x', 'y')) @partial(shard_map, mesh=mesh, in_specs=([P('x', 'y')],), out_specs=P('x', 'y')) def f(x): @@ -554,7 +534,7 @@ def f(x): f([x, x]) def test_rank_errors(self): - mesh = Mesh(np.array(jax.devices()[:4]).reshape(2, 2), ('x', 'y')) + mesh = jtu.create_mesh((2, 2), ('x', 'y')) def foo(): return {'hi': [3.]} @@ -575,7 +555,7 @@ def foo(): shard_map(foo, mesh=mesh, in_specs=P(None), out_specs=())(3.) def test_reverse_mode_ad(self): - mesh = Mesh(np.array(jax.devices()[:4]).reshape(2, 2), ('x', 'y')) + mesh = jtu.create_mesh((2, 2), ('x', 'y')) @jax.jit @partial(shard_map, mesh=mesh, @@ -589,7 +569,7 @@ def f(x, y): def test_post_process(self): # JVPTrace.post_process_shard_map and JaxprTrace.post_process_shard_map - mesh = Mesh(np.array(jax.devices()[:4]).reshape(2, 2), ('x', 'y')) + mesh = jtu.create_mesh((2, 2), ('x', 'y')) def f(x): @partial(shard_map, mesh=mesh, in_specs=P('x'), out_specs=P('x')) @@ -606,7 +586,7 @@ def g(y): @jtu.run_on_devices('gpu', 'tpu') def test_axis_index(self): - mesh = Mesh(np.array(jax.devices()[:4]), ('x',)) + mesh = jtu.create_mesh((4,), 'x') @jax.jit @partial(shard_map, mesh=mesh, in_specs=(), out_specs=P('x')) @@ -714,7 +694,7 @@ def f3(): jax.jit(f3)() def test_vmap_spmd_axis_name(self): - mesh = Mesh(np.array(jax.devices()[:4]).reshape(2, 2), ('x', 'y')) + mesh = jtu.create_mesh((2, 2), ('x', 'y')) @partial(shard_map, mesh=mesh, in_specs=P('x'), out_specs=P('x')) def f(x): @@ -729,7 +709,7 @@ def f(x): self.assertEqual(e.params['out_names'], ({0: ('y',), 1: ('x',)},)) def test_vmap_spmd_axis_name_pair(self): - mesh = Mesh(np.array(jax.devices()[:4]).reshape(2, 2), ('x', 'y')) + mesh = jtu.create_mesh((2, 2), ('x', 'y')) @partial(shard_map, mesh=mesh, in_specs=P(), out_specs=P()) def f(x): @@ -743,9 +723,135 @@ def f(x): self.assertIn('out_names', e.params) self.assertEqual(e.params['out_names'], ({0: ('x', 'y',)},)) + def test_nested_vmap_with_capture_spmd_axis_name(self): + self.skipTest('https://github.com/jax-ml/jax/issues/23476') + mesh = jtu.create_mesh((2, 2), ('x', 'y')) + + def to_map_with_capture(x, y): + + # We capture x from `to_map_with_capture`'s parameters. + def with_capture(y_slice): + # Inside of all the maps, we have 'mapped everything away'--we are just + # adding two scalars, but one by fully mapping across each of the two + # dimensions, the other by mapping across one and capturing the + # resulting scalar. + self.assertEqual(x.shape, ()) + self.assertEqual(y_slice.shape, ()) + return x + y_slice + + # This vmap i will refer to as 'inner vmap'. + vmap_with_capture = jax.vmap(with_capture) + shmap_vmap_capture = shard_map( + vmap_with_capture, mesh=mesh, in_specs=P('y'), out_specs=P('y') + ) + return shmap_vmap_capture(y) + + # And this one is the outer vmap. + mapped = jax.vmap(to_map_with_capture, spmd_axis_name='x') + x = jnp.arange(2).reshape(2) + y = jnp.arange(2 * 2).reshape(2, 2) + # Inner vmap inside of shard-map will be over an axis of size 1. Outer vmap + # is over an axis of size 2. This is a problem at the moment. + jax.make_jaxpr(mapped)(x, y).jaxpr + + def test_shard_map_abstract_mesh(self): + mesh = jtu.create_mesh((2, 2), ('x', 'y')) + np_inp = np.arange(16).reshape(8, 2) + arr = jax.device_put(np_inp, NamedSharding(mesh, P('x', 'y'))) + + def f(x): + return shard_map(lambda x: x, mesh=mesh.abstract_mesh, in_specs=P('x'), + out_specs=P('x'))(x) + + out1 = jax.jit(f)(arr) + self.assertArraysEqual(out1, np_inp) + self.assertEqual(out1.sharding, NamedSharding(mesh, P('x'))) + + out_eager = f(arr) + self.assertArraysEqual(out_eager, np_inp) + self.assertEqual(out_eager.sharding, NamedSharding(mesh, P('x'))) + + out1, out2 = shard_map(lambda x, y: (x, y), mesh=mesh.abstract_mesh, + in_specs=P('x'), out_specs=P('x'))(np_inp, arr) + self.assertArraysEqual(out1, np_inp) + self.assertEqual(out1.sharding, NamedSharding(mesh, P('x'))) + self.assertArraysEqual(out2, np_inp) + self.assertEqual(out2.sharding, NamedSharding(mesh, P('x'))) + + def test_different_devices_shmap_abstract_mesh_cache_hit(self): + if jax.device_count() < 4: + self.skipTest('Requires >=4 devices') + + mesh1 = jax.sharding.Mesh(jax.devices()[:2], 'i') + mesh2 = jax.sharding.Mesh(jax.devices()[2:4], 'i') + abstract_mesh = AbstractMesh(mesh1.shape_tuple) + + @jax.jit + def f(x): + x = shard_map(lambda x: x, mesh=abstract_mesh, in_specs=P('i'), + out_specs=P('i'))(x) + return jnp.sin(x) + + with ( + jtu.count_jit_tracing_cache_miss() as tracing_count, + jtu.count_jit_and_pmap_lowerings() as lowering_count, + jtu.count_jit_compilation_cache_miss() as compilation_count, + ): + a = jax.device_put(np.arange(8.), NamedSharding(mesh1, P())) + out_a = f(a) # tracing and lowering cached + + # same num_devices but different devices. + b = jax.device_put(out_a, NamedSharding(mesh2, P())) + f(b) # tracing and lowering cache *hit* + + self.assertEqual(tracing_count[0], 2) # 1 miss for `f` and 1 miss for `sin` + self.assertEqual(lowering_count[0], 1) + self.assertEqual(compilation_count[0], 2) # 2 misses since devices differ. + + def test_shmap_abstract_mesh_errors(self): + mesh = jtu.create_mesh((2,), ('x',)) + np_inp = np.arange(8) + abstract_mesh = jax.sharding.AbstractMesh(mesh.shape_tuple) + + with self.assertRaisesRegex( + ValueError, + "Please pass `jax.Array`s with a `NamedSharding` as input to" + " `shard_map` when passing `AbstractMesh` to the mesh argument"): + shard_map(lambda x: x, mesh=abstract_mesh, in_specs=P('x'), + out_specs=P('x'))(jnp.arange(8)) + + arr = jax.device_put(np_inp, NamedSharding(mesh, P('x'))) + mesh2 = jtu.create_mesh((2,), 'y') + abs_mesh2 = AbstractMesh(mesh2.shape_tuple) + with self.assertRaisesRegex( + ValueError, + 'Mesh shape of the input.*does not match the mesh shape passed to' + ' shard_map'): + shard_map(lambda x: x, mesh=abs_mesh2, in_specs=P('y'), + out_specs=P('y'))(arr) + + with self.assertRaisesRegex( + ValueError, + 'Please pass `jax.Array`s with a `NamedSharding` as input to' + ' `shard_map` when passing `AbstractMesh` to the mesh argument.'): + shard_map(lambda x: x, mesh=abstract_mesh, in_specs=P('x'), + out_specs=P('x'))(np_inp) + + arr_mesh2 = jax.device_put(np_inp, NamedSharding(mesh2, P('y'))) + with self.assertRaisesRegex( + ValueError, + 'Mesh shape of the input.*does not match the mesh shape passed to' + ' shard_map'): + shard_map(lambda x, y: (x, y), mesh=abstract_mesh, in_specs=P('x'), + out_specs=P('x'))(arr, arr_mesh2) + @parameterized.parameters([True, False]) @jtu.run_on_devices('cpu', 'gpu', 'tpu') def test_debug_print_jit(self, jit): + if config.use_shardy_partitioner.value: + self.skipTest( + 'TODO(b/364547005): debug prints not supported by Shardy yet' + ) mesh = Mesh(jax.devices(), ('i',)) @partial(shard_map, mesh=mesh, in_specs=P('i'), out_specs=P('i')) @@ -800,8 +906,8 @@ def f(_): @jax.legacy_prng_key('allow') def test_prngkeyarray_eager(self): - # https://github.com/google/jax/issues/15398 - mesh = jtu.create_global_mesh((4,), ('x',)) + # https://github.com/jax-ml/jax/issues/15398 + mesh = jtu.create_mesh((4,), ('x',)) sharding = jax.sharding.NamedSharding(mesh, P('x')) rng = jax.random.PRNGKey(0) @@ -817,7 +923,7 @@ def f(key): _ = g(sharded_rng) # don't crash! def test_functools_partial_rank_error(self): - mesh = jtu.create_global_mesh((4,), ('x',)) + mesh = jtu.create_mesh((4,), ('x',)) @partial def f(x): @@ -829,7 +935,7 @@ def f(x): g(x) def test_in_specs_none_error(self): - mesh = jtu.create_global_mesh((4,), ('x',)) + mesh = jtu.create_mesh((4,), ('x',)) def f(x): return x @@ -843,7 +949,7 @@ def f(x): return x shard_map(f, mesh, in_specs=P(), out_specs=P())(3.) # doesn't crash def test_scan_rep_rule(self): - mesh = jtu.create_global_mesh((2, 2,), ('x', 'y')) + mesh = jtu.create_mesh((2, 2,), ('x', 'y')) def f(x, y, z): x, y, z = x.sum(), y.sum(), z.sum() @@ -896,7 +1002,7 @@ def foo_jvp(primals, tangents): (x,), (x_dot,) = primals, tangents return foo(x), 3. * x_dot - mesh = jtu.create_global_mesh((4,), ('x',)) + mesh = jtu.create_mesh((4,), ('x',)) g = shard_map(foo, mesh, in_specs=(P('x'),), out_specs=P('x')) y, x_bar = jax.value_and_grad(lambda x: g(x).sum())(jnp.arange(4.)) self.assertAllClose(y, (2. * jnp.arange(4.)).sum()) @@ -915,7 +1021,7 @@ def foo_bwd(_, y_bar): foo.defvjp(foo_fwd, foo_bwd) - mesh = jtu.create_global_mesh((4,), ('x',)) + mesh = jtu.create_mesh((4,), ('x',)) g = shard_map(foo, mesh, in_specs=(P('x'),), out_specs=P('x')) y, x_bar = jax.value_and_grad(lambda x: g(x).sum())(jnp.arange(4.)) self.assertAllClose(y, (2. * jnp.arange(4.)).sum()) @@ -929,7 +1035,7 @@ def foo(): if jit: foo = jax.jit(foo) - mesh = jtu.create_global_mesh((4,), ('x',)) + mesh = jtu.create_mesh((4,), ('x',)) ans = shard_map(foo, mesh, in_specs=(), out_specs=P('x'))() expected = jnp.arange(4.) self.assertAllClose(ans, expected, check_dtypes=False) @@ -945,7 +1051,7 @@ def foo(): if jit: foo = jax.jit(foo) - mesh = jtu.create_global_mesh((4, 2), ('i', 'j')) + mesh = jtu.create_mesh((4, 2), ('i', 'j')) ans1, ans2, ans3 = shard_map(foo, mesh, in_specs=(), out_specs=P('i', 'j'))() expected1 = jnp.arange(4.)[:, None] + jnp.zeros((4, 2)) @@ -956,7 +1062,7 @@ def foo(): self.assertAllClose(ans3, expected3, check_dtypes=False) def test_axis_index_eager(self): - mesh = jtu.create_global_mesh((4,), ('x',)) + mesh = jtu.create_mesh((4,), ('x',)) @partial(shard_map, mesh=mesh, in_specs=(), out_specs=P()) def foo(): @@ -967,8 +1073,8 @@ def foo(): self.assertEqual(out, 1.) def test_jaxpr_shardings_with_no_outputs(self): - # https://github.com/google/jax/issues/15385 - mesh = jtu.create_global_mesh((4,), ('i',)) + # https://github.com/jax-ml/jax/issues/15385 + mesh = jtu.create_mesh((4,), ('i',)) @jax.jit @partial(shard_map, mesh=mesh, in_specs=(), out_specs=P('i')) @@ -984,7 +1090,7 @@ def g(a_block): g(np.arange(32)) # don't crash def test_device_put(self): - mesh = jtu.create_global_mesh((4,), ('i',)) + mesh = jtu.create_mesh((4,), ('i',)) @partial(shard_map, mesh=mesh, in_specs=P('i'), out_specs=P('i')) def f(x): @@ -1007,9 +1113,9 @@ def g(x): @jtu.run_on_devices('cpu', 'gpu', 'tpu') def test_key_array_with_replicated_last_tile_dim(self): - # See https://github.com/google/jax/issues/16137 + # See https://github.com/jax-ml/jax/issues/16137 - mesh = jtu.create_global_mesh((2, 4), ('i', 'j')) + mesh = jtu.create_mesh((2, 4), ('i', 'j')) def f(rng): @partial(shard_map, mesh=mesh, in_specs=P('i'), out_specs=P('i'), @@ -1049,7 +1155,7 @@ def assert_dce_result(self, jaxpr: core.Jaxpr, used_outputs: list[bool], jtu.check_grads(f, inputs_dce, order=2, modes=['rev']) def test_returned_out_sharding(self): - mesh = jtu.create_global_mesh((1, 2), ('x', 'y')) + mesh = jtu.create_mesh((1, 2), ('x', 'y')) s = NamedSharding(mesh, P('x', 'y')) inp = jax.device_put(jnp.zeros((2, 2)), s) out = shard_map(lambda x: x, mesh, P('x', 'y'), P('x', 'y'))(inp) @@ -1057,7 +1163,7 @@ def test_returned_out_sharding(self): self.assertArraysEqual(out, inp) def test_dce(self): - mesh = jtu.create_global_mesh((4, 2), ('i', 'j')) + mesh = jtu.create_mesh((4, 2), ('i', 'j')) def f(x, y, z): @partial(shard_map, mesh=mesh, in_specs=(P('i', 'j'), P(None, 'i')), @@ -1108,7 +1214,7 @@ def g(y, z): check_diff=False) def test_post_process_partial_eval_with_scalar_res(self): - mesh = jtu.create_global_mesh((4, 2), ('i', 'j')) + mesh = jtu.create_mesh((4, 2), ('i', 'j')) g = jax.grad(lambda x: shard_map(lambda: jnp.sin(x), mesh=mesh, in_specs=P(), out_specs=P())())(2.0) self.assertAllClose(g, jnp.cos(2.0), check_dtypes=False) @@ -1127,19 +1233,24 @@ def foo(x): return x hlo_str = mlir.module_to_string(jax.jit(foo).lower(x).compiler_ir('stablehlo')) - self.assertIn("call @shmap_body", hlo_str) - self.assertIn("call @shmap_body_0", hlo_str) - self.assertIn("%arg0: tensor<1xf32>", hlo_str) - self.assertIn("\"[None]\"", hlo_str) - self.assertIn("%arg1: tensor<1xf32>", hlo_str) - self.assertIn("\"[('i',)]\"", hlo_str) - self.assertIn("-> (tensor<1xf32> {jax.result_info = \"[('i',)]\"})", hlo_str) + if config.use_shardy_partitioner.value: + self.assertEqual(2, hlo_str.count('sdy.manual_computation')) + else: + self.assertIn('call @shmap_body', hlo_str) + self.assertIn('call @shmap_body_0', hlo_str) + self.assertIn('%arg0: tensor<1xf32>', hlo_str) + self.assertIn('"[None]"', hlo_str) + self.assertIn('%arg1: tensor<1xf32>', hlo_str) + self.assertIn('"[(\'i\',)]"', hlo_str) + self.assertIn( + '-> (tensor<1xf32> {jax.result_info = "[(\'i\',)]"})', hlo_str + ) def test_rewrite_process_call(self): def f(x): return core.call_p.bind(lu.wrap_init(lambda x: [2. * x]), x)[0] * x - mesh = jtu.create_global_mesh((4,), ('x',)) + mesh = jtu.create_mesh((4,), ('x',)) g = shard_map(f, mesh, in_specs=(P('x'),), out_specs=P('x')) x = jnp.arange(4.) y = jax.jit(g)(x) # eager requires shmap to have ShardMapTrace.process_call @@ -1148,7 +1259,7 @@ def f(x): def test_rewrite_post_process_call(self): # We shouldn't hit post_process_call here because of RewriteTrace's dynamic # behavior (i.e. no data dependence). - mesh = jtu.create_global_mesh((4,), ('x',)) + mesh = jtu.create_mesh((4,), ('x',)) @jax.jit @partial(shard_map, mesh=mesh, in_specs=(P('x'),), out_specs=P('x')) @@ -1170,7 +1281,7 @@ def foo_jvp(primals, tangents): (x,), (x_dot,) = primals, tangents return foo(x), 2. * x_dot - mesh = jtu.create_global_mesh((4,), ('x',)) + mesh = jtu.create_mesh((4,), ('x',)) g = shard_map(lambda x: foo(x) * x, mesh, in_specs=(P('x'),), out_specs=P('x')) if jit: @@ -1198,7 +1309,7 @@ def foo_bwd(_, y_bar): foo.defvjp(foo_fwd, foo_bwd) - mesh = jtu.create_global_mesh((4,), ('x',)) + mesh = jtu.create_mesh((4,), ('x',)) g = shard_map(lambda x: foo(x) * x, mesh, in_specs=(P('x'),), out_specs=P('x')) if jit: @@ -1226,7 +1337,7 @@ def foo_bwd(_, y_bar): foo.defvjp(foo_fwd, foo_bwd) - mesh = jtu.create_global_mesh((4,), ('x',)) + mesh = jtu.create_mesh((4,), ('x',)) g = shard_map(lambda x: foo(x) * x, mesh, in_specs=(P('x'),), out_specs=P('x')) if jit: @@ -1243,7 +1354,7 @@ def foo_bwd(_, y_bar): def test_same_pspec_eager_shard_map(self): # This behavior is not guaranteed by JAX and this test can be changed if # the behavior changes. - mesh = jtu.create_global_mesh((1, 4, 1), ('data', 'seq', 'model')) + mesh = jtu.create_mesh((1, 4, 1), ('data', 'seq', 'model')) def f(x): return x * x + 2 @@ -1271,7 +1382,7 @@ def foo_bwd(y, _): foo.defvjp(foo_fwd, foo_bwd) - mesh = jtu.create_global_mesh((4,), ('x',)) + mesh = jtu.create_mesh((4,), ('x',)) g = shard_map(lambda x, y: foo(x, y) * y, mesh, in_specs=(P(), P('x')), out_specs=P('x')) if jit: @@ -1306,7 +1417,7 @@ def foo_scan(x): y, _ = jax.lax.scan(lambda x, _: (foo(x), None), x, None, length=1) return y - mesh = jtu.create_global_mesh((4,), ('x',)) + mesh = jtu.create_mesh((4,), ('x',)) g = shard_map(lambda x: foo_scan(x) * x, mesh, in_specs=(P('x'),), out_specs=P('x')) if jit: @@ -1321,7 +1432,7 @@ def foo_scan(x): self.assertAllClose(x_bar, 2 * 2 * x, check_dtypes=True) def test_transpose_identity(self): - mesh = jtu.create_global_mesh((4,), ('x',)) + mesh = jtu.create_mesh((4,), ('x',)) @partial(shard_map, mesh=mesh, in_specs=P(), out_specs=P()) def f(x): @@ -1346,7 +1457,7 @@ def g(x): self.assertLen(e2.params['jaxpr'].eqns, 1) def test_fanout_specs_transpose_to_psum(self): - mesh = jtu.create_global_mesh((4,), ('x',)) + mesh = jtu.create_mesh((4,), ('x',)) @partial(shard_map, mesh=mesh, in_specs=P(), out_specs=P('x')) def f(x): @@ -1359,7 +1470,7 @@ def f(x): self.assertEqual(e2.params['axes'], ('x',)) def test_fanin_psum_transposes_to_fanout(self): - mesh = jtu.create_global_mesh((4,), ('x',)) + mesh = jtu.create_mesh((4,), ('x',)) @partial(shard_map, mesh=mesh, in_specs=P('x'), out_specs=P()) def f(x): @@ -1371,7 +1482,7 @@ def f(x): self.assertEqual(str(e1.primitive), 'pbroadcast') def test_psum_with_implicit_fanout_self_transposes(self): - mesh = jtu.create_global_mesh((4,), ('x',)) + mesh = jtu.create_mesh((4,), ('x',)) @partial(shard_map, mesh=mesh, in_specs=P('x'), out_specs=P('x')) def f(x): @@ -1384,7 +1495,7 @@ def f(x): self.assertEqual(str(e2.primitive), 'pbroadcast') def test_rewrite_binops(self): - mesh = jtu.create_global_mesh((4,), ('x',)) + mesh = jtu.create_mesh((4,), ('x',)) @partial(shard_map, mesh=mesh, in_specs=(P(), P('x')), out_specs=P('x')) def f(x, y): @@ -1397,7 +1508,7 @@ def f(x, y): self.assertEqual(e.params['axes'], ('x',)) def test_rewrite_scan(self): - mesh = jtu.create_global_mesh((4,), ('x',)) + mesh = jtu.create_mesh((4,), ('x',)) @partial(shard_map, mesh=mesh, in_specs=P('x'), out_specs=P('x')) def f(x): @@ -1413,12 +1524,14 @@ def f(x): self.assertEqual(e2.primitive.name, 'pbroadcast') def test_check_rep_false_grads(self): + if jtu.is_device_tpu(5, 'e'): + self.skipTest('TODO(b/307508823): Test currently fails on TPU v5e') + # This test is redundant with the systematic tests below, but it serves as a # direct regression test for a bug. - mesh = jtu.create_global_mesh((4,), ('heads',)) + mesh = jtu.create_mesh((4,), ('heads',)) def f(q, k, v): - def body(q, k, v): return q * k[None, :] + v[None, :] @@ -1433,7 +1546,11 @@ def body(q, k, v): k = jax.device_put(jnp.arange(8.), jax.sharding.NamedSharding(mesh, kv_spec)) v = jax.device_put(jnp.arange(8.), jax.sharding.NamedSharding(mesh, kv_spec)) - jtu.check_grads(f, (q, k, v), order=1, modes=['rev'], rtol=1e-2) + if jtu.device_under_test() == 'tpu': + rtol = 2e-2 + else: + rtol = 1e-2 + jtu.check_grads(f, (q, k, v), order=1, modes=['rev'], rtol=rtol) def test_axis_env_extension_regression(self): def foo(x): @@ -1449,7 +1566,7 @@ def bar(x): @parameterized.parameters(it.product([True, False], repeat=2)) def test_res_forwarding_optimization(self, jit, remat): - mesh = jtu.create_global_mesh((4,), ('i',)) + mesh = jtu.create_mesh((4,), ('i',)) @partial(shard_map, mesh=mesh, in_specs=P('i'), out_specs=P('i')) def f(x): @@ -1472,7 +1589,7 @@ def f(x): @parameterized.parameters(it.product([True, False], repeat=2)) def test_res_forwarding_optimization_complex(self, jit, remat): # like the above test, but a different function `f` - mesh = jtu.create_global_mesh((4,), ('i',)) + mesh = jtu.create_mesh((4,), ('i',)) @partial(shard_map, mesh=mesh, in_specs=P('i'), out_specs=P('i')) def f(x): @@ -1494,7 +1611,7 @@ def f(x): @parameterized.parameters([True, False]) def test_check_rep_failure_inside_rule(self, jit): - mesh = jtu.create_global_mesh((4,), ('i',)) + mesh = jtu.create_mesh((4,), ('i',)) def loss(w, x): @partial(shard_map, mesh=mesh, in_specs=P('i'), out_specs=P()) @@ -1508,7 +1625,7 @@ def f(x): jax.grad(loss)(3.0, jnp.arange(8.)) # don't crash def test_conv_general_dilated(self): - mesh = jtu.create_global_mesh((4,), ('i',)) + mesh = jtu.create_mesh((4,), ('i',)) dot = partial(lax.conv_general_dilated, window_strides=(), padding='VALID', dimension_numbers=('NC', 'IO', 'NC')) @@ -1524,31 +1641,31 @@ def f(x, y): self.assertAllClose(y, a @ b, check_dtypes=False, atol=1e-2, rtol=1e-2) def test_cumsum(self): - mesh = jtu.create_global_mesh((4,), ('i',)) + mesh = jtu.create_mesh((4,), ('i',)) x = jnp.arange(8.) shard_map(jnp.cumsum, mesh=mesh, in_specs=P('i'), out_specs=P('i') )(x) # don't crash def test_custom_jvp_inside_jit(self): - mesh = jtu.create_global_mesh((4,), ('batch',)) + mesh = jtu.create_mesh((4,), ('batch',)) x = shard_map(jax.jit(jax.nn.relu), mesh=mesh, in_specs=P('batch'), out_specs=P('batch'))(jnp.arange(16.)) # don't crash def test_random_normal_rules(self): - mesh = jtu.create_global_mesh((4,), ('i',)) + mesh = jtu.create_mesh((4,), ('i',)) keys = jax.random.split(jax.random.key(0), 4) shard_map(lambda k: jax.random.normal(k[0], (1,)), mesh=mesh, in_specs=P('i'), out_specs=P('i'))(keys) # don't crash def test_erf_rules(self): - mesh = jtu.create_global_mesh((4,), ('i',)) + mesh = jtu.create_mesh((4,), ('i',)) x = jnp.arange(16.) shard_map(jax.lax.erf, mesh=mesh, in_specs=P('i'), out_specs=P('i'))(x) # don't crash def test_error_for_variable_num_args(self): - mesh = Mesh(np.array(jax.devices()[:4]).reshape(2, 2), ('x', 'y')) + mesh = jtu.create_mesh((2, 2), ('x', 'y')) def f(*args): return args[0] @ args[1] @@ -1560,7 +1677,7 @@ def f(*args): shard_f(jnp.ones((8, 8)), jnp.ones((8, 8))) def test_custom_vjp_replication_error_message_hint(self): - mesh = Mesh(np.array(jax.devices()[:4]), ('i',)) + mesh = jtu.create_mesh((4,), 'i') @jax.custom_vjp def f(x): @@ -1582,8 +1699,8 @@ def g(x): self.assertAllClose(grad, jnp.ones(4) * 4 * 4, check_dtypes=False) def test_repeated_psum_allowed(self): - # https://github.com/google/jax/issues/19175 - mesh = Mesh(jax.devices()[:4], ('i',)) + # https://github.com/jax-ml/jax/issues/19175 + mesh = jtu.create_mesh((4,), 'i') @partial(shard_map, mesh=mesh, in_specs=P('i'), out_specs=P()) def g(x): @@ -1632,7 +1749,7 @@ def f(inputs): modes=['rev'], atol=1e-3, rtol=1e-3) def test_partial_auto(self): - mesh = jtu.create_global_mesh((2, 2), ('i', 'j')) + mesh = jtu.create_mesh((2, 2), ('i', 'j')) def g(x): x = jax.lax.with_sharding_constraint( @@ -1651,14 +1768,38 @@ def f(x): v = jnp.arange(32.).reshape(4, 8) v = jax.device_put(v, jax.sharding.NamedSharding(mesh, P('i', 'j'))) - self.assertIn( - 'sharding={devices=[1,1,2,2]<=[4] last_tile_dims={manual, replicated}}', - f.lower(v).as_text('hlo'), - ) + if config.use_shardy_partitioner.value: + self.assertIn( + 'in_shardings=[<@mesh, [{"i"}, {}]>] out_shardings=[<@mesh, [{"i"},' + ' {}]>] manual_axes={"i"}', + f.lower(v).as_text(), + ) + else: + self.assertIn( + 'sharding={devices=[1,1,2,2]<=[4] last_tile_dims={manual,' + ' replicated}}', + f.lower(v).as_text('hlo'), + ) self.assertAllClose(v*v, f(v), check_dtypes=False) + def test_sharded_prng_with_abstract_mesh(self): + shape = (8, 2, 2) + mesh = jtu.create_mesh((2, 2, 2), ('x', 'y', 'z')) + + np_inp = np.arange(math.prod(shape), dtype=np.uint32).reshape(shape) + key = prng.random_seed(np_inp, impl=prng.threefry_prng_impl) + key = jax.device_put(key, NamedSharding(mesh, P())) + + @jax.jit + def shard_key(key): + return shard_map( + lambda x: x, mesh=mesh.abstract_mesh, in_specs=P(), out_specs=P())(key) + + out = shard_key(key) + self.assertEqual(out.sharding, NamedSharding(mesh, P())) + def test_partial_auto_error_wsc_manual(self): - mesh = jtu.create_global_mesh((2, 2), ('i', 'j')) + mesh = jtu.create_mesh((2, 2), ('i', 'j')) def g(x): x = jax.lax.with_sharding_constraint( @@ -1681,7 +1822,7 @@ def f(x): f(v) def test_partial_auto_error_invalid_auto(self): - mesh = jtu.create_global_mesh((2, 2), ('i', 'j')) + mesh = jtu.create_mesh((2, 2), ('i', 'j')) def g(x): x = jax.lax.with_sharding_constraint( @@ -1704,7 +1845,7 @@ def f(x): f(v) def test_partial_auto_error_wrong_in_specs(self): - mesh = jtu.create_global_mesh((2, 2), ('i', 'j')) + mesh = jtu.create_mesh((2, 2), ('i', 'j')) def g(x): x = jax.lax.with_sharding_constraint( @@ -1727,7 +1868,7 @@ def f(x): f(v) def test_nested_partial_auto(self): - mesh = jtu.create_global_mesh((2, 2), ('i', 'j')) + mesh = jtu.create_mesh((2, 2), ('i', 'j')) def g(x): return x * x @@ -1750,7 +1891,7 @@ def f(x): self.assertAllClose(v*v, f(v), check_dtypes=False) def test_axis_size_1_partial_auto(self): - mesh = jtu.create_global_mesh((1, 2, 2), ('i', 'j', 'k')) + mesh = jtu.create_mesh((1, 2, 2), ('i', 'j', 'k')) def h(x): return x * x @@ -1768,7 +1909,7 @@ def f(x): self.assertAllClose(v*v, f(v), check_dtypes=False) def test_partial_auto_of_pjit(self): - mesh = jtu.create_global_mesh((2, 2), ('i', 'j')) + mesh = jtu.create_mesh((2, 2), ('i', 'j')) def h(): def _make_zeros(): @@ -1785,7 +1926,12 @@ def f(): self.assertAllClose(jax.jit(f)(), jnp.zeros((2,))) def test_partial_auto_of_pjit_different_mesh(self): - mesh = jtu.create_global_mesh((2, 2), ('i', 'j')) + if config.use_shardy_partitioner.value: + self.skipTest( + 'Shardy requires the mesh axis names to be the same across ' + 'the entire computation.' + ) + mesh = jtu.create_mesh((2, 2), ('i', 'j')) mesh2 = jax.sharding.Mesh(mesh.devices, ('k', 'l')) def h(): @@ -1803,8 +1949,8 @@ def f(): self.assertAllClose(jax.jit(f)(), jnp.zeros((2,))) def test_vmap_grad_shmap_spmd_axis_name_residuals(self): - # https://github.com/google/jax/pull/21032 - mesh = jtu.create_global_mesh((4, 2), ('i', 'j')) + # https://github.com/jax-ml/jax/pull/21032 + mesh = jtu.create_mesh((4, 2), ('i', 'j')) @partial( shard_map, @@ -1820,8 +1966,8 @@ def f(x): jax.vmap(jax.grad(lambda x: f(x).sum()), spmd_axis_name='i')(xs) # don't crash def test_vmap_grad_remat_shmap_spmd_axis_name_residuals(self): - # https://github.com/google/jax/pull/21056 - mesh = jtu.create_global_mesh((4, 2), ('i', 'j')) + # https://github.com/jax-ml/jax/pull/21056 + mesh = jtu.create_mesh((4, 2), ('i', 'j')) @partial(jax.remat, policy=lambda *_, **__: True) @partial( @@ -1838,8 +1984,8 @@ def f(x): jax.vmap(jax.grad(lambda x: f(x).sum()), spmd_axis_name='i')(xs) # don't crash def test_grad_shmap_residuals_axis_names_in_mesh_order(self): - # https://github.com/google/jax/issues/21236 - mesh = jtu.create_global_mesh((4, 2, 1, 1), ('i', 'j', 'k', 'a')) + # https://github.com/jax-ml/jax/issues/21236 + mesh = jtu.create_mesh((4, 2, 1, 1), ('i', 'j', 'k', 'a')) @partial( shard_map, @@ -1853,13 +1999,17 @@ def f(x): xs = jnp.arange(16.) ir = jax.jit(jax.grad(lambda x: f(x).sum())).lower(xs) - self.assertIn( - '{jax.result_info = "[(\'i\', \'j\', \'k\', \'a\')]"}', - ir.as_text() - ) + if config.use_shardy_partitioner.value: + self.assertIn( + 'out_shardings=[<@mesh, [{"i", "j", "k", "a"}]>]', ir.as_text() + ) + else: + self.assertIn( + "{jax.result_info = \"[('i', 'j', 'k', 'a')]\"}", ir.as_text() + ) def test_vmap_spmd_axis_name_error(self): - mesh = jtu.create_global_mesh((4, 2), ('i', 'j')) + mesh = jtu.create_mesh((4, 2), ('i', 'j')) @partial( shard_map, @@ -1889,7 +2039,7 @@ def g(x): jax.vmap(g, spmd_axis_name='i')(xs) def test_in_spec_none(self): - mesh = jtu.create_global_mesh((4, 2), ('i', 'j')) + mesh = jtu.create_mesh((4, 2), ('i', 'j')) x = jnp.arange(8).reshape(4, 2) @@ -1940,7 +2090,7 @@ def f4(o1, o2, x, o3): self.assertAllClose(y, jnp.sin(x), check_dtypes=False) def test_in_spec_none_divisibility_errors(self): - mesh = jtu.create_global_mesh((4, 2), ('i', 'j')) + mesh = jtu.create_mesh((4, 2), ('i', 'j')) x = jnp.arange(4).reshape(2, 2) with self.assertRaisesRegex(ValueError, 'divisible'): @@ -1962,7 +2112,7 @@ def test_in_spec_none_divisibility_errors(self): )((object(), object()), x) def test_in_spec_none_rank_errors(self): - mesh = jtu.create_global_mesh((4, 2), ('i', 'j')) + mesh = jtu.create_mesh((4, 2), ('i', 'j')) x = jnp.arange(4) with self.assertRaisesRegex(ValueError, 'rank'): @@ -1984,8 +2134,8 @@ def test_in_spec_none_rank_errors(self): )((object(), object()), x) def test_custom_linear_solve_rep_rules(self): - # https://github.com/google/jax/issues/20162 - mesh = jtu.create_global_mesh((1,), ('i',)) + # https://github.com/jax-ml/jax/issues/20162 + mesh = jtu.create_mesh((1,), ('i',)) a = jnp.array(1).reshape(1, 1) b = jnp.array(1).reshape(1) @@ -1997,7 +2147,7 @@ def f(a, b): _ = f(a, b) # don't crash def test_temporary_error_suppression_flag(self): - mesh = jtu.create_global_mesh((2,), ('i',)) + mesh = jtu.create_mesh((2,), ('i',)) def f(x, y): z = shard_map(lambda x, y: x + jax.lax.all_gather(y, 'i', tiled=True), @@ -2014,6 +2164,63 @@ def f(x, y): with config.disable_vmap_shmap_error(): _ = jax.vmap(f, in_axes=(0, None), spmd_axis_name='i')(xs, y) + def test_in_spec_none_hashability(self): + mesh = jtu.create_mesh((2,), ('i',)) + + class A: + def __hash__(self): + raise Exception + + @partial(shard_map, mesh=mesh, in_specs=(None,), out_specs=()) + def f(a): + return () + + f(A()) # don't crash + + def test_get_check_rep(self): + mesh = jtu.create_mesh((2, 2), ('x', 'y')) + + def f(x, reduce_along, use_jit): + out_spec = P(*(n for n in ('x', 'y') if n not in reduce_along)) + + @partial(shard_map, mesh=mesh, in_specs=P('x', 'y'), out_specs=out_spec) + def g(x): + result = lax.psum(x, axis_name=reduce_along) + def check_rep(result): + self.assertEqual( + jax.experimental.shard_map.get_replication(result), + set(reduce_along)) + return result + result = check_rep(result) + result = jax.vmap(check_rep)(result) + return result + if use_jit: + return jax.jit(g)(x) + else: + return g(x) + + for use_jit in [True, False]: + x = np.zeros((8, 8), dtype=np.float32) + f(x, reduce_along=('y',), use_jit=use_jit) + f(x, reduce_along=('x',), use_jit=use_jit) + f(x, reduce_along=('x', 'y'), use_jit=use_jit) + + def test_pmin(self): + mesh = jtu.create_mesh((4,), ('i',)) + x = jnp.arange(8., dtype=np.float32) + y = shard_map(lambda x: jax.lax.pmin(x, 'i'), + mesh=mesh, in_specs=P('i'), out_specs=P() + )(x) # don't crash + self.assertArraysEqual(y, np.array([0, 1], dtype=np.float32)) + + def test_pmax(self): + mesh = jtu.create_mesh((4,), ('i',)) + x = jnp.arange(8., dtype=np.float32) + y = shard_map(lambda x: jax.lax.pmax(x, 'i'), + mesh=mesh, in_specs=P('i'), out_specs=P() + )(x) # don't crash + self.assertArraysEqual(y, np.array([6, 7], dtype=np.float32)) + class FunSpec(NamedTuple): name: str @@ -2259,7 +2466,7 @@ class ShardMapSystematicTest(jtu.JaxTestCase): @staticmethod def make_mesh(mesh_shape): - return jtu.create_global_mesh(tuple(mesh_shape.values()), tuple(mesh_shape)) + return jtu.create_mesh(tuple(mesh_shape.values()), tuple(mesh_shape)) @parameterized.named_parameters( sample(jtu.NUM_GENERATED_CASES.value, sample_shmap)) @@ -2428,5 +2635,27 @@ def fwd(a): self.assertEqual(c.addressable_data(0).shape, (4, 2)) +@jtu.with_config(jax_use_shardy_partitioner=True) +class SdyIntegrationTest(jtu.JaxTestCase): + # Verify we can lower to a `ManualComputationOp`. + def test_shardy_collective_permute(self): + mesh = jtu.create_mesh((2,), ('x',)) + a = jax.device_put( + jnp.arange(8 * 8).reshape((8, 8)), + jax.sharding.NamedSharding(mesh, P('x', None)), + ) + + @jax.jit + @partial( + shard_map, mesh=mesh, in_specs=(P('x', None),), out_specs=P('x', None) + ) + def fwd(a): + axis_size = lax.psum(1, 'x') + perm = [(j, (j + 1) % axis_size) for j in range(axis_size)] + return lax.ppermute(a, 'x', perm=perm) + + self.assertIn('sdy.manual_computation', jax.jit(fwd).lower(a).as_text()) + + if __name__ == '__main__': absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/tests/sparse_bcoo_bcsr_test.py b/tests/sparse_bcoo_bcsr_test.py index 545d73bff291..12088db7fe18 100644 --- a/tests/sparse_bcoo_bcsr_test.py +++ b/tests/sparse_bcoo_bcsr_test.py @@ -310,7 +310,7 @@ def test_bcoo_extract_duplicate_indices_n_sparse_0(self): self.assertArraysEqual(data2, jnp.array([[1, 0], [5, 0], [9, 0]])) def test_bcoo_extract_batching(self): - # https://github.com/google/jax/issues/9431 + # https://github.com/jax-ml/jax/issues/9431 indices = jnp.zeros((4, 1, 1), dtype=int) mat = jnp.arange(4.).reshape((4, 1)) @@ -353,7 +353,7 @@ def test_bcoo_extract_ad(self, shape, dtype, n_batch, n_dense): self.assertEqual(hess.shape, data.shape + 2 * M.shape) def test_bcoo_extract_zero_nse(self): - # Regression test for https://github.com/google/jax/issues/13653 + # Regression test for https://github.com/jax-ml/jax/issues/13653 # (n_batch, n_sparse, n_dense) = (1, 0, 0), nse = 2 args_maker = lambda: (jnp.zeros((3, 2, 0), dtype='int32'), jnp.arange(3)) @@ -973,8 +973,9 @@ def test_bcoo_spdot_general_nse(self, lhs_shape, rhs_shape): self.assertArraysAllClose(out.todense(), expected_out) self.assertEqual(out.nse, expected_nse) + @jtu.ignore_warning(message="bcoo_dot_general cusparse/hipsparse lowering not available") def test_bcoo_spdot_general_ad_bug(self): - # Regression test for https://github.com/google/jax/issues/10163 + # Regression test for https://github.com/jax-ml/jax/issues/10163 A_indices = jnp.array([[0, 1], [0, 2], [1, 1], [1, 2], [1, 0]]) A_values = jnp.array([-2.0, 1.0, -1.0, 0.5, 2.0]) A_shape = (2, 3) @@ -1287,7 +1288,7 @@ def test_bcoo_sum_duplicates_remove_zeros(self): self.assertEqual(y2.nse, x.nse) def test_bcoo_sum_duplicates_padding(self): - # Regression test for https://github.com/google/jax/issues/8163 + # Regression test for https://github.com/jax-ml/jax/issues/8163 size = 3 data = jnp.array([1, 0, 0]) indices = jnp.array([1, size, size])[:, None] @@ -1606,7 +1607,7 @@ def test_bcoo_mul_sparse(self, lhs_shape, lhs_dtype, rhs_shape, rhs_dtype, lhs_n self._CheckAgainstDense(operator.mul, operator.mul, args_maker, tol=tol) def test_bcoo_mul_sparse_with_duplicates(self): - # Regression test for https://github.com/google/jax/issues/8888 + # Regression test for https://github.com/jax-ml/jax/issues/8888 indices = jnp.array([[0, 1, 0, 0, 1, 1], [1, 0, 1, 2, 0, 2]]).T data = jnp.array([1, 2, 3, 4, 5, 6]) @@ -1940,7 +1941,7 @@ def test_bcsr_concatenate(self, shape, dtype, n_batch, n_dense, dimension): self._CheckGradsSparse(dense_func, sparse_func, args_maker) def test_bcoo_spdot_abstract_eval_bug(self): - # Regression test for https://github.com/google/jax/issues/21921 + # Regression test for https://github.com/jax-ml/jax/issues/21921 lhs = sparse.BCOO( (jnp.float32([[1]]), lax.broadcasted_iota(jnp.int32, (10, 1, 1), 0)), shape=(10, 10)) diff --git a/tests/sparse_test.py b/tests/sparse_test.py index 616396222ec6..eb8d70be1f05 100644 --- a/tests/sparse_test.py +++ b/tests/sparse_test.py @@ -323,7 +323,7 @@ def test_coo_matmat(self, shape, dtype, transpose): self.assertAllClose(op(M) @ B, jit(matmat)(*args), rtol=sptu.MATMUL_TOL) def test_coo_matmat_layout(self): - # Regression test for https://github.com/google/jax/issues/7533 + # Regression test for https://github.com/jax-ml/jax/issues/7533 d = jnp.array([1.0, 2.0, 3.0, 4.0]) i = jnp.array([0, 0, 1, 2]) j = jnp.array([0, 2, 0, 0]) diff --git a/tests/sparsify_test.py b/tests/sparsify_test.py index 46086511d8b5..46c2f5aafbf6 100644 --- a/tests/sparsify_test.py +++ b/tests/sparsify_test.py @@ -610,7 +610,7 @@ def func(M): self.assertArraysEqual(jit(func)(Msp).todense(), expected) def testWeakTypes(self): - # Regression test for https://github.com/google/jax/issues/8267 + # Regression test for https://github.com/jax-ml/jax/issues/8267 M = jnp.arange(12, dtype='int32').reshape(3, 4) Msp = BCOO.fromdense(M) self.assertArraysEqual( diff --git a/tests/stax_test.py b/tests/stax_test.py index 6850f36a02ea..e21300ddd119 100644 --- a/tests/stax_test.py +++ b/tests/stax_test.py @@ -216,7 +216,7 @@ def testBatchNormShapeNHWC(self): def testBatchNormShapeNCHW(self): key = random.PRNGKey(0) - # Regression test for https://github.com/google/jax/issues/461 + # Regression test for https://github.com/jax-ml/jax/issues/461 init_fun, apply_fun = stax.BatchNorm(axis=(0, 2, 3)) input_shape = (4, 5, 6, 7) diff --git a/tests/tree_util_test.py b/tests/tree_util_test.py index 23ddf73904b5..c5342a99365d 100644 --- a/tests/tree_util_test.py +++ b/tests/tree_util_test.py @@ -24,6 +24,7 @@ import jax from jax import flatten_util from jax import tree_util +from jax._src.lib import xla_extension_version from jax._src import test_util as jtu from jax._src.tree_util import flatten_one_level, prefix_errors import jax.numpy as jnp @@ -343,7 +344,7 @@ def f(a, b, c): pass self.assertEqual(h.args, (3,)) def testPartialFuncAttributeHasStableHash(self): - # https://github.com/google/jax/issues/9429 + # https://github.com/jax-ml/jax/issues/9429 fun = functools.partial(print, 1) p1 = tree_util.Partial(fun, 2) p2 = tree_util.Partial(fun, 2) @@ -359,7 +360,7 @@ def testChildren(self): self.assertEqual([c0, c1], tree.children()) def testTreedefTupleFromChildren(self): - # https://github.com/google/jax/issues/7377 + # https://github.com/jax-ml/jax/issues/7377 tree = ((1, 2, (3, 4)), (5,)) leaves, treedef1 = tree_util.tree_flatten(tree) treedef2 = tree_util.treedef_tuple(treedef1.children()) @@ -368,7 +369,7 @@ def testTreedefTupleFromChildren(self): self.assertEqual(treedef1.num_nodes, treedef2.num_nodes) def testTreedefTupleComparesEqual(self): - # https://github.com/google/jax/issues/9066 + # https://github.com/jax-ml/jax/issues/9066 self.assertEqual(tree_util.tree_structure((3,)), tree_util.treedef_tuple((tree_util.tree_structure(3),))) @@ -395,6 +396,7 @@ def testFlattenOrder(self): ({"a": 1, "b": (2, 3)}, {"a": [7], "b": ([8], (9,))}, [[7], [8], (9,)]), ({"a": 1}, {"a": (7,)}, [(7,)]), ({"a": 1}, {"a": {"a": 7}}, [{"a": 7}]), + (None, None, []) ) def testFlattenUpTo(self, tree, xs, expected): _, tree_def = tree_util.tree_flatten(tree) @@ -483,6 +485,11 @@ def testFlattenUpTo(self, tree, xs, expected): [([1], (2,), {"a": [1]})], re.escape("Custom node type mismatch"), ), + *( + [] + if xla_extension_version < 288 + else [(None, [2], re.escape("Expected None, got [2]."))] + ), ) def testFlattenUpToErrors(self, tree, xs, error): _, tree_def = tree_util.tree_flatten(tree) @@ -978,7 +985,7 @@ def testEmpty(self): self.assertAllClose(tree, tree_, atol=0., rtol=0.) def testDtypePolymorphicUnravel(self): - # https://github.com/google/jax/issues/7809 + # https://github.com/jax-ml/jax/issues/7809 x = jnp.arange(10, dtype=jnp.float32) x_flat, unravel = flatten_util.ravel_pytree(x) y = x_flat < 5.3 @@ -987,7 +994,7 @@ def testDtypePolymorphicUnravel(self): @jax.numpy_dtype_promotion('standard') # Explicitly exercises implicit dtype promotion. def testDtypeMonomorphicUnravel(self): - # https://github.com/google/jax/issues/7809 + # https://github.com/jax-ml/jax/issues/7809 x1 = jnp.arange(10, dtype=jnp.float32) x2 = jnp.arange(10, dtype=jnp.int32) x_flat, unravel = flatten_util.ravel_pytree((x1, x2)) @@ -1257,5 +1264,63 @@ def test_tree_unflatten(self): ) +class RegistrationTest(jtu.JaxTestCase): + + def test_register_dataclass_missing_fields(self): + @dataclasses.dataclass + class Foo: + x: int + y: int + z: float = dataclasses.field(init=False) + + with self.assertRaisesRegex( + ValueError, + "data_fields and meta_fields must include all dataclass fields.*" + "Missing fields: {'y'}", + ): + tree_util.register_dataclass(Foo, data_fields=["x"], meta_fields=[]) + + # ``z`` is not required, because it's not included in ``__init__``. + tree_util.register_dataclass(Foo, data_fields=["x"], meta_fields=["y"]) + + def test_register_dataclass_unexpected_fields(self): + @dataclasses.dataclass + class Foo: + x: int + y: float + + with self.assertRaisesRegex( + ValueError, + "data_fields and meta_fields must include all dataclass fields.*" + "Unexpected fields: {'z'}", + ): + tree_util.register_dataclass( + Foo, data_fields=["x"], meta_fields=["y", "z"] + ) + + def test_register_dataclass_drop_fields(self): + @dataclasses.dataclass + class Foo: + x: int + y: int = dataclasses.field(default=42) + + # ``y`` is explicitly excluded. + tree_util.register_dataclass( + Foo, data_fields=["x"], meta_fields=[], drop_fields=["y"] + ) + + def test_register_dataclass_invalid_plain_class(self): + class Foo: + x: int + y: int + + def __init__(self, x, y): + self.x = x + self.y = y + + # ``y`` is missing, but no validation is done for plain classes. + tree_util.register_dataclass(Foo, data_fields=["x"], meta_fields=[]) + + if __name__ == "__main__": absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/tests/x64_context_test.py b/tests/x64_context_test.py index 58cf4a2baae3..d4403b7e5e30 100644 --- a/tests/x64_context_test.py +++ b/tests/x64_context_test.py @@ -128,7 +128,7 @@ def test_jit_cache(self): @unittest.skip("test fails, see #8552") def test_convert_element_type(self): - # Regression test for part of https://github.com/google/jax/issues/5982 + # Regression test for part of https://github.com/jax-ml/jax/issues/5982 with enable_x64(): x = jnp.int64(1) self.assertEqual(x.dtype, jnp.int64) diff --git a/tests/xla_metadata_test.py b/tests/xla_metadata_test.py new file mode 100644 index 000000000000..38bd7e05533e --- /dev/null +++ b/tests/xla_metadata_test.py @@ -0,0 +1,290 @@ +# Copyright 2024 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests whether the frontend attributes added by the context manager are + +correctly propagated to the jaxpr and mlir. +""" + +from absl.testing import absltest +import jax +from jax._src import config +from jax._src import dispatch +from jax._src import test_util as jtu +from jax._src.lax import lax +from jax.experimental.xla_metadata import set_xla_metadata +import jax.numpy as jnp + +config.parse_flags_with_absl() + + +class XlaMetadataTest(jtu.JaxTestCase): + + def test_f_jitted(self): + @jax.jit + def f(a, b): + with set_xla_metadata(a="b"): + return a + b + + f_jaxpr = jax.make_jaxpr(f)(1, 2) + eqns = f_jaxpr.eqns + for eq in eqns[1:]: + self.assertDictEqual(eq.ctx.attributes, {"a": "b"}) + + f_lowered_text = f.lower(1.0, 2.0).as_text() + self.assertIn('mhlo.frontend_attributes = {a = "b"}', f_lowered_text) + + def test_f_jitted_bool_attributes(self): + @jax.jit + def f(a, b): + with set_xla_metadata(a=True): + return a + b + + f_lowered_text = f.lower(1.0, 2.0).as_text() + self.assertIn('mhlo.frontend_attributes = {a = "true"}', f_lowered_text) + + def test_f_jitted_int_attributes(self): + @jax.jit + def f(a, b): + with set_xla_metadata(a=10): + return a + b + + f_lowered_text = f.lower(1.0, 2.0).as_text() + self.assertIn('mhlo.frontend_attributes = {a = "10"}', f_lowered_text) + + def test_f_nonjitted(self): + def f_add(a, b): + return dispatch.apply_primitive(lax.add_p, a, b) + + arg1 = jnp.arange(2) + with set_xla_metadata(a="b"): + self.assertIn( + 'mhlo.frontend_attributes = {a = "b"}', + jax.jit(f_add).lower(arg1, arg1).as_text(), + ) + + def test_f_attributes_overwrite(self): + @jax.jit + def g(a, b): + return a * b + + with set_xla_metadata(a="b"): + + @jax.jit + def f(a, b): + with set_xla_metadata(a="c"): + return a + b + + f_lowered_text = f.lower(1.0, 2.0).as_text() + self.assertIn('mhlo.frontend_attributes = {a = "c"}', f_lowered_text) + self.assertIn( + 'mhlo.frontend_attributes = {a = "b"}', g.lower(1.0, 2.0).as_text() + ) + self.assertNotIn("mhlo.frontend_attributes", g.lower(1.0, 2.0).as_text()) + + def test_f_attributes_merge(self): + with set_xla_metadata(key1="val1"): + + @jax.jit + def f(a, b): + with set_xla_metadata(key2="val2"): + return a + b + + f_lowered_text = f.lower(1.0, 2.0).as_text() + self.assertIn( + 'mhlo.frontend_attributes = {key1 = "val1", key2 = "val2"}', + f_lowered_text, + ) + + def test_attr_caching_jit(self): + @jax.jit + def f_add_jit(a, b): + return a + b + + with set_xla_metadata(b="c"): + f_add_lowered1 = f_add_jit.lower(2.0, 3.0).as_text() + # Expect no attributes in the mlir. + f_add_lowered2 = f_add_jit.lower(1.0, 2.0).as_text() + with set_xla_metadata(c="d"): + f_add_lowered3 = f_add_jit.lower(4.0, 5.0).as_text() + self.assertIn('mhlo.frontend_attributes = {b = "c"}', f_add_lowered1) + self.assertNotIn("mhlo.frontend_attributes = {}", f_add_lowered2) + self.assertNotIn('mhlo.frontend_attributes = {b = "c"}', f_add_lowered2) + self.assertNotIn('mhlo.frontend_attributes = {c = "d"}', f_add_lowered2) + self.assertIn('mhlo.frontend_attributes = {c = "d"}', f_add_lowered3) + + def test_attr_caching_nonjit(self): + def f_add(a, b): + return dispatch.apply_primitive(lax.add_p, a, b) + + arg1 = jnp.arange(2) + arg2 = jnp.arange(2) + 1 + arg3 = jnp.arange(2) + 2 + with set_xla_metadata(b="c"): + self.assertIn( + 'mhlo.frontend_attributes = {b = "c"}', + jax.jit(f_add).lower(arg1, arg1).as_text(), + ) + # Expect no attributes in the jaxpr. + self.assertNotIn( + "mhlo.frontend_attributes", + jax.jit(f_add).lower(arg2, arg2).as_text(), + ) + + with set_xla_metadata(c="d"): + self.assertIn( + 'mhlo.frontend_attributes = {c = "d"}', + jax.jit(f_add).lower(arg3, arg3).as_text(), + ) + + def test_axpy(self): + @jax.jit + def axpy(a, x, y): + with set_xla_metadata(a="b"): + return a * x + y + + for line in axpy.lower(1.0, 2.0, 3.0).as_text().split("\n"): + if "stablehlo.multiply" in line: + self.assertIn('mhlo.frontend_attributes = {a = "b"}', line) + if "stablehlo.add" in line: + self.assertIn('mhlo.frontend_attributes = {a = "b"}', line) + + def test_while(self): + @jax.jit + def f(a): + with set_xla_metadata(a="b"): + return jax.lax.while_loop(lambda x: x < 10, lambda x: x + 1, a) + + self.assertIn( + 'mhlo.frontend_attributes = {a = "b"}', f.lower(1.0).as_text() + ) + + def test_while_condition_body(self): + @jax.jit + def f_condition(x): + with set_xla_metadata(a="b"): + return x < 10 + + @jax.jit + def f_body(x): + with set_xla_metadata(a="c"): + return x + 1 + + @jax.jit + def while_fn(a): + return jax.lax.while_loop(f_condition, f_body, a) + + for line in while_fn.lower(1.0).as_text().split("\n"): + if "stablehlo.compare" in line: + self.assertIn('mhlo.frontend_attributes = {a = "b"}', line) + if "stablehlo.add" in line: + self.assertIn('mhlo.frontend_attributes = {a = "c"}', line) + + def test_nested_jit(self): + @jax.jit + def f(x, y): + with set_xla_metadata(a="b"): + z = x * y + + @jax.jit + def g(z): + with set_xla_metadata(c="d"): + return z**2 + 1 + + return g(z) + + self.assertIn( + 'mhlo.frontend_attributes = {a = "b", c = "d"}', + f.lower(1.0, 2.0).as_text(), + ) + + def test_grad(self): + @jax.jit + def f(x, y): + with set_xla_metadata(a="b"): + return jax.grad(lambda x: x**3 + y**2 + jnp.sin(x))(x) + + f_jaxpr = jax.make_jaxpr(f)(1.0, 2.0) + eqns = f_jaxpr.eqns + for eq in eqns[1:]: + self.assertDictEqual(eq.ctx.attributes, {"a": "b"}) + + self.assertIn( + 'mhlo.frontend_attributes = {a = "b"}', f.lower(1.0, 2.).as_text() + ) + + def test_grad_outside_ctx(self): + @jax.jit + def f(x): + with set_xla_metadata(a="b"): + return x**3 + x**2 + jnp.sin(x) + + grad_fn = jax.jit(jax.grad(f)) + for line in grad_fn.lower(1.0).as_text().split("\n"): + if "stablehlo.cosine" in line: + self.assertIn('mhlo.frontend_attributes = {a = "b"}', line) + if "call @integer_pow" in line: + self.assertIn('mhlo.frontend_attributes = {a = "b"}', line) + + def test_vmap(self): + dct = {"a": 0.0, "b": jnp.arange(5.0)} + + @jax.jit + def f(dct, x): + with set_xla_metadata(a="b"): + return dct["a"] + dct["b"] + x + + with set_xla_metadata(a="d"): + f_vmap = jax.vmap(f, in_axes=({"a": None, "b": 0}, None)) + f_jaxpr = jax.make_jaxpr(f_vmap)(dct, 1.0) + eqns = f_jaxpr.eqns + for eq in eqns[1:]: + self.assertDictEqual(eq.ctx.attributes, {"a": "d"}) + @jax.jit + def f2(x, y): + with set_xla_metadata(a="b"): + return (x + y, y * 2.0) + + f_vmap_jaxpr = jax.make_jaxpr(jax.vmap(f2, in_axes=(0, None))) + self.assertIn( + 'mhlo.frontend_attributes = {a = "b"}', + f_vmap_jaxpr.lower(jnp.arange(5.0), 1.0).as_text(), + ) + + def test_multiple_instructions(self): + @jax.jit + def f(x, a): + y = jnp.matmul(x, x) + with set_xla_metadata(a="b"): + return y + a + + for line in f.lower(jnp.arange(5.0), 1.0).as_text().split("\n"): + # matmul doesn't have attributes + if "stablehlo.dot_general" in line: + self.assertNotIn('mhlo.frontend_attributes = {a = "b"}', line) + if "stablehlo.add" in line: + self.assertIn('mhlo.frontend_attributes = {a = "b"}', line) + + def test_softmax(self): + @jax.jit + def f(x): + with set_xla_metadata(a="b"): + return jax.nn.softmax(x) + self.assertIn( + 'mhlo.frontend_attributes = {a = "b"}', f.lower(jnp.arange(5.0)).as_text() + ) + + +if __name__ == "__main__": + absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/third_party/xla/workspace.bzl b/third_party/xla/workspace.bzl index f49cf2eb34ba..ca84599d6cff 100644 --- a/third_party/xla/workspace.bzl +++ b/third_party/xla/workspace.bzl @@ -21,8 +21,8 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") # curl -L https://github.com/openxla/xla/archive/.tar.gz | sha256sum # and update XLA_SHA256 with the result. -XLA_COMMIT = "08b8d938eb56928970e65639b126794c01b75c3d" -XLA_SHA256 = "365d9b42b6da10c9f0b53f01e075e8b1513e431ad596183d8ec1c2c27e1d7973" +XLA_COMMIT = "0e732d65bdf8fb158c7b01e18139e5ba59ca7025" +XLA_SHA256 = "16e4aeca04ce94bd0fcfa32990d76be3779c026c2b649478bf27d0db0679e65c" def repo(): tf_http_archive(