Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Got an error when offloading SplashAttention pallas call #25841

Open
hanzhi713 opened this issue Jan 10, 2025 · 0 comments
Open

Got an error when offloading SplashAttention pallas call #25841

hanzhi713 opened this issue Jan 10, 2025 · 0 comments
Labels
bug Something isn't working

Comments

@hanzhi713
Copy link

hanzhi713 commented Jan 10, 2025

Description

Reproducer:

import jax
from jax.ad_checkpoint import Offloadable, Recompute
from jax.experimental.pallas import pallas_call_p
import jax.numpy as jnp
from jax.experimental.pallas.ops.tpu.splash_attention import (
    splash_attention_kernel,
    splash_attention_mask,
)


@jax.jit
def tpu_attn(query, key, value):
    num_heads = query.shape[1]
    mask_shape = (query.shape[2], key.shape[2])
    mask = splash_attention_mask.FullMask(mask_shape)
    kernel = splash_attention_kernel.make_splash_mha(
        mask=splash_attention_mask.MultiHeadMask(masks=[mask] * num_heads),
        block_sizes=None,
        head_shards=1,
        q_seq_shards=1,
    )
    kernel = jax.vmap(kernel)
    context = kernel(q=query, k=key, v=value)
    return context.sum()


def policy(prim, *_, **__):
    if prim is pallas_call_p:
        return Offloadable("device", "pinned_host")
    return Recompute


q = jnp.ones((1, 1, 128, 128))
k = jnp.ones((1, 1, 128, 128))
v = jnp.ones((1, 1, 128, 128))
fn = jax.grad(jax.remat(tpu_attn, policy=policy))
print(jax.jit(fn).lower(q, k, v).as_text("hlo"))

Error traceback:

Traceback (most recent call last):
  File "/root/splash_test.py", line 38, in <module>
    print(jax.jit(fn).lower(q, k, v).as_text("hlo"))
  File "/opt/venv/lib/python3.10/site-packages/jax/_src/traceback_util.py", line 180, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/opt/venv/lib/python3.10/site-packages/jax/_src/pjit.py", line 485, in lower
    return trace(*args, **kwargs).lower()
  File "/opt/venv/lib/python3.10/site-packages/jax/_src/stages.py", line 775, in lower
    lowering = new_callable()
  File "/opt/venv/lib/python3.10/site-packages/jax/_src/pjit.py", line 1624, in _resolve_and_lower
    return _pjit_lower(
  File "/opt/venv/lib/python3.10/site-packages/jax/_src/pjit.py", line 1789, in _pjit_lower
    return pxla.lower_sharding_computation(
  File "/opt/venv/lib/python3.10/site-packages/jax/_src/profiler.py", line 333, in wrapper
    return func(*args, **kwargs)
  File "/opt/venv/lib/python3.10/site-packages/jax/_src/interpreters/pxla.py", line 2323, in lower_sharding_computation
    nreps, tuple_args, shape_poly_state) = _cached_lowering_to_hlo(
  File "/opt/venv/lib/python3.10/site-packages/jax/_src/interpreters/pxla.py", line 1953, in _cached_lowering_to_hlo
    lowering_result = mlir.lower_jaxpr_to_module(
  File "/opt/venv/lib/python3.10/site-packages/jax/_src/interpreters/mlir.py", line 1194, in lower_jaxpr_to_module
    lower_jaxpr_to_fun(
  File "/opt/venv/lib/python3.10/site-packages/jax/_src/interpreters/mlir.py", line 1678, in lower_jaxpr_to_fun
    out_vals, tokens_out = jaxpr_subcomp(
  File "/opt/venv/lib/python3.10/site-packages/jax/_src/interpreters/mlir.py", line 1950, in jaxpr_subcomp
    ans = lower_per_platform(rule_ctx, str(eqn.primitive),
  File "/opt/venv/lib/python3.10/site-packages/jax/_src/interpreters/mlir.py", line 2068, in lower_per_platform
    output = kept_rules[0](ctx, *rule_args, **rule_kwargs)
  File "/opt/venv/lib/python3.10/site-packages/jax/_src/pjit.py", line 1944, in _pjit_lowering
    func = _pjit_cached_lower_jaxpr_to_fun(
  File "/opt/venv/lib/python3.10/site-packages/jax/_src/pjit.py", line 1927, in _pjit_cached_lower_jaxpr_to_fun
    func = mlir.lower_jaxpr_to_fun(
  File "/opt/venv/lib/python3.10/site-packages/jax/_src/interpreters/mlir.py", line 1678, in lower_jaxpr_to_fun
    out_vals, tokens_out = jaxpr_subcomp(
  File "/opt/venv/lib/python3.10/site-packages/jax/_src/interpreters/mlir.py", line 1950, in jaxpr_subcomp
    ans = lower_per_platform(rule_ctx, str(eqn.primitive),
  File "/opt/venv/lib/python3.10/site-packages/jax/_src/interpreters/mlir.py", line 2068, in lower_per_platform
    output = kept_rules[0](ctx, *rule_args, **rule_kwargs)
  File "/opt/venv/lib/python3.10/site-packages/jax/_src/pjit.py", line 1944, in _pjit_lowering
    func = _pjit_cached_lower_jaxpr_to_fun(
  File "/opt/venv/lib/python3.10/site-packages/jax/_src/pjit.py", line 1927, in _pjit_cached_lower_jaxpr_to_fun
    func = mlir.lower_jaxpr_to_fun(
  File "/opt/venv/lib/python3.10/site-packages/jax/_src/interpreters/mlir.py", line 1678, in lower_jaxpr_to_fun
    out_vals, tokens_out = jaxpr_subcomp(
  File "/opt/venv/lib/python3.10/site-packages/jax/_src/interpreters/mlir.py", line 1950, in jaxpr_subcomp
    ans = lower_per_platform(rule_ctx, str(eqn.primitive),
  File "/opt/venv/lib/python3.10/site-packages/jax/_src/interpreters/mlir.py", line 2068, in lower_per_platform
    output = kept_rules[0](ctx, *rule_args, **rule_kwargs)
  File "/opt/venv/lib/python3.10/site-packages/jax/_src/dispatch.py", line 584, in _tpu_gpu_device_put_lowering
    return list(map(lower, xs, devices, ctx.avals_in, ctx.avals_out))
ValueError: safe_map() argument 2 is shorter than argument 1

System info (python version, jaxlib version, accelerator, etc.)

Tested on v5p

jax                            0.4.39.dev20250110
jaxlib                         0.4.39.dev20250110
libtpu                         0.0.8.dev20250110+nightly
@hanzhi713 hanzhi713 added the bug Something isn't working label Jan 10, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

1 participant