From 70051f44219c2294c6eb2f1738139f3d06ad64dc Mon Sep 17 00:00:00 2001 From: Alex Kranias Date: Tue, 15 Oct 2024 10:52:40 -0500 Subject: [PATCH] test: passes output but fails attn diff --- tests/test_flash_attn.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/tests/test_flash_attn.py b/tests/test_flash_attn.py index 11c68b630..840848344 100755 --- a/tests/test_flash_attn.py +++ b/tests/test_flash_attn.py @@ -919,11 +919,11 @@ def test_flash_attn_varlen_qkvpacked( # @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192]) # @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192]) # @pytest.mark.parametrize('d', [56, 80]) -@pytest.mark.parametrize("d", [64]) +@pytest.mark.parametrize("d", [4]) @pytest.mark.parametrize( "seqlen_q,seqlen_k", [ - (113, 203), + (256, 256), # (128, 217), # (113, 211), # (108, 256), @@ -937,7 +937,7 @@ def test_flash_attn_varlen_qkvpacked( ) # @pytest.mark.parametrize('seqlen_q,seqlen_k', [(256, 128)]) # @pytest.mark.parametrize("dropout_p", [0.0, 0.17]) -@pytest.mark.parametrize("dropout_p", [0.00]) +@pytest.mark.parametrize("dropout_p", [0.20]) # @pytest.mark.parametrize("softcap", [0.0, 50.0]) @pytest.mark.parametrize("softcap", [0.0]) def test_flash_attn_output( @@ -968,7 +968,7 @@ def test_flash_attn_output( device = "cuda" # set seed torch.random.manual_seed(0) - batch_size = 4 + batch_size = 1 nheads = 6 if softcap == 0.0 else 4 # softcap reference impl takes more memory nheads_k = nheads if mha_type == "mha" else (1 if mha_type == "mqa" else 2) assert nheads % nheads_k == 0 @@ -1032,6 +1032,7 @@ def test_flash_attn_output( window_size=window_size, ) dropout_mask = S_dmask_converted >= 0 + print('torch_dropout_mask', dropout_mask) attn_unnorm = S_dmask_converted.abs() if kvpacked: kv_rep = repeat(kv, "b s two h d -> b s two (h g) d", g=nheads // nheads_k) @@ -1114,6 +1115,10 @@ def test_flash_attn_output( upcast=False, reorder_ops=True, ) + + print("Output diff", (out - out_ref)) + print("Attn diff", (attn - attn_ref)) + print(f"Output max diff: {(out - out_ref).abs().max().item()}") print(f"Output mean diff: {(out - out_ref).abs().mean().item()}")