You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Traceback (most recent call last):
File "/root/splash_test.py", line 38, in <module>
print(jax.jit(fn).lower(q, k, v).as_text("hlo"))
File "/opt/venv/lib/python3.10/site-packages/jax/_src/traceback_util.py", line 180, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
File "/opt/venv/lib/python3.10/site-packages/jax/_src/pjit.py", line 485, in lower
return trace(*args, **kwargs).lower()
File "/opt/venv/lib/python3.10/site-packages/jax/_src/stages.py", line 775, in lower
lowering = new_callable()
File "/opt/venv/lib/python3.10/site-packages/jax/_src/pjit.py", line 1624, in _resolve_and_lower
return _pjit_lower(
File "/opt/venv/lib/python3.10/site-packages/jax/_src/pjit.py", line 1789, in _pjit_lower
return pxla.lower_sharding_computation(
File "/opt/venv/lib/python3.10/site-packages/jax/_src/profiler.py", line 333, in wrapper
return func(*args, **kwargs)
File "/opt/venv/lib/python3.10/site-packages/jax/_src/interpreters/pxla.py", line 2323, in lower_sharding_computation
nreps, tuple_args, shape_poly_state) = _cached_lowering_to_hlo(
File "/opt/venv/lib/python3.10/site-packages/jax/_src/interpreters/pxla.py", line 1953, in _cached_lowering_to_hlo
lowering_result = mlir.lower_jaxpr_to_module(
File "/opt/venv/lib/python3.10/site-packages/jax/_src/interpreters/mlir.py", line 1194, in lower_jaxpr_to_module
lower_jaxpr_to_fun(
File "/opt/venv/lib/python3.10/site-packages/jax/_src/interpreters/mlir.py", line 1678, in lower_jaxpr_to_fun
out_vals, tokens_out = jaxpr_subcomp(
File "/opt/venv/lib/python3.10/site-packages/jax/_src/interpreters/mlir.py", line 1950, in jaxpr_subcomp
ans = lower_per_platform(rule_ctx, str(eqn.primitive),
File "/opt/venv/lib/python3.10/site-packages/jax/_src/interpreters/mlir.py", line 2068, in lower_per_platform
output = kept_rules[0](ctx, *rule_args, **rule_kwargs)
File "/opt/venv/lib/python3.10/site-packages/jax/_src/pjit.py", line 1944, in _pjit_lowering
func = _pjit_cached_lower_jaxpr_to_fun(
File "/opt/venv/lib/python3.10/site-packages/jax/_src/pjit.py", line 1927, in _pjit_cached_lower_jaxpr_to_fun
func = mlir.lower_jaxpr_to_fun(
File "/opt/venv/lib/python3.10/site-packages/jax/_src/interpreters/mlir.py", line 1678, in lower_jaxpr_to_fun
out_vals, tokens_out = jaxpr_subcomp(
File "/opt/venv/lib/python3.10/site-packages/jax/_src/interpreters/mlir.py", line 1950, in jaxpr_subcomp
ans = lower_per_platform(rule_ctx, str(eqn.primitive),
File "/opt/venv/lib/python3.10/site-packages/jax/_src/interpreters/mlir.py", line 2068, in lower_per_platform
output = kept_rules[0](ctx, *rule_args, **rule_kwargs)
File "/opt/venv/lib/python3.10/site-packages/jax/_src/pjit.py", line 1944, in _pjit_lowering
func = _pjit_cached_lower_jaxpr_to_fun(
File "/opt/venv/lib/python3.10/site-packages/jax/_src/pjit.py", line 1927, in _pjit_cached_lower_jaxpr_to_fun
func = mlir.lower_jaxpr_to_fun(
File "/opt/venv/lib/python3.10/site-packages/jax/_src/interpreters/mlir.py", line 1678, in lower_jaxpr_to_fun
out_vals, tokens_out = jaxpr_subcomp(
File "/opt/venv/lib/python3.10/site-packages/jax/_src/interpreters/mlir.py", line 1950, in jaxpr_subcomp
ans = lower_per_platform(rule_ctx, str(eqn.primitive),
File "/opt/venv/lib/python3.10/site-packages/jax/_src/interpreters/mlir.py", line 2068, in lower_per_platform
output = kept_rules[0](ctx, *rule_args, **rule_kwargs)
File "/opt/venv/lib/python3.10/site-packages/jax/_src/dispatch.py", line 584, in _tpu_gpu_device_put_lowering
return list(map(lower, xs, devices, ctx.avals_in, ctx.avals_out))
ValueError: safe_map() argument 2 is shorter than argument 1
System info (python version, jaxlib version, accelerator, etc.)
Description
Reproducer:
Error traceback:
System info (python version, jaxlib version, accelerator, etc.)
Tested on v5p
The text was updated successfully, but these errors were encountered: