Skip to content

Commit

Permalink
Merge pull request #114 from ROCm/ck_tile/fa3-hd64-bf16-atomic32
Browse files Browse the repository at this point in the history
[CK_TILE] Enable FAv3 bwd for head_size=64 dtype=bf16 atomic32
  • Loading branch information
poyenc authored Dec 28, 2024
2 parents e10bc4d + f547b58 commit 423b9c2
Show file tree
Hide file tree
Showing 3 changed files with 2 additions and 12 deletions.
6 changes: 1 addition & 5 deletions csrc/flash_attn_ck/mha_bwd.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ fmha_bwd_traits get_ck_fmha_bwd_traits(const mask_info &mask,
false, // s_randval
deterministic,
true, // uses_ext_asm
head_size != 64, // is_v3_atomic_fp32
true, // is_v3_atomic_fp32
false, // is_v3_spec
1}; // how_v3_bf16_cvt 0:RTNE; 1:RTNA; 2:RTZ
}
Expand Down Expand Up @@ -338,10 +338,6 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num
dv_expanded = dv;
}

if (head_size == 64) {
dq.zero_();
}

auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>(
gen_, at::cuda::detail::getDefaultCUDAGenerator());

Expand Down
6 changes: 0 additions & 6 deletions tests/test_flash_attn_ck_fa3.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,12 +80,6 @@ def pad_rearrange_dropout_mask_hts_to_bhss(S_dmask, cu_seqlens_q, seqlen_q_round
def test_flash_attn_output(
seqlen_q, seqlen_k, d, dropout_p, causal, local, alibi, deterministic, mha_type, dtype, kvpacked
):
if d == 64 and causal and dtype is torch.bfloat16:
pytest.skip("hd=64,dtype=bf16 with causal mask not supported")

if d == 128 and causal:
pytest.skip("hd=128 with causal mask not supported")

device = "cuda"
# set seed
torch.random.manual_seed(0)
Expand Down

0 comments on commit 423b9c2

Please sign in to comment.