Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CI: 01/06/25 upstream sync #193

Merged
merged 61 commits into from
Jan 6, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
61 commits
Select commit Hold shift + click to select a range
307ea87
support head size of 256
kaixih Oct 29, 2024
75b5654
Fix a typo in documentation for `pinv` function.
liblaf Dec 23, 2024
3e7f481
[pallas:mosaic_gpu] Updated the lowering following the changes in in …
superbobry Dec 23, 2024
a51d627
[pallas:mosaic_gpu] Reduced duplication between `_ensure_fa` and `_en…
superbobry Dec 23, 2024
83e60a9
[pallas:triton] Add support for lowering `int4` load.
chr1sj0nes Dec 23, 2024
8987867
[mosaic_gpu] Include Mosaic GPU dialect fiels into jaxlib
superbobry Dec 23, 2024
76a51f8
Merge pull request #25667 from superbobry:main
Google-ML-Automation Dec 23, 2024
cb10710
Remove casting from jax.nn.one_hot
jakevdp Dec 23, 2024
68ec202
Use the right include for gmock and gtest
superbobry Dec 23, 2024
704185e
Merge pull request #24607 from kaixih:support_head_size_256
Google-ML-Automation Dec 23, 2024
6c85e54
Merge pull request #25662 from liblaf:main
Google-ML-Automation Dec 23, 2024
ccc3a29
Internal: use a single registry for abstractify APIs
jakevdp Dec 23, 2024
51b5102
Merge pull request #25651 from jakevdp:combine-abstractify
Google-ML-Automation Dec 23, 2024
23965b7
Update XLA dependency to use revision
Google-ML-Automation Dec 23, 2024
c206ae7
changelog: link to api compatibility & python version docs
jakevdp Dec 23, 2024
1087623
Merge pull request #25673 from jakevdp:changelog
Google-ML-Automation Dec 23, 2024
7da753e
Bump actions/upload-artifact from 4.4.3 to 4.5.0
dependabot[bot] Dec 23, 2024
40fe4b8
Finalize deprecation of some symbols from jax.lib.xla_client
jakevdp Dec 23, 2024
c57b49c
Merge pull request #25669 from jakevdp:undep
Google-ML-Automation Dec 23, 2024
3c79b98
[Mosaic:TPU] Vreg-slice-aligned offset changes with scratch retiling
tlongeri Dec 23, 2024
b8091a4
Switch `mlir` bindings from `pybind11` to `nanobind`
Google-ML-Automation Dec 23, 2024
4452960
[Mosaic:TPU] In infer ext rule, avoid assigning offsets outside of ds…
tlongeri Dec 23, 2024
fa9c7ed
Merge pull request #25674 from jax-ml:dependabot/github_actions/actio…
Google-ML-Automation Dec 24, 2024
4eff131
Merge pull request #25672 from jakevdp:finalize-dep
Google-ML-Automation Dec 24, 2024
44333e1
[pallas:mosaic_gpu] Addressed a todo in `broadcasted_iota` lowering
superbobry Dec 24, 2024
64511a1
Update XLA dependency to use revision
Google-ML-Automation Dec 24, 2024
b6aead6
[AutoPGLE] Explicitly disable command buffers when profiler is used.
Google-ML-Automation Dec 25, 2024
42a0d55
Update XLA dependency to use revision
Google-ML-Automation Dec 25, 2024
008c25a
Fix formatting in the docs for transposing pytrees
Mikcl Dec 26, 2024
6dbda90
Update XLA dependency to use revision
Google-ML-Automation Dec 26, 2024
7ab61b7
Update XLA dependency to use revision
Google-ML-Automation Dec 27, 2024
76ccb19
[pallas:mosaic_gpu] Added some runtime type checking to `copy_*` and …
superbobry Dec 28, 2024
8eeedd1
Update XLA dependency to use revision
Google-ML-Automation Dec 28, 2024
879fa12
Update XLA dependency to use revision
Google-ML-Automation Dec 29, 2024
97b1faa
Fixes the random key sharding in shard_map.
yliu120 Dec 29, 2024
25fff52
Update XLA dependency to use revision
Google-ML-Automation Dec 30, 2024
494c157
Merge pull request #25692 from yliu120:rng_key_sharding
Google-ML-Automation Dec 30, 2024
e37ea58
Update XLA dependency to use revision
Google-ML-Automation Dec 31, 2024
50670bd
Fix log10 and log2 for large inputs.
pearu Jan 1, 2025
4a6cfeb
Update XLA dependency to use revision
Google-ML-Automation Jan 1, 2025
213e178
tbp nightly instructions
rdyro Dec 23, 2024
dbe9ccd
Reverts 83e60a9697ec20023f4e11169edf64e910b93031
apaszke Jan 2, 2025
7c984c6
Don't use x32 mode for pallas_test
apaszke Jan 2, 2025
04a0fbe
Merge pull request #25661 from rdyro:tb-nightly-instructions
Google-ML-Automation Jan 2, 2025
82001ed
Merge pull request #25706 from pearu:pearu/log10-large
Google-ML-Automation Jan 2, 2025
ac817b4
[Mosaic:TPU][NFC] Clean up unused variable
tlongeri Jan 2, 2025
6443343
Fix OSS build for the Mosaic GPU dialect
apaszke Jan 2, 2025
68483b8
Merge pull request #25710 from apaszke:mgpu_dialect_fix
Google-ML-Automation Jan 2, 2025
726950b
Update XLA dependency to use revision
Google-ML-Automation Jan 2, 2025
800f903
Merge pull request #25686 from Mikcl:docs/working-with-pytrees-format…
Google-ML-Automation Jan 2, 2025
df36c29
Compute cost-analysis on only one HLO module.
zacmustin Jan 2, 2025
57b2154
[Mosaic] NFC: Pull out vreg related functions to util.
WindQAQ Jan 2, 2025
3306063
jax.debug.print: respect local np.printoptions
jakevdp Jan 3, 2025
0f4677b
Merge pull request #25713 from jakevdp:debug-printoptions
Google-ML-Automation Jan 3, 2025
e4278f7
Update XLA dependency to use revision
Google-ML-Automation Jan 3, 2025
9af2970
Update XLA dependency to use revision
Google-ML-Automation Jan 4, 2025
54fd738
Add SMEM as a supported Pallas output memory space.
Google-ML-Automation Jan 5, 2025
d0a92c5
Update XLA dependency to use revision
Google-ML-Automation Jan 5, 2025
a1734fd
Change to trigger CI
charleshofer Jan 6, 2025
307f0db
Skip failing tests
charleshofer Jan 6, 2025
708f48d
Skip one more test
charleshofer Jan 6, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/upstream-nightly.yml
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ jobs:
&& steps.status.outcome == 'failure'
&& github.event_name == 'schedule'
&& github.repository == 'jax-ml/jax'
uses: actions/upload-artifact@b4b15b8c7c6ac21ea08fcf65892d2ee8f75cf882 # v4.4.3
uses: actions/upload-artifact@6f51ac03b9356f520e9adb1b1b7802705f340c2b # v4.5.0
with:
name: output-${{ matrix.python-version }}-log.jsonl
path: output-${{ matrix.python-version }}-log.jsonl
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/wheel_win_x64.yml
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ jobs:
--bazel_options=--config=win_clang `
--verbose

- uses: actions/upload-artifact@b4b15b8c7c6ac21ea08fcf65892d2ee8f75cf882 # v4.4.3
- uses: actions/upload-artifact@6f51ac03b9356f520e9adb1b1b7802705f340c2b # v4.5.0
with:
name: wheels-${{ matrix.os }}-${{ matrix.pyver }}
path: ${{ github.workspace }}\dist\*.whl
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/windows_ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ jobs:
--bazel_options=--config=win_clang `
--verbose

- uses: actions/upload-artifact@b4b15b8c7c6ac21ea08fcf65892d2ee8f75cf882 # v4.4.3
- uses: actions/upload-artifact@6f51ac03b9356f520e9adb1b1b7802705f340c2b # v4.5.0
with:
name: wheels
path: ${{ github.workspace }}\jax\dist\*.whl
Expand Down
30 changes: 16 additions & 14 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,10 @@ Best viewed [here](https://jax.readthedocs.io/en/latest/changelog.html).
For the changes specific to the experimental Pallas APIs,
see {ref}`pallas-changelog`.

JAX follows Effort-based versioning; for a discussion of this and JAX's API
compatibility policy, refer to {ref}`api-compatibility`. For the Python and
NumPy version support policy, refer to {ref}`version-support-policy`.

<!--
Remember to align the itemized text with the first line of an item within a list.

Expand All @@ -12,30 +16,28 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.

## Unreleased

* Deprecations
* From {mod}`jax.interpreters.xla`, `abstractify` and `pytype_aval_mappings`
are now deprecated, having been replaced by symbols of the same name
in {mod}`jax.core`.

* Deletions
* `jax_enable_memories` flag has been deleted and the behavior of that flag
is on by default.

* Changes:
* The minimum NumPy version is now 1.25. NumPy 1.25 will remain the minimum
supported version until June 2025.

* Deprecations
* From {mod}`jax.interpreters.xla`, `abstractify` and `pytype_aval_mappings`
are now deprecated, having been replaced by symbols of the same name
in {mod}`jax.core`.

* New Features
* {func}`jax.numpy.fft.fftn`, {func}`jax.numpy.fft.rfftn`,
{func}`jax.numpy.fft.ifftn`, and {func}`jax.numpy.fft.irfftn` now support
transforms in more than 3 dimensions, which was previously the limit. See
{jax-issue}`#25606` for more details.

* Deprecations
* From {mod}`jax.interpreters.xla`, `abstractify` and `pytype_aval_mappings`
are now deprecated, having been replaced by symbols of the same name
in {mod}`jax.core`.

* Deletions
* `jax_enable_memories` flag has been deleted and the behavior of that flag
is on by default.
* From `jax.lib.xla_client`, the previously-deprecated `Device` and
`XlaRuntimeError` symbols have been removed; instead use `jax.Device`
and `jax.errors.JaxRuntimeError` respectively.

## jax 0.4.38 (Dec 17, 2024)

* Changes:
Expand Down
1 change: 1 addition & 0 deletions build/rocm/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -207,3 +207,4 @@ This will generate three wheels in the `dist/` directory:
### Simplified Build Script

For a streamlined process, consider using the `jax/build/rocm/dev_build_rocm.py` script.

6 changes: 6 additions & 0 deletions docs/profiling.md
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,12 @@ plugins" error described {ref}`below <multiple_installs>`. See
<https://www.tensorflow.org/guide/profiler> for more information on installing
TensorBoard.

Nightly version of TensorBoard profiler requires nightly tensorflow and
tensorboard
```shell
pip install tf-nightly tb-nightly tbp-nightly
```

### Programmatic capture

You can instrument your code to capture a profiler trace via the
Expand Down
2 changes: 1 addition & 1 deletion docs/working-with-pytrees.md
Original file line number Diff line number Diff line change
Expand Up @@ -490,7 +490,7 @@ This section covers some of the most common patterns with JAX pytrees.

### Transposing pytrees with `jax.tree.map` and `jax.tree.transpose`

To transpose a pytree (turn a list of trees into a tree of lists), JAX has two functions: {func} `jax.tree.map` (more basic) and {func}`jax.tree.transpose` (more flexible, complex and verbose).
To transpose a pytree (turn a list of trees into a tree of lists), JAX has two functions: {func}`jax.tree.map` (more basic) and {func}`jax.tree.transpose` (more flexible, complex and verbose).

**Option 1:** Use {func}`jax.tree.map`. Here's an example:

Expand Down
4 changes: 0 additions & 4 deletions jax/_src/abstract_arrays.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,6 @@ def masked_array_error(*args, **kwargs):
"Use arr.filled() to convert the value to a standard numpy array.")

core.pytype_aval_mappings[np.ma.MaskedArray] = masked_array_error
core.shaped_abstractify_handlers[np.ma.MaskedArray] = masked_array_error


def _make_shaped_array_for_numpy_array(x: np.ndarray) -> ShapedArray:
Expand All @@ -58,7 +57,6 @@ def _make_shaped_array_for_numpy_array(x: np.ndarray) -> ShapedArray:
return ShapedArray(x.shape, dtypes.canonicalize_dtype(dtype))

core.pytype_aval_mappings[np.ndarray] = _make_shaped_array_for_numpy_array
core.shaped_abstractify_handlers[np.ndarray] = _make_shaped_array_for_numpy_array


def _make_shaped_array_for_numpy_scalar(x: np.generic) -> ShapedArray:
Expand All @@ -68,7 +66,6 @@ def _make_shaped_array_for_numpy_scalar(x: np.generic) -> ShapedArray:

for t in numpy_scalar_types:
core.pytype_aval_mappings[t] = _make_shaped_array_for_numpy_scalar
core.shaped_abstractify_handlers[t] = _make_shaped_array_for_numpy_scalar

core.literalable_types.update(array_types)

Expand All @@ -81,6 +78,5 @@ def _make_abstract_python_scalar(typ, val):

for t in dtypes.python_scalar_dtypes:
core.pytype_aval_mappings[t] = partial(_make_abstract_python_scalar, t)
core.shaped_abstractify_handlers[t] = partial(_make_abstract_python_scalar, t)

core.literalable_types.update(dtypes.python_scalar_dtypes.keys())
1 change: 0 additions & 1 deletion jax/_src/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -2564,7 +2564,6 @@ def _sds_aval_mapping(x):
x.shape, dtypes.canonicalize_dtype(x.dtype, allow_extended_dtype=True),
weak_type=x.weak_type)
core.pytype_aval_mappings[ShapeDtypeStruct] = _sds_aval_mapping
core.shaped_abstractify_handlers[ShapeDtypeStruct] = _sds_aval_mapping


@api_boundary
Expand Down
1 change: 0 additions & 1 deletion jax/_src/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -1035,7 +1035,6 @@ def _get_aval_array(self):
else:
return self.aval

core.shaped_abstractify_handlers[ArrayImpl] = _get_aval_array
core.pytype_aval_mappings[ArrayImpl] = _get_aval_array

# TODO(jakevdp) replace this with true inheritance at the C++ level.
Expand Down
8 changes: 8 additions & 0 deletions jax/_src/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,14 @@ def get_compile_options(
build_options.exec_time_optimization_effort = config.exec_time_optimization_effort.value
build_options.memory_fitting_effort = config.memory_fitting_effort.value

# This is a temporary workaround to simplify the AutoPGLE usage.
# TODO(b/376647494): Remove once the bug is fixed.
if config.enable_pgle.value and config.pgle_profiling_runs.value > 0:
logger.debug("Explicitly disabling command buffer scheduling for AutoPGLE.")
if env_options_overrides is None:
env_options_overrides = {}
env_options_overrides['xla_gpu_enable_command_buffer'] = ''

if env_options_overrides is not None:
# Some overrides are passed directly on build_options.
overrides_on_build_options = [
Expand Down
89 changes: 47 additions & 42 deletions jax/_src/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -656,6 +656,13 @@ def check_bool_conversion(arr: Array):
" is ambiguous. Use a.any() or a.all()")


pytype_aval_mappings: dict[type, Callable[[Any], AbstractValue]] = {}

def _str_abstractify(x):
raise TypeError(f"Argument '{x}' of type {type(x)} is not a valid JAX type")
pytype_aval_mappings[str] = _str_abstractify


def _aval_property(name):
return property(lambda self: getattr(self.aval, name))

Expand Down Expand Up @@ -918,6 +925,8 @@ def unsafe_buffer_pointer(self):
aval_property = namedtuple("aval_property", ["fget"])
aval_method = namedtuple("aval_method", ["fun"])

pytype_aval_mappings[Tracer] = lambda x: x.aval

def check_eval_args(args):
for arg in args:
if isinstance(arg, Tracer):
Expand Down Expand Up @@ -1400,45 +1409,51 @@ def check_valid_jaxtype(x):
f"Value {x!r} of type {type(x)} is not a valid JAX type")


def _shaped_abstractify_slow(x):
try:
return x if isinstance(x, AbstractValue) else get_aval(x)
except TypeError:
pass

weak_type = getattr(x, 'weak_type', False)
if hasattr(x, 'dtype'):
dtype = dtypes.canonicalize_dtype(x.dtype, allow_extended_dtype=True)
else:
raise TypeError(
f"Cannot interpret value of type {type(x)} as an abstract array; it "
"does not have a dtype attribute")
return ShapedArray(np.shape(x), dtype, weak_type=weak_type)
# We have three flavors of abstractification APIs here which each used to have
# their own separate implementation. Now they're effectively the same, with the
# following differences:
#
# - abstractify returns avals for non-traced array-like objects.
# - get_aval is like abstractify, but also accepts tracers.
# - shaped_abstractify is like get_aval, but also accepts duck-typed arrays.
#
# TODO(jakevdp): can these be unified further?

# TODO(jakevdp): deduplicate this with abstractify
def shaped_abstractify(x):
# This was originally api_util.shaped_abstractify; temporarily moved
# here in order to facilitate combining it with abstractify.
handler = shaped_abstractify_handlers.get(type(x), None)
return handler(x) if handler is not None else _shaped_abstractify_slow(x)
typ = type(x)
if (aval_fn := pytype_aval_mappings.get(typ)): # fast path
return aval_fn(x)
for t in typ.__mro__[1:]:
if (aval_fn := pytype_aval_mappings.get(t)):
return aval_fn(x)
if isinstance(x, AbstractValue):
return x
if hasattr(x, '__jax_array__'):
return shaped_abstractify(x.__jax_array__())
if hasattr(x, 'dtype'):
return ShapedArray(np.shape(x), x.dtype, weak_type=getattr(x, 'weak_type', False))
raise TypeError(
f"Cannot interpret value of type {typ} as an abstract array; it "
"does not have a dtype attribute")


def abstractify(x):
for typ in type(x).__mro__:
aval_fn = pytype_aval_mappings.get(typ)
if aval_fn: return aval_fn(x)
if hasattr(x, '__jax_array__'):
return abstractify(x.__jax_array__())
raise TypeError(f"Argument '{x}' of type '{type(x)}' is not a valid JAX type")
if isinstance(x, Tracer):
raise TypeError(f"Argument '{x}' of type '{type(x)}' is not a valid JAX type")
return get_aval(x)


def get_aval(x):
if isinstance(x, Tracer):
return x.aval
else:
return abstractify(x)
typ = type(x)
if (aval_fn := pytype_aval_mappings.get(typ)): # fast path
return aval_fn(x)
for t in typ.__mro__[1:]:
if (aval_fn := pytype_aval_mappings.get(t)):
return aval_fn(x)
if hasattr(x, '__jax_array__'):
return get_aval(x.__jax_array__())
raise TypeError(f"Argument '{x}' of type '{typ}' is not a valid JAX type")

get_type = get_aval

def is_concrete(x):
return to_concrete_value(x) is not None
Expand Down Expand Up @@ -1831,13 +1846,6 @@ def to_tangent_aval(self):
return DShapedArray(self.shape, primal_dtype_to_tangent_dtype(self.dtype),
self.weak_type)

pytype_aval_mappings: dict[type, Callable[[Any], AbstractValue]] = {}
shaped_abstractify_handlers: dict[Any, Callable[[Any], AbstractValue]] = {}

def _str_abstractify(x):
raise TypeError(f"Argument '{x}' of type {type(x)} is not a valid JAX type")
pytype_aval_mappings[str] = _str_abstractify
shaped_abstractify_handlers[str] = _str_abstractify

class DArray:
_aval: DShapedArray
Expand Down Expand Up @@ -1894,7 +1902,6 @@ def _darray_aval(x):
return DShapedArray(x._aval.shape, x._aval.dtype, x._aval.weak_type)

pytype_aval_mappings[DArray] = _darray_aval
shaped_abstractify_handlers[DArray] = _darray_aval


@dataclass(frozen=True)
Expand Down Expand Up @@ -1924,11 +1931,10 @@ def __init__(self, aval, buf):
aval = property(lambda self: self._aval)
shape = property(lambda self: self._aval.shape)
dtype = property(lambda self: self._aval.dtype)
def __getitem__(self, idx): return get_aval(self)._getitem(self, idx)
def __setitem__(self, idx, x): return get_aval(self)._setitem(self, idx, x)
def __getitem__(self, idx): return self._aval._getitem(self, idx)
def __setitem__(self, idx, x): return self._aval._setitem(self, idx, x)
def __repr__(self) -> str: return 'Mutable' + repr(self[...])
pytype_aval_mappings[MutableArray] = lambda x: x._aval
shaped_abstractify_handlers[MutableArray] = lambda x: x._aval

def mutable_array(init_val):
return mutable_array_p.bind(init_val)
Expand Down Expand Up @@ -1984,7 +1990,6 @@ def __init__(self, buf):
def block_until_ready(self):
self._buf.block_until_ready()
pytype_aval_mappings[Token] = lambda _: abstract_token
shaped_abstractify_handlers[Token] = lambda _: abstract_token


# TODO(dougalm): Deprecate these. They're just here for backwards compat.
Expand Down
19 changes: 14 additions & 5 deletions jax/_src/cudnn/fused_attention_stablehlo.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,11 +348,20 @@ def check_is_flash_attention(
)
else:
# Regular attention conditions
if not ((H <= 128 and H % 8 == 0) and
(not is_training or not has_bias or T % 2 == 0 and S % 2 == 0)):
raise NotImplementedError(
f"Unsupported sequence length Q {T}, KV {S} and head dim {H}."
)
# Check the head dim.
is_on_hopper = check_compute_capability("9.0")
H_max = 256 if cudnn_version >= 90500 and is_on_hopper else 128
if not (H <= H_max and H % 8 == 0):
raise NotImplementedError(
f"The head dim must be <= {H_max} and a mutiple of 8, "
f"but got {H}."
)

# Check patterns with bias, seqlen should be divisible by 2
if (is_training and has_bias and (T % 2 != 0 or S % 2 != 0)):
raise NotImplementedError(
f"Unsupported sequence length Q {T}, KV {S}."
)

def check_cudnn_version():
# check if cuDNN is installed
Expand Down
9 changes: 5 additions & 4 deletions jax/_src/debugging.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,8 +300,9 @@ def check_unused_args(self, used_args, args, kwargs):

formatter = _DebugPrintFormatChecker()

def _format_print_callback(fmt: str, *args, **kwargs):
sys.stdout.write(fmt.format(*args, **kwargs) + "\n")
def _format_print_callback(fmt: str, np_printoptions, *args, **kwargs):
with np.printoptions(**np_printoptions):
sys.stdout.write(fmt.format(*args, **kwargs) + "\n")

def debug_print(fmt: str, *args, ordered: bool = False, **kwargs) -> None:
"""Prints values and works in staged out JAX functions.
Expand Down Expand Up @@ -338,8 +339,8 @@ def debug_print(fmt: str, *args, **kwargs):
# Check that we provide the correct arguments to be formatted.
formatter.format(fmt, *args, **kwargs)

debug_callback(functools.partial(_format_print_callback, fmt), *args,
**kwargs, ordered=ordered)
debug_callback(functools.partial(_format_print_callback, fmt, np.get_printoptions()),
*args, **kwargs, ordered=ordered)


# Sharding visualization
Expand Down
1 change: 0 additions & 1 deletion jax/_src/earray.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,6 @@ def _earray_shard_arg_handler(xs, shardings, layouts, copy_semantics):
return pxla.shard_args(phys_shardings, layouts, copy_semantics, arrs)
pxla.shard_arg_handlers[EArray] = _earray_shard_arg_handler

core.shaped_abstractify_handlers[EArray] = lambda self: self.aval
core.pytype_aval_mappings[EArray] = lambda x: x.aval
xla.canonicalize_dtype_handlers[EArray] = lambda x: x
tree_util.dispatch_registry.register_node(
Expand Down
1 change: 0 additions & 1 deletion jax/_src/export/shape_poly.py
Original file line number Diff line number Diff line change
Expand Up @@ -1205,7 +1205,6 @@ def _geq_decision(e1: DimSize, e2: DimSize, cmp_str: Callable[[], str]) -> bool:
f"Symbolic dimension comparison {cmp_str()} is inconclusive.{describe_scope}")

core.pytype_aval_mappings[_DimExpr] = _DimExpr._get_aval
core.shaped_abstractify_handlers[_DimExpr] = _DimExpr._get_aval
dtypes._weak_types.append(_DimExpr)

def _convertible_to_int(p: DimSize) -> bool:
Expand Down
5 changes: 1 addition & 4 deletions jax/_src/interpreters/partial_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -1569,10 +1569,7 @@ def get_referent(self):
val = frame.constvar_to_val.get(frame.tracer_to_var.get(id(self)))
return self if val is None else get_referent(val)


def _dynamic_jaxpr_tracer_shaped_abstractify(x):
return x.aval
core.shaped_abstractify_handlers[DynamicJaxprTracer] = _dynamic_jaxpr_tracer_shaped_abstractify
core.pytype_aval_mappings[DynamicJaxprTracer] = lambda x: x.aval

def make_jaxpr_effects(constvars, invars, outvars, eqns) -> effects.Effects:
sentinel = object()
Expand Down
Loading
Loading