Skip to content

Commit

Permalink
Merge pull request #187 from ROCm/ci-upstream-sync-66_1
Browse files Browse the repository at this point in the history
CI: 12/19/24 upstream sync
  • Loading branch information
charleshofer authored Dec 19, 2024
2 parents ffcfc10 + b8bbb14 commit b05089d
Show file tree
Hide file tree
Showing 139 changed files with 6,147 additions and 2,325 deletions.
2 changes: 2 additions & 0 deletions .github/workflows/asan.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ on:

jobs:
asan:
# Don't execute in fork due to runner type
if: github.repository == 'jax-ml/jax'
runs-on: linux-x86-n2-64
container:
image: index.docker.io/library/ubuntu@sha256:b359f1067efa76f37863778f7b6d0e8d911e3ee8efa807ad01fbf5dc1ef9006b # ratchet:ubuntu:24.04
Expand Down
16 changes: 11 additions & 5 deletions .github/workflows/cloud-tpu-ci-nightly.yml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
name: CI - Cloud TPU (nightly)
on:
schedule:
- cron: "0 */2 * * *" # Run every 2 hours
- cron: "0 2,14 * * *" # Run at 7am and 7pm PST
workflow_dispatch: # allows triggering the workflow run manually
# This should also be set to read-only in the project settings, but it's nice to
# document and enforce the permissions here.
Expand All @@ -33,12 +33,11 @@ jobs:
python-version: ["3.10"]
name: "TPU test (jaxlib=${{ matrix.jaxlib-version }}, ${{ matrix.tpu.type }})"
env:
LIBTPU_OLDEST_VERSION_DATE: 20240722
ENABLE_PJRT_COMPATIBILITY: ${{ matrix.jaxlib-version == 'nightly+oldest_supported_libtpu' }}
LIBTPU_OLDEST_VERSION_DATE: 20240922
PYTHON: python${{ matrix.python-version }}
runs-on: ${{ matrix.tpu.runner }}
container: "us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build:latest"
timeout-minutes: 120
timeout-minutes: 180
defaults:
run:
shell: bash -ex {0}
Expand Down Expand Up @@ -112,10 +111,17 @@ jobs:
JAX_PLATFORMS: tpu,cpu
PY_COLORS: 1
run: |
# We're deselecting all Pallas TPU tests in the oldest libtpu build. Mosaic TPU does not
# guarantee anything about forward compatibility (unless jax.export is used) and the 12
# week compatibility window accumulates way too many failures.
IGNORE_FLAGS=
if [ "${{ matrix.jaxlib-version }}" == "nightly+oldest_supported_libtpu" ]; then
IGNORE_FLAGS="--ignore=tests/pallas"
fi
# Run single-accelerator tests in parallel
JAX_ENABLE_TPU_XDIST=true $PYTHON -m pytest -n=${{ matrix.tpu.cores }} --tb=short \
--deselect=tests/pallas/tpu_pallas_test.py::PallasCallPrintTest \
--maxfail=20 -m "not multiaccelerator" tests examples
--maxfail=20 -m "not multiaccelerator" $IGNORE_FLAGS tests examples
# Run Pallas printing tests, which need to run with I/O capturing disabled.
TPU_STDERR_LOG_LEVEL=0 $PYTHON -m pytest -s \
tests/pallas/tpu_pallas_test.py::PallasCallPrintTest
Expand Down
93 changes: 93 additions & 0 deletions .github/workflows/cloud-tpu-ci-presubmit.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
# Cloud TPU CI (presubmit)
#
# This job currently runs as a non-blocking presubmit. It is experimental and is currently being
# tested to get to a stable state before we enable it as a blocking presubmit.
name: CI - Cloud TPU (presubmit)
on:
workflow_dispatch:
inputs:
halt-for-connection:
description: 'Should this workflow run wait for a remote connection?'
type: choice
required: true
default: 'no'
options:
- 'yes'
- 'no'
pull_request:
branches:
- main

# This should also be set to read-only in the project settings, but it's nice to
# document and enforce the permissions here.
permissions:
contents: read

concurrency:
group: ${{ github.workflow }}-${{ github.head_ref || github.ref }}
cancel-in-progress: true

jobs:
cloud-tpu-test:
if: github.event.repository.fork == false
strategy:
fail-fast: false # don't cancel all jobs on failure
matrix:
tpu: [
{type: "v5e-8", cores: "8", runner: "linux-x86-ct5lp-224-8tpu"}
]
python-version: ["3.10"]

name: "TPU test (jaxlib=head, ${{ matrix.tpu.type }})"

env:
JAXCI_PYTHON: python${{ matrix.python-version }}
JAXCI_TPU_CORES: ${{ matrix.tpu.cores }}

runs-on: ${{ matrix.tpu.runner }}
container: "us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build:latest"

timeout-minutes: 60

defaults:
run:
shell: bash -ex {0}

steps:
# https://opensource.google/documentation/reference/github/services#actions
# mandates using a specific commit for non-Google actions. We use
# https://github.com/sethvargo/ratchet to pin specific versions.
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
# Checkout XLA at head, if we're building jaxlib at head.
- name: Checkout XLA at head
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
with:
repository: openxla/xla
path: xla
# We need to mark the GitHub workspace as safe as otherwise git commands will fail.
- name: Mark GitHub workspace as safe
run: |
git config --global --add safe.directory "$GITHUB_WORKSPACE"
- name: Install JAX test requirements
run: |
$JAXCI_PYTHON -m pip install -U -r build/test-requirements.txt
$JAXCI_PYTHON -m pip install -U -r build/collect-profile-requirements.txt
- name: Build jaxlib at head with latest XLA
run: |
# Build and install jaxlib at head
$JAXCI_PYTHON build/build.py build --wheels=jaxlib \
--python_version=${{ matrix.python-version }} \
--bazel_options=--config=rbe_linux_x86_64 \
--local_xla_path="$(pwd)/xla" \
--verbose
# Install libtpu
$JAXCI_PYTHON -m pip install --pre libtpu \
-f https://storage.googleapis.com/jax-releases/libtpu_releases.html
# Halt for testing
- name: Wait For Connection
uses: google-ml-infra/actions/ci_connection@main
with:
halt-dispatch-input: ${{ inputs.halt-for-connection }}
- name: Install jaxlib wheel and run tests
run: ./ci/run_pytest_tpu.sh
41 changes: 34 additions & 7 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,24 +10,51 @@ Remember to align the itemized text with the first line of an item within a list
When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.md.
-->

## jax 0.4.38
## Unreleased

* Deprecations
* From {mod}`jax.interpreters.xla`, `abstractify` and `pytype_aval_mappings`
are now deprecated, having been replaced by symbols of the same name
in {mod}`jax.core`.

* Deletions
* `jax_enable_memories` flag has been deleted and the behavior of that flag
is on by default.

* Changes:
* The minimum NumPy version is now 1.25. NumPy 1.25 will remain the minimum
supported version until June 2025.

* Deprecations
* From {mod}`jax.interpreters.xla`, `abstractify` and `pytype_aval_mappings`
are now deprecated, having been replaced by symbols of the same name
in {mod}`jax.core`.

## jax 0.4.38 (Dec 17, 2024)

* Changes:
* `jax.tree.flatten_with_path` and `jax.tree.map_with_path` are added
as shortcuts of the corresponding `tree_util` functions.

* Deprecations
* a number of APIs in the internal `jax.core` namespace have been deprecated, including
`ClosedJaxpr`, `full_lower`, `Jaxpr`, `JaxprEqn`, `jaxpr_as_fun`, `lattice_join`,
`Literal`, `Primitive`, `raise_to_shaped`, `Token`, `Var`. Most can be replaced by
APIs of the same name in {mod}`jax.extend.core`; see the documentation for
{mod}`jax.extend` for information on the compatibility guarantees of these
semi-public extensions.
* a number of APIs in the internal `jax.core` namespace have been deprecated.
Most were no-ops, were little-used, or can be replaced by APIs of the same
name in {mod}`jax.extend.core`; see the documentation for {mod}`jax.extend`
for information on the compatibility guarantees of these semi-public extensions.
* Several previously-deprecated APIs have been removed, including:
* from {mod}`jax.core`: `check_eqn`, `check_type`, `check_valid_jaxtype`, and
`non_negative_dim`.
* from {mod}`jax.lib.xla_bridge`: `xla_client` and `default_backend`.
* from {mod}`jax.lib.xla_client`: `_xla` and `bfloat16`.
* from {mod}`jax.numpy`: `round_`.

* New Features
* {func}`jax.export.export` can be used for device-polymorphic export with
shardings constructed with {func}`jax.sharding.AbstractMesh`.
See the [jax.export documentation](https://jax.readthedocs.io/en/latest/export/export.html#device-polymorphic-export).
* Added {func}`jax.lax.split`. This is a primitive version of
{func}`jax.numpy.split`, added because it yields a more compact
transpose during automatic differentiation.

## jax 0.4.37 (Dec 9, 2024)

Expand Down
4 changes: 2 additions & 2 deletions benchmarks/api_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@
from jax import lax
from jax._src.api_util import shaped_abstractify # technically not an api fn
from jax._src.ad_checkpoint import checkpoint # new jax.remat implementation
from jax._src import core
from jax._src.lib import xla_client as xc
from jax.interpreters import xla
from jax._src import array
from jax._src import op_shardings
from jax._src.pjit import pjit_check_aval_sharding
Expand Down Expand Up @@ -427,7 +427,7 @@ def bench_shaped_abstractify(state):

def _run_benchmark_for_xla_abstractify(arg, state):
while state:
xla.abstractify(arg)
core.abstractify(arg)

def bench_xla_abstractify():
_abstractify_args = [
Expand Down
16 changes: 8 additions & 8 deletions build/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -475,20 +475,20 @@ async def main():
wheel_build_command.append(f"--action_env=CLANG_COMPILER_PATH=\"{clang_path}\"")
wheel_build_command.append(f"--repo_env=CC=\"{clang_path}\"")
wheel_build_command.append(f"--repo_env=BAZEL_COMPILER=\"{clang_path}\"")

if clang_major_version >= 16:
# Enable clang settings that are needed for the build to work with newer
# versions of Clang.
wheel_build_command.append("--config=clang")
else:
logging.debug("Use Clang: False")

# Do not apply --config=clang on Mac as these settings do not apply to
# Apple Clang.
if os_name != "darwin":
wheel_build_command.append("--config=clang")

if not args.disable_mkl_dnn:
logging.debug("Enabling MKL DNN")
if target_cpu == "aarch64":
wheel_build_command.append("--config=mkl_aarch64_threadpool")
wheel_build_command.append("--config=mkl_aarch64_threadpool")
else:
wheel_build_command.append("--config=mkl_open_source_only")
wheel_build_command.append("--config=mkl_open_source_only")

if args.target_cpu_features == "release":
if arch in ["x86_64", "AMD64"]:
Expand All @@ -501,7 +501,7 @@ async def main():
if os_name == "windows"
else "--config=avx_posix"
)
elif wheel_build_command == "native":
elif args.target_cpu_features == "native":
if os_name == "windows":
logger.warning(
"--target_cpu_features=native is not supported on Windows;"
Expand Down
2 changes: 1 addition & 1 deletion build/tools/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,7 @@ def get_githash():
capture_output=True,
check=True,
).stdout.strip()
except OSError:
except (subprocess.CalledProcessError, OSError):
return ""

def _parse_string_as_bool(s):
Expand Down
45 changes: 45 additions & 0 deletions ci/envs/docker.env
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
# Copyright 2024 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
# This file contains all the docker specifc envs that are needed by the
# ci/utilities/run_docker_container.sh script.

os=$(uname -s | awk '{print tolower($0)}')
arch=$(uname -m)

# The path to the JAX git repository.
export JAXCI_JAX_GIT_DIR=$(pwd)

export JAXCI_DOCKER_WORK_DIR="/jax"
export JAXCI_DOCKER_ARGS=""

# TODO(b/384533574): Replace latest tagged images with sha or release specific
# tags when we make the new release jobs as default.
# Linux x86 image for building JAX artifacts, running Pytests CPU/TPU tests, and
# Bazel tests
if [[ $os == "linux" ]] && [[ $arch == "x86_64" ]]; then
export JAXCI_DOCKER_IMAGE="us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build:latest"
fi

# Linux Aarch64 image for building JAX artifacts, running Pytests CPU tests, and
# Bazel tests
if [[ $os == "linux" ]] && [[ $arch == "aarch64" ]]; then
export JAXCI_DOCKER_IMAGE="us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build-arm64:latest"
fi

# Windows image for building JAX artifacts, running Pytests CPU tests, and Bazel
# tests
if [[ $os =~ "msys_nt" ]]; then
export JAXCI_DOCKER_IMAGE="us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/tf-test-windows:latest"
fi
25 changes: 12 additions & 13 deletions ci/run_pytest_tpu.sh
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -33,29 +33,28 @@ source ./ci/utilities/install_wheels_locally.sh
# Set up the build environment.
source "ci/utilities/setup_build_environment.sh"

export PY_COLORS=1
export JAX_SKIP_SLOW_TESTS=true

"$JAXCI_PYTHON" -c "import jax; print(jax.default_backend()); print(jax.devices()); print(len(jax.devices()))"

"$JAXCI_PYTHON" -c 'import sys; print("python version:", sys.version)'
"$JAXCI_PYTHON" -c 'import jax; print("jax version:", jax.__version__)'
"$JAXCI_PYTHON" -c 'import jaxlib; print("jaxlib version:", jaxlib.__version__)'
strings /usr/local/lib/"$JAXCI_PYTHON"/site-packages/libtpu/libtpu.so | grep 'Built on'
strings /usr/local/lib/"$JAXCI_PYTHON"/dist-packages/libtpu/libtpu.so | grep 'Built on'
"$JAXCI_PYTHON" -c 'import jax; print("libtpu version:",jax.lib.xla_bridge.get_backend().platform_version)'

echo "Running TPU tests..."
# Set up all common test environment variables
export PY_COLORS=1
export JAX_PLATFORMS=tpu,cpu
# Run single-accelerator tests in parallel
export JAX_ENABLE_TPU_XDIST=true
export JAX_SKIP_SLOW_TESTS=true
# End of common test environment variable setup

"$JAXCI_PYTHON" -m pytest -n="$JAXCI_TPU_CORES" --tb=short \
echo "Running TPU tests..."

# Run single-accelerator tests in parallel
JAX_ENABLE_TPU_XDIST=true "$JAXCI_PYTHON" -m pytest -n="$JAXCI_TPU_CORES" --tb=short \
--deselect=tests/pallas/tpu_pallas_test.py::PallasCallPrintTest \
--maxfail=20 -m "not multiaccelerator" tests examples
--maxfail=20 -m "not multiaccelerator" tests/pallas/tpu_ops_test.py

# Run Pallas printing tests, which need to run with I/O capturing disabled.
export TPU_STDERR_LOG_LEVEL=0
"$JAXCI_PYTHON" -m pytest -s tests/pallas/tpu_pallas_test.py::PallasCallPrintTest
TPU_STDERR_LOG_LEVEL=0 "$JAXCI_PYTHON" -m pytest -s tests/pallas/tpu_pallas_test.py::PallasCallPrintTest

# Run multi-accelerator across all chips
"$JAXCI_PYTHON" -m pytest --tb=short --maxfail=20 -m "multiaccelerator" tests
"$JAXCI_PYTHON" -m pytest --tb=short --maxfail=20 -m "multiaccelerator" tests/pjit_test.py
Loading

0 comments on commit b05089d

Please sign in to comment.