Skip to content

Commit

Permalink
use triton commit 3ca2f498e98ed7249b82722587c511a5610e00c4 -- now bat…
Browse files Browse the repository at this point in the history
…ched layout passes
  • Loading branch information
alexkranias-amd committed Nov 13, 2024
1 parent 3cd9739 commit 34eab23
Showing 1 changed file with 41 additions and 33 deletions.
74 changes: 41 additions & 33 deletions tests/test_flash_attn_triton_amd.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -903,31 +903,30 @@ def test_flash_attn_varlen_qkvpacked(
# @pytest.mark.parametrize("causal", [False, True])
# @pytest.mark.parametrize("causal", [True])
@pytest.mark.parametrize("causal", [False])
# @pytest.mark.parametrize("d", [32, 40, 59, 64, 96, 111, 128, 160, 192, 224, 256])
@pytest.mark.parametrize("d", [32, 40, 59, 64, 96, 111, 128, 160, 192, 224, 256])
# @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192, 224, 256])
# @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192])
# @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192])
# @pytest.mark.parametrize('d', [224, 256])
@pytest.mark.parametrize("d", [32])
# @pytest.mark.parametrize("d", [32])
@pytest.mark.parametrize(
"seqlen_q,seqlen_k",
[
(2, 2),
# (113, 203),
# (128, 217),
# (113, 211),
# (108, 256),
# (256, 512),
# (512, 256),
# (1024, 1024),
# (1023, 1024),
# (1024, 1023),
# (2048, 2048),
(113, 203),
(128, 217),
(113, 211),
(108, 256),
(256, 512),
(512, 256),
(1024, 1024),
(1023, 1024),
(1024, 1023),
(2048, 2048),
],
)
# @pytest.mark.parametrize('seqlen_q,seqlen_k', [(256, 128)])
# @pytest.mark.parametrize("dropout_p", [0.0, 0.17])
@pytest.mark.parametrize("dropout_p", [0.5])
@pytest.mark.parametrize("dropout_p", [0.17])
# @pytest.mark.parametrize("softcap", [0.0, 50.0])
@pytest.mark.parametrize("softcap", [0.0])
def test_flash_attn_output(
Expand Down Expand Up @@ -1157,37 +1156,46 @@ def test_flash_attn_output(
print(f"dK Pytorch mean diff: {(dk_pt - dk_ref).abs().mean().item()}")
print(f"dV Pytorch mean diff: {(dv_pt - dv_ref).abs().mean().item()}")

# NOTE: often is the case the the pytorch max diff is 0. This results in the test almost always
# failing since the triton kernel must have 0 error to pass. To overcome this I've created a constant that is added
# to the error. If it is within these bounds it will pass.
# VERY IMPORTANT NOTE:
# if there is an issue with the dropout mask created in the bwd pass, the max error will be on the order of magnitude of
# 10^0. Thus I have set MIN_ERROR = 10^-2. This is large enough that it will pass every test regardless of precision error,
# but will definitely fail if there is an issue with the reconstructed mask.
MIN_ERROR = 1e-2

# Check that FlashAttention's numerical error is at most twice the numerical error
# of a Pytorch implementation.
if DEBUG:
print("out:", out, out.shape)
print("out_ref:", out_ref, out_ref.shape)
assert (out - out_ref).abs().max().item() <= 2 * (out_pt - out_ref).abs().max().item()
assert (out - out_ref).abs().max().item() <= 2 * (out_pt - out_ref).abs().max().item() + MIN_ERROR

# if dropout_p > 0.0:
# # assert (attn - attn_ref).abs().max().item() <= 2 * (attn_pt - attn_ref).abs().max().item()
# # With alibi, many of the prob values are 0.0 & -0.0 so dropout_fraction isn't accurate
# if not alibi:
# assert abs(dropout_fraction - dropout_p) <= (0.01 if not local else 0.025)
if dropout_p > 0.0:
# assert (attn - attn_ref).abs().max().item() <= 2 * (attn_pt - attn_ref).abs().max().item()
# With alibi, many of the prob values are 0.0 & -0.0 so dropout_fraction isn't accurate
if not alibi:
assert abs(dropout_fraction - dropout_p) <= (0.01 if not local else 0.025)

if (d <= MAX_HEADDIM_SM8x or dropout_p == 0) or (is_sm80 or is_sm90):
if DEBUG:
print("dv:", dv, dv.shape)
print("dv_ref:", dv_ref, dv_ref.shape)
print("dv_pt:", dv_pt, dv_pt.shape)
assert (dv - dv_ref).abs().max().item() <= 3 * (dv_pt - dv_ref).abs().max().item()
assert (dv - dv_ref).abs().max().item() <= 3 * (dv_pt - dv_ref).abs().max().item() + MIN_ERROR

if DEBUG:
print("dk:", dk, dk.shape)
print("dk_ref:", dk_ref, dk_ref.shape)
print("dk_pt:", dk_pt, dk_pt.shape)
assert (dk - dk_ref).abs().max().item() <= 3 * (dk_pt - dk_ref).abs().max().item()
assert (dk - dk_ref).abs().max().item() <= 3 * (dk_pt - dk_ref).abs().max().item() + MIN_ERROR

if DEBUG:
print("dq:", dq, dq.shape)
print("dq_ref:", dq_ref, dq_ref.shape)
print("dq_pt:", dq_pt, dq_pt.shape)
assert (dq - dq_ref).abs().max().item() <= 3 * (dq_pt - dq_ref).abs().max().item()
assert (dq - dq_ref).abs().max().item() <= 3 * (dq_pt - dq_ref).abs().max().item() + MIN_ERROR



Expand Down Expand Up @@ -1217,18 +1225,18 @@ def test_flash_attn_output(
# (128, 217),
# (113, 211),
# (108, 256),
# (256, 512),
(256, 512),
# (512, 256),
(1024, 1024),
(1023, 1024),
(1024, 1023),
(2048, 2048),
# (1024, 1024),
# (1023, 1024),
# (1024, 1023),
# (2048, 2048),
# (790, 790)
],
)
# @pytest.mark.parametrize('seqlen_q,seqlen_k', [(128, 128)])
# @pytest.mark.parametrize("dropout_p", [0.0, 0.17])
@pytest.mark.parametrize('dropout_p', [0.0])
@pytest.mark.parametrize('dropout_p', [0.17])
# @pytest.mark.parametrize("softcap", [0.0, 50.0])
@pytest.mark.parametrize("softcap", [0.0])
def test_flash_attn_varlen_output(
Expand Down Expand Up @@ -1515,7 +1523,7 @@ def test_flash_attn_varlen_output(

# Check that FlashAttention's numerical error is at most twice the numerical error
# of a Pytorch implementation.
assert (out - out_ref).abs().max().item() <= 2 * (out_pt - out_ref).abs().max().item()
assert (out - out_ref).abs().max().item() <= 2 * (out_pt - out_ref).abs().max().item() + MIN_ERROR

if dropout_p > 0.0:
# assert (attn - attn_ref).abs().max().item() <= 2 * (attn_pt - attn_ref).abs().max().item()
Expand All @@ -1528,19 +1536,19 @@ def test_flash_attn_varlen_output(
print("dv:", dv, dv.shape)
print("dv_ref:", dv_ref, dv_ref.shape)
print("dv_pt:", dv_pt, dv_pt.shape)
assert (dv - dv_ref).abs().max().item() <= 3 * (dv_pt - dv_ref).abs().max().item()
assert (dv - dv_ref).abs().max().item() <= 3 * (dv_pt - dv_ref).abs().max().item() + MIN_ERROR

if DEBUG:
print("dk:", dk, dk.shape)
print("dk_ref:", dk_ref, dk_ref.shape)
print("dk_pt:", dk_pt, dk_pt.shape)
assert (dk - dk_ref).abs().max().item() <= 3 * (dk_pt - dk_ref).abs().max().item()
assert (dk - dk_ref).abs().max().item() <= 3 * (dk_pt - dk_ref).abs().max().item() + MIN_ERROR

if DEBUG:
print("dq:", dq, dq.shape)
print("dq_ref:", dq_ref, dq_ref.shape)
print("dq_pt:", dq_pt, dq_pt.shape)
assert (dq - dq_ref).abs().max().item() <= 3 * (dq_pt - dq_ref).abs().max().item()
assert (dq - dq_ref).abs().max().item() <= 3 * (dq_pt - dq_ref).abs().max().item() + MIN_ERROR


# @pytest.mark.parametrize("dtype", ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16]))
Expand Down

0 comments on commit 34eab23

Please sign in to comment.