diff --git a/flash_attn/flash_attn_triton_amd/fwd_prefill.py b/flash_attn/flash_attn_triton_amd/fwd_prefill.py index 44efa9fd5..ed4403a17 100644 --- a/flash_attn/flash_attn_triton_amd/fwd_prefill.py +++ b/flash_attn/flash_attn_triton_amd/fwd_prefill.py @@ -181,7 +181,7 @@ def _attn_fwd_inner(acc, l_i, m_i, q, k_ptrs, v_ptrs, bias_ptrs, stride_kn, stri m_i = m_ij if IS_FP8: - p_scale = 1 + p_scale = 1 # NOTE: for proper scaling set this = tl.max(p) (increases error) p_scaled = (p / p_scale) acc += tl.dot(p_scaled.to(v.type.element_ty), v.to(v.type.element_ty)).to(tl.float32) * v_scale * p_scale # if you want to use p_scaled: tl.dot(p_scaled.to(v.type.element_ty), v.to(v.type.element_ty)) * v_scale * p_scale else: