diff --git a/build/rocm/ci_build b/build/rocm/ci_build index aeb0201e27ed..3c051be8c027 100755 --- a/build/rocm/ci_build +++ b/build/rocm/ci_build @@ -83,11 +83,15 @@ def dist_wheels( cmd = ["docker", "run"] + # docker run fails when mounting ./ + own_path = os.path.dirname(os.path.abspath(__file__)) + repo_path = os.path.abspath(os.path.join(own_path, "..","..")) + whl_path = os.path.join(repo_path, "wheelhouse") mounts = [ "-v", - "./:/jax", + "%s:/jax" % repo_path, "-v", - "./wheelhouse:/wheelhouse", + "%s:/wheelhouse" % whl_path, ] if xla_path: @@ -130,10 +134,16 @@ def _fetch_jax_metadata(xla_path): jax_version = subprocess.check_output(cmd, env=env) + def safe_decode(x): + if isinstance(x, str): + return x + else: + return x.decode("utf8") + return { - "jax_version": jax_version.decode("utf8").strip(), - "jax_commit": jax_commit.decode("utf8").strip(), - "xla_commit": xla_commit.decode("utf8").strip(), + "jax_version": safe_decode(jax_version).strip(), + "jax_commit": safe_decode(jax_commit).strip(), + "xla_commit": safe_decode(xla_commit).strip(), } @@ -204,9 +214,12 @@ def test(image_name): # NOTE(mrodden): we need jax source dir for the unit test code only, # JAX and jaxlib are already installed from wheels + # docker run fails when mounting ./ + own_path = os.path.dirname(os.path.abspath(__file__)) + repo_path = os.path.abspath(os.path.join(own_path, "..","..")) mounts = [ "-v", - "./:/jax", + "%s:/jax" % repo_path, ] cmd.extend(mounts)