Skip to content

Commit

Permalink
Merge pull request #185 from ROCm/ci-upstream-sync-12-12-2024
Browse files Browse the repository at this point in the history
CI: 12/12/24 upstream sync
  • Loading branch information
charleshofer authored Dec 12, 2024
2 parents 6dc4dee + 02831ed commit 5cda053
Show file tree
Hide file tree
Showing 103 changed files with 2,463 additions and 878 deletions.
5 changes: 5 additions & 0 deletions .bazelrc
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,11 @@ build:avx_windows --copt=/arch:AVX

build:mkl_open_source_only --define=tensorflow_mkldnn_contraction_kernel=1

# Config setting to build oneDNN with Compute Library for the Arm Architecture (ACL).
build:mkl_aarch64_threadpool --define=build_with_mkl_aarch64=true
build:mkl_aarch64_threadpool --@compute_library//:openmp=false
build:mkl_aarch64_threadpool -c opt

# Disable clang extention that rejects type definitions within offsetof.
# This was added in clang-16 by https://reviews.llvm.org/D133574.
# Can be removed once upb is updated, since a type definition is used within
Expand Down
12 changes: 9 additions & 3 deletions .github/workflows/bazel_cpu_rbe.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@ on:
options:
- 'yes'
- 'no'
pull_request:
branches:
- main

concurrency:
group: ${{ github.workflow }}-${{ github.head_ref || github.ref }}
Expand All @@ -21,15 +24,18 @@ jobs:
if: github.event.repository.fork == false
strategy:
matrix:
runner: ["linux-x86-n2-16", "linux-arm64-t2a-16"]
runner: ["linux-x86-n2-16", "linux-arm64-c4a-16"]
enable-x_64: [1, 0]

runs-on: ${{ matrix.runner }}
# TODO(b/369382309): Replace Linux Arm64 container with the ml-build container once it is available
container: ${{ (contains(matrix.runner, 'linux-x86') && 'us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build:latest') ||
(contains(matrix.runner, 'linux-arm64') && 'us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/linux-arm64-arc-container:latest') }}
(contains(matrix.runner, 'linux-arm64') && 'us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build-arm64:latest') }}

env:
JAXCI_HERMETIC_PYTHON_VERSION: "3.12"
JAXCI_ENABLE_X64: ${{ matrix.enable-x_64 }}

name: "Bazel CPU tests (${{ matrix.runner }}, Python 3.12, x64=${{ matrix.enable-x_64 }})"

steps:
- uses: actions/checkout@v3
Expand Down
7 changes: 7 additions & 0 deletions .github/workflows/bazel_gpu_rbe.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@ on:
options:
- 'yes'
- 'no'
pull_request:
branches:
- main

concurrency:
group: ${{ github.workflow }}-${{ github.head_ref || github.ref }}
Expand All @@ -22,12 +25,16 @@ jobs:
strategy:
matrix:
runner: ["linux-x86-n2-16"]
enable-x_64: [1, 0]

runs-on: ${{ matrix.runner }}
container: 'us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build:latest'

env:
JAXCI_HERMETIC_PYTHON_VERSION: "3.12"
JAXCI_ENABLE_X64: ${{ matrix.enable-x_64 }}

name: "Bazel single accelerator GPU tests (${{ matrix.runner }}, Python 3.12, x64=${{ matrix.enable-x_64 }})"

steps:
- uses: actions/checkout@v3
Expand Down
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ repos:
- id: mypy
files: (jax/|tests/typing_test\.py)
exclude: jax/_src/basearray.py|jax/numpy/__init__.py # Use pyi instead
additional_dependencies: [types-requests==2.31.0, jaxlib, numpy~=2.1.0]
additional_dependencies: [types-requests==2.31.0, jaxlib, numpy>=2.2.0]
args: [--config=pyproject.toml]

- repo: https://github.com/mwouts/jupytext
Expand Down
5 changes: 5 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,11 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.
APIs of the same name in {mod}`jax.extend.core`; see the documentation for
{mod}`jax.extend` for information on the compatibility guarantees of these
semi-public extensions.
* Several previously-deprecated APIs have been removed, including:
* from {mod}`jax.core`: `check_eqn`, `check_type`, `check_valid_jaxtype`, and
`non_negative_dim`.
* from {mod}`jax.lib.xla_bridge`: `xla_client` and `default_backend`.
* from {mod}`jax.lib.xla_client`: `_xla` and `bfloat16`.

## jax 0.4.37 (Dec 9, 2024)

Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ are instances of such transformations. Others are
[`pmap`](#spmd-programming-with-pmap) for single-program multiple-data (SPMD)
parallel programming of multiple accelerators, with more to come.

This is a research project, not an official Google product. Expect bugs and
This is a research project, not an official Google product. Expect
[sharp edges](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html).
Please help by trying it out, [reporting
bugs](https://github.com/jax-ml/jax/issues), and letting us know what you
Expand Down
5 changes: 4 additions & 1 deletion build/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -485,7 +485,10 @@ async def main():

if not args.disable_mkl_dnn:
logging.debug("Enabling MKL DNN")
wheel_build_command.append("--config=mkl_open_source_only")
if target_cpu == "aarch64":
wheel_build_command.append("--config=mkl_aarch64_threadpool")
else:
wheel_build_command.append("--config=mkl_open_source_only")

if args.target_cpu_features == "release":
if arch in ["x86_64", "AMD64"]:
Expand Down
1 change: 0 additions & 1 deletion docs/jax.lib.rst
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ jax.lib.xla_bridge
.. autosummary::
:toctree: _autosummary

default_backend
get_backend
get_compile_options

Expand Down
2 changes: 2 additions & 0 deletions docs/jax.numpy.rst
Original file line number Diff line number Diff line change
Expand Up @@ -274,6 +274,7 @@ namespace; they are listed below.
mask_indices
matmul
matrix_transpose
matvec
max
maximum
mean
Expand Down Expand Up @@ -428,6 +429,7 @@ namespace; they are listed below.
var
vdot
vecdot
vecmat
vectorize
vsplit
vstack
Expand Down
4 changes: 2 additions & 2 deletions examples/ffi/tests/cpu_examples_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,10 @@ def test_array_attr_jit_cache(self):
jit_array_attr = jax.jit(cpu_examples.array_attr, static_argnums=(0,))
with jtu.count_jit_and_pmap_lowerings() as count:
jit_array_attr(5)
self.assertEqual(count[0], 1) # compiles once the first time
self.assertEqual(count(), 1) # compiles once the first time
with jtu.count_jit_and_pmap_lowerings() as count:
jit_array_attr(5)
self.assertEqual(count[0], 0) # cache hit
self.assertEqual(count(), 0) # cache hit

def test_array_attr_no_jit(self):
with jax.disable_jit():
Expand Down
2 changes: 2 additions & 0 deletions jax/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -455,6 +455,7 @@ pytype_strict_library(
":dtypes",
":effects",
":mesh",
":partition_spec",
":pretty_printer",
":source_info_util",
":traceback_util",
Expand Down Expand Up @@ -558,6 +559,7 @@ pytype_strict_library(
":layout",
":op_shardings",
":partial_eval",
":partition_spec",
":path",
":pickle_util",
":sharding",
Expand Down
3 changes: 3 additions & 0 deletions jax/_src/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
from jax._src import dtypes
from jax._src import errors
from jax._src import profiler
from jax._src import util
from jax._src import xla_bridge
from jax._src.interpreters import mlir
from jax._src.interpreters import pxla
Expand Down Expand Up @@ -1131,6 +1132,7 @@ def _sharding_indices_and_eq(src_sharding, shape, dst_sharding):


def _array_shard_arg(xs, shardings, layouts, copy_semantics):
util.test_event("_array_shard_arg")
results = []
batch_xs, batch_devs, batch_shardings, batch_indices = [], [], [], []
batch_cs = []
Expand Down Expand Up @@ -1168,6 +1170,7 @@ def _array_shard_arg(xs, shardings, layouts, copy_semantics):
results.append(
shard_sharded_device_array_slow_path(x, devices, indices, sharding))

util.test_event("batched_copy_array")
copy_outs = xc.batched_copy_array_to_devices_with_sharding(
batch_xs, batch_devs, batch_shardings, batch_cs)
for i, copy_out in safe_zip(batch_indices, copy_outs):
Expand Down
63 changes: 29 additions & 34 deletions jax/_src/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
from jax._src import effects
from jax._src import compute_on
from jax._src import mesh as mesh_lib
from jax._src.partition_spec import PartitionSpec as P, UnconstrainedSingleton
from jax._src.errors import (
ConcretizationTypeError, TracerArrayConversionError, TracerBoolConversionError,
TracerIntegerConversionError, UnexpectedTracerError)
Expand Down Expand Up @@ -1599,19 +1600,34 @@ def _invalid_shape_error(shape: Shape, context: str=""):

return TypeError(msg)

# TODO(yashkatariya): Only works with User/Auto. Generalize it to work with
# Collective too.
def _maybe_modify_sharding(sharding):
if mesh_lib.AxisTypes.Auto not in sharding.mesh.axis_types:
return sharding

new_spec = []
for s in sharding.spec:
if s is None or isinstance(s, UnconstrainedSingleton):
new_spec.append(s)
else:
temp_s = s[0] if isinstance(s, tuple) else s
new_spec.append(
P.UNCONSTRAINED
if sharding.mesh._name_to_type[temp_s] == mesh_lib.AxisTypes.Auto else s)
return sharding.with_spec(new_spec)


def get_sharding(sharding, ndim):
from jax._src.sharding_impls import NamedSharding, PartitionSpec as P # type: ignore
from jax._src.sharding_impls import NamedSharding # type: ignore

if sharding is not None:
assert len(sharding.spec) == ndim
return sharding
return _maybe_modify_sharding(sharding)

context_mesh = mesh_lib.get_abstract_mesh()
# TODO(yashkatariya): Error out and ask users to set the context mesh in their
# code.
if not context_mesh:
return None
return RuntimeError("Please set the mesh via `jax.set_mesh` API.")
assert sharding is None
return NamedSharding(context_mesh, P(*[None] * ndim))

Expand Down Expand Up @@ -1674,10 +1690,8 @@ def str_short(self, short_dtypes=False):
self.dtype.name)
dt_str = dt_str.replace('void', 'float0')
if hasattr(self, 'sharding') and self.sharding is not None:
shapestr = _get_shape_sharding_str(self.shape, self.sharding.spec)
axis_types = self.sharding.mesh.axis_types
axt = _get_axis_type_str(axis_types) if axis_types is not None else ''
return f'{dt_str}[{shapestr}]{axt}'
shapestr = _get_shape_sharding_str(self.shape, self.sharding.spec) # type: ignore
return f'{dt_str}[{shapestr}]'
else:
shapestr = ','.join(map(str, self.shape))
return f'{dt_str}[{shapestr}]'
Expand All @@ -1689,26 +1703,13 @@ def _len(self, ignored_tracer):
raise TypeError("len() of unsized object") from err # same as numpy error


def _get_axis_type_str(axis_types):
from jax._src.mesh import AxisTypes # type: ignore

out = []
for t, axes in axis_types.items():
a = f"({','.join(a for a in axes)})" if isinstance(axes, tuple) else axes
if t == AxisTypes.Collective:
out.append(f"C:{a}")
elif t == AxisTypes.User:
out.append(f"U:{a}")
else:
assert t == AxisTypes.Auto
out.append(f"A:{a}")
return f"{{{', '.join(out)}}}"

def _get_shape_sharding_str(shape, spec):
out = []
for s1, s2 in zip(shape, spec):
if s2 is None:
out.append(f"{s1}")
elif isinstance(s2, UnconstrainedSingleton):
out.append(f"{s1}")
elif isinstance(s2, tuple):
ss = ','.join(s for s in s2)
out.append(f"{s1}@({ss})")
Expand Down Expand Up @@ -2655,16 +2656,10 @@ def substitute(aval: AbstractValue):
return aval
for v, x in zip(call_jaxpr.invars, in_atoms):
if not typecompat(substitute(v.aval), x.aval):
# TODO(yashkatariya): Remove this once numpy array's aval has a sharding
# on it.
if (config.sharding_in_types.value and isinstance(x, Literal) and
v.aval.sharding is not None and x.val.ndim == 0):
pass
else:
# TODO(mattjj): vars in error message are confusing b/c of Var.__repr__
raise JaxprTypeError(f"Call primitive {prim} passes operand {x} of type "
f"{x.aval} to jaxpr expecting type "
f"{substitute(v.aval)}")
# TODO(mattjj): vars in error message are confusing b/c of Var.__repr__
raise JaxprTypeError(f"Call primitive {prim} passes operand {x} of type "
f"{x.aval} to jaxpr expecting type "
f"{substitute(v.aval)}")
env[v] = x if type(x) is Var else x.val

_check_jaxpr(ctx_factory, call_jaxpr)
Expand Down
1 change: 1 addition & 0 deletions jax/_src/dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ def apply_primitive(prim, *args, **params):

@util.cache()
def xla_primitive_callable(prim: core.Primitive, **params):
util.test_event("xla_primitive_callable_cache_miss")
def prim_fun(*args):
with config.eager_constant_folding(False):
return prim.bind(*args, **params)
Expand Down
2 changes: 1 addition & 1 deletion jax/_src/dtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -419,7 +419,7 @@ def _issubdtype_cached(a: type | np.dtype | ExtendedDType,
return b_sctype in {a_sctype, np.unsignedinteger, np.integer, np.number, np.generic}

# Otherwise, fall back to numpy.issubdtype
return np.issubdtype(a_sctype, b_sctype)
return bool(np.issubdtype(a_sctype, b_sctype))

can_cast = np.can_cast

Expand Down
1 change: 1 addition & 0 deletions jax/_src/export/_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -1017,6 +1017,7 @@ def _check_lowering(lowering) -> None:
"lapack_ssytrd_ffi", "lapack_dsytrd_ffi", "lapack_chetrd_ffi", "lapack_zhetrd_ffi",
"lapack_sgehrd_ffi", "lapack_dgehrd_ffi", "lapack_cgehrd_ffi", "lapack_zgehrd_ffi",
"lapack_sgees_ffi", "lapack_dgees_ffi", "lapack_cgees_ffi", "lapack_zgees_ffi",
"lapack_strsm_ffi", "lapack_dtrsm_ffi", "lapack_ctrsm_ffi", "lapack_ztrsm_ffi",
]
# These are the JAX custom call target names that are guaranteed to be stable.
# Their backwards compatibility is tested by back_compat_test.py.
Expand Down
Loading

0 comments on commit 5cda053

Please sign in to comment.