Skip to content

Commit

Permalink
test: passes output but fails attn diff
Browse files Browse the repository at this point in the history
  • Loading branch information
alexkranias-amd committed Oct 15, 2024
1 parent 0131273 commit 70051f4
Showing 1 changed file with 9 additions and 4 deletions.
13 changes: 9 additions & 4 deletions tests/test_flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -919,11 +919,11 @@ def test_flash_attn_varlen_qkvpacked(
# @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192])
# @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192])
# @pytest.mark.parametrize('d', [56, 80])
@pytest.mark.parametrize("d", [64])
@pytest.mark.parametrize("d", [4])
@pytest.mark.parametrize(
"seqlen_q,seqlen_k",
[
(113, 203),
(256, 256),
# (128, 217),
# (113, 211),
# (108, 256),
Expand All @@ -937,7 +937,7 @@ def test_flash_attn_varlen_qkvpacked(
)
# @pytest.mark.parametrize('seqlen_q,seqlen_k', [(256, 128)])
# @pytest.mark.parametrize("dropout_p", [0.0, 0.17])
@pytest.mark.parametrize("dropout_p", [0.00])
@pytest.mark.parametrize("dropout_p", [0.20])
# @pytest.mark.parametrize("softcap", [0.0, 50.0])
@pytest.mark.parametrize("softcap", [0.0])
def test_flash_attn_output(
Expand Down Expand Up @@ -968,7 +968,7 @@ def test_flash_attn_output(
device = "cuda"
# set seed
torch.random.manual_seed(0)
batch_size = 4
batch_size = 1
nheads = 6 if softcap == 0.0 else 4 # softcap reference impl takes more memory
nheads_k = nheads if mha_type == "mha" else (1 if mha_type == "mqa" else 2)
assert nheads % nheads_k == 0
Expand Down Expand Up @@ -1032,6 +1032,7 @@ def test_flash_attn_output(
window_size=window_size,
)
dropout_mask = S_dmask_converted >= 0
print('torch_dropout_mask', dropout_mask)
attn_unnorm = S_dmask_converted.abs()
if kvpacked:
kv_rep = repeat(kv, "b s two h d -> b s two (h g) d", g=nheads // nheads_k)
Expand Down Expand Up @@ -1114,6 +1115,10 @@ def test_flash_attn_output(
upcast=False,
reorder_ops=True,
)

print("Output diff", (out - out_ref))
print("Attn diff", (attn - attn_ref))


print(f"Output max diff: {(out - out_ref).abs().max().item()}")
print(f"Output mean diff: {(out - out_ref).abs().mean().item()}")
Expand Down

0 comments on commit 70051f4

Please sign in to comment.