Skip to content

Commit

Permalink
fix mismatches
Browse files Browse the repository at this point in the history
  • Loading branch information
micmelesse committed Dec 9, 2024
1 parent 94b2da3 commit 7434112
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 31 deletions.
3 changes: 0 additions & 3 deletions flash_attn/flash_attn_triton_amd/fwd_prefill.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import triton.language as tl
from .utils import DEBUG, DROPOUT_USE_PYTORCH, DROPOUT_DUMP, AUTOTUNE, get_shape_from_layout, get_strides_from_layout, is_cdna, is_rdna, write_dropout_mask, create_dropout_mask, create_scale_tensors


# NOTE: triton fails to import tl.constexprs so create them here for the file
tl_DROPOUT_USE_PYTORCH: tl.constexpr = DROPOUT_USE_PYTORCH
tl_DROPOUT_DUMP: tl.constexpr = DROPOUT_DUMP
Expand Down Expand Up @@ -576,8 +575,6 @@ def attention_prefill_forward_triton_impl(
torch.float8_e5m2fnuz,
}
is_fp8 = q.dtype in fp8_types
is_fp8 = False


# if qkv are fp8, then find scaling factor for quantization
q_scale, k_scale, v_scale = create_scale_tensors(q, k, v, SCALE_PER_HEAD=True, layout=layout) # TODO: if SCALE_PER_HEAD: within the kernel itself just compute qkv_scale = tl.max(q or k or v)
Expand Down
58 changes: 30 additions & 28 deletions flash_attn/flash_attn_triton_amd/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -353,34 +353,34 @@ def test_op_bwd(Z, H, N_CTX_Q, N_CTX_K, D_HEAD, causal, torch_sdpa_test, use_ali
(1, 2, 2, 4, 4, 16),
(2, 1, 1, 4, 4, 16),
(2, 2, 2, 4, 4, 16),
(1, 1, 1, 128, 64, 16),
(2, 2, 2, 2, 128, 1),
(2, 3, 3, 2, 128, 16),
(3, 2, 2, 256, 512, 16),
(3, 3, 3, 128, 128, 64),
(2, 4, 4, 1024, 1024, 64),
(4, 6, 6, 108, 256, 224),
(4, 8, 8, 2048, 2048, 128),
(4, 16, 16, 4096, 4096, 64),
(2, 4, 4, 8192, 8192, 32),
# fa configs
(4, 6, 1, 113, 203, 256),
(4, 6, 1, 128, 217, 256),
(4, 6, 2, 113, 211, 128),
(4, 6, 2, 108, 256, 128),
(4, 6, 1, 256, 512, 64),
(4, 6, 1, 512, 256, 64),
(4, 6, 2, 1024, 1024, 32),
(4, 6, 2, 1023, 1024, 32),
(4, 6, 6, 1024, 1023, 32),
(4, 6, 6, 2048, 2048, 32),
# (1, 1, 1, 128, 64, 16),
# (2, 2, 2, 2, 128, 1),
# (2, 3, 3, 2, 128, 16),
# (3, 2, 2, 256, 512, 16),
# (3, 3, 3, 128, 128, 64),
# (2, 4, 4, 1024, 1024, 64),
# (4, 6, 6, 108, 256, 224),
# (4, 8, 8, 2048, 2048, 128),
# (4, 16, 16, 4096, 4096, 64),
# (2, 4, 4, 8192, 8192, 32),
# # fa configs
# (4, 6, 1, 113, 203, 256),
# (4, 6, 1, 128, 217, 256),
# (4, 6, 2, 113, 211, 128),
# (4, 6, 2, 108, 256, 128),
# (4, 6, 1, 256, 512, 64),
# (4, 6, 1, 512, 256, 64),
# (4, 6, 2, 1024, 1024, 32),
# (4, 6, 2, 1023, 1024, 32),
# (4, 6, 6, 1024, 1023, 32),
# (4, 6, 6, 2048, 2048, 32),
],
)
@pytest.mark.parametrize('causal', [True, False])
@pytest.mark.parametrize('causal', [False])
@pytest.mark.parametrize('dropout_p', [0.0])
@pytest.mark.parametrize('layout', ["bhsd", "bshd", "thd"])
@pytest.mark.parametrize('use_exp2', [True, False]) # works when use_exp2 is false
@pytest.mark.parametrize('dtype', [torch.float8_e4m3fnuz, torch.float16])
@pytest.mark.parametrize('layout', ["bhsd"])
@pytest.mark.parametrize('use_exp2', [False]) # works when use_exp2 is false
@pytest.mark.parametrize('dtype', [torch.float8_e4m3fnuz])
@pytest.mark.parametrize('DEBUG_INPUT', [False]) # NOTE: debug input can overflow when the tensors are large. Just use to figure out issues
def test_op_prefill_fwd_impl(Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, causal, dropout_p, layout, use_exp2, dtype, DEBUG_INPUT):
torch.manual_seed(0)
Expand Down Expand Up @@ -436,9 +436,7 @@ def test_op_prefill_fwd_impl(Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, causal, dropou
metadata.use_exp2)

output_ref, softmax_lse_ref, sd_mask_ref = attention_forward_pytorch_ref_impl(
q_fp32,
k_fp32,
v_fp32,
q_fp32, k_fp32, v_fp32,
metadata.sm_scale,
causal,
layout,
Expand Down Expand Up @@ -468,6 +466,10 @@ def test_op_prefill_fwd_impl(Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, causal, dropou
print("softmax_lse_ref:", softmax_lse_ref, softmax_lse_ref.shape)
torch.testing.assert_close(softmax_lse_triton, softmax_lse_ref, atol=ATOL, rtol=RTOL)

# if triton is fp8, cast to fp16 in order to compare with ref
if output_triton.dtype in {torch.float8_e4m3fnuz, torch.float8_e5m2, torch.float8_e5m2fnuz}:
output_triton = output_triton.to(torch.float16)

if DEBUG:
print("output_triton:", output_triton, output_triton.shape)
print("output_ref:", output_ref, output_ref.shape)
Expand Down

0 comments on commit 7434112

Please sign in to comment.