Skip to content

Commit

Permalink
Update fmha_bwd_traits template arguments
Browse files Browse the repository at this point in the history
  • Loading branch information
poyenc committed Dec 28, 2024
1 parent 5a93a97 commit e471acc
Showing 1 changed file with 1 addition and 5 deletions.
6 changes: 1 addition & 5 deletions csrc/flash_attn_ck/mha_bwd.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ fmha_bwd_traits get_ck_fmha_bwd_traits(const mask_info &mask,
false, // s_randval
deterministic,
true, // uses_ext_asm
head_size != 64, // is_v3_atomic_fp32
true, // is_v3_atomic_fp32
false, // is_v3_spec
1}; // how_v3_bf16_cvt 0:RTNE; 1:RTNA; 2:RTZ
}
Expand Down Expand Up @@ -338,10 +338,6 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num
dv_expanded = dv;
}

if (head_size == 64) {
dq.zero_();
}

auto gen = at::get_generator_or_default<at::CUDAGeneratorImpl>(
gen_, at::cuda::detail::getDefaultCUDAGenerator());

Expand Down

0 comments on commit e471acc

Please sign in to comment.