Skip to content

Commit

Permalink
save
Browse files Browse the repository at this point in the history
  • Loading branch information
micmelesse committed Jan 10, 2025
1 parent d080d33 commit daa7532
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 161 deletions.
141 changes: 0 additions & 141 deletions flash_attn/flash_attn_triton_amd/fp8.py

This file was deleted.

7 changes: 6 additions & 1 deletion flash_attn/flash_attn_triton_amd/fwd_prefill.py
Original file line number Diff line number Diff line change
Expand Up @@ -655,7 +655,12 @@ def attention_prefill_forward_triton_impl(
q_max = torch.maximum(q_float32.abs().amax(dim=(seqlen_loc, 3)), torch.tensor(eps))
k_max = torch.maximum(k_float32.abs().amax(dim=(seqlen_loc, 3)), torch.tensor(eps))
v_max = torch.maximum(v_float32.abs().amax(dim=(seqlen_loc, 3)), torch.tensor(eps))


# add unsqueeze operations to make q_max broadcastable
q_max = q_max.unsqueeze(seqlen_loc).unsqueeze(-1)
k_max = k_max.unsqueeze(seqlen_loc).unsqueeze(-1)
v_max = v_max.unsqueeze(seqlen_loc).unsqueeze(-1)

# Scale values to fp8 range
q = (q_float32 * type_max/ q_max).to(q.dtype)
k = (k_float32 * type_max/ k_max).to(k.dtype)
Expand Down
38 changes: 19 additions & 19 deletions flash_attn/flash_attn_triton_amd/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -476,26 +476,26 @@ def test_op_prefill_fwd_impl(Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, causal, dropou
@pytest.mark.parametrize(
"Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD",
[
# (1, 1, 1, 1, 1, 1),
(1, 1, 1, 1, 1, 1),
(1, 1, 1, 2, 4, 16),
# (1, 2, 2, 2, 4, 16),
# (1, 4, 1, 2, 4, 16),
# (1, 4, 2, 2, 4, 16),
# # (1, 1, 1, 4, 2, 16),
# (1, 1, 1, 4, 4, 16),
# (1, 2, 2, 4, 4, 16),
# (2, 1, 1, 4, 4, 16),
# (2, 2, 2, 4, 4, 16),
# (1, 1, 1, 128, 64, 16),
# (2, 2, 2, 2, 128, 1),
# (2, 3, 3, 2, 128, 16),
# (3, 2, 2, 256, 512, 16),
# (3, 3, 3, 128, 128, 64),
# (2, 4, 4, 1024, 1024, 64),
# (4, 6, 6, 108, 256, 224),
# (4, 8, 8, 2048, 2048, 128),
# (4, 16, 16, 4096, 4096, 64),
# (2, 4, 4, 8192, 8192, 32),
(1, 2, 2, 2, 4, 16),
(1, 4, 1, 2, 4, 16),
(1, 4, 2, 2, 4, 16),
(1, 1, 1, 4, 2, 16),
(1, 1, 1, 4, 4, 16),
(1, 2, 2, 4, 4, 16),
(2, 1, 1, 4, 4, 16),
(2, 2, 2, 4, 4, 16),
(1, 1, 1, 128, 64, 16),
(2, 2, 2, 2, 128, 1),
(2, 3, 3, 2, 128, 16),
(3, 2, 2, 256, 512, 16),
(3, 3, 3, 128, 128, 64),
(2, 4, 4, 1024, 1024, 64),
(4, 6, 6, 108, 256, 224),
(4, 8, 8, 2048, 2048, 128),
(4, 16, 16, 4096, 4096, 64),
(2, 4, 4, 8192, 8192, 32),
# # fa configs
# (4, 6, 1, 113, 203, 256),
# (4, 6, 1, 128, 217, 256),
Expand Down

0 comments on commit daa7532

Please sign in to comment.