Skip to content

Commit

Permalink
Merge pull request #194 from ROCm/ci-upstream-sync-80_1
Browse files Browse the repository at this point in the history
CI: 01/07/25 upstream sync
  • Loading branch information
github-actions[bot] authored Jan 7, 2025
2 parents 4b11080 + a94ee1f commit 972f95b
Show file tree
Hide file tree
Showing 64 changed files with 850 additions and 651 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/ci-build.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -132,8 +132,8 @@ jobs:
JAX_ARRAY: 1
PY_COLORS: 1
run: |
pytest -n auto --tb=short --doctest-glob='*.md' --doctest-glob='*.rst' docs --doctest-continue-on-failure --ignore=docs/multi_process.md --ignore=docs/jax.experimental.array_api.rst
pytest -n auto --tb=short --doctest-modules jax --ignore=jax/config.py --ignore=jax/experimental/jax2tf --ignore=jax/_src/lib/mlir --ignore=jax/_src/lib/triton.py --ignore=jax/_src/lib/mosaic_gpu.py --ignore=jax/interpreters/mlir.py --ignore=jax/experimental/array_serialization --ignore=jax/collect_profile.py --ignore=jax/_src/tpu_custom_call.py --ignore=jax/experimental/mosaic --ignore=jax/experimental/pallas --ignore=jax/_src/pallas --ignore=jax/experimental/array_api --ignore=jax/lib/xla_extension.py
pytest -n auto --tb=short --doctest-glob='*.md' --doctest-glob='*.rst' docs --doctest-continue-on-failure --ignore=docs/multi_process.md
pytest -n auto --tb=short --doctest-modules jax --ignore=jax/config.py --ignore=jax/experimental/jax2tf --ignore=jax/_src/lib/mlir --ignore=jax/_src/lib/triton.py --ignore=jax/_src/lib/mosaic_gpu.py --ignore=jax/interpreters/mlir.py --ignore=jax/experimental/array_serialization --ignore=jax/collect_profile.py --ignore=jax/_src/tpu_custom_call.py --ignore=jax/experimental/mosaic --ignore=jax/experimental/pallas --ignore=jax/_src/pallas --ignore=jax/lib/xla_extension.py
documentation_render:
Expand Down
13 changes: 13 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,24 +19,37 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.
* Changes:
* The minimum NumPy version is now 1.25. NumPy 1.25 will remain the minimum
supported version until June 2025.
* {func}`jax.numpy.einsum` now defaults to `optimize='auto'` rather than
`optimize='optimal'`. This avoids exponentially-scaling trace-time in
the case of many arguments ({jax-issue}`#25214`).

* New Features
* {func}`jax.numpy.fft.fftn`, {func}`jax.numpy.fft.rfftn`,
{func}`jax.numpy.fft.ifftn`, and {func}`jax.numpy.fft.irfftn` now support
transforms in more than 3 dimensions, which was previously the limit. See
{jax-issue}`#25606` for more details.
* Support added for user defined state in the FFI via the new
{func}`jax.ffi.register_ffi_type_id` function.

* Deprecations
* From {mod}`jax.interpreters.xla`, `abstractify` and `pytype_aval_mappings`
are now deprecated, having been replaced by symbols of the same name
in {mod}`jax.core`.
* {func}`jax.scipy.special.lpmn` and {func}`jax.scipy.special.lpmn_values`
are deprecated, following their deprecation in SciPy v1.15.0. There are
no plans to replace these deprecated functions with new APIs.
* The {mod}`jax.extend.ffi` submodule was moved to {mod}`jax.ffi`, and the
previous import path is deprecated.

* Deletions
* `jax_enable_memories` flag has been deleted and the behavior of that flag
is on by default.
* From `jax.lib.xla_client`, the previously-deprecated `Device` and
`XlaRuntimeError` symbols have been removed; instead use `jax.Device`
and `jax.errors.JaxRuntimeError` respectively.
* The `jax.experimental.array_api` module has been removed after being
deprecated in JAX v0.4.32. Since that release, {mod}`jax.numpy` supports
the array API directly.

## jax 0.4.38 (Dec 17, 2024)

Expand Down
7 changes: 7 additions & 0 deletions build/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -468,6 +468,9 @@ async def main():
# Enable clang settings that are needed for the build to work with newer
# versions of Clang.
wheel_build_command_base.append("--config=clang")
if clang_major_version < 19:
wheel_build_command_base.append("--define=xnn_enable_avxvnniint8=false")

else:
gcc_path = args.gcc_path or utils.get_gcc_path_or_exit()
logging.debug(
Expand All @@ -477,6 +480,10 @@ async def main():
wheel_build_command_base.append(f"--repo_env=CC=\"{gcc_path}\"")
wheel_build_command_base.append(f"--repo_env=BAZEL_COMPILER=\"{gcc_path}\"")

gcc_major_version = utils.get_gcc_major_version(gcc_path)
if gcc_major_version < 13:
wheel_build_command_base.append("--define=xnn_enable_avxvnniint8=false")

if not args.disable_mkl_dnn:
logging.debug("Enabling MKL DNN")
if target_cpu == "aarch64":
Expand Down
5 changes: 2 additions & 3 deletions build/rocm/tools/get_rocm.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,9 +74,8 @@ def install_packages(self, package_specs):
env = dict(os.environ)
if self.pkgbin == "apt":
env["DEBIAN_FRONTEND"] = "noninteractive"

# Update indexes.
subprocess.check_call(["apt-get", "update"])
# Update indexes.
subprocess.check_call(["apt-get", "update"])

LOG.info("Running %r" % cmd)
subprocess.check_call(cmd, env=env)
Expand Down
12 changes: 12 additions & 0 deletions build/tools/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,18 @@ def get_clang_major_version(clang_path):

return major_version

def get_gcc_major_version(gcc_path: str):
gcc_version_proc = subprocess.run(
[gcc_path, "-dumpversion"],
check=True,
capture_output=True,
text=True,
)
major_version = int(gcc_version_proc.stdout)

return major_version


def get_jax_configure_bazel_options(bazel_command: list[str]):
"""Returns the bazel options to be written to .jax_configure.bazelrc."""
# Get the index of the "run" parameter. Build options will come after "run" so
Expand Down
1 change: 1 addition & 0 deletions docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,4 +362,5 @@ def linkcode_resolve(domain, info):
'jax-101/index.rst': 'tutorials.rst',
'notebooks/external_callbacks.md': 'external-callbacks.md',
'notebooks/How_JAX_primitives_work.md': 'jax-primitives.md',
'jax.extend.ffi.rst': 'jax.ffi.rst',
}
2 changes: 1 addition & 1 deletion docs/developer.md
Original file line number Diff line number Diff line change
Expand Up @@ -659,7 +659,7 @@ You can find the up-to-date command to run doctests in
E.g., you can run:

```
JAX_TRACEBACK_FILTERING=off XLA_FLAGS=--xla_force_host_platform_device_count=8 pytest -n auto --tb=short --doctest-glob='*.md' --doctest-glob='*.rst' docs --doctest-continue-on-failure --ignore=docs/multi_process.md --ignore=docs/jax.experimental.array_api.rst
JAX_TRACEBACK_FILTERING=off XLA_FLAGS=--xla_force_host_platform_device_count=8 pytest -n auto --tb=short --doctest-glob='*.md' --doctest-glob='*.rst' docs --doctest-continue-on-failure --ignore=docs/multi_process.md
```

Additionally, JAX runs pytest in `doctest-modules` mode to ensure code examples in
Expand Down
Loading

0 comments on commit 972f95b

Please sign in to comment.