diff --git a/.github/workflows/rocm-ci.yml b/.github/workflows/rocm-ci.yml new file mode 100644 index 000000000000..7f9cd760aee5 --- /dev/null +++ b/.github/workflows/rocm-ci.yml @@ -0,0 +1,63 @@ +name: ROCm GPU CI + +on: + # Trigger the workflow on push or pull request, + # but only for the rocm-main branch + push: + branches: + - rocm-main + pull_request: + branches: + - rocm-main + +concurrency: + group: ${{ github.workflow }}-${{ github.head_ref || github.ref }} + cancel-in-progress: true + +jobs: + build-jax-in-docker: # strategy and matrix come here + runs-on: mi-250 + env: + BASE_IMAGE: "ubuntu:22.04" + TEST_IMAGE: ubuntu-jax-${{ github.run_id }}_${{ github.run_number }}_${{ github.run_attempt }} + PYTHON_VERSION: "3.10" + ROCM_VERSION: "6.2.4" + WORKSPACE_DIR: workdir_${{ github.run_id }}_${{ github.run_number }}_${{ github.run_attempt }} + steps: + - name: Clean up old runs + run: | + ls + # Make sure that we own all of the files so that we have permissions to delete them + docker run -v "./:/jax" ubuntu /bin/bash -c "chown -R $UID /jax/workdir_* || true" + # Remove any old work directories from this machine + rm -rf workdir_* + ls + - name: Print system info + run: | + whoami + printenv + df -h + rocm-smi + - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + with: + path: ${{ env.WORKSPACE_DIR }} + - name: Build JAX + run: | + pushd $WORKSPACE_DIR + python3 build/rocm/ci_build \ + --rocm-version $ROCM_VERSION \ + --base-docker $BASE_IMAGE \ + --python-versions $PYTHON_VERSION \ + --compiler=clang \ + dist_docker \ + --image-tag $TEST_IMAGE + - name: Archive jax wheels + uses: actions/upload-artifact@v4 + with: + name: rocm_jax_r${{ env.ROCM_VERSION }}_py${{ env.PYTHON_VERSION }}_id${{ github.run_id }} + path: ./dist/*.whl + - name: Run tests + run: | + cd $WORKSPACE_DIR + python3 build/rocm/ci_build test $TEST_IMAGE + diff --git a/build/rocm/build_wheels/Dockerfile.manylinux_2_28_x86_64.rocm b/build/rocm/build_wheels/Dockerfile.manylinux_2_28_x86_64.rocm index 6e610e711c77..3e6333d6627f 100644 --- a/build/rocm/build_wheels/Dockerfile.manylinux_2_28_x86_64.rocm +++ b/build/rocm/build_wheels/Dockerfile.manylinux_2_28_x86_64.rocm @@ -5,7 +5,11 @@ ARG ROCM_BUILD_JOB ARG ROCM_BUILD_NUM # Install system GCC and C++ libraries. -RUN yum install -y gcc-c++.x86_64 +# (charleshofer) This is not ideal, as we should already have GCC and C++ libraries in the +# manylinux base image. However, adding this does fix an issue where Bazel isn't able +# to find them. +RUN --mount=type=cache,target=/var/cache/dnf \ + dnf install -y gcc-c++-8.5.0-22.el8_10.x86_64 RUN --mount=type=cache,target=/var/cache/dnf \ --mount=type=bind,source=build/rocm/tools/get_rocm.py,target=get_rocm.py \ @@ -20,3 +24,6 @@ RUN --mount=type=cache,target=/var/cache/dnf \ RUN mkdir /tmp/llvm-project && wget -qO - https://github.com/llvm/llvm-project/archive/refs/tags/llvmorg-18.1.8.tar.gz | tar -xz -C /tmp/llvm-project --strip-components 1 && \ mkdir /tmp/llvm-project/build && cd /tmp/llvm-project/build && cmake -DLLVM_ENABLE_PROJECTS='clang;lld' -DCMAKE_BUILD_TYPE=Release -DCMAKE_INSTALL_PREFIX=/usr/lib/llvm-18/ ../llvm && \ make -j$(nproc) && make -j$(nproc) install && rm -rf /tmp/llvm-project + +# Stop git from erroring out when we don't own the repo +RUN git config --global --add safe.directory '*' diff --git a/build/rocm/ci_build b/build/rocm/ci_build index 849c082dca38..2556633482ff 100755 --- a/build/rocm/ci_build +++ b/build/rocm/ci_build @@ -21,11 +21,15 @@ import argparse +import logging import os import subprocess import sys +LOG = logging.getLogger("ci_build") + + def image_by_name(name): cmd = ["docker", "images", "-q", "-f", "reference=%s" % name] out = subprocess.check_output(cmd) @@ -33,27 +37,8 @@ def image_by_name(name): return image_id -def dist_wheels( - rocm_version, - python_versions, - xla_path, - rocm_build_job="", - rocm_build_num="", - compiler="gcc", -): - if xla_path: - xla_path = os.path.abspath(xla_path) - - # create manylinux image with requested ROCm installed - image = "jax-manylinux_2_28_x86_64_rocm%s" % rocm_version.replace(".", "") - - # Try removing the Docker image. - try: - subprocess.run(["docker", "rmi", image], check=True) - print(f"Image {image} removed successfully.") - except subprocess.CalledProcessError as e: - print(f"Failed to remove Docker image {image}: {e}") - +def create_manylinux_build_image(rocm_version, rocm_build_job, rocm_build_num): + image_name = "jax-build-manylinux_2_28_x86_64_rocm%s" % rocm_version.replace(".", "") cmd = [ "docker", "build", @@ -62,12 +47,29 @@ def dist_wheels( "--build-arg=ROCM_VERSION=%s" % rocm_version, "--build-arg=ROCM_BUILD_JOB=%s" % rocm_build_job, "--build-arg=ROCM_BUILD_NUM=%s" % rocm_build_num, - "--tag=%s" % image, + "--tag=%s" % image_name, ".", ] - if not image_by_name(image): - _ = subprocess.run(cmd, check=True) + LOG.info("Creating manylinux build image. Running: %s", cmd) + _ = subprocess.run(cmd, check=True) + return image_name + + +def dist_wheels( + rocm_version, + python_versions, + xla_path, + rocm_build_job="", + rocm_build_num="", + compiler="gcc", +): + # We want to make sure the wheels we build are manylinux compliant. We'll + # do the build in a container. Build the image for this. + image_name = create_manylinux_build_image(rocm_version, rocm_build_job, rocm_build_num) + + if xla_path: + xla_path = os.path.abspath(xla_path) # use image to build JAX/jaxlib wheels os.makedirs("wheelhouse", exist_ok=True) @@ -114,13 +116,14 @@ def dist_wheels( [ "--init", "--rm", - image, + image_name, "bash", "-c", " ".join(bw_cmd), ] ) + LOG.info("Running: %s", cmd) _ = subprocess.run(cmd, check=True) @@ -141,10 +144,16 @@ def _fetch_jax_metadata(xla_path): jax_version = subprocess.check_output(cmd, env=env) + def safe_decode(x): + if isinstance(x, str): + return x + else: + return x.decode("utf8") + return { - "jax_version": jax_version.decode("utf8").strip(), - "jax_commit": jax_commit.decode("utf8").strip(), - "xla_commit": xla_commit.decode("utf8").strip(), + "jax_version": safe_decode(jax_version).strip(), + "jax_commit": safe_decode(jax_commit).strip(), + "xla_commit": safe_decode(xla_commit).strip(), } @@ -211,10 +220,12 @@ def test(image_name): cmd = [ "docker", "run", - "-it", "--rm", ] + if os.isatty(sys.stdout.fileno()): + cmd.append("-it") + # NOTE(mrodden): we need jax source dir for the unit test code only, # JAX and jaxlib are already installed from wheels mounts = [ @@ -298,6 +309,7 @@ def parse_args(): def main(): + logging.basicConfig(level=logging.INFO) args = parse_args() if args.action == "dist_wheels": diff --git a/build/rocm/run_single_gpu.py b/build/rocm/run_single_gpu.py index e1fa26c72872..14a1e9037989 100755 --- a/build/rocm/run_single_gpu.py +++ b/build/rocm/run_single_gpu.py @@ -25,179 +25,205 @@ LAST_CODE = 0 base_dir = "./logs" + def extract_filename(path): - base_name = os.path.basename(path) - file_name, _ = os.path.splitext(base_name) - return file_name + base_name = os.path.basename(path) + file_name, _ = os.path.splitext(base_name) + return file_name def combine_json_reports(): - all_json_files = [f for f in os.listdir(base_dir) if f.endswith('_log.json')] - combined_data = [] - for json_file in all_json_files: - with open(os.path.join(base_dir, json_file), 'r') as infile: - data = json.load(infile) - combined_data.append(data) - combined_json_file = f"{base_dir}/final_compiled_report.json" - with open(combined_json_file, 'w') as outfile: - json.dump(combined_data, outfile, indent=4) + all_json_files = [f for f in os.listdir(base_dir) if f.endswith("_log.json")] + combined_data = [] + for json_file in all_json_files: + with open(os.path.join(base_dir, json_file), "r") as infile: + data = json.load(infile) + combined_data.append(data) + combined_json_file = f"{base_dir}/final_compiled_report.json" + with open(combined_json_file, "w") as outfile: + json.dump(combined_data, outfile, indent=4) def combine_csv_reports(): - all_csv_files = [f for f in os.listdir(base_dir) if f.endswith('_log.csv')] - combined_csv_file = f"{base_dir}/final_compiled_report.csv" - with open(combined_csv_file, mode='w', newline='') as outfile: - csv_writer = csv.writer(outfile) - for i, csv_file in enumerate(all_csv_files): - with open(os.path.join(base_dir, csv_file), mode='r') as infile: - csv_reader = csv.reader(infile) - if i == 0: - # write headers only once - csv_writer.writerow(next(csv_reader)) - for row in csv_reader: - csv_writer.writerow(row) + all_csv_files = [f for f in os.listdir(base_dir) if f.endswith("_log.csv")] + combined_csv_file = f"{base_dir}/final_compiled_report.csv" + with open(combined_csv_file, mode="w", newline="") as outfile: + csv_writer = csv.writer(outfile) + for i, csv_file in enumerate(all_csv_files): + with open(os.path.join(base_dir, csv_file), mode="r") as infile: + csv_reader = csv.reader(infile) + if i == 0: + # write headers only once + csv_writer.writerow(next(csv_reader)) + for row in csv_reader: + csv_writer.writerow(row) def generate_final_report(shell=False, env_vars={}): - env = os.environ - env = {**env, **env_vars} - cmd = ["pytest_html_merger", "-i", f'{base_dir}', "-o", f'{base_dir}/final_compiled_report.html'] - result = subprocess.run(cmd, - shell=shell, - capture_output=True, - env=env) - if result.returncode != 0: - print("FAILED - {}".format(" ".join(cmd))) - print(result.stderr.decode()) - - # Generate json reports. - combine_json_reports() - # Generate csv reports. - combine_csv_reports() + env = os.environ + env = {**env, **env_vars} + cmd = [ + "pytest_html_merger", + "-i", + f"{base_dir}", + "-o", + f"{base_dir}/final_compiled_report.html", + ] + result = subprocess.run(cmd, shell=shell, capture_output=True, env=env) + if result.returncode != 0: + print("FAILED - {}".format(" ".join(cmd))) + print(result.stderr.decode()) + + # Generate json reports. + combine_json_reports() + # Generate csv reports. + combine_csv_reports() def run_shell_command(cmd, shell=False, env_vars={}): - env = os.environ - env = {**env, **env_vars} - result = subprocess.run(cmd, - shell=shell, - capture_output=True, - env=env) - if result.returncode != 0: - print("FAILED - {}".format(" ".join(cmd))) - print(result.stderr.decode()) + env = os.environ + env = {**env, **env_vars} + result = subprocess.run(cmd, shell=shell, capture_output=True, env=env) + if result.returncode != 0: + print("FAILED - {}".format(" ".join(cmd))) + print(result.stderr.decode()) - return result.returncode, result.stderr.decode(), result.stdout.decode() + return result.returncode, result.stderr.decode(), result.stdout.decode() def parse_test_log(log_file): - """Parses the test module log file to extract test modules and functions.""" - test_files = set() - with open(log_file, "r") as f: - for line in f: - report = json.loads(line) - if "nodeid" in report: - module = report["nodeid"].split("::")[0] - if module and ".py" in module: - test_files.add(os.path.abspath(module)) - return test_files + """Parses the test module log file to extract test modules and functions.""" + test_files = set() + with open(log_file, "r") as f: + for line in f: + report = json.loads(line) + if "nodeid" in report: + module = report["nodeid"].split("::")[0] + if module and ".py" in module: + test_files.add(os.path.abspath(module)) + return test_files def collect_testmodules(): - log_file = f"{base_dir}/collect_module_log.jsonl" - return_code, stderr, stdout = run_shell_command( - ["python3", "-m", "pytest", "--collect-only", "tests", f"--report-log={log_file}"]) - if return_code != 0: - print("Test module discovery failed.") - print("STDOUT:", stdout) - print("STDERR:", stderr) - exit(return_code) - print("---------- collected test modules ----------") - test_files = parse_test_log(log_file) - print("Found %d test modules." % (len(test_files))) - print("--------------------------------------------") - print("\n".join(test_files)) - return test_files + log_file = f"{base_dir}/collect_module_log.jsonl" + return_code, stderr, stdout = run_shell_command( + [ + "python3", + "-m", + "pytest", + "--collect-only", + "tests", + f"--report-log={log_file}", + ] + ) + if return_code != 0: + print("Test module discovery failed.") + print("STDOUT:", stdout) + print("STDERR:", stderr) + exit(return_code) + print("---------- collected test modules ----------") + test_files = parse_test_log(log_file) + print("Found %d test modules." % (len(test_files))) + print("--------------------------------------------") + print("\n".join(test_files)) + return test_files def run_test(testmodule, gpu_tokens, continue_on_fail): - global LAST_CODE - with GPU_LOCK: - if LAST_CODE != 0: - return - target_gpu = gpu_tokens.pop() - env_vars = { - "HIP_VISIBLE_DEVICES": str(target_gpu), - "XLA_PYTHON_CLIENT_ALLOCATOR": "default", - } - testfile = extract_filename(testmodule) - if continue_on_fail: - cmd = ["python3", "-m", "pytest", - "--json-report", f"--json-report-file={base_dir}/{testfile}_log.json", - f"--csv={base_dir}/{testfile}_log.csv", - "--csv-columns", "id,module,name,file,status,duration", - f"--html={base_dir}/{testfile}_log.html", - "--reruns", "3", "-v", testmodule] - else: - cmd = ["python3", "-m", "pytest", - "--json-report", f"--json-report-file={base_dir}/{testfile}_log.json", - f"--csv={base_dir}/{testfile}_log.csv", - "--csv-columns", "id,module,name,file,status,duration", - f"--html={base_dir}/{testfile}_log.html", - "--reruns", "3", "-x", "-v", testmodule] - - return_code, stderr, stdout = run_shell_command(cmd, env_vars=env_vars) - with GPU_LOCK: - gpu_tokens.append(target_gpu) - if LAST_CODE == 0: - print("Running tests in module %s on GPU %d:" % (testmodule, target_gpu)) - print(stdout) - print(stderr) - if continue_on_fail == False: - LAST_CODE = return_code + global LAST_CODE + with GPU_LOCK: + if LAST_CODE != 0: + return + target_gpu = gpu_tokens.pop() + env_vars = { + "HIP_VISIBLE_DEVICES": str(target_gpu), + "XLA_PYTHON_CLIENT_ALLOCATOR": "default", + } + testfile = extract_filename(testmodule) + if continue_on_fail: + cmd = [ + "python3", + "-m", + "pytest", + "--json-report", + f"--json-report-file={base_dir}/{testfile}_log.json", + f"--csv={base_dir}/{testfile}_log.csv", + "--csv-columns", + "id,module,name,file,status,duration", + f"--html={base_dir}/{testfile}_log.html", + "--reruns", + "3", + "-v", + testmodule, + ] + else: + cmd = [ + "python3", + "-m", + "pytest", + "--json-report", + f"--json-report-file={base_dir}/{testfile}_log.json", + f"--csv={base_dir}/{testfile}_log.csv", + "--csv-columns", + "id,module,name,file,status,duration", + f"--html={base_dir}/{testfile}_log.html", + "--reruns", + "3", + "-x", + "-v", + testmodule, + ] + + return_code, stderr, stdout = run_shell_command(cmd, env_vars=env_vars) + with GPU_LOCK: + gpu_tokens.append(target_gpu) + if LAST_CODE == 0: + print("Running tests in module %s on GPU %d:" % (testmodule, target_gpu)) + print(stdout) + print(stderr) + if continue_on_fail == False: + LAST_CODE = return_code def run_parallel(all_testmodules, p, c): - print(f"Running tests with parallelism = {p}") - available_gpu_tokens = list(range(p)) - executor = ThreadPoolExecutor(max_workers=p) - # walking through test modules. - for testmodule in all_testmodules: - executor.submit(run_test, testmodule, available_gpu_tokens, c) - # waiting for all modules to finish. - executor.shutdown(wait=True) + print(f"Running tests with parallelism = {p}") + available_gpu_tokens = list(range(p)) + executor = ThreadPoolExecutor(max_workers=p) + # walking through test modules. + for testmodule in all_testmodules: + executor.submit(run_test, testmodule, available_gpu_tokens, c) + # waiting for all modules to finish. + executor.shutdown(wait=True) def find_num_gpus(): - cmd = [r"lspci|grep 'controller\|accel'|grep 'AMD/ATI'|wc -l"] - _, _, stdout = run_shell_command(cmd, shell=True) - return int(stdout) + cmd = [r"lspci|grep 'controller\|accel'|grep 'AMD/ATI'|wc -l"] + _, _, stdout = run_shell_command(cmd, shell=True) + return int(stdout) def main(args): - all_testmodules = collect_testmodules() - run_parallel(all_testmodules, args.parallel, args.continue_on_fail) - generate_final_report() - exit(LAST_CODE) - - -if __name__ == '__main__': - os.environ['HSA_TOOLS_LIB'] = "libroctracer64.so" - parser = argparse.ArgumentParser() - parser.add_argument("-p", - "--parallel", - type=int, - help="number of tests to run in parallel") - parser.add_argument("-c", - "--continue_on_fail", - action='store_true', - help="continue on failure") - args = parser.parse_args() - if args.continue_on_fail: - print("continue on fail is set") - if args.parallel is None: - sys_gpu_count = find_num_gpus() - args.parallel = sys_gpu_count - print("%d GPUs detected." % sys_gpu_count) - - main(args) + all_testmodules = collect_testmodules() + run_parallel(all_testmodules, args.parallel, args.continue_on_fail) + generate_final_report() + exit(LAST_CODE) + + +if __name__ == "__main__": + os.environ["HSA_TOOLS_LIB"] = "libroctracer64.so" + parser = argparse.ArgumentParser() + parser.add_argument( + "-p", "--parallel", type=int, help="number of tests to run in parallel" + ) + parser.add_argument( + "-c", "--continue_on_fail", action="store_true", help="continue on failure" + ) + args = parser.parse_args() + if args.continue_on_fail: + print("continue on fail is set") + if args.parallel is None: + sys_gpu_count = find_num_gpus() + args.parallel = sys_gpu_count + print("%d GPUs detected." % sys_gpu_count) + + main(args) diff --git a/build/rocm/tools/build_wheels.py b/build/rocm/tools/build_wheels.py index 33f2e100de61..1483608fa71b 100644 --- a/build/rocm/tools/build_wheels.py +++ b/build/rocm/tools/build_wheels.py @@ -116,6 +116,7 @@ def build_jaxlib_wheel( if compiler == "clang": clang_path = find_clang_path() if clang_path: + LOG.info("Found clang at path: %s", clang_path) cmd.append("--clang_path=%s" % clang_path) else: raise RuntimeError("Clang binary not found in /usr/lib/llvm-*") @@ -315,6 +316,21 @@ def main(): LOG.info("Copying %s into %s" % (whl, wheelhouse_dir)) shutil.copy(whl, wheelhouse_dir) + # Delete the 'dist' directory since it causes permissions issues + logging.info('Deleting dist, egg-info and cache directory') + shutil.rmtree(os.path.join(args.jax_path, "dist")) + shutil.rmtree(os.path.join(args.jax_path, "jax.egg-info")) + shutil.rmtree(os.path.join(args.jax_path, "jax", "__pycache__")) + + # Make the wheels deleteable by the runner + whl_house = os.path.join(args.jax_path, "wheelhouse") + logging.info("Changing permissions for %s" % whl_house) + mode = 0o664 + for item in os.listdir(whl_house): + whl_path = os.path.join(whl_house, item) + if os.path.isfile(whl_path): + os.chmod(whl_path, mode) + if __name__ == "__main__": logging.basicConfig(level=logging.INFO) diff --git a/third_party/xla/workspace.bzl b/third_party/xla/workspace.bzl index 18c1e8f80b3d..1d007392fa12 100644 --- a/third_party/xla/workspace.bzl +++ b/third_party/xla/workspace.bzl @@ -21,15 +21,15 @@ load("//third_party:repo.bzl", "tf_http_archive", "tf_mirror_urls") # curl -L https://github.com/openxla/xla/archive/.tar.gz | sha256sum # and update XLA_SHA256 with the result. -XLA_COMMIT = "1a6361a734c5cd10dc93938fc6163a51fd37b82e" -XLA_SHA256 = "01159fd52f0e402829a3823472a309562817c72d0212f81cd5555f77394c094f" +XLA_COMMIT = "373f359cbd8d02ee850d98fed92a7bbca4a09c1b" +XLA_SHA256 = "bccda939edabf6723fcb9e59b833288d66ff93b6f34902c28c521a0b39b52d83" def repo(): tf_http_archive( name = "xla", sha256 = XLA_SHA256, strip_prefix = "xla-{commit}".format(commit = XLA_COMMIT), - urls = tf_mirror_urls("https://github.com/openxla/xla/archive/{commit}.tar.gz".format(commit = XLA_COMMIT)), + urls = tf_mirror_urls("https://github.com/rocm/xla/archive/{commit}.tar.gz".format(commit = XLA_COMMIT)), ) # For development, one often wants to make changes to the TF repository as well