Skip to content

Commit

Permalink
Add GPU CI (#137)
Browse files Browse the repository at this point in the history
  • Loading branch information
charleshofer authored Jan 7, 2025
1 parent 972f95b commit bc06c93
Show file tree
Hide file tree
Showing 6 changed files with 301 additions and 177 deletions.
63 changes: 63 additions & 0 deletions .github/workflows/rocm-ci.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
name: ROCm GPU CI

on:
# Trigger the workflow on push or pull request,
# but only for the rocm-main branch
push:
branches:
- rocm-main
pull_request:
branches:
- rocm-main

concurrency:
group: ${{ github.workflow }}-${{ github.head_ref || github.ref }}
cancel-in-progress: true

jobs:
build-jax-in-docker: # strategy and matrix come here
runs-on: mi-250
env:
BASE_IMAGE: "ubuntu:22.04"
TEST_IMAGE: ubuntu-jax-${{ github.run_id }}_${{ github.run_number }}_${{ github.run_attempt }}
PYTHON_VERSION: "3.10"
ROCM_VERSION: "6.2.4"
WORKSPACE_DIR: workdir_${{ github.run_id }}_${{ github.run_number }}_${{ github.run_attempt }}
steps:
- name: Clean up old runs
run: |
ls
# Make sure that we own all of the files so that we have permissions to delete them
docker run -v "./:/jax" ubuntu /bin/bash -c "chown -R $UID /jax/workdir_* || true"
# Remove any old work directories from this machine
rm -rf workdir_*
ls
- name: Print system info
run: |
whoami
printenv
df -h
rocm-smi
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
with:
path: ${{ env.WORKSPACE_DIR }}
- name: Build JAX
run: |
pushd $WORKSPACE_DIR
python3 build/rocm/ci_build \
--rocm-version $ROCM_VERSION \
--base-docker $BASE_IMAGE \
--python-versions $PYTHON_VERSION \
--compiler=clang \
dist_docker \
--image-tag $TEST_IMAGE
- name: Archive jax wheels
uses: actions/upload-artifact@v4
with:
name: rocm_jax_r${{ env.ROCM_VERSION }}_py${{ env.PYTHON_VERSION }}_id${{ github.run_id }}
path: ./dist/*.whl
- name: Run tests
run: |
cd $WORKSPACE_DIR
python3 build/rocm/ci_build test $TEST_IMAGE
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,11 @@ ARG ROCM_BUILD_JOB
ARG ROCM_BUILD_NUM

# Install system GCC and C++ libraries.
RUN yum install -y gcc-c++.x86_64
# (charleshofer) This is not ideal, as we should already have GCC and C++ libraries in the
# manylinux base image. However, adding this does fix an issue where Bazel isn't able
# to find them.
RUN --mount=type=cache,target=/var/cache/dnf \
dnf install -y gcc-c++-8.5.0-22.el8_10.x86_64

RUN --mount=type=cache,target=/var/cache/dnf \
--mount=type=bind,source=build/rocm/tools/get_rocm.py,target=get_rocm.py \
Expand All @@ -20,3 +24,6 @@ RUN --mount=type=cache,target=/var/cache/dnf \
RUN mkdir /tmp/llvm-project && wget -qO - https://github.com/llvm/llvm-project/archive/refs/tags/llvmorg-18.1.8.tar.gz | tar -xz -C /tmp/llvm-project --strip-components 1 && \
mkdir /tmp/llvm-project/build && cd /tmp/llvm-project/build && cmake -DLLVM_ENABLE_PROJECTS='clang;lld' -DCMAKE_BUILD_TYPE=Release -DCMAKE_INSTALL_PREFIX=/usr/lib/llvm-18/ ../llvm && \
make -j$(nproc) && make -j$(nproc) install && rm -rf /tmp/llvm-project

# Stop git from erroring out when we don't own the repo
RUN git config --global --add safe.directory '*'
70 changes: 41 additions & 29 deletions build/rocm/ci_build
Original file line number Diff line number Diff line change
Expand Up @@ -21,39 +21,24 @@


import argparse
import logging
import os
import subprocess
import sys


LOG = logging.getLogger("ci_build")


def image_by_name(name):
cmd = ["docker", "images", "-q", "-f", "reference=%s" % name]
out = subprocess.check_output(cmd)
image_id = out.decode("utf8").strip().split("\n")[0] or None
return image_id


def dist_wheels(
rocm_version,
python_versions,
xla_path,
rocm_build_job="",
rocm_build_num="",
compiler="gcc",
):
if xla_path:
xla_path = os.path.abspath(xla_path)

# create manylinux image with requested ROCm installed
image = "jax-manylinux_2_28_x86_64_rocm%s" % rocm_version.replace(".", "")

# Try removing the Docker image.
try:
subprocess.run(["docker", "rmi", image], check=True)
print(f"Image {image} removed successfully.")
except subprocess.CalledProcessError as e:
print(f"Failed to remove Docker image {image}: {e}")

def create_manylinux_build_image(rocm_version, rocm_build_job, rocm_build_num):
image_name = "jax-build-manylinux_2_28_x86_64_rocm%s" % rocm_version.replace(".", "")
cmd = [
"docker",
"build",
Expand All @@ -62,12 +47,29 @@ def dist_wheels(
"--build-arg=ROCM_VERSION=%s" % rocm_version,
"--build-arg=ROCM_BUILD_JOB=%s" % rocm_build_job,
"--build-arg=ROCM_BUILD_NUM=%s" % rocm_build_num,
"--tag=%s" % image,
"--tag=%s" % image_name,
".",
]

if not image_by_name(image):
_ = subprocess.run(cmd, check=True)
LOG.info("Creating manylinux build image. Running: %s", cmd)
_ = subprocess.run(cmd, check=True)
return image_name


def dist_wheels(
rocm_version,
python_versions,
xla_path,
rocm_build_job="",
rocm_build_num="",
compiler="gcc",
):
# We want to make sure the wheels we build are manylinux compliant. We'll
# do the build in a container. Build the image for this.
image_name = create_manylinux_build_image(rocm_version, rocm_build_job, rocm_build_num)

if xla_path:
xla_path = os.path.abspath(xla_path)

# use image to build JAX/jaxlib wheels
os.makedirs("wheelhouse", exist_ok=True)
Expand Down Expand Up @@ -114,13 +116,14 @@ def dist_wheels(
[
"--init",
"--rm",
image,
image_name,
"bash",
"-c",
" ".join(bw_cmd),
]
)

LOG.info("Running: %s", cmd)
_ = subprocess.run(cmd, check=True)


Expand All @@ -141,10 +144,16 @@ def _fetch_jax_metadata(xla_path):

jax_version = subprocess.check_output(cmd, env=env)

def safe_decode(x):
if isinstance(x, str):
return x
else:
return x.decode("utf8")

return {
"jax_version": jax_version.decode("utf8").strip(),
"jax_commit": jax_commit.decode("utf8").strip(),
"xla_commit": xla_commit.decode("utf8").strip(),
"jax_version": safe_decode(jax_version).strip(),
"jax_commit": safe_decode(jax_commit).strip(),
"xla_commit": safe_decode(xla_commit).strip(),
}


Expand Down Expand Up @@ -211,10 +220,12 @@ def test(image_name):
cmd = [
"docker",
"run",
"-it",
"--rm",
]

if os.isatty(sys.stdout.fileno()):
cmd.append("-it")

# NOTE(mrodden): we need jax source dir for the unit test code only,
# JAX and jaxlib are already installed from wheels
mounts = [
Expand Down Expand Up @@ -298,6 +309,7 @@ def parse_args():


def main():
logging.basicConfig(level=logging.INFO)
args = parse_args()

if args.action == "dist_wheels":
Expand Down
Loading

0 comments on commit bc06c93

Please sign in to comment.