Skip to content

Commit

Permalink
issue: error caused by acc += tl.dot(p.to(v.type.element_ty), v)
Browse files Browse the repository at this point in the history
  • Loading branch information
alexkranias-amd committed Dec 6, 2024
1 parent 2c78a3b commit c118787
Show file tree
Hide file tree
Showing 4 changed files with 6 additions and 4 deletions.
3 changes: 2 additions & 1 deletion flash_attn/flash_attn_triton_amd/fwd_prefill.py
Original file line number Diff line number Diff line change
Expand Up @@ -592,11 +592,12 @@ def attention_prefill_forward_triton_impl(
torch.float8_e5m2fnuz,
}
is_fp8 = q.dtype in fp8_types
is_fp8 = False
# check if varlen
is_varlen = layout == "thd"

# if qkv are fp8, then find scaling factor for quantization
q_scale, k_scale, v_scale = create_scale_tensors(q, k, v, SCALE_PER_HEAD=True, layout=layout)
q_scale, k_scale, v_scale = create_scale_tensors(q, k, v, SCALE_PER_HEAD=True, layout=layout) # TODO: if SCALE_PER_HEAD: within the kernel itself just compute qkv_scale = tl.max(q or k or v)
q_scale_stride_z = q_scale.stride(0)
kv_scale_stride_z = k_scale.stride(0)

Expand Down
2 changes: 1 addition & 1 deletion flash_attn/flash_attn_triton_amd/fwd_ref.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ def attention_forward_core_ref_impl(q, k, v, sm_scale, causal, use_exp2):
print("softmax_lse:", softmax_lse, softmax_lse.shape)

# Compute output
o = torch.matmul(softmax, v.to(torch.float32)).to(torch.float16)
o = torch.matmul(softmax, v.to(torch.float32))
if DEBUG_CORE:
print("o:", o, o.shape)

Expand Down
4 changes: 2 additions & 2 deletions flash_attn/flash_attn_triton_amd/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -464,7 +464,7 @@ def test_op_prefill_fwd_impl(Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, causal, return
if DEBUG:
print("softmax_lse_triton:", softmax_lse_triton, softmax_lse_triton.shape)
print("softmax_lse_ref:", softmax_lse_ref, softmax_lse_ref.shape)
# torch.testing.assert_close(softmax_lse_triton, softmax_lse_ref, atol=ATOL, rtol=RTOL)
torch.testing.assert_close(softmax_lse_triton, softmax_lse_ref, atol=ATOL, rtol=RTOL)

if DEBUG:
print("exp_scores_triton", exp_scores_triton)
Expand All @@ -478,7 +478,7 @@ def test_op_prefill_fwd_impl(Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, causal, return
print("softmax_lse_triton:", softmax_lse_triton, softmax_lse_triton.shape)
print("softmax_triton:", softmax_triton, softmax_triton.shape)
print("softmax_ref:", softmax_ref, softmax_ref.shape)
# torch.testing.assert_close(softmax_triton, softmax_ref, atol=ATOL, rtol=RTOL)
torch.testing.assert_close(softmax_triton, softmax_ref, atol=ATOL, rtol=RTOL)

# if triton is fp8, cast to fp16 in order to compare with ref
if output_triton.dtype in {torch.float8_e4m3fnuz, torch.float8_e5m2, torch.float8_e5m2fnuz}:
Expand Down
1 change: 1 addition & 0 deletions flash_attn/flash_attn_triton_amd/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,6 +305,7 @@ def create_scale_tensors(q, k, v, SCALE_PER_HEAD=False, layout='bshd'):
torch.float8_e5m2fnuz,
}
is_fp8 = q.dtype in fp8_types
is_fp8 = False

if layout == 'bhsd':
seqlen_loc = 2
Expand Down

0 comments on commit c118787

Please sign in to comment.