Skip to content

Commit

Permalink
backwards for softcapping (Dao-AILab#1033)
Browse files Browse the repository at this point in the history
* check in the two ways of approaching backwards for softcapping, both functional

* prepare the softcap switch for backwards

* temporary

* cleanup to the way Tri prefers

* calculate dtanh when copying from scores -> dtanh Tensor

* no ternary operators allowed for constexpr, so just use some hack found online

* fix maybe_dtanh, restore some files

* restore another file

* move calculate_dtanh to utils and colocate with apply_softcap

* cleanup

* maybe last cleanup

* save for another pr

* remove a stray line

* fix spacing

* fix an issue, and make test_flash_attn.py ready to test softcapping backwards
  • Loading branch information
lucidrains authored Jul 22, 2024
1 parent ef3e358 commit 5f1ae4a
Show file tree
Hide file tree
Showing 5 changed files with 69 additions and 33 deletions.
32 changes: 28 additions & 4 deletions csrc/flash_attn/src/flash_bwd_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ make_tiled_copy_C_warpcontiguousN(Copy_Atom<Args...> const& copy_atom,

////////////////////////////////////////////////////////////////////////////////////////////////////

template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, bool Is_first, bool Is_last, bool Seq_parallel=false, typename Params>
template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, bool Is_softcap, bool Is_first, bool Is_last, bool Seq_parallel=false, typename Params>
inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const int bidb, const int bidh, const int n_block) {

using Element = typename Kernel_traits::Element;
Expand Down Expand Up @@ -471,10 +471,27 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
flash::gemm(acc_s, tSrQ, tSrK, tSsQ, tSsK, tiled_mma_sdp,
smem_tiled_copy_QdO, smem_tiled_copy_KV, smem_thr_copy_QdO, smem_thr_copy_KV);

if constexpr (Is_softcap) {
flash::apply_softcap(acc_s, params.softcap);
}

// Reshape acc_s from (MMA=4, MMA_N, MMA_N) to (row=(2, MMA_N), col=(2, MMA_N))
Tensor scores = make_tensor(acc_s.data(), flash::convert_layout_acc_rowcol(acc_s.layout()));
// if (cute::thread(32, 0)) { print(scores); }

// Softcapping - calculating dTanh and scaling dS later with it
auto dtanh = ([&]{
if constexpr (Is_softcap) {
Tensor _dtanh = make_tensor_like(scores);
flash::calculate_dtanh(scores, _dtanh, params.softcap);
return _dtanh;
}
else {
return nullptr;
}
}());

// Alibi
if (Has_alibi) {
alibi.apply_alibi(scores, n_block * kBlockN + (tidx / 32 / AtomLayoutMS) * MMA_N_SdP * 16,
m_block * kBlockM + get<0>(taccScS_row(0)), AtomLayoutMS * 16);
Expand Down Expand Up @@ -574,7 +591,14 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
for (int mi = 0; mi < size<0>(dS); ++mi) {
#pragma unroll
for (int ni = 0; ni < size<1>(dS); ++ni) {
dS(mi, ni) = pointwise_mult(scores(mi, ni), dS(mi, ni), dP_sum(mi));

float scaled_ds = pointwise_mult(scores(mi, ni), dS(mi, ni), dP_sum(mi));

if constexpr (Is_softcap) {
scaled_ds *= dtanh(mi, ni);
}

dS(mi, ni) = scaled_ds;
}
}
// if (cute::thread0()) { print(dS); }
Expand Down Expand Up @@ -807,7 +831,7 @@ inline __device__ void compute_dq_dk_dv(const Params &params) {

////////////////////////////////////////////////////////////////////////////////////////////////////

template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, typename Params>
template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, bool Is_softcap, typename Params>
inline __device__ void compute_dq_dk_dv_seqk_parallel(const Params &params) {

// The block index for the batch.
Expand All @@ -817,7 +841,7 @@ inline __device__ void compute_dq_dk_dv_seqk_parallel(const Params &params) {

// If deterministic, each thread block will do atomicAdd to a different dQ_accum buffer.
for (int n_block = blockIdx.x; n_block < (params.seqlen_k + Kernel_traits::kBlockN - 1) / Kernel_traits::kBlockN; n_block += gridDim.x) {
compute_dq_dk_dv_1colblock<Kernel_traits, Is_dropout, Is_causal, Is_local, Has_alibi, Is_even_MN, Is_even_K, false, false, /*Seq_parallel=*/true>(params, bidb, bidh, n_block);
compute_dq_dk_dv_1colblock<Kernel_traits, Is_dropout, Is_causal, Is_local, Has_alibi, Is_even_MN, Is_even_K, Is_softcap, false, false, /*Seq_parallel=*/true>(params, bidb, bidh, n_block);
}
}

Expand Down
28 changes: 15 additions & 13 deletions csrc/flash_attn/src/flash_bwd_launch_template.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,10 @@ DEFINE_FLASH_BACKWARD_KERNEL(flash_bwd_dq_dk_dv_loop_kernel, bool Is_dropout, bo
#endif
}

DEFINE_FLASH_BACKWARD_KERNEL(flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel, bool Is_dropout, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K) {
DEFINE_FLASH_BACKWARD_KERNEL(flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel, bool Is_dropout, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, bool Is_softcap) {
#if defined(ARCH_SUPPORTS_FLASH)
static_assert(!(Is_causal && Is_local)); // If Is_local is true, Is_causal should be false
flash::compute_dq_dk_dv_seqk_parallel<Kernel_traits, Is_dropout, Is_causal, Is_local, Has_alibi, Is_even_MN, Is_even_K>(params);
flash::compute_dq_dk_dv_seqk_parallel<Kernel_traits, Is_dropout, Is_causal, Is_local, Has_alibi, Is_even_MN, Is_even_K, Is_softcap>(params);
#else
FLASH_UNSUPPORTED_ARCH
#endif
Expand Down Expand Up @@ -95,17 +95,19 @@ void run_flash_bwd_seqk_parallel(Flash_bwd_params &params, cudaStream_t stream)
EVENK_SWITCH(is_even_K, IsEvenKConst, [&] {
LOCAL_SWITCH((params.window_size_left >= 0 || params.window_size_right >= 0) && !params.is_causal, Is_local, [&] {
ALIBI_SWITCH(params.alibi_slopes_ptr != nullptr, Has_alibi, [&] {
// If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates.
// If head dim > 128, set IsEvenMNConst to false to reduce number of templates
// If Is_local, set Is_causal to false
auto kernel = &flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel<Kernel_traits, Is_dropout, Is_causal, Is_local && !Is_causal, Has_alibi, IsEvenMNConst && IsEvenKConst && !Is_local && Kernel_traits::kHeadDim <= 128, IsEvenKConst>;
// auto kernel = &flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel<Kernel_traits, false, Is_causal, false, false, true, true>;
if (smem_size_dq_dk_dv >= 48 * 1024) {
C10_CUDA_CHECK(cudaFuncSetAttribute(
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size_dq_dk_dv));
}
kernel<<<grid_n, Kernel_traits::kNThreads, smem_size_dq_dk_dv, stream>>>(params);
C10_CUDA_KERNEL_LAUNCH_CHECK();
SOFTCAP_SWITCH(params.softcap > 0.0, Is_softcap, [&] {
// If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates.
// If head dim > 128, set IsEvenMNConst to false to reduce number of templates
// If Is_local, set Is_causal to false
auto kernel = &flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel<Kernel_traits, Is_dropout, Is_causal, Is_local && !Is_causal, Has_alibi, IsEvenMNConst && IsEvenKConst && !Is_local && Kernel_traits::kHeadDim <= 128, IsEvenKConst, Is_softcap>;
// auto kernel = &flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel<Kernel_traits, false, Is_causal, false, false, true, true>;
if (smem_size_dq_dk_dv >= 48 * 1024) {
C10_CUDA_CHECK(cudaFuncSetAttribute(
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size_dq_dk_dv));
}
kernel<<<grid_n, Kernel_traits::kNThreads, smem_size_dq_dk_dv, stream>>>(params);
C10_CUDA_KERNEL_LAUNCH_CHECK();
});
});
});
});
Expand Down
16 changes: 4 additions & 12 deletions csrc/flash_attn/src/flash_fwd_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,6 @@ namespace flash {

using namespace cute;

template <typename Engine, typename Layout>
__forceinline__ __device__ void apply_softcap(Tensor<Engine, Layout> &tensor, const float softcap){
#pragma unroll
for (int i = 0; i < size(tensor); ++i) {
tensor(i) = cutlass::fast_tanh(tensor(i) * softcap);
}
}

////////////////////////////////////////////////////////////////////////////////////////////////////

template<typename ElementAccum, typename Params, int kBlockM, bool Is_even_MN>
Expand Down Expand Up @@ -327,7 +319,7 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
);
// if (cute::thread0()) { print(acc_s); }
if constexpr (Is_softcap){
apply_softcap(acc_s, params.softcap);
flash::apply_softcap(acc_s, params.softcap);
}

mask.template apply_mask<Is_causal, Is_even_MN>(
Expand Down Expand Up @@ -393,7 +385,7 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
smem_thr_copy_Q, smem_thr_copy_K
);
if constexpr (Is_softcap){
apply_softcap(acc_s, params.softcap);
flash::apply_softcap(acc_s, params.softcap);
}

flash::cp_async_wait<0>();
Expand Down Expand Up @@ -887,7 +879,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
);
// if (cute::thread0()) { print(acc_s); }
if constexpr (Is_softcap){
apply_softcap(acc_s, params.softcap);
flash::apply_softcap(acc_s, params.softcap);
}


Expand Down Expand Up @@ -962,7 +954,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
smem_thr_copy_Q, smem_thr_copy_K
);
if constexpr (Is_softcap){
apply_softcap(acc_s, params.softcap);
flash::apply_softcap(acc_s, params.softcap);
}

flash::cp_async_wait<0>();
Expand Down
18 changes: 18 additions & 0 deletions csrc/flash_attn/src/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -390,4 +390,22 @@ __forceinline__ __device__ void copy_w_min_idx(Tensor<Engine0, Layout0> const &S

////////////////////////////////////////////////////////////////////////////////////////////////////

template <typename Engine, typename Layout>
__forceinline__ __device__ void apply_softcap(Tensor<Engine, Layout> &tensor, const float softcap){
#pragma unroll
for (int i = 0; i < size(tensor); ++i) {
tensor(i) = cutlass::fast_tanh(tensor(i) * softcap);
}
}

template <typename Engine0, typename Layout0, typename Engine1, typename Layout1>
__forceinline__ __device__ void calculate_dtanh(Tensor<Engine0, Layout0> &src_tensor, Tensor<Engine1, Layout1> &dst_tensor, const float softcap){
#pragma unroll
for (int i = 0; i < size(src_tensor); ++i) {
dst_tensor(i) = (1.f - (src_tensor(i) * src_tensor(i))) * softcap;
}
}

////////////////////////////////////////////////////////////////////////////////////////////////////

} // namespace flash
8 changes: 4 additions & 4 deletions tests/test_flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,9 +261,9 @@ def attention_ref(
else:
scores = torch.einsum("bthd,bshd->bhts", q, k / math.sqrt(d))
if softcap > 0:
scores /= softcap
scores = scores / softcap
scores = scores.tanh()
scores *= softcap
scores = scores * softcap
if key_padding_mask is not None:
scores.masked_fill_(rearrange(~key_padding_mask, "b s -> b 1 1 s"), float("-inf"))
if window_size[0] >= 0 or window_size[1] >= 0:
Expand Down Expand Up @@ -1122,7 +1122,7 @@ def test_flash_attn_output(
if not alibi:
assert abs(dropout_fraction - dropout_p) <= (0.01 if not local else 0.025)

if ((d <= MAX_HEADDIM_SM8x or (d > 224 and dropout_p == 0)) or (is_sm80 or is_sm90)) and softcap == 0.0:
if ((d <= MAX_HEADDIM_SM8x or (d > 224 and dropout_p == 0)) or (is_sm80 or is_sm90)):
assert (dq - dq_ref).abs().max().item() <= 2 * (dq_pt - dq_ref).abs().max().item()
assert (dk - dk_ref).abs().max().item() <= 2 * (dk_pt - dk_ref).abs().max().item()
assert (dv - dv_ref).abs().max().item() <= 2 * (dv_pt - dv_ref).abs().max().item()
Expand Down Expand Up @@ -1382,7 +1382,7 @@ def test_flash_attn_varlen_output(
print(f"Attention Pytorch max diff: {(attn_pt - attn_ref).abs().max().item()}")

g = torch.randn_like(out)
if ((d <= MAX_HEADDIM_SM8x or (d > 224 and dropout_p == 0)) or (is_sm80 or is_sm90)) and softcap == 0.0:
if ((d <= MAX_HEADDIM_SM8x or (d > 224 and dropout_p == 0)) or (is_sm80 or is_sm90)):
if kvpacked:
(
dq_unpad,
Expand Down

0 comments on commit 5f1ae4a

Please sign in to comment.