Skip to content

Commit

Permalink
dump qk
Browse files Browse the repository at this point in the history
  • Loading branch information
micmelesse committed Jan 7, 2025
1 parent 0121712 commit 1cef817
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 18 deletions.
45 changes: 28 additions & 17 deletions flash_attn/flash_attn_triton_amd/fwd_prefill.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def compute_alibi_block(alibi_slope, seqlen_q, seqlen_k, offs_m, offs_n, transpo
@triton.jit
def _attn_fwd_inner(acc, l_i, m_i, q, k_ptrs, v_ptrs, bias_ptrs, stride_kn, stride_vk, stride_bn, stride_sn, start_m,
actual_seqlen_k, actual_seqlen_q, dropout_p, philox_seed, philox_ptrs, sd_mask_ptrs, dropout_mask_ptrs,
block_min, block_max, offs_n_causal, masked_blocks, n_extra_tokens, alibi_slope,
qk_fp8_ptrs, block_min, block_max, offs_n_causal, masked_blocks, n_extra_tokens, alibi_slope,
q_scale, k_scale, v_scale, p_scale, p_inv_scale, IS_FP8: tl.constexpr,
IS_CAUSAL: tl.constexpr, BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr,
OFFS_M: tl.constexpr, OFFS_N: tl.constexpr, PRE_LOAD_V: tl.constexpr, MASK_STEPS: tl.constexpr,
Expand Down Expand Up @@ -100,12 +100,16 @@ def _attn_fwd_inner(acc, l_i, m_i, q, k_ptrs, v_ptrs, bias_ptrs, stride_kn, stri
size_n = start_n + OFFS_N[None, :]
mask = size_n < boundary_m[:, None]
qk = tl.where(mask, qk, float("-inf"))

# compute mask for scores
p_mask = (OFFS_M[:, None] < actual_seqlen_q) & ((start_n + tl.arange(0, BLOCK_N))[None, :] < actual_seqlen_k)

# -- compute qk ----
qk += tl.dot(q, k)
qk_scaled = qk * SM_SCALE
if IS_FP8:
qk_scaled *= q_scale * k_scale # descale qk after matmul if quantized
tl.store(qk_fp8_ptrs, qk_scaled, mask=p_mask)

if IS_CAUSAL:
causal_boundary = start_n + offs_n_causal
Expand Down Expand Up @@ -134,11 +138,11 @@ def _attn_fwd_inner(acc, l_i, m_i, q, k_ptrs, v_ptrs, bias_ptrs, stride_kn, stri
p = tl.math.exp2(q_shifted * RCP_LN2)
else:
p = tl.math.exp(q_shifted)

p_mask = (OFFS_M[:, None] < actual_seqlen_q) & ((start_n + tl.arange(0, BLOCK_N))[None, :] < actual_seqlen_k)
if IS_FP8:
p *= p_inv_scale

# CAVEAT: Must update l_ij before applying dropout
l_ij = tl.sum(p, 1) # p is fp32 at this point
l_ij = tl.sum(p, 1)
if ENABLE_DROPOUT:
if tl_DROPOUT_USE_PYTORCH:
dropout_mask = tl.load(dropout_mask_ptrs, mask=p_mask)
Expand Down Expand Up @@ -173,10 +177,9 @@ def _attn_fwd_inner(acc, l_i, m_i, q, k_ptrs, v_ptrs, bias_ptrs, stride_kn, stri
l_i = l_i * alpha + l_ij
# update m_i and l_i
m_i = m_ij
acc += tl.dot(p.to(v.type.element_ty), v)
if IS_FP8:
acc += tl.dot((p * p_inv_scale).to(v.type.element_ty), v) * p_scale * v_scale
else:
acc += tl.dot(p.to(v.type.element_ty), v)
acc *= p_scale * v_scale
k_ptrs += BLOCK_N * stride_kn
v_ptrs += BLOCK_N * stride_vk
if bias_ptrs is not None:
Expand Down Expand Up @@ -271,7 +274,7 @@ def attn_fwd(Q, K, V, bias,
stride_kz, stride_kh, stride_kn, stride_kk, stride_vz, stride_vh, stride_vk, stride_vn,
stride_oz, stride_oh, stride_om, stride_on, stride_bz, stride_bh, stride_bm, stride_bn, stride_az, stride_ah,
stride_sz, stride_sh, stride_sm, stride_sn, stride_lse_z, stride_lse_h, stride_lse_m, cu_seqlens_q, cu_seqlens_k,
dropout_p, philox_seed, philox_offset_base, sd_mask, dropout_mask, alibi_slopes, HQ: tl.constexpr,
dropout_p, philox_seed, philox_offset_base, sd_mask, dropout_mask, alibi_slopes, qk_fp8, HQ: tl.constexpr,
HK: tl.constexpr, ACTUAL_BLOCK_DMODEL: tl.constexpr, MAX_SEQLENS_Q: tl.constexpr,
MAX_SEQLENS_K: tl.constexpr, VARLEN: tl.constexpr, IS_CAUSAL: tl.constexpr, BLOCK_M: tl.constexpr,
BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr, PRE_LOAD_V: tl.constexpr, USE_BIAS: tl.constexpr,
Expand Down Expand Up @@ -416,8 +419,11 @@ def attn_fwd(Q, K, V, bias,
# print("v_scale", v_scale)
# print("p_scale", p_scale)
# print("p_inv_scale", p_inv_scale)
qk_fp8_offset = qk_fp8 + off_z * stride_sz + off_h_q * stride_sh #+ cu_seqlens_q_start * stride_sm
qk_fp8_ptrs = qk_fp8_offset + offs_m[:, None] * stride_sm + offs_n[None, :] * stride_sn #+ cu_seqlens_q_start * stride_sm
else:
q_scale, k_scale, v_scale, p_scale, p_inv_scale = 1.0, 1.0, 1.0, 1.0, 1.0
qk_fp8_ptrs = None

# Here we compute how many full and masked blocks we have.
padded_block_k = n_extra_tokens != 0
Expand All @@ -441,7 +447,7 @@ def attn_fwd(Q, K, V, bias,
block_max = (n_blocks - masked_blocks) * BLOCK_N
acc, l_i, m_i = _attn_fwd_inner(acc, l_i, m_i, q, k_ptrs, v_ptrs, bias_ptrs, stride_kn, stride_vk, stride_bn, stride_sn,
start_m, seqlen_k, seqlen_q, dropout_p, philox_seed, philox_ptrs,
sd_mask_ptrs, dropout_mask_ptrs,
sd_mask_ptrs, dropout_mask_ptrs, qk_fp8_ptrs,
# _, _, offs_n_causal, masked_blocks, n_extra_tokens, _
block_min, block_max, 0, 0, 0, alibi_slope,
q_scale, k_scale, v_scale, p_scale, p_inv_scale, IS_FP8,
Expand Down Expand Up @@ -471,7 +477,7 @@ def attn_fwd(Q, K, V, bias,
philox_ptrs += n_full_blocks * BLOCK_N * stride_sn
acc, l_i, m_i = _attn_fwd_inner(acc, l_i, m_i, q, k_ptrs, v_ptrs, bias_ptrs, stride_kn, stride_vk, stride_bn, stride_sn,
start_m, seqlen_k, seqlen_q, dropout_p, philox_seed, philox_ptrs,
sd_mask_ptrs, dropout_mask_ptrs, block_min, block_max, offs_n_causal, masked_blocks,
sd_mask_ptrs, dropout_mask_ptrs, qk_fp8_ptrs, block_min, block_max, offs_n_causal, masked_blocks,
n_extra_tokens, alibi_slope,
q_scale, k_scale, v_scale, p_scale, p_inv_scale, IS_FP8,
IS_CAUSAL, BLOCK_M, BLOCK_DMODEL, BLOCK_N, offs_m, offs_n,
Expand Down Expand Up @@ -612,15 +618,13 @@ def attention_prefill_forward_triton_impl(
# Set up layout-specific dimensions
if layout == "bhsd":
seqlen_loc = 2
dim_loc = 3
elif layout == "bshd":
seqlen_loc = 1
dim_loc = 3


# Compute max for each batch-head pair across seqlen and dim
q_scale = torch.maximum(q_float32.abs().amax(dim=(seqlen_loc, dim_loc)), torch.tensor(eps))
k_scale = torch.maximum(k_float32.abs().amax(dim=(seqlen_loc, dim_loc)), torch.tensor(eps))
v_scale = torch.maximum(v_float32.abs().amax(dim=(seqlen_loc, dim_loc)), torch.tensor(eps))
q_scale = torch.maximum(q_float32.abs().amax(dim=(seqlen_loc, 3)), torch.tensor(eps))
k_scale = torch.maximum(k_float32.abs().amax(dim=(seqlen_loc, 3)), torch.tensor(eps))
v_scale = torch.maximum(v_float32.abs().amax(dim=(seqlen_loc, 3)), torch.tensor(eps))

# Divide by type max
q_scale = q_scale / type_max
Expand Down Expand Up @@ -649,10 +653,15 @@ def attention_prefill_forward_triton_impl(
kv_scale_stride_z = k_scale.stride(0)
p_scale_stride_z = p_scale.stride(0)
p_inv_scale_stride_z = p_inv_scale.stride(0)

# dump intermedia results
qk_fp8 = torch.zeros((batch, nheads_q, max_seqlens_q, max_seqlens_k), dtype=q.dtype, device=q.device)
acc_fp8 = torch.zeros((batch, nheads_q, max_seqlens_q, max_seqlens_k), dtype=q.dtype, device=q.device)
else:
# For non-FP8 types, use dummy values (no scaling needed)
q_scale = k_scale = v_scale = p_scale = p_inv_scale = 1
q_scale_stride_z = kv_scale_stride_z = p_scale_stride_z = p_inv_scale_stride_z = 0
qk_fp8= None

if DEBUG:
print("is_fp8:", is_fp8)
Expand Down Expand Up @@ -731,7 +740,7 @@ def attention_prefill_forward_triton_impl(
sm_scale, softmax_lse, o, *q_strides, *k_strides, *v_strides, *o_strides,
*bias_strides, *alibi_strides, *scores_strides, stride_lse_z, stride_lse_h, stride_lse_m, cu_seqlens_q, cu_seqlens_k,
dropout_p=dropout_p, philox_seed=philox_seed, philox_offset_base=philox_offset, sd_mask=sd_mask, dropout_mask=dropout_mask, alibi_slopes=alibi_slopes,
HQ=nheads_q, HK=nheads_k, ACTUAL_BLOCK_DMODEL=head_size, MAX_SEQLENS_Q=max_seqlens_q,
qk_fp8=qk_fp8, HQ=nheads_q, HK=nheads_k, ACTUAL_BLOCK_DMODEL=head_size, MAX_SEQLENS_Q=max_seqlens_q,
MAX_SEQLENS_K=max_seqlens_k, IS_CAUSAL=causal, VARLEN=is_varlen,
BLOCK_DMODEL=padded_d_model, USE_BIAS=False if bias is None else True,
USE_ALIBI=False if alibi_slopes is None else True, ENABLE_DROPOUT=dropout_p
Expand All @@ -747,5 +756,7 @@ def attention_prefill_forward_triton_impl(
print("dropout_mask:", dropout_mask, dropout_mask.shape if dropout_mask is not None else None)
print("dropout_fraction fwd:", 1.0 - (dropout_mask.sum()/ dropout_mask.numel()).item())
write_dropout_mask(dropout_mask, "dropout_mask_fwd")
if is_fp8:
print("qk_fp8:", qk_fp8)

return o, softmax_lse, sd_mask.to(o.dtype) if return_softmax else None
2 changes: 1 addition & 1 deletion flash_attn/flash_attn_triton_amd/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -512,7 +512,7 @@ def test_op_prefill_fwd_impl(Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, causal, dropou
@pytest.mark.parametrize('causal', [False])
@pytest.mark.parametrize('dropout_p', [0.0])
@pytest.mark.parametrize('layout', ["bshd"]) # expects bshd args
@pytest.mark.parametrize('DEBUG_INPUT', [False])
@pytest.mark.parametrize('DEBUG_INPUT', [True])
def test_op_prefill_fp8(Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, causal, dropout_p, layout, DEBUG_INPUT):
device = "cuda"
window_size = (-1, -1)
Expand Down

0 comments on commit 1cef817

Please sign in to comment.