Skip to content

Commit

Permalink
misc: added note about p_scale
Browse files Browse the repository at this point in the history
  • Loading branch information
alexkranias-amd committed Dec 9, 2024
1 parent c65af82 commit 9297d78
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion flash_attn/flash_attn_triton_amd/fwd_prefill.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ def _attn_fwd_inner(acc, l_i, m_i, q, k_ptrs, v_ptrs, bias_ptrs, stride_kn, stri
m_i = m_ij

if IS_FP8:
p_scale = 1
p_scale = 1 # NOTE: for proper scaling set this = tl.max(p) (increases error)
p_scaled = (p / p_scale)
acc += tl.dot(p_scaled.to(v.type.element_ty), v.to(v.type.element_ty)).to(tl.float32) * v_scale * p_scale # if you want to use p_scaled: tl.dot(p_scaled.to(v.type.element_ty), v.to(v.type.element_ty)) * v_scale * p_scale
else:
Expand Down

0 comments on commit 9297d78

Please sign in to comment.