From c118787868e8cd1652976fbed20e516b48407d9d Mon Sep 17 00:00:00 2001 From: Alex Kranias Date: Fri, 6 Dec 2024 16:21:53 -0600 Subject: [PATCH] issue: error caused by acc += tl.dot(p.to(v.type.element_ty), v) --- flash_attn/flash_attn_triton_amd/fwd_prefill.py | 3 ++- flash_attn/flash_attn_triton_amd/fwd_ref.py | 2 +- flash_attn/flash_attn_triton_amd/test.py | 4 ++-- flash_attn/flash_attn_triton_amd/utils.py | 1 + 4 files changed, 6 insertions(+), 4 deletions(-) diff --git a/flash_attn/flash_attn_triton_amd/fwd_prefill.py b/flash_attn/flash_attn_triton_amd/fwd_prefill.py index 573f86380..4d5ef0843 100644 --- a/flash_attn/flash_attn_triton_amd/fwd_prefill.py +++ b/flash_attn/flash_attn_triton_amd/fwd_prefill.py @@ -592,11 +592,12 @@ def attention_prefill_forward_triton_impl( torch.float8_e5m2fnuz, } is_fp8 = q.dtype in fp8_types + is_fp8 = False # check if varlen is_varlen = layout == "thd" # if qkv are fp8, then find scaling factor for quantization - q_scale, k_scale, v_scale = create_scale_tensors(q, k, v, SCALE_PER_HEAD=True, layout=layout) + q_scale, k_scale, v_scale = create_scale_tensors(q, k, v, SCALE_PER_HEAD=True, layout=layout) # TODO: if SCALE_PER_HEAD: within the kernel itself just compute qkv_scale = tl.max(q or k or v) q_scale_stride_z = q_scale.stride(0) kv_scale_stride_z = k_scale.stride(0) diff --git a/flash_attn/flash_attn_triton_amd/fwd_ref.py b/flash_attn/flash_attn_triton_amd/fwd_ref.py index 2ae2a3b4d..74c49aa35 100644 --- a/flash_attn/flash_attn_triton_amd/fwd_ref.py +++ b/flash_attn/flash_attn_triton_amd/fwd_ref.py @@ -105,7 +105,7 @@ def attention_forward_core_ref_impl(q, k, v, sm_scale, causal, use_exp2): print("softmax_lse:", softmax_lse, softmax_lse.shape) # Compute output - o = torch.matmul(softmax, v.to(torch.float32)).to(torch.float16) + o = torch.matmul(softmax, v.to(torch.float32)) if DEBUG_CORE: print("o:", o, o.shape) diff --git a/flash_attn/flash_attn_triton_amd/test.py b/flash_attn/flash_attn_triton_amd/test.py index 28b4455d3..aacdc58b7 100644 --- a/flash_attn/flash_attn_triton_amd/test.py +++ b/flash_attn/flash_attn_triton_amd/test.py @@ -464,7 +464,7 @@ def test_op_prefill_fwd_impl(Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, causal, return if DEBUG: print("softmax_lse_triton:", softmax_lse_triton, softmax_lse_triton.shape) print("softmax_lse_ref:", softmax_lse_ref, softmax_lse_ref.shape) - # torch.testing.assert_close(softmax_lse_triton, softmax_lse_ref, atol=ATOL, rtol=RTOL) + torch.testing.assert_close(softmax_lse_triton, softmax_lse_ref, atol=ATOL, rtol=RTOL) if DEBUG: print("exp_scores_triton", exp_scores_triton) @@ -478,7 +478,7 @@ def test_op_prefill_fwd_impl(Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, causal, return print("softmax_lse_triton:", softmax_lse_triton, softmax_lse_triton.shape) print("softmax_triton:", softmax_triton, softmax_triton.shape) print("softmax_ref:", softmax_ref, softmax_ref.shape) - # torch.testing.assert_close(softmax_triton, softmax_ref, atol=ATOL, rtol=RTOL) + torch.testing.assert_close(softmax_triton, softmax_ref, atol=ATOL, rtol=RTOL) # if triton is fp8, cast to fp16 in order to compare with ref if output_triton.dtype in {torch.float8_e4m3fnuz, torch.float8_e5m2, torch.float8_e5m2fnuz}: diff --git a/flash_attn/flash_attn_triton_amd/utils.py b/flash_attn/flash_attn_triton_amd/utils.py index 934439ac6..53ecf696a 100644 --- a/flash_attn/flash_attn_triton_amd/utils.py +++ b/flash_attn/flash_attn_triton_amd/utils.py @@ -305,6 +305,7 @@ def create_scale_tensors(q, k, v, SCALE_PER_HEAD=False, layout='bshd'): torch.float8_e5m2fnuz, } is_fp8 = q.dtype in fp8_types + is_fp8 = False if layout == 'bhsd': seqlen_loc = 2