diff --git a/.github/workflows/publish.yml b/.github/workflows/publish.yml index 020c1371a..58a79a6ca 100644 --- a/.github/workflows/publish.yml +++ b/.github/workflows/publish.yml @@ -43,9 +43,9 @@ jobs: # Using ubuntu-20.04 instead of 22.04 for more compatibility (glibc). Ideally we'd use the # manylinux docker image, but I haven't figured out how to install CUDA on manylinux. os: [ubuntu-20.04] - python-version: ['3.7', '3.8', '3.9', '3.10', '3.11', '3.12'] - torch-version: ['1.12.1', '1.13.1', '2.0.1', '2.1.2', '2.2.2', '2.3.0', '2.4.0.dev20240407'] - cuda-version: ['11.8.0', '12.2.2'] + python-version: ['3.8', '3.9', '3.10', '3.11', '3.12'] + torch-version: ['2.0.1', '2.1.2', '2.2.2', '2.3.1', '2.4.0.dev20240514'] + cuda-version: ['11.8.0', '12.3.2'] # We need separate wheels that either uses C++11 ABI (-D_GLIBCXX_USE_CXX11_ABI) or not. # Pytorch wheels currently don't use it, but nvcr images have Pytorch compiled with C++11 ABI. # Without this we get import error (undefined symbol: _ZN3c105ErrorC2ENS_14SourceLocationESs) @@ -54,35 +54,13 @@ jobs: exclude: # see https://github.com/pytorch/pytorch/blob/main/RELEASE.md#release-compatibility-matrix # Pytorch < 2.2 does not support Python 3.12 - - torch-version: '1.12.1' - python-version: '3.12' - - torch-version: '1.13.1' - python-version: '3.12' - torch-version: '2.0.1' python-version: '3.12' - torch-version: '2.1.2' python-version: '3.12' - # Pytorch <= 1.12 does not support Python 3.11 - - torch-version: '1.12.1' - python-version: '3.11' - # Pytorch >= 2.0 only supports Python >= 3.8 - - torch-version: '2.0.1' - python-version: '3.7' - - torch-version: '2.1.2' - python-version: '3.7' - - torch-version: '2.2.2' - python-version: '3.7' - - torch-version: '2.3.0' - python-version: '3.7' - - torch-version: '2.4.0.dev20240407' - python-version: '3.7' # Pytorch <= 2.0 only supports CUDA <= 11.8 - - torch-version: '1.12.1' - cuda-version: '12.2.2' - - torch-version: '1.13.1' - cuda-version: '12.2.2' - torch-version: '2.0.1' - cuda-version: '12.2.2' + cuda-version: '12.3.2' steps: - name: Checkout @@ -97,7 +75,6 @@ jobs: run: | echo "MATRIX_CUDA_VERSION=$(echo ${{ matrix.cuda-version }} | awk -F \. {'print $1 $2'})" >> $GITHUB_ENV echo "MATRIX_TORCH_VERSION=$(echo ${{ matrix.torch-version }} | awk -F \. {'print $1 "." $2'})" >> $GITHUB_ENV - echo "MATRIX_PYTHON_VERSION=$(echo ${{ matrix.python-version }} | awk -F \. {'print $1 $2'})" >> $GITHUB_ENV - name: Free up disk space if: ${{ runner.os == 'Linux' }} @@ -141,8 +118,8 @@ jobs: # see https://github.com/pytorch/pytorch/blob/main/RELEASE.md#release-compatibility-matrix # This code is ugly, maybe there's a better way to do this. export TORCH_CUDA_VERSION=$(python -c "from os import environ as env; \ - minv = {'1.12': 113, '1.13': 116, '2.0': 117, '2.1': 118, '2.2': 118, '2.3': 118, '2.4': 118}[env['MATRIX_TORCH_VERSION']]; \ - maxv = {'1.12': 116, '1.13': 117, '2.0': 118, '2.1': 121, '2.2': 121, '2.3': 121, '2.4': 121}[env['MATRIX_TORCH_VERSION']]; \ + minv = {'2.0': 117, '2.1': 118, '2.2': 118, '2.3': 118, '2.4': 118}[env['MATRIX_TORCH_VERSION']]; \ + maxv = {'2.0': 118, '2.1': 121, '2.2': 121, '2.3': 121, '2.4': 121}[env['MATRIX_TORCH_VERSION']]; \ print(max(min(int(env['MATRIX_CUDA_VERSION']), maxv), minv))" \ ) if [[ ${{ matrix.torch-version }} == *"dev"* ]]; then @@ -168,8 +145,8 @@ jobs: export PATH=/usr/local/nvidia/bin:/usr/local/nvidia/lib64:$PATH export LD_LIBRARY_PATH=/usr/local/nvidia/lib64:/usr/local/cuda/lib64:$LD_LIBRARY_PATH # Limit MAX_JOBS otherwise the github runner goes OOM - # CUDA 11.8 can compile with 2 jobs, but CUDA 12.2 goes OOM - MAX_JOBS=$([ "$MATRIX_CUDA_VERSION" == "122" ] && echo 1 || echo 2) FLASH_ATTENTION_FORCE_BUILD="TRUE" FLASH_ATTENTION_FORCE_CXX11_ABI=${{ matrix.cxx11_abi}} python setup.py bdist_wheel --dist-dir=dist + # CUDA 11.8 can compile with 2 jobs, but CUDA 12.3 goes OOM + MAX_JOBS=$([ "$MATRIX_CUDA_VERSION" == "123" ] && echo 1 || echo 2) FLASH_ATTENTION_FORCE_BUILD="TRUE" FLASH_ATTENTION_FORCE_CXX11_ABI=${{ matrix.cxx11_abi}} python setup.py bdist_wheel --dist-dir=dist tmpname=cu${MATRIX_CUDA_VERSION}torch${MATRIX_TORCH_VERSION}cxx11abi${{ matrix.cxx11_abi }} wheel_name=$(ls dist/*whl | xargs -n 1 basename | sed "s/-/+$tmpname-/2") ls dist/*whl |xargs -I {} mv {} dist/${wheel_name} diff --git a/README.md b/README.md index 67972e650..b6efd8ee3 100644 --- a/README.md +++ b/README.md @@ -26,6 +26,42 @@ contains a partial list of places where FlashAttention is being used. FlashAttention and FlashAttention-2 are free to use and modify (see LICENSE). Please cite and credit FlashAttention if you use it. + +## FlashAttention-3 beta release +FlashAttention-3 is optimized for Hopper GPUs (e.g. H100). + +Blogpost: https://tridao.me/blog/2024/flash3/ + +Paper: https://tridao.me/publications/flash3/flash3.pdf + +![FlashAttention-3 speedup on H100 80GB SXM5 with FP16](assets/flash3_fp16_fwd.png) + +This is a beta release for testing / benchmarking before we integrate that with +the rest of the repo. + +Currently released: +- FP16 forward and backward + +Coming soon in the next couple of days / next week: +- BF16 +- Variable length (FP16, BF16) +- FP8 forward. + +Requirements: H100 / H800 GPU, CUDA >= 12.3. + +To install: +```sh +cd hopper +python setup.py install +``` +To run the test: +```sh +export PYTHONPATH=$PWD +pytest -q -s test_flash_attn.py +``` + + + ## Installation and features Requirements: @@ -314,6 +350,11 @@ Implement deterministic backward pass. Thanks to engineers from [Meituan](www.me Support paged KV cache (i.e., [PagedAttention](https://arxiv.org/abs/2309.06180)). Thanks to @beginlner for this contribution. +### 2.6: Softcapping. + +Support attention with softcapping, as used in Gemma-2 and Grok models. +Thanks to @Narsil for this contribution. + ## Performance We present expected speedup (combined forward + backward pass) and memory savings from using FlashAttention against PyTorch standard attention, depending on sequence length, on different GPUs (speedup depends on memory bandwidth - we see more speedup on slower GPU memory). diff --git a/assets/flash3_fp16_fwd.png b/assets/flash3_fp16_fwd.png new file mode 100644 index 000000000..403d13944 Binary files /dev/null and b/assets/flash3_fp16_fwd.png differ diff --git a/csrc/cutlass b/csrc/cutlass index 7d49e6c7e..756c351b4 160000 --- a/csrc/cutlass +++ b/csrc/cutlass @@ -1 +1 @@ -Subproject commit 7d49e6c7e2f8896c47f586706e67e1fb215529dc +Subproject commit 756c351b4994854b2f8c6dded3821ebbb580876b diff --git a/csrc/flash_attn/flash_api.cpp b/csrc/flash_attn/flash_api.cpp index ac753af2c..2b79b9709 100644 --- a/csrc/flash_attn/flash_api.cpp +++ b/csrc/flash_attn/flash_api.cpp @@ -43,7 +43,9 @@ void set_params_fprop(Flash_fwd_params ¶ms, float softmax_scale, int window_size_left, int window_size_right, - bool seqlenq_ngroups_swapped=false) { + const float softcap, + bool seqlenq_ngroups_swapped=false, + const bool unpadded_lse=false) { // Reset the parameters params = {}; @@ -99,8 +101,19 @@ void set_params_fprop(Flash_fwd_params ¶ms, params.d_rounded = d_rounded; // Set the different scale values. - params.scale_softmax = softmax_scale; - params.scale_softmax_log2 = softmax_scale * M_LOG2E; + #ifdef FLASHATTENTION_DISABLE_SOFTCAP + TORCH_CHECK(softcap <= 0.0, "This flash attention build does not support softcap."); + #endif + if (softcap > 0.0) { + params.softcap = softmax_scale / softcap; + params.scale_softmax = softcap; + params.scale_softmax_log2 = softcap * M_LOG2E; + } else{ + // Remove potential NaN + params.softcap = 0.0; + params.scale_softmax = softmax_scale; + params.scale_softmax_log2 = softmax_scale * M_LOG2E; + } // Set this to probability of keeping an element to simplify things. params.p_dropout = 1.f - p_dropout; @@ -135,6 +148,9 @@ void set_params_fprop(Flash_fwd_params ¶ms, #ifdef FLASHATTENTION_DISABLE_UNEVEN_K TORCH_CHECK(d == d_rounded, "This flash attention build does not support headdim not being a multiple of 32."); #endif + + params.unpadded_lse = unpadded_lse; + params.seqlenq_ngroups_swapped = seqlenq_ngroups_swapped; } void set_params_dgrad(Flash_bwd_params ¶ms, @@ -168,7 +184,9 @@ void set_params_dgrad(Flash_bwd_params ¶ms, float softmax_scale, int window_size_left, int window_size_right, - bool deterministic) { + const float softcap, + bool deterministic, + const bool unpadded_lse) { set_params_fprop(params, b, seqlen_q, seqlen_k, seqlen_q_rounded, seqlen_k_rounded, h, h_k, d, d_rounded, @@ -181,7 +199,10 @@ void set_params_dgrad(Flash_bwd_params ¶ms, p_dropout, softmax_scale, window_size_left, - window_size_right); + window_size_right, + softcap, + false, // seqlenq_ngroups_swapped + unpadded_lse); // Set the pointers and strides. params.do_ptr = dout.data_ptr(); @@ -217,11 +238,13 @@ void set_params_dgrad(Flash_bwd_params ¶ms, void run_mha_fwd(Flash_fwd_params ¶ms, cudaStream_t stream, bool force_split_kernel=false) { FP16_SWITCH(!params.is_bf16, [&] { HEADDIM_SWITCH(params.d, [&] { - if (params.num_splits <= 1 && !force_split_kernel) { // If we don't set it num_splits == 0 - run_mha_fwd_(params, stream); - } else { - run_mha_fwd_splitkv_dispatch(params, stream); - } + BOOL_SWITCH(params.is_causal, Is_causal, [&] { + if (params.num_splits <= 1 && !force_split_kernel) { // If we don't set it num_splits == 0 + run_mha_fwd_(params, stream); + } else { + run_mha_fwd_splitkv_dispatch(params, stream); + } + }); }); }); } @@ -325,6 +348,7 @@ mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size bool is_causal, int window_size_left, int window_size_right, + const float softcap, const bool return_softmax, c10::optional gen_) { @@ -359,10 +383,12 @@ mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size const int head_size_og = sizes[3]; const int seqlen_k = k.size(1); const int num_heads_k = k.size(2); - TORCH_CHECK(batch_size > 0, "batch size must be postive"); + TORCH_CHECK(batch_size > 0, "batch size must be positive"); TORCH_CHECK(head_size_og <= 256, "FlashAttention forward only supports head dimension at most 256"); TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query"); + if (softcap > 0.f) { TORCH_CHECK(p_dropout == 0.f, "Softcapping does not support dropout for now"); } + if (window_size_left >= seqlen_k) { window_size_left = -1; } if (window_size_right >= seqlen_k) { window_size_right = -1; } @@ -446,7 +472,9 @@ mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size p_dropout, softmax_scale, window_size_left, - window_size_right); + window_size_right, + softcap + ); set_params_splitkv(params, batch_size, num_heads, @@ -504,6 +532,7 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \s const at::Tensor &cu_seqlens_q, // b+1 const at::Tensor &cu_seqlens_k, // b+1 c10::optional &seqused_k, // b. If given, only this many elements of each batch element's keys are used. + c10::optional &leftpad_k_, // batch_size c10::optional &block_table_, // batch_size x max_num_blocks_per_seq c10::optional &alibi_slopes_, // num_heads or b x num_heads int max_seqlen_q, @@ -514,6 +543,7 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \s bool is_causal, int window_size_left, int window_size_right, + const float softcap, const bool return_softmax, c10::optional gen_) { @@ -562,6 +592,8 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \s const int head_size_og = sizes[2]; const int num_heads_k = paged_KV ? k.size(2) : k.size(1); + if (softcap > 0.f) { TORCH_CHECK(p_dropout == 0.f, "Softcapping does not support dropout for now"); } + const int max_num_blocks_per_seq = !paged_KV ? 0 : block_table.size(1); const int num_blocks = !paged_KV ? 0 : k.size(0); const int page_block_size = !paged_KV ? 1 : k.size(1); @@ -630,7 +662,6 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \s TORCH_CHECK(out.dtype() == q_dtype, "Output must have the same dtype as inputs"); CHECK_DEVICE(out); TORCH_CHECK(out.stride(-1) == 1, "Output tensor must have contiguous last dimension"); - CHECK_SHAPE(out, total_q, num_heads, head_size_og); CHECK_SHAPE(out, sizes[0], sizes[1], head_size_og); if (seqlenq_ngroups_swapped) { out = out.reshape({batch_size, num_heads_k, ngroups, head_size_og}).transpose(1, 2).reshape({batch_size * ngroups, num_heads_k, head_size_og}); @@ -651,8 +682,7 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \s at::cuda::CUDAGuard device_guard{(char)q.get_device()}; auto opts = q.options(); - - auto softmax_lse = torch::empty({batch_size, num_heads, max_seqlen_q}, opts.dtype(at::kFloat)); + auto softmax_lse = torch::empty({num_heads, total_q}, opts.dtype(at::kFloat)); at::Tensor p; // Only return softmax if there's dropout to reduce compilation time if (return_softmax) { @@ -683,7 +713,10 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \s softmax_scale, window_size_left, window_size_right, - seqlenq_ngroups_swapped); + softcap, + seqlenq_ngroups_swapped, + /*unpadded_lse*/true); + params.total_q = total_q; if (paged_KV) { params.block_table = block_table.data_ptr(); @@ -699,6 +732,16 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \s head_size_rounded, p_dropout, /*num_splits*/0, dprops, opts); } + if (leftpad_k_.has_value()) { + auto leftpad_k = leftpad_k_.value(); + TORCH_CHECK(!paged_KV, "We don't support Paged KV and leftpad_k running at the same time yet"); + TORCH_CHECK(leftpad_k.dtype() == torch::kInt32, "leftpad_k must have dtype int32"); + CHECK_DEVICE(leftpad_k); + CHECK_CONTIGUOUS(leftpad_k); + CHECK_SHAPE(leftpad_k, batch_size); + params.leftpad_k = static_cast(leftpad_k.data_ptr()); + } + // number of times random will be generated per thread, to offset philox counter in thc random // state // We use a custom RNG that increases the offset by batch_size * nheads * 32. @@ -739,7 +782,7 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \s out = out.reshape(size_before).transpose(1, 2).reshape(size_after); out_padded = out_padded.reshape(size_before).transpose(1, 2).reshape(size_after); q_padded = q_padded.reshape(size_before).transpose(1, 2).reshape(size_after); - softmax_lse = softmax_lse.reshape({batch_size, num_heads_k * max_seqlen_q, 1}); + softmax_lse = softmax_lse.reshape({num_heads * max_seqlen_q, batch_size}); } return {out, q_padded, k_padded, v_padded, out_padded, softmax_lse, p, rng_state}; @@ -769,6 +812,7 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_si const bool is_causal, int window_size_left, int window_size_right, + const float softcap, const bool deterministic, c10::optional gen_, c10::optional &rng_state) { @@ -933,7 +977,9 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_si softmax_scale, window_size_left, window_size_right, - deterministic); + softcap, + deterministic, + /*unpadded_lse*/false); params.dq_accum_split_stride = !deterministic ? 0 : dq_accum.stride(0); auto launch = &run_mha_bwd; @@ -986,7 +1032,7 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size const at::Tensor &k, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i const at::Tensor &v, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i const at::Tensor &out, // total_q x num_heads x head_size - const at::Tensor &softmax_lse, // b x h x s softmax logsumexp + const at::Tensor &softmax_lse, // h x total_q, softmax logsumexp c10::optional &dq_, // total_q x num_heads x head_size, total_q := \sum_{i=0}^{b} s_i c10::optional &dk_, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i c10::optional &dv_, // total_k x num_heads_k x head_size, total_k := \sum_{i=0}^{b} s_i @@ -1001,6 +1047,7 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size const bool is_causal, int window_size_left, int window_size_right, + const float softcap, const bool deterministic, c10::optional gen_, c10::optional &rng_state) { @@ -1126,7 +1173,7 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size at::cuda::CUDAGuard device_guard{(char)q.get_device()}; auto opts = q.options(); - auto softmax_d = torch::empty({batch_size, num_heads, seqlen_q_rounded}, opts.dtype(at::kFloat)); + auto softmax_d = torch::empty({num_heads, total_q + 128 * batch_size}, opts.dtype(at::kFloat)); at::Tensor dq_accum; if (loop) { // We don't want to allocate dq_accum of size (batch, seqlen_q_rounded, num_heads, head_size_rounded) @@ -1137,6 +1184,7 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size // cu_seqlens[i + 1] * 128 * i - 1. This ensures that the i-th sequence and (i + 1)-th sequence will // be at least 128 apart. It's ok for us to do atomicAdds up to 128 rows beyond what we're normally // allowed to do. So we won't have to do any bound checking, and performance should stay the same. + // Same holds for softmax_d, since LSE is stored in unpadded format. if (!deterministic) { dq_accum = torch::empty({total_q + 128 * batch_size, num_heads, head_size_rounded}, opts.dtype(at::kFloat)); } else { @@ -1182,8 +1230,11 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size softmax_scale, window_size_left, window_size_right, - deterministic); + softcap, + deterministic, + /*unpadded_lse*/true); params.dq_accum_split_stride = !deterministic ? 0 : dq_accum.stride(0); + params.total_q = total_q; auto launch = &run_mha_bwd; @@ -1239,6 +1290,7 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he c10::optional &rotary_cos_, // seqlen_ro x (rotary_dim / 2) c10::optional &rotary_sin_, // seqlen_ro x (rotary_dim / 2) c10::optional &cache_batch_idx_, // indices to index into the KV cache + c10::optional &leftpad_k_, // batch_size c10::optional &block_table_, // batch_size x max_num_blocks_per_seq c10::optional &alibi_slopes_, // num_heads or batch_size x num_heads c10::optional &out_, // batch_size x seqlen_q x num_heads x head_size @@ -1246,6 +1298,7 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he bool is_causal, int window_size_left, int window_size_right, + const float softcap, bool is_rotary_interleaved, // if true, rotary combines indices 0 & 1, else indices 0 & rotary_dim / 2 int num_splits ) { @@ -1297,7 +1350,7 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he const int seqlen_k = !paged_KV ? kcache.size(1) : max_num_blocks_per_seq * page_block_size; const int num_heads_k = kcache.size(2); const int batch_size_c = !paged_KV ? kcache.size(0) : batch_size; - TORCH_CHECK(batch_size > 0, "batch size must be postive"); + TORCH_CHECK(batch_size > 0, "batch size must be positive"); TORCH_CHECK(head_size_og <= 256, "FlashAttention forward only supports head dimension at most 256"); TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query"); @@ -1381,7 +1434,9 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he /*p_dropout=*/0.f, softmax_scale, window_size_left, - window_size_right); + window_size_right, + softcap + ); at::Tensor k, v, k_padded, v_padded; if (k_.has_value()) { @@ -1426,6 +1481,15 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he params.cu_seqlens_k = static_cast(seqlens_k.data_ptr()); } params.is_seqlens_k_cumulative = !(seqlens_k_.has_value()); + if (leftpad_k_.has_value()) { + TORCH_CHECK(!paged_KV, "We don't support Paged KV and leftpad_k running at the same time yet"); + auto leftpad_k = leftpad_k_.value(); + TORCH_CHECK(leftpad_k.dtype() == torch::kInt32, "leftpad_k must have dtype int32"); + CHECK_DEVICE(leftpad_k); + CHECK_CONTIGUOUS(leftpad_k); + CHECK_SHAPE(leftpad_k, batch_size); + params.leftpad_k = static_cast(leftpad_k.data_ptr()); + } if (rotary_cos_.has_value()) { TORCH_CHECK(k_.has_value(), "If rotary cos/sin are provided, new key / value to be appended to KV cache must also be provided"); diff --git a/csrc/flash_attn/src/alibi.h b/csrc/flash_attn/src/alibi.h index 80d297fc9..e714233e7 100644 --- a/csrc/flash_attn/src/alibi.h +++ b/csrc/flash_attn/src/alibi.h @@ -31,7 +31,7 @@ struct Alibi { const int col_idx_offset_, const int row_idx_offset, const int warp_row_stride) { - // tensor has shape (ncol=(2, MMA_M), nrow=(2, MMA_N)) + // tensor has shape (nrow=(2, MMA_M), ncol=(2, MMA_N)) static_assert(Layout::rank == 2, "Only support 2D Tensor"); const int lane_id = threadIdx.x % 32; const int col_idx_offset = col_idx_offset_ + (lane_id % 4) * 2; diff --git a/csrc/flash_attn/src/block_info.h b/csrc/flash_attn/src/block_info.h index 3a23a1e1f..cf60d653c 100644 --- a/csrc/flash_attn/src/block_info.h +++ b/csrc/flash_attn/src/block_info.h @@ -18,8 +18,9 @@ struct BlockInfo { , actual_seqlen_q(!Varlen || params.cu_seqlens_q == nullptr ? params.seqlen_q : params.cu_seqlens_q[bidb + 1] - sum_s_q) // If is_seqlens_k_cumulative, then seqlen_k is cu_seqlens_k[bidb + 1] - cu_seqlens_k[bidb]. // Otherwise it's cu_seqlens_k[bidb], i.e., we use cu_seqlens_k to store the sequence lengths of K. - , seqlen_k_cache(!Varlen || params.cu_seqlens_k == nullptr ? params.seqlen_k : (params.is_seqlens_k_cumulative ? params.cu_seqlens_k[bidb + 1] - sum_s_k : params.cu_seqlens_k[bidb])) - , actual_seqlen_k(params.seqused_k ? params.seqused_k[bidb] : seqlen_k_cache + (params.knew_ptr == nullptr ? 0 : params.seqlen_knew)) + , leftpad_k(params.leftpad_k == nullptr ? 0 : params.leftpad_k[bidb]) + , seqlen_k_cache((!Varlen || params.cu_seqlens_k == nullptr ? params.seqlen_k : (params.is_seqlens_k_cumulative ? params.cu_seqlens_k[bidb + 1] - sum_s_k : params.cu_seqlens_k[bidb])) - leftpad_k) + , actual_seqlen_k(params.seqused_k ? params.seqused_k[bidb] - leftpad_k : seqlen_k_cache + (params.knew_ptr == nullptr ? 0 : params.seqlen_knew)) { } @@ -30,13 +31,14 @@ struct BlockInfo { template __forceinline__ __device__ index_t k_offset(const index_t batch_stride, const index_t row_stride, const int bidb) const { - return sum_s_k == -1 ? bidb * batch_stride : uint32_t(sum_s_k) * row_stride; + return sum_s_k == -1 ? bidb * batch_stride + leftpad_k * row_stride : uint32_t(sum_s_k + leftpad_k) * row_stride; } const int sum_s_q; const int sum_s_k; const int actual_seqlen_q; // We have to have seqlen_k_cache declared before actual_seqlen_k, otherwise actual_seqlen_k is set to 0. + const int leftpad_k; const int seqlen_k_cache; const int actual_seqlen_k; }; diff --git a/csrc/flash_attn/src/flash.h b/csrc/flash_attn/src/flash.h index 88a7195fa..1a218b0d2 100644 --- a/csrc/flash_attn/src/flash.h +++ b/csrc/flash_attn/src/flash.h @@ -67,7 +67,7 @@ struct Flash_fwd_params : public Qkv_params { void * __restrict__ softmax_lseaccum_ptr; // The dimensions. - int b, seqlen_q, seqlen_k, seqlen_knew, d, seqlen_q_rounded, seqlen_k_rounded, d_rounded, rotary_dim; + int b, seqlen_q, seqlen_k, seqlen_knew, d, seqlen_q_rounded, seqlen_k_rounded, d_rounded, rotary_dim, total_q; // The scaling factors for the kernel. float scale_softmax; @@ -76,6 +76,7 @@ struct Flash_fwd_params : public Qkv_params { // array of length b+1 holding starting offset of each sequence. int * __restrict__ cu_seqlens_q; int * __restrict__ cu_seqlens_k; + int * __restrict__ leftpad_k; // If provided, the actual length of each k sequence. int * __restrict__ seqused_k; @@ -118,6 +119,7 @@ struct Flash_fwd_params : public Qkv_params { // Local window size int window_size_left, window_size_right; + float softcap; // Random state. at::PhiloxCudaState philox_args; @@ -138,6 +140,9 @@ struct Flash_fwd_params : public Qkv_params { void * __restrict__ alibi_slopes_ptr; index_t alibi_slopes_batch_stride; + + bool unpadded_lse; // For varlen paths: LSE is in [nheads, total_seqlen_q] format instead of [b, nheads, seqlen_q]. + bool seqlenq_ngroups_swapped; // q has been transposed from (b, 1, (nheads_kv ngroups), d) to (b, ngroups, nheads_kv, d). }; //////////////////////////////////////////////////////////////////////////////////////////////////// @@ -184,7 +189,7 @@ struct Flash_bwd_params : public Flash_fwd_params { //////////////////////////////////////////////////////////////////////////////////////////////////// -template void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream); -template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); template void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream); diff --git a/csrc/flash_attn/src/flash_bwd_kernel.h b/csrc/flash_attn/src/flash_bwd_kernel.h index 7d35209c0..00cbc081e 100644 --- a/csrc/flash_attn/src/flash_bwd_kernel.h +++ b/csrc/flash_attn/src/flash_bwd_kernel.h @@ -76,7 +76,7 @@ make_tiled_copy_C_warpcontiguousN(Copy_Atom const& copy_atom, //////////////////////////////////////////////////////////////////////////////////////////////////// -template +template inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const int bidb, const int bidh, const int n_block) { using Element = typename Kernel_traits::Element; @@ -120,10 +120,9 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in + ((m_block_max - 1) * kBlockM + (params.cu_seqlens_q == nullptr ? 0 : 128 * bidb)) * params.h * params.d_rounded + bidh * params.d_rounded // If deterministic, each thread block will do atomicAdd to a different dQ_accum buffer. + (!params.deterministic ? 0 : blockIdx.x * params.dq_accum_split_stride); - const index_t row_offset_lse = (bidb * params.h + bidh) * params.seqlen_q - + (m_block_max - 1) * kBlockM; - const index_t row_offset_dpsum = (bidb * params.h + bidh) * params.seqlen_q_rounded - + (m_block_max - 1) * kBlockM; + const index_t row_offset_lse = (params.unpadded_lse? bidh * params.total_q + binfo.q_offset(params.seqlen_q, 1, bidb): (bidb * params.h + bidh) * params.seqlen_q) + (m_block_max - 1) * kBlockM; + // Regarding 128 * params.b see a comment in mha_varlen_bwd about padding of dq_accum and softmax_d + const index_t row_offset_dpsum = (params.unpadded_lse? bidh * (params.total_q + 128 * params.b) + binfo.q_offset(params.seqlen_q_rounded, 1, bidb) + 128 * bidb: (bidb * params.h + bidh) * params.seqlen_q_rounded) + (m_block_max - 1) * kBlockM; Tensor gQ = make_tensor(make_gmem_ptr(reinterpret_cast(params.q_ptr) + row_offset_q), Shape, Int>{}, @@ -472,10 +471,27 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in flash::gemm(acc_s, tSrQ, tSrK, tSsQ, tSsK, tiled_mma_sdp, smem_tiled_copy_QdO, smem_tiled_copy_KV, smem_thr_copy_QdO, smem_thr_copy_KV); - // Reshape acc_s from (MMA=4, MMA_N, MMA_N) to (col=(2, MMA_N), row=(2, MMA_N)) + if constexpr (Is_softcap) { + flash::apply_softcap(acc_s, params.softcap); + } + + // Reshape acc_s from (MMA=4, MMA_N, MMA_N) to (row=(2, MMA_N), col=(2, MMA_N)) Tensor scores = make_tensor(acc_s.data(), flash::convert_layout_acc_rowcol(acc_s.layout())); // if (cute::thread(32, 0)) { print(scores); } + // Softcapping - calculating dTanh and scaling dS later with it + auto dtanh = ([&]{ + if constexpr (Is_softcap) { + Tensor _dtanh = make_tensor_like(scores); + flash::calculate_dtanh(scores, _dtanh, params.softcap); + return _dtanh; + } + else { + return nullptr; + } + }()); + + // Alibi if (Has_alibi) { alibi.apply_alibi(scores, n_block * kBlockN + (tidx / 32 / AtomLayoutMS) * MMA_N_SdP * 16, m_block * kBlockM + get<0>(taccScS_row(0)), AtomLayoutMS * 16); @@ -566,7 +582,7 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in smem_tiled_copy_QdO, smem_tiled_copy_KV, smem_thr_copy_QdO, smem_thr_copy_KV ); - // Reshape acc_dp from (MMA=4, MMA_N, MMA_N) to (col=(2, MMA_N), row=(2, MMA_N)) + // Reshape acc_dp from (MMA=4, MMA_N, MMA_N) to (row=(2, MMA_N), col=(2, MMA_N)) Tensor dS = make_tensor(acc_dp.data(), scores.layout()); auto pointwise_mult = [](float p, float dp, float d) { return p * (!Is_dropout || p >= 0 ? dp - d : d); @@ -575,7 +591,14 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params ¶ms, const in for (int mi = 0; mi < size<0>(dS); ++mi) { #pragma unroll for (int ni = 0; ni < size<1>(dS); ++ni) { - dS(mi, ni) = pointwise_mult(scores(mi, ni), dS(mi, ni), dP_sum(mi)); + + float scaled_ds = pointwise_mult(scores(mi, ni), dS(mi, ni), dP_sum(mi)); + + if constexpr (Is_softcap) { + scaled_ds *= dtanh(mi, ni); + } + + dS(mi, ni) = scaled_ds; } } // if (cute::thread0()) { print(dS); } @@ -808,7 +831,7 @@ inline __device__ void compute_dq_dk_dv(const Params ¶ms) { //////////////////////////////////////////////////////////////////////////////////////////////////// -template +template inline __device__ void compute_dq_dk_dv_seqk_parallel(const Params ¶ms) { // The block index for the batch. @@ -818,7 +841,7 @@ inline __device__ void compute_dq_dk_dv_seqk_parallel(const Params ¶ms) { // If deterministic, each thread block will do atomicAdd to a different dQ_accum buffer. for (int n_block = blockIdx.x; n_block < (params.seqlen_k + Kernel_traits::kBlockN - 1) / Kernel_traits::kBlockN; n_block += gridDim.x) { - compute_dq_dk_dv_1colblock(params, bidb, bidh, n_block); + compute_dq_dk_dv_1colblock(params, bidb, bidh, n_block); } } diff --git a/csrc/flash_attn/src/flash_bwd_launch_template.h b/csrc/flash_attn/src/flash_bwd_launch_template.h index fd81c8844..9168914ff 100644 --- a/csrc/flash_attn/src/flash_bwd_launch_template.h +++ b/csrc/flash_attn/src/flash_bwd_launch_template.h @@ -35,10 +35,10 @@ DEFINE_FLASH_BACKWARD_KERNEL(flash_bwd_dq_dk_dv_loop_kernel, bool Is_dropout, bo #endif } -DEFINE_FLASH_BACKWARD_KERNEL(flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel, bool Is_dropout, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K) { +DEFINE_FLASH_BACKWARD_KERNEL(flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel, bool Is_dropout, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, bool Is_softcap) { #if defined(ARCH_SUPPORTS_FLASH) static_assert(!(Is_causal && Is_local)); // If Is_local is true, Is_causal should be false - flash::compute_dq_dk_dv_seqk_parallel(params); + flash::compute_dq_dk_dv_seqk_parallel(params); #else FLASH_UNSUPPORTED_ARCH #endif @@ -95,17 +95,19 @@ void run_flash_bwd_seqk_parallel(Flash_bwd_params ¶ms, cudaStream_t stream) EVENK_SWITCH(is_even_K, IsEvenKConst, [&] { LOCAL_SWITCH((params.window_size_left >= 0 || params.window_size_right >= 0) && !params.is_causal, Is_local, [&] { ALIBI_SWITCH(params.alibi_slopes_ptr != nullptr, Has_alibi, [&] { - // If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates. - // If head dim > 128, set IsEvenMNConst to false to reduce number of templates - // If Is_local, set Is_causal to false - auto kernel = &flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel; - // auto kernel = &flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel; - if (smem_size_dq_dk_dv >= 48 * 1024) { - C10_CUDA_CHECK(cudaFuncSetAttribute( - kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size_dq_dk_dv)); - } - kernel<<>>(params); - C10_CUDA_KERNEL_LAUNCH_CHECK(); + SOFTCAP_SWITCH(params.softcap > 0.0, Is_softcap, [&] { + // If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates. + // If head dim > 128, set IsEvenMNConst to false to reduce number of templates + // If Is_local, set Is_causal to false + auto kernel = &flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel; + // auto kernel = &flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel; + if (smem_size_dq_dk_dv >= 48 * 1024) { + C10_CUDA_CHECK(cudaFuncSetAttribute( + kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size_dq_dk_dv)); + } + kernel<<>>(params); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + }); }); }); }); diff --git a/csrc/flash_attn/src/flash_bwd_preprocess_kernel.h b/csrc/flash_attn/src/flash_bwd_preprocess_kernel.h index aa0641530..c8e307417 100644 --- a/csrc/flash_attn/src/flash_bwd_preprocess_kernel.h +++ b/csrc/flash_attn/src/flash_bwd_preprocess_kernel.h @@ -79,7 +79,8 @@ inline __device__ void compute_dot_do_o(const Params ¶ms) { + m_block * kBlockM * params.o_row_stride + bidh * params.o_head_stride; const index_t row_offset_dq_accum = binfo.q_offset(params.seqlen_q_rounded * params.h * params.d_rounded, params.h * params.d_rounded, bidb) + (m_block * kBlockM + (params.cu_seqlens_q == nullptr ? 0 : 128 * bidb)) * params.h * params.d_rounded + bidh * params.d_rounded; - const index_t row_offset_dpsum = (bidb * params.h + bidh) * params.seqlen_q_rounded + m_block * kBlockM; + // Regarding 128 * params.b see a comment in mha_varlen_bwd about padding of dq_accum and softmax_d + const index_t row_offset_dpsum = (params.unpadded_lse ? (bidh * (params.total_q + 128 * params.b) + binfo.q_offset(params.seqlen_q_rounded, 1, bidb) + 128 * bidb): (bidb * params.h + bidh) * params.seqlen_q_rounded) + m_block * kBlockM; Tensor gdO = make_tensor(make_gmem_ptr(reinterpret_cast(params.do_ptr) + row_offset_do), Shape, Int>{}, diff --git a/csrc/flash_attn/src/flash_fwd_hdim128_bf16_causal_sm80.cu b/csrc/flash_attn/src/flash_fwd_hdim128_bf16_causal_sm80.cu new file mode 100644 index 000000000..f19049b49 --- /dev/null +++ b/csrc/flash_attn/src/flash_fwd_hdim128_bf16_causal_sm80.cu @@ -0,0 +1,10 @@ +// Copyright (c) 2023, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +template<> +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim128(params, stream); +} diff --git a/csrc/flash_attn/src/flash_fwd_hdim128_bf16_sm80.cu b/csrc/flash_attn/src/flash_fwd_hdim128_bf16_sm80.cu index 6ffa4126e..cb1357419 100644 --- a/csrc/flash_attn/src/flash_fwd_hdim128_bf16_sm80.cu +++ b/csrc/flash_attn/src/flash_fwd_hdim128_bf16_sm80.cu @@ -5,6 +5,6 @@ #include "flash_fwd_launch_template.h" template<> -void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { - run_mha_fwd_hdim128(params, stream); +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim128(params, stream); } diff --git a/csrc/flash_attn/src/flash_fwd_hdim128_fp16_causal_sm80.cu b/csrc/flash_attn/src/flash_fwd_hdim128_fp16_causal_sm80.cu new file mode 100644 index 000000000..dfb04b78b --- /dev/null +++ b/csrc/flash_attn/src/flash_fwd_hdim128_fp16_causal_sm80.cu @@ -0,0 +1,10 @@ +// Copyright (c) 2023, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +template<> +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim128(params, stream); +} diff --git a/csrc/flash_attn/src/flash_fwd_hdim128_fp16_sm80.cu b/csrc/flash_attn/src/flash_fwd_hdim128_fp16_sm80.cu index 19b005ad9..6df16b2c3 100644 --- a/csrc/flash_attn/src/flash_fwd_hdim128_fp16_sm80.cu +++ b/csrc/flash_attn/src/flash_fwd_hdim128_fp16_sm80.cu @@ -5,6 +5,6 @@ #include "flash_fwd_launch_template.h" template<> -void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { - run_mha_fwd_hdim128(params, stream); +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim128(params, stream); } diff --git a/csrc/flash_attn/src/flash_fwd_hdim160_bf16_causal_sm80.cu b/csrc/flash_attn/src/flash_fwd_hdim160_bf16_causal_sm80.cu new file mode 100644 index 000000000..230af9069 --- /dev/null +++ b/csrc/flash_attn/src/flash_fwd_hdim160_bf16_causal_sm80.cu @@ -0,0 +1,10 @@ +// Copyright (c) 2023, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +template<> +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim160(params, stream); +} diff --git a/csrc/flash_attn/src/flash_fwd_hdim160_bf16_sm80.cu b/csrc/flash_attn/src/flash_fwd_hdim160_bf16_sm80.cu index f674f4818..cf1ffad20 100644 --- a/csrc/flash_attn/src/flash_fwd_hdim160_bf16_sm80.cu +++ b/csrc/flash_attn/src/flash_fwd_hdim160_bf16_sm80.cu @@ -5,6 +5,6 @@ #include "flash_fwd_launch_template.h" template<> -void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { - run_mha_fwd_hdim160(params, stream); +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim160(params, stream); } diff --git a/csrc/flash_attn/src/flash_fwd_hdim160_fp16_causal_sm80.cu b/csrc/flash_attn/src/flash_fwd_hdim160_fp16_causal_sm80.cu new file mode 100644 index 000000000..1fc5ac597 --- /dev/null +++ b/csrc/flash_attn/src/flash_fwd_hdim160_fp16_causal_sm80.cu @@ -0,0 +1,10 @@ +// Copyright (c) 2023, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +template<> +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim160(params, stream); +} diff --git a/csrc/flash_attn/src/flash_fwd_hdim160_fp16_sm80.cu b/csrc/flash_attn/src/flash_fwd_hdim160_fp16_sm80.cu index afd0a8a38..a9796aded 100644 --- a/csrc/flash_attn/src/flash_fwd_hdim160_fp16_sm80.cu +++ b/csrc/flash_attn/src/flash_fwd_hdim160_fp16_sm80.cu @@ -5,6 +5,6 @@ #include "flash_fwd_launch_template.h" template<> -void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { - run_mha_fwd_hdim160(params, stream); +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim160(params, stream); } diff --git a/csrc/flash_attn/src/flash_fwd_hdim192_bf16_causal_sm80.cu b/csrc/flash_attn/src/flash_fwd_hdim192_bf16_causal_sm80.cu new file mode 100644 index 000000000..94792d4d3 --- /dev/null +++ b/csrc/flash_attn/src/flash_fwd_hdim192_bf16_causal_sm80.cu @@ -0,0 +1,10 @@ +// Copyright (c) 2023, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +template<> +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim192(params, stream); +} diff --git a/csrc/flash_attn/src/flash_fwd_hdim192_bf16_sm80.cu b/csrc/flash_attn/src/flash_fwd_hdim192_bf16_sm80.cu index aa91bdd66..76d5136b1 100644 --- a/csrc/flash_attn/src/flash_fwd_hdim192_bf16_sm80.cu +++ b/csrc/flash_attn/src/flash_fwd_hdim192_bf16_sm80.cu @@ -5,6 +5,6 @@ #include "flash_fwd_launch_template.h" template<> -void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { - run_mha_fwd_hdim192(params, stream); +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim192(params, stream); } diff --git a/csrc/flash_attn/src/flash_fwd_hdim192_fp16_causal_sm80.cu b/csrc/flash_attn/src/flash_fwd_hdim192_fp16_causal_sm80.cu new file mode 100644 index 000000000..9e5b21e02 --- /dev/null +++ b/csrc/flash_attn/src/flash_fwd_hdim192_fp16_causal_sm80.cu @@ -0,0 +1,10 @@ +// Copyright (c) 2023, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +template<> +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim192(params, stream); +} diff --git a/csrc/flash_attn/src/flash_fwd_hdim192_fp16_sm80.cu b/csrc/flash_attn/src/flash_fwd_hdim192_fp16_sm80.cu index 37a965264..b4019a0be 100644 --- a/csrc/flash_attn/src/flash_fwd_hdim192_fp16_sm80.cu +++ b/csrc/flash_attn/src/flash_fwd_hdim192_fp16_sm80.cu @@ -5,6 +5,6 @@ #include "flash_fwd_launch_template.h" template<> -void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { - run_mha_fwd_hdim192(params, stream); +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim192(params, stream); } diff --git a/csrc/flash_attn/src/flash_fwd_hdim224_bf16_causal_sm80.cu b/csrc/flash_attn/src/flash_fwd_hdim224_bf16_causal_sm80.cu new file mode 100644 index 000000000..a12a5f4ad --- /dev/null +++ b/csrc/flash_attn/src/flash_fwd_hdim224_bf16_causal_sm80.cu @@ -0,0 +1,10 @@ +// Copyright (c) 2023, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +template<> +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim224(params, stream); +} diff --git a/csrc/flash_attn/src/flash_fwd_hdim224_bf16_sm80.cu b/csrc/flash_attn/src/flash_fwd_hdim224_bf16_sm80.cu index 167a0df2b..8690bdb1a 100644 --- a/csrc/flash_attn/src/flash_fwd_hdim224_bf16_sm80.cu +++ b/csrc/flash_attn/src/flash_fwd_hdim224_bf16_sm80.cu @@ -5,6 +5,6 @@ #include "flash_fwd_launch_template.h" template<> -void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { - run_mha_fwd_hdim224(params, stream); +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim224(params, stream); } diff --git a/csrc/flash_attn/src/flash_fwd_hdim224_fp16_causal_sm80.cu b/csrc/flash_attn/src/flash_fwd_hdim224_fp16_causal_sm80.cu new file mode 100644 index 000000000..f01dad09c --- /dev/null +++ b/csrc/flash_attn/src/flash_fwd_hdim224_fp16_causal_sm80.cu @@ -0,0 +1,10 @@ +// Copyright (c) 2023, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +template<> +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim224(params, stream); +} diff --git a/csrc/flash_attn/src/flash_fwd_hdim224_fp16_sm80.cu b/csrc/flash_attn/src/flash_fwd_hdim224_fp16_sm80.cu index 58ffe75c3..7ec1e16b7 100644 --- a/csrc/flash_attn/src/flash_fwd_hdim224_fp16_sm80.cu +++ b/csrc/flash_attn/src/flash_fwd_hdim224_fp16_sm80.cu @@ -5,6 +5,6 @@ #include "flash_fwd_launch_template.h" template<> -void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { - run_mha_fwd_hdim224(params, stream); +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim224(params, stream); } diff --git a/csrc/flash_attn/src/flash_fwd_hdim256_bf16_causal_sm80.cu b/csrc/flash_attn/src/flash_fwd_hdim256_bf16_causal_sm80.cu new file mode 100644 index 000000000..3d816ab60 --- /dev/null +++ b/csrc/flash_attn/src/flash_fwd_hdim256_bf16_causal_sm80.cu @@ -0,0 +1,10 @@ +// Copyright (c) 2023, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +template<> +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim256(params, stream); +} diff --git a/csrc/flash_attn/src/flash_fwd_hdim256_bf16_sm80.cu b/csrc/flash_attn/src/flash_fwd_hdim256_bf16_sm80.cu index 1b3701415..c6c55229c 100644 --- a/csrc/flash_attn/src/flash_fwd_hdim256_bf16_sm80.cu +++ b/csrc/flash_attn/src/flash_fwd_hdim256_bf16_sm80.cu @@ -5,6 +5,6 @@ #include "flash_fwd_launch_template.h" template<> -void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { - run_mha_fwd_hdim256(params, stream); +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim256(params, stream); } diff --git a/csrc/flash_attn/src/flash_fwd_hdim256_fp16_causal_sm80.cu b/csrc/flash_attn/src/flash_fwd_hdim256_fp16_causal_sm80.cu new file mode 100644 index 000000000..0149abacd --- /dev/null +++ b/csrc/flash_attn/src/flash_fwd_hdim256_fp16_causal_sm80.cu @@ -0,0 +1,10 @@ +// Copyright (c) 2023, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +template<> +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim256(params, stream); +} diff --git a/csrc/flash_attn/src/flash_fwd_hdim256_fp16_sm80.cu b/csrc/flash_attn/src/flash_fwd_hdim256_fp16_sm80.cu index 9f35129c3..9c9a1715e 100644 --- a/csrc/flash_attn/src/flash_fwd_hdim256_fp16_sm80.cu +++ b/csrc/flash_attn/src/flash_fwd_hdim256_fp16_sm80.cu @@ -5,6 +5,6 @@ #include "flash_fwd_launch_template.h" template<> -void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { - run_mha_fwd_hdim256(params, stream); +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim256(params, stream); } diff --git a/csrc/flash_attn/src/flash_fwd_hdim32_bf16_causal_sm80.cu b/csrc/flash_attn/src/flash_fwd_hdim32_bf16_causal_sm80.cu new file mode 100644 index 000000000..29097ac3a --- /dev/null +++ b/csrc/flash_attn/src/flash_fwd_hdim32_bf16_causal_sm80.cu @@ -0,0 +1,10 @@ +// Copyright (c) 2023, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +template<> +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim32(params, stream); +} diff --git a/csrc/flash_attn/src/flash_fwd_hdim32_bf16_sm80.cu b/csrc/flash_attn/src/flash_fwd_hdim32_bf16_sm80.cu index 770de6fcf..cb52f34fa 100644 --- a/csrc/flash_attn/src/flash_fwd_hdim32_bf16_sm80.cu +++ b/csrc/flash_attn/src/flash_fwd_hdim32_bf16_sm80.cu @@ -5,6 +5,6 @@ #include "flash_fwd_launch_template.h" template<> -void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { - run_mha_fwd_hdim32(params, stream); +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim32(params, stream); } diff --git a/csrc/flash_attn/src/flash_fwd_hdim32_fp16_causal_sm80.cu b/csrc/flash_attn/src/flash_fwd_hdim32_fp16_causal_sm80.cu new file mode 100644 index 000000000..7bdadefbe --- /dev/null +++ b/csrc/flash_attn/src/flash_fwd_hdim32_fp16_causal_sm80.cu @@ -0,0 +1,10 @@ +// Copyright (c) 2023, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +template<> +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim32(params, stream); +} diff --git a/csrc/flash_attn/src/flash_fwd_hdim32_fp16_sm80.cu b/csrc/flash_attn/src/flash_fwd_hdim32_fp16_sm80.cu index 8dbf8b94a..44b388161 100644 --- a/csrc/flash_attn/src/flash_fwd_hdim32_fp16_sm80.cu +++ b/csrc/flash_attn/src/flash_fwd_hdim32_fp16_sm80.cu @@ -5,6 +5,6 @@ #include "flash_fwd_launch_template.h" template<> -void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { - run_mha_fwd_hdim32(params, stream); +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim32(params, stream); } diff --git a/csrc/flash_attn/src/flash_fwd_hdim64_bf16_causal_sm80.cu b/csrc/flash_attn/src/flash_fwd_hdim64_bf16_causal_sm80.cu new file mode 100644 index 000000000..99cd728bc --- /dev/null +++ b/csrc/flash_attn/src/flash_fwd_hdim64_bf16_causal_sm80.cu @@ -0,0 +1,10 @@ +// Copyright (c) 2023, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +template<> +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim64(params, stream); +} diff --git a/csrc/flash_attn/src/flash_fwd_hdim64_bf16_sm80.cu b/csrc/flash_attn/src/flash_fwd_hdim64_bf16_sm80.cu index 22eac8789..c11096ac1 100644 --- a/csrc/flash_attn/src/flash_fwd_hdim64_bf16_sm80.cu +++ b/csrc/flash_attn/src/flash_fwd_hdim64_bf16_sm80.cu @@ -5,6 +5,6 @@ #include "flash_fwd_launch_template.h" template<> -void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { - run_mha_fwd_hdim64(params, stream); +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim64(params, stream); } diff --git a/csrc/flash_attn/src/flash_fwd_hdim64_fp16_causal_sm80.cu b/csrc/flash_attn/src/flash_fwd_hdim64_fp16_causal_sm80.cu new file mode 100644 index 000000000..2fbcd44e6 --- /dev/null +++ b/csrc/flash_attn/src/flash_fwd_hdim64_fp16_causal_sm80.cu @@ -0,0 +1,10 @@ +// Copyright (c) 2023, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +template<> +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim64(params, stream); +} diff --git a/csrc/flash_attn/src/flash_fwd_hdim64_fp16_sm80.cu b/csrc/flash_attn/src/flash_fwd_hdim64_fp16_sm80.cu index e6da5dd2d..7b65a9c9e 100644 --- a/csrc/flash_attn/src/flash_fwd_hdim64_fp16_sm80.cu +++ b/csrc/flash_attn/src/flash_fwd_hdim64_fp16_sm80.cu @@ -5,6 +5,6 @@ #include "flash_fwd_launch_template.h" template<> -void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { - run_mha_fwd_hdim64(params, stream); +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim64(params, stream); } diff --git a/csrc/flash_attn/src/flash_fwd_hdim96_bf16_causal_sm80.cu b/csrc/flash_attn/src/flash_fwd_hdim96_bf16_causal_sm80.cu new file mode 100644 index 000000000..6fb3cf642 --- /dev/null +++ b/csrc/flash_attn/src/flash_fwd_hdim96_bf16_causal_sm80.cu @@ -0,0 +1,10 @@ +// Copyright (c) 2023, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +template<> +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim96(params, stream); +} diff --git a/csrc/flash_attn/src/flash_fwd_hdim96_bf16_sm80.cu b/csrc/flash_attn/src/flash_fwd_hdim96_bf16_sm80.cu index 9c003540c..e696b2f2c 100644 --- a/csrc/flash_attn/src/flash_fwd_hdim96_bf16_sm80.cu +++ b/csrc/flash_attn/src/flash_fwd_hdim96_bf16_sm80.cu @@ -5,6 +5,6 @@ #include "flash_fwd_launch_template.h" template<> -void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { - run_mha_fwd_hdim96(params, stream); +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim96(params, stream); } diff --git a/csrc/flash_attn/src/flash_fwd_hdim96_fp16_causal_sm80.cu b/csrc/flash_attn/src/flash_fwd_hdim96_fp16_causal_sm80.cu new file mode 100644 index 000000000..bb3b744d1 --- /dev/null +++ b/csrc/flash_attn/src/flash_fwd_hdim96_fp16_causal_sm80.cu @@ -0,0 +1,10 @@ +// Copyright (c) 2023, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +template<> +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim96(params, stream); +} diff --git a/csrc/flash_attn/src/flash_fwd_hdim96_fp16_sm80.cu b/csrc/flash_attn/src/flash_fwd_hdim96_fp16_sm80.cu index 8108696a0..5f3accc30 100644 --- a/csrc/flash_attn/src/flash_fwd_hdim96_fp16_sm80.cu +++ b/csrc/flash_attn/src/flash_fwd_hdim96_fp16_sm80.cu @@ -5,6 +5,6 @@ #include "flash_fwd_launch_template.h" template<> -void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { - run_mha_fwd_hdim96(params, stream); +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim96(params, stream); } diff --git a/csrc/flash_attn/src/flash_fwd_kernel.h b/csrc/flash_attn/src/flash_fwd_kernel.h index fd68cec12..788f3790e 100644 --- a/csrc/flash_attn/src/flash_fwd_kernel.h +++ b/csrc/flash_attn/src/flash_fwd_kernel.h @@ -24,7 +24,28 @@ using namespace cute; //////////////////////////////////////////////////////////////////////////////////////////////////// -template +template +__forceinline__ __device__ auto get_lse_tile(const Params ¶ms, const int bidb, const int bidh, const int m_block, const BlockInfo &binfo) { + // When params.unpadded_lse is false, LSE is written as (b, h, seqlen_q) - this is non-variable seqlen path. + // Otherwise, when params.seqlenq_ngroups_swapped is true, it is written as (h, seqlen_q, b) to account for seqlen_q <-> h swapping trick. + // Otherwise, it's written as (h, b, seqlen_q). + const bool varlen_q = params.unpadded_lse && !params.seqlenq_ngroups_swapped; + auto lse_offset = varlen_q ? binfo.q_offset(params.seqlen_q, 1, bidb) : 0; + auto gmem_ptr_lse = make_gmem_ptr(reinterpret_cast(params.softmax_lse_ptr) + lse_offset); + + auto lse_shape = varlen_q ? make_shape(1, params.h, params.total_q) : make_shape(params.b, params.h, params.seqlen_q); + auto lse_stride = params.seqlenq_ngroups_swapped ? make_stride(1, params.seqlen_q * params.b, params.b) : ( + params.unpadded_lse ? make_stride(params.h * params.total_q, params.total_q, 1) : make_stride(params.h * params.seqlen_q, params.seqlen_q, 1) + ); + + auto lse_layout = make_layout(lse_shape, lse_stride); + Tensor mLSE = make_tensor(gmem_ptr_lse, lse_layout); + auto mLSE_slice = varlen_q ? mLSE(0, bidh, _) : mLSE(bidb, bidh, _); + return local_tile(mLSE_slice, Shape>{}, make_coord(m_block)); +} + + +template inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bidb, const int bidh, const int m_block) { using Element = typename Kernel_traits::Element; @@ -74,10 +95,8 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi make_stride(params.o_row_stride, params.o_head_stride, _1{})); Tensor gO = local_tile(mO(_, bidh, _), Shape, Int>{}, make_coord(m_block, 0)); // (kBlockM, kHeadDim) - Tensor mLSE = make_tensor(make_gmem_ptr(reinterpret_cast(params.softmax_lse_ptr)), - make_shape(params.b, params.h, params.seqlen_q), - make_stride(params.h * params.seqlen_q, params.seqlen_q, _1{})); - Tensor gLSE = local_tile(mLSE(bidb, bidh, _), Shape>{}, make_coord(m_block)); + + Tensor gLSE = get_lse_tile(params, bidb, bidh, m_block, binfo); typename Kernel_traits::GmemTiledCopyO gmem_tiled_copy_O; auto gmem_thr_copy_O = gmem_tiled_copy_O.get_thread_slice(tidx); @@ -142,7 +161,7 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi typename Kernel_traits::SmemLayoutKV{}); Tensor sV = make_tensor(sK.data() + size(sK), typename Kernel_traits::SmemLayoutKV{}); Tensor sVt = make_tensor(sV.data(), typename Kernel_traits::SmemLayoutVtransposed{}); - Tensor sVtNoSwizzle = make_tensor(sV.data(), typename Kernel_traits::SmemLayoutVtransposedNoSwizzle{}); + Tensor sVtNoSwizzle = make_tensor(sV.data().get(), typename Kernel_traits::SmemLayoutVtransposedNoSwizzle{}); typename Kernel_traits::GmemTiledCopyQKV gmem_tiled_copy_QKV; auto gmem_thr_copy_QKV = gmem_tiled_copy_QKV.get_thread_slice(tidx); @@ -299,6 +318,9 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi smem_thr_copy_Q, smem_thr_copy_K ); // if (cute::thread0()) { print(acc_s); } + if constexpr (Is_softcap){ + flash::apply_softcap(acc_s, params.softcap); + } mask.template apply_mask( acc_s, n_block * kBlockN, m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, kNWarps * 16 @@ -362,6 +384,9 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi acc_s, tSrQ, tSrK, tSsQ, tSsK, tiled_mma, smem_tiled_copy_Q, smem_tiled_copy_K, smem_thr_copy_Q, smem_thr_copy_K ); + if constexpr (Is_softcap){ + flash::apply_softcap(acc_s, params.softcap); + } flash::cp_async_wait<0>(); __syncthreads(); @@ -424,10 +449,7 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi make_stride(params.o_row_stride, params.o_head_stride, _1{})); Tensor gO = local_tile(mO(_, bidh, _), Shape, Int>{}, make_coord(m_block, 0)); // (kBlockM, kHeadDim) - Tensor mLSE = make_tensor(make_gmem_ptr(reinterpret_cast(params.softmax_lse_ptr)), - make_shape(params.b, params.h, params.seqlen_q), - make_stride(params.h * params.seqlen_q, params.seqlen_q, _1{})); - Tensor gLSE = local_tile(mLSE(bidb, bidh, _), Shape>{}, make_coord(m_block)); + Tensor gLSE = get_lse_tile(params, bidb, bidh, m_block, binfo); typename Kernel_traits::GmemTiledCopyO gmem_tiled_copy_O; auto gmem_thr_copy_O = gmem_tiled_copy_O.get_thread_slice(tidx); @@ -470,7 +492,7 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi //////////////////////////////////////////////////////////////////////////////////////////////////// -template +template inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, const int bidb, const int bidh, const int m_block, const int n_split_idx, const int num_n_splits) { using Element = typename Kernel_traits::Element; @@ -586,7 +608,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons Tensor sK = make_tensor(sQ.data() + size(sQ), typename Kernel_traits::SmemLayoutKV{}); Tensor sV = make_tensor(sK.data() + size(sK), typename Kernel_traits::SmemLayoutKV{}); Tensor sVt = make_tensor(sV.data(), typename Kernel_traits::SmemLayoutVtransposed{}); - Tensor sVtNoSwizzle = make_tensor(sV.data(), typename Kernel_traits::SmemLayoutVtransposedNoSwizzle{}); + Tensor sVtNoSwizzle = make_tensor(sV.data().get(), typename Kernel_traits::SmemLayoutVtransposedNoSwizzle{}); typename Kernel_traits::GmemTiledCopyQKV gmem_tiled_copy_QKV; auto gmem_thr_copy_QKV = gmem_tiled_copy_QKV.get_thread_slice(tidx); @@ -660,7 +682,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons // Even if we have MQA / GQA, all threadblocks responsible for the same KV head are writing to // gmem. Technically it's a race condition, but they all write the same content anyway, and it's safe. // We want to do this so that all threadblocks can proceed right after they finish writing the KV cache. - const index_t row_offset_cossin = ((n_block_max - 1) * kBlockN) * (params.rotary_dim / 2); + const index_t row_offset_cossin = ((n_block_max - 1) * kBlockN + (params.leftpad_k == nullptr ? 0 : params.leftpad_k[bidb])) * (params.rotary_dim / 2); Tensor gCos = make_tensor(make_gmem_ptr(reinterpret_cast(params.rotary_cos_ptr) + row_offset_cossin), Shape, Int>{}, make_stride(params.rotary_dim / 2, _1{})); @@ -681,9 +703,11 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons // if (cute::thread(8, 0)) { print_tensor(gCos); } // if (cute::thread(0, 0)) { print_tensor(tRgCos); } - const index_t row_offset_knew = binfo.k_offset(params.knew_batch_stride, params.knew_row_stride, bidb) + // const index_t row_offset_knew = binfo.k_offset(params.knew_batch_stride, params.knew_row_stride, bidb) + const index_t row_offset_knew = bidb * params.knew_batch_stride + ((n_block_max - 1) * kBlockN) * params.knew_row_stride + (bidh / params.h_h_k_ratio) * params.knew_head_stride; - const index_t row_offset_vnew = binfo.k_offset(params.vnew_batch_stride, params.vnew_row_stride, bidb) + // const index_t row_offset_vnew = binfo.k_offset(params.vnew_batch_stride, params.vnew_row_stride, bidb) + const index_t row_offset_vnew = bidb * params.vnew_batch_stride + ((n_block_max - 1) * kBlockN) * params.vnew_row_stride + (bidh / params.h_h_k_ratio) * params.vnew_head_stride; // Subtract seqlen_k_cache * row stride so that conceptually gK and gKnew "line up". When we access them, // e.g. if gK has 128 rows and gKnew has 64 rows, we access gK[:128] and gKNew[128:128 + 64]. @@ -761,7 +785,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons flash::copy(gmem_tiled_copy_QKV, tQgQ, tQsQ, tQcQ, tQpQ, binfo.actual_seqlen_q - m_block * kBlockM); } else { - const index_t row_offset_cossin = (binfo.seqlen_k_cache + (Is_causal || Is_local ? m_block * kBlockM : 0)) * (params.rotary_dim / 2); + const index_t row_offset_cossin = (binfo.seqlen_k_cache + (params.leftpad_k == nullptr ? 0 : params.leftpad_k[bidb]) + (Is_causal || Is_local ? m_block * kBlockM : 0)) * (params.rotary_dim / 2); // If not causal, all the queries get the same the cos/sin, taken at location seqlen_k_cache. // We do this by setting the row stride of gCos / gSin to 0. Tensor gCos = make_tensor(make_gmem_ptr(reinterpret_cast(params.rotary_cos_ptr) + row_offset_cossin), @@ -854,6 +878,10 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons smem_thr_copy_Q, smem_thr_copy_K ); // if (cute::thread0()) { print(acc_s); } + if constexpr (Is_softcap){ + flash::apply_softcap(acc_s, params.softcap); + } + mask.template apply_mask( acc_s, n_block * kBlockN, m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, kNWarps * 16 @@ -925,6 +953,9 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons acc_s, tSrQ, tSrK, tSsQ, tSsK, tiled_mma, smem_tiled_copy_Q, smem_tiled_copy_K, smem_thr_copy_Q, smem_thr_copy_K ); + if constexpr (Is_softcap){ + flash::apply_softcap(acc_s, params.softcap); + } flash::cp_async_wait<0>(); __syncthreads(); @@ -986,7 +1017,9 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons + m_block * kBlockM * params.o_row_stride + bidh * params.o_head_stride; const index_t row_offset_oaccum = (((n_split_idx * params.b + bidb) * params.h + bidh) * params.seqlen_q + m_block * kBlockM) * params.d_rounded; - const index_t row_offset_lseaccum = ((n_split_idx * params.b + bidb) * params.h + bidh) * params.seqlen_q + m_block * kBlockM; + const index_t row_offset_lseaccum = (Split || !params.unpadded_lse ? + ((n_split_idx * params.b + bidb) * params.h + bidh) * params.seqlen_q : bidh * params.total_q + binfo.q_offset(params.seqlen_q, 1, bidb) + ) + m_block * kBlockM; Tensor gOaccum = make_tensor(make_gmem_ptr(reinterpret_cast(Split ? params.oaccum_ptr : params.o_ptr) + (Split ? row_offset_oaccum : row_offset_o)), Shape, Int>{}, @@ -1036,7 +1069,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons //////////////////////////////////////////////////////////////////////////////////////////////////// -template +template inline __device__ void compute_attn(const Params ¶ms) { const int m_block = blockIdx.x; // The block index for the batch. @@ -1052,12 +1085,12 @@ inline __device__ void compute_attn(const Params ¶ms) { // the attention matrix. This way, as long as we have the batch, head, and the location of // the 16 x 32 block within the attention matrix, we can generate the exact same dropout pattern. - flash::compute_attn_1rowblock(params, bidb, bidh, m_block); + flash::compute_attn_1rowblock(params, bidb, bidh, m_block); } //////////////////////////////////////////////////////////////////////////////////////////////////// -template +template inline __device__ void compute_attn_splitkv(const Params ¶ms) { const int m_block = blockIdx.x; // The block index for the batch. @@ -1066,7 +1099,7 @@ inline __device__ void compute_attn_splitkv(const Params ¶ms) { const int bidh = Split ? blockIdx.z - bidb * params.h : blockIdx.z; const int n_split_idx = Split ? blockIdx.y : 0; const int num_n_splits = Split ? gridDim.y : 1; - flash::compute_attn_1rowblock_splitkv(params, bidb, bidh, m_block, n_split_idx, num_n_splits); + flash::compute_attn_1rowblock_splitkv(params, bidb, bidh, m_block, n_split_idx, num_n_splits); } //////////////////////////////////////////////////////////////////////////////////////////////////// @@ -1092,21 +1125,36 @@ inline __device__ void combine_attn_seqk_parallel(const Params ¶ms) { const int tidx = threadIdx.x; const int bidx = blockIdx.x; + const index_t lse_size = params.b * params.h * params.seqlen_q; + const index_t row_offset_lse = bidx * kBlockM; Tensor gLSEaccum = make_tensor(make_gmem_ptr(reinterpret_cast(params.softmax_lseaccum_ptr) + row_offset_lse), Shape, Int>{}, - make_stride(params.b * params.h * params.seqlen_q, _1{})); + make_stride(lse_size, _1{})); + + // LSE format is different depending on params.unpadded_lse and params.seqlenq_ngroups_swapped, see comment in get_lse_tile. + // This tensor's layout maps row_offset_lse to {bidb, bidh, q_offset}. Tensor gLSE = make_tensor(make_gmem_ptr(reinterpret_cast(params.softmax_lse_ptr) + row_offset_lse), Shape>{}, Stride<_1>{}); + + // This layout maps row_offset_lse to {bidh, q_offset, bidb} or {bidh, bidb, q_offset}. + Layout flat_layout = make_layout(lse_size); + Layout orig_layout = make_layout(make_shape(params.seqlen_q, params.h, params.b)); + auto transposed_stride = params.seqlenq_ngroups_swapped ? make_stride(params.b, params.seqlen_q * params.b, 1) : make_stride(1, params.seqlen_q * params.b, params.seqlen_q); + Layout remapped_layout = make_layout(make_shape(params.seqlen_q, params.h, params.b), transposed_stride); + Layout final_layout = cute::composition(remapped_layout, cute::composition(orig_layout, flat_layout)); + + Tensor gLSE_unpadded = make_tensor(make_gmem_ptr(reinterpret_cast(params.softmax_lse_ptr)), final_layout); + constexpr int kNLsePerThread = (kMaxSplits * kBlockM + kNThreads - 1) / kNThreads; - // Read the LSE values from gmem and store them in shared memory, then tranpose them. + // Read the LSE values from gmem and store them in shared memory, then transpose them. constexpr int kRowsPerLoadLSE = kNThreads / kBlockM; #pragma unroll for (int l = 0; l < kNLsePerThread; ++l) { const int row = l * kRowsPerLoadLSE + tidx / kBlockM; const int col = tidx % kBlockM; - ElementAccum lse = (row < params.num_splits && col < params.b * params.h * params.seqlen_q - bidx * kBlockM) ? gLSEaccum(row, col) : -INFINITY; + ElementAccum lse = (row < params.num_splits && col < lse_size - bidx * kBlockM) ? gLSEaccum(row, col) : -INFINITY; if (row < kMaxSplits) { sLSE[row][col] = lse; } // if (bidx == 0 && tidx < 32) { printf("tidx = %d, row = %d, col = %d, lse = %f\n", tidx, row, col, lse); } } @@ -1145,7 +1193,16 @@ inline __device__ void combine_attn_seqk_parallel(const Params ¶ms) { // lse_logsum is log(0.0) = -INFINITY and we get NaN when we do lse_accum(l) - lse_logsum. ElementAccum lse_logsum = (lse_sum == 0.f || lse_sum != lse_sum) ? INFINITY : logf(lse_sum) + lse_max; // if (bidx == 0 && tidx < 32) { printf("tidx = %d, lse = %f, lse_max = %f, lse_logsum = %f\n", tidx, lse_accum(0), lse_max, lse_logsum); } - if (tidx % kRowsPerLoadTranspose == 0 && tidx / kRowsPerLoadTranspose < kBlockM) { gLSE(tidx / kRowsPerLoadTranspose) = lse_logsum; } + if (tidx % kRowsPerLoadTranspose == 0 && tidx / kRowsPerLoadTranspose < kBlockM) { + if (params.unpadded_lse) { + const index_t lse_offset = row_offset_lse + tidx / kRowsPerLoadTranspose; + if (lse_offset < lse_size) { + gLSE_unpadded(lse_offset) = lse_logsum; + } + } else { + gLSE(tidx / kRowsPerLoadTranspose) = lse_logsum; + } + } // Store the scales exp(lse - lse_logsum) in shared memory. #pragma unroll for (int l = 0; l < kNLsePerThread; ++l) { diff --git a/csrc/flash_attn/src/flash_fwd_launch_template.h b/csrc/flash_attn/src/flash_fwd_launch_template.h index fa6a6f6b2..eb8bceab4 100644 --- a/csrc/flash_attn/src/flash_fwd_launch_template.h +++ b/csrc/flash_attn/src/flash_fwd_launch_template.h @@ -26,18 +26,18 @@ template \ __global__ void kernelName(KERNEL_PARAM_MODIFIER const Flash_fwd_params params) -DEFINE_FLASH_FORWARD_KERNEL(flash_fwd_kernel, bool Is_dropout, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, bool Return_softmax) { +DEFINE_FLASH_FORWARD_KERNEL(flash_fwd_kernel, bool Is_dropout, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, bool Is_softcap, bool Return_softmax) { #if defined(ARCH_SUPPORTS_FLASH) static_assert(!(Is_causal && Is_local)); // Enforce constraints - flash::compute_attn(params); + flash::compute_attn(params); #else FLASH_UNSUPPORTED_ARCH #endif } -DEFINE_FLASH_FORWARD_KERNEL(flash_fwd_splitkv_kernel, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, bool Split, bool Append_KV) { +DEFINE_FLASH_FORWARD_KERNEL(flash_fwd_splitkv_kernel, bool Is_causal, bool Is_local, bool Has_alibi, bool Is_even_MN, bool Is_even_K, bool Is_softcap, bool Split, bool Append_KV) { #if defined(ARCH_SUPPORTS_FLASH) - flash::compute_attn_splitkv(params); + flash::compute_attn_splitkv(params); #else FLASH_UNSUPPORTED_ARCH #endif @@ -67,25 +67,27 @@ void run_flash_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { LOCAL_SWITCH((params.window_size_left >= 0 || params.window_size_right >= 0) && !Is_causal, Is_local, [&] { BOOL_SWITCH(return_softmax, ReturnSoftmaxConst, [&] { ALIBI_SWITCH(params.alibi_slopes_ptr != nullptr, Has_alibi, [&] { - // Will only return softmax if dropout, to reduce compilation time. - // If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates. - // If return_softmax, set IsEvenMNConst to false to reduce number of templates - // If head dim > 128, set IsEvenMNConst to false to reduce number of templates - // If Is_local, set Is_causal to false - auto kernel = &flash_fwd_kernel; - // auto kernel = &flash_fwd_kernel; - // printf("IsEvenMNConst = %d, IsEvenKConst = %d, Is_local = %d, Is_causal = %d, ReturnSoftmaxConst = %d, Is_dropout = %d\n", int(IsEvenMNConst), int(IsEvenKConst), int(Is_local), int(Is_causal), int(ReturnSoftmaxConst), int(Is_dropout)); - // auto kernel = &flash_fwd_kernel; - if (smem_size >= 48 * 1024) { - C10_CUDA_CHECK(cudaFuncSetAttribute( - kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); - } - // int ctas_per_sm; - // cudaError status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor( - // &ctas_per_sm, kernel, Kernel_traits::kNThreads, smem_size); - // printf("smem_size = %d, CTAs per SM = %d\n", int(smem_size), ctas_per_sm); - kernel<<>>(params); - C10_CUDA_KERNEL_LAUNCH_CHECK(); + SOFTCAP_SWITCH(params.softcap > 0.0, Is_softcap, [&] { + // Will only return softmax if dropout, to reduce compilation time. + // If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates. + // If return_softmax, set IsEvenMNConst to false to reduce number of templates + // If head dim > 128, set IsEvenMNConst to false to reduce number of templates + // If Is_local, set Is_causal to false + auto kernel = &flash_fwd_kernel; + // auto kernel = &flash_fwd_kernel; + // printf("IsEvenMNConst = %d, IsEvenKConst = %d, Is_local = %d, Is_causal = %d, ReturnSoftmaxConst = %d, Is_dropout = %d\n", int(IsEvenMNConst), int(IsEvenKConst), int(Is_local), int(Is_causal), int(ReturnSoftmaxConst), int(Is_dropout)); + // auto kernel = &flash_fwd_kernel; + if (smem_size >= 48 * 1024) { + C10_CUDA_CHECK(cudaFuncSetAttribute( + kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); + } + // int ctas_per_sm; + // cudaError status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor( + // &ctas_per_sm, kernel, Kernel_traits::kNThreads, smem_size); + // printf("smem_size = %d, CTAs per SM = %d\n", int(smem_size), ctas_per_sm); + kernel<<>>(params); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + }); }); }); }); @@ -93,7 +95,7 @@ void run_flash_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { }); } -template +template void run_flash_splitkv_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { static_assert(!Kernel_traits::Is_Q_in_regs, "SplitKV implementation does not support Is_Q_in_regs"); static_assert(!Kernel_traits::Share_Q_K_smem, "SplitKV implementation does not support Share_Q_K_smem"); @@ -102,17 +104,17 @@ void run_flash_splitkv_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { dim3 grid(num_m_block, params.num_splits > 1 ? params.num_splits : params.b, params.num_splits > 1 ? params.b * params.h : params.h); const bool is_even_MN = params.cu_seqlens_q == nullptr && params.cu_seqlens_k == nullptr && params.seqlen_k % Kernel_traits::kBlockN == 0 && params.seqlen_q % Kernel_traits::kBlockM == 0; const bool is_even_K = params.d == Kernel_traits::kHeadDim; - BOOL_SWITCH(params.is_causal, Is_causal, [&] { - BOOL_SWITCH(is_even_MN, IsEvenMNConst, [&] { - EVENK_SWITCH(is_even_K, IsEvenKConst, [&] { - LOCAL_SWITCH((params.window_size_left >= 0 || params.window_size_right >= 0) && !Is_causal, Is_local, [&] { - BOOL_SWITCH(params.num_splits > 1, Split, [&] { - BOOL_SWITCH(params.knew_ptr != nullptr, Append_KV, [&] { - ALIBI_SWITCH(params.alibi_slopes_ptr != nullptr, Has_alibi, [&] { + BOOL_SWITCH(is_even_MN, IsEvenMNConst, [&] { + EVENK_SWITCH(is_even_K, IsEvenKConst, [&] { + LOCAL_SWITCH((params.window_size_left >= 0 || params.window_size_right >= 0) && !Is_causal, Is_local, [&] { + BOOL_SWITCH(params.num_splits > 1, Split, [&] { + BOOL_SWITCH(params.knew_ptr != nullptr, Append_KV, [&] { + ALIBI_SWITCH(params.alibi_slopes_ptr != nullptr, Has_alibi, [&] { + SOFTCAP_SWITCH(params.softcap > 0.0, Is_softcap, [&] { // If Append_KV, then we must have seqlen_offsets, which means cu_seqlens_k != nullptr. // If not IsEvenKConst, we also set IsEvenMNConst to false to reduce number of templates. // If Is_local, set Is_causal to false - auto kernel = &flash_fwd_splitkv_kernel; + auto kernel = &flash_fwd_splitkv_kernel; // auto kernel = &flash_fwd_splitkv_kernel; // auto kernel = &flash_fwd_splitkv_kernel; if (smem_size >= 48 * 1024) { @@ -155,161 +157,149 @@ void run_flash_splitkv_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { } } -template +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream) { constexpr static int kBlockM = 64; // Fixed for all head dimensions // TD [2023-08-28]: nvcc segfaults for headdim 96 with block size 64 x 256, // and for headdim 192 with block size 64 x 128. // Also for headdim 160 with block size 64 x 128 after the rotary addition. constexpr static int kBlockN = Headdim <= 64 ? 256 : (Headdim <= 128 ? 128 : 64); - run_flash_splitkv_fwd>(params, stream); + run_flash_splitkv_fwd, Is_causal>(params, stream); } -template +template void run_mha_fwd_hdim32(Flash_fwd_params ¶ms, cudaStream_t stream) { constexpr static int Headdim = 32; DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { - BOOL_SWITCH(params.is_causal, Is_causal, [&] { - run_flash_fwd, Is_dropout, Is_causal>(params, stream); - }); + run_flash_fwd, Is_dropout, Is_causal>(params, stream); }); } -template +template void run_mha_fwd_hdim64(Flash_fwd_params ¶ms, cudaStream_t stream) { constexpr static int Headdim = 64; DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { - BOOL_SWITCH(params.is_causal, Is_causal, [&] { - if constexpr(!Is_dropout) { - // Using 8 warps is 18% slower for seqlen=2k, 2 warps is 5% slower - // Using block size (64 x 256) is 27% slower for seqlen=2k - // Using block size (256 x 64) is 85% slower for seqlen=2k, because of register spilling - run_flash_fwd, Is_dropout, Is_causal>(params, stream); - // run_flash_fwd, Is_dropout, Is_causal>(params, stream); - // run_flash_fwd, Is_dropout, Is_causal>(params, stream); - } else { - run_flash_fwd, Is_dropout, Is_causal>(params, stream); - // run_flash_fwd, Is_dropout, Is_causal>(params, stream); - // run_flash_fwd, Is_dropout, Is_causal>(params, stream); - // run_flash_fwd, Is_dropout, Is_causal>(params, stream); - } - }); + if constexpr(!Is_dropout) { + // Using 8 warps is 18% slower for seqlen=2k, 2 warps is 5% slower + // Using block size (64 x 256) is 27% slower for seqlen=2k + // Using block size (256 x 64) is 85% slower for seqlen=2k, because of register spilling + run_flash_fwd, Is_dropout, Is_causal>(params, stream); + // run_flash_fwd, Is_dropout, Is_causal>(params, stream); + // run_flash_fwd, Is_dropout, Is_causal>(params, stream); + } else { + run_flash_fwd, Is_dropout, Is_causal>(params, stream); + // run_flash_fwd, Is_dropout, Is_causal>(params, stream); + // run_flash_fwd, Is_dropout, Is_causal>(params, stream); + // run_flash_fwd, Is_dropout, Is_causal>(params, stream); + } }); } -template +template void run_mha_fwd_hdim96(Flash_fwd_params ¶ms, cudaStream_t stream) { constexpr static int Headdim = 96; auto dprops = at::cuda::getCurrentDeviceProperties(); bool is_sm8x = dprops->major == 8 && dprops->minor > 0; DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { - BOOL_SWITCH(params.is_causal, Is_causal, [&] { - // For sm86 or sm89, 64 x 64 is the fastest for causal (because it's square), - if (is_sm8x) { - if constexpr(!Is_causal) { - run_flash_fwd, Is_dropout, Is_causal>(params, stream); - } else { - run_flash_fwd, Is_dropout, Is_causal>(params, stream); - } - } else { + // For sm86 or sm89, 64 x 64 is the fastest for causal (because it's square), + if (is_sm8x) { + if constexpr(!Is_causal) { run_flash_fwd, Is_dropout, Is_causal>(params, stream); + } else { + run_flash_fwd, Is_dropout, Is_causal>(params, stream); } - // run_flash_fwd, Is_dropout, Is_causal>(params, stream); - // run_flash_fwd, Is_dropout, Is_causal>(params, stream); - // These two are always slower - // run_flash_fwd>(params, stream); - // run_flash_fwd>(params, stream); - }); + } else { + run_flash_fwd, Is_dropout, Is_causal>(params, stream); + } + // run_flash_fwd, Is_dropout, Is_causal>(params, stream); + // run_flash_fwd, Is_dropout, Is_causal>(params, stream); + // These two are always slower + // run_flash_fwd>(params, stream); + // run_flash_fwd>(params, stream); }); } -template +template void run_mha_fwd_hdim128(Flash_fwd_params ¶ms, cudaStream_t stream) { constexpr static int Headdim = 128; auto dprops = at::cuda::getCurrentDeviceProperties(); bool is_sm8x = dprops->major == 8 && dprops->minor > 0; DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { - BOOL_SWITCH(params.is_causal, Is_causal, [&] { - if constexpr(!Is_dropout) { - // For sm86 or sm89, 64 x 64 is the fastest for causal (because it's square), - // and 128 x 32 (48 KB smem) is the fastest for non-causal since we get 2 CTAs per SM. - if (is_sm8x) { - if constexpr(!Is_causal) { - run_flash_fwd, Is_dropout, Is_causal>(params, stream); - } else { - run_flash_fwd, Is_dropout, Is_causal>(params, stream); - } + if constexpr(!Is_dropout) { + // For sm86 or sm89, 64 x 64 is the fastest for causal (because it's square), + // and 128 x 32 (48 KB smem) is the fastest for non-causal since we get 2 CTAs per SM. + if (is_sm8x) { + if constexpr(!Is_causal) { + run_flash_fwd, Is_dropout, Is_causal>(params, stream); } else { - run_flash_fwd, Is_dropout, Is_causal>(params, stream); + run_flash_fwd, Is_dropout, Is_causal>(params, stream); } - // run_flash_fwd, Is_dropout, Is_causal>(params, stream); - // run_flash_fwd, Is_dropout, Is_causal>(params, stream); - // run_flash_fwd, Is_dropout, Is_causal>(params, stream); - // Using 8 warps (128 x 128 and 256 x 64) is 28% slower for seqlen=2k - // run_flash_fwd, Is_dropout, Is_causal>(params, stream); - // run_flash_fwd, Is_dropout, Is_causal>(params, stream); - // 1st ones are good for H100, A100 - // 2nd one is good for A6000 bc we get slightly better occupancy } else { - run_flash_fwd, Is_dropout, Is_causal>(params, stream); - // run_flash_fwd, Is_dropout, Is_causal>(params, stream); - // run_flash_fwd, Is_dropout, Is_causal>(params, stream); - // run_flash_fwd, Is_dropout, Is_causal>(params, stream); + run_flash_fwd, Is_dropout, Is_causal>(params, stream); } - }); + // run_flash_fwd, Is_dropout, Is_causal>(params, stream); + // run_flash_fwd, Is_dropout, Is_causal>(params, stream); + // run_flash_fwd, Is_dropout, Is_causal>(params, stream); + // Using 8 warps (128 x 128 and 256 x 64) is 28% slower for seqlen=2k + // run_flash_fwd, Is_dropout, Is_causal>(params, stream); + // run_flash_fwd, Is_dropout, Is_causal>(params, stream); + // 1st ones are good for H100, A100 + // 2nd one is good for A6000 bc we get slightly better occupancy + } else { + run_flash_fwd, Is_dropout, Is_causal>(params, stream); + // run_flash_fwd, Is_dropout, Is_causal>(params, stream); + // run_flash_fwd, Is_dropout, Is_causal>(params, stream); + // run_flash_fwd, Is_dropout, Is_causal>(params, stream); + } }); } -template +template void run_mha_fwd_hdim160(Flash_fwd_params ¶ms, cudaStream_t stream) { constexpr static int Headdim = 160; auto dprops = at::cuda::getCurrentDeviceProperties(); bool is_sm8x = dprops->major == 8 && dprops->minor > 0; DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { - BOOL_SWITCH(params.is_causal, Is_causal, [&] { - // For A100, H100, 128 x 32 is the fastest. - // For sm86 or sm89, 64 x 64 is the fastest for causal (because it's square), - // and 128 x 64 with 8 warps is the fastest for non-causal. - if (is_sm8x) { - if constexpr(!Is_causal) { - run_flash_fwd, Is_dropout, Is_causal>(params, stream); - } else { - run_flash_fwd, Is_dropout, Is_causal>(params, stream); - } + // For A100, H100, 128 x 32 is the fastest. + // For sm86 or sm89, 64 x 64 is the fastest for causal (because it's square), + // and 128 x 64 with 8 warps is the fastest for non-causal. + if (is_sm8x) { + if constexpr(!Is_causal) { + run_flash_fwd, Is_dropout, Is_causal>(params, stream); } else { - run_flash_fwd, Is_dropout, Is_causal>(params, stream); + run_flash_fwd, Is_dropout, Is_causal>(params, stream); } - // run_flash_fwd, Is_dropout, Is_causal>(params, stream); - // run_flash_fwd, Is_dropout, Is_causal>(params, stream); - // run_flash_fwd>(params, stream); - // run_flash_fwd>(params, stream); - // run_flash_fwd>(params, stream); - // run_flash_fwd>(params, stream); - // run_flash_fwd>(params, stream); - }); + } else { + run_flash_fwd, Is_dropout, Is_causal>(params, stream); + } + // run_flash_fwd, Is_dropout, Is_causal>(params, stream); + // run_flash_fwd, Is_dropout, Is_causal>(params, stream); + // run_flash_fwd>(params, stream); + // run_flash_fwd>(params, stream); + // run_flash_fwd>(params, stream); + // run_flash_fwd>(params, stream); + // run_flash_fwd>(params, stream); }); } -template +template void run_mha_fwd_hdim192(Flash_fwd_params ¶ms, cudaStream_t stream) { constexpr static int Headdim = 192; DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { - BOOL_SWITCH(params.is_causal, Is_causal, [&] { - if constexpr(!Is_dropout) { - run_flash_fwd, Is_dropout, Is_causal>(params, stream); - } else { - run_flash_fwd, Is_dropout, Is_causal>(params, stream); - } - // run_flash_fwd, Is_dropout, Is_causal>(params, stream); - // run_flash_fwd, Is_dropout, Is_causal>(params, stream); - // run_flash_fwd>(params, stream); - // run_flash_fwd>(params, stream); - // run_flash_fwd>(params, stream); - }); + if constexpr(!Is_dropout) { + run_flash_fwd, Is_dropout, Is_causal>(params, stream); + } else { + run_flash_fwd, Is_dropout, Is_causal>(params, stream); + } + // run_flash_fwd, Is_dropout, Is_causal>(params, stream); + // run_flash_fwd, Is_dropout, Is_causal>(params, stream); + // run_flash_fwd>(params, stream); + // run_flash_fwd>(params, stream); + // run_flash_fwd>(params, stream); }); } -template +template void run_mha_fwd_hdim224(Flash_fwd_params ¶ms, cudaStream_t stream) { constexpr static int Headdim = 224; int device; @@ -322,23 +312,21 @@ void run_mha_fwd_hdim224(Flash_fwd_params ¶ms, cudaStream_t stream) { } // printf("max_smem_per_block = %d\n", max_smem_per_block); DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { - BOOL_SWITCH(params.is_causal, Is_causal, [&] { - if (max_smem_per_block >= 2 * Headdim * (128 + 2 * 64)) { // 112 KB - run_flash_fwd, Is_dropout, Is_causal>(params, stream); - } else { - run_flash_fwd, Is_dropout, Is_causal>(params, stream); - } - // run_flash_fwd, Is_dropout, Is_causal>(params, stream); - // run_flash_fwd, Is_dropout, Is_causal>(params, stream); - // We can't do 128 x 32 with 8 warps because with headdim 224, kBlockKSmem = 32. - // If we have N = 32, there are only 1024 elements to load at once, where each load - // is 8 elements. This means we can only use 128 threads and not 256 threads. - // run_flash_fwd, Is_dropout, Is_causal>(params, stream); - }); + if (max_smem_per_block >= 2 * Headdim * (128 + 2 * 64)) { // 112 KB + run_flash_fwd, Is_dropout, Is_causal>(params, stream); + } else { + run_flash_fwd, Is_dropout, Is_causal>(params, stream); + } + // run_flash_fwd, Is_dropout, Is_causal>(params, stream); + // run_flash_fwd, Is_dropout, Is_causal>(params, stream); + // We can't do 128 x 32 with 8 warps because with headdim 224, kBlockKSmem = 32. + // If we have N = 32, there are only 1024 elements to load at once, where each load + // is 8 elements. This means we can only use 128 threads and not 256 threads. + // run_flash_fwd, Is_dropout, Is_causal>(params, stream); }); } -template +template void run_mha_fwd_hdim256(Flash_fwd_params ¶ms, cudaStream_t stream) { constexpr static int Headdim = 256; int device; @@ -353,18 +341,16 @@ void run_mha_fwd_hdim256(Flash_fwd_params ¶ms, cudaStream_t stream) { } // printf("max_smem_per_sm = %d, max_smem_per_block = %d\n", max_smem_per_sm, max_smem_per_block); DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] { - BOOL_SWITCH(params.is_causal, Is_causal, [&] { - // For A100, we want to run with 128 x 64 (128KB smem). - // For H100 we want to run with 64 x 64 (96KB smem) since then we can get 2 CTAs per SM. - if (max_smem_per_block >= 2 * Headdim * (128 + 2 * 64) && max_smem_per_sm < 4 * Headdim * (64 + 2 * 64)) { - run_flash_fwd, Is_dropout, Is_causal>(params, stream); - } else { - run_flash_fwd, Is_dropout, Is_causal>(params, stream); - } - // 64 KB - // run_flash_fwd, Is_dropout, Is_causal>(params, stream); - // 96 KB - // run_flash_fwd, Is_dropout, Is_causal>(params, stream); - }); + // For A100, we want to run with 128 x 64 (128KB smem). + // For H100 we want to run with 64 x 64 (96KB smem) since then we can get 2 CTAs per SM. + if (max_smem_per_block >= 2 * Headdim * (128 + 2 * 64) && max_smem_per_sm < 4 * Headdim * (64 + 2 * 64)) { + run_flash_fwd, Is_dropout, Is_causal>(params, stream); + } else { + run_flash_fwd, Is_dropout, Is_causal>(params, stream); + } + // 64 KB + // run_flash_fwd, Is_dropout, Is_causal>(params, stream); + // 96 KB + // run_flash_fwd, Is_dropout, Is_causal>(params, stream); }); } diff --git a/csrc/flash_attn/src/flash_fwd_split_hdim128_bf16_causal_sm80.cu b/csrc/flash_attn/src/flash_fwd_split_hdim128_bf16_causal_sm80.cu new file mode 100644 index 000000000..9500ff73f --- /dev/null +++ b/csrc/flash_attn/src/flash_fwd_split_hdim128_bf16_causal_sm80.cu @@ -0,0 +1,7 @@ +// Copyright (c) 2023, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); diff --git a/csrc/flash_attn/src/flash_fwd_split_hdim128_bf16_sm80.cu b/csrc/flash_attn/src/flash_fwd_split_hdim128_bf16_sm80.cu index 477c560a7..f868234de 100644 --- a/csrc/flash_attn/src/flash_fwd_split_hdim128_bf16_sm80.cu +++ b/csrc/flash_attn/src/flash_fwd_split_hdim128_bf16_sm80.cu @@ -4,4 +4,4 @@ #include "flash_fwd_launch_template.h" -template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); diff --git a/csrc/flash_attn/src/flash_fwd_split_hdim128_fp16_causal_sm80.cu b/csrc/flash_attn/src/flash_fwd_split_hdim128_fp16_causal_sm80.cu new file mode 100644 index 000000000..1b59e9ea8 --- /dev/null +++ b/csrc/flash_attn/src/flash_fwd_split_hdim128_fp16_causal_sm80.cu @@ -0,0 +1,7 @@ +// Copyright (c) 2023, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); diff --git a/csrc/flash_attn/src/flash_fwd_split_hdim128_fp16_sm80.cu b/csrc/flash_attn/src/flash_fwd_split_hdim128_fp16_sm80.cu index 914cd23bb..b8863f1ba 100644 --- a/csrc/flash_attn/src/flash_fwd_split_hdim128_fp16_sm80.cu +++ b/csrc/flash_attn/src/flash_fwd_split_hdim128_fp16_sm80.cu @@ -4,4 +4,4 @@ #include "flash_fwd_launch_template.h" -template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); diff --git a/csrc/flash_attn/src/flash_fwd_split_hdim160_bf16_causal_sm80.cu b/csrc/flash_attn/src/flash_fwd_split_hdim160_bf16_causal_sm80.cu new file mode 100644 index 000000000..cf0d6539c --- /dev/null +++ b/csrc/flash_attn/src/flash_fwd_split_hdim160_bf16_causal_sm80.cu @@ -0,0 +1,7 @@ +// Copyright (c) 2023, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); diff --git a/csrc/flash_attn/src/flash_fwd_split_hdim160_bf16_sm80.cu b/csrc/flash_attn/src/flash_fwd_split_hdim160_bf16_sm80.cu index d753d59d5..ab5a6d002 100644 --- a/csrc/flash_attn/src/flash_fwd_split_hdim160_bf16_sm80.cu +++ b/csrc/flash_attn/src/flash_fwd_split_hdim160_bf16_sm80.cu @@ -4,4 +4,4 @@ #include "flash_fwd_launch_template.h" -template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); diff --git a/csrc/flash_attn/src/flash_fwd_split_hdim160_fp16_causal_sm80.cu b/csrc/flash_attn/src/flash_fwd_split_hdim160_fp16_causal_sm80.cu new file mode 100644 index 000000000..7a018205a --- /dev/null +++ b/csrc/flash_attn/src/flash_fwd_split_hdim160_fp16_causal_sm80.cu @@ -0,0 +1,7 @@ +// Copyright (c) 2023, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); diff --git a/csrc/flash_attn/src/flash_fwd_split_hdim160_fp16_sm80.cu b/csrc/flash_attn/src/flash_fwd_split_hdim160_fp16_sm80.cu index 552c25d02..0f7a7b322 100644 --- a/csrc/flash_attn/src/flash_fwd_split_hdim160_fp16_sm80.cu +++ b/csrc/flash_attn/src/flash_fwd_split_hdim160_fp16_sm80.cu @@ -4,4 +4,4 @@ #include "flash_fwd_launch_template.h" -template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); diff --git a/csrc/flash_attn/src/flash_fwd_split_hdim192_bf16_causal_sm80.cu b/csrc/flash_attn/src/flash_fwd_split_hdim192_bf16_causal_sm80.cu new file mode 100644 index 000000000..0ebad0649 --- /dev/null +++ b/csrc/flash_attn/src/flash_fwd_split_hdim192_bf16_causal_sm80.cu @@ -0,0 +1,7 @@ +// Copyright (c) 2023, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); diff --git a/csrc/flash_attn/src/flash_fwd_split_hdim192_bf16_sm80.cu b/csrc/flash_attn/src/flash_fwd_split_hdim192_bf16_sm80.cu index e6b350a7c..738d8dcd3 100644 --- a/csrc/flash_attn/src/flash_fwd_split_hdim192_bf16_sm80.cu +++ b/csrc/flash_attn/src/flash_fwd_split_hdim192_bf16_sm80.cu @@ -4,4 +4,4 @@ #include "flash_fwd_launch_template.h" -template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); diff --git a/csrc/flash_attn/src/flash_fwd_split_hdim192_fp16_causal_sm80.cu b/csrc/flash_attn/src/flash_fwd_split_hdim192_fp16_causal_sm80.cu new file mode 100644 index 000000000..8cafda539 --- /dev/null +++ b/csrc/flash_attn/src/flash_fwd_split_hdim192_fp16_causal_sm80.cu @@ -0,0 +1,7 @@ +// Copyright (c) 2023, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); diff --git a/csrc/flash_attn/src/flash_fwd_split_hdim192_fp16_sm80.cu b/csrc/flash_attn/src/flash_fwd_split_hdim192_fp16_sm80.cu index b9c193501..27b2d65a0 100644 --- a/csrc/flash_attn/src/flash_fwd_split_hdim192_fp16_sm80.cu +++ b/csrc/flash_attn/src/flash_fwd_split_hdim192_fp16_sm80.cu @@ -4,4 +4,4 @@ #include "flash_fwd_launch_template.h" -template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); diff --git a/csrc/flash_attn/src/flash_fwd_split_hdim224_bf16_causal_sm80.cu b/csrc/flash_attn/src/flash_fwd_split_hdim224_bf16_causal_sm80.cu new file mode 100644 index 000000000..ea024d9ab --- /dev/null +++ b/csrc/flash_attn/src/flash_fwd_split_hdim224_bf16_causal_sm80.cu @@ -0,0 +1,7 @@ +// Copyright (c) 2023, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); diff --git a/csrc/flash_attn/src/flash_fwd_split_hdim224_bf16_sm80.cu b/csrc/flash_attn/src/flash_fwd_split_hdim224_bf16_sm80.cu index b6bf081f2..b06ae5ace 100644 --- a/csrc/flash_attn/src/flash_fwd_split_hdim224_bf16_sm80.cu +++ b/csrc/flash_attn/src/flash_fwd_split_hdim224_bf16_sm80.cu @@ -4,4 +4,4 @@ #include "flash_fwd_launch_template.h" -template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); diff --git a/csrc/flash_attn/src/flash_fwd_split_hdim224_fp16_causal_sm80.cu b/csrc/flash_attn/src/flash_fwd_split_hdim224_fp16_causal_sm80.cu new file mode 100644 index 000000000..b217f3789 --- /dev/null +++ b/csrc/flash_attn/src/flash_fwd_split_hdim224_fp16_causal_sm80.cu @@ -0,0 +1,7 @@ +// Copyright (c) 2023, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); diff --git a/csrc/flash_attn/src/flash_fwd_split_hdim224_fp16_sm80.cu b/csrc/flash_attn/src/flash_fwd_split_hdim224_fp16_sm80.cu index 0d09606fb..8cf2eabed 100644 --- a/csrc/flash_attn/src/flash_fwd_split_hdim224_fp16_sm80.cu +++ b/csrc/flash_attn/src/flash_fwd_split_hdim224_fp16_sm80.cu @@ -4,4 +4,4 @@ #include "flash_fwd_launch_template.h" -template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); diff --git a/csrc/flash_attn/src/flash_fwd_split_hdim256_bf16_causal_sm80.cu b/csrc/flash_attn/src/flash_fwd_split_hdim256_bf16_causal_sm80.cu new file mode 100644 index 000000000..4f9a31ab1 --- /dev/null +++ b/csrc/flash_attn/src/flash_fwd_split_hdim256_bf16_causal_sm80.cu @@ -0,0 +1,7 @@ +// Copyright (c) 2023, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); diff --git a/csrc/flash_attn/src/flash_fwd_split_hdim256_bf16_sm80.cu b/csrc/flash_attn/src/flash_fwd_split_hdim256_bf16_sm80.cu index 06a9524ac..ff31b9e6f 100644 --- a/csrc/flash_attn/src/flash_fwd_split_hdim256_bf16_sm80.cu +++ b/csrc/flash_attn/src/flash_fwd_split_hdim256_bf16_sm80.cu @@ -4,4 +4,4 @@ #include "flash_fwd_launch_template.h" -template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); diff --git a/csrc/flash_attn/src/flash_fwd_split_hdim256_fp16_causal_sm80.cu b/csrc/flash_attn/src/flash_fwd_split_hdim256_fp16_causal_sm80.cu new file mode 100644 index 000000000..bcd842d70 --- /dev/null +++ b/csrc/flash_attn/src/flash_fwd_split_hdim256_fp16_causal_sm80.cu @@ -0,0 +1,7 @@ +// Copyright (c) 2023, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); diff --git a/csrc/flash_attn/src/flash_fwd_split_hdim256_fp16_sm80.cu b/csrc/flash_attn/src/flash_fwd_split_hdim256_fp16_sm80.cu index 54fd3b87f..8140f3662 100644 --- a/csrc/flash_attn/src/flash_fwd_split_hdim256_fp16_sm80.cu +++ b/csrc/flash_attn/src/flash_fwd_split_hdim256_fp16_sm80.cu @@ -4,4 +4,4 @@ #include "flash_fwd_launch_template.h" -template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); diff --git a/csrc/flash_attn/src/flash_fwd_split_hdim32_bf16_causal_sm80.cu b/csrc/flash_attn/src/flash_fwd_split_hdim32_bf16_causal_sm80.cu new file mode 100644 index 000000000..93e7553bf --- /dev/null +++ b/csrc/flash_attn/src/flash_fwd_split_hdim32_bf16_causal_sm80.cu @@ -0,0 +1,7 @@ +// Copyright (c) 2023, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); diff --git a/csrc/flash_attn/src/flash_fwd_split_hdim32_bf16_sm80.cu b/csrc/flash_attn/src/flash_fwd_split_hdim32_bf16_sm80.cu index beff74ce8..9c424491a 100644 --- a/csrc/flash_attn/src/flash_fwd_split_hdim32_bf16_sm80.cu +++ b/csrc/flash_attn/src/flash_fwd_split_hdim32_bf16_sm80.cu @@ -4,4 +4,4 @@ #include "flash_fwd_launch_template.h" -template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); diff --git a/csrc/flash_attn/src/flash_fwd_split_hdim32_fp16_causal_sm80.cu b/csrc/flash_attn/src/flash_fwd_split_hdim32_fp16_causal_sm80.cu new file mode 100644 index 000000000..075567cb5 --- /dev/null +++ b/csrc/flash_attn/src/flash_fwd_split_hdim32_fp16_causal_sm80.cu @@ -0,0 +1,7 @@ +// Copyright (c) 2023, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); diff --git a/csrc/flash_attn/src/flash_fwd_split_hdim32_fp16_sm80.cu b/csrc/flash_attn/src/flash_fwd_split_hdim32_fp16_sm80.cu index d97c9eaa0..7d891c60e 100644 --- a/csrc/flash_attn/src/flash_fwd_split_hdim32_fp16_sm80.cu +++ b/csrc/flash_attn/src/flash_fwd_split_hdim32_fp16_sm80.cu @@ -4,4 +4,4 @@ #include "flash_fwd_launch_template.h" -template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); diff --git a/csrc/flash_attn/src/flash_fwd_split_hdim64_bf16_causal_sm80.cu b/csrc/flash_attn/src/flash_fwd_split_hdim64_bf16_causal_sm80.cu new file mode 100644 index 000000000..1b52f32e5 --- /dev/null +++ b/csrc/flash_attn/src/flash_fwd_split_hdim64_bf16_causal_sm80.cu @@ -0,0 +1,7 @@ +// Copyright (c) 2023, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); diff --git a/csrc/flash_attn/src/flash_fwd_split_hdim64_bf16_sm80.cu b/csrc/flash_attn/src/flash_fwd_split_hdim64_bf16_sm80.cu index aed05fadc..aad4cf5af 100644 --- a/csrc/flash_attn/src/flash_fwd_split_hdim64_bf16_sm80.cu +++ b/csrc/flash_attn/src/flash_fwd_split_hdim64_bf16_sm80.cu @@ -4,4 +4,4 @@ #include "flash_fwd_launch_template.h" -template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); diff --git a/csrc/flash_attn/src/flash_fwd_split_hdim64_fp16_causal_sm80.cu b/csrc/flash_attn/src/flash_fwd_split_hdim64_fp16_causal_sm80.cu new file mode 100644 index 000000000..e42fa9250 --- /dev/null +++ b/csrc/flash_attn/src/flash_fwd_split_hdim64_fp16_causal_sm80.cu @@ -0,0 +1,7 @@ +// Copyright (c) 2023, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); diff --git a/csrc/flash_attn/src/flash_fwd_split_hdim64_fp16_sm80.cu b/csrc/flash_attn/src/flash_fwd_split_hdim64_fp16_sm80.cu index 3b905f62c..73cb99bc7 100644 --- a/csrc/flash_attn/src/flash_fwd_split_hdim64_fp16_sm80.cu +++ b/csrc/flash_attn/src/flash_fwd_split_hdim64_fp16_sm80.cu @@ -4,4 +4,4 @@ #include "flash_fwd_launch_template.h" -template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); diff --git a/csrc/flash_attn/src/flash_fwd_split_hdim96_bf16_causal_sm80.cu b/csrc/flash_attn/src/flash_fwd_split_hdim96_bf16_causal_sm80.cu new file mode 100644 index 000000000..c8cfe93f9 --- /dev/null +++ b/csrc/flash_attn/src/flash_fwd_split_hdim96_bf16_causal_sm80.cu @@ -0,0 +1,7 @@ +// Copyright (c) 2023, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); diff --git a/csrc/flash_attn/src/flash_fwd_split_hdim96_bf16_sm80.cu b/csrc/flash_attn/src/flash_fwd_split_hdim96_bf16_sm80.cu index 00a5972bf..26ff113ed 100644 --- a/csrc/flash_attn/src/flash_fwd_split_hdim96_bf16_sm80.cu +++ b/csrc/flash_attn/src/flash_fwd_split_hdim96_bf16_sm80.cu @@ -4,4 +4,4 @@ #include "flash_fwd_launch_template.h" -template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); diff --git a/csrc/flash_attn/src/flash_fwd_split_hdim96_fp16_causal_sm80.cu b/csrc/flash_attn/src/flash_fwd_split_hdim96_fp16_causal_sm80.cu new file mode 100644 index 000000000..b8173261a --- /dev/null +++ b/csrc/flash_attn/src/flash_fwd_split_hdim96_fp16_causal_sm80.cu @@ -0,0 +1,7 @@ +// Copyright (c) 2023, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. +// This file is auto-generated. See "generate_kernels.py" + +#include "flash_fwd_launch_template.h" + +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); diff --git a/csrc/flash_attn/src/flash_fwd_split_hdim96_fp16_sm80.cu b/csrc/flash_attn/src/flash_fwd_split_hdim96_fp16_sm80.cu index 95a76967c..fd91232e6 100644 --- a/csrc/flash_attn/src/flash_fwd_split_hdim96_fp16_sm80.cu +++ b/csrc/flash_attn/src/flash_fwd_split_hdim96_fp16_sm80.cu @@ -4,4 +4,4 @@ #include "flash_fwd_launch_template.h" -template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, cudaStream_t stream); diff --git a/csrc/flash_attn/src/generate_kernels.py b/csrc/flash_attn/src/generate_kernels.py index 0f71002d3..45fc3d9f1 100644 --- a/csrc/flash_attn/src/generate_kernels.py +++ b/csrc/flash_attn/src/generate_kernels.py @@ -16,17 +16,18 @@ SM = [80] # Sm80 kernels support up to HEAD_DIMENSIONS = [32, 64, 96, 128, 160, 192, 224, 256] +IS_CAUSAL = ["false", "true"] KERNEL_IMPL_TEMPLATE_FWD = """#include "flash_fwd_launch_template.h" template<> -void run_mha_fwd_<{DTYPE}, {HEAD_DIM}>(Flash_fwd_params ¶ms, cudaStream_t stream) {{ - run_mha_fwd_hdim{HEAD_DIM}<{DTYPE}>(params, stream); +void run_mha_fwd_<{DTYPE}, {HEAD_DIM}, {IS_CAUSAL}>(Flash_fwd_params ¶ms, cudaStream_t stream) {{ + run_mha_fwd_hdim{HEAD_DIM}<{DTYPE}, {IS_CAUSAL}>(params, stream); }} """ KERNEL_IMPL_TEMPLATE_FWD_SPLIT = """#include "flash_fwd_launch_template.h" -template void run_mha_fwd_splitkv_dispatch<{DTYPE}, {HEAD_DIM}>(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_fwd_splitkv_dispatch<{DTYPE}, {HEAD_DIM}, {IS_CAUSAL}>(Flash_fwd_params ¶ms, cudaStream_t stream); """ KERNEL_IMPL_TEMPLATE_BWD = """#include "flash_bwd_launch_template.h" @@ -43,13 +44,14 @@ class Kernel: sm: int dtype: str head_dim: int + is_causal: bool direction: str @property def template(self) -> str: if self.direction == "fwd": return KERNEL_IMPL_TEMPLATE_FWD.format( - DTYPE=DTYPE_MAP[self.dtype], HEAD_DIM=self.head_dim + DTYPE=DTYPE_MAP[self.dtype], HEAD_DIM=self.head_dim, IS_CAUSAL=self.is_causal ) elif self.direction == "bwd": return KERNEL_IMPL_TEMPLATE_BWD.format( @@ -57,18 +59,21 @@ def template(self) -> str: ) else: return KERNEL_IMPL_TEMPLATE_FWD_SPLIT.format( - DTYPE=DTYPE_MAP[self.dtype], HEAD_DIM=self.head_dim + DTYPE=DTYPE_MAP[self.dtype], HEAD_DIM=self.head_dim, IS_CAUSAL=self.is_causal ) @property def filename(self) -> str: - return f"flash_{self.direction}_hdim{self.head_dim}_{self.dtype}_sm{self.sm}.cu" + return f"flash_{self.direction}_hdim{self.head_dim}_{self.dtype}_{'causal_' if self.is_causal == 'true' else ''}sm{self.sm}.cu" def get_all_kernels() -> List[Kernel]: - for dtype, head_dim, sm in itertools.product(DTYPE_MAP.keys(), HEAD_DIMENSIONS, SM): - for direction in ["fwd", "bwd", "fwd_split"]: - yield Kernel(sm=sm, dtype=dtype, head_dim=head_dim, direction=direction) + for direction in ["fwd", "fwd_split"]: + for dtype, head_dim, is_causal, sm in itertools.product(DTYPE_MAP.keys(), HEAD_DIMENSIONS, IS_CAUSAL, SM): + yield Kernel(sm=sm, dtype=dtype, head_dim=head_dim, is_causal=is_causal, direction=direction) + for direction in ["bwd"]: + for dtype, head_dim, sm in itertools.product(DTYPE_MAP.keys(), HEAD_DIMENSIONS, SM): + yield Kernel(sm=sm, dtype=dtype, head_dim=head_dim, is_causal="false", direction=direction) def write_kernel(kernel: Kernel, autogen_dir: Path) -> None: diff --git a/csrc/flash_attn/src/mask.h b/csrc/flash_attn/src/mask.h index 3d9b42985..7ba435a37 100644 --- a/csrc/flash_attn/src/mask.h +++ b/csrc/flash_attn/src/mask.h @@ -13,7 +13,7 @@ using namespace cute; template __forceinline__ __device__ void apply_mask(Tensor &tensor, const int max_seqlen_k, const int col_idx_offset_ = 0) { - // tensor has shape (ncol=(2, MMA_M), nrow=(2, MMA_N)) + // tensor has shape (nrow=(2, MMA_M), ncol=(2, MMA_N)) static_assert(Layout::rank == 2, "Only support 2D Tensor"); const int lane_id = threadIdx.x % 32; const int col_idx_offset = col_idx_offset_ + (lane_id % 4) * 2; @@ -39,7 +39,7 @@ __forceinline__ __device__ void apply_mask_local(Tensor &tensor, const int max_seqlen_k, const int row_idx_offset, const int max_seqlen_q, const int warp_row_stride, const int window_size_left, const int window_size_right) { - // tensor has shape (ncol=(2, MMA_M), nrow=(2, MMA_N)) + // tensor has shape (nrow=(2, MMA_M), ncol=(2, MMA_N)) static_assert(Layout::rank == 2, "Only support 2D Tensor"); const int lane_id = threadIdx.x % 32; const int col_idx_offset = col_idx_offset_ + (lane_id % 4) * 2; @@ -85,7 +85,7 @@ __forceinline__ __device__ void apply_mask_causal_w_idx( Tensor &tensor, Tensor const &idx_rowcol, const int col_idx_offset_, const int max_seqlen_k, const int row_idx_offset) { - // tensor has shape (ncol=(2, MMA_M), nrow=(2, MMA_N)) + // tensor has shape (nrow=(2, MMA_M), ncol=(2, MMA_N)) static_assert(Layout0::rank == 2, "Only support 2D Tensor"); static_assert(Layout1::rank == 2, "Only support 2D Tensor"); CUTE_STATIC_ASSERT_V(size<0>(tensor) == size<0>(idx_rowcol)); diff --git a/csrc/flash_attn/src/static_switch.h b/csrc/flash_attn/src/static_switch.h index ca12fa171..20c2afd6c 100644 --- a/csrc/flash_attn/src/static_switch.h +++ b/csrc/flash_attn/src/static_switch.h @@ -56,6 +56,16 @@ #define EVENK_SWITCH BOOL_SWITCH #endif +#ifdef FLASHATTENTION_DISABLE_SOFTCAP + #define SOFTCAP_SWITCH(COND, CONST_NAME, ...) \ + [&] { \ + constexpr static bool CONST_NAME = false; \ + return __VA_ARGS__(); \ + }() +#else + #define SOFTCAP_SWITCH BOOL_SWITCH +#endif + #ifdef FLASHATTENTION_DISABLE_LOCAL #define LOCAL_SWITCH(COND, CONST_NAME, ...) \ [&] { \ diff --git a/csrc/flash_attn/src/utils.h b/csrc/flash_attn/src/utils.h index 708aeddfa..b7408ec44 100644 --- a/csrc/flash_attn/src/utils.h +++ b/csrc/flash_attn/src/utils.h @@ -390,4 +390,22 @@ __forceinline__ __device__ void copy_w_min_idx(Tensor const &S //////////////////////////////////////////////////////////////////////////////////////////////////// +template +__forceinline__ __device__ void apply_softcap(Tensor &tensor, const float softcap){ + #pragma unroll + for (int i = 0; i < size(tensor); ++i) { + tensor(i) = cutlass::fast_tanh(tensor(i) * softcap); + } +} + +template +__forceinline__ __device__ void calculate_dtanh(Tensor &src_tensor, Tensor &dst_tensor, const float softcap){ + #pragma unroll + for (int i = 0; i < size(src_tensor); ++i) { + dst_tensor(i) = (1.f - (src_tensor(i) * src_tensor(i))) * softcap; + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + } // namespace flash diff --git a/flash_attn/__init__.py b/flash_attn/__init__.py index 242022d6a..aa3214ebd 100644 --- a/flash_attn/__init__.py +++ b/flash_attn/__init__.py @@ -1,4 +1,4 @@ -__version__ = "2.5.9.post1" +__version__ = "2.6.1" from flash_attn.flash_attn_interface import ( flash_attn_func, diff --git a/flash_attn/flash_attn_interface.py b/flash_attn/flash_attn_interface.py index a7f15beee..ecb3515c0 100644 --- a/flash_attn/flash_attn_interface.py +++ b/flash_attn/flash_attn_interface.py @@ -11,6 +11,8 @@ # isort: on +def maybe_contiguous(x): + return x.contiguous() if x is not None and x.stride(-1) != 1 else x def _get_block_size_n(device, head_dim, is_dropout, is_causal): # This should match the block sizes in the CUDA kernel @@ -44,9 +46,8 @@ def _get_block_size_n(device, head_dim, is_dropout, is_causal): def _flash_attn_forward( - q, k, v, dropout_p, softmax_scale, causal, window_size, alibi_slopes, return_softmax + q, k, v, dropout_p, softmax_scale, causal, window_size, softcap, alibi_slopes, return_softmax ): - maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x q, k, v = [maybe_contiguous(x) for x in (q, k, v)] out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = flash_attn_cuda.fwd( q, @@ -59,6 +60,7 @@ def _flash_attn_forward( causal, window_size[0], window_size[1], + softcap, return_softmax, None, ) @@ -76,12 +78,14 @@ def _flash_attn_varlen_forward( dropout_p, softmax_scale, causal, - window_size, - alibi_slopes, - return_softmax, - block_table, + window_size=(-1, -1), + softcap=0.0, + alibi_slopes=None, + return_softmax=False, + block_table=None, + leftpad_k=None, + seqused_k=None, ): - maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x q, k, v = [maybe_contiguous(x) for x in (q, k, v)] out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = flash_attn_cuda.varlen_fwd( q, @@ -90,7 +94,8 @@ def _flash_attn_varlen_forward( None, cu_seqlens_q, cu_seqlens_k, - None, + seqused_k, + leftpad_k, block_table, alibi_slopes, max_seqlen_q, @@ -101,6 +106,7 @@ def _flash_attn_varlen_forward( causal, window_size[0], window_size[1], + softcap, return_softmax, None, ) @@ -123,14 +129,19 @@ def _flash_attn_backward( softmax_scale, causal, window_size, + softcap, alibi_slopes, deterministic, rng_state=None, ): - maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x # dq, dk, dv are allocated by us so they should already be contiguous dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)] - dq, dk, dv, softmax_d, = flash_attn_cuda.bwd( + ( + dq, + dk, + dv, + softmax_d, + ) = flash_attn_cuda.bwd( dout, q, k, @@ -146,6 +157,7 @@ def _flash_attn_backward( causal, window_size[0], window_size[1], + softcap, deterministic, None, rng_state, @@ -171,14 +183,19 @@ def _flash_attn_varlen_backward( softmax_scale, causal, window_size, + softcap, alibi_slopes, deterministic, rng_state=None, ): - maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x # dq, dk, dv are allocated by us so they should already be contiguous dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)] - dq, dk, dv, softmax_d, = flash_attn_cuda.varlen_bwd( + ( + dq, + dk, + dv, + softmax_d, + ) = flash_attn_cuda.varlen_bwd( dout, q, k, @@ -199,6 +216,7 @@ def _flash_attn_varlen_backward( causal, window_size[0], window_size[1], + softcap, deterministic, None, rng_state, @@ -217,6 +235,7 @@ def forward( softmax_scale, causal, window_size, + softcap, alibi_slopes, deterministic, return_softmax, @@ -231,6 +250,7 @@ def forward( softmax_scale, causal=causal, window_size=window_size, + softcap=softcap, alibi_slopes=alibi_slopes, return_softmax=return_softmax and dropout_p > 0, ) @@ -239,6 +259,7 @@ def forward( ctx.softmax_scale = softmax_scale ctx.causal = causal ctx.window_size = window_size + ctx.softcap = softcap ctx.alibi_slopes = alibi_slopes ctx.deterministic = deterministic return out if not return_softmax else (out, softmax_lse, S_dmask) @@ -262,12 +283,13 @@ def backward(ctx, dout, *args): ctx.softmax_scale, ctx.causal, ctx.window_size, + ctx.softcap, ctx.alibi_slopes, ctx.deterministic, rng_state=rng_state, ) dqkv = dqkv[..., : dout.shape[-1]] # We could have padded the head dimension - return dqkv, None, None, None, None, None, None, None + return dqkv, None, None, None, None, None, None, None, None class FlashAttnVarlenQKVPackedFunc(torch.autograd.Function): @@ -281,6 +303,7 @@ def forward( softmax_scale, causal, window_size, + softcap, alibi_slopes, deterministic, return_softmax, @@ -299,6 +322,7 @@ def forward( softmax_scale, causal=causal, window_size=window_size, + softcap=softcap, alibi_slopes=alibi_slopes, return_softmax=return_softmax and dropout_p > 0, block_table=None, @@ -309,6 +333,7 @@ def forward( ctx.softmax_scale = softmax_scale ctx.causal = causal ctx.window_size = window_size + ctx.softcap = softcap ctx.alibi_slopes = alibi_slopes ctx.deterministic = deterministic return out if not return_softmax else (out, softmax_lse, S_dmask) @@ -336,12 +361,13 @@ def backward(ctx, dout, *args): ctx.softmax_scale, ctx.causal, ctx.window_size, + ctx.softcap, ctx.alibi_slopes, ctx.deterministic, rng_state=rng_state, ) dqkv = dqkv[..., : dout.shape[-1]] # We could have padded the head dimension - return dqkv, None, None, None, None, None, None, None, None, None + return dqkv, None, None, None, None, None, None, None, None, None, None class FlashAttnKVPackedFunc(torch.autograd.Function): @@ -354,6 +380,7 @@ def forward( softmax_scale, causal, window_size, + softcap, alibi_slopes, deterministic, return_softmax, @@ -368,6 +395,7 @@ def forward( softmax_scale, causal=causal, window_size=window_size, + softcap=softcap, alibi_slopes=alibi_slopes, return_softmax=return_softmax and dropout_p > 0, ) @@ -376,6 +404,7 @@ def forward( ctx.softmax_scale = softmax_scale ctx.causal = causal ctx.window_size = window_size + ctx.softcap = softcap ctx.alibi_slopes = alibi_slopes ctx.deterministic = deterministic return out if not return_softmax else (out, softmax_lse, S_dmask) @@ -400,13 +429,14 @@ def backward(ctx, dout, *args): ctx.softmax_scale, ctx.causal, ctx.window_size, + ctx.softcap, ctx.alibi_slopes, ctx.deterministic, rng_state=rng_state, ) dq = dq[..., : dout.shape[-1]] # We could have padded the head dimension dkv = dkv[..., : dout.shape[-1]] - return dq, dkv, None, None, None, None, None, None, None + return dq, dkv, None, None, None, None, None, None, None, None class FlashAttnVarlenKVPackedFunc(torch.autograd.Function): @@ -423,6 +453,7 @@ def forward( softmax_scale, causal, window_size, + softcap, alibi_slopes, deterministic, return_softmax, @@ -441,6 +472,7 @@ def forward( softmax_scale, causal=causal, window_size=window_size, + softcap=softcap, alibi_slopes=alibi_slopes, return_softmax=return_softmax and dropout_p > 0, block_table=None, @@ -454,6 +486,7 @@ def forward( ctx.softmax_scale = softmax_scale ctx.causal = causal ctx.window_size = window_size + ctx.softcap = softcap ctx.alibi_slopes = alibi_slopes ctx.deterministic = deterministic return out if not return_softmax else (out, softmax_lse, S_dmask) @@ -482,13 +515,14 @@ def backward(ctx, dout, *args): ctx.softmax_scale, ctx.causal, ctx.window_size, + ctx.softcap, ctx.alibi_slopes, ctx.deterministic, rng_state=rng_state, ) dq = dq[..., : dout.shape[-1]] # We could have padded the head dimension dkv = dkv[..., : dout.shape[-1]] - return dq, dkv, None, None, None, None, None, None, None, None, None, None, None + return dq, dkv, None, None, None, None, None, None, None, None, None, None, None, None class FlashAttnFunc(torch.autograd.Function): @@ -502,6 +536,7 @@ def forward( softmax_scale, causal, window_size, + softcap, alibi_slopes, deterministic, return_softmax, @@ -516,6 +551,7 @@ def forward( softmax_scale, causal=causal, window_size=window_size, + softcap=softcap, alibi_slopes=alibi_slopes, return_softmax=return_softmax and dropout_p > 0, ) @@ -524,6 +560,7 @@ def forward( ctx.softmax_scale = softmax_scale ctx.causal = causal ctx.window_size = window_size + ctx.softcap = softcap ctx.alibi_slopes = alibi_slopes ctx.deterministic = deterministic return out if not return_softmax else (out, softmax_lse, S_dmask) @@ -546,6 +583,7 @@ def backward(ctx, dout, *args): ctx.softmax_scale, ctx.causal, ctx.window_size, + ctx.softcap, ctx.alibi_slopes, ctx.deterministic, rng_state=rng_state, @@ -553,7 +591,7 @@ def backward(ctx, dout, *args): dq = dq[..., : dout.shape[-1]] # We could have padded the head dimension dk = dk[..., : dout.shape[-1]] dv = dv[..., : dout.shape[-1]] - return dq, dk, dv, None, None, None, None, None, None, None + return dq, dk, dv, None, None, None, None, None, None, None, None class FlashAttnVarlenFunc(torch.autograd.Function): @@ -571,6 +609,7 @@ def forward( softmax_scale, causal, window_size, + softcap, alibi_slopes, deterministic, return_softmax, @@ -590,6 +629,7 @@ def forward( softmax_scale, causal=causal, window_size=window_size, + softcap=softcap, alibi_slopes=alibi_slopes, return_softmax=return_softmax and dropout_p > 0, block_table=block_table, @@ -603,6 +643,7 @@ def forward( ctx.softmax_scale = softmax_scale ctx.causal = causal ctx.window_size = window_size + ctx.softcap = softcap ctx.alibi_slopes = alibi_slopes ctx.deterministic = deterministic return out if not return_softmax else (out, softmax_lse, S_dmask) @@ -629,6 +670,7 @@ def backward(ctx, dout, *args): ctx.softmax_scale, ctx.causal, ctx.window_size, + ctx.softcap, ctx.alibi_slopes, ctx.deterministic, rng_state=rng_state, @@ -636,7 +678,7 @@ def backward(ctx, dout, *args): dq = dq[..., : dout.shape[-1]] # We could have padded the head dimension dk = dk[..., : dout.shape[-1]] dv = dv[..., : dout.shape[-1]] - return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None, None + return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None, None, None def flash_attn_qkvpacked_func( @@ -645,6 +687,7 @@ def flash_attn_qkvpacked_func( softmax_scale=None, causal=False, window_size=(-1, -1), # -1 means infinite context window + softcap=0.0, # <=0.0 means deactivate alibi_slopes=None, deterministic=False, return_attn_probs=False, @@ -666,6 +709,7 @@ def flash_attn_qkvpacked_func( Default to 1 / sqrt(headdim). causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling). window_size: (left, right). If not (-1, -1), implements sliding window local attention. + softcap: float. Anything > 0 activates softcapping attention. alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of (-alibi_slope * |i - j|) is added to the attention score of query i and key j. deterministic: bool. Whether to use the deterministic implementation of the backward pass, @@ -688,6 +732,7 @@ def flash_attn_qkvpacked_func( softmax_scale, causal, window_size, + softcap, alibi_slopes, deterministic, return_attn_probs, @@ -701,6 +746,7 @@ def flash_attn_kvpacked_func( softmax_scale=None, causal=False, window_size=(-1, -1), # -1 means infinite context window + softcap=0.0, # 0.0 means deactivated alibi_slopes=None, deterministic=False, return_attn_probs=False, @@ -738,6 +784,7 @@ def flash_attn_kvpacked_func( Default to 1 / sqrt(headdim). causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling). window_size: (left, right). If not (-1, -1), implements sliding window local attention. + softcap: float. Anything > 0 activates softcapping attention. alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of (-alibi_slope * |i + seqlen_k - seqlen_q - j|) is added to the attention score of query i and key j. @@ -762,6 +809,7 @@ def flash_attn_kvpacked_func( softmax_scale, causal, window_size, + softcap, alibi_slopes, deterministic, return_attn_probs, @@ -776,6 +824,7 @@ def flash_attn_func( softmax_scale=None, causal=False, window_size=(-1, -1), # -1 means infinite context window + softcap=0.0, # 0.0 means deactivated alibi_slopes=None, deterministic=False, return_attn_probs=False, @@ -836,6 +885,7 @@ def flash_attn_func( softmax_scale, causal, window_size, + softcap, alibi_slopes, deterministic, return_attn_probs, @@ -850,6 +900,7 @@ def flash_attn_varlen_qkvpacked_func( softmax_scale=None, causal=False, window_size=(-1, -1), # -1 means infinite context window + softcap=0.0, # 0.0 means deactivated alibi_slopes=None, deterministic=False, return_attn_probs=False, @@ -874,6 +925,7 @@ def flash_attn_varlen_qkvpacked_func( Default to 1 / sqrt(headdim). causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling). window_size: (left, right). If not (-1, -1), implements sliding window local attention. + softcap: float. Anything > 0 activates softcapping attention. alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of (-alibi_slope * |i - j|) is added to the attention score of query i and key j. deterministic: bool. Whether to use the deterministic implementation of the backward pass, @@ -883,7 +935,7 @@ def flash_attn_varlen_qkvpacked_func( (they might not have the right scaling). Return: out: (total, nheads, headdim). - softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The + softmax_lse [optional, if return_attn_probs=True]: (nheads, total_q_seqlen). The logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax normalization factor). S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen). @@ -898,6 +950,7 @@ def flash_attn_varlen_qkvpacked_func( softmax_scale, causal, window_size, + softcap, alibi_slopes, deterministic, return_attn_probs, @@ -915,6 +968,7 @@ def flash_attn_varlen_kvpacked_func( softmax_scale=None, causal=False, window_size=(-1, -1), # -1 means infinite context window + softcap=0.0, # 0.0 means deactivated alibi_slopes=None, deterministic=False, return_attn_probs=False, @@ -958,6 +1012,7 @@ def flash_attn_varlen_kvpacked_func( Default to 1 / sqrt(headdim). causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling). window_size: (left, right). If not (-1, -1), implements sliding window local attention. + softcap: float. Anything > 0 activates softcapping attention. alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of (-alibi_slope * |i + seqlen_k - seqlen_q - j|) is added to the attention score of query i and key j. @@ -968,7 +1023,7 @@ def flash_attn_varlen_kvpacked_func( (they might not have the right scaling). Return: out: (total, nheads, headdim). - softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The + softmax_lse [optional, if return_attn_probs=True]: (nheads, total_q_seqlen). The logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax normalization factor). S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen). @@ -986,6 +1041,7 @@ def flash_attn_varlen_kvpacked_func( softmax_scale, causal, window_size, + softcap, alibi_slopes, deterministic, return_attn_probs, @@ -1004,6 +1060,7 @@ def flash_attn_varlen_func( softmax_scale=None, causal=False, window_size=(-1, -1), # -1 means infinite context window + softcap=0.0, # 0.0 means deactivated alibi_slopes=None, deterministic=False, return_attn_probs=False, @@ -1046,6 +1103,7 @@ def flash_attn_varlen_func( Default to 1 / sqrt(headdim). causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling). window_size: (left, right). If not (-1, -1), implements sliding window local attention. + softcap: float. Anything > 0 activates softcapping attention. alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of (-alibi_slope * |i + seqlen_k - seqlen_q - j|) is added to the attention score of query i and key j. @@ -1056,7 +1114,7 @@ def flash_attn_varlen_func( (they might not have the right scaling). Return: out: (total, nheads, headdim). - softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The + softmax_lse [optional, if return_attn_probs=True]: (nheads, total_q_seqlen). The logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax normalization factor). S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen). @@ -1075,6 +1133,7 @@ def flash_attn_varlen_func( softmax_scale, causal, window_size, + softcap, alibi_slopes, deterministic, return_attn_probs, @@ -1092,13 +1151,16 @@ def flash_attn_with_kvcache( rotary_sin=None, cache_seqlens: Optional[Union[(int, torch.Tensor)]] = None, cache_batch_idx: Optional[torch.Tensor] = None, + cache_leftpad: Optional[torch.Tensor] = None, block_table: Optional[torch.Tensor] = None, softmax_scale=None, causal=False, window_size=(-1, -1), # -1 means infinite context window + softcap=0.0, # 0.0 means deactivated rotary_interleaved=True, alibi_slopes=None, num_splits=0, + return_softmax_lse=False, ): """ If k and v are not None, k_cache and v_cache will be updated *inplace* with the new values from @@ -1157,15 +1219,17 @@ def flash_attn_with_kvcache( rotary_sin [optional]: (seqlen_ro, rotary_dim / 2). Similar to rotary_cos. cache_seqlens: int, or (batch_size,), dtype torch.int32. The sequence lengths of the KV cache. - block_table [optional]: (batch_size, max_num_blocks_per_seq), dtype torch.int32. cache_batch_idx: (batch_size,), dtype torch.int32. The indices used to index into the KV cache. If None, we assume that the batch indices are [0, 1, 2, ..., batch_size - 1]. If the indices are not distinct, and k and v are provided, the values updated in the cache might come from any of the duplicate indices. + cache_leftpad: (batch_size,), dtype torch.int32. The index that the KV cache starts. If None, assume 0. + block_table [optional]: (batch_size, max_num_blocks_per_seq), dtype torch.int32. softmax_scale: float. The scaling of QK^T before applying softmax. Default to 1 / sqrt(headdim). causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling). window_size: (left, right). If not (-1, -1), implements sliding window local attention. + softcap: float. Anything > 0 activates softcapping attention. rotary_interleaved: bool. Only applicable if rotary_cos and rotary_sin are passed in. If True, rotary embedding will combine dimensions 0 & 1, 2 & 3, etc. If False, rotary embedding will combine dimensions 0 & rotary_dim / 2, 1 & rotary_dim / 2 + 1 @@ -1177,13 +1241,16 @@ def flash_attn_with_kvcache( If num_splits == 1, we don't split the key/value. If num_splits == 0, we use a heuristic to automatically determine the number of splits. Don't change this unless you know what you are doing. + return_softmax_lse: bool. Whether to return the logsumexp of the attention scores. Return: out: (batch_size, seqlen, nheads, headdim). + softmax_lse [optional, if return_softmax_lse=True]: (batch_size, nheads, seqlen). The + logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax + normalization factor). """ assert k_cache.stride(-1) == 1, "k_cache must have contiguous last dimension" assert v_cache.stride(-1) == 1, "v_cache must have contiguous last dimension" - maybe_contiguous = lambda x: x.contiguous() if x is not None and x.stride(-1) != 1 else x q, k, v = [maybe_contiguous(x) for x in (q, k, v)] if softmax_scale is None: softmax_scale = q.shape[-1] ** (-0.5) @@ -1204,6 +1271,7 @@ def flash_attn_with_kvcache( rotary_cos, rotary_sin, cache_batch_idx, + cache_leftpad, block_table, alibi_slopes, None, @@ -1211,7 +1279,8 @@ def flash_attn_with_kvcache( causal, window_size[0], window_size[1], + softcap, rotary_interleaved, num_splits, ) - return out + return (out, softmax_lse) if return_softmax_lse else out diff --git a/flash_attn/models/gpt.py b/flash_attn/models/gpt.py index 71540da95..3539f8f90 100644 --- a/flash_attn/models/gpt.py +++ b/flash_attn/models/gpt.py @@ -966,7 +966,7 @@ def key_mapping_mlp(key): # Attention for d in range(config.num_hidden_layers): - state_dict.pop(f"h.{d}.attn.bias") # We don't store this bias + state_dict.pop(f"h.{d}.attn.bias", None) # We don't store this bias Wqkv = state_dict.pop(f"h.{d}.attn.c_attn.weight") state_dict[f"transformer.layers.{d}.mixer.Wqkv.weight"] = Wqkv.t() Wout = state_dict.pop(f"h.{d}.attn.c_proj.weight") diff --git a/hopper/__init__.py b/hopper/__init__.py new file mode 100644 index 000000000..2e33087c5 --- /dev/null +++ b/hopper/__init__.py @@ -0,0 +1 @@ +__version__ = "3.0.0.b1" diff --git a/hopper/block_info.h b/hopper/block_info.h new file mode 100644 index 000000000..3a23a1e1f --- /dev/null +++ b/hopper/block_info.h @@ -0,0 +1,46 @@ +/****************************************************************************** + * Copyright (c) 2023, Tri Dao. + ******************************************************************************/ + +#pragma once + +namespace flash { + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct BlockInfo { + + template + __device__ BlockInfo(const Params ¶ms, const int bidb) + : sum_s_q(!Varlen || params.cu_seqlens_q == nullptr ? -1 : params.cu_seqlens_q[bidb]) + , sum_s_k(!Varlen || params.cu_seqlens_k == nullptr || !params.is_seqlens_k_cumulative ? -1 : params.cu_seqlens_k[bidb]) + , actual_seqlen_q(!Varlen || params.cu_seqlens_q == nullptr ? params.seqlen_q : params.cu_seqlens_q[bidb + 1] - sum_s_q) + // If is_seqlens_k_cumulative, then seqlen_k is cu_seqlens_k[bidb + 1] - cu_seqlens_k[bidb]. + // Otherwise it's cu_seqlens_k[bidb], i.e., we use cu_seqlens_k to store the sequence lengths of K. + , seqlen_k_cache(!Varlen || params.cu_seqlens_k == nullptr ? params.seqlen_k : (params.is_seqlens_k_cumulative ? params.cu_seqlens_k[bidb + 1] - sum_s_k : params.cu_seqlens_k[bidb])) + , actual_seqlen_k(params.seqused_k ? params.seqused_k[bidb] : seqlen_k_cache + (params.knew_ptr == nullptr ? 0 : params.seqlen_knew)) + { + } + + template + __forceinline__ __device__ index_t q_offset(const index_t batch_stride, const index_t row_stride, const int bidb) const { + return sum_s_q == -1 ? bidb * batch_stride : uint32_t(sum_s_q) * row_stride; + } + + template + __forceinline__ __device__ index_t k_offset(const index_t batch_stride, const index_t row_stride, const int bidb) const { + return sum_s_k == -1 ? bidb * batch_stride : uint32_t(sum_s_k) * row_stride; + } + + const int sum_s_q; + const int sum_s_k; + const int actual_seqlen_q; + // We have to have seqlen_k_cache declared before actual_seqlen_k, otherwise actual_seqlen_k is set to 0. + const int seqlen_k_cache; + const int actual_seqlen_k; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace flash diff --git a/hopper/epilogue_fwd_sm90_tma.hpp b/hopper/epilogue_fwd_sm90_tma.hpp new file mode 100644 index 000000000..2d5c33eb4 --- /dev/null +++ b/hopper/epilogue_fwd_sm90_tma.hpp @@ -0,0 +1,215 @@ +/****************************************************************************** + * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. + ******************************************************************************/ + +#pragma once + +#include +#include "cute/tensor.hpp" + +#include "cutlass/gemm/collective/collective_builder.hpp" + +#include "named_barrier.hpp" +#include "utils.h" + +namespace flash { + +using namespace cute; + +// template +template +struct CollectiveEpilogueFwd { + + using Element = typename Ktraits::Element; + static constexpr int kBlockM = Ktraits::kBlockM; + static constexpr int kBlockN = Ktraits::kBlockN; + static constexpr int kHeadDim = Ktraits::kHeadDim; + // using Element = Element_; + // static constexpr int kBlockM = kBlockM_; + // static constexpr int kBlockN = kBlockN_; + // static constexpr int kHeadDim = kHeadDim_; + using TileShape_MNK = Shape, Int, Int>; + + // static constexpr int kNWarps = kNWarps_; + static constexpr int kNWarps = Ktraits::kNWarps; + static constexpr int kNThreads = kNWarps * cutlass::NumThreadsPerWarp; + static constexpr bool Is_WS = kNWarps >= 12; + + static constexpr int NumCopyThreads = !Is_WS ? 0 : cutlass::NumThreadsPerWarpGroup; + static constexpr int NumMmaThreads = kNThreads - NumCopyThreads; + + using GmemTiledCopyOTMA = cute::SM90_TMA_STORE; + + // These are for storing the output tensor without TMA (e.g., for setting output to zero) + static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(Element); + static_assert(kHeadDim % kGmemElemsPerLoad == 0, "kHeadDim must be a multiple of kGmemElemsPerLoad"); + static constexpr int kGmemThreadsPerRow = kHeadDim / kGmemElemsPerLoad; + static_assert(NumMmaThreads % kGmemThreadsPerRow == 0, "NumMmaThreads must be a multiple of kGmemThreadsPerRow"); + using GmemLayoutAtom = Layout, Int>, + Stride, _1>>; + using GmemTiledCopyO = decltype( + make_tiled_copy(Copy_Atom{}, + GmemLayoutAtom{}, + Layout>>{})); // Val layout, 8 or 16 vals per store + + using SmemLayoutAtomO = decltype(cutlass::gemm::collective::detail::ss_smem_selector(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>()); + using SmemLayoutO = decltype(tile_to_shape(SmemLayoutAtomO{}, select<0, 2>(TileShape_MNK{}))); + + using SmemCopyAtomO = Copy_Atom; + using SharedStorage = cute::array_aligned>; + + using ShapeO = cute::Shape; // (seqlen_q, d, head, batch) + using StrideO = cute::Stride; + using StrideLSE = cute::Stride<_1, int64_t, int64_t>; // (seqlen_q, head, batch) + + using TMA_O = decltype(make_tma_copy( + GmemTiledCopyOTMA{}, + make_tensor(make_gmem_ptr(static_cast(nullptr)), repeat_like(StrideO{}, int32_t(0)), StrideO{}), + SmemLayoutO{}, + select<0, 2>(TileShape_MNK{}), + _1{})); // no mcast for O + + // Host side kernel arguments + struct Arguments { + Element* ptr_O; + ShapeO const shape_O; + StrideO const stride_O; + float* ptr_LSE; + StrideLSE const stride_LSE; + }; + + // Device side kernel params + struct Params { + Element* ptr_O; + ShapeO const shape_O; + StrideO const stride_O; + float* ptr_LSE; + StrideLSE const stride_LSE; + TMA_O tma_store_O; + }; + + static Params + to_underlying_arguments(Arguments const& args) { + Tensor mO = make_tensor(make_gmem_ptr(args.ptr_O), args.shape_O, args.stride_O); + TMA_O tma_store_O = make_tma_copy( + GmemTiledCopyOTMA{}, + mO, + SmemLayoutO{}, + select<0, 2>(TileShape_MNK{}), + _1{}); // no mcast for O + return {args.ptr_O, args.shape_O, args.stride_O, args.ptr_LSE, args.stride_LSE, tma_store_O}; + } + + /// Issue Tma Descriptor Prefetch -- ideally from a single thread for best performance + CUTLASS_DEVICE + static void prefetch_tma_descriptors(Params const& epilogue_params) { + cute::prefetch_tma_descriptor(epilogue_params.tma_store_O.get_tma_descriptor()); + } + + template + CUTLASS_DEVICE void + store(Params const& epilogue_params, + FrgTensorO const& tOrO, + FrgTensorLSE const& lse, + SharedStorage& shared_storage, + TiledMma tiled_mma, + int thread_idx, + cute::tuple const& block_coord + ) { + + auto [m_block, bidh, bidb] = block_coord; + Tensor sO = make_tensor(make_smem_ptr(shared_storage.smem_o.data()), SmemLayoutO{}); + auto smem_tiled_copy_O = make_tiled_copy_C(SmemCopyAtomO{}, tiled_mma); + auto smem_thr_copy_O = smem_tiled_copy_O.get_thread_slice(thread_idx); + + Tensor tOrO_out = flash::convert_type(tOrO); + Tensor taccOrO = smem_thr_copy_O.retile_S(tOrO_out); // ((Atom,AtomNum), MMA_M, MMA_N) + Tensor taccOsO = smem_thr_copy_O.partition_D(sO); // ((Atom,AtomNum),PIPE_M,PIPE_N) + + // Make sure all WGs have finished reading V + cutlass::arch::NamedBarrier::sync(NumMmaThreads, static_cast(FwdNamedBarriers::ValueEmpty) /*id*/); + cute::copy(smem_tiled_copy_O, taccOrO, taccOsO); + cutlass::arch::fence_view_async_shared(); // ensure smem writes are visible to TMA + cutlass::arch::NamedBarrier::arrive(NumMmaThreads + cutlass::NumThreadsPerWarp, + cutlass::arch::ReservedNamedBarriers::EpilogueBarrier); + + Tensor mO = epilogue_params.tma_store_O.get_tma_tensor(epilogue_params.shape_O); + Tensor gO = local_tile(mO(_, _, bidh, bidb), select<0, 2>(TileShape_MNK{}), make_coord(m_block, _0{})); // (M, K) + auto block_tma_O = epilogue_params.tma_store_O.get_slice(_0{}); + Tensor tOgO = block_tma_O.partition_D(gO); // (TMA, TMA_M, TMA_K) + Tensor tOsO = block_tma_O.partition_S(sO); // (TMA, TMA_M, TMA_K) + + auto shape_LSE = select<0, 2, 3>(epilogue_params.shape_O); + Tensor mLSE = make_tensor(make_gmem_ptr(epilogue_params.ptr_LSE), shape_LSE, epilogue_params.stride_LSE); + Tensor gLSE = local_tile(mLSE(_, bidh, bidb), Shape>{}, make_coord(m_block)); + + Tensor caccO = cute::make_identity_tensor(select<0, 2>(TileShape_MNK{})); + auto thread_mma = tiled_mma.get_thread_slice(thread_idx); + Tensor taccOcO = thread_mma.partition_C(caccO); // (MMA,MMA_M,MMA_K) + static_assert(decltype(size<0, 0>(taccOcO))::value == 2); + static_assert(decltype(size<0, 1>(taccOcO))::value == 2); + // taccOcO has shape ((2, 2, V), MMA_M, MMA_K), we only take only the row indices. + Tensor taccOcO_row = taccOcO(make_coord(_0{}, _, _0{}), _, _0{}); + CUTE_STATIC_ASSERT_V(size(lse) == size(taccOcO_row)); // MMA_M + if (get<1>(taccOcO_row(_0{})) == 0) { + #pragma unroll + for (int mi = 0; mi < size(lse); ++mi) { + const int row = get<0>(taccOcO_row(mi)); + if (row < get<0>(shape_LSE) - m_block * kBlockM) { gLSE(row) = lse(mi); } + } + } + + if (cutlass::canonical_warp_idx_sync() == kNWarps - 1) { + cutlass::arch::NamedBarrier::sync(NumMmaThreads + cutlass::NumThreadsPerWarp, + cutlass::arch::ReservedNamedBarriers::EpilogueBarrier); + int const lane_predicate = cute::elect_one_sync(); + if (lane_predicate) { + cute::copy(epilogue_params.tma_store_O, tOsO, tOgO); + tma_store_arrive(); + } + } + } + + CUTLASS_DEVICE void + store_tail() { + tma_store_wait<0>(); + } + + // Write 0 to output and -inf to LSE + CUTLASS_DEVICE void + store_zero( + Params const& epilogue_params, + int thread_idx, + cute::tuple const& block_coord + ) { + auto [m_block, bidh, bidb] = block_coord; + Tensor mO = make_tensor(make_gmem_ptr(epilogue_params.ptr_O), epilogue_params.shape_O, epilogue_params.stride_O); + Tensor gO = local_tile(mO(_, _, bidh, bidb), select<0, 2>(TileShape_MNK{}), make_coord(m_block, _0{})); // (M, K) + auto shape_LSE = select<0, 2, 3>(epilogue_params.shape_O); + Tensor mLSE = make_tensor(make_gmem_ptr(epilogue_params.ptr_LSE), shape_LSE, epilogue_params.stride_LSE); + Tensor gLSE = local_tile(mLSE(_, bidh, bidb), Shape>{}, make_coord(m_block)); + + GmemTiledCopyO gmem_tiled_copy_O; + auto gmem_thr_copy_O = gmem_tiled_copy_O.get_thread_slice(thread_idx); + Tensor tOgO = gmem_thr_copy_O.partition_D(gO); + Tensor tOrO = make_fragment_like(tOgO); + clear(tOrO); + // Construct identity layout for sO + Tensor cO = cute::make_identity_tensor(select<0, 2>(TileShape_MNK{})); // (BLK_M,BLK_K) -> (blk_m,blk_k) + // Repeat the partitioning with identity layouts + Tensor tOcO = gmem_thr_copy_O.partition_D(cO); + Tensor tOpO = make_tensor(make_shape(size<2>(tOgO))); + #pragma unroll + for (int k = 0; k < size(tOpO); ++k) { tOpO(k) = get<1>(tOcO(_0{}, _0{}, k)) < get<1>(epilogue_params.shape_O); } + // Clear_OOB_K must be false since we don't want to write zeros to gmem + flash::copy( + gmem_tiled_copy_O, tOrO, tOgO, tOcO, tOpO, get<0>(epilogue_params.shape_O) - m_block * kBlockM + ); + static_assert(kBlockM <= NumMmaThreads); + if (thread_idx < get<0>(shape_LSE) - m_block * kBlockM) { gLSE(thread_idx) = INFINITY; } + } + +}; + +} // namespace flash diff --git a/hopper/flash.h b/hopper/flash.h new file mode 100644 index 000000000..0418f0be8 --- /dev/null +++ b/hopper/flash.h @@ -0,0 +1,192 @@ +/****************************************************************************** + * Copyright (c) 2023, Tri Dao. + ******************************************************************************/ + +#pragma once + +#include +#include + +#ifdef OLD_GENERATOR_PATH +#include +#else +#include +#endif + +#include // For at::cuda::philox::unpack + +#include "cutlass/fast_math.h" // For cutlass::FastDivmod + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +struct Qkv_params { + using index_t = int64_t; + // The QKV matrices. + void *__restrict__ q_ptr; + void *__restrict__ k_ptr; + void *__restrict__ v_ptr; + + // The stride between rows of the Q, K and V matrices. + index_t q_batch_stride; + index_t k_batch_stride; + index_t v_batch_stride; + index_t q_row_stride; + index_t k_row_stride; + index_t v_row_stride; + index_t q_head_stride; + index_t k_head_stride; + index_t v_head_stride; + + // The number of heads. + int h, h_k; + // In the case of multi-query and grouped-query attention (MQA/GQA), nheads_k could be + // different from nheads (query). + int h_h_k_ratio; // precompute h / h_k, +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +struct Flash_fwd_params : public Qkv_params { + + // The O matrix (output). + void * __restrict__ o_ptr; + void * __restrict__ oaccum_ptr; + + // The stride between rows of O. + index_t o_batch_stride; + index_t o_row_stride; + index_t o_head_stride; + + // The pointer to the P matrix. + void * __restrict__ p_ptr; + + // The pointer to the softmax sum. + void * __restrict__ softmax_lse_ptr; + void * __restrict__ softmax_lseaccum_ptr; + + // The dimensions. + int b, seqlen_q, seqlen_k, seqlen_knew, d, seqlen_q_rounded, seqlen_k_rounded, d_rounded, rotary_dim; + + // The scaling factors for the kernel. + float scale_softmax; + float scale_softmax_log2; + uint32_t scale_softmax_log2_half2; + + // array of length b+1 holding starting offset of each sequence. + int * __restrict__ cu_seqlens_q; + int * __restrict__ cu_seqlens_k; + + // If provided, the actual length of each k sequence. + int * __restrict__ seqused_k; + + int *__restrict__ blockmask; + + // The K_new and V_new matrices. + void * __restrict__ knew_ptr; + void * __restrict__ vnew_ptr; + + // The stride between rows of the Q, K and V matrices. + index_t knew_batch_stride; + index_t vnew_batch_stride; + index_t knew_row_stride; + index_t vnew_row_stride; + index_t knew_head_stride; + index_t vnew_head_stride; + + // The cos and sin matrices for rotary embedding. + void * __restrict__ rotary_cos_ptr; + void * __restrict__ rotary_sin_ptr; + + // The indices to index into the KV cache. + int * __restrict__ cache_batch_idx; + + // Paged KV cache + int * __restrict__ block_table; + index_t block_table_batch_stride; + int page_block_size; + + // The dropout probability (probability of keeping an activation). + float p_dropout; + // uint32_t p_dropout_in_uint; + // uint16_t p_dropout_in_uint16_t; + uint8_t p_dropout_in_uint8_t; + + // Scale factor of 1 / (1 - p_dropout). + float rp_dropout; + float scale_softmax_rp_dropout; + + // Local window size + int window_size_left, window_size_right; + + // Random state. + at::PhiloxCudaState philox_args; + + // Pointer to the RNG seed (idx 0) and offset (idx 1). + uint64_t * rng_state; + + bool is_bf16; + bool is_e4m3; + bool is_causal; + + // If is_seqlens_k_cumulative, then seqlen_k is cu_seqlens_k[bidb + 1] - cu_seqlens_k[bidb]. + // Otherwise it's cu_seqlens_k[bidb], i.e., we use cu_seqlens_k to store the sequence lengths of K. + bool is_seqlens_k_cumulative; + + bool is_rotary_interleaved; + + int num_splits; // For split-KV version + + void * __restrict__ alibi_slopes_ptr; + index_t alibi_slopes_batch_stride; + + int * __restrict__ tile_count_semaphore; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +struct Flash_bwd_params : public Flash_fwd_params { + + // The dO and dQKV matrices. + void *__restrict__ do_ptr; + void *__restrict__ dq_ptr; + void *__restrict__ dk_ptr; + void *__restrict__ dv_ptr; + + // To accumulate dQ + void *__restrict__ dq_accum_ptr; + void *__restrict__ dk_accum_ptr; + void *__restrict__ dv_accum_ptr; + + // // To accumulate dK and dV in case we're splitting the bwd along seqlen_q + // dimension void *__restrict__ dk_accum_ptr; void *__restrict__ + // dv_accum_ptr; + + // The stride between rows of the dO, dQ, dK and dV matrices. + // TD [2022-04-16]: We're using 32-bit indexing to save registers. + // The code probably won't work for arrays larger than 2GB. + index_t do_batch_stride; + index_t do_row_stride; + index_t do_head_stride; + index_t dq_batch_stride; + index_t dk_batch_stride; + index_t dv_batch_stride; + index_t dq_row_stride; + index_t dk_row_stride; + index_t dv_row_stride; + index_t dq_head_stride; + index_t dk_head_stride; + index_t dv_head_stride; + + // The pointer to the softmax d sum. + void *__restrict__ dsoftmax_sum; + + int *__restrict__ dq_semaphore; + + bool deterministic; + index_t dq_accum_split_stride; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream); +template void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream); diff --git a/hopper/flash_api.cpp b/hopper/flash_api.cpp new file mode 100644 index 000000000..f21d2d12e --- /dev/null +++ b/hopper/flash_api.cpp @@ -0,0 +1,580 @@ +/****************************************************************************** + * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. + ******************************************************************************/ + +// Include these 2 headers instead of torch/extension.h since we don't need all of the torch headers. +#include +#include +#include +#include + +#include + +#include "flash.h" +#include "static_switch.h" + +#define CHECK_DEVICE(x) TORCH_CHECK(x.is_cuda(), #x " must be on CUDA") +#define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")") +#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") + + +void set_params_fprop(Flash_fwd_params ¶ms, + // sizes + const size_t b, + const size_t seqlen_q, + const size_t seqlen_k, + const size_t seqlen_q_rounded, + const size_t seqlen_k_rounded, + const size_t h, + const size_t h_k, + const size_t d, + const size_t d_rounded, + // device pointers + const at::Tensor q, + const at::Tensor k, + const at::Tensor v, + at::Tensor out, + void *cu_seqlens_q_d, + void *cu_seqlens_k_d, + void *seqused_k, + void *p_d, + void *softmax_lse_d, + float p_dropout, + float softmax_scale, + int window_size_left, + int window_size_right, + bool seqlenq_ngroups_swapped=false) { + + // Reset the parameters + params = {}; + + params.is_bf16 = q.dtype() == torch::kBFloat16; + params.is_e4m3 = q.dtype() == torch::kFloat8_e4m3fn; + + // Set the pointers and strides. + params.q_ptr = q.data_ptr(); + params.k_ptr = k.data_ptr(); + params.v_ptr = v.data_ptr(); + // All stride are in elements, not bytes. + params.q_row_stride = q.stride(-3); + params.k_row_stride = k.stride(-3); + params.v_row_stride = v.stride(-3); + params.q_head_stride = q.stride(-2); + params.k_head_stride = k.stride(-2); + params.v_head_stride = v.stride(-2); + params.o_ptr = out.data_ptr(); + params.o_row_stride = out.stride(-3); + params.o_head_stride = out.stride(-2); + + if (cu_seqlens_q_d == nullptr) { + params.q_batch_stride = q.stride(0); + params.k_batch_stride = k.stride(0); + params.v_batch_stride = v.stride(0); + params.o_batch_stride = out.stride(0); + if (seqlenq_ngroups_swapped) { + params.q_batch_stride *= seqlen_q; + params.o_batch_stride *= seqlen_q; + } + } + + params.cu_seqlens_q = static_cast(cu_seqlens_q_d); + params.cu_seqlens_k = static_cast(cu_seqlens_k_d); + params.seqused_k = static_cast(seqused_k); + + // P = softmax(QK^T) + params.p_ptr = p_d; + + // Softmax sum + params.softmax_lse_ptr = softmax_lse_d; + + // Set the dimensions. + params.b = b; + params.h = h; + params.h_k = h_k; + params.h_h_k_ratio = h / h_k; + params.seqlen_q = seqlen_q; + params.seqlen_k = seqlen_k; + params.seqlen_q_rounded = seqlen_q_rounded; + params.seqlen_k_rounded = seqlen_k_rounded; + params.d = d; + params.d_rounded = d_rounded; + + // Set the different scale values. + params.scale_softmax = softmax_scale; + params.scale_softmax_log2 = softmax_scale * M_LOG2E; + __half scale_softmax_log2_half = __float2half(params.scale_softmax_log2); + __half2 scale_softmax_log2_half2 = __half2(scale_softmax_log2_half, scale_softmax_log2_half); + params.scale_softmax_log2_half2 = reinterpret_cast(scale_softmax_log2_half2); + + // Set this to probability of keeping an element to simplify things. + params.p_dropout = 1.f - p_dropout; + // Convert p from float to int so we don't have to convert the random uint to float to compare. + // [Minor] We want to round down since when we do the comparison we use <= instead of < + // params.p_dropout_in_uint = uint32_t(std::floor(params.p_dropout * 4294967295.0)); + // params.p_dropout_in_uint16_t = uint16_t(std::floor(params.p_dropout * 65535.0)); + params.p_dropout_in_uint8_t = uint8_t(std::floor(params.p_dropout * 255.0)); + params.rp_dropout = 1.f / params.p_dropout; + params.scale_softmax_rp_dropout = params.rp_dropout * params.scale_softmax; + TORCH_CHECK(p_dropout < 1.f); + #ifdef FLASHATTENTION_DISABLE_DROPOUT + TORCH_CHECK(p_dropout == 0.0f, "This flash attention build does not support dropout."); + #endif + + // Causal is the special case where window_size_right == 0 and window_size_left < 0. + // Local is the more general case where window_size_right >= 0 or window_size_left >= 0. + params.is_causal = window_size_left < 0 && window_size_right == 0; + + if (window_size_left < 0 && window_size_right >= 0) { window_size_left = seqlen_k; } + if (window_size_left >= 0 && window_size_right < 0) { window_size_right = seqlen_k; } + params.window_size_left = window_size_left; + params.window_size_right = window_size_right; + + #ifdef FLASHATTENTION_DISABLE_LOCAL + TORCH_CHECK(params.is_causal || (window_size_left < 0 && window_size_right < 0), + "This flash attention build does not support local attention."); + #endif + + params.is_seqlens_k_cumulative = true; + + #ifdef FLASHATTENTION_DISABLE_UNEVEN_K + TORCH_CHECK(d == d_rounded, "This flash attention build does not support headdim not being a multiple of 32."); + #endif +} + +void set_params_dgrad(Flash_bwd_params ¶ms, + // sizes + const size_t b, + const size_t seqlen_q, + const size_t seqlen_k, + const size_t seqlen_q_rounded, + const size_t seqlen_k_rounded, + const size_t h, + const size_t h_k, + const size_t d, + const size_t d_rounded, + // device pointers + const at::Tensor q, + const at::Tensor k, + const at::Tensor v, + const at::Tensor out, + const at::Tensor dout, + at::Tensor dq, + at::Tensor dk, + at::Tensor dv, + void *cu_seqlens_q_d, + void *cu_seqlens_k_d, + void *dq_accum_d, + void *dk_accum_d, + void *dv_accum_d, + void *softmax_lse_d, + void *dsoftmax_sum_d, + float p_dropout, + float softmax_scale, + int window_size_left, + int window_size_right, + bool deterministic) { + + set_params_fprop(params, + b, seqlen_q, seqlen_k, seqlen_q_rounded, seqlen_k_rounded, h, h_k, d, d_rounded, + q, k, v, out, + cu_seqlens_q_d, + cu_seqlens_k_d, + nullptr, + nullptr, + softmax_lse_d, + p_dropout, + softmax_scale, + window_size_left, + window_size_right); + + // Set the pointers and strides. + params.do_ptr = dout.data_ptr(); + params.do_row_stride = dout.stride(-3); + params.do_head_stride = dout.stride(-2); + params.dq_ptr = dq.data_ptr(); + params.dk_ptr = dk.data_ptr(); + params.dv_ptr = dv.data_ptr(); + params.dq_row_stride = dq.stride(-3); + params.dk_row_stride = dk.stride(-3); + params.dv_row_stride = dv.stride(-3); + params.dq_head_stride = dq.stride(-2); + params.dk_head_stride = dk.stride(-2); + params.dv_head_stride = dv.stride(-2); + + if (cu_seqlens_q_d == nullptr) { + params.do_batch_stride = dout.stride(0); + params.dq_batch_stride = dq.stride(0); + params.dk_batch_stride = dk.stride(0); + params.dv_batch_stride = dv.stride(0); + } + + params.dq_accum_ptr = dq_accum_d; + params.dk_accum_ptr = dk_accum_d; + params.dv_accum_ptr = dv_accum_d; + + // Softmax sum + params.dsoftmax_sum = dsoftmax_sum_d; + + params.deterministic = deterministic; +} + +void run_mha_fwd(Flash_fwd_params ¶ms, cudaStream_t stream, bool force_split_kernel=false) { + // HEADDIM_SWITCH(params.d, [&] { + // run_mha_fwd_(params, stream); + // }); + if (!params.is_e4m3) { + if (params.is_bf16) { + if (params.d == 64) { + run_mha_fwd_(params, stream); + } else if (params.d == 128) { + run_mha_fwd_(params, stream); + } else { + run_mha_fwd_(params, stream); + } + } else { + if (params.d == 64) { + run_mha_fwd_(params, stream); + } else if (params.d == 128) { + run_mha_fwd_(params, stream); + } else { + run_mha_fwd_(params, stream); + } + } + } else { + // run_mha_fwd_(params, stream); + } +} + +std::vector +mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size + const at::Tensor &k, // batch_size x seqlen_k x num_heads_k x head_size + const at::Tensor &v, // batch_size x seqlen_k x num_heads_k x head_size + c10::optional &out_, // batch_size x seqlen_q x num_heads x head_size + const float softmax_scale, + bool is_causal) { + + auto dprops = at::cuda::getCurrentDeviceProperties(); + bool is_sm90 = dprops->major == 9 && dprops->minor == 0; + TORCH_CHECK(is_sm90, "FlashAttention only supports Hopper GPUs or newer."); + + auto q_dtype = q.dtype(); + TORCH_CHECK(q_dtype == torch::kFloat16 || q_dtype == torch::kBFloat16, + "FlashAttention only support fp16 and bf16 data type for now"); + // TODO: will add e4m3 later + // TORCH_CHECK(q_dtype == torch::kFloat16 || q_dtype == torch::kFloat8_e4m3fn, + // "FlashAttention only support fp16 and bf16 data type"); + // "FlashAttention only support fp16 and fp8 (e4m3) data type for now"); + TORCH_CHECK(k.dtype() == q_dtype, "query and key must have the same dtype"); + TORCH_CHECK(v.dtype() == q_dtype, "query and value must have the same dtype"); + + CHECK_DEVICE(q); CHECK_DEVICE(k); CHECK_DEVICE(v); + + TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension"); + TORCH_CHECK(k.stride(-1) == 1, "Input tensor must have contiguous last dimension"); + TORCH_CHECK(v.stride(-1) == 1, "Input tensor must have contiguous last dimension"); + + TORCH_CHECK(q.is_contiguous(), "Input tensor must be contiguous"); + TORCH_CHECK(k.is_contiguous(), "Input tensor must be contiguous"); + TORCH_CHECK(v.is_contiguous(), "Input tensor must be contiguous"); + + const auto sizes = q.sizes(); + + const int batch_size = sizes[0]; + int seqlen_q = sizes[1]; + int num_heads = sizes[2]; + const int head_size_og = sizes[3]; + const int seqlen_k = k.size(1); + const int num_heads_k = k.size(2); + TORCH_CHECK(batch_size > 0, "batch size must be positive"); + TORCH_CHECK(head_size_og <= 256, "FlashAttention forward only supports head dimension at most 256"); + TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query"); + + TORCH_CHECK(head_size_og == 64 || head_size_og == 128 || head_size_og == 256, "Only support head size 64, 128, and 256 for now"); + + CHECK_SHAPE(q, batch_size, seqlen_q, num_heads, head_size_og); + CHECK_SHAPE(k, batch_size, seqlen_k, num_heads_k, head_size_og); + CHECK_SHAPE(v, batch_size, seqlen_k, num_heads_k, head_size_og); + + at::Tensor q_padded, k_padded, v_padded; + if (head_size_og % 8 != 0) { + q_padded = torch::nn::functional::pad(q, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8})); + k_padded = torch::nn::functional::pad(k, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8})); + v_padded = torch::nn::functional::pad(v, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8})); + } else { + q_padded = q; + k_padded = k; + v_padded = v; + } + + at::Tensor out; + if (out_.has_value()) { + out = out_.value(); + TORCH_CHECK(out.dtype() == q_dtype, "Output must have the same dtype as inputs"); + CHECK_DEVICE(out); + TORCH_CHECK(out.stride(-1) == 1, "Output tensor must have contiguous last dimension"); + CHECK_SHAPE(out, batch_size, seqlen_q, num_heads, head_size_og); + if (head_size_og % 8 != 0) { out = torch::empty_like(q_padded); } + } else { + out = torch::empty_like(q_padded); + } + + auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; }; + const int head_size = round_multiple(head_size_og, 8); + const int head_size_rounded = round_multiple(head_size, 32); + const int seqlen_q_rounded = round_multiple(seqlen_q, 128); + const int seqlen_k_rounded = round_multiple(seqlen_k, 128); + + // Otherwise the kernel will be launched from cuda:0 device + // Cast to char to avoid compiler warning about narrowing + at::cuda::CUDAGuard device_guard{(char)q.get_device()}; + + auto opts = q.options(); + + auto softmax_lse = torch::empty({batch_size, num_heads, seqlen_q}, opts.dtype(at::kFloat)); + at::Tensor p; + + Flash_fwd_params params; + set_params_fprop(params, + batch_size, + seqlen_q, seqlen_k, + seqlen_q_rounded, seqlen_k_rounded, + num_heads, num_heads_k, + head_size, head_size_rounded, + q_padded, k_padded, v_padded, out, + /*cu_seqlens_q_d=*/nullptr, + /*cu_seqlens_k_d=*/nullptr, + /*seqused_k=*/nullptr, + nullptr, + softmax_lse.data_ptr(), + /*p_dropout=*/0.f, + softmax_scale, + /*window_size_left=*/-1, + /*window_size_right=*/is_causal ? 0 : -1); + + auto tile_count_semaphore = is_causal ? torch::zeros({1}, opts.dtype(torch::kInt32)) : torch::empty({1}, opts.dtype(torch::kInt32)); + params.tile_count_semaphore = tile_count_semaphore.data_ptr(); + + if (seqlen_k > 0) { + auto stream = at::cuda::getCurrentCUDAStream().stream(); + run_mha_fwd(params, stream); + } else { + // If seqlen_k == 0, then we have an empty tensor. We need to set the output to 0. + out.zero_(); + softmax_lse.fill_(std::numeric_limits::infinity()); + } + + at::Tensor out_padded = out; + if (head_size_og % 8 != 0) { + out = out.index({"...", torch::indexing::Slice(torch::indexing::None, head_size_og)}); + if (out_.has_value()) { out_.value().copy_(out); } + } + + return {out, q_padded, k_padded, v_padded, out_padded, softmax_lse, p}; +} + +void run_mha_bwd(Flash_bwd_params ¶ms, cudaStream_t stream) { + // FP16_SWITCH(!params.is_bf16, [&] { + // HEADDIM_SWITCH(params.d, [&] { + // run_mha_bwd_(params, stream); + // }); + // }); + if (params.d == 64) { + run_mha_bwd_(params, stream); + } else if (params.d == 128) { + run_mha_bwd_(params, stream); + } else { + run_mha_bwd_(params, stream); + } +} + +std::vector +mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_size_og + const at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size + const at::Tensor &k, // batch_size x seqlen_k x num_heads_k x head_size + const at::Tensor &v, // batch_size x seqlen_k x num_heads_k x head_size + const at::Tensor &out, // batch_size x seqlen_q x num_heads x head_size + const at::Tensor &softmax_lse, // b x h x seqlen_q + c10::optional &dq_, // batch_size x seqlen_q x num_heads x head_size + c10::optional &dk_, // batch_size x seqlen_k x num_heads_k x head_size + c10::optional &dv_, // batch_size x seqlen_k x num_heads_k x head_size + const float softmax_scale, + const bool is_causal) { + + #ifdef FLASHATTENTION_DISABLE_BACKWARD + TORCH_CHECK(false, "This flash attention build does not support backward."); + #endif + auto dprops = at::cuda::getCurrentDeviceProperties(); + bool is_sm9x = dprops->major == 9 && dprops->minor >= 0; + TORCH_CHECK(is_sm9x, "FlashAttentionHopper only supports Hopper GPUs or newer."); + + auto stream = at::cuda::getCurrentCUDAStream().stream(); + + auto q_dtype = q.dtype(); + TORCH_CHECK(q_dtype == torch::kFloat16, + // "FlashAttention only support fp16 and bf16 data type"); + "FlashAttention only support fp16 data type for now"); + TORCH_CHECK(k.dtype() == q_dtype, "query and key must have the same dtype"); + TORCH_CHECK(v.dtype() == q_dtype, "query and value must have the same dtype"); + TORCH_CHECK(out.dtype() == q_dtype, "query and out must have the same dtype"); + TORCH_CHECK(dout.dtype() == q_dtype, "query and dout must have the same dtype"); + + CHECK_DEVICE(q); CHECK_DEVICE(k); CHECK_DEVICE(v); + CHECK_DEVICE(out); CHECK_DEVICE(dout); CHECK_DEVICE(softmax_lse); + + TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension"); + TORCH_CHECK(k.stride(-1) == 1, "Input tensor must have contiguous last dimension"); + TORCH_CHECK(v.stride(-1) == 1, "Input tensor must have contiguous last dimension"); + TORCH_CHECK(out.stride(-1) == 1, "out tensor must have contiguous last dimension"); + TORCH_CHECK(dout.stride(-1) == 1, "dout tensor must have contiguous last dimension"); + + TORCH_CHECK(q.is_contiguous(), "Input tensor must be contiguous"); + TORCH_CHECK(k.is_contiguous(), "Input tensor must be contiguous"); + TORCH_CHECK(v.is_contiguous(), "Input tensor must be contiguous"); + + const auto sizes = q.sizes(); + + const int batch_size = sizes[0]; + const int seqlen_q = sizes[1]; + const int num_heads = sizes[2]; + const int head_size_og = dout.size(3); + const int head_size = sizes[3]; + const int seqlen_k = k.size(1); + const int num_heads_k = k.size(2); + TORCH_CHECK(batch_size > 0, "batch size must be positive"); + TORCH_CHECK(head_size % 8 == 0, "head_size should be a multiple of 8"); + TORCH_CHECK(head_size <= 256, "FlashAttention backward only supports head dimension at most 256"); + TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query"); + + TORCH_CHECK(head_size_og == 64 || head_size_og == 128, "Only support head size 64 and 128 for now"); + + auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; }; + const int head_size_rounded = round_multiple(head_size, 32); + const int seqlen_q_rounded = round_multiple(seqlen_q, 128); + const int seqlen_k_rounded = round_multiple(seqlen_k, 128); + + TORCH_CHECK(head_size == round_multiple(head_size_og, 8), "head_size must be head_size_og rounded to a multiple of 8"); + + CHECK_SHAPE(q, batch_size, seqlen_q, num_heads, head_size); + CHECK_SHAPE(k, batch_size, seqlen_k, num_heads_k, head_size); + CHECK_SHAPE(v, batch_size, seqlen_k, num_heads_k, head_size); + CHECK_SHAPE(out, batch_size, seqlen_q, num_heads, head_size); + CHECK_SHAPE(dout, batch_size, seqlen_q, num_heads, head_size_og); + + at::Tensor dq, dk, dv; + if (dq_.has_value()) { + dq = dq_.value(); + TORCH_CHECK(dq.dtype() == q_dtype, "dq must have the same dtype as q"); + CHECK_DEVICE(dq); + TORCH_CHECK(dq.stride(-1) == 1, "dq must have contiguous last dimension"); + CHECK_SHAPE(dq, batch_size, seqlen_q, num_heads, head_size); + } else { + dq = torch::empty_like(q); + } + if (dk_.has_value()) { + dk = dk_.value(); + TORCH_CHECK(dk.dtype() == q_dtype, "dk must have the same dtype as q"); + CHECK_DEVICE(dk); + TORCH_CHECK(dk.stride(-1) == 1, "dk must have contiguous last dimension"); + CHECK_SHAPE(dk, batch_size, seqlen_k, num_heads_k, head_size); + } else { + dk = torch::empty_like(k); + } + if (dv_.has_value()) { + dv = dv_.value(); + TORCH_CHECK(dv.dtype() == q_dtype, "dv must have the same dtype as q"); + CHECK_DEVICE(dv); + TORCH_CHECK(dv.stride(-1) == 1, "dv must have contiguous last dimension"); + CHECK_SHAPE(dv, batch_size, seqlen_k, num_heads_k, head_size); + } else { + dv = torch::empty_like(v); + } + + at::Tensor dout_padded; + if (head_size_og % 8 != 0) { + dout_padded = torch::nn::functional::pad(dout, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8})); + } else { + dout_padded = dout; + } + + // bool loop = seqlen_k > blocksize_c; + // TODO: change later, for now set to true for simplicity + bool loop = true; + + // Otherwise the kernel will be launched from cuda:0 device + // Cast to char to avoid compiler warning about narrowing + at::cuda::CUDAGuard device_guard{(char)q.get_device()}; + + auto opts = q.options(); + auto softmax_d = torch::empty({batch_size, num_heads, seqlen_q_rounded}, opts.dtype(at::kFloat)); + at::Tensor dq_accum; + at::Tensor dk_accum, dv_accum; + if (loop) { + dq_accum = torch::empty({batch_size, seqlen_q_rounded, num_heads, head_size_rounded}, opts.dtype(at::kFloat)); + // dk_accum = torch::zeros({batch_size, seqlen_k_rounded, num_heads_k, head_size_rounded}, opts.dtype(at::kFloat)); + // dv_accum = torch::zeros({batch_size, seqlen_k_rounded, num_heads_k, head_size_rounded}, opts.dtype(at::kFloat)); + } + + at::Tensor dk_expanded, dv_expanded; + if (num_heads_k != num_heads) { // MQA / GQA + dk_expanded = torch::empty({batch_size, seqlen_k, num_heads, head_size}, opts); + dv_expanded = torch::empty({batch_size, seqlen_k, num_heads, head_size}, opts); + } else { + dk_expanded = dk; + dv_expanded = dv; + } + + Flash_bwd_params params; + + set_params_dgrad(params, + batch_size, + seqlen_q, seqlen_k, + seqlen_q_rounded, seqlen_k_rounded, + num_heads, num_heads_k, + head_size, head_size_rounded, + q, k, v, out, + dout_padded, dq, dk_expanded, dv_expanded, + nullptr, + nullptr, + loop ? dq_accum.data_ptr() : nullptr, + // loop ? dk_accum.data_ptr() : nullptr, + // loop ? dv_accum.data_ptr() : nullptr, + nullptr, + nullptr, + softmax_lse.data_ptr(), + softmax_d.data_ptr(), + /*p_dropout=*/0.f, + softmax_scale, + /*window_size_left=*/-1, + /*window_size_right=*/-1, + /*deterministic=*/false); + + at::Tensor dq_semaphore = torch::zeros({(seqlen_q + 64 - 1) / 64, batch_size, num_heads}, opts.dtype(torch::kInt32)); + params.dq_semaphore = dq_semaphore.data_ptr(); + // printf("dq_semaphore: %p, [%d, %d, %d]\n", params.dq_semaphore, (seqlen_q + 64 - 1) / 64, batch_size, num_heads); + + auto launch = &run_mha_bwd; + + if (seqlen_q > 0) { + launch(params, stream); + } else { + // If seqlen_q == 0, then we have an empty tensor. We need to set the output to 0. + dk_expanded.zero_(); + dv_expanded.zero_(); + softmax_d.zero_(); + } + + if (head_size_og % 8 != 0) { + dq = dq.index({"...", torch::indexing::Slice(torch::indexing::None, head_size_og)}); + dk = dk.index({"...", torch::indexing::Slice(torch::indexing::None, head_size_og)}); + dv = dv.index({"...", torch::indexing::Slice(torch::indexing::None, head_size_og)}); + } + + return { dq, dk, dv, softmax_d }; +} + + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.doc() = "FlashAttention"; + m.def("fwd", &mha_fwd, "Forward pass"); + m.def("bwd", &mha_bwd, "Backward pass"); +} diff --git a/hopper/flash_attn_interface.py b/hopper/flash_attn_interface.py new file mode 100644 index 000000000..c09342826 --- /dev/null +++ b/hopper/flash_attn_interface.py @@ -0,0 +1,169 @@ +# Copyright (c) 2023, Tri Dao. + +from typing import Optional, Union + +import torch +import torch.nn as nn + +# isort: off +# We need to import the CUDA kernels after importing torch +import flashattn_hopper_cuda + +# isort: on + +def maybe_contiguous(x): + return x.contiguous() if x is not None and x.stride(-1) != 1 else x + +def _flash_attn_forward(q, k, v, softmax_scale, causal): + q, k, v = [maybe_contiguous(x) for x in (q, k, v)] + out, q, k, v, out_padded, softmax_lse, S_dmask = flashattn_hopper_cuda.fwd( + q, + k, + v, + None, + softmax_scale, + causal, + ) + return out, q, k, v, out_padded, softmax_lse, S_dmask + + +def _flash_attn_backward( + dout, + q, + k, + v, + out, + softmax_lse, + dq, + dk, + dv, + softmax_scale, + causal +): + # dq, dk, dv are allocated by us so they should already be contiguous + dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)] + dq, dk, dv, softmax_d, = flashattn_hopper_cuda.bwd( + dout, + q, + k, + v, + out, + softmax_lse, + dq, + dk, + dv, + softmax_scale, + causal, + ) + return dq, dk, dv, softmax_d + + +class FlashAttnFunc(torch.autograd.Function): + @staticmethod + def forward( + ctx, + q, + k, + v, + softmax_scale, + causal, + ): + if softmax_scale is None: + softmax_scale = q.shape[-1] ** (-0.5) + out, q, k, v, out_padded, softmax_lse, S_dmask = _flash_attn_forward( + q, + k, + v, + softmax_scale, + causal + ) + ctx.save_for_backward(q, k, v, out_padded, softmax_lse) + ctx.softmax_scale = softmax_scale + ctx.causal = causal + return out, softmax_lse + + @staticmethod + def backward(ctx, dout, *args): + q, k, v, out, softmax_lse = ctx.saved_tensors + dq, dk, dv = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v) + _flash_attn_backward( + dout, + q, + k, + v, + out, + softmax_lse, + dq, + dk, + dv, + ctx.softmax_scale, + ctx.causal, + ) + dq = dq[..., : dout.shape[-1]] # We could have padded the head dimension + dk = dk[..., : dout.shape[-1]] + dv = dv[..., : dout.shape[-1]] + return dq, dk, dv, None, None + + +def flash_attn_func( + q, + k, + v, + softmax_scale=None, + causal=False, +): + """dropout_p should be set to 0.0 during evaluation + Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads + than Q. Note that the number of heads in Q must be divisible by the number of heads in KV. + For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head + 0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V. + + If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix. + For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is: + 1 1 1 1 0 + 1 1 1 1 1 + If seqlen_q = 5 and seqlen_k = 2, the causal mask is: + 0 0 + 0 0 + 0 0 + 1 0 + 1 1 + If the row of the mask is all zero, the output will be zero. + + If window_size != (-1, -1), implements sliding window local attention. Query at position i + will only attend to keys between + [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive. + + Arguments: + q: (batch_size, seqlen, nheads, headdim) + k: (batch_size, seqlen, nheads_k, headdim) + v: (batch_size, seqlen, nheads_k, headdim) + dropout_p: float. Dropout probability. + softmax_scale: float. The scaling of QK^T before applying softmax. + Default to 1 / sqrt(headdim). + causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling). + window_size: (left, right). If not (-1, -1), implements sliding window local attention. + alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of + (-alibi_slope * |i + seqlen_k - seqlen_q - j|) + is added to the attention score of query i and key j. + deterministic: bool. Whether to use the deterministic implementation of the backward pass, + which is slightly slower and uses more memory. The forward pass is always deterministic. + return_attn_probs: bool. Whether to return the attention probabilities. This option is for + testing only. The returned probabilities are not guaranteed to be correct + (they might not have the right scaling). + Return: + out: (batch_size, seqlen, nheads, headdim). + softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The + logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax + normalization factor). + S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen). + The output of softmax (possibly with different scaling). It also encodes the dropout + pattern (negative means that location was dropped, nonnegative means it was kept). + """ + return FlashAttnFunc.apply( + q, + k, + v, + softmax_scale, + causal, + ) diff --git a/hopper/flash_bwd_hdim128_fp16_sm90.cu b/hopper/flash_bwd_hdim128_fp16_sm90.cu new file mode 100644 index 000000000..01a1d469f --- /dev/null +++ b/hopper/flash_bwd_hdim128_fp16_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. + +#include "flash_bwd_launch_template.h" + +template<> +void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { + run_mha_bwd_hdim128(params, stream); +} diff --git a/hopper/flash_bwd_hdim256_fp16_sm90.cu b/hopper/flash_bwd_hdim256_fp16_sm90.cu new file mode 100644 index 000000000..ee139bd2e --- /dev/null +++ b/hopper/flash_bwd_hdim256_fp16_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. + +#include "flash_bwd_launch_template.h" + +template<> +void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { + run_mha_bwd_hdim256(params, stream); +} diff --git a/hopper/flash_bwd_hdim64_fp16_sm90.cu b/hopper/flash_bwd_hdim64_fp16_sm90.cu new file mode 100644 index 000000000..f486e870f --- /dev/null +++ b/hopper/flash_bwd_hdim64_fp16_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. + +#include "flash_bwd_launch_template.h" + +template<> +void run_mha_bwd_(Flash_bwd_params ¶ms, cudaStream_t stream) { + run_mha_bwd_hdim64(params, stream); +} diff --git a/hopper/flash_bwd_kernel.h b/hopper/flash_bwd_kernel.h new file mode 100644 index 000000000..b510ecc5f --- /dev/null +++ b/hopper/flash_bwd_kernel.h @@ -0,0 +1,2042 @@ +/****************************************************************************** + * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. + ******************************************************************************/ + +#pragma once + + +#include "cute/tensor.hpp" + +#include +#include +#include +#include +#include +#include +#include "cutlass/pipeline/pipeline.hpp" + +#include "flash.h" +#include "utils.h" +#include "softmax.h" + +namespace flash { + +using namespace cute; + +template +__global__ void __launch_bounds__(Ktraits::kNWarps * cutlass::NumThreadsPerWarp, 1) + compute_dqkv(CUTE_GRID_CONSTANT Flash_bwd_params const params, + CUTE_GRID_CONSTANT TiledCopyQ const tma_load_Q, + CUTE_GRID_CONSTANT TiledCopydO const tma_load_dO, + CUTE_GRID_CONSTANT TiledCopyK const tma_load_K, + CUTE_GRID_CONSTANT TiledCopyV const tma_load_V, + CUTE_GRID_CONSTANT TiledCopydK const tma_store_dK, + CUTE_GRID_CONSTANT TiledCopydV const tma_store_dV) { + + using Element = typename Ktraits::Element; + using ElementAccum = typename Ktraits::ElementAccum; + using SoftType = ElementAccum; + using index_t = typename Ktraits::index_t; + using TileShape_MNK = typename Ktraits::TileShape_MNK; + using ClusterShape = typename Ktraits::ClusterShape_MNK; + + static constexpr int kNThreads = Ktraits::kNThreads; + // static constexpr int NumMmaThreads = size(typename Ktraits::TiledMmaSdP{}); + static constexpr int NumMmaThreads = Ktraits::kNThreads; + static constexpr int kBlockM = Ktraits::kBlockM; + // static constexpr int kBlockN = Ktraits::kBlockN; + // constexpr int kHeadDim = Ktraits::kHeadDim; + static constexpr int kStages = Ktraits::kStages; + + static constexpr bool SdP_swapAB = Ktraits::SdP_swapAB; + static constexpr bool dKV_swapAB = Ktraits::dKV_swapAB; + static constexpr bool dQ_swapAB = Ktraits::dQ_swapAB; + + static constexpr bool Mma_dQ_is_RS = Ktraits::Mma_dQ_is_RS; + if constexpr (dQ_swapAB) { static_assert(!Mma_dQ_is_RS); } + + using MainloopPipeline = typename Ktraits::MainloopPipeline; + using PipelineParams = typename MainloopPipeline::Params; + using PipelineState = typename MainloopPipeline::PipelineState; + + extern __shared__ char shared_memory[]; + auto &shared_storage = *reinterpret_cast(shared_memory); + + int const n_block = blockIdx.x; + int const bidb = blockIdx.z; // The block index for the batch. + int const bidh = blockIdx.y; // The block index for the head. + + int lane_predicate = cute::elect_one_sync(); + int warp_idx = cutlass::canonical_warp_idx_sync(); + + // Issue Tma Descriptor Prefetch from a single thread + if (warp_idx == 0 && lane_predicate) { + cute::prefetch_tma_descriptor(tma_load_Q.get_tma_descriptor()); + cute::prefetch_tma_descriptor(tma_load_dO.get_tma_descriptor()); + cute::prefetch_tma_descriptor(tma_load_K.get_tma_descriptor()); + cute::prefetch_tma_descriptor(tma_load_V.get_tma_descriptor()); + cute::prefetch_tma_descriptor(tma_store_dK.get_tma_descriptor()); + cute::prefetch_tma_descriptor(tma_store_dV.get_tma_descriptor()); + } + + Tensor mQ = tma_load_Q.get_tma_tensor(make_shape(params.seqlen_q, params.d, params.h, params.b)); + Tensor mdO = tma_load_dO.get_tma_tensor(make_shape(params.seqlen_q, params.d, params.h, params.b)); + Tensor mK = tma_load_K.get_tma_tensor(make_shape(params.seqlen_k, params.d, params.h, params.b)); + Tensor mV = tma_load_V.get_tma_tensor(make_shape(params.seqlen_k, params.d, params.h, params.b)); + Tensor mLSE = make_tensor(make_gmem_ptr(reinterpret_cast(params.softmax_lse_ptr)), + make_shape(params.b, params.h, params.seqlen_q), + make_stride(params.h * params.seqlen_q, params.seqlen_q, _1{})); + Tensor mdPsum = make_tensor(make_gmem_ptr(reinterpret_cast(params.dsoftmax_sum)), + make_shape(params.b, params.h, params.seqlen_q), + make_stride(params.h * params.seqlen_q_rounded, params.seqlen_q_rounded, _1{})); + Tensor mdQaccum = make_tensor(make_gmem_ptr(reinterpret_cast(params.dq_accum_ptr)), + make_shape(params.seqlen_q, params.d, params.h, params.b), + make_stride(params.d * params.h, _1{}, params.d, params.d * params.h * params.seqlen_q_rounded)); + + + Tensor gQ = local_tile(mQ(_, _, bidh, bidb), select<0, 2>(TileShape_MNK{}), make_coord(_, _0{})); // (M, K, _) + Tensor gdO = local_tile(mdO(_, _, bidh, bidb), select<0, 2>(TileShape_MNK{}), make_coord(_, _0{})); // (M, K, _) + Tensor gK = local_tile(mK(_, _, bidh, bidb), select<1, 2>(TileShape_MNK{}), make_coord(n_block, _0{})); // (N, K) + Tensor gV = local_tile(mV(_, _, bidh, bidb), select<1, 2>(TileShape_MNK{}), make_coord(n_block, _0{})); // (N, K) + Tensor gdQaccum = local_tile(mdQaccum(_, _, bidh, bidb), select<0, 2>(TileShape_MNK{}), make_coord(_, _0{})); // (M, K, _) + // if (cute::thread0()) { print(tma_load_K); printf("\n"); } + // if (cute::thread0()) { print(mK); printf("\n"); print(gK); printf("\n"); } + + typename Ktraits::GmemTiledCopydQaccum gmem_tiled_copy_dQaccum; + auto gmem_thr_copy_dQaccum = gmem_tiled_copy_dQaccum.get_thread_slice(threadIdx.x); + Tensor tdQgdQaccum = gmem_thr_copy_dQaccum.partition_D(gdQaccum); + + // Construct SMEM tensors. + Tensor sQ = make_tensor(make_smem_ptr(shared_storage.smem_q.data()), typename Ktraits::SmemLayoutQ{}); + Tensor sdO = make_tensor(make_smem_ptr(shared_storage.smem_do.data()), typename Ktraits::SmemLayoutdO{}); + Tensor sK = make_tensor(make_smem_ptr(shared_storage.smem_k.data()), typename Ktraits::SmemLayoutK{}); + Tensor sV = make_tensor(make_smem_ptr(shared_storage.smem_v.data()), typename Ktraits::SmemLayoutV{}); + Tensor sP = make_tensor(make_smem_ptr(shared_storage.smem_p.data()), typename Ktraits::SmemLayoutP{}); + Tensor sdS = make_tensor(make_smem_ptr(shared_storage.smem_ds.data()), typename Ktraits::SmemLayoutdS{}); + Tensor sQt = make_tensor(make_smem_ptr(shared_storage.smem_q.data()), typename Ktraits::SmemLayoutQt{}); + Tensor sdOt = make_tensor(make_smem_ptr(shared_storage.smem_do.data()), typename Ktraits::SmemLayoutdOt{}); + Tensor sKt = make_tensor(make_smem_ptr(shared_storage.smem_k.data()), typename Ktraits::SmemLayoutKt{}); + Tensor sPt = make_tensor(make_smem_ptr(shared_storage.smem_p.data()), typename Ktraits::SmemLayoutPt{}); + Tensor sdSt = make_tensor(make_smem_ptr(shared_storage.smem_ds.data()), typename Ktraits::SmemLayoutdSt{}); + + // Prepare the TMA loads + uint32_t block_rank_in_cluster = cute::block_rank_in_cluster(); + constexpr uint32_t cluster_shape_x = get<0>(ClusterShape()); + uint2 cluster_local_block_id = {block_rank_in_cluster % cluster_shape_x, block_rank_in_cluster / cluster_shape_x}; + auto block_tma_Q = tma_load_Q.get_slice(cluster_local_block_id.y); + auto block_tma_dO = tma_load_dO.get_slice(cluster_local_block_id.y); + auto block_tma_K = tma_load_K.get_slice(_0{}); + auto block_tma_V = tma_load_V.get_slice(_0{}); + + Tensor tQgQ = block_tma_Q.partition_S(gQ); // (TMA, TMA_M, TMA_K, k) + Tensor tQsQ = block_tma_Q.partition_D(sQ); // (TMA, TMA_M, TMA_K, PIPE) + Tensor tdOgdO = block_tma_dO.partition_S(gdO); // (TMA, TMA_M, TMA_K, k) + Tensor tdOsdO = block_tma_dO.partition_D(sdO); // (TMA, TMA_M, TMA_K, PIPE) + Tensor tKgK = block_tma_K.partition_S(gK); // (TMA, TMA_N, TMA_K) + Tensor tKsK = block_tma_K.partition_D(sK); // (TMA, TMA_N, TMA_K) + Tensor tVgV = block_tma_V.partition_S(gV); // (TMA, TMA_N, TMA_K) + Tensor tVsV = block_tma_V.partition_D(sV); // (TMA, TMA_N, TMA_K) + // if (cute::thread0()) { print(tQgQ); printf("\n"); print(tQsQ); printf("\n"); } + // if (cute::thread0()) { print(tKgK); printf("\n"); print(tKsK); printf("\n"); } + + // Set the bytes transferred in this TMA transaction (may involve multiple issues) + constexpr uint32_t TmaTransactionBytesQ = static_cast(size<0>(sQ) * size<1>(sQ) * cutlass::sizeof_bits_v / 8); + constexpr uint32_t TmaTransactionBytesdO = static_cast(size<0>(sdO) * size<1>(sdO) * cutlass::sizeof_bits_v / 8); + static_assert(TmaTransactionBytesQ == TmaTransactionBytesdO); + constexpr uint32_t TmaTransactionBytesK = static_cast(size<0>(sK) * size<1>(sK) * cutlass::sizeof_bits_v / 8); + constexpr uint32_t TmaTransactionBytesV = static_cast(size<0>(sV) * size<1>(sV) * cutlass::sizeof_bits_v / 8); + static_assert(TmaTransactionBytesK == TmaTransactionBytesV); + + // Obtain warp index + int thread_idx = int(threadIdx.x); + int warp_group_thread_idx = thread_idx % cutlass::NumThreadsPerWarpGroup; + // int const warp_group_thread_idx = threadIdx.x % cutlass::NumThreadsPerWarpGroup; + + PipelineParams pipeline_params; + pipeline_params.transaction_bytes = TmaTransactionBytesQ; + pipeline_params.role = MainloopPipeline::ThreadCategory::ProducerConsumer; + pipeline_params.is_leader = warp_group_thread_idx == 0; + pipeline_params.num_consumers = NumMmaThreads; + + if (warp_idx == 0 && lane_predicate) { + shared_storage.barrier_K.init(1 /*numThreads*/); + shared_storage.barrier_V.init(1 /*numThreads*/); + } + // cutlass::arch::fence_barrier_init(); + // We're counting on pipeline_q to call fence_barrier_init(); + MainloopPipeline pipeline_q(shared_storage.pipeline_q, pipeline_params, ClusterShape{}); + MainloopPipeline pipeline_do(shared_storage.pipeline_do, pipeline_params, ClusterShape{}); + + // We need this to guarantee that the Pipeline init is visible + // To all producers and consumer blocks in the Cluster + if constexpr (size(ClusterShape{}) > 1) { + cute::cluster_arrive_relaxed(); + cute::cluster_wait(); + } else { + __syncthreads(); + } + + // State variables used for iterating the circular buffer + // smem_pipe_read / release is used by the consumer of SMEM data - i.e MMA + // smem_pipe_write is used by the producer of SMEM data - i.e TMA + PipelineState smem_pipe_read_q, smem_pipe_read_do; + PipelineState smem_pipe_release_q, smem_pipe_release_do; + PipelineState smem_pipe_write_q = cutlass::make_producer_start_state(); + PipelineState smem_pipe_write_do = cutlass::make_producer_start_state(); + + // Copy K tile and V tile from GMEM to SMEM. + if (warp_idx == 0 && lane_predicate) { + shared_storage.barrier_K.arrive_and_expect_tx(TmaTransactionBytesK); + copy(tma_load_K.with(reinterpret_cast(shared_storage.barrier_K), 0 /*mcast_mask*/), tKgK, tKsK); + shared_storage.barrier_V.arrive_and_expect_tx(TmaTransactionBytesV); + copy(tma_load_V.with(reinterpret_cast(shared_storage.barrier_V), 0 /*mcast_mask*/), tVgV, tVsV); + } + // if (cute::thread0()) { print_tensor(sQ); printf("\n"); } __syncthreads(); + + int m_block = cute::ceil_div(params.seqlen_q, kBlockM) - 1; + + uint16_t mcast_mask_qdo = 0; + if constexpr (cute::is_same_v) { + auto block_layout = Layout{}; // (m,n) -> block_id + for (int n = 0; n < size<1>(block_layout); ++n) { + mcast_mask_qdo |= (uint16_t(1) << block_layout(n, cluster_local_block_id.x, _0{})); + } + } + // Issue TmaLoads (Prologue fetches) + if (warp_idx == 0 && lane_predicate) { + // Issue the prologue loads + CUTLASS_PRAGMA_UNROLL + for (int stage = 0; stage < kStages && stage <= m_block; ++stage) { + pipeline_q.producer_acquire(smem_pipe_write_q); + copy(tma_load_Q.with(*pipeline_q.producer_get_barrier(smem_pipe_write_q), mcast_mask_qdo), tQgQ(_, _, _, m_block - stage), tQsQ(_, _, _, stage)); + ++smem_pipe_write_q; + pipeline_do.producer_acquire(smem_pipe_write_do); + copy(tma_load_dO.with(*pipeline_do.producer_get_barrier(smem_pipe_write_do), mcast_mask_qdo), tdOgdO(_, _, _, m_block - stage), tdOsdO(_, _, _, stage)); + ++smem_pipe_write_do; + } + } + + Tensor gLSE = local_tile(mLSE(bidb, bidh, _), Shape>{}, make_coord(m_block)); + Tensor gdPsum = local_tile(mdPsum(bidb, bidh, _), Shape>{}, make_coord(m_block)); + + // Initialize matmul objects. + typename Ktraits::TiledMmaSdP tiledMmaSdP; + auto threadMmaSdP = tiledMmaSdP.get_thread_slice(threadIdx.x); + typename Ktraits::TiledMmadKV tiledMmadKV; + auto threadMmadKV = tiledMmadKV.get_thread_slice(threadIdx.x); + typename Ktraits::TiledMmadQ tiledMmadQ; + auto threadMmadQ = tiledMmadQ.get_thread_slice(threadIdx.x); + + // Allocate accumulator + Tensor tdKrdK = partition_fragment_C(tiledMmadKV, select(TileShape_MNK{})); + Tensor tdVrdV = partition_fragment_C(tiledMmadKV, select(TileShape_MNK{})); + + auto smem_tiled_copy_PdS = make_tiled_copy_C(typename Ktraits::SmemCopyAtomPdS{}, tiledMmaSdP); + auto smem_thr_copy_PdS = smem_tiled_copy_PdS.get_thread_slice(threadIdx.x); + + if constexpr (!SdP_swapAB) { + Tensor tPsP = smem_thr_copy_PdS.partition_D(sP); // ((Atom,AtomNum),PIPE_M,PIPE_N) + Tensor tdSsdS = smem_thr_copy_PdS.partition_D(sdS); // ((Atom,AtomNum),PIPE_M,PIPE_N) + + // Allocate "fragments/descriptors" + Tensor tSrQ = threadMmaSdP.partition_fragment_A(sQ); + Tensor tSrK = threadMmaSdP.partition_fragment_B(sK); + Tensor tdPrdO = threadMmaSdP.partition_fragment_A(sdO); + Tensor tdPrV = threadMmaSdP.partition_fragment_B(sV); + + Tensor caccS = make_identity_tensor(select<0, 1>(TileShape_MNK{})); // (BLK_M,BLK_N) -> (blk_m,blk_n) + Tensor taccScS = threadMmaSdP.partition_C(caccS); // (MMA,MMA_N,MMA_N) + static_assert(decltype(size<0, 0>(taccScS))::value == 2); + static_assert(decltype(size<0, 1>(taccScS))::value == 2); + // taccScS has shape ((2, 2, V), MMA_M, MMA_N), we only take only the row indices. + Tensor taccScS_row = taccScS(make_coord(_0{}, _, _0{}), _, _0{}); + Tensor lse = make_tensor(Shape>{}); + Tensor dP_sum = make_fragment_like(lse); + #pragma unroll + for (int mi = 0; mi < size(lse); ++mi) { + const int row = get<0>(taccScS_row(mi)); + lse(mi) = row < params.seqlen_q - m_block * kBlockM ? gLSE(row) : INFINITY; + dP_sum(mi) = row < params.seqlen_q - m_block * kBlockM ? gdPsum(row) : 0; + } + // if (cute::thread0()) { print_tensor(dP_sum); printf("\n"); } + // We want LSE = inf if the row is OOB. In that case Q would be zero, K would be zero, + // and scores would be zero. With LSE = 0, probs will be all 1's, and when we multiply + // with V (which would be zero), we're fine. However, with ALiBi, we might modify these + // scores, and probs can become NaN. Instead if we set LSE = inf for OOB rows, probs are always 0. + + clear(tdKrdK); + clear(tdVrdV); + + shared_storage.barrier_K.wait(0); + shared_storage.barrier_V.wait(0); + __syncthreads(); + + // #pragma unroll 2 + CUTLASS_PRAGMA_NO_UNROLL + for (; m_block >= 0; --m_block) { + Tensor tSrS = partition_fragment_C(tiledMmaSdP, select<0, 1>(TileShape_MNK{})); + pipeline_q.consumer_wait(smem_pipe_read_q); + __syncwarp(); + flash::gemm(tiledMmaSdP, tSrQ(_, _, _, smem_pipe_read_q.index()), tSrK, tSrS); + Tensor tdPrdP = partition_fragment_C(tiledMmaSdP, select<0, 1>(TileShape_MNK{})); + pipeline_do.consumer_wait(smem_pipe_read_do); + __syncwarp(); + flash::gemm(tiledMmaSdP, tdPrdO(_, _, _, smem_pipe_read_do.index()), tdPrV, tdPrdP); + + warpgroup_wait<1>(); + // Reshape tSrS from ((2, 2, V), MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, V, MMA_N)) + Tensor scores = make_tensor(tSrS.data(), flash::convert_layout_acc_rowcol(tSrS.layout())); + flash::scale_apply_exp2(scores, lse, params.scale_softmax_log2); + // if (cute::thread0()) { print_tensor(scores); printf("\n"); } + // Convert scores from fp32 to fp16/bf16 + Tensor rP = flash::convert_type(tSrS); + Tensor tPaP = smem_thr_copy_PdS.retile_S(rP); // ((Atom,AtomNum), MMA_N, MMA_N) + cute::copy(smem_tiled_copy_PdS, tPaP, tPsP); + int const warp_group_idx = cutlass::canonical_warp_group_idx(); + cutlass::arch::NamedBarrier::arrive(kNThreads, warp_group_idx /*id*/); + + warpgroup_wait<0>(); + // Reshape tdPrdP from ((2, 2, V), MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, V, MMA_N)) + Tensor dS = make_tensor(tdPrdP.data(), scores.layout()); + // if (cute::thread0()) { print_tensor(dS); printf("\n"); } + #pragma unroll + for (int mi = 0; mi < size<0>(dS); ++mi) { + #pragma unroll + for (int ni = 0; ni < size<1>(dS); ++ni) { dS(mi, ni) = scores(mi, ni) * (dS(mi, ni) - dP_sum(mi)); } + } + Tensor rdS = flash::convert_type(tdPrdP); + + Tensor tdSadS = smem_thr_copy_PdS.retile_S(rdS); // ((Atom,AtomNum), MMA_N, MMA_N) + cute::copy(smem_tiled_copy_PdS, tdSadS, tdSsdS); + // cutlass::arch::NamedBarrier::arrive(kNThreads, 1 /*id*/); + cutlass::arch::NamedBarrier::arrive(kNThreads, 2 + warp_group_idx /*id*/); + // if (cute::thread0()) { print_tensor(dS); printf("\n"); } + + if (m_block > 0) { + gLSE.data() = gLSE.data() + (-int(kBlockM)); + gdPsum.data() = gdPsum.data() + (-int(kBlockM)); + #pragma unroll + for (int mi = 0; mi < size(lse); ++mi) { + const int row = get<0>(taccScS_row(mi)); + lse(mi) = gLSE(row); + dP_sum(mi) = gdPsum(row); + } + } + + Tensor tdQrdQ = partition_fragment_C(tiledMmadQ, select(TileShape_MNK{})); + if constexpr (Mma_dQ_is_RS) { + static_assert(!dQ_swapAB); + Tensor tdQrdS = make_tensor(rdS.data(), convert_layout_acc_Aregs(tdPrdP.layout())); + Tensor tdQrK = threadMmadQ.partition_fragment_B(sKt); + flash::gemm(tiledMmadQ, tdQrdS, tdQrK, tdQrdQ); + // if (cute::thread0()) { print(tdQrdS); printf("\n"); print(tdQrK); printf("\n"); print(tdQrdQ); printf("\n"); } + } + + // warpgroup_wait<0>(); + // Tensor dQ_tmp = make_tensor(tdQrdQ.data(), flash::convert_layout_acc_rowcol(tdQrdQ.layout())); + // if (cute::thread0()) { print_tensor(dQ_tmp); printf("\n"); } + // if (cute::thread0()) { print_tensor(sK); printf("\n"); } + // if (cute::thread0()) { print_tensor(sKt); printf("\n"); } __syncthreads(); + + // __syncthreads(); // Without this I'm getting race condition, I thought the barrier would be enough + // SMEM fence to make sure sP is written before it's read by WGMMA + cutlass::arch::fence_view_async_shared(); + cutlass::arch::NamedBarrier::sync(kNThreads, 1 - warp_group_idx /*id*/); + if constexpr (!dKV_swapAB) { + Tensor tdVrP = threadMmadKV.partition_fragment_A(sPt); + Tensor tdVrdO = threadMmadKV.partition_fragment_B(sdOt); + flash::gemm(tiledMmadKV, tdVrP, tdVrdO(_, _, _, smem_pipe_read_do.index()), tdVrdV); + } else { + Tensor tdVrP = threadMmadKV.partition_fragment_B(sPt); + Tensor tdVrdO = threadMmadKV.partition_fragment_A(sdOt); + flash::gemm(tiledMmadKV, tdVrdO(_, _, _, smem_pipe_read_do.index()), tdVrP, tdVrdV); + } + ++smem_pipe_read_do; + + // warpgroup_wait<0>(); + // Tensor dV_tmp = make_tensor(tdVrdV.data(), flash::convert_layout_acc_rowcol(tdVrdV.layout())); + // if (cute::thread0()) { print_tensor(dV_tmp); printf("\n"); } + cutlass::arch::fence_view_async_shared(); + cutlass::arch::NamedBarrier::sync(kNThreads, 2 + 1 - warp_group_idx /*id*/); + if constexpr (!Mma_dQ_is_RS) { + if constexpr (!dQ_swapAB) { + Tensor tdQrdS = threadMmadQ.partition_fragment_A(sdS); + Tensor tdQrK = threadMmadQ.partition_fragment_B(sKt); + flash::gemm(tiledMmadQ, tdQrdS, tdQrK, tdQrdQ); + } else { + Tensor tdQrdS = threadMmadQ.partition_fragment_B(sdS); + Tensor tdQrK = threadMmadQ.partition_fragment_A(sKt); + flash::gemm(tiledMmadQ, tdQrK, tdQrdS, tdQrdQ); + } + } + // warpgroup_wait<0>(); + // Tensor dQ_tmp = make_tensor(tdQrdQ.data(), flash::convert_layout_acc_rowcol(tdQrdQ.layout())); + // if (cute::thread0()) { print_tensor(dQ_tmp); printf("\n"); } + // if (blockIdx.x == 0 && threadIdx.x == 128) { print_tensor(dQ_tmp); printf("\n"); } + + if constexpr (!dKV_swapAB) { + Tensor tdKrdS = threadMmadKV.partition_fragment_A(sdSt); + Tensor tdKrQ = threadMmadKV.partition_fragment_B(sQt); + flash::gemm(tiledMmadKV, tdKrdS, tdKrQ(_, _, _, smem_pipe_read_q.index()), tdKrdK); + } else { + Tensor tdKrdS = threadMmadKV.partition_fragment_B(sdSt); + Tensor tdKrQ = threadMmadKV.partition_fragment_A(sQt); + flash::gemm(tiledMmadKV, tdKrQ(_, _, _, smem_pipe_read_q.index()), tdKrdS, tdKrdK); + } + ++smem_pipe_read_q; + // Tensor dK_tmp = make_tensor(tdKrdK.data(), flash::convert_layout_acc_rowcol(tdKrdK.layout())); + // if (cute::thread0()) { print_tensor(dK_tmp); printf("\n"); } + + warpgroup_wait(); + // if (cute::thread0()) { print(tdQrdQ); printf("\n"); print(tdQgdQaccum); printf("\n"); } + Tensor tdQrdQ_atomic = recast(tdQrdQ); + Tensor tdQgdQaccum_atomic = recast(tdQgdQaccum(_, _, _, m_block)); + #pragma unroll + for (int i = 0; i < size(tdQrdQ_atomic); ++i) { atomicAdd(&tdQgdQaccum_atomic(i), tdQrdQ_atomic(i)); } + // for (int i = 0; i < size(tdQrdQ_atomic); ++i) { tdQgdQaccum_atomic(i) = tdQrdQ_atomic(i); } + + warpgroup_wait<0>(); + + pipeline_do.consumer_release(smem_pipe_release_do); // release V + ++smem_pipe_release_do; + pipeline_q.consumer_release(smem_pipe_release_q); // release V + ++smem_pipe_release_q; + + int const lane_predicate = cute::elect_one_sync(); + int const warp_idx = cutlass::canonical_warp_idx_sync(); + if (warp_idx == 0 && lane_predicate && m_block >= kStages) { + pipeline_q.producer_acquire(smem_pipe_write_q); + copy(tma_load_Q.with(*pipeline_q.producer_get_barrier(smem_pipe_write_q), mcast_mask_qdo), tQgQ(_, _, _, m_block - kStages), tQsQ(_, _, _, smem_pipe_write_q.index())); + ++smem_pipe_write_q; + pipeline_do.producer_acquire(smem_pipe_write_do); + copy(tma_load_dO.with(*pipeline_do.producer_get_barrier(smem_pipe_write_do), mcast_mask_qdo), tdOgdO(_, _, _, m_block - kStages), tdOsdO(_, _, _, smem_pipe_write_do.index())); + ++smem_pipe_write_do; + } + } + + } else { // SdP_swapAB + Tensor tdSsdS = smem_thr_copy_PdS.partition_D(sdSt); // ((Atom,AtomNum),PIPE_M,PIPE_N) + + // Allocate "fragments/descriptors" + Tensor tSrQ = threadMmaSdP.partition_fragment_B(sQ); + Tensor tSrK = threadMmaSdP.partition_fragment_A(sK); + Tensor tdPrdO = threadMmaSdP.partition_fragment_B(sdO); + Tensor tdPrV = threadMmaSdP.partition_fragment_A(sV); + + Tensor caccS = make_identity_tensor(select<1, 0>(TileShape_MNK{})); // (BLK_M,BLK_N) -> (blk_m,blk_n) + Tensor taccScS = threadMmaSdP.partition_C(caccS); // (MMA,MMA_N,MMA_N) + static_assert(decltype(size<0, 0>(taccScS))::value == 2); + static_assert(decltype(size<0, 1>(taccScS))::value == 2); + // taccScS has shape ((2, 2, V), MMA_M, MMA_N), we only take only the row indices. + Tensor taccScS_row = taccScS(make_coord(_, _0{}, _), _0{}, _); + Tensor lse = make_tensor(Shape>{}); + Tensor dP_sum = make_fragment_like(lse); + #pragma unroll + for (int mi = 0; mi < size(lse); ++mi) { + const int row = get<1>(taccScS_row(mi)); + lse(mi) = row < params.seqlen_q - m_block * kBlockM ? gLSE(row) : INFINITY; + dP_sum(mi) = row < params.seqlen_q - m_block * kBlockM ? gdPsum(row) : 0; + } + // cute::fill(lse, 1); + // cute::fill(dP_sum, 1); + // if (cute::thread0()) { print_tensor(dP_sum); printf("\n"); } + // We want LSE = inf if the row is OOB. In that case Q would be zero, K would be zero, + // and scores would be zero. With LSE = 0, probs will be all 1's, and when we multiply + // with V (which would be zero), we're fine. However, with ALiBi, we might modify these + // scores, and probs can become NaN. Instead if we set LSE = inf for OOB rows, probs are always 0. + + clear(tdKrdK); + clear(tdVrdV); + + shared_storage.barrier_K.wait(0); + shared_storage.barrier_V.wait(0); + + // #pragma unroll 2 + CUTLASS_PRAGMA_NO_UNROLL + for (; m_block >= 0; --m_block) { + Tensor tSrS = partition_fragment_C(tiledMmaSdP, select<1, 0>(TileShape_MNK{})); + pipeline_q.consumer_wait(smem_pipe_read_q); + flash::gemm(tiledMmaSdP, tSrK, tSrQ(_, _, _, smem_pipe_read_q.index()), tSrS); + Tensor tdPrdP = partition_fragment_C(tiledMmaSdP, select<1, 0>(TileShape_MNK{})); + pipeline_do.consumer_wait(smem_pipe_read_do); + flash::gemm(tiledMmaSdP, tdPrV, tdPrdO(_, _, _, smem_pipe_read_do.index()), tdPrdP); + + warpgroup_wait<1>(); + // Reshape tSrS from ((2, 2, V), MMA_N, MMA_M) to (nrow=(2, V, MMA_M), ncol=(2, MMA_N)) + Tensor scores = make_tensor(tSrS.data(), flash::convert_layout_acc_transposed_rowcol(tSrS.layout())); + flash::scale_apply_exp2(scores, lse, params.scale_softmax_log2); + // if (cute::thread0()) { print_tensor(scores); printf("\n"); } + + // Convert scores from fp32 to fp16/bf16 + Tensor rP = flash::convert_type(tSrS); + + static_assert(!dKV_swapAB); + Tensor tdVrP = make_tensor(rP.data(), convert_layout_acc_Aregs(tSrS.layout())); + Tensor tdVrdO = threadMmadKV.partition_fragment_B(sdOt); + flash::gemm(tiledMmadKV, tdVrP, tdVrdO(_, _, _, smem_pipe_read_do.index()), tdVrdV); + ++smem_pipe_read_do; + // warpgroup_wait<0>(); + // Tensor dV_tmp = make_tensor(tdVrdV.data(), flash::convert_layout_acc_rowcol(tdVrdV.layout())); + // if (cute::thread0()) { print_tensor(dV_tmp); printf("\n"); } + + warpgroup_wait<1>(); + // Reshape tdPrdP from ((2, 2, V), MMA_N, MMA_M) to (nrow=(2, V, MMA_M), ncol=(2, MMA_N)) + Tensor dS = make_tensor(tdPrdP.data(), scores.layout()); + #pragma unroll + for (int mi = 0; mi < size<0>(dS); ++mi) { + #pragma unroll + for (int ni = 0; ni < size<1>(dS); ++ni) { dS(mi, ni) = scores(mi, ni) * (dS(mi, ni) - dP_sum(mi)); } + } + // if (cute::thread0()) { print_tensor(dS); printf("\n"); } + Tensor rdS = flash::convert_type(tdPrdP); + + Tensor tdSadS = smem_thr_copy_PdS.retile_S(rdS); // ((Atom,AtomNum), MMA_N, MMA_N) + cute::copy(smem_tiled_copy_PdS, tdSadS, tdSsdS); + + if (m_block > 0) { + gLSE.data() = gLSE.data() + (-int(kBlockM)); + gdPsum.data() = gdPsum.data() + (-int(kBlockM)); + #pragma unroll + for (int mi = 0; mi < size(lse); ++mi) { + const int row = get<1>(taccScS_row(mi)); + lse(mi) = gLSE(row); + dP_sum(mi) = gdPsum(row); + } + } + + Tensor tdKrdS = make_tensor(rdS.data(), convert_layout_acc_Aregs(tdPrdP.layout())); + Tensor tdKrQ = threadMmadKV.partition_fragment_B(sQt); + flash::gemm(tiledMmadKV, tdKrdS, tdKrQ(_, _, _, smem_pipe_read_q.index()), tdKrdK); + ++smem_pipe_read_q; + // warpgroup_wait<0>(); + // Tensor dK_tmp = make_tensor(tdKrdK.data(), flash::convert_layout_acc_rowcol(tdKrdK.layout())); + // if (cute::thread0()) { print_tensor(dK_tmp); printf("\n"); } + + // SMEM fence to make sure sP is written before it's read by WGMMA + cutlass::arch::fence_view_async_shared(); + // cutlass::arch::NamedBarrier::sync(kNThreads, 0 /*id*/); + __syncthreads(); + static_assert(!Mma_dQ_is_RS); + Tensor tdQrdQ = partition_fragment_C(tiledMmadQ, select(TileShape_MNK{})); + if constexpr (!dQ_swapAB) { + Tensor tdQrdS = threadMmadQ.partition_fragment_A(sdS); + Tensor tdQrK = threadMmadQ.partition_fragment_B(sKt); + flash::gemm(tiledMmadQ, tdQrdS, tdQrK, tdQrdQ); + } else { + Tensor tdQrdS = threadMmadQ.partition_fragment_B(sdS); + Tensor tdQrK = threadMmadQ.partition_fragment_A(sKt); + flash::gemm(tiledMmadQ, tdQrK, tdQrdS, tdQrdQ); + } + // warpgroup_wait<0>(); + // Tensor dQ_tmp = make_tensor(tdQrdQ.data(), flash::convert_layout_acc_rowcol(tdQrdQ.layout())); + // if (cute::thread0()) { print_tensor(dQ_tmp); printf("\n"); } + + + warpgroup_wait<0>(); + // if (cute::thread0()) { print(tdQrdQ); printf("\n"); print(tdQgdQaccum); printf("\n"); } + Tensor tdQrdQ_atomic = recast(tdQrdQ); + Tensor tdQgdQaccum_atomic = recast(tdQgdQaccum(_, _, _, m_block)); + #pragma unroll + for (int i = 0; i < size(tdQrdQ_atomic); ++i) { atomicAdd(&tdQgdQaccum_atomic(i), tdQrdQ_atomic(i)); } + // for (int i = 0; i < size(tdQrdQ_atomic); ++i) { tdQgdQaccum_atomic(i) = tdQrdQ_atomic(i); } + + pipeline_do.consumer_release(smem_pipe_release_do); // release V + ++smem_pipe_release_do; + pipeline_q.consumer_release(smem_pipe_release_q); // release V + ++smem_pipe_release_q; + + int const lane_predicate = cute::elect_one_sync(); + int const warp_idx = cutlass::canonical_warp_idx_sync(); + if (warp_idx == 0 && lane_predicate && m_block >= kStages) { + pipeline_q.producer_acquire(smem_pipe_write_q); + copy(tma_load_Q.with(*pipeline_q.producer_get_barrier(smem_pipe_write_q), mcast_mask_qdo), tQgQ(_, _, _, m_block - kStages), tQsQ(_, _, _, smem_pipe_write_q.index())); + ++smem_pipe_write_q; + pipeline_do.producer_acquire(smem_pipe_write_do); + copy(tma_load_dO.with(*pipeline_do.producer_get_barrier(smem_pipe_write_do), mcast_mask_qdo), tdOgdO(_, _, _, m_block - kStages), tdOsdO(_, _, _, smem_pipe_write_do.index())); + ++smem_pipe_write_do; + } + } + } + + // Epilogue + + #pragma unroll + for (int i = 0; i < size(tdKrdK); ++i) { tdKrdK(i) *= params.scale_softmax; } + + Tensor tdKrdK_out = convert_type(tdKrdK); + Tensor tdVrdV_out = convert_type(tdVrdV); + + Tensor sdK = make_tensor(make_smem_ptr(shared_storage.smem_dk.data()), typename Ktraits::SmemLayoutdK{}); + Tensor sdV = make_tensor(make_smem_ptr(shared_storage.smem_dv.data()), typename Ktraits::SmemLayoutdV{}); + Tensor sdKt = make_tensor(make_smem_ptr(shared_storage.smem_dk.data()), typename Ktraits::SmemLayoutdKt{}); + Tensor sdVt = make_tensor(make_smem_ptr(shared_storage.smem_dv.data()), typename Ktraits::SmemLayoutdVt{}); + + auto smem_tiled_copy_dKV = make_tiled_copy_C(typename Ktraits::SmemCopyAtomdKV{}, tiledMmadKV); + auto smem_thr_copy_dKV = smem_tiled_copy_dKV.get_thread_slice(threadIdx.x); + Tensor taccdKrdK = smem_thr_copy_dKV.retile_S(tdKrdK_out); // ((Atom,AtomNum), MMA_M, MMA_N) + Tensor taccdVrdV = smem_thr_copy_dKV.retile_S(tdVrdV_out); // ((Atom,AtomNum), MMA_M, MMA_N) + + __syncthreads(); + if constexpr (!dKV_swapAB) { + Tensor taccdKsdK = smem_thr_copy_dKV.partition_D(sdK); // ((Atom,AtomNum),PIPE_M,PIPE_N) + Tensor taccdVsdV = smem_thr_copy_dKV.partition_D(sdV); // ((Atom,AtomNum),PIPE_M,PIPE_N) + cute::copy(smem_tiled_copy_dKV, taccdKrdK, taccdKsdK); + cute::copy(smem_tiled_copy_dKV, taccdVrdV, taccdVsdV); + } else { + Tensor taccdKsdKt = smem_thr_copy_dKV.partition_D(sdKt); // ((Atom,AtomNum),PIPE_M,PIPE_N) + Tensor taccdVsdVt = smem_thr_copy_dKV.partition_D(sdVt); // ((Atom,AtomNum),PIPE_M,PIPE_N) + cute::copy(smem_tiled_copy_dKV, taccdKrdK, taccdKsdKt); + cute::copy(smem_tiled_copy_dKV, taccdVrdV, taccdVsdVt); + } + cutlass::arch::fence_view_async_shared(); // ensure smem writes are visible to TMA + + Tensor mdK = tma_store_dK.get_tma_tensor(make_shape(params.seqlen_k, params.d, params.h, params.b)); + Tensor mdV = tma_store_dV.get_tma_tensor(make_shape(params.seqlen_k, params.d, params.h, params.b)); + Tensor gdK = local_tile(mdK(_, _, bidh, bidb), select<1, 2>(TileShape_MNK{}), make_coord(n_block, _0{})); // (M, K) + Tensor gdV = local_tile(mdV(_, _, bidh, bidb), select<1, 2>(TileShape_MNK{}), make_coord(n_block, _0{})); // (M, K) + auto block_tma_dK = tma_store_dK.get_slice(_0{}); + auto block_tma_dV = tma_store_dV.get_slice(_0{}); + Tensor tdKgdK = block_tma_dK.partition_D(gdK); // (TMA, TMA_M, TMA_K) + Tensor tdKsdK = block_tma_dK.partition_S(sdK); // (TMA, TMA_M, TMA_K) + Tensor tdVgdV = block_tma_dV.partition_D(gdV); // (TMA, TMA_M, TMA_K) + Tensor tdVsdV = block_tma_dV.partition_S(sdV); // (TMA, TMA_M, TMA_K) + + __syncthreads(); // ensure all threads have issued their async fence + + lane_predicate = cute::elect_one_sync(); + warp_idx = cutlass::canonical_warp_idx_sync(); + if (warp_idx == 0 && lane_predicate) { + cute::copy(tma_store_dV, tdVsdV, tdVgdV); + cute::copy(tma_store_dK, tdKsdK, tdKgdK); + tma_store_arrive(); + } + tma_store_wait<0>(); + + // To make sure remote SMEM doesn't get destroyed + if constexpr (size(ClusterShape{}) > 1) { + cute::cluster_arrive(); + cute::cluster_wait(); + } + +} + +template +__global__ void __launch_bounds__(Ktraits::kNWarps * cutlass::NumThreadsPerWarp, 1) + compute_dqkv_seqqpar(CUTE_GRID_CONSTANT Flash_bwd_params const params, + CUTE_GRID_CONSTANT TiledCopyQ const tma_load_Q, + CUTE_GRID_CONSTANT TiledCopydO const tma_load_dO, + CUTE_GRID_CONSTANT TiledCopyK const tma_load_K, + CUTE_GRID_CONSTANT TiledCopyV const tma_load_V, + CUTE_GRID_CONSTANT TiledCopydQ const tma_store_dQ, + CUTE_GRID_CONSTANT TiledCopydK const tma_store_dK, + CUTE_GRID_CONSTANT TiledCopydV const tma_store_dV) { + + using Element = typename Ktraits::Element; + using ElementAccum = typename Ktraits::ElementAccum; + using SoftType = ElementAccum; + using index_t = typename Ktraits::index_t; + using TileShape_MNK = typename Ktraits::TileShape_MNK; + using ClusterShape = typename Ktraits::ClusterShape_MNK; + + static constexpr int kNThreads = Ktraits::kNThreads; + static constexpr int NumMmaThreads = Ktraits::kNThreads; + static constexpr int kBlockM = Ktraits::kBlockM; + static constexpr int kBlockN = Ktraits::kBlockN; + // constexpr int kHeadDim = Ktraits::kHeadDim; + static constexpr int kStages = Ktraits::kStages; + + static constexpr bool SdP_swapAB = Ktraits::SdP_swapAB; + static constexpr bool dKV_swapAB = Ktraits::dKV_swapAB; + static constexpr bool dQ_swapAB = Ktraits::dQ_swapAB; + + static constexpr bool Mma_dQ_is_RS = Ktraits::Mma_dQ_is_RS; + if constexpr (dQ_swapAB) { static_assert(!Mma_dQ_is_RS); } + + using MainloopPipeline = typename Ktraits::MainloopPipeline; + using PipelineParams = typename MainloopPipeline::Params; + using PipelineState = typename MainloopPipeline::PipelineState; + + extern __shared__ char shared_memory[]; + auto &shared_storage = *reinterpret_cast(shared_memory); + + int const m_block = blockIdx.x; + int const bidb = blockIdx.z; // The block index for the batch. + int const bidh = blockIdx.y; // The block index for the head. + + int lane_predicate = cute::elect_one_sync(); + int warp_idx = cutlass::canonical_warp_idx_sync(); + + // Issue Tma Descriptor Prefetch from a single thread + if (warp_idx == 0 && lane_predicate) { + cute::prefetch_tma_descriptor(tma_load_Q.get_tma_descriptor()); + cute::prefetch_tma_descriptor(tma_load_dO.get_tma_descriptor()); + cute::prefetch_tma_descriptor(tma_load_K.get_tma_descriptor()); + cute::prefetch_tma_descriptor(tma_load_V.get_tma_descriptor()); + cute::prefetch_tma_descriptor(tma_store_dQ.get_tma_descriptor()); + } + + Tensor mQ = tma_load_Q.get_tma_tensor(make_shape(params.seqlen_q, params.d, params.h, params.b)); + Tensor mdO = tma_load_dO.get_tma_tensor(make_shape(params.seqlen_q, params.d, params.h, params.b)); + Tensor mK = tma_load_K.get_tma_tensor(make_shape(params.seqlen_k, params.d, params.h, params.b)); + Tensor mV = tma_load_V.get_tma_tensor(make_shape(params.seqlen_k, params.d, params.h, params.b)); + Tensor mLSE = make_tensor(make_gmem_ptr(reinterpret_cast(params.softmax_lse_ptr)), + make_shape(params.b, params.h, params.seqlen_q), + make_stride(params.h * params.seqlen_q, params.seqlen_q, _1{})); + Tensor mdPsum = make_tensor(make_gmem_ptr(reinterpret_cast(params.dsoftmax_sum)), + make_shape(params.b, params.h, params.seqlen_q), + make_stride(params.h * params.seqlen_q_rounded, params.seqlen_q_rounded, _1{})); + Tensor mdKaccum = make_tensor(make_gmem_ptr(reinterpret_cast(params.dk_accum_ptr)), + make_shape(params.seqlen_k, params.d, params.h, params.b), + make_stride(params.d * params.h, _1{}, params.d, params.d * params.h * params.seqlen_k)); + Tensor mdVaccum = make_tensor(make_gmem_ptr(reinterpret_cast(params.dv_accum_ptr)), + make_shape(params.seqlen_k, params.d, params.h, params.b), + make_stride(params.d * params.h, _1{}, params.d, params.d * params.h * params.seqlen_k)); + + + Tensor gQ = local_tile(mQ(_, _, bidh, bidb), select<0, 2>(TileShape_MNK{}), make_coord(m_block, _0{})); // (M, K) + Tensor gdO = local_tile(mdO(_, _, bidh, bidb), select<0, 2>(TileShape_MNK{}), make_coord(m_block, _0{})); // (M, K) + Tensor gK = local_tile(mK(_, _, bidh, bidb), select<1, 2>(TileShape_MNK{}), make_coord(_, _0{})); // (N, K, _) + Tensor gV = local_tile(mV(_, _, bidh, bidb), select<1, 2>(TileShape_MNK{}), make_coord(_, _0{})); // (N, K, _) + Tensor gdKaccum = local_tile(mdKaccum(_, _, bidh, bidb), select<1, 2>(TileShape_MNK{}), make_coord(_, _0{})); // (N, K, _) + Tensor gdVaccum = local_tile(mdVaccum(_, _, bidh, bidb), select<1, 2>(TileShape_MNK{}), make_coord(_, _0{})); // (N, K, _) + + typename Ktraits::GmemTiledCopydQaccum gmem_tiled_copy_dKVaccum; + auto gmem_thr_copy_dKVaccum = gmem_tiled_copy_dKVaccum.get_thread_slice(threadIdx.x); + Tensor tdKgdKaccum = gmem_thr_copy_dKVaccum.partition_D(gdKaccum); + Tensor tdVgdVaccum = gmem_thr_copy_dKVaccum.partition_D(gdVaccum); + + // Construct SMEM tensors. + Tensor sQ = make_tensor(make_smem_ptr(shared_storage.smem_q.data()), typename Ktraits::SmemLayoutQ{}); + Tensor sdO = make_tensor(make_smem_ptr(shared_storage.smem_do.data()), typename Ktraits::SmemLayoutdO{}); + Tensor sK = make_tensor(make_smem_ptr(shared_storage.smem_k.data()), typename Ktraits::SmemLayoutK{}); + Tensor sV = make_tensor(make_smem_ptr(shared_storage.smem_v.data()), typename Ktraits::SmemLayoutV{}); + Tensor sP = make_tensor(make_smem_ptr(shared_storage.smem_p.data()), typename Ktraits::SmemLayoutP{}); + Tensor sdS = make_tensor(make_smem_ptr(shared_storage.smem_ds.data()), typename Ktraits::SmemLayoutdS{}); + Tensor sQt = make_tensor(make_smem_ptr(shared_storage.smem_q.data()), typename Ktraits::SmemLayoutQt{}); + Tensor sdOt = make_tensor(make_smem_ptr(shared_storage.smem_do.data()), typename Ktraits::SmemLayoutdOt{}); + Tensor sKt = make_tensor(make_smem_ptr(shared_storage.smem_k.data()), typename Ktraits::SmemLayoutKt{}); + Tensor sPt = make_tensor(make_smem_ptr(shared_storage.smem_p.data()), typename Ktraits::SmemLayoutPt{}); + Tensor sdSt = make_tensor(make_smem_ptr(shared_storage.smem_ds.data()), typename Ktraits::SmemLayoutdSt{}); + + // Prepare the TMA loads + uint32_t block_rank_in_cluster = cute::block_rank_in_cluster(); + constexpr uint32_t cluster_shape_x = get<0>(ClusterShape()); + uint2 cluster_local_block_id = {block_rank_in_cluster % cluster_shape_x, block_rank_in_cluster / cluster_shape_x}; + auto block_tma_Q = tma_load_Q.get_slice(_0{}); + auto block_tma_dO = tma_load_dO.get_slice(_0{}); + auto block_tma_K = tma_load_K.get_slice(cluster_local_block_id.x); + auto block_tma_V = tma_load_V.get_slice(cluster_local_block_id.x); + + Tensor tQgQ = block_tma_Q.partition_S(gQ); // (TMA, TMA_M, TMA_K) + Tensor tQsQ = block_tma_Q.partition_D(sQ); // (TMA, TMA_M, TMA_K) + Tensor tdOgdO = block_tma_dO.partition_S(gdO); // (TMA, TMA_M, TMA_K) + Tensor tdOsdO = block_tma_dO.partition_D(sdO); // (TMA, TMA_M, TMA_K) + Tensor tKgK = block_tma_K.partition_S(gK); // (TMA, TMA_N, TMA_K, k) + Tensor tKsK = block_tma_K.partition_D(sK); // (TMA, TMA_N, TMA_K, PIPE) + Tensor tVgV = block_tma_V.partition_S(gV); // (TMA, TMA_N, TMA_K, k) + Tensor tVsV = block_tma_V.partition_D(sV); // (TMA, TMA_N, TMA_K, PIPE) + + // Set the bytes transferred in this TMA transaction (may involve multiple issues) + constexpr uint32_t TmaTransactionBytesQ = static_cast(size<0>(sQ) * size<1>(sQ) * cutlass::sizeof_bits_v / 8); + constexpr uint32_t TmaTransactionBytesdO = static_cast(size<0>(sdO) * size<1>(sdO) * cutlass::sizeof_bits_v / 8); + static_assert(TmaTransactionBytesQ == TmaTransactionBytesdO); + constexpr uint32_t TmaTransactionBytesK = static_cast(size<0>(sK) * size<1>(sK) * cutlass::sizeof_bits_v / 8); + constexpr uint32_t TmaTransactionBytesV = static_cast(size<0>(sV) * size<1>(sV) * cutlass::sizeof_bits_v / 8); + static_assert(TmaTransactionBytesK == TmaTransactionBytesV); + + // Obtain warp index + int thread_idx = int(threadIdx.x); + int warp_group_thread_idx = thread_idx % cutlass::NumThreadsPerWarpGroup; + // int const warp_group_thread_idx = threadIdx.x % cutlass::NumThreadsPerWarpGroup; + + PipelineParams pipeline_params; + pipeline_params.transaction_bytes = TmaTransactionBytesK; + pipeline_params.role = MainloopPipeline::ThreadCategory::ProducerConsumer; + pipeline_params.is_leader = warp_group_thread_idx == 0; + pipeline_params.num_consumers = NumMmaThreads; + + if (warp_idx == 0 && lane_predicate) { + shared_storage.barrier_Q.init(1 /*numThreads*/); + shared_storage.barrier_dO.init(1 /*numThreads*/); + } + // cutlass::arch::fence_barrier_init(); + // We're counting on pipeline_k to call fence_barrier_init(); + MainloopPipeline pipeline_k(shared_storage.pipeline_k, pipeline_params, ClusterShape{}); + MainloopPipeline pipeline_v(shared_storage.pipeline_v, pipeline_params, ClusterShape{}); + + // We need this to guarantee that the Pipeline init is visible + // To all producers and consumer blocks in the Cluster + if constexpr (size(ClusterShape{}) > 1) { + cute::cluster_arrive_relaxed(); + cute::cluster_wait(); + } else { + __syncthreads(); + } + + // State variables used for iterating the circular buffer + // smem_pipe_read / release is used by the consumer of SMEM data - i.e MMA + // smem_pipe_write is used by the producer of SMEM data - i.e TMA + PipelineState smem_pipe_read_k, smem_pipe_read_v; + PipelineState smem_pipe_release_k, smem_pipe_release_v; + PipelineState smem_pipe_write_k = cutlass::make_producer_start_state(); + PipelineState smem_pipe_write_v = cutlass::make_producer_start_state(); + + // Copy K tile and V tile from GMEM to SMEM. + if (warp_idx == 0 && lane_predicate) { + shared_storage.barrier_Q.arrive_and_expect_tx(TmaTransactionBytesQ); + copy(tma_load_Q.with(reinterpret_cast(shared_storage.barrier_Q), 0 /*mcast_mask*/), tQgQ, tQsQ); + shared_storage.barrier_dO.arrive_and_expect_tx(TmaTransactionBytesdO); + copy(tma_load_dO.with(reinterpret_cast(shared_storage.barrier_dO), 0 /*mcast_mask*/), tdOgdO, tdOsdO); + } + // if (cute::thread0()) { print_tensor(sQ); printf("\n"); } __syncthreads(); + + int n_block = cute::ceil_div(params.seqlen_k, kBlockN) - 1; + + uint16_t mcast_mask_kv = 0; + if constexpr (cute::is_same_v) { + auto block_layout = Layout{}; // (m,n) -> block_id + for (int m = 0; m < size<0>(block_layout); ++m) { + mcast_mask_kv |= (uint16_t(1) << block_layout(m, cluster_local_block_id.y, _0{})); + } + } + // Issue TmaLoads (Prologue fetches) + if (warp_idx == 0 && lane_predicate) { + // Issue the prologue loads + CUTLASS_PRAGMA_UNROLL + for (int stage = 0; stage < kStages && stage <= n_block; ++stage) { + pipeline_k.producer_acquire(smem_pipe_write_k); + copy(tma_load_K.with(*pipeline_k.producer_get_barrier(smem_pipe_write_k), mcast_mask_kv), tKgK(_, _, _, n_block - stage), tKsK(_, _, _, stage)); + ++smem_pipe_write_k; + pipeline_v.producer_acquire(smem_pipe_write_v); + copy(tma_load_V.with(*pipeline_v.producer_get_barrier(smem_pipe_write_v), mcast_mask_kv), tVgV(_, _, _, n_block - stage), tVsV(_, _, _, stage)); + ++smem_pipe_write_v; + } + } + + Tensor gLSE = local_tile(mLSE(bidb, bidh, _), Shape>{}, make_coord(m_block)); + Tensor gdPsum = local_tile(mdPsum(bidb, bidh, _), Shape>{}, make_coord(m_block)); + + // Initialize matmul objects. + typename Ktraits::TiledMmaSdP tiledMmaSdP; + auto threadMmaSdP = tiledMmaSdP.get_thread_slice(threadIdx.x); + typename Ktraits::TiledMmadKV tiledMmadKV; + auto threadMmadKV = tiledMmadKV.get_thread_slice(threadIdx.x); + typename Ktraits::TiledMmadQ tiledMmadQ; + auto threadMmadQ = tiledMmadQ.get_thread_slice(threadIdx.x); + + // Allocate accumulator + Tensor tdQrdQ = partition_fragment_C(tiledMmadQ, select(TileShape_MNK{})); + clear(tdQrdQ); + + auto smem_tiled_copy_PdS = make_tiled_copy_C(typename Ktraits::SmemCopyAtomPdS{}, tiledMmaSdP); + auto smem_thr_copy_PdS = smem_tiled_copy_PdS.get_thread_slice(threadIdx.x); + + if constexpr (!SdP_swapAB) { + Tensor tPsP = smem_thr_copy_PdS.partition_D(sP); // ((Atom,AtomNum),PIPE_M,PIPE_N) + Tensor tdSsdS = smem_thr_copy_PdS.partition_D(sdS); // ((Atom,AtomNum),PIPE_M,PIPE_N) + + // Allocate "fragments/descriptors" + Tensor tSrQ = threadMmaSdP.partition_fragment_A(sQ); + Tensor tSrK = threadMmaSdP.partition_fragment_B(sK); + Tensor tdPrdO = threadMmaSdP.partition_fragment_A(sdO); + Tensor tdPrV = threadMmaSdP.partition_fragment_B(sV); + + Tensor caccS = make_identity_tensor(select<0, 1>(TileShape_MNK{})); // (BLK_M,BLK_N) -> (blk_m,blk_n) + Tensor taccScS = threadMmaSdP.partition_C(caccS); // (MMA,MMA_N,MMA_N) + static_assert(decltype(size<0, 0>(taccScS))::value == 2); + static_assert(decltype(size<0, 1>(taccScS))::value == 2); + // taccScS has shape ((2, 2, V), MMA_M, MMA_N), we only take only the row indices. + Tensor taccScS_row = taccScS(make_coord(_0{}, _, _0{}), _, _0{}); + Tensor lse = make_tensor(Shape>{}); + Tensor dP_sum = make_fragment_like(lse); + #pragma unroll + for (int mi = 0; mi < size(lse); ++mi) { + const int row = get<0>(taccScS_row(mi)); + lse(mi) = row < params.seqlen_q - m_block * kBlockM ? gLSE(row) : INFINITY; + dP_sum(mi) = row < params.seqlen_q - m_block * kBlockM ? gdPsum(row) : 0; + } + // if (cute::thread0()) { print_tensor(lse); printf("\n"); } + // if (cute::thread0()) { print_tensor(dP_sum); printf("\n"); } + // We want LSE = inf if the row is OOB. In that case Q would be zero, K would be zero, + // and scores would be zero. With LSE = 0, probs will be all 1's, and when we multiply + // with V (which would be zero), we're fine. However, with ALiBi, we might modify these + // scores, and probs can become NaN. Instead if we set LSE = inf for OOB rows, probs are always 0. + + shared_storage.barrier_Q.wait(0); + shared_storage.barrier_dO.wait(0); + + // #pragma unroll 2 + CUTLASS_PRAGMA_NO_UNROLL + for (; n_block >= 0; --n_block) { + // Otherwise we might have WG0 still wating on NamedBarrier but WG1 already + // started the next iteration and start flipping the same NamedBarrier. + __syncthreads(); + Tensor tSrS = partition_fragment_C(tiledMmaSdP, select<0, 1>(TileShape_MNK{})); + pipeline_k.consumer_wait(smem_pipe_read_k); + flash::gemm(tiledMmaSdP, tSrQ, tSrK(_, _, _, smem_pipe_read_k.index()), tSrS); + Tensor tdPrdP = partition_fragment_C(tiledMmaSdP, select<0, 1>(TileShape_MNK{})); + pipeline_v.consumer_wait(smem_pipe_read_v); + flash::gemm(tiledMmaSdP, tdPrdO, tdPrV(_, _, _, smem_pipe_read_v.index()), tdPrdP); + ++smem_pipe_read_v; + + warpgroup_wait<1>(); + // Reshape tSrS from ((2, 2, V), MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, V, MMA_N)) + Tensor scores = make_tensor(tSrS.data(), flash::convert_layout_acc_rowcol(tSrS.layout())); + flash::scale_apply_exp2(scores, lse, params.scale_softmax_log2); + // if (cute::thread0()) { print_tensor(scores); printf("\n"); } + // Convert scores from fp32 to fp16/bf16 + Tensor rP = flash::convert_type(tSrS); + Tensor tPaP = smem_thr_copy_PdS.retile_S(rP); // ((Atom,AtomNum), MMA_N, MMA_N) + cute::copy(smem_tiled_copy_PdS, tPaP, tPsP); + int const warp_group_idx = cutlass::canonical_warp_group_idx(); + cutlass::arch::NamedBarrier::arrive(kNThreads, warp_group_idx /*id*/); + + warpgroup_wait<0>(); + // Reshape tdPrdP from ((2, 2, V), MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, V, MMA_N)) + Tensor dS = make_tensor(tdPrdP.data(), scores.layout()); + // if (cute::thread0()) { print_tensor(dS); printf("\n"); } + #pragma unroll + for (int mi = 0; mi < size<0>(dS); ++mi) { + #pragma unroll + for (int ni = 0; ni < size<1>(dS); ++ni) { dS(mi, ni) = scores(mi, ni) * (dS(mi, ni) - dP_sum(mi)); } + } + Tensor rdS = flash::convert_type(tdPrdP); + + Tensor tdSadS = smem_thr_copy_PdS.retile_S(rdS); // ((Atom,AtomNum), MMA_N, MMA_N) + cute::copy(smem_tiled_copy_PdS, tdSadS, tdSsdS); + cutlass::arch::NamedBarrier::arrive(kNThreads, 2 + warp_group_idx /*id*/); + // if (cute::thread0()) { print_tensor(dS); printf("\n"); } + + if constexpr (Mma_dQ_is_RS) { + static_assert(!dQ_swapAB); + Tensor tdQrdS = make_tensor(rdS.data(), convert_layout_acc_Aregs(tdPrdP.layout())); + Tensor tdQrK = threadMmadQ.partition_fragment_B(sKt); + flash::gemm(tiledMmadQ, tdQrdS, tdQrK(_, _, _, smem_pipe_read_k.index()), tdQrdQ); + // if (cute::thread0()) { print(tdQrdS); printf("\n"); print(tdQrK); printf("\n"); print(tdQrdQ); printf("\n"); } + } + + // warpgroup_wait<0>(); + // Tensor dQ_tmp = make_tensor(tdQrdQ.data(), flash::convert_layout_acc_rowcol(tdQrdQ.layout())); + // if (cute::thread0()) { print_tensor(dQ_tmp); printf("\n"); } + // if (cute::thread0()) { print_tensor(sK); printf("\n"); } + // if (cute::thread0()) { print_tensor(sKt); printf("\n"); } __syncthreads(); + + // if (cute::thread0()) { printf("before barrier sync 0\n"); } + // SMEM fence to make sure sP is written before it's read by WGMMA + cutlass::arch::fence_view_async_shared(); + cutlass::arch::NamedBarrier::sync(kNThreads, 1 - warp_group_idx /*id*/); + // if (cute::thread0()) { printf("after barrier sync 0\n"); } + Tensor tdVrdV = partition_fragment_C(tiledMmadKV, select(TileShape_MNK{})); + + if constexpr (!dKV_swapAB) { + Tensor tdVrP = threadMmadKV.partition_fragment_A(sPt); + Tensor tdVrdO = threadMmadKV.partition_fragment_B(sdOt); + flash::gemm(tiledMmadKV, tdVrP, tdVrdO, tdVrdV); + } else { + Tensor tdVrP = threadMmadKV.partition_fragment_B(sPt); + Tensor tdVrdO = threadMmadKV.partition_fragment_A(sdOt); + flash::gemm(tiledMmadKV, tdVrdO, tdVrP, tdVrdV); + } + + // warpgroup_wait<0>(); + // Tensor dV_tmp = make_tensor(tdVrdV.data(), flash::convert_layout_acc_rowcol(tdVrdV.layout())); + // if (cute::thread0()) { print_tensor(dV_tmp); printf("\n"); } + cutlass::arch::fence_view_async_shared(); + cutlass::arch::NamedBarrier::sync(kNThreads, 2 + 1 - warp_group_idx /*id*/); + if constexpr (!Mma_dQ_is_RS) { + if constexpr (!dQ_swapAB) { + Tensor tdQrdS = threadMmadQ.partition_fragment_A(sdS); + Tensor tdQrK = threadMmadQ.partition_fragment_B(sKt); + flash::gemm(tiledMmadQ, tdQrdS, tdQrK(_, _, _, smem_pipe_read_k.index()), tdQrdQ); + } else { + Tensor tdQrdS = threadMmadQ.partition_fragment_B(sdS); + Tensor tdQrK = threadMmadQ.partition_fragment_A(sKt); + flash::gemm(tiledMmadQ, tdQrK(_, _, _, smem_pipe_read_k.index()), tdQrdS, tdQrdQ); + } + } + ++smem_pipe_read_k; + // warpgroup_wait<0>(); + // Tensor dQ_tmp = make_tensor(tdQrdQ.data(), flash::convert_layout_acc_rowcol(tdQrdQ.layout())); + // if (cute::thread0()) { print_tensor(dQ_tmp); printf("\n"); } + // if (blockIdx.x == 0 && threadIdx.x == 128) { print_tensor(dQ_tmp); printf("\n"); } + + Tensor tdKrdK = partition_fragment_C(tiledMmadKV, select(TileShape_MNK{})); + if constexpr (!dKV_swapAB) { + Tensor tdKrdS = threadMmadKV.partition_fragment_A(sdSt); + Tensor tdKrQ = threadMmadKV.partition_fragment_B(sQt); + flash::gemm(tiledMmadKV, tdKrdS, tdKrQ, tdKrdK); + } else { + Tensor tdKrdS = threadMmadKV.partition_fragment_B(sdSt); + Tensor tdKrQ = threadMmadKV.partition_fragment_A(sQt); + flash::gemm(tiledMmadKV, tdKrQ, tdKrdS, tdKrdK); + } + // warpgroup_wait<0>(); + // Tensor dK_tmp = make_tensor(tdKrdK.data(), flash::convert_layout_acc_rowcol(tdKrdK.layout())); + // if (cute::thread0()) { print_tensor(dK_tmp); printf("\n"); } + + warpgroup_wait(); + // if (cute::thread0()) { print(tdQrdQ); printf("\n"); print(tdQgdQaccum); printf("\n"); } + Tensor tdVrdV_atomic = recast(tdVrdV); + Tensor tdVgdVaccum_atomic = recast(tdVgdVaccum(_, _, _, n_block)); + #pragma unroll + for (int i = 0; i < size(tdVrdV_atomic); ++i) { atomicAdd(&tdVgdVaccum_atomic(i), tdVrdV_atomic(i)); } + // for (int i = 0; i < size(tdVrdV_atomic); ++i) { tdVgdVaccum_atomic(i) = tdVrdV_atomic(i); } + + warpgroup_wait<0>(); + Tensor tdKrdK_atomic = recast(tdKrdK); + Tensor tdKgdKaccum_atomic = recast(tdKgdKaccum(_, _, _, n_block)); + #pragma unroll + for (int i = 0; i < size(tdKrdK_atomic); ++i) { atomicAdd(&tdKgdKaccum_atomic(i), tdKrdK_atomic(i)); } + + pipeline_v.consumer_release(smem_pipe_release_v); // release V + ++smem_pipe_release_v; + pipeline_k.consumer_release(smem_pipe_release_k); // release V + ++smem_pipe_release_k; + + int const lane_predicate = cute::elect_one_sync(); + int const warp_idx = cutlass::canonical_warp_idx_sync(); + if (warp_idx == 0 && lane_predicate && n_block >= kStages) { + pipeline_k.producer_acquire(smem_pipe_write_k); + copy(tma_load_K.with(*pipeline_k.producer_get_barrier(smem_pipe_write_k), mcast_mask_kv), tKgK(_, _, _, n_block - kStages), tKsK(_, _, _, smem_pipe_write_k.index())); + ++smem_pipe_write_k; + pipeline_v.producer_acquire(smem_pipe_write_v); + copy(tma_load_V.with(*pipeline_v.producer_get_barrier(smem_pipe_write_v), mcast_mask_kv), tVgV(_, _, _, n_block - kStages), tVsV(_, _, _, smem_pipe_write_v.index())); + ++smem_pipe_write_v; + } + } + + } else { // SdP_swapAB + Tensor tdSsdS = smem_thr_copy_PdS.partition_D(sdSt); // ((Atom,AtomNum),PIPE_M,PIPE_N) + + // Allocate "fragments/descriptors" + Tensor tSrQ = threadMmaSdP.partition_fragment_B(sQ); + Tensor tSrK = threadMmaSdP.partition_fragment_A(sK); + Tensor tdPrdO = threadMmaSdP.partition_fragment_B(sdO); + Tensor tdPrV = threadMmaSdP.partition_fragment_A(sV); + + Tensor caccS = make_identity_tensor(select<1, 0>(TileShape_MNK{})); // (BLK_M,BLK_N) -> (blk_m,blk_n) + Tensor taccScS = threadMmaSdP.partition_C(caccS); // (MMA,MMA_N,MMA_N) + static_assert(decltype(size<0, 0>(taccScS))::value == 2); + static_assert(decltype(size<0, 1>(taccScS))::value == 2); + // taccScS has shape ((2, 2, V), MMA_M, MMA_N), we only take only the row indices. + Tensor taccScS_row = taccScS(make_coord(_, _0{}, _), _0{}, _); + Tensor lse = make_tensor(Shape>{}); + Tensor dP_sum = make_fragment_like(lse); + #pragma unroll + for (int mi = 0; mi < size(lse); ++mi) { + const int row = get<1>(taccScS_row(mi)); + lse(mi) = row < params.seqlen_q - m_block * kBlockM ? gLSE(row) : INFINITY; + dP_sum(mi) = row < params.seqlen_q - m_block * kBlockM ? gdPsum(row) : 0; + } + // if (cute::thread0()) { print_tensor(taccScS_row); printf("\n"); } + // cute::fill(lse, 1); + // cute::fill(dP_sum, 1); + // if (cute::thread0()) { print_tensor(dP_sum); printf("\n"); } + // We want LSE = inf if the row is OOB. In that case Q would be zero, K would be zero, + // and scores would be zero. With LSE = 0, probs will be all 1's, and when we multiply + // with V (which would be zero), we're fine. However, with ALiBi, we might modify these + // scores, and probs can become NaN. Instead if we set LSE = inf for OOB rows, probs are always 0. + + clear(tdQrdQ); + + shared_storage.barrier_Q.wait(0); + shared_storage.barrier_dO.wait(0); + + // #pragma unroll 2 + CUTLASS_PRAGMA_NO_UNROLL + for (; n_block >= 0; --n_block) { + Tensor tSrS = partition_fragment_C(tiledMmaSdP, select<1, 0>(TileShape_MNK{})); + pipeline_k.consumer_wait(smem_pipe_read_k); + flash::gemm(tiledMmaSdP, tSrK(_, _, _, smem_pipe_read_k.index()), tSrQ, tSrS); + Tensor tdPrdP = partition_fragment_C(tiledMmaSdP, select<1, 0>(TileShape_MNK{})); + pipeline_v.consumer_wait(smem_pipe_read_v); + flash::gemm(tiledMmaSdP, tdPrV(_, _, _, smem_pipe_read_v.index()), tdPrdO, tdPrdP); + ++smem_pipe_read_v; + + warpgroup_wait<1>(); + // Reshape tSrS from ((2, 2, V), MMA_N, MMA_M) to (nrow=(2, V, MMA_M), ncol=(2, MMA_N)) + Tensor scores = make_tensor(tSrS.data(), flash::convert_layout_acc_transposed_rowcol(tSrS.layout())); + // if (cute::thread0()) { print_tensor(lse); printf("\n"); } + flash::scale_apply_exp2(scores, lse, params.scale_softmax_log2); + // if (cute::thread0()) { print_tensor(scores); printf("\n"); } + + // Convert scores from fp32 to fp16/bf16 + Tensor rP = flash::convert_type(tSrS); + + static_assert(!dKV_swapAB); + Tensor tdVrdV = partition_fragment_C(tiledMmadKV, select<1, 2>(TileShape_MNK{})); + Tensor tdVrP = make_tensor(rP.data(), convert_layout_acc_Aregs(tSrS.layout())); + Tensor tdVrdO = threadMmadKV.partition_fragment_B(sdOt); + flash::gemm(tiledMmadKV, tdVrP, tdVrdO, tdVrdV); + // warpgroup_wait<0>(); + // Tensor dV_tmp = make_tensor(tdVrdV.data(), flash::convert_layout_acc_rowcol(tdVrdV.layout())); + // if (cute::thread0()) { print_tensor(dV_tmp); printf("\n"); } + + warpgroup_wait<1>(); + // Reshape tdPrdP from ((2, 2, V), MMA_N, MMA_M) to (nrow=(2, V, MMA_M), ncol=(2, MMA_N)) + Tensor dS = make_tensor(tdPrdP.data(), scores.layout()); + #pragma unroll + for (int mi = 0; mi < size<0>(dS); ++mi) { + #pragma unroll + for (int ni = 0; ni < size<1>(dS); ++ni) { dS(mi, ni) = scores(mi, ni) * (dS(mi, ni) - dP_sum(mi)); } + } + // if (cute::thread0()) { print_tensor(dS); printf("\n"); } + Tensor rdS = flash::convert_type(tdPrdP); + + Tensor tdSadS = smem_thr_copy_PdS.retile_S(rdS); // ((Atom,AtomNum), MMA_N, MMA_N) + cute::copy(smem_tiled_copy_PdS, tdSadS, tdSsdS); + + Tensor tdKrdK = partition_fragment_C(tiledMmadKV, select<1, 2>(TileShape_MNK{})); + Tensor tdKrdS = make_tensor(rdS.data(), convert_layout_acc_Aregs(tdPrdP.layout())); + Tensor tdKrQ = threadMmadKV.partition_fragment_B(sQt); + flash::gemm(tiledMmadKV, tdKrdS, tdKrQ, tdKrdK); + // warpgroup_wait<0>(); + // Tensor dK_tmp = make_tensor(tdKrdK.data(), flash::convert_layout_acc_rowcol(tdKrdK.layout())); + // if (cute::thread0()) { print_tensor(dK_tmp); printf("\n"); } + + warpgroup_wait<1>(); + // if (cute::thread0()) { print(tdQrdQ); printf("\n"); print(tdQgdQaccum); printf("\n"); } + Tensor tdVrdV_atomic = recast(tdVrdV); + Tensor tdVgdVaccum_atomic = recast(tdVgdVaccum(_, _, _, n_block)); + #pragma unroll + for (int i = 0; i < size(tdVrdV_atomic); ++i) { atomicAdd(&tdVgdVaccum_atomic(i), tdVrdV_atomic(i)); } + // for (int i = 0; i < size(tdVrdV_atomic); ++i) { tdVgdVaccum_atomic(i) = tdVrdV_atomic(i); } + + // SMEM fence to make sure sP is written before it's read by WGMMA + cutlass::arch::fence_view_async_shared(); + // cutlass::arch::NamedBarrier::sync(kNThreads, 0 /*id*/); + __syncthreads(); + static_assert(!Mma_dQ_is_RS); + if constexpr (!dQ_swapAB) { + Tensor tdQrdS = threadMmadQ.partition_fragment_A(sdS); + Tensor tdQrK = threadMmadQ.partition_fragment_B(sKt); + flash::gemm(tiledMmadQ, tdQrdS, tdQrK(_, _, _, smem_pipe_read_k.index()), tdQrdQ); + } else { + Tensor tdQrdS = threadMmadQ.partition_fragment_B(sdS); + Tensor tdQrK = threadMmadQ.partition_fragment_A(sKt); + flash::gemm(tiledMmadQ, tdQrK(_, _, _, smem_pipe_read_k.index()), tdQrdS, tdQrdQ); + } + ++smem_pipe_read_k; + // warpgroup_wait<0>(); + // Tensor dQ_tmp = make_tensor(tdQrdQ.data(), flash::convert_layout_acc_rowcol(tdQrdQ.layout())); + // if (cute::thread0()) { print_tensor(dQ_tmp); printf("\n"); } + + warpgroup_wait<1>(); + // if (cute::thread0()) { print(tdQrdQ); printf("\n"); print(tdQgdQaccum); printf("\n"); } + Tensor tdKrdK_atomic = recast(tdKrdK); + Tensor tdKgdKaccum_atomic = recast(tdKgdKaccum(_, _, _, n_block)); + #pragma unroll + for (int i = 0; i < size(tdKrdK_atomic); ++i) { atomicAdd(&tdKgdKaccum_atomic(i), tdKrdK_atomic(i)); } + // for (int i = 0; i < size(tdVrdV_atomic); ++i) { tdVgdVaccum_atomic(i) = tdVrdV_atomic(i); } + + warpgroup_wait<0>(); + + pipeline_v.consumer_release(smem_pipe_release_v); // release V + ++smem_pipe_release_v; + pipeline_k.consumer_release(smem_pipe_release_k); // release V + ++smem_pipe_release_k; + + int const lane_predicate = cute::elect_one_sync(); + int const warp_idx = cutlass::canonical_warp_idx_sync(); + if (warp_idx == 0 && lane_predicate && n_block >= kStages) { + pipeline_k.producer_acquire(smem_pipe_write_k); + copy(tma_load_K.with(*pipeline_k.producer_get_barrier(smem_pipe_write_k), mcast_mask_kv), tKgK(_, _, _, n_block - kStages), tKsK(_, _, _, smem_pipe_write_k.index())); + ++smem_pipe_write_k; + pipeline_v.producer_acquire(smem_pipe_write_v); + copy(tma_load_V.with(*pipeline_v.producer_get_barrier(smem_pipe_write_v), mcast_mask_kv), tVgV(_, _, _, n_block - kStages), tVsV(_, _, _, smem_pipe_write_v.index())); + ++smem_pipe_write_v; + } + } + } + + // Epilogue + + #pragma unroll + for (int i = 0; i < size(tdQrdQ); ++i) { tdQrdQ(i) *= params.scale_softmax; } + // if (cute::thread0()) { print_tensor(tdQrdQ); } + + Tensor tdQrdQ_out = convert_type(tdQrdQ); + + Tensor sdQ = make_tensor(make_smem_ptr(shared_storage.smem_dq.data()), typename Ktraits::SmemLayoutdQ{}); + Tensor sdQt = make_tensor(make_smem_ptr(shared_storage.smem_dq.data()), typename Ktraits::SmemLayoutdQt{}); + + auto smem_tiled_copy_dQ = make_tiled_copy_C(typename Ktraits::SmemCopyAtomdQ{}, tiledMmadQ); + auto smem_thr_copy_dQ = smem_tiled_copy_dQ.get_thread_slice(threadIdx.x); + Tensor taccdQrdQ = smem_thr_copy_dQ.retile_S(tdQrdQ_out); // ((Atom,AtomNum), MMA_M, MMA_N) + + __syncthreads(); + if constexpr (!dQ_swapAB) { + Tensor taccdQsdQ = smem_thr_copy_dQ.partition_D(sdQ); // ((Atom,AtomNum),PIPE_M,PIPE_N) + cute::copy(smem_tiled_copy_dQ, taccdQrdQ, taccdQsdQ); + } else { + Tensor taccdQsdQt = smem_thr_copy_dQ.partition_D(sdQt); // ((Atom,AtomNum),PIPE_M,PIPE_N) + cute::copy(smem_tiled_copy_dQ, taccdQrdQ, taccdQsdQt); + } + cutlass::arch::fence_view_async_shared(); // ensure smem writes are visible to TMA + + Tensor mdQ = tma_store_dQ.get_tma_tensor(make_shape(params.seqlen_q, params.d, params.h, params.b)); + Tensor gdQ = local_tile(mdQ(_, _, bidh, bidb), select<0, 2>(TileShape_MNK{}), make_coord(m_block, _0{})); // (M, K) + auto block_tma_dQ = tma_store_dQ.get_slice(_0{}); + Tensor tdQgdQ = block_tma_dQ.partition_D(gdQ); // (TMA, TMA_M, TMA_K) + Tensor tdQsdQ = block_tma_dQ.partition_S(sdQ); // (TMA, TMA_M, TMA_K) + + __syncthreads(); // ensure all threads have issued their async fence + // if (cute::thread0()) { print_tensor(sdQ); } + + lane_predicate = cute::elect_one_sync(); + warp_idx = cutlass::canonical_warp_idx_sync(); + if (warp_idx == 0 && lane_predicate) { + cute::copy(tma_store_dQ, tdQsdQ, tdQgdQ); + tma_store_arrive(); + } + tma_store_wait<0>(); + + // To make sure remote SMEM doesn't get destroyed + if constexpr (size(ClusterShape{}) > 1) { + cute::cluster_arrive(); + cute::cluster_wait(); + } + +} + +template +__global__ void __launch_bounds__(Ktraits::kNWarps * cutlass::NumThreadsPerWarp, 1) + compute_dqkv_ws(CUTE_GRID_CONSTANT Flash_bwd_params const params, + CUTE_GRID_CONSTANT TiledCopyQ const tma_load_Q, + CUTE_GRID_CONSTANT TiledCopydO const tma_load_dO, + CUTE_GRID_CONSTANT TiledCopyK const tma_load_K, + CUTE_GRID_CONSTANT TiledCopyV const tma_load_V, + CUTE_GRID_CONSTANT TiledCopydK const tma_store_dK, + CUTE_GRID_CONSTANT TiledCopydV const tma_store_dV, + CUTE_GRID_CONSTANT TiledCopydQ const tma_store_dQ, + CUTE_GRID_CONSTANT TiledCopyAdddQ const tma_reduce_add_dQ) { + + using Element = typename Ktraits::Element; + using ElementAccum = typename Ktraits::ElementAccum; + using SoftType = ElementAccum; + using index_t = typename Ktraits::index_t; + using TileShape_MNK = typename Ktraits::TileShape_MNK; + using ClusterShape = typename Ktraits::ClusterShape_MNK; + + static_assert(Ktraits::Is_WS); + + // static constexpr int kNThreads = Ktraits::kNThreads; + // static constexpr int NumMmaThreads = size(typename Ktraits::TiledMmaSdP{}); + static constexpr int NumMmaThreads = 256; + static constexpr int NumCopyThreads = cutlass::NumThreadsPerWarpGroup; + static constexpr int kBlockM = Ktraits::kBlockM; + static constexpr int kNThreadsdQ = Ktraits::kNThreadsdQ; + // static constexpr int kBlockN = Ktraits::kBlockN; + // constexpr int kHeadDim = Ktraits::kHeadDim; + // static constexpr int kStages = Ktraits::kStages; + + static constexpr bool SdP_swapAB = Ktraits::SdP_swapAB; + static constexpr bool dKV_swapAB = Ktraits::dKV_swapAB; + static constexpr bool dQ_swapAB = Ktraits::dQ_swapAB; + + if constexpr (SdP_swapAB) { static_assert(!dKV_swapAB); } + + static constexpr bool Mma_dQ_is_RS = Ktraits::Mma_dQ_is_RS; + if constexpr (dQ_swapAB) { static_assert(!Mma_dQ_is_RS); } + + using MainloopPipeline = typename Ktraits::MainloopPipeline; + using PipelineParams = typename MainloopPipeline::Params; + using PipelineState = typename MainloopPipeline::PipelineState; + + extern __shared__ char shared_memory[]; + auto &shared_storage = *reinterpret_cast(shared_memory); + + int lane_predicate = cute::elect_one_sync(); + int warp_idx = cutlass::canonical_warp_idx_sync(); + + // Issue Tma Descriptor Prefetch from a single thread + if (warp_idx == 0 && lane_predicate) { + cute::prefetch_tma_descriptor(tma_load_Q.get_tma_descriptor()); + cute::prefetch_tma_descriptor(tma_load_dO.get_tma_descriptor()); + cute::prefetch_tma_descriptor(tma_load_K.get_tma_descriptor()); + cute::prefetch_tma_descriptor(tma_load_V.get_tma_descriptor()); + cute::prefetch_tma_descriptor(tma_store_dK.get_tma_descriptor()); + cute::prefetch_tma_descriptor(tma_store_dV.get_tma_descriptor()); + cute::prefetch_tma_descriptor(tma_store_dQ.get_tma_descriptor()); + cute::prefetch_tma_descriptor(tma_reduce_add_dQ.get_tma_descriptor()); + } + + // Construct SMEM tensors. + Tensor sQ = make_tensor(make_smem_ptr(shared_storage.smem_q.data()), typename Ktraits::SmemLayoutQ{}); + Tensor sdO = make_tensor(make_smem_ptr(shared_storage.smem_do.data()), typename Ktraits::SmemLayoutdO{}); + Tensor sK = make_tensor(make_smem_ptr(shared_storage.smem_k.data()), typename Ktraits::SmemLayoutK{}); + Tensor sV = make_tensor(make_smem_ptr(shared_storage.smem_v.data()), typename Ktraits::SmemLayoutV{}); + Tensor sP = make_tensor(make_smem_ptr(shared_storage.smem_p.data()), typename Ktraits::SmemLayoutP{}); + Tensor sdS = make_tensor(make_smem_ptr(shared_storage.smem_ds.data()), typename Ktraits::SmemLayoutdS{}); + Tensor sQt = make_tensor(make_smem_ptr(shared_storage.smem_q.data()), typename Ktraits::SmemLayoutQt{}); + Tensor sdOt = make_tensor(make_smem_ptr(shared_storage.smem_do.data()), typename Ktraits::SmemLayoutdOt{}); + Tensor sKt = make_tensor(make_smem_ptr(shared_storage.smem_k.data()), typename Ktraits::SmemLayoutKt{}); + Tensor sPt = make_tensor(make_smem_ptr(shared_storage.smem_p.data()), typename Ktraits::SmemLayoutPt{}); + Tensor sdSt = make_tensor(make_smem_ptr(shared_storage.smem_ds.data()), typename Ktraits::SmemLayoutdSt{}); + Tensor sdQ = make_tensor(make_smem_ptr(shared_storage.smem_dqacc.data()), typename Ktraits::SmemLayoutdQacc{}); + Tensor sdQ2 = make_tensor(make_smem_ptr(shared_storage.smem_dqacc.data()), typename Ktraits::SmemLayoutdQacc2{}); + Tensor sdQt = make_tensor(make_smem_ptr(shared_storage.smem_dqacc.data()), typename Ktraits::SmemLayoutdQacct{}); + + // Set the bytes transferred in this TMA transaction (may involve multiple issues) + constexpr uint32_t TmaTransactionBytesQ = static_cast(size<0>(sQ) * size<1>(sQ) * cutlass::sizeof_bits_v / 8); + constexpr uint32_t TmaTransactionBytesdO = static_cast(size<0>(sdO) * size<1>(sdO) * cutlass::sizeof_bits_v / 8); + static_assert(TmaTransactionBytesQ == TmaTransactionBytesdO); + constexpr uint32_t TmaTransactionBytesK = static_cast(size<0>(sK) * size<1>(sK) * cutlass::sizeof_bits_v / 8); + constexpr uint32_t TmaTransactionBytesV = static_cast(size<0>(sV) * size<1>(sV) * cutlass::sizeof_bits_v / 8); + static_assert(TmaTransactionBytesK == TmaTransactionBytesV); + + // Obtain warp index + int thread_idx = int(threadIdx.x); + int warp_group_thread_idx = thread_idx % cutlass::NumThreadsPerWarpGroup; + // int const warp_group_thread_idx = threadIdx.x % cutlass::NumThreadsPerWarpGroup; + + PipelineParams pipeline_params; + pipeline_params.transaction_bytes = TmaTransactionBytesQ; + int warp_group_idx = cutlass::canonical_warp_group_idx(); + if (warp_group_idx == 0) { + pipeline_params.role = MainloopPipeline::ThreadCategory::Producer; + } else { + pipeline_params.role = MainloopPipeline::ThreadCategory::Consumer; + } + pipeline_params.is_leader = warp_group_thread_idx == 0; + pipeline_params.num_consumers = NumMmaThreads; + + if (warp_idx == 0 && lane_predicate) { + shared_storage.barrier_K.init(1 /*numThreads*/); + shared_storage.barrier_V.init(1 /*numThreads*/); + } + // cutlass::arch::fence_barrier_init(); + // We're counting on pipeline_q to call fence_barrier_init(); + MainloopPipeline pipeline_q(shared_storage.pipeline_q, pipeline_params, ClusterShape{}); + MainloopPipeline pipeline_do(shared_storage.pipeline_do, pipeline_params, ClusterShape{}); + + // We need this to guarantee that the Pipeline init is visible + // To all producers and consumer blocks in the Cluster + if constexpr (size(ClusterShape{}) > 1) { + cute::cluster_arrive_relaxed(); + cute::cluster_wait(); + } else { + __syncthreads(); + } + + if (warp_group_idx == 0) { // Producer + // method in cutlass/arch/reg_reconfig.h + // calls setmaxnreg.dec.sync.aligned.u32 + cutlass::arch::warpgroup_reg_dealloc<24>(); + + int const n_block = blockIdx.x; + int const bidb = blockIdx.z; // The block index for the batch. + int const bidh = blockIdx.y; // The block index for the head. + + int m_block = cute::ceil_div(params.seqlen_q, kBlockM) - 1; + + int warp_idx_in_warpgroup = __shfl_sync(0xffffffff, (threadIdx.x / 32) % 4, 0); + int lane_predicate = cute::elect_one_sync(); + // if (warp_idx_in_warpgroup == 0 && lane_predicate) { + if (warp_idx_in_warpgroup == 0) { // Load K, and do TMA on Q and dO + Tensor mQ = tma_load_Q.get_tma_tensor(make_shape(params.seqlen_q, params.d, params.h, params.b)); + Tensor mdO = tma_load_dO.get_tma_tensor(make_shape(params.seqlen_q, params.d, params.h, params.b)); + Tensor mK = tma_load_K.get_tma_tensor(make_shape(params.seqlen_k, params.d, params.h, params.b)); + Tensor gQ = local_tile(mQ(_, _, bidh, bidb), select<0, 2>(TileShape_MNK{}), make_coord(_, _0{})); // (M, K, _) + Tensor gdO = local_tile(mdO(_, _, bidh, bidb), select<0, 2>(TileShape_MNK{}), make_coord(_, _0{})); // (M, K, _) + Tensor gK = local_tile(mK(_, _, bidh, bidb), select<1, 2>(TileShape_MNK{}), make_coord(n_block, _0{})); // (N, K) + // Prepare the TMA loads + uint32_t block_rank_in_cluster = cute::block_rank_in_cluster(); + constexpr uint32_t cluster_shape_x = get<0>(ClusterShape()); + uint2 cluster_local_block_id = {block_rank_in_cluster % cluster_shape_x, block_rank_in_cluster / cluster_shape_x}; + auto block_tma_Q = tma_load_Q.get_slice(cluster_local_block_id.y); + auto block_tma_dO = tma_load_dO.get_slice(cluster_local_block_id.y); + auto block_tma_K = tma_load_K.get_slice(_0{}); + Tensor tQgQ = block_tma_Q.partition_S(gQ); // (TMA, TMA_M, TMA_K, k) + Tensor tQsQ = block_tma_Q.partition_D(sQ); // (TMA, TMA_M, TMA_K, PIPE) + Tensor tdOgdO = block_tma_dO.partition_S(gdO); // (TMA, TMA_M, TMA_K, k) + Tensor tdOsdO = block_tma_dO.partition_D(sdO); // (TMA, TMA_M, TMA_K, PIPE) + Tensor tKgK = block_tma_K.partition_S(gK); // (TMA, TMA_N, TMA_K) + Tensor tKsK = block_tma_K.partition_D(sK); // (TMA, TMA_N, TMA_K) + + PipelineState smem_pipe_write_q = cutlass::make_producer_start_state(); + PipelineState smem_pipe_write_do = cutlass::make_producer_start_state(); + + uint16_t mcast_mask_qdo = 0; + if constexpr (cute::is_same_v) { + auto block_layout = Layout{}; // (m,n) -> block_id + for (int n = 0; n < size<1>(block_layout); ++n) { + mcast_mask_qdo |= (uint16_t(1) << block_layout(n, cluster_local_block_id.x, _0{})); + } + } + + if (lane_predicate) { + // Copy K tile and V tile from GMEM to SMEM. + shared_storage.barrier_K.arrive_and_expect_tx(TmaTransactionBytesK); + copy(tma_load_K.with(reinterpret_cast(shared_storage.barrier_K), 0 /*mcast_mask*/), tKgK, tKsK); + + #pragma unroll 2 + for (; m_block >= 0; --m_block) { + pipeline_q.producer_acquire(smem_pipe_write_q); + copy(tma_load_Q.with(*pipeline_q.producer_get_barrier(smem_pipe_write_q), mcast_mask_qdo), tQgQ(_, _, _, m_block), tQsQ(_, _, _, smem_pipe_write_q.index())); + ++smem_pipe_write_q; + pipeline_do.producer_acquire(smem_pipe_write_do); + copy(tma_load_dO.with(*pipeline_do.producer_get_barrier(smem_pipe_write_do), mcast_mask_qdo), tdOgdO(_, _, _, m_block), tdOsdO(_, _, _, smem_pipe_write_do.index())); + ++smem_pipe_write_do; + } + + // Tail loop + /* This helps avoid early exit of blocks in Cluster + * Waits for all stages to either be released (all + * Consumer UNLOCKs), or if the stage was never used + * then would just be acquired since the phase was + * still inverted from make_producer_start_state + */ + pipeline_q.producer_tail(smem_pipe_write_q); + pipeline_do.producer_tail(smem_pipe_write_do); + } + } else if (warp_idx_in_warpgroup == 1) { // Load V, and do TMA_REDUCE_ADD on dQ + Tensor mV = tma_load_V.get_tma_tensor(make_shape(params.seqlen_k, params.d, params.h, params.b)); + Tensor gV = local_tile(mV(_, _, bidh, bidb), select<1, 2>(TileShape_MNK{}), make_coord(n_block, _0{})); // (N, K) + auto block_tma_V = tma_load_V.get_slice(_0{}); + Tensor tVgV = block_tma_V.partition_S(gV); // (TMA, TMA_N, TMA_K) + Tensor tVsV = block_tma_V.partition_D(sV); // (TMA, TMA_N, TMA_K) + if (lane_predicate) { + shared_storage.barrier_V.arrive_and_expect_tx(TmaTransactionBytesV); + copy(tma_load_V.with(reinterpret_cast(shared_storage.barrier_V), 0 /*mcast_mask*/), tVgV, tVsV); + } + + Tensor mdQaccum = tma_store_dQ.get_tma_tensor(make_shape(params.seqlen_q, params.d, params.h, params.b)); + Tensor gdQaccum = local_tile(mdQaccum(_, _, bidh, bidb), select<0, 2>(TileShape_MNK{}), make_coord(_, _0{})); // (M, K, _) + auto block_tma_dQ = tma_store_dQ.get_slice(_0{}); + Tensor tdQgdQ = block_tma_dQ.partition_D(gdQaccum); // (TMA, TMA_M, TMA_K) + Tensor tdQsdQ = block_tma_dQ.partition_S(sdQ); // (TMA, TMA_M, TMA_K) + Tensor tdQsdQ2 = block_tma_dQ.partition_S(sdQ2); // (TMA, TMA_M, TMA_K, 2) + int *lock_ptr = params.dq_semaphore + bidb * params.h + bidh; + using Barrier = cutlass::GenericBarrier; + // cutlass::arch::NamedBarrier::arrive(kNThreadsdQ + 32, 1 /*id*/); // sdQ empty, ready to be written to + cutlass::arch::NamedBarrier::arrive(kNThreadsdQ + 32, 0 /*id*/); // sdQ empty, ready to be written to + // cutlass::arch::NamedBarrier::arrive(kNThreadsdQ + 32, 0 + (m_block + 1) % 2 /*id*/); // sdQ empty, ready to be written to + // cutlass::arch::NamedBarrier::arrive(kNThreadsdQ + 32, 0 + m_block % 2 /*id*/); // sdQ empty, ready to be written to + // if (n_block == 0) { // Use TMA_STORE + if (false) { // Use TMA_STORE + #pragma unroll 2 + for (; m_block >= 0; --m_block) { + cutlass::arch::NamedBarrier::sync(kNThreadsdQ + 32, 2 /*id*/); // sdQ full, to be written to gmem + // cutlass::arch::NamedBarrier::sync(kNThreadsdQ + 32, 2 + m_block % 2 /*id*/); // sdQ full, to be written to gmem + if (lane_predicate) { + cute::copy(tma_store_dQ, tdQsdQ, tdQgdQ(_, _, _, m_block)); + // cute::copy(tma_store_dQ, tdQsdQ2(_, _, _, m_block % 2), tdQgdQ(_, _, _, m_block)); + tma_store_arrive(); + } + tma_store_wait<0>(); + Barrier::arrive_inc(lock_ptr, threadIdx.x % 32, m_block * params.b * params.h); + cutlass::arch::NamedBarrier::arrive(kNThreadsdQ + 32, 0 /*id*/); // sdQ empty, ready to be written to + // cutlass::arch::NamedBarrier::arrive(kNThreadsdQ + 32, 0 + m_block % 2 /*id*/); // sdQ empty, ready to be written to + } + } else { // Use TMA_REDUCE_ADD + #pragma unroll 2 + for (; m_block >= 0; --m_block) { + // Barrier::wait_eq(lock_ptr, threadIdx.x % 32, m_block * params.b * params.h, n_block); + // Barrier::wait_lt(lock_ptr, threadIdx.x % 32, m_block * params.b * params.h, 1); + cutlass::arch::NamedBarrier::sync(kNThreadsdQ + 32, 2 /*id*/); // sdQ full, to be written to gmem + // cutlass::arch::NamedBarrier::sync(kNThreadsdQ + 32, 2 + m_block % 2 /*id*/); // sdQ full, to be written to gmem + if (lane_predicate) { + cute::copy(tma_reduce_add_dQ, tdQsdQ, tdQgdQ(_, _, _, m_block)); + // cute::copy(tma_reduce_add_dQ, tdQsdQ2(_, _, _, m_block % 2), tdQgdQ(_, _, _, m_block)); + tma_store_arrive(); + } + tma_store_wait<0>(); + // Barrier::arrive_inc(lock_ptr, threadIdx.x % 32, m_block * params.b * params.h); + // cutlass::arch::NamedBarrier::arrive(kNThreadsdQ + 32, 0 + m_block % 2 /*id*/); // sdQ empty, ready to be written to + cutlass::arch::NamedBarrier::arrive(kNThreadsdQ + 32, 0 /*id*/); // sdQ empty, ready to be written to + } + } + // } else if (warp_idx_in_warpgroup == 2) { // Load LSE and dPSum + // Tensor mLSE = make_tensor(make_gmem_ptr(reinterpret_cast(params.softmax_lse_ptr)), + // make_shape(params.b, params.h, params.seqlen_q), + // make_stride(params.h * params.seqlen_q, params.seqlen_q, _1{})); + // Tensor mdPsum = make_tensor(make_gmem_ptr(reinterpret_cast(params.dsoftmax_sum)), + // make_shape(params.b, params.h, params.seqlen_q), + // make_stride(params.h * params.seqlen_q, params.seqlen_q, _1{})); + // Tensor gLSE = local_tile(mLSE(bidb, bidh, _), Shape>{}, make_coord(_)); // (M, _) + // Tensor gdPsum = local_tile(mdPsum(bidb, bidh, _), Shape>{}, make_coord(_)); // (M, _) + // Tensor sLSE = make_tensor(make_smem_ptr(shared_storage.smem_lse.data()), Shape>{}); + // Tensor sdPsum = make_tensor(make_smem_ptr(shared_storage.smem_dpsum.data()), Shape>{}); + // #pragma unroll 2 + // for (; m_block >= 0; --m_block) { + // cutlass::arch::NamedBarrier::sync(NumMmaThreads + 32, 3 /*id*/); // sLSE and sdPsum are empty + // #pragma unroll + // for (int i = 0; i < cute::ceil_div(kBlockM, 32); ++i) { + // int idx = threadIdx.x % 32 + i * 32; + // sLSE(idx) = idx < params.seqlen_q - m_block * kBlockM ? gLSE(idx, m_block) : INFINITY; + // sdPsum(idx) = idx < params.seqlen_q - m_block * kBlockM ? gdPsum(idx, m_block) : 0; + // } + // // sLSE and sdPsum are ready for WG 1 + // cutlass::arch::NamedBarrier::arrive(128 + 32, 3 + 1 /*id*/); + // // sLSE and sdPsum are ready for WG 2 + // cutlass::arch::NamedBarrier::arrive(128 + 32, 3 + 2 /*id*/); + // } + } + + + } else { // Consumers + // method in cutlass/arch/reg_reconfig.h + // calls setmaxnreg.inc.sync.aligned.u32 + cutlass::arch::warpgroup_reg_alloc<240>(); + + // State variables used for iterating the circular buffer + // smem_pipe_read / release is used by the consumer of SMEM data - i.e MMA + // smem_pipe_write is used by the producer of SMEM data - i.e TMA + PipelineState smem_pipe_read_q, smem_pipe_read_do; + PipelineState smem_pipe_release_q, smem_pipe_release_do; + + int m_block = cute::ceil_div(params.seqlen_q, kBlockM) - 1; + const int m_block_max = m_block; + + int bidb = blockIdx.z; // The block index for the batch. + int bidh = blockIdx.y; // The block index for the head. + + Tensor mLSE = make_tensor(make_gmem_ptr(reinterpret_cast(params.softmax_lse_ptr)), + make_shape(params.b, params.h, params.seqlen_q), + make_stride(params.h * params.seqlen_q, params.seqlen_q, _1{})); + Tensor mdPsum = make_tensor(make_gmem_ptr(reinterpret_cast(params.dsoftmax_sum)), + make_shape(params.b, params.h, params.seqlen_q), + make_stride(params.h * params.seqlen_q_rounded, params.seqlen_q_rounded, _1{})); + Tensor gLSE = local_tile(mLSE(bidb, bidh, _), Shape>{}, make_coord(m_block)); + Tensor gdPsum = local_tile(mdPsum(bidb, bidh, _), Shape>{}, make_coord(m_block)); + + Tensor sLSE = make_tensor(make_smem_ptr(shared_storage.smem_lse.data()), Shape>{}); + Tensor sdPsum = make_tensor(make_smem_ptr(shared_storage.smem_dpsum.data()), Shape>{}); + + typename Ktraits::RmemTiledCopydQacc rmem_tiled_copy_dQaccum; + // auto rmem_thr_copy_dQaccum = rmem_tiled_copy_dQaccum.get_thread_slice((threadIdx.x - NumCopyThreads) % kNThreadsdQ); + auto rmem_thr_copy_dQaccum = rmem_tiled_copy_dQaccum.get_thread_slice(threadIdx.x - NumCopyThreads); + Tensor tdQsdQaccum = rmem_thr_copy_dQaccum.partition_D(sdQ); + Tensor tdQsdQaccum2 = rmem_thr_copy_dQaccum.partition_D(sdQ2); + + // Initialize matmul objects. + typename Ktraits::TiledMmaSdP tiledMmaSdP; + auto threadMmaSdP = tiledMmaSdP.get_thread_slice(threadIdx.x - NumCopyThreads); + typename Ktraits::TiledMmadKV tiledMmadKV; + auto threadMmadKV = tiledMmadKV.get_thread_slice(threadIdx.x - NumCopyThreads); + typename Ktraits::TiledMmadQ tiledMmadQ; + // auto threadMmadQ = tiledMmadQ.get_thread_slice((threadIdx.x - NumCopyThreads) % kNThreadsdQ); + auto threadMmadQ = tiledMmadQ.get_thread_slice(threadIdx.x - NumCopyThreads); + + // Allocate accumulator + Tensor tdKrdK = partition_fragment_C(tiledMmadKV, select(TileShape_MNK{})); + Tensor tdVrdV = partition_fragment_C(tiledMmadKV, select(TileShape_MNK{})); + + auto smem_tiled_copy_PdS = make_tiled_copy_C(typename Ktraits::SmemCopyAtomPdS{}, tiledMmaSdP); + auto smem_thr_copy_PdS = smem_tiled_copy_PdS.get_thread_slice(threadIdx.x - NumCopyThreads); + // auto smem_tiled_copy_dQ = make_tiled_copy_C(typename Ktraits::SmemCopyAtomdQ{}, tiledMmadQ); + // auto smem_tiled_copy_dQ = make_tiled_copy_C(Copy_Atom{}, tiledMmadQ); + // auto smem_tiled_copy_dQ = make_tiled_copy_C(Copy_Atom{}, tiledMmadQ); + // auto smem_thr_copy_dQ = smem_tiled_copy_dQ.get_thread_slice(threadIdx.x - NumCopyThreads); + + Tensor tdSsdS = smem_thr_copy_PdS.partition_D(sdSt); // ((Atom,AtomNum),PIPE_M,PIPE_N) + + if constexpr (SdP_swapAB) { + + // Allocate "fragments/descriptors" + Tensor tSrQ = threadMmaSdP.partition_fragment_B(sQ); + Tensor tSrK = threadMmaSdP.partition_fragment_A(sK); + Tensor tdPrdO = threadMmaSdP.partition_fragment_B(sdO); + Tensor tdPrV = threadMmaSdP.partition_fragment_A(sV); + + Tensor caccS = make_identity_tensor(select<1, 0>(TileShape_MNK{})); // (BLK_M,BLK_N) -> (blk_m,blk_n) + Tensor taccScS = threadMmaSdP.partition_C(caccS); // (MMA,MMA_N,MMA_N) + static_assert(decltype(size<0, 0>(taccScS))::value == 2); + static_assert(decltype(size<0, 1>(taccScS))::value == 2); + // taccScS has shape ((2, 2, V), MMA_M, MMA_N), we only take only the row indices. + Tensor taccScS_row = taccScS(make_coord(_, _0{}, _), _0{}, _); + static constexpr int kStatsPerThread = cute::ceil_div(decltype(size(taccScS_row))::value, 8); + static constexpr bool kStatsDivisibleBy8 = decltype(size(taccScS_row))::value % 8 == 0; + Tensor lse = make_tensor(Shape>{}); + // Tensor lse = make_tensor(Shape>{}); + Tensor dP_sum = make_fragment_like(lse); + #pragma unroll + for (int mi = 0; mi < size(lse); ++mi) { + const int row = get<1>(taccScS_row(mi)); + lse(mi) = row < params.seqlen_q - m_block * kBlockM ? gLSE(row) : INFINITY; + dP_sum(mi) = row < params.seqlen_q - m_block * kBlockM ? gdPsum(row) : 0; + } + // #pragma unroll + // for (int mi = 0; mi < size(lse); ++mi) { + // const int row_idx = mi * 8 + (threadIdx.x % 32) / 4; + // const int row = kStatsDivisibleBy8 || row_idx < size(taccScS_row) ? get<1>(taccScS_row(row_idx)) : 0; + // lse(mi) = row < params.seqlen_q - m_block * kBlockM ? gLSE(row) : INFINITY; + // dP_sum(mi) = row < params.seqlen_q - m_block * kBlockM ? gdPsum(row) : 0; + // } + // if (blockIdx.x == 0 && blockIdx.z == 1 && threadIdx.x == 128) { print_tensor(dP_sum); printf("\n"); } + // Trying to spread LSE and dPSum across threads in a warp but it's slower + // const int row_idx = mi * 8 + (threadIdx.x % 32) / 4; + // const int row = get<1>(taccScS_row(row_idx)); // TODO: what if row_idx is outside the range? + // cute::fill(lse, 1); + // cute::fill(dP_sum, 1); + // if (cute::thread0()) { print_tensor(dP_sum); printf("\n"); } + // We want LSE = inf if the row is OOB. In that case Q would be zero, K would be zero, + // and scores would be zero. With LSE = 0, probs will be all 1's, and when we multiply + // with V (which would be zero), we're fine. However, with ALiBi, we might modify these + // scores, and probs can become NaN. Instead if we set LSE = inf for OOB rows, probs are always 0. + + // cutlass::arch::NamedBarrier::arrive(NumMmaThreads + 32, 3 /*id*/); // sLSE and sdPsum are empty + + clear(tdKrdK); + clear(tdVrdV); + + shared_storage.barrier_K.wait(0); + shared_storage.barrier_V.wait(0); + + // #pragma unroll 2 + CUTLASS_PRAGMA_NO_UNROLL + for (; m_block >= 0; --m_block) { + + // Putting this dQ block at the beginning of the loop gives an extra 10 TFLOPs + // It does make the code uglier, idk if it's worth it. + if (m_block < m_block_max) { + // SMEM fence to make sure sP is written before it's read by WGMMA + cutlass::arch::fence_view_async_shared(); + // dS is already written to smem, and the smem for dQ is empty (from warp 1 doing TMA_REDUCE_ADD) + // int warp_group_idx = cutlass::canonical_warp_group_idx(); + // if (warp_group_idx == 1 + (m_block + 1) % 2) { + // // cutlass::arch::NamedBarrier::arrive(NumMmaThreads + 32, 0 + (m_block + 1) % 2 /*id*/); + // cutlass::arch::NamedBarrier::arrive(NumMmaThreads, 4); + // } else { + // // cutlass::arch::NamedBarrier::sync(NumMmaThreads + 32, 0 + (m_block + 1) % 2 /*id*/); + // cutlass::arch::NamedBarrier::sync(NumMmaThreads, 4); + // static_assert(!Mma_dQ_is_RS); + // Tensor tdQrdQ = partition_fragment_C(tiledMmadQ, select(TileShape_MNK{})); + // if constexpr (!dQ_swapAB) { + // Tensor tdQrdS = threadMmadQ.partition_fragment_A(sdS); + // Tensor tdQrK = threadMmadQ.partition_fragment_B(sKt); + // flash::gemm(tiledMmadQ, tdQrdS, tdQrK, tdQrdQ); + // } else { + // Tensor tdQrdS = threadMmadQ.partition_fragment_B(sdS); + // Tensor tdQrK = threadMmadQ.partition_fragment_A(sKt); + // flash::gemm(tiledMmadQ, tdQrK, tdQrdS, tdQrdQ); + // } + // Tensor taccdQrdQ = rmem_thr_copy_dQaccum.retile_S(tdQrdQ); // ((Atom,AtomNum), MMA_M, MMA_N) + // cutlass::arch::NamedBarrier::sync(NumMmaThreads / 2 + 32, 0 + (m_block + 1) % 2 /*id*/); + // cute::copy(rmem_tiled_copy_dQaccum, taccdQrdQ, tdQsdQaccum2(_, _, _, (m_block + 1) % 2)); + // cutlass::arch::fence_view_async_shared(); + // cutlass::arch::NamedBarrier::arrive(NumMmaThreads / 2 + 32, 2 + (m_block + 1) % 2 /*id*/); // sdQ ready to be written to gmem + // } + // cutlass::arch::NamedBarrier::sync(NumMmaThreads + 32, 0 + (m_block + 1) % 2 /*id*/); + cutlass::arch::NamedBarrier::sync(NumMmaThreads + 32, 0 /*id*/); + static_assert(!Mma_dQ_is_RS); + Tensor tdQrdQ = partition_fragment_C(tiledMmadQ, select(TileShape_MNK{})); + // Tensor taccdQrdQ = smem_thr_copy_dQ.retile_S(tdQrdQ); // ((Atom,AtomNum), MMA_M, MMA_N) + if constexpr (!dQ_swapAB) { + Tensor tdQrdS = threadMmadQ.partition_fragment_A(sdS); + Tensor tdQrK = threadMmadQ.partition_fragment_B(sKt); + flash::gemm(tiledMmadQ, tdQrdS, tdQrK, tdQrdQ); + } else { + Tensor tdQrdS = threadMmadQ.partition_fragment_B(sdS); + Tensor tdQrK = threadMmadQ.partition_fragment_A(sKt); + flash::gemm(tiledMmadQ, tdQrK, tdQrdS, tdQrdQ); + } + // Tensor taccdQsdQt = smem_thr_copy_dQ.partition_D(sdQt); // ((Atom,AtomNum),PIPE_M,PIPE_N) + // cute::copy(smem_tiled_copy_dQ, taccdQrdQ, taccdQsdQt); + // Tensor taccdQsdQ = smem_thr_copy_dQ.partition_D(sdQ); // ((Atom,AtomNum),PIPE_M,PIPE_N) + // cute::copy(smem_tiled_copy_dQ, taccdQrdQ, taccdQsdQ); + Tensor taccdQrdQ = rmem_thr_copy_dQaccum.retile_S(tdQrdQ); // ((Atom,AtomNum), MMA_M, MMA_N) + // cutlass::arch::NamedBarrier::sync(NumMmaThreads + 32, 1 /*id*/); // sdQ empty, ready to be written to + cute::copy(rmem_tiled_copy_dQaccum, taccdQrdQ, tdQsdQaccum); + // cute::copy(rmem_tiled_copy_dQaccum, taccdQrdQ, tdQsdQaccum2(_, _, _, (m_block + 1) % 2)); + cutlass::arch::fence_view_async_shared(); + cutlass::arch::NamedBarrier::arrive(NumMmaThreads + 32, 2 /*id*/); // sdQ ready to be written to gmem + // cutlass::arch::NamedBarrier::arrive(NumMmaThreads + 32, 2 + (m_block + 1) % 2 /*id*/); // sdQ ready to be written to gmem + } + + Tensor tSrS = partition_fragment_C(tiledMmaSdP, select<1, 0>(TileShape_MNK{})); + pipeline_q.consumer_wait(smem_pipe_read_q); + flash::gemm(tiledMmaSdP, tSrK, tSrQ(_, _, _, smem_pipe_read_q.index()), tSrS); + Tensor tdPrdP = partition_fragment_C(tiledMmaSdP, select<1, 0>(TileShape_MNK{})); + pipeline_do.consumer_wait(smem_pipe_read_do); + flash::gemm(tiledMmaSdP, tdPrV, tdPrdO(_, _, _, smem_pipe_read_do.index()), tdPrdP); + + // sLSE and sdPsum are done loading for WG 1 or 2 + // cutlass::arch::NamedBarrier::sync(128 + 32, 3 + cutlass::canonical_warp_group_idx() /*id*/); + // Tensor lse = make_tensor(Shape>{}); + // #pragma unroll + // for (int mi = 0; mi < size(lse); ++mi) { lse(mi) = sLSE(get<1>(taccScS_row(mi))); } + warpgroup_wait<1>(); + // Reshape tSrS from ((2, 2, V), MMA_N, MMA_M) to (nrow=(2, V, MMA_M), ncol=(2, MMA_N)) + Tensor scores = make_tensor(tSrS.data(), flash::convert_layout_acc_transposed_rowcol(tSrS.layout())); + flash::scale_apply_exp2(scores, lse, params.scale_softmax_log2); + // #pragma unroll + // for (int mi = 0; mi < size<0>(lse); ++mi) { lse(mi) *= float(M_LOG2E); } + // #pragma unroll + // for (int mi = 0; mi < size<0>(scores); ++mi) { + // // const float lse_scaled = lse(mi) * float(M_LOG2E); + // const float lse_scaled = __shfl_sync(0xffffffff, lse(mi / 8), (mi % 8) * 4 + (threadIdx.x % 4)); + // // const float lse_scaled = __shfl_xor_sync(0xffffffff, lse(mi / 8), 1 << (mi % 4)) * float(M_LOG2E); + // // const float lse_scaled = lse(mi); + // #pragma unroll + // for (int ni = 0; ni < size<1>(scores); ++ni) { + // scores(mi, ni) = exp2f(scores(mi, ni) * params.scale_softmax_log2 - lse_scaled); + // } + // } + // if (blockIdx.x == 0 && blockIdx.z == 1 && threadIdx.x == 128) { print_tensor(scores); printf("\n"); } + // Tensor dP_sum = make_fragment_like(lse); + // sLSE and sdPsum are done loading for WG 1 or 2 + // cutlass::arch::NamedBarrier::sync(128 + 32, 3 + cutlass::canonical_warp_group_idx() /*id*/); + // #pragma unroll + // for (int mi = 0; mi < size(dP_sum); ++mi) { dP_sum(mi) = sdPsum(get<1>(taccScS_row(mi))); } + + + // Convert scores from fp32 to fp16/bf16 + Tensor rP = flash::convert_type(tSrS); + + warpgroup_wait<0>(); + // Reshape tdPrdP from ((2, 2, V), MMA_N, MMA_M) to (nrow=(2, V, MMA_M), ncol=(2, MMA_N)) + Tensor dS = make_tensor(tdPrdP.data(), scores.layout()); + // if (blockIdx.x == 0 && blockIdx.z == 1 && threadIdx.x == 128) { print_tensor(dS); printf("\n"); } + // if (blockIdx.x == 0 && blockIdx.z == 1 && threadIdx.x == 128) { print_tensor(dP_sum); printf("\n"); } + #pragma unroll + for (int mi = 0; mi < size<0>(dS); ++mi) { + // const float dP_sum_cur = __shfl_sync(0xffffffff, dP_sum(mi / 8), (mi % 8) * 4 + (threadIdx.x % 4)); + // const float dP_sum_cur = __shfl_xor_sync(0xffffffff, dP_sum(mi / 8), 1 << (mi % 4)); + #pragma unroll + for (int ni = 0; ni < size<1>(dS); ++ni) { dS(mi, ni) = scores(mi, ni) * (dS(mi, ni) - dP_sum(mi)); } + // for (int ni = 0; ni < size<1>(dS); ++ni) { dS(mi, ni) = scores(mi, ni) * (dS(mi, ni) - dP_sum_cur); } + } + // sLSE and sdPsum are done processing, can load for the next iteration + // cutlass::arch::NamedBarrier::arrive(NumMmaThreads + 32, 3 /*id*/); + // if (blockIdx.x == 0 && blockIdx.z == 1 && threadIdx.x == 128) { print_tensor(dS); printf("\n"); } + Tensor rdS = flash::convert_type(tdPrdP); + + Tensor tdSadS = smem_thr_copy_PdS.retile_S(rdS); // ((Atom,AtomNum), MMA_N, MMA_N) + cute::copy(smem_tiled_copy_PdS, tdSadS, tdSsdS); + + if (m_block > 0) { + gLSE.data() = gLSE.data() + (-int(kBlockM)); + gdPsum.data() = gdPsum.data() + (-int(kBlockM)); + } + // #pragma unroll + // for (int mi = 0; mi < size(lse); ++mi) { + // // const int row = get<1>(taccScS_row(mi)); + // const int row_idx = mi * 8 + (threadIdx.x % 32) / 4; + // const int row = kStatsDivisibleBy8 || row_idx < size(taccScS_row) ? get<1>(taccScS_row(row_idx)) : 0; + // lse(mi) = gLSE(row); + // dP_sum(mi) = gdPsum(row); + // } + Tensor lse_float2 = recast(lse); + Tensor dP_sum_float2 = recast(dP_sum); + #pragma unroll + for (int mi = 0; mi < size(lse) / 2; ++mi) { + const int row = get<1>(taccScS_row(mi * 2)); + lse_float2(mi) = *reinterpret_cast(&(gLSE(row))); + dP_sum_float2(mi) = *reinterpret_cast(&(gdPsum(row))); + } + + static_assert(!dKV_swapAB); + Tensor tdVrP = make_tensor(rP.data(), convert_layout_acc_Aregs(tSrS.layout())); + Tensor tdVrdO = threadMmadKV.partition_fragment_B(sdOt); + flash::gemm(tiledMmadKV, tdVrP, tdVrdO(_, _, _, smem_pipe_read_do.index()), tdVrdV); + ++smem_pipe_read_do; + // warpgroup_wait<0>(); + // Tensor dV_tmp = make_tensor(tdVrdV.data(), flash::convert_layout_acc_rowcol(tdVrdV.layout())); + // if (blockIdx.x == 0 && threadIdx.x == 128) { print_tensor(dV_tmp); printf("\n"); } + + Tensor tdKrdS = make_tensor(rdS.data(), convert_layout_acc_Aregs(tdPrdP.layout())); + Tensor tdKrQ = threadMmadKV.partition_fragment_B(sQt); + flash::gemm(tiledMmadKV, tdKrdS, tdKrQ(_, _, _, smem_pipe_read_q.index()), tdKrdK); + ++smem_pipe_read_q; + // warpgroup_wait<0>(); + // Tensor dK_tmp = make_tensor(tdKrdK.data(), flash::convert_layout_acc_rowcol(tdKrdK.layout())); + // if (blockIdx.x == 0 && threadIdx.x == 128) { print_tensor(dK_tmp); printf("\n"); } + + pipeline_do.consumer_release(smem_pipe_release_do); // release V + ++smem_pipe_release_do; + pipeline_q.consumer_release(smem_pipe_release_q); // release V + ++smem_pipe_release_q; + + // warpgroup_wait<0>(); + // Tensor dQ_tmp = make_tensor(tdQrdQ.data(), flash::convert_layout_acc_rowcol(tdQrdQ.layout())); + // if (cute::thread0()) { print_tensor(dQ_tmp); printf("\n"); } + + } + + { + // SMEM fence to make sure sP is written before it's read by WGMMA + cutlass::arch::fence_view_async_shared(); + // dS is already written to smem, and the smem for dQ is empty (from warp 1 doing TMA_REDUCE_ADD) + // int warp_group_idx = cutlass::canonical_warp_group_idx(); + // if (warp_group_idx == 1 + (m_block + 1) % 2) { + // // cutlass::arch::NamedBarrier::arrive(NumMmaThreads + 32, 0 + (m_block + 1) % 2 /*id*/); + // cutlass::arch::NamedBarrier::arrive(NumMmaThreads, 4); + // } else { + // // cutlass::arch::NamedBarrier::sync(NumMmaThreads + 32, 0 + (m_block + 1) % 2 /*id*/); + // cutlass::arch::NamedBarrier::sync(NumMmaThreads, 4); + // static_assert(!Mma_dQ_is_RS); + // Tensor tdQrdQ = partition_fragment_C(tiledMmadQ, select(TileShape_MNK{})); + // if constexpr (!dQ_swapAB) { + // Tensor tdQrdS = threadMmadQ.partition_fragment_A(sdS); + // Tensor tdQrK = threadMmadQ.partition_fragment_B(sKt); + // flash::gemm(tiledMmadQ, tdQrdS, tdQrK, tdQrdQ); + // } else { + // Tensor tdQrdS = threadMmadQ.partition_fragment_B(sdS); + // Tensor tdQrK = threadMmadQ.partition_fragment_A(sKt); + // flash::gemm(tiledMmadQ, tdQrK, tdQrdS, tdQrdQ); + // } + // Tensor taccdQrdQ = rmem_thr_copy_dQaccum.retile_S(tdQrdQ); // ((Atom,AtomNum), MMA_M, MMA_N) + // cutlass::arch::NamedBarrier::sync(NumMmaThreads / 2 + 32, 0 + (m_block + 1) % 2 /*id*/); + // cute::copy(rmem_tiled_copy_dQaccum, taccdQrdQ, tdQsdQaccum2(_, _, _, (m_block + 1) % 2)); + // cutlass::arch::fence_view_async_shared(); + // cutlass::arch::NamedBarrier::arrive(NumMmaThreads / 2 + 32, 2 + (m_block + 1) % 2 /*id*/); // sdQ ready to be written to gmem + // Tensor dQ_tmp = make_tensor(tdQrdQ.data(), flash::convert_layout_acc_rowcol(tdQrdQ.layout())); + // // if (blockIdx.x == 0 && threadIdx.x == 128) { print_tensor(dQ_tmp); printf("\n"); } + // // if (blockIdx.x == 0 && threadIdx.x == 128) { print(taccdQrdQ); printf("\n"); print(tdQsdQaccum2); printf("\n"); } + // } + cutlass::arch::NamedBarrier::sync(NumMmaThreads + 32, 0 /*id*/); + // cutlass::arch::NamedBarrier::sync(NumMmaThreads + 32, 0 + 0 % 2 /*id*/); + // cutlass::arch::NamedBarrier::sync(NumMmaThreads, 0 /*id*/); + static_assert(!Mma_dQ_is_RS); + Tensor tdQrdQ = partition_fragment_C(tiledMmadQ, select(TileShape_MNK{})); + if constexpr (!dQ_swapAB) { + Tensor tdQrdS = threadMmadQ.partition_fragment_A(sdS); + Tensor tdQrK = threadMmadQ.partition_fragment_B(sKt); + flash::gemm(tiledMmadQ, tdQrdS, tdQrK, tdQrdQ); + } else { + Tensor tdQrdS = threadMmadQ.partition_fragment_B(sdS); + Tensor tdQrK = threadMmadQ.partition_fragment_A(sKt); + flash::gemm(tiledMmadQ, tdQrK, tdQrdS, tdQrdQ); + } + // Tensor dQ_tmp = make_tensor(tdQrdQ.data(), flash::convert_layout_acc_rowcol(tdQrdQ.layout())); + // if (blockIdx.x == 0 && threadIdx.x == 128) { print_tensor(dQ_tmp); printf("\n"); } + Tensor taccdQrdQ = rmem_thr_copy_dQaccum.retile_S(tdQrdQ); // ((Atom,AtomNum), MMA_M, MMA_N) + // cutlass::arch::NamedBarrier::sync(NumMmaThreads + 32, 1 /*id*/); // sdQ empty, ready to be written to + cute::copy(rmem_tiled_copy_dQaccum, taccdQrdQ, tdQsdQaccum); + // cute::copy(rmem_tiled_copy_dQaccum, taccdQrdQ, tdQsdQaccum2(_, _, _, 0 % 2)); + cutlass::arch::fence_view_async_shared(); + cutlass::arch::NamedBarrier::arrive(NumMmaThreads + 32, 2 /*id*/); // sdQ ready to be written to gmem + // cutlass::arch::NamedBarrier::arrive(NumMmaThreads + 32, 2 + 0 % 2 /*id*/); // sdQ ready to be written to gmem + // if (blockIdx.x == 0 && threadIdx.x == 128) { print_tensor(sdQ); printf("\n"); } + } + + } else { // !SdP_swapAB + Tensor tPsP = smem_thr_copy_PdS.partition_D(sP); // ((Atom,AtomNum),PIPE_M,PIPE_N) + Tensor tdSsdS = smem_thr_copy_PdS.partition_D(sdS); // ((Atom,AtomNum),PIPE_M,PIPE_N) + + // Allocate "fragments/descriptors" + Tensor tSrQ = threadMmaSdP.partition_fragment_A(sQ); + Tensor tSrK = threadMmaSdP.partition_fragment_B(sK); + Tensor tdPrdO = threadMmaSdP.partition_fragment_A(sdO); + Tensor tdPrV = threadMmaSdP.partition_fragment_B(sV); + + Tensor caccS = make_identity_tensor(select<0, 1>(TileShape_MNK{})); // (BLK_M,BLK_N) -> (blk_m,blk_n) + Tensor taccScS = threadMmaSdP.partition_C(caccS); // (MMA,MMA_N,MMA_N) + static_assert(decltype(size<0, 0>(taccScS))::value == 2); + static_assert(decltype(size<0, 1>(taccScS))::value == 2); + // taccScS has shape ((2, 2, V), MMA_M, MMA_N), we only take only the row indices. + Tensor taccScS_row = taccScS(make_coord(_0{}, _, _0{}), _, _0{}); + Tensor lse = make_tensor(Shape>{}); + Tensor dP_sum = make_fragment_like(lse); + #pragma unroll + for (int mi = 0; mi < size(lse); ++mi) { + const int row = get<0>(taccScS_row(mi)); + lse(mi) = row < params.seqlen_q - m_block * kBlockM ? gLSE(row) : INFINITY; + dP_sum(mi) = row < params.seqlen_q - m_block * kBlockM ? gdPsum(row) : 0; + } + // if (cute::thread0()) { print_tensor(dP_sum); printf("\n"); } + // We want LSE = inf if the row is OOB. In that case Q would be zero, K would be zero, + // and scores would be zero. With LSE = 0, probs will be all 1's, and when we multiply + // with V (which would be zero), we're fine. However, with ALiBi, we might modify these + // scores, and probs can become NaN. Instead if we set LSE = inf for OOB rows, probs are always 0. + + clear(tdKrdK); + clear(tdVrdV); + + shared_storage.barrier_K.wait(0); + shared_storage.barrier_V.wait(0); + + CUTLASS_PRAGMA_NO_UNROLL + for (; m_block >= 0; --m_block) { + Tensor tSrS = partition_fragment_C(tiledMmaSdP, select<0, 1>(TileShape_MNK{})); + pipeline_q.consumer_wait(smem_pipe_read_q); + flash::gemm(tiledMmaSdP, tSrQ(_, _, _, smem_pipe_read_q.index()), tSrK, tSrS); + Tensor tdPrdP = partition_fragment_C(tiledMmaSdP, select<0, 1>(TileShape_MNK{})); + pipeline_do.consumer_wait(smem_pipe_read_do); + // if (blockIdx.x == 0 && blockIdx.z == 0 && threadIdx.x == 128) { printf("After dO wait\n"); } + flash::gemm(tiledMmaSdP, tdPrdO(_, _, _, smem_pipe_read_do.index()), tdPrV, tdPrdP); + + warpgroup_wait<1>(); + // Reshape tSrS from ((2, 2, V), MMA_N, MMA_M) to (nrow=(2, V, MMA_M), ncol=(2, MMA_N)) + Tensor scores = make_tensor(tSrS.data(), flash::convert_layout_acc_rowcol(tSrS.layout())); + flash::scale_apply_exp2(scores, lse, params.scale_softmax_log2); + // if (blockIdx.x == 0 && blockIdx.z == 0 && threadIdx.x == 128) { print_tensor(scores); printf("\n"); } + + // Convert scores from fp32 to fp16/bf16 + Tensor rP = flash::convert_type(tSrS); + Tensor tPaP = smem_thr_copy_PdS.retile_S(rP); // ((Atom,AtomNum), MMA_N, MMA_N) + cutlass::arch::NamedBarrier::sync(NumMmaThreads, 8 /*id*/); + cute::copy(smem_tiled_copy_PdS, tPaP, tPsP); + int const warp_group_idx = cutlass::canonical_warp_group_idx() - 1; + cutlass::arch::NamedBarrier::arrive(NumMmaThreads, 4 + warp_group_idx /*id*/); + // if (blockIdx.x == 0 && blockIdx.z == 0 && (threadIdx.x == 128 || threadIdx.x == 256)) { printf("After barrier arrive 4, tidx = %d\n", threadIdx.x); } + + warpgroup_wait<0>(); + // Reshape tdPrdP from ((2, 2, V), MMA_N, MMA_M) to (nrow=(2, V, MMA_M), ncol=(2, MMA_N)) + Tensor dS = make_tensor(tdPrdP.data(), scores.layout()); + // if (blockIdx.x == 0 && blockIdx.z == 1 && threadIdx.x == 128) { print_tensor(dS); printf("\n"); } + // if (blockIdx.x == 0 && blockIdx.z == 1 && threadIdx.x == 128) { print_tensor(dP_sum); printf("\n"); } + #pragma unroll + for (int mi = 0; mi < size<0>(dS); ++mi) { + #pragma unroll + for (int ni = 0; ni < size<1>(dS); ++ni) { dS(mi, ni) = scores(mi, ni) * (dS(mi, ni) - dP_sum(mi)); } + } + // if (blockIdx.x == 0 && blockIdx.z == 0 && threadIdx.x == 128) { print_tensor(dS); printf("\n"); } + Tensor rdS = flash::convert_type(tdPrdP); + + Tensor tdSadS = smem_thr_copy_PdS.retile_S(rdS); // ((Atom,AtomNum), MMA_N, MMA_N) + cute::copy(smem_tiled_copy_PdS, tdSadS, tdSsdS); + cutlass::arch::NamedBarrier::arrive(NumMmaThreads, 6 + warp_group_idx /*id*/); + // if (blockIdx.x == 0 && blockIdx.z == 0 && (threadIdx.x == 128 || threadIdx.x == 256)) { printf("After barrier arrive 6, tidx = %d\n", threadIdx.x); } + + if (m_block > 0) { + gLSE.data() = gLSE.data() + (-int(kBlockM)); + gdPsum.data() = gdPsum.data() + (-int(kBlockM)); + } + #pragma unroll + for (int mi = 0; mi < size(lse); ++mi) { + const int row = get<1>(taccScS_row(mi)); + lse(mi) = gLSE(row); + dP_sum(mi) = gdPsum(row); + } + + Tensor tdQrdQ = partition_fragment_C(tiledMmadQ, select(TileShape_MNK{})); + if constexpr (Mma_dQ_is_RS) { + static_assert(!dQ_swapAB); + Tensor tdQrdS = make_tensor(rdS.data(), convert_layout_acc_Aregs(tdPrdP.layout())); + Tensor tdQrK = threadMmadQ.partition_fragment_B(sKt); + flash::gemm(tiledMmadQ, tdQrdS, tdQrK, tdQrdQ); + // if (cute::thread0()) { print(tdQrdS); printf("\n"); print(tdQrK); printf("\n"); print(tdQrdQ); printf("\n"); } + } + + cutlass::arch::fence_view_async_shared(); + // if (blockIdx.x == 0 && blockIdx.z == 0 && (threadIdx.x == 128 || threadIdx.x == 256)) { printf("Before barrier sync 4, tidx = %d\n", threadIdx.x); } + cutlass::arch::NamedBarrier::sync(NumMmaThreads, 4 + 1 - warp_group_idx /*id*/); + // if (blockIdx.x == 0 && blockIdx.z == 0 && (threadIdx.x == 128 || threadIdx.x == 256)) { printf("After barrier sync 4, tidx = %d\n", threadIdx.x); } + // if (blockIdx.x == 0 && blockIdx.z == 0 && (threadIdx.x == 128)) { print_tensor(sPt); printf("\n"); } + // if (blockIdx.x == 0 && blockIdx.z == 0 && (threadIdx.x == 128)) { print_tensor(sdOt); printf("\n"); } + if constexpr (!dKV_swapAB) { + Tensor tdVrP = threadMmadKV.partition_fragment_A(sPt); + Tensor tdVrdO = threadMmadKV.partition_fragment_B(sdOt); + flash::gemm(tiledMmadKV, tdVrP, tdVrdO(_, _, _, smem_pipe_read_do.index()), tdVrdV); + } else { + Tensor tdVrP = threadMmadKV.partition_fragment_B(sPt); + Tensor tdVrdO = threadMmadKV.partition_fragment_A(sdOt); + flash::gemm(tiledMmadKV, tdVrdO(_, _, _, smem_pipe_read_do.index()), tdVrP, tdVrdV); + } + ++smem_pipe_read_do; + // warpgroup_wait<0>(); + // Tensor dV_tmp = make_tensor(tdVrdV.data(), flash::convert_layout_acc_rowcol(tdVrdV.layout())); + // if (blockIdx.x == 0 && threadIdx.x == 128) { print_tensor(dV_tmp); printf("\n"); } + + cutlass::arch::fence_view_async_shared(); + // if (blockIdx.x == 0 && blockIdx.z == 0 && (threadIdx.x == 128 || threadIdx.x == 256)) { printf("Before barrier sync 6, tidx = %d\n", threadIdx.x); } + cutlass::arch::NamedBarrier::sync(NumMmaThreads, 6 + 1 - warp_group_idx /*id*/); + // if (blockIdx.x == 0 && blockIdx.z == 0 && (threadIdx.x == 128 || threadIdx.x == 256)) { printf("After barrier sync 6, tidx = %d\n", threadIdx.x); } + if constexpr (!dKV_swapAB) { + Tensor tdKrdS = threadMmadKV.partition_fragment_A(sdSt); + Tensor tdKrQ = threadMmadKV.partition_fragment_B(sQt); + flash::gemm(tiledMmadKV, tdKrdS, tdKrQ(_, _, _, smem_pipe_read_q.index()), tdKrdK); + } else { + Tensor tdKrdS = threadMmadKV.partition_fragment_B(sdSt); + Tensor tdKrQ = threadMmadKV.partition_fragment_A(sQt); + flash::gemm(tiledMmadKV, tdKrQ(_, _, _, smem_pipe_read_q.index()), tdKrdS, tdKrdK); + } + ++smem_pipe_read_q; + warpgroup_wait<0>(); + // Tensor dK_tmp = make_tensor(tdKrdK.data(), flash::convert_layout_acc_rowcol(tdKrdK.layout())); + // if (blockIdx.x == 0 && threadIdx.x == 128) { print_tensor(dK_tmp); printf("\n"); } + + pipeline_do.consumer_release(smem_pipe_release_do); // release V + ++smem_pipe_release_do; + pipeline_q.consumer_release(smem_pipe_release_q); // release V + ++smem_pipe_release_q; + + // warpgroup_wait<0>(); + // Tensor dQ_tmp = make_tensor(tdQrdQ.data(), flash::convert_layout_acc_rowcol(tdQrdQ.layout())); + // if (cute::thread0()) { print_tensor(dQ_tmp); printf("\n"); } + + cutlass::arch::NamedBarrier::sync(NumMmaThreads, 8 /*id*/); + } + + } + + // Epilogue + + Tensor sdK = make_tensor(make_smem_ptr(shared_storage.smem_dk.data()), typename Ktraits::SmemLayoutdK{}); + Tensor sdV = make_tensor(make_smem_ptr(shared_storage.smem_dv.data()), typename Ktraits::SmemLayoutdV{}); + Tensor sdKt = make_tensor(make_smem_ptr(shared_storage.smem_dk.data()), typename Ktraits::SmemLayoutdKt{}); + Tensor sdVt = make_tensor(make_smem_ptr(shared_storage.smem_dv.data()), typename Ktraits::SmemLayoutdVt{}); + + auto smem_tiled_copy_dKV = make_tiled_copy_C(typename Ktraits::SmemCopyAtomdKV{}, tiledMmadKV); + auto smem_thr_copy_dKV = smem_tiled_copy_dKV.get_thread_slice(threadIdx.x - NumCopyThreads); + + int n_block = blockIdx.x; + bidb = blockIdx.z; // The block index for the batch. + bidh = blockIdx.y; // The block index for the head. + Tensor mdK = tma_store_dK.get_tma_tensor(make_shape(params.seqlen_k, params.d, params.h, params.b)); + Tensor mdV = tma_store_dV.get_tma_tensor(make_shape(params.seqlen_k, params.d, params.h, params.b)); + Tensor gdK = local_tile(mdK(_, _, bidh, bidb), select<1, 2>(TileShape_MNK{}), make_coord(n_block, _0{})); // (M, K) + Tensor gdV = local_tile(mdV(_, _, bidh, bidb), select<1, 2>(TileShape_MNK{}), make_coord(n_block, _0{})); // (M, K) + auto block_tma_dK = tma_store_dK.get_slice(_0{}); + auto block_tma_dV = tma_store_dV.get_slice(_0{}); + Tensor tdKgdK = block_tma_dK.partition_D(gdK); // (TMA, TMA_M, TMA_K) + Tensor tdKsdK = block_tma_dK.partition_S(sdK); // (TMA, TMA_M, TMA_K) + Tensor tdVgdV = block_tma_dV.partition_D(gdV); // (TMA, TMA_M, TMA_K) + Tensor tdVsdV = block_tma_dV.partition_S(sdV); // (TMA, TMA_M, TMA_K) + + + // Very slightly faster to do the smem write and TMA write for dV first, then do the same for dK, + // Instead of doing both at the same time. + Tensor tdVrdV_out = convert_type(tdVrdV); + #pragma unroll + for (int i = 0; i < size(tdKrdK); ++i) { tdKrdK(i) *= params.scale_softmax; } + Tensor tdKrdK_out = convert_type(tdKrdK); + + Tensor taccdVrdV = smem_thr_copy_dKV.retile_S(tdVrdV_out); // ((Atom,AtomNum), MMA_M, MMA_N) + Tensor taccdKrdK = smem_thr_copy_dKV.retile_S(tdKrdK_out); // ((Atom,AtomNum), MMA_M, MMA_N) + + // Can't use __syncthreads() in WS code + auto synchronize = [&] () { cutlass::arch::NamedBarrier::sync(NumMmaThreads, cutlass::arch::ReservedNamedBarriers::EpilogueBarrier); }; + synchronize(); + if constexpr (!dKV_swapAB) { + Tensor taccdVsdV = smem_thr_copy_dKV.partition_D(sdV); // ((Atom,AtomNum),PIPE_M,PIPE_N) + cute::copy(smem_tiled_copy_dKV, taccdVrdV, taccdVsdV); + } else { + Tensor taccdVsdVt = smem_thr_copy_dKV.partition_D(sdVt); // ((Atom,AtomNum),PIPE_M,PIPE_N) + cute::copy(smem_tiled_copy_dKV, taccdVrdV, taccdVsdVt); + } + cutlass::arch::fence_view_async_shared(); // ensure smem writes are visible to TMA + synchronize(); + + lane_predicate = cute::elect_one_sync(); + warp_idx = cutlass::canonical_warp_idx_sync(); + if (warp_idx == NumCopyThreads / cutlass::NumThreadsPerWarp && lane_predicate) { + cute::copy(tma_store_dV, tdVsdV, tdVgdV); + tma_store_arrive(); + } + + if constexpr (!dKV_swapAB) { + Tensor taccdKsdK = smem_thr_copy_dKV.partition_D(sdK); // ((Atom,AtomNum),PIPE_M,PIPE_N) + cute::copy(smem_tiled_copy_dKV, taccdKrdK, taccdKsdK); + } else { + Tensor taccdKsdKt = smem_thr_copy_dKV.partition_D(sdKt); // ((Atom,AtomNum),PIPE_M,PIPE_N) + cute::copy(smem_tiled_copy_dKV, taccdKrdK, taccdKsdKt); + } + cutlass::arch::fence_view_async_shared(); // ensure smem writes are visible to TMA + synchronize(); + if (warp_idx == NumCopyThreads / cutlass::NumThreadsPerWarp && lane_predicate) { + cute::copy(tma_store_dK, tdKsdK, tdKgdK); + tma_store_arrive(); + } + tma_store_wait<0>(); + } + +} + +} // namespace flash diff --git a/hopper/flash_bwd_launch_template.h b/hopper/flash_bwd_launch_template.h new file mode 100644 index 000000000..d0e3f24a4 --- /dev/null +++ b/hopper/flash_bwd_launch_template.h @@ -0,0 +1,246 @@ +/****************************************************************************** + * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. + ******************************************************************************/ + +#pragma once + +#include + +#include "cute/tensor.hpp" + +#include "cutlass/cluster_launch.hpp" + +#include "static_switch.h" +#include "flash.h" +#include "flash_bwd_preprocess_kernel.h" +#include "flash_bwd_kernel.h" +#include "kernel_traits.h" + +template +__global__ void flash_bwd_dot_do_o_kernel(const Flash_bwd_params params) { + flash::compute_dot_do_o(params); +} + +// template +// __global__ void flash_bwd_convert_dq_kernel(const Flash_bwd_params params, const int nsplits) { +// flash::convert_dQ(params, nsplits); +// } + +template +__global__ void flash_bwd_convert_dkv_kernel(const Flash_bwd_params params) { + flash::convert_dKV(params); +} + +template +void run_flash_bwd(Flash_bwd_params ¶ms, cudaStream_t stream) { + int num_m_block = cute::ceil_div(params.seqlen_q, Kernel_traits::kBlockM); + dim3 grid_m(num_m_block, params.b, params.h); + flash_bwd_dot_do_o_kernel<<>>(params); + // If we use both TMA_STORE (for n_block=0) and TMA_REDUCE_ADD (for n_block>0), we don't need to clear dQaccum + // flash_bwd_dot_do_o_kernel<<>>(params); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + + using Element = typename Kernel_traits::Element; + using ElementAccum = typename Kernel_traits::ElementAccum; + using TileShape_MNK = typename Kernel_traits::TileShape_MNK; + using ClusterShape = typename Kernel_traits::ClusterShape_MNK; + + Tensor mQ = make_tensor(make_gmem_ptr(reinterpret_cast(params.q_ptr)), + make_shape(params.seqlen_q, params.d, params.h, params.b), + make_stride(params.q_row_stride, _1{}, params.q_head_stride, params.q_batch_stride)); + auto tma_load_Q = make_tma_copy( + typename Kernel_traits::GmemTiledCopyQdO{}, + mQ, + typename Kernel_traits::SmemLayoutQ{}(_, _, _0{}), + // typename Kernel_traits::SmemLayoutQ{}, + select<0, 2>(TileShape_MNK{}), + size<1>(ClusterShape{})); // mcast along N mode for this M load, if any + Tensor mdO = make_tensor(make_gmem_ptr(reinterpret_cast(params.do_ptr)), + make_shape(params.seqlen_q, params.d, params.h, params.b), + make_stride(params.do_row_stride, _1{}, params.do_head_stride, params.do_batch_stride)); + auto tma_load_dO = make_tma_copy( + typename Kernel_traits::GmemTiledCopyQdO{}, + mdO, + typename Kernel_traits::SmemLayoutdO{}(_, _, _0{}), + // typename Kernel_traits::SmemLayoutdO{}, + select<0, 2>(TileShape_MNK{}), + size<1>(ClusterShape{})); // mcast along N mode for this M load, if any + Tensor mK = make_tensor(make_gmem_ptr(reinterpret_cast(params.k_ptr)), + make_shape(params.seqlen_k, params.d, params.h, params.b), + make_stride(params.k_row_stride, _1{}, params.k_head_stride, params.k_batch_stride)); + auto tma_load_K = make_tma_copy( + typename Kernel_traits::GmemTiledCopyKV{}, + mK, + typename Kernel_traits::SmemLayoutK{}, + // typename Kernel_traits::SmemLayoutK{}(_, _, _0{}), + select<1, 2>(TileShape_MNK{}), + _1{}); // no mcast for K + Tensor mV = make_tensor(make_gmem_ptr(reinterpret_cast(params.v_ptr)), + make_shape(params.seqlen_k, params.d, params.h, params.b), + make_stride(params.v_row_stride, _1{}, params.v_head_stride, params.v_batch_stride)); + auto tma_load_V = make_tma_copy( + typename Kernel_traits::GmemTiledCopyKV{}, + mV, + typename Kernel_traits::SmemLayoutV{}, + // typename Kernel_traits::SmemLayoutV{}(_, _, _0{}), + select<1, 2>(TileShape_MNK{}), + _1{}); // no mcast for V + Tensor mdK = make_tensor(make_gmem_ptr(reinterpret_cast(params.dk_ptr)), + make_shape(params.seqlen_k, params.d, params.h, params.b), + make_stride(params.dk_row_stride, _1{}, params.dk_head_stride, params.dk_batch_stride)); + auto tma_store_dK = make_tma_copy( + typename Kernel_traits::GmemTiledCopydKV{}, + mdK, + typename Kernel_traits::SmemLayoutdK{}, + select<1, 2>(TileShape_MNK{}), + _1{}); // no mcast for output + Tensor mdV = make_tensor(make_gmem_ptr(reinterpret_cast(params.dv_ptr)), + make_shape(params.seqlen_k, params.d, params.h, params.b), + make_stride(params.dv_row_stride, _1{}, params.dv_head_stride, params.dv_batch_stride)); + auto tma_store_dV = make_tma_copy( + typename Kernel_traits::GmemTiledCopydKV{}, + mdV, + typename Kernel_traits::SmemLayoutdV{}, + select<1, 2>(TileShape_MNK{}), + _1{}); // no mcast for output + Tensor mdQ = make_tensor(make_gmem_ptr(reinterpret_cast(params.dq_ptr)), + make_shape(params.seqlen_q, params.d, params.h, params.b), + make_stride(params.dq_row_stride, _1{}, params.dq_head_stride, params.dq_batch_stride)); + Tensor mdQaccum = make_tensor(make_gmem_ptr(reinterpret_cast(params.dq_accum_ptr)), + make_shape(params.seqlen_q, params.d, params.h, params.b), + make_stride(params.d * params.h, _1{}, params.d, params.d * params.h * params.seqlen_q_rounded)); + auto tma_store_dQaccum = make_tma_copy( + // typename Kernel_traits::GmemTiledCopydKV{}, + typename cute::SM90_TMA_STORE{}, + // mdQ, + mdQaccum, + // typename Kernel_traits::SmemLayoutdQTMA{}, + typename Kernel_traits::SmemLayoutdQaccTMA{}, + select<0, 2>(TileShape_MNK{}), + _1{}); // no mcast for output + auto tma_reduce_add_dQaccum = make_tma_copy( + // typename Kernel_traits::GmemTiledCopydKV{}, + typename cute::SM90_TMA_REDUCE_ADD{}, + // mdQ, + mdQaccum, + // typename Kernel_traits::SmemLayoutdQTMA{}, + typename Kernel_traits::SmemLayoutdQaccTMA{}, + select<0, 2>(TileShape_MNK{}), + _1{}); // no mcast for output + // print(typename Kernel_traits::SmemLayoutVt{}); printf("\n"); print(typename Kernel_traits::SmemLayoutVt_tmp{}); + + // print(typename Kernel_traits::TiledMmaSdP{}); printf("\n"); + // print(typename Kernel_traits::TiledMmadKV{}); printf("\n"); + // print(typename Kernel_traits::TiledMmadQ{}); printf("\n"); + // print(typename Kernel_traits::SmemLayoutAtomK{}); printf("\n"); + // print(typename Kernel_traits::SmemLayoutK{}); printf("\n"); + // print(typename Kernel_traits::SmemLayoutKt{}); printf("\n"); + // Get the ptr to kernel function. + void *kernel; + if constexpr (!Kernel_traits::Is_WS) { + kernel = (void *)flash::compute_dqkv; + } else { + kernel = (void *)flash::compute_dqkv_ws; + } + // void *kernel = (void *)flash::compute_dqkv_seqqpar; + auto shared_storage = typename Kernel_traits::SharedStorage{}; + int smem_size = sizeof(typename Kernel_traits::SharedStorage); + int smem_size_q = sizeof(decltype(shared_storage.smem_q)); + int smem_size_do = sizeof(decltype(shared_storage.smem_do)); + int smem_size_k = sizeof(decltype(shared_storage.smem_k)); + int smem_size_v = sizeof(decltype(shared_storage.smem_v)); + // int smem_size_p = sizeof(decltype(shared_storage.smem_p)); + int smem_size_ds = sizeof(decltype(shared_storage.smem_ds)); + // printf("smem_size = %d, q = %d, do = %d, k = %d, v = %d, p = %d, ds = %d\n", smem_size, smem_size_q, smem_size_do, smem_size_k, smem_size_v, smem_size_p, smem_size_ds); + // printf("smem_size = %d, q = %d, do = %d, k = %d, v = %d, ds = %d\n", smem_size, smem_size_q, smem_size_do, smem_size_k, smem_size_v, smem_size_ds); + if (smem_size >= 48 * 1024) { + C10_CUDA_CHECK(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); + } + + static constexpr int ctaSize = Kernel_traits::kNWarps * 32; + int num_blocks_n = cutlass::ceil_div(params.seqlen_k, Kernel_traits::kBlockN); + num_blocks_n = cutlass::ceil_div(num_blocks_n, size<1>(ClusterShape{})) * size<1>(ClusterShape{}); + dim3 grid_dims(num_blocks_n, params.h, params.b); + // int num_blocks_m = cutlass::ceil_div(params.seqlen_q, Kernel_traits::kBlockM); + // num_blocks_m = cutlass::ceil_div(num_blocks_m, size<0>(ClusterShape{})) * size<0>(ClusterShape{}); + // dim3 grid_dims(num_blocks_m, params.h, params.b); + dim3 block_dims(ctaSize); + dim3 cluster_dims(size<0>(ClusterShape{}), size<1>(ClusterShape{}), size<2>(ClusterShape{})); + cutlass::ClusterLaunchParams launch_params{grid_dims, block_dims, cluster_dims, smem_size, stream}; + if constexpr (!Kernel_traits::Is_WS) { + cutlass::launch_kernel_on_cluster(launch_params, kernel, params, tma_load_Q, tma_load_dO, + tma_load_K, tma_load_V, tma_store_dK, tma_store_dV); + } else { + cutlass::launch_kernel_on_cluster(launch_params, kernel, params, tma_load_Q, tma_load_dO, + tma_load_K, tma_load_V, tma_store_dK, tma_store_dV, tma_store_dQaccum, tma_reduce_add_dQaccum); + } + // cutlass::launch_kernel_on_cluster(launch_params, kernel, params, tma_load_Q, tma_load_dO, + // tma_load_K, tma_load_V, tma_store_dQaccum, tma_store_dK, tma_store_dV); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + + auto tma_load_dQaccum = make_tma_copy( + typename cute::SM90_TMA_LOAD{}, + mdQaccum, + typename Kernel_traits::SmemLayoutdQaccTMA{}, + select<0, 2>(TileShape_MNK{}), + _1{}); // no mcast for output + // auto kernel_dq = &flash_bwd_convert_dq_kernel; + auto kernel_dq = &flash::convert_dQ; + if (Kernel_traits::kSmemdQSize * 2 + 8 >= 48 * 1024) { + C10_CUDA_CHECK(cudaFuncSetAttribute( + kernel_dq, cudaFuncAttributeMaxDynamicSharedMemorySize, Kernel_traits::kSmemdQSize * 2 + 8)); + } + kernel_dq<<>>(params, tma_load_dQaccum); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + // auto kernel_dkv = &flash_bwd_convert_dkv_kernel; + // if (Kernel_traits::kSmemdKVSize >= 48 * 1024) { + // C10_CUDA_CHECK(cudaFuncSetAttribute( + // kernel_dkv, cudaFuncAttributeMaxDynamicSharedMemorySize, Kernel_traits::kSmemdKVSize)); + // } + // int num_n_block = cute::ceil_div(params.seqlen_k, Kernel_traits::kBlockN); + // dim3 grid_n(num_n_block, params.b, params.h); + // kernel_dkv<<>>(params); + // C10_CUDA_KERNEL_LAUNCH_CHECK(); +} + + +template +void run_mha_bwd_hdim64(Flash_bwd_params ¶ms, cudaStream_t stream) { + constexpr static int Headdim = 64; + // BOOL_SWITCH(params.is_causal, Is_causal, [&] { + // run_flash_bwd(params, stream); + // }); + // run_flash_bwd, false>(params, stream); + run_flash_bwd, false>(params, stream); + // run_flash_bwd, false>(params, stream); +} + +template +void run_mha_bwd_hdim128(Flash_bwd_params ¶ms, cudaStream_t stream) { + constexpr static int Headdim = 128; + // BOOL_SWITCH(params.is_causal, Is_causal, [&] { + // run_flash_bwd(params, stream); + // }); + // run_flash_bwd, false>(params, stream); + // run_flash_bwd, false>(params, stream); + // run_flash_bwd, false>(params, stream); + // run_flash_bwd, false>(params, stream); + // run_flash_bwd, false>(params, stream); + run_flash_bwd, false>(params, stream); + // run_flash_bwd, false>(params, stream); + // run_flash_bwd, false>(params, stream); + // run_flash_bwd, false>(params, stream); + // run_flash_bwd, false>(params, stream); + // run_flash_bwd, false>(params, stream); +} + +template +void run_mha_bwd_hdim256(Flash_bwd_params ¶ms, cudaStream_t stream) { + // constexpr static int Headdim = 256; + // BOOL_SWITCH(params.is_causal, Is_causal, [&] { + // run_flash_bwd(params, stream); + // }); +} diff --git a/hopper/flash_bwd_preprocess_kernel.h b/hopper/flash_bwd_preprocess_kernel.h new file mode 100644 index 000000000..ecb51298f --- /dev/null +++ b/hopper/flash_bwd_preprocess_kernel.h @@ -0,0 +1,444 @@ +/*************************************************************************************************** + * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. + ******************************************************************************/ + +#pragma once + +#include + +#include +#include +#include +#include "cutlass/pipeline/pipeline.hpp" + +#include "flash.h" +#include "block_info.h" +#include "kernel_traits.h" +#include "utils.h" + +namespace flash { + +using namespace cute; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ void dot_do_o(Tensor const &do_, Tensor const &o, + Tensor &dP_sum, const int gdP_col_stride, const float scale) { + static_assert(Layout0::rank == 3, "Only support 3D Tensor"); + static_assert(Layout1::rank == 1, "Only support 1D Tensor"); + CUTE_STATIC_ASSERT_V(do_.layout() == o.layout()); + // Reshape do_ and o from (8, kBlockM / 32, kHeadDim / 64) to (kBlockM / 32, 8 * kHeadDim / 64) + // The last coordinate is the "page". + Tensor do_reshaped = make_tensor(do_.data(), make_layout(get<1>(do_.layout()), + make_layout(get<0>(do_.layout()), + get<2>(do_.layout())))); + Tensor o_reshaped = make_tensor(o.data(), do_reshaped.layout()); + Tensor do_fp32 = flash::convert_type(do_reshaped); + Tensor o_fp32 = flash::convert_type(o_reshaped); + #pragma unroll + for (int mi = 0; mi < size<0>(do_reshaped); ++mi) { + float dP_sum_cur = do_fp32(mi, 0) * o_fp32(mi, 0); + #pragma unroll + for (int ni = 1; ni < size<1>(do_reshaped); ni++) { + dP_sum_cur += do_fp32(mi, ni) * o_fp32(mi, ni); + } + flash::SumOp sum_op; + dP_sum_cur = flash::Allreduce::run(dP_sum_cur, sum_op) * scale; + if (threadIdx.x % THREADS_PER_ROW == 0) { + dP_sum(mi * gdP_col_stride + threadIdx.x / THREADS_PER_ROW) = dP_sum_cur; + } + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// Just compute dot(do, o) and write the result (softmax_d) to global memory as a separate kernel. +// This is used in the case where we want to parallelize the backward across seqlen_k. +template +inline __device__ void compute_dot_do_o(const Params ¶ms) { + using Element = typename Kernel_traits::Element; + using ElementAccum = typename Kernel_traits::ElementAccum; + using index_t = typename Kernel_traits::index_t; + + const int m_block = blockIdx.x; + // The block index for the batch. + const int bidb = blockIdx.y; + // The block index for the head. + const int bidh = blockIdx.z; + // The thread index. + const int tidx = threadIdx.x; + + constexpr int kBlockM = Kernel_traits::kBlockM; + constexpr int kHeadDim = Kernel_traits::kHeadDim; + + const BlockInfo binfo(params, bidb); + if (m_block * kBlockM >= binfo.actual_seqlen_q) return; + + const index_t row_offset_do = binfo.q_offset(params.do_batch_stride, params.do_row_stride, bidb) + + m_block * kBlockM * params.do_row_stride + bidh * params.do_head_stride; + const index_t row_offset_o = binfo.q_offset(params.o_batch_stride, params.o_row_stride, bidb) + + m_block * kBlockM * params.o_row_stride + bidh * params.o_head_stride; + const index_t row_offset_dq_accum = binfo.q_offset(params.seqlen_q_rounded * params.h * params.d_rounded, params.h * params.d_rounded, bidb) + + (m_block * kBlockM + (params.cu_seqlens_q == nullptr ? 0 : 128 * bidb)) * params.h * params.d_rounded + bidh * params.d_rounded; + const index_t row_offset_dpsum = (bidb * params.h + bidh) * params.seqlen_q_rounded + m_block * kBlockM; + + Tensor gdO = make_tensor(make_gmem_ptr(reinterpret_cast(params.do_ptr) + row_offset_do), + Shape, Int>{}, + make_stride(params.do_row_stride, _1{})); + Tensor gO = make_tensor(make_gmem_ptr(reinterpret_cast(params.o_ptr) + row_offset_o), + Shape, Int>{}, + make_stride(params.o_row_stride, _1{})); + Tensor gdQaccum = make_tensor(make_gmem_ptr(reinterpret_cast(params.dq_accum_ptr) + row_offset_dq_accum), + Shape, Int>{}, + make_stride(params.h * params.d_rounded, _1{})); + Tensor dP_sum = make_tensor(make_gmem_ptr(reinterpret_cast(params.dsoftmax_sum) + row_offset_dpsum), + Shape>{}, Stride<_1>{}); + + typename Kernel_traits::GmemTiledCopydO gmem_tiled_copy_dO; + auto gmem_thr_copy_dO = gmem_tiled_copy_dO.get_thread_slice(tidx); + // TODO: careful, we're zeroing out dQaccum with type float4, but when + // we do atomicAdds, we use type float. The layouts are different. Check this. + typename Kernel_traits::GmemTiledCopydQaccum gmem_tiled_copy_dQaccum; + auto gmem_thr_copy_dQaccum = gmem_tiled_copy_dQaccum.get_thread_slice(tidx); + + Tensor tdOgdO = gmem_thr_copy_dO.partition_S(gdO); + Tensor tdOgO = gmem_thr_copy_dO.partition_S(gO); + Tensor tdQgdQaccum = gmem_thr_copy_dQaccum.partition_D(gdQaccum); + + Tensor cdO = make_identity_tensor(Shape, Int>{}); // (BLK_M,BLK_K) -> (blk_m,blk_k) + Tensor tdOcdO = gmem_thr_copy_dO.partition_S(cdO); + + // Allocate predicate tensors for k + Tensor tdOpdO = make_tensor(make_shape(size<2>(tdOgdO))); + // Set predicates for k bounds + #pragma unroll + for (int k = 0; k < size(tdOpdO); ++k) {tdOpdO(k) = get<1>(tdOcdO(0, 0, k)) < params.d;} + + Tensor tdOrdO = make_fragment_like(tdOgdO); + Tensor tdOrO = make_fragment_like(tdOgO); + flash::copy( + gmem_tiled_copy_dO, tdOgdO, tdOrdO, tdOcdO, tdOpdO, binfo.actual_seqlen_q - m_block * kBlockM + ); + flash::copy( + gmem_tiled_copy_dO, tdOgO, tdOrO, tdOcdO, tdOpdO, binfo.actual_seqlen_q - m_block * kBlockM + ); + // By right we need to scale dP up by 1/p_dropout, but instead we don't and only scale the final + // results (dQ and dK) by 1/p_dropout. So we need to keep dP_sum scaled down by p_dropout here, + // so that (dP - dP_sum) is on the same scale. + dot_do_o(tdOrdO, tdOrO, dP_sum, + // Kernel_traits::kNThreads / (Kernel_traits::kGmemThreadsPerRow), params.p_dropout); + Kernel_traits::kNThreadsNonWS / (Kernel_traits::kGmemThreadsPerRow), params.p_dropout); + if (Clear_dQaccum) { + // We're actually not zero'ing out all of dQaccum, but only the part that we're going to + // do atomicAdds on. + Tensor zero = make_fragment_like(tdQgdQaccum); + clear(zero); + cute::copy(gmem_tiled_copy_dQaccum, zero, tdQgdQaccum); + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ void clear_dKVaccum(const Params ¶ms) { + using ElementAccum = typename Kernel_traits::ElementAccum; + using index_t = typename Kernel_traits::index_t; + + const int n_block = blockIdx.x; + // The block index for the batch. + const int bidb = blockIdx.y; + // The block index for the head. + const int bidh = blockIdx.z; + // The thread index. + const int tidx = threadIdx.x; + + constexpr int kBlockN = Kernel_traits::kBlockN; + constexpr int kHeadDim = Kernel_traits::kHeadDim; + + const BlockInfo binfo(params, bidb); + if (n_block * kBlockN >= binfo.actual_seqlen_k) return; + + const index_t row_offset_dkv_accum = ((bidb * params.h_k + bidh) * params.seqlen_k_rounded + n_block * kBlockN) * params.d_rounded; + + Tensor gdKaccum = make_tensor(make_gmem_ptr(reinterpret_cast(params.dk_accum_ptr) + row_offset_dkv_accum), + Shape, Int>{}, Stride, _1>{}); + Tensor gdVaccum = make_tensor(make_gmem_ptr(reinterpret_cast(params.dv_accum_ptr) + row_offset_dkv_accum), + Shape, Int>{}, Stride, _1>{}); + + typename Kernel_traits::GmemTiledCopydQaccum gmem_tiled_copy_dKVaccum; + auto gmem_thr_copy_dKVaccum = gmem_tiled_copy_dKVaccum.get_thread_slice(tidx); + Tensor tdKgdKaccum = gmem_thr_copy_dKVaccum.partition_D(gdKaccum); + Tensor tdVgdVaccum = gmem_thr_copy_dKVaccum.partition_D(gdVaccum); + Tensor zero = make_fragment_like(tdKgdKaccum); + clear(zero); + cute::copy(gmem_tiled_copy_dKVaccum, zero, tdKgdKaccum); + cute::copy(gmem_tiled_copy_dKVaccum, zero, tdVgdVaccum); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// Convert dQ from dQaccum (in float) to fp16/bf16. +// This is used in the case where we want to parallelize the backward across seqlen_k. +// template +template +// inline __device__ void convert_dQ(const Params ¶ms, +__global__ void convert_dQ(CUTE_GRID_CONSTANT Flash_bwd_params const params, + CUTE_GRID_CONSTANT TiledCopydQaccum const tma_load_dQaccum) { + using Element = typename Kernel_traits::Element; + using ElementAccum = typename Kernel_traits::ElementAccum; + using index_t = typename Kernel_traits::index_t; + + // Shared memory. + extern __shared__ char smem_[]; + + int lane_predicate = cute::elect_one_sync(); + int warp_idx = cutlass::canonical_warp_idx_sync(); + // Issue Tma Descriptor Prefetch from a single thread + if (warp_idx == 0 && lane_predicate) { + cute::prefetch_tma_descriptor(tma_load_dQaccum.get_tma_descriptor()); + } + + const int m_block = blockIdx.x; + // The block index for the batch. + const int bidb = blockIdx.y; + // The block index for the head. + const int bidh = blockIdx.z; + // The thread index. + const int tidx = threadIdx.x; + + constexpr int kBlockM = Kernel_traits::kBlockM; + constexpr int kHeadDim = Kernel_traits::kHeadDim; + + static constexpr bool dQ_swapAB = Kernel_traits::dQ_swapAB; + + Tensor mdQaccum = tma_load_dQaccum.get_tma_tensor(make_shape(params.seqlen_q, params.d, params.h, params.b)); + Tensor gdQaccum = local_tile(mdQaccum(_, _, bidh, bidb), Shape, Int>{}, make_coord(m_block, _0{})); // (M, K) + + const BlockInfo binfo(params, bidb); + if (m_block * kBlockM >= binfo.actual_seqlen_q) return; + + const index_t row_offset_dq = binfo.q_offset(params.dq_batch_stride, params.dq_row_stride, bidb) + + m_block * kBlockM * params.dq_row_stride + bidh * params.dq_head_stride; + const index_t row_offset_dq_accum = binfo.q_offset(params.seqlen_q_rounded * params.h * params.d_rounded, params.h * params.d_rounded, bidb) + + (m_block * kBlockM + (params.cu_seqlens_q == nullptr ? 0 : 128 * bidb)) * params.h * params.d_rounded + bidh * params.d_rounded; + + Tensor gdQ = make_tensor(make_gmem_ptr(reinterpret_cast(params.dq_ptr) + row_offset_dq), + Shape, Int>{}, + make_stride(params.dq_row_stride, _1{})); + // Tensor gdQaccum = make_tensor(make_gmem_ptr(reinterpret_cast(params.dq_accum_ptr) + row_offset_dq_accum), + // Shape, Int>{}, + // make_stride(params.h * params.d_rounded, _1{})); + + Tensor sdQTMA = make_tensor(make_smem_ptr(reinterpret_cast(smem_)), + typename Kernel_traits::SmemLayoutdQaccTMA{}); + Tensor sdQaccum = make_tensor(make_smem_ptr(reinterpret_cast(smem_)), + typename Kernel_traits::SmemLayoutdQacc{}); + Tensor sdQ = make_tensor(make_smem_ptr(reinterpret_cast(smem_)), + typename Kernel_traits::SmemLayoutdQ{}); + Tensor sdQt = make_tensor(make_smem_ptr(reinterpret_cast(smem_)), + typename Kernel_traits::SmemLayoutdQt{}); + + auto &barrier_dQaccum = *reinterpret_cast(smem_ + sizeof(ElementAccum) * size(sdQTMA)); + + auto block_tma_dQ = tma_load_dQaccum.get_slice(_0{}); + Tensor tdQgdQaccumTMA = block_tma_dQ.partition_S(gdQaccum); // (TMA, TMA_M, TMA_K) + Tensor tdQsdQaccumTMA = block_tma_dQ.partition_D(sdQTMA); // (TMA, TMA_M, TMA_K) + + typename Kernel_traits::GmemTiledCopydQ gmem_tiled_copy_dQ; + auto gmem_thr_copy_dQ = gmem_tiled_copy_dQ.get_thread_slice(tidx); + // typename Kernel_traits::GmemTiledCopydQaccumAtomicAdd gmem_tiled_copy_dQaccum; + // typename Kernel_traits::GmemTiledCopydQaccum gmem_tiled_copy_dQaccum; + // auto gmem_thr_copy_dQaccum = gmem_tiled_copy_dQaccum.get_thread_slice(tidx); + + typename Kernel_traits::TiledMmadQ tiled_mma_dq; + auto smem_tiled_copy_dQ = make_tiled_copy_C(typename Kernel_traits::SmemCopyAtomdQ{}, tiled_mma_dq); + auto smem_thr_copy_dQ = smem_tiled_copy_dQ.get_thread_slice(tidx); + + Tensor tdQsdQ = gmem_thr_copy_dQ.partition_S(sdQ); // ((Atom,AtomNum),ATOM_M,ATOM_N) + Tensor tdQgdQ = gmem_thr_copy_dQ.partition_D(gdQ); + // Tensor tdQgdQaccum = gmem_thr_copy_dQaccum.partition_S(gdQaccum); + + constexpr uint32_t TmaTransactionBytesdQaccum = static_cast(size<0>(sdQTMA) * size<1>(sdQTMA) * cutlass::sizeof_bits_v / 8); + if (warp_idx == 0 && lane_predicate) { + barrier_dQaccum.init(1 /*numThreads*/); + } + __syncthreads(); + if (warp_idx == 0 && lane_predicate) { + barrier_dQaccum.arrive_and_expect_tx(TmaTransactionBytesdQaccum); + copy(tma_load_dQaccum.with(reinterpret_cast(barrier_dQaccum), 0 /*mcast_mask*/), tdQgdQaccumTMA, tdQsdQaccumTMA); + } + barrier_dQaccum.wait(0); + // if (cute::thread0()) { print_tensor(sdQTMA); printf("\n"); } + + typename Kernel_traits::RmemTiledCopydQacc rmem_tiled_copy_dQaccum; + auto rmem_thr_copy_dQaccum = rmem_tiled_copy_dQaccum.get_thread_slice(threadIdx.x); + Tensor tdQsdQaccum = rmem_thr_copy_dQaccum.partition_S(sdQaccum); + + Tensor acc_dq = partition_fragment_C(tiled_mma_dq, Shape, Int>{}); // MMA, MMA_N, MMA_K + CUTE_STATIC_ASSERT_V(size(acc_dq) == size(tdQsdQaccum)); + + Tensor tdQrdQaccum = rmem_thr_copy_dQaccum.retile_D(acc_dq); + cute::copy(rmem_tiled_copy_dQaccum, tdQsdQaccum, tdQrdQaccum); + // Tensor dQ_tmp = make_tensor(acc_dq.data(), flash::convert_layout_acc_rowcol(acc_dq.layout())); + // if (blockIdx.x == 0 && threadIdx.x == 0) { print_tensor(dQ_tmp); printf("\n"); } + #pragma unroll + for (int i = 0; i < size(acc_dq); ++i) { acc_dq(i) *= params.scale_softmax_rp_dropout; } + // Convert acc_dq from fp32 to fp16 + Tensor rdQ = flash::convert_type(acc_dq); + Tensor taccdQrdQ = smem_thr_copy_dQ.retile_S(rdQ); // ((Atom,AtomNum), MMA_N, MMA_N) + // dQacc and dQ uses the same shared memory, need to wait for all threads to finish reading smem first + __syncthreads(); + if constexpr (!dQ_swapAB) { + Tensor taccdQsdQ = smem_thr_copy_dQ.partition_D(sdQ); // ((Atom,AtomNum),PIPE_M,PIPE_N) + cute::copy(smem_tiled_copy_dQ, taccdQrdQ, taccdQsdQ); + } else { + Tensor taccdQsdQt = smem_thr_copy_dQ.partition_D(sdQt); // ((Atom,AtomNum),PIPE_M,PIPE_N) + cute::copy(smem_tiled_copy_dQ, taccdQrdQ, taccdQsdQt); + } + __syncthreads(); + Tensor tdQrdQ = make_tensor(shape(tdQgdQ)); + cute::copy(gmem_tiled_copy_dQ, tdQsdQ, tdQrdQ); + + Tensor cdQ = make_identity_tensor(Shape, Int>{}); // (BLK_M,BLK_K) -> (blk_m,blk_k) + Tensor tdQcdQ = gmem_thr_copy_dQ.partition_D(cdQ); + Tensor tdQpdQ = make_tensor(make_shape(size<2>(tdQgdQ))); + #pragma unroll + for (int k = 0; k < size(tdQpdQ); ++k) { tdQpdQ(k) = get<1>(tdQcdQ(0, 0, k)) < params.d; } + // Clear_OOB_K must be false since we don't want to write zeros to gmem + flash::copy( + gmem_tiled_copy_dQ, tdQrdQ, tdQgdQ, tdQcdQ, tdQpdQ, binfo.actual_seqlen_q - m_block * kBlockM + ); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// Convert dK and dV from dKaccum and dVaccum (in float) to fp16/bf16. +// This is used in the case where we want to parallelize the backward across seqlen_q. +template +inline __device__ void convert_dKV(const Params ¶ms) { + using Element = typename Kernel_traits::Element; + using ElementAccum = typename Kernel_traits::ElementAccum; + using index_t = typename Kernel_traits::index_t; + + // Shared memory. + extern __shared__ char smem_[]; + + const int n_block = blockIdx.x; + // The block index for the batch. + const int bidb = blockIdx.y; + // The block index for the head. + const int bidh = blockIdx.z; + // The thread index. + const int tidx = threadIdx.x; + + constexpr int kBlockN = Kernel_traits::kBlockN; + constexpr int kHeadDim = Kernel_traits::kHeadDim; + + static constexpr bool dKV_swapAB = Kernel_traits::dKV_swapAB; + + const BlockInfo binfo(params, bidb); + if (n_block * kBlockN >= binfo.actual_seqlen_k) return; + + const index_t row_offset_dk = binfo.k_offset(params.dk_batch_stride, params.dk_row_stride, bidb) + + n_block * kBlockN * params.dk_row_stride + bidh * params.dk_head_stride; + const index_t row_offset_dv = binfo.k_offset(params.dv_batch_stride, params.dv_row_stride, bidb) + + n_block * kBlockN * params.dv_row_stride + bidh * params.dv_head_stride; + const index_t row_offset_dkv_accum = ((bidb * params.h_k + bidh) * params.seqlen_k_rounded + + n_block * kBlockN) * params.d_rounded; + + Tensor gdK = make_tensor(make_gmem_ptr(reinterpret_cast(params.dk_ptr) + row_offset_dk), + Shape, Int>{}, + make_stride(params.dk_row_stride, _1{})); + Tensor gdV = make_tensor(make_gmem_ptr(reinterpret_cast(params.dv_ptr) + row_offset_dv), + Shape, Int>{}, + make_stride(params.dv_row_stride, _1{})); + Tensor gdKaccum = make_tensor(make_gmem_ptr(reinterpret_cast(params.dk_accum_ptr) + row_offset_dkv_accum), + Shape, Int>{}, + Stride, _1>{}); + Tensor gdVaccum = make_tensor(make_gmem_ptr(reinterpret_cast(params.dv_accum_ptr) + row_offset_dkv_accum), + Shape, Int>{}, + Stride, _1>{}); + + Tensor sdK = make_tensor(make_smem_ptr(reinterpret_cast(smem_)), + typename Kernel_traits::SmemLayoutdKV{}); + Tensor sdKt = make_tensor(make_smem_ptr(reinterpret_cast(smem_)), + typename Kernel_traits::SmemLayoutdKVt{}); + Tensor sdV = make_tensor(sdK.data() + size(sdK), + typename Kernel_traits::SmemLayoutdKV{}); // (SMEM_N, SMEM_K) + Tensor sdVt = make_tensor(make_smem_ptr(sdK.data() + size(sdK)), + typename Kernel_traits::SmemLayoutdKVt{}); + + typename Kernel_traits::GmemTiledCopydQ gmem_tiled_copy_dKV; + auto gmem_thr_copy_dKV = gmem_tiled_copy_dKV.get_thread_slice(tidx); + // typename Kernel_traits::GmemTiledCopydQaccumAtomicAdd gmem_tiled_copy_dKVaccum; + typename Kernel_traits::GmemTiledCopydQaccum gmem_tiled_copy_dKVaccum; + auto gmem_thr_copy_dKVaccum = gmem_tiled_copy_dKVaccum.get_thread_slice(tidx); + + typename Kernel_traits::TiledMmadKV tiled_mma_dkv; + auto smem_tiled_copy_dKV = make_tiled_copy_C(typename Kernel_traits::SmemCopyAtomdKV{}, tiled_mma_dkv); + auto smem_thr_copy_dKV = smem_tiled_copy_dKV.get_thread_slice(tidx); + + Tensor tdKsdK = gmem_thr_copy_dKV.partition_S(sdK); // ((Atom,AtomNum),ATOM_M,ATOM_N) + Tensor tdKgdK = gmem_thr_copy_dKV.partition_D(gdK); + Tensor tdVsdV = gmem_thr_copy_dKV.partition_S(sdV); // ((Atom,AtomNum),ATOM_M,ATOM_N) + Tensor tdVgdV = gmem_thr_copy_dKV.partition_D(gdV); + Tensor tdKgdKaccum = gmem_thr_copy_dKVaccum.partition_S(gdKaccum); + Tensor tdVgdVaccum = gmem_thr_copy_dKVaccum.partition_S(gdVaccum); + + Tensor acc_dk = partition_fragment_C(tiled_mma_dkv, Shape, Int>{}); // MMA, MMA_N, MMA_K + Tensor acc_dv = partition_fragment_C(tiled_mma_dkv, Shape, Int>{}); // MMA, MMA_N, MMA_K + CUTE_STATIC_ASSERT_V(size(acc_dk) == size(tdKgdKaccum)); + CUTE_STATIC_ASSERT_V(size(acc_dv) == size(tdVgdVaccum)); + + Tensor tdKrdKaccum = make_fragment_like(tdKgdKaccum); + Tensor tdVrdVaccum = make_fragment_like(tdVgdVaccum); + cute::copy(gmem_tiled_copy_dKVaccum, tdKgdKaccum, tdKrdKaccum); + cute::copy(gmem_tiled_copy_dKVaccum, tdVgdVaccum, tdVrdVaccum); + #pragma unroll + for (int i = 0; i < size(acc_dk); ++i) { + acc_dk(i) = tdKrdKaccum(i) * params.scale_softmax_rp_dropout; + } + #pragma unroll + for (int i = 0; i < size(acc_dv); ++i) { + acc_dv(i) = tdVrdVaccum(i) * params.rp_dropout; + } + // Convert acc_dk from fp32 to fp16 + Tensor rdK = flash::convert_type(acc_dk); + Tensor rdV = flash::convert_type(acc_dv); + Tensor taccdKrdK = smem_thr_copy_dKV.retile_S(rdK); // ((Atom,AtomNum), MMA_N, MMA_N) + Tensor taccdVrdV = smem_thr_copy_dKV.retile_S(rdV); // ((Atom,AtomNum), MMA_N, MMA_N) + if constexpr (!dKV_swapAB) { + Tensor taccdKsdK = smem_thr_copy_dKV.partition_D(sdK); // ((Atom,AtomNum),PIPE_M,PIPE_N) + Tensor taccdVsdV = smem_thr_copy_dKV.partition_D(sdV); // ((Atom,AtomNum),PIPE_M,PIPE_N) + cute::copy(smem_tiled_copy_dKV, taccdKrdK, taccdKsdK); + cute::copy(smem_tiled_copy_dKV, taccdVrdV, taccdVsdV); + } else { + Tensor taccdKsdKt = smem_thr_copy_dKV.partition_D(sdKt); // ((Atom,AtomNum),PIPE_M,PIPE_N) + Tensor taccdVsdVt = smem_thr_copy_dKV.partition_D(sdVt); // ((Atom,AtomNum),PIPE_M,PIPE_N) + cute::copy(smem_tiled_copy_dKV, taccdKrdK, taccdKsdKt); + cute::copy(smem_tiled_copy_dKV, taccdVrdV, taccdVsdVt); + } + __syncthreads(); + Tensor tdKrdK = make_tensor(shape(tdKgdK)); + Tensor tdVrdV = make_tensor(shape(tdVgdV)); + cute::copy(gmem_tiled_copy_dKV, tdKsdK, tdKrdK); + cute::copy(gmem_tiled_copy_dKV, tdVsdV, tdVrdV); + // if (cute::thread0()) { print_tensor(tdKrdK); printf("\n"); } + // if (cute::thread0()) { print_tensor(tdVrdV); printf("\n"); } + + Tensor cdKV = make_identity_tensor(Shape, Int>{}); // (BLK_M,BLK_K) -> (blk_m,blk_k) + Tensor tdKVcdKV = gmem_thr_copy_dKV.partition_D(cdKV); + Tensor tdKVpdKV = make_tensor(make_shape(size<2>(tdKgdK))); + #pragma unroll + for (int k = 0; k < size(tdKVpdKV); ++k) { tdKVpdKV(k) = get<1>(tdKVcdKV(0, 0, k)) < params.d; } + // Clear_OOB_K must be false since we don't want to write zeros to gmem + flash::copy( + gmem_tiled_copy_dKV, tdKrdK, tdKgdK, tdKVcdKV, tdKVpdKV, binfo.actual_seqlen_k - n_block * kBlockN + ); + flash::copy( + gmem_tiled_copy_dKV, tdVrdV, tdVgdV, tdKVcdKV, tdKVpdKV, binfo.actual_seqlen_k - n_block * kBlockN + ); +} + +} // namespace flash diff --git a/hopper/flash_fwd_hdim128_bf16_sm90.cu b/hopper/flash_fwd_hdim128_bf16_sm90.cu new file mode 100644 index 000000000..11bb9ddec --- /dev/null +++ b/hopper/flash_fwd_hdim128_bf16_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. + +#include "flash_fwd_launch_template.h" + +template<> +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim128(params, stream); +} diff --git a/hopper/flash_fwd_hdim128_fp16_sm90.cu b/hopper/flash_fwd_hdim128_fp16_sm90.cu new file mode 100644 index 000000000..176c38edd --- /dev/null +++ b/hopper/flash_fwd_hdim128_fp16_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. + +#include "flash_fwd_launch_template.h" + +template<> +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim128(params, stream); +} diff --git a/hopper/flash_fwd_hdim256_bf16_sm90.cu b/hopper/flash_fwd_hdim256_bf16_sm90.cu new file mode 100644 index 000000000..06d0df617 --- /dev/null +++ b/hopper/flash_fwd_hdim256_bf16_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. + +#include "flash_fwd_launch_template.h" + +template<> +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim256(params, stream); +} diff --git a/hopper/flash_fwd_hdim256_fp16_sm90.cu b/hopper/flash_fwd_hdim256_fp16_sm90.cu new file mode 100644 index 000000000..0cc26c791 --- /dev/null +++ b/hopper/flash_fwd_hdim256_fp16_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. + +#include "flash_fwd_launch_template.h" + +template<> +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim256(params, stream); +} diff --git a/hopper/flash_fwd_hdim64_bf16_sm90.cu b/hopper/flash_fwd_hdim64_bf16_sm90.cu new file mode 100644 index 000000000..d3839898f --- /dev/null +++ b/hopper/flash_fwd_hdim64_bf16_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. + +#include "flash_fwd_launch_template.h" + +template<> +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim64(params, stream); +} diff --git a/hopper/flash_fwd_hdim64_fp16_sm90.cu b/hopper/flash_fwd_hdim64_fp16_sm90.cu new file mode 100644 index 000000000..c6eac5352 --- /dev/null +++ b/hopper/flash_fwd_hdim64_fp16_sm90.cu @@ -0,0 +1,9 @@ +// Copyright (c) 2024, Tri Dao. +// Splitting the different head dimensions to different files to speed up compilation. + +#include "flash_fwd_launch_template.h" + +template<> +void run_mha_fwd_(Flash_fwd_params ¶ms, cudaStream_t stream) { + run_mha_fwd_hdim64(params, stream); +} diff --git a/hopper/flash_fwd_kernel.h b/hopper/flash_fwd_kernel.h new file mode 100644 index 000000000..b97250d65 --- /dev/null +++ b/hopper/flash_fwd_kernel.h @@ -0,0 +1,176 @@ +/****************************************************************************** + * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. + ******************************************************************************/ + +#pragma once + +#include "cute/tensor.hpp" + +#include +#include +#include +#include +#include +#include "cutlass/pipeline/pipeline.hpp" + +#include "flash.h" +#include "utils.h" +#include "softmax.h" +#include "tile_scheduler.hpp" +#include "mainloop_fwd_sm90_tma_gmma_ws.hpp" +#include "epilogue_fwd_sm90_tma.hpp" + +namespace flash { + +using namespace cute; + +template +__global__ void __launch_bounds__(Ktraits::kNWarps * cutlass::NumThreadsPerWarp, 1) + compute_attn_ws(CUTE_GRID_CONSTANT typename CollectiveMainloopFwd::Params const mainloop_params, + CUTE_GRID_CONSTANT typename CollectiveEpilogueFwd::Params const epilogue_params, + CUTE_GRID_CONSTANT typename TileScheduler::Params const scheduler_params + ) { + + using Element = typename Ktraits::Element; + using ElementAccum = typename Ktraits::ElementAccum; + using SoftType = ElementAccum; + using TileShape_MNK = typename Ktraits::TileShape_MNK; + using ClusterShape = typename Ktraits::ClusterShape_MNK; + + static_assert(Ktraits::Is_WS); + static constexpr bool Is_WS = Ktraits::Is_WS; + + static constexpr int NumMmaThreads = size(typename Ktraits::TiledMma0{}); + static constexpr int NumCopyThreads = !Is_WS ? 0 : cutlass::NumThreadsPerWarpGroup; + static constexpr int kBlockM = Ktraits::kBlockM; + // static constexpr int kBlockN = Ktraits::kBlockN; + // constexpr int kHeadDim = Ktraits::kHeadDim; + + using CollectiveMainloop = CollectiveMainloopFwd; + using CollectiveEpilogue = CollectiveEpilogueFwd; + + using MainloopPipeline = typename Ktraits::MainloopPipeline; + using PipelineParams = typename MainloopPipeline::Params; + using PipelineState = typename MainloopPipeline::PipelineState; + + extern __shared__ char shared_memory[]; + auto &shared_storage = *reinterpret_cast(shared_memory); + + int const lane_predicate = cute::elect_one_sync(); + int const warp_idx = cutlass::canonical_warp_idx_sync(); + + // Issue Tma Descriptor Prefetch from a single thread + if (warp_idx == 0 && lane_predicate) { + CollectiveMainloop::prefetch_tma_descriptors(mainloop_params); + CollectiveEpilogue::prefetch_tma_descriptors(epilogue_params); + } + + // Obtain warp index + int const warp_group_thread_idx = threadIdx.x % cutlass::NumThreadsPerWarpGroup; + + PipelineParams pipeline_params; + pipeline_params.transaction_bytes = CollectiveMainloop::TmaTransactionBytesK; + int warp_group_idx = cutlass::canonical_warp_group_idx(); + pipeline_params.role = warp_group_idx == 0 + ? MainloopPipeline::ThreadCategory::Producer + : MainloopPipeline::ThreadCategory::Consumer; + pipeline_params.is_leader = warp_group_thread_idx == 0; + pipeline_params.num_consumers = NumMmaThreads; + + if (warp_idx == 0 && lane_predicate) { + shared_storage.barrier_Q.init(1 /*numThreads*/); + shared_storage.barrier_O.init(size(ClusterShape{}) /*numThreads*/); + } + // We're counting on pipeline_k to call cutlass::arch::fence_barrier_init(); + MainloopPipeline pipeline_k(shared_storage.pipeline_k, pipeline_params, ClusterShape{}); + MainloopPipeline pipeline_v(shared_storage.pipeline_v, pipeline_params, ClusterShape{}); + + CollectiveMainloop collective_mainloop; + CollectiveEpilogue collective_epilogue; + + // We need this to guarantee that the Pipeline init is visible to all producers and consumer blocks in the Cluster + if constexpr (size(ClusterShape{}) > 1) { + cute::cluster_arrive_relaxed(); + cute::cluster_wait(); + } else { + __syncthreads(); + } + + static_assert(Ktraits::kNWarps == 12 || Ktraits::kNWarps == 16); + if (warp_group_idx == 0) { // Producer + cutlass::arch::warpgroup_reg_dealloc(); + // cutlass::arch::warpgroup_reg_dealloc<56>(); + + int warp_idx_in_warpgroup = __shfl_sync(0xffffffff, (threadIdx.x / 32) % 4, 0); + if (warp_idx_in_warpgroup == 0) { // Load Q, K, V + PipelineState smem_pipe_write_k = cutlass::make_producer_start_state(); + PipelineState smem_pipe_write_v = cutlass::make_producer_start_state(); + + int work_idx = 0; + + TileScheduler scheduler(&shared_storage.tile_count_semaphore); + for (auto work_tile_info = scheduler.get_initial_work(); + work_tile_info.is_valid(scheduler_params); + work_tile_info = scheduler.template get_next_work(scheduler_params, work_tile_info)) { + auto block_coord = work_tile_info.get_block_coord(scheduler_params); + auto [m_block, bidh, bidb] = block_coord; + + int n_block_max = collective_mainloop.get_n_block_max(mainloop_params, m_block); + if (Is_causal && n_block_max <= 0) { + scheduler.prefetch_next_work(scheduler_params, work_tile_info); + scheduler.broadcast_next_work(work_tile_info); + continue; + } + collective_mainloop.load(mainloop_params, pipeline_k, pipeline_v, smem_pipe_write_k, smem_pipe_write_v, + shared_storage, scheduler, scheduler_params, work_tile_info, block_coord, work_idx); + ++work_idx; + } + collective_mainloop.load_tail(pipeline_k, pipeline_v, smem_pipe_write_k, smem_pipe_write_v); + } + } else { // Consumer + cutlass::arch::warpgroup_reg_alloc(); + // cutlass::arch::warpgroup_reg_alloc(); + + TileScheduler scheduler(&shared_storage.tile_count_semaphore); + // Initialize matmul objects. + typename Ktraits::TiledMma1 tiled_mma1; + + PipelineState smem_pipe_read_k, smem_pipe_read_v; + // We don't need separate variables smem_pipe_release_k and smem_pipe_release_v + // (like in Cutlass's gemm) because the read and release pipeline states are always the same. + + collective_mainloop.mma_init(); + scheduler.init_consumer(); + + int work_idx = 0; + CUTLASS_PRAGMA_NO_UNROLL + for (auto work_tile_info = scheduler.get_initial_work(); + work_tile_info.is_valid(scheduler_params); + work_tile_info = scheduler.template get_next_work(scheduler_params, work_tile_info)) { + // Attention output (GEMM-II) accumulator. + Tensor tOrO = partition_fragment_C(tiled_mma1, select<0, 2>(TileShape_MNK{})); + flash::Softmax<2 * (2 * kBlockM / NumMmaThreads)> softmax; + + auto block_coord = work_tile_info.get_block_coord(scheduler_params); + auto [m_block, bidh, bidb] = block_coord; + + int n_block_max = collective_mainloop.get_n_block_max(mainloop_params, m_block); + if (Is_causal && n_block_max <= 0) { // We exit early and write 0 to gO and -inf to gLSE. + collective_epilogue.store_zero(epilogue_params, threadIdx.x - NumCopyThreads, block_coord); + continue; + } + + collective_mainloop.mma(mainloop_params, pipeline_k, pipeline_v, smem_pipe_read_k, smem_pipe_read_v, + tOrO, softmax, n_block_max, threadIdx.x - NumCopyThreads, work_idx, m_block, shared_storage); + // tOrO, softmax, n_block_max, threadIdx.x - NumCopyThreads + (work_idx >> 30), work_idx, shared_storage); + collective_epilogue.store(epilogue_params, tOrO, softmax.row_sum, shared_storage, tiled_mma1, + threadIdx.x - NumCopyThreads, block_coord); + + ++work_idx; + } + collective_epilogue.store_tail(); + } + +} + +} // namespace flash diff --git a/hopper/flash_fwd_launch_template.h b/hopper/flash_fwd_launch_template.h new file mode 100644 index 000000000..e0b40ceeb --- /dev/null +++ b/hopper/flash_fwd_launch_template.h @@ -0,0 +1,117 @@ +/****************************************************************************** + * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. + ******************************************************************************/ + +#pragma once + +#include + +#include "cute/tensor.hpp" + +#include "cutlass/cutlass.h" +#include "cutlass/cluster_launch.hpp" + +#include "static_switch.h" +#include "flash.h" +#include "tile_scheduler.hpp" +#include "flash_fwd_kernel.h" +#include "kernel_traits.h" + + +template +void run_flash_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { + using Element = typename Kernel_traits::Element; + using TileShape_MNK = typename Kernel_traits::TileShape_MNK; + using ClusterShape = typename Kernel_traits::ClusterShape_MNK; + + // print(typename Kernel_traits::SmemLayoutVt{}); printf("\n"); print(typename Kernel_traits::SmemLayoutVt_tmp{}); + using CollectiveMainloop = flash::CollectiveMainloopFwd; + using CollectiveEpilogue = flash::CollectiveEpilogueFwd; + using Scheduler = std::conditional_t>; + // flash::SingleTileScheduler>; + typename CollectiveMainloop::Params mainloop_params = + CollectiveMainloop::to_underlying_arguments({ + static_cast(params.q_ptr), + {params.seqlen_q, params.d, params.h, params.b}, // shape_Q + {params.q_row_stride, _1{}, params.q_head_stride, params.q_batch_stride}, // stride_Q + static_cast(params.k_ptr), + {params.seqlen_k, params.d, params.h_k, params.b}, // shape_K + {params.k_row_stride, _1{}, params.k_head_stride, params.k_batch_stride}, // stride_K + static_cast(params.v_ptr), + {params.v_row_stride, _1{}, params.v_head_stride, params.v_batch_stride}, // stride_V + params.scale_softmax_log2 + }); + typename CollectiveEpilogue::Params epilogue_params = + CollectiveEpilogue::to_underlying_arguments({ + static_cast(params.o_ptr), + {params.seqlen_q, params.d, params.h, params.b}, // shape_O + {params.o_row_stride, _1{}, params.o_head_stride, params.o_batch_stride}, // stride_O + static_cast(params.softmax_lse_ptr), + {_1{}, params.seqlen_q, params.h * params.seqlen_q}, // stride_LSE + }); + + int num_blocks_m = cutlass::ceil_div(params.seqlen_q, Kernel_traits::kBlockM); + num_blocks_m = cutlass::ceil_div(num_blocks_m, size<0>(ClusterShape{})) * size<0>(ClusterShape{}); + typename Scheduler::Arguments scheduler_args = {num_blocks_m, params.h, params.b, params.tile_count_semaphore}; + typename Scheduler::Params scheduler_params = Scheduler::to_underlying_arguments(scheduler_args); + + // Get the ptr to kernel function. + void *kernel; + kernel = (void *)flash::compute_attn_ws; + int smem_size = sizeof(typename Kernel_traits::SharedStorage); + // int smem_size_q = sizeof(decltype((typename Kernel_traits::SharedStorage{}).smem_q)); + // int smem_size_k = sizeof(decltype((typename Kernel_traits::SharedStorage{}).smem_k)); + // int smem_size_v = sizeof(decltype((typename Kernel_traits::SharedStorage{}).smem_v)); + // printf("smem_size = %d, q = %d, k = %d, v = %d\n", smem_size, smem_size_q, smem_size_k, smem_size_v); + if (smem_size >= 48 * 1024) { + C10_CUDA_CHECK(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); + } + + int device; + cudaGetDevice(&device); + int multiprocessor_count; + cudaError status_ = cudaDeviceGetAttribute( + &multiprocessor_count, cudaDevAttrMultiProcessorCount, device); + if (status_ != cudaSuccess) { + C10_CUDA_CHECK(status_); + } + dim3 grid_dims = Scheduler::get_grid_dim(scheduler_args, multiprocessor_count); + static constexpr int ctaSize = Kernel_traits::kNWarps * 32; + dim3 block_dims(ctaSize); + dim3 cluster_dims(size<0>(ClusterShape{}), size<1>(ClusterShape{}), size<2>(ClusterShape{})); + cutlass::ClusterLaunchParams launch_params{grid_dims, block_dims, cluster_dims, smem_size, stream}; + cutlass::launch_kernel_on_cluster(launch_params, kernel, mainloop_params, epilogue_params, scheduler_params); + C10_CUDA_KERNEL_LAUNCH_CHECK(); +} + +template +void run_mha_fwd_hdim64(Flash_fwd_params ¶ms, cudaStream_t stream) { + constexpr static int Headdim = 64; + BOOL_SWITCH(params.is_causal, Is_causal, [&] { + run_flash_fwd, Is_causal>(params, stream); + }); +} + +template +void run_mha_fwd_hdim128(Flash_fwd_params ¶ms, cudaStream_t stream) { + constexpr static int Headdim = 128; + BOOL_SWITCH(params.is_causal, Is_causal, [&] { + // Only use Cluster if number of tiles along seqlen_q is even + BOOL_SWITCH(cutlass::ceil_div(params.seqlen_q, 128) % 2 == 0, UseCluster, [&] { + run_flash_fwd, Is_causal>(params, stream); + }); + }); +} + +template +void run_mha_fwd_hdim256(Flash_fwd_params ¶ms, cudaStream_t stream) { + constexpr static int Headdim = 256; + BOOL_SWITCH(params.is_causal, Is_causal, [&] { + // Only use Cluster if number of tiles along seqlen_q is even + BOOL_SWITCH(cutlass::ceil_div(params.seqlen_q, 128) % 2 == 0, UseCluster, [&] { + run_flash_fwd, Is_causal>(params, stream); + }); + }); +} diff --git a/hopper/kernel_traits.h b/hopper/kernel_traits.h new file mode 100644 index 000000000..90ee3ccf9 --- /dev/null +++ b/hopper/kernel_traits.h @@ -0,0 +1,810 @@ +/****************************************************************************** + * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. + ******************************************************************************/ + +#pragma once + +#include "cute/algorithm/copy.hpp" +#include "cute/atom/mma_atom.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" + +#include "cutlass/cutlass.h" +#include "cutlass/layout/layout.h" +#include "cutlass/numeric_types.h" +#include "cutlass/pipeline/pipeline.hpp" + +using namespace cute; + +template +struct SharedStorageQKVO { + cute::array_aligned> smem_q; + cute::array_aligned> smem_k; + union { + cute::array_aligned> smem_v; + cute::array_aligned> smem_o; + }; + struct { + cutlass::arch::ClusterTransactionBarrier barrier_Q; + cutlass::arch::ClusterBarrier barrier_O; + typename cutlass::PipelineTmaAsync::SharedStorage pipeline_k; + typename cutlass::PipelineTmaAsync::SharedStorage pipeline_v; + int tile_count_semaphore; + }; +}; + +// If Share_Q_K_smem is true, that forces Is_Q_in_regs to be true +template +struct Flash_fwd_kernel_traits { + using Element = elem_type; + using ElementAccum = float; + using index_t = int64_t; + + // The number of threads. + static constexpr int kNWarps = kNWarps_; + static constexpr int kNThreads = kNWarps * cutlass::NumThreadsPerWarp; + + static constexpr bool Is_Q_in_regs = Is_Q_in_regs_; + static_assert(kNWarps_ == 4 || kNWarps_ == 8 || kNWarps_ == 12 || kNWarps_ == 16); + static constexpr bool Is_WS = kNWarps_ >= 12; + static_assert(!(Is_WS && Is_Q_in_regs), "Warp-specialization does not support Q in registers"); + + static constexpr int kBlockM = kBlockM_; + static constexpr int kBlockN = kBlockN_; + static constexpr int kHeadDim = kHeadDim_; + static_assert(kHeadDim % 32 == 0); + using TileShape_MNK = Shape, Int, Int>; + + static constexpr int kClusterM = kClusterM_; + using ClusterShape_MNK = Shape, _1, _1>; + + static constexpr int kStages = kStages_; + + using AtomLayoutMNK = Layout, _1, _1>>; + using TiledMma0 = decltype(cute::make_tiled_mma( + std::conditional_t< + Is_Q_in_regs, + decltype(cute::GMMA::rs_op_selector()), + decltype(cute::GMMA::ss_op_selector()) + >{}, + AtomLayoutMNK{})); + using TiledMma1 = decltype(cute::make_tiled_mma( + cute::GMMA::rs_op_selector(TileShape_MNK{})), + GMMA::Major::K, GMMA::Major::MN>(), + AtomLayoutMNK{})); + + using SmemLayoutAtomQ = decltype(cutlass::gemm::collective::detail::ss_smem_selector(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>()); + using SmemLayoutQ = decltype(tile_to_shape(SmemLayoutAtomQ{}, select<0, 2>(TileShape_MNK{}))); + + using SmemLayoutAtomK = decltype(cutlass::gemm::collective::detail::ss_smem_selector(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>()); + using SmemLayoutK = + decltype(tile_to_shape(SmemLayoutAtomK{}, + make_shape(shape<1>(TileShape_MNK{}), shape<2>(TileShape_MNK{}), Int{}))); + + using SmemLayoutAtomV = decltype(cutlass::gemm::collective::detail::ss_smem_selector(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>()); + using SmemLayoutV = + decltype(tile_to_shape(SmemLayoutAtomV{}, + make_shape(shape<1>(TileShape_MNK{}), shape<2>(TileShape_MNK{}), Int{}))); + + using SmemLayoutAtomO = decltype(cutlass::gemm::collective::detail::ss_smem_selector(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>()); + using SmemLayoutO = decltype(tile_to_shape(SmemLayoutAtomO{}, select<0, 2>(TileShape_MNK{}))); + + using SmemCopyAtomQ = Copy_Atom; + + using SharedStorage = SharedStorageQKVO; + + using MainloopPipeline = typename cutlass::PipelineTmaAsync; + using PipelineState = typename cutlass::PipelineState; + // using BarrierType = typename MainloopPipeline::ProducerBarrierType; + +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct SharedStorageQKVdOdKV; + +template +struct SharedStorageQKVdOdKV { + struct { + cute::array_aligned> smem_q; + cute::array_aligned> smem_do; + union { + struct { + cute::array_aligned> smem_k; + cute::array_aligned> smem_v; + }; + struct { + cute::array_aligned> smem_dk; + cute::array_aligned> smem_dv; + }; + }; + cute::array_aligned> smem_p; + cute::array_aligned> smem_ds; + }; + struct { + cute::uint64_t tma_load_mbar[8]; // 8 TMA barrier pre-allcoated for usage. + cutlass::arch::ClusterTransactionBarrier barrier_K; + cutlass::arch::ClusterTransactionBarrier barrier_V; + typename cutlass::PipelineTmaAsync::SharedStorage pipeline_q; + typename cutlass::PipelineTmaAsync::SharedStorage pipeline_do; + }; +}; + +template +struct SharedStorageQKVdOdKV { + struct { + cute::array_aligned> smem_q; + cute::array_aligned> smem_do; + union { + struct { + cute::array_aligned> smem_k; + cute::array_aligned> smem_v; + }; + struct { + cute::array_aligned> smem_dk; + cute::array_aligned> smem_dv; + }; + }; + union { // Put smem_p in a union just so we can still refer to it in the struct, even if it's not used. + cute::array_aligned> smem_p; + cute::array_aligned> smem_ds; + }; + }; + struct { + cute::uint64_t tma_load_mbar[8]; // 8 TMA barrier pre-allcoated for usage. + cutlass::arch::ClusterTransactionBarrier barrier_K; + cutlass::arch::ClusterTransactionBarrier barrier_V; + typename cutlass::PipelineTmaAsync::SharedStorage pipeline_q; + typename cutlass::PipelineTmaAsync::SharedStorage pipeline_do; + }; +}; + +template +struct SharedStorageQKVdOdKVWS; + +template +struct SharedStorageQKVdOdKVWS { + struct { + cute::array_aligned> smem_q; + cute::array_aligned> smem_do; + union { + struct { + cute::array_aligned> smem_k; + cute::array_aligned> smem_v; + }; + struct { + cute::array_aligned> smem_dk; + cute::array_aligned> smem_dv; + }; + }; + cute::array_aligned> smem_p; + cute::array_aligned> smem_ds; + cute::array_aligned> smem_dqacc; + cute::array_aligned smem_lse; + cute::array_aligned smem_dpsum; + }; + struct { + cute::uint64_t tma_load_mbar[8]; // 8 TMA barrier pre-allcoated for usage. + cutlass::arch::ClusterTransactionBarrier barrier_K; + cutlass::arch::ClusterTransactionBarrier barrier_V; + typename cutlass::PipelineTmaAsync::SharedStorage pipeline_q; + typename cutlass::PipelineTmaAsync::SharedStorage pipeline_do; + }; +}; + +template +struct SharedStorageQKVdOdKVWS { + struct { + cute::array_aligned> smem_q; + cute::array_aligned> smem_do; + union { + struct { + cute::array_aligned> smem_k; + cute::array_aligned> smem_v; + }; + struct { + cute::array_aligned> smem_dk; + cute::array_aligned> smem_dv; + }; + }; + union { // Put smem_p in a union just so we can still refer to it in the struct, even if it's not used. + cute::array_aligned> smem_p; + cute::array_aligned> smem_ds; + }; + cute::array_aligned> smem_dqacc; + cute::array_aligned smem_lse; + cute::array_aligned smem_dpsum; + }; + struct { + cute::uint64_t tma_load_mbar[8]; // 8 TMA barrier pre-allcoated for usage. + cutlass::arch::ClusterTransactionBarrier barrier_K; + cutlass::arch::ClusterTransactionBarrier barrier_V; + typename cutlass::PipelineTmaAsync::SharedStorage pipeline_q; + typename cutlass::PipelineTmaAsync::SharedStorage pipeline_do; + }; +}; + +template +struct SharedStorageQKVdOdKVSeqqPar; + +template +struct SharedStorageQKVdOdKVSeqqPar { + struct { + cute::array_aligned> smem_k; + cute::array_aligned> smem_v; + union { + struct { + cute::array_aligned> smem_q; + cute::array_aligned> smem_do; + }; + struct { + cute::array_aligned> smem_dq; + }; + }; + cute::array_aligned> smem_p; + cute::array_aligned> smem_ds; + }; + struct { + cute::uint64_t tma_load_mbar[8]; // 8 TMA barrier pre-allcoated for usage. + cutlass::arch::ClusterTransactionBarrier barrier_Q; + cutlass::arch::ClusterTransactionBarrier barrier_dO; + typename cutlass::PipelineTmaAsync::SharedStorage pipeline_k; + typename cutlass::PipelineTmaAsync::SharedStorage pipeline_v; + }; +}; + +template +struct SharedStorageQKVdOdKVSeqqPar { + struct { + cute::array_aligned> smem_k; + cute::array_aligned> smem_v; + union { + struct { + cute::array_aligned> smem_q; + cute::array_aligned> smem_do; + }; + struct { + cute::array_aligned> smem_dq; + }; + }; + union { // Put smem_p in a union just so we can still refer to it in the struct, even if it's not used. + cute::array_aligned> smem_p; + cute::array_aligned> smem_ds; + }; + }; + struct { + cute::uint64_t tma_load_mbar[8]; // 8 TMA barrier pre-allcoated for usage. + cutlass::arch::ClusterTransactionBarrier barrier_Q; + cutlass::arch::ClusterTransactionBarrier barrier_dO; + typename cutlass::PipelineTmaAsync::SharedStorage pipeline_k; + typename cutlass::PipelineTmaAsync::SharedStorage pipeline_v; + }; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Flash_bwd_kernel_traits { + using Element = elem_type; + using ElementAccum = float; + using index_t = int64_t; + + // The number of threads. + static constexpr int kNWarps = kNWarps_; + static constexpr int kNThreads = kNWarps * cutlass::NumThreadsPerWarp; + static constexpr int kNThreadsNonWS = 8 * cutlass::NumThreadsPerWarp; + // static constexpr int kNThreadsdQ = cutlass::NumThreadsPerWarpGroup; + static constexpr int kNThreadsdQ = 2 * cutlass::NumThreadsPerWarpGroup; + + static_assert(kNWarps_ == 8 || kNWarps_ == 12); + + static constexpr bool Is_WS = kNWarps_ >= 12; + + static constexpr int kBlockM = kBlockM_; + static constexpr int kBlockN = kBlockN_; + static constexpr int kHeadDim = kHeadDim_; + static_assert(kHeadDim % 32 == 0); + using TileShape_MNK = Shape, Int, Int>; + + static constexpr int kClusterN = kClusterN_; + using ClusterShape_MNK = Shape<_1, Int, _1>; + + static constexpr int kStages = 2; + + static constexpr bool SdP_swapAB = SdP_swapAB_; + static constexpr bool dKV_swapAB = dKV_swapAB_; + static constexpr bool dQ_swapAB = dQ_swapAB_; + static_assert(!(SdP_swapAB && dKV_swapAB)); // If SdP_swapAB, then we don't swap for dKV + + static constexpr bool Mma_dQ_is_RS = AtomLayoutMSdP == 2 && AtomLayoutMdQ == 2 && !SdP_swapAB && !dQ_swapAB; // If dQ_swapAB we can't use RS + + using TileShapeAtomSdP = std::conditional_t< + !SdP_swapAB, + Shape, Int, Int>, + Shape, Int, Int> + >; + using AtomLayoutSdP = std::conditional_t< + !SdP_swapAB, + Layout, Int<2 / AtomLayoutMSdP>, _1>>, + Layout, Int, _1>> + >; + using TiledMmaSdP = decltype(cute::make_tiled_mma( + cute::GMMA::ss_op_selector(), + AtomLayoutSdP{})); + + using TileShapeAtomdKV = std::conditional_t< + !dKV_swapAB, + Shape, Int, Int>, + Shape, Int, Int> + >; + using AtomLayoutdKV = std::conditional_t< + !dKV_swapAB, + Layout, Int<2 / AtomLayoutNdKV>, _1>>, + Layout, Int, _1>> + >; + using TiledMmadKV = decltype(cute::make_tiled_mma( + std::conditional_t< + !SdP_swapAB, + decltype(cute::GMMA::ss_op_selector()), + decltype(cute::GMMA::rs_op_selector()) + >{}, + AtomLayoutdKV{})); + + using TileShapeAtomdQ = std::conditional_t< + !dQ_swapAB, + Shape, Int, Int>, + Shape, Int, Int> + // Shape, Int, Int>, + // Shape, Int, Int> + >; + using AtomLayoutdQ = std::conditional_t< + !dQ_swapAB, + Layout, Int<2 / AtomLayoutMdQ>, _1>>, + Layout, Int, _1>> + // Layout, Int<1>, _1>>, + // Layout, Int<1>, _1>> + >; + static constexpr GMMA::Major MmadQMajorA = !dQ_swapAB ? GMMA::Major::K : GMMA::Major::MN; + static constexpr GMMA::Major MmadQMajorB = !dQ_swapAB ? GMMA::Major::MN : GMMA::Major::K; + using TiledMmadQ = decltype(cute::make_tiled_mma( + std::conditional_t< + !dQ_swapAB, + std::conditional_t< + Mma_dQ_is_RS, + decltype(cute::GMMA::rs_op_selector()), + decltype(cute::GMMA::ss_op_selector()) + >, + decltype(cute::GMMA::ss_op_selector()) + >{}, + AtomLayoutdQ{})); + + using GmemTiledCopyQdO = decltype(cutlass::gemm::collective::detail::sm90_cluster_shape_to_tma_atom(shape<1>(ClusterShape_MNK{}))); + using GmemTiledCopyKV = cute::SM90_TMA_LOAD; + using GmemTiledCopydKV = cute::SM90_TMA_STORE; + +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 + static constexpr bool Has_cp_async = true; +#else + static constexpr bool Has_cp_async = false; +#endif + // For the dot_do_o preprocessing kernel + using Gmem_copy_struct = std::conditional_t< + Has_cp_async, + SM80_CP_ASYNC_CACHEGLOBAL, + DefaultCopy + >; + static constexpr int kBlockKSmem = kHeadDim % 64 == 0 ? 64 : 32; + static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(Element); + static_assert(kHeadDim % kGmemElemsPerLoad == 0, "kHeadDim must be a multiple of kGmemElemsPerLoad"); + // Using kBlockKSmem instead of kHeadDim here to avoid bank conflicts, but doesn't seem + // to affect speed in practice. + static constexpr int kGmemThreadsPerRow = kBlockKSmem / kGmemElemsPerLoad; + static_assert(kNThreadsNonWS % kGmemThreadsPerRow == 0, "kNThreadsNonWS must be a multiple of kGmemThreadsPerRow"); + using GmemLayoutAtom = Layout, Int>, + Stride, _1>>; + using GmemLayoutAtomdQ = Layout, Int>, + Stride, _1>>; + using GmemTiledCopydO = decltype( + make_tiled_copy(Copy_Atom{}, + GmemLayoutAtom{}, + Layout>{})); // Val layout, 8 vals per store + using GmemTiledCopydQ = decltype( + make_tiled_copy(Copy_Atom{}, + GmemLayoutAtomdQ{}, + Layout>{})); // Val layout, 8 vals per store + using GmemLayoutAtomdQaccum = std::conditional_t< + kBlockKSmem == 32, + Layout, _8>, // Thread layout, 8 threads per row + Stride< _8, _1>>, + Layout, _16>, // Thread layout, 16 threads per row + Stride< _16, _1>> + >; + using GmemTiledCopydQaccum = decltype( + make_tiled_copy(Copy_Atom{}, + GmemLayoutAtomdQaccum{}, + Layout>{})); // Val layout, 4 vals per store + + using SmemLayoutAtomQ = decltype(cutlass::gemm::collective::detail::ss_smem_selector(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>()); + using SmemLayoutQ = + decltype(tile_to_shape(SmemLayoutAtomQ{}, + make_shape(shape<0>(TileShape_MNK{}), shape<2>(TileShape_MNK{}), Int{}))); + using SmemLayoutdO = SmemLayoutQ; + + using SmemLayoutAtomK = decltype(cutlass::gemm::collective::detail::ss_smem_selector(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>()); + using SmemLayoutK = decltype(tile_to_shape(SmemLayoutAtomK{}, select<1, 2>(TileShape_MNK{}))); + + using SmemLayoutAtomV = decltype(cutlass::gemm::collective::detail::ss_smem_selector(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>()); + using SmemLayoutV = decltype(tile_to_shape(SmemLayoutAtomV{}, select<1, 2>(TileShape_MNK{}))); + + using SmemLayoutAtomP = decltype(cutlass::gemm::collective::detail::ss_smem_selector(TileShape_MNK{})), decltype(cute::get<1>(TileShape_MNK{}))>()); + using SmemLayoutP = decltype(tile_to_shape(SmemLayoutAtomP{}, select<0, 1>(TileShape_MNK{}))); + using SmemLayoutAtomdS = decltype(cutlass::gemm::collective::detail::ss_smem_selector(TileShape_MNK{})), decltype(cute::get<1>(TileShape_MNK{}))>()); + using SmemLayoutdS = decltype(tile_to_shape(SmemLayoutAtomdS{}, select<0, 1>(TileShape_MNK{}))); + + // using SmemLayoutAtomdQacc = decltype(cutlass::gemm::collective::detail::ss_smem_selector(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>()); + // using SmemLayoutdQacc = decltype(tile_to_shape(SmemLayoutAtomdQacc{}, select<0, 2>(TileShape_MNK{}))); + + // Note this is the transpose in terms of the view, not in terms of memory. + using SmemLayoutQt = + decltype(cute::composition(SmemLayoutQ{}, + make_layout(make_shape(get<2>(TileShape_MNK{}), get<0>(TileShape_MNK{}), Int{}), + make_stride(Int{}, _1{}, Int{})))); + using SmemLayoutdOt = + decltype(cute::composition(SmemLayoutdO{}, + make_layout(make_shape(get<2>(TileShape_MNK{}), get<0>(TileShape_MNK{}), Int{}), + make_stride(Int{}, _1{}, Int{})))); + using SmemLayoutKt = + decltype(cute::composition(SmemLayoutK{}, + make_layout(make_shape(get<2>(TileShape_MNK{}), get<1>(TileShape_MNK{})), + make_stride(Int{}, _1{})))); + using SmemLayoutPt = + decltype(cute::composition(SmemLayoutP{}, + make_layout(make_shape(get<1>(TileShape_MNK{}), get<0>(TileShape_MNK{})), + make_stride(Int{}, _1{})))); + using SmemLayoutdSt = + decltype(cute::composition(SmemLayoutdS{}, + make_layout(make_shape(get<1>(TileShape_MNK{}), get<0>(TileShape_MNK{})), + make_stride(Int{}, _1{})))); + + // using SmemLayoutdQacct = + // decltype(cute::composition(SmemLayoutdQacc{}, + // make_layout(make_shape(get<2>(TileShape_MNK{}), get<0>(TileShape_MNK{})), + // make_stride(Int{}, _1{})))); + + using SmemLayoutdK = SmemLayoutK; + using SmemLayoutdV = SmemLayoutV; + using SmemLayoutdKt = SmemLayoutKt; + using SmemLayoutdVt = SmemLayoutKt; + + static constexpr int kSwizzle = kBlockKSmem == 32 ? 2 : 3; + using SmemLayoutAtomdQ = decltype( + // composition(Swizzle{}, + composition(Swizzle<3, 3, 3>{}, + Layout, Int<32>>, + Stride, _1>>{})); + using SmemLayoutdQ = decltype(tile_to_shape( + SmemLayoutAtomdQ{}, + make_shape(Int{}, Int{}))); + using SmemLayoutdQt = + decltype(cute::composition(SmemLayoutdQ{}, + make_layout(make_shape(get<2>(TileShape_MNK{}), get<0>(TileShape_MNK{})), + make_stride(Int{}, _1{})))); + static constexpr int kSmemdQSize = size(SmemLayoutdQ{}) * sizeof(Element); + + using SmemLayoutAtomdQaccTMA = decltype(cutlass::gemm::collective::detail::ss_smem_selector(TileShape_MNK{})), decltype(cute::get<1>(TileShape_MNK{}))>()); + using SmemLayoutdQaccTMA = decltype(tile_to_shape(SmemLayoutAtomdQaccTMA{}, select<0, 2>(TileShape_MNK{}))); + using SmemLayoutdQacc = SmemLayoutdQ; + using SmemLayoutdQacct = SmemLayoutdQt; + using SmemLayoutdQacc2 = decltype(tile_to_shape( + SmemLayoutAtomdQ{}, + make_shape(Int{}, Int{}, _2{}))); + // using SmemLayoutdQacc = decltype(tile_to_shape(SmemLayoutAtomdQacc{}, select<0, 2>(TileShape_MNK{}))); + // using SmemLayoutdQacct = + // decltype(cute::composition(SmemLayoutdQacc{}, + // make_layout(make_shape(get<2>(TileShape_MNK{}), get<0>(TileShape_MNK{})), + // make_stride(Int{}, _1{})))); + using RmemTiledCopydQacc = decltype( + make_tiled_copy(Copy_Atom{}, + GmemLayoutAtomdQaccum{}, + Layout>{})); // Val layout, 4 vals per store + + // using SmemCopyAtomQ = Copy_Atom; + using SmemCopyAtomPdS = Copy_Atom< + std::conditional_t, + Element>; + using SmemCopyAtomdKV = Copy_Atom< + std::conditional_t, + Element>; + using SmemCopyAtomdQ = Copy_Atom< + std::conditional_t, + Element>; + + using SharedStorage = std::conditional_t< + !Is_WS, + SharedStorageQKVdOdKV, + SharedStorageQKVdOdKVWS + // SmemLayoutK, SmemLayoutV, SmemLayoutdS, SmemLayoutdQacc2, SmemLayoutdK, SmemLayoutdV> + >; + + // using MainloopPipeline = typename cutlass::PipelineTmaAsync; + // using PipelineState = typename cutlass::PipelineState; + using MainloopPipeline = typename cutlass::PipelineTmaAsync; + +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Flash_bwd_seqqpar_kernel_traits { + using Element = elem_type; + using ElementAccum = float; + using index_t = int64_t; + + // The number of threads. + static constexpr int kNWarps = kNWarps_; + static constexpr int kNThreads = kNWarps * cutlass::NumThreadsPerWarp; + + static_assert(kNWarps_ == 8); + + static constexpr int kBlockM = kBlockM_; + static constexpr int kBlockN = kBlockN_; + static constexpr int kHeadDim = kHeadDim_; + static_assert(kHeadDim % 32 == 0); + using TileShape_MNK = Shape, Int, Int>; + + static constexpr int kClusterN = kClusterN_; + using ClusterShape_MNK = Shape<_1, Int, _1>; + + static constexpr int kStages = 2; + + static constexpr bool SdP_swapAB = SdP_swapAB_; + static constexpr bool dKV_swapAB = dKV_swapAB_; + static constexpr bool dQ_swapAB = dQ_swapAB_; + static_assert(!(SdP_swapAB && dKV_swapAB)); // If SdP_swapAB, then we don't swap for dKV + + static constexpr bool Mma_dQ_is_RS = AtomLayoutMSdP == 2 && AtomLayoutMdQ == 2 && !SdP_swapAB && !dQ_swapAB; // If dQ_swapAB we can't use RS + + using TileShapeAtomSdP = std::conditional_t< + !SdP_swapAB, + Shape, Int, Int>, + Shape, Int, Int> + >; + using AtomLayoutSdP = std::conditional_t< + !SdP_swapAB, + Layout, Int<2 / AtomLayoutMSdP>, _1>>, + Layout, Int, _1>> + >; + using TiledMmaSdP = decltype(cute::make_tiled_mma( + cute::GMMA::ss_op_selector(), + AtomLayoutSdP{})); + + using TileShapeAtomdKV = std::conditional_t< + !dKV_swapAB, + Shape, Int, Int>, + Shape, Int, Int> + >; + using AtomLayoutdKV = std::conditional_t< + !dKV_swapAB, + Layout, Int<2 / AtomLayoutNdKV>, _1>>, + Layout, Int, _1>> + >; + using TiledMmadKV = decltype(cute::make_tiled_mma( + std::conditional_t< + !SdP_swapAB, + decltype(cute::GMMA::ss_op_selector()), + decltype(cute::GMMA::rs_op_selector()) + >{}, + AtomLayoutdKV{})); + + using TileShapeAtomdQ = std::conditional_t< + !dQ_swapAB, + Shape, Int, Int>, + Shape, Int, Int> + >; + using AtomLayoutdQ = std::conditional_t< + !dQ_swapAB, + Layout, Int<2 / AtomLayoutMdQ>, _1>>, + Layout, Int, _1>> + >; + static constexpr GMMA::Major MmadQMajorA = !dQ_swapAB ? GMMA::Major::K : GMMA::Major::MN; + static constexpr GMMA::Major MmadQMajorB = !dQ_swapAB ? GMMA::Major::MN : GMMA::Major::K; + using TiledMmadQ = decltype(cute::make_tiled_mma( + std::conditional_t< + !dQ_swapAB, + std::conditional_t< + Mma_dQ_is_RS, + decltype(cute::GMMA::rs_op_selector()), + decltype(cute::GMMA::ss_op_selector()) + >, + decltype(cute::GMMA::ss_op_selector()) + >{}, + AtomLayoutdQ{})); + + using GmemTiledCopyQdO = decltype(cutlass::gemm::collective::detail::sm90_cluster_shape_to_tma_atom(shape<1>(ClusterShape_MNK{}))); + using GmemTiledCopyKV = cute::SM90_TMA_LOAD; + using GmemTiledCopydKV = cute::SM90_TMA_STORE; + +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 + static constexpr bool Has_cp_async = true; +#else + static constexpr bool Has_cp_async = false; +#endif + // For the dot_do_o preprocessing kernel + using Gmem_copy_struct = std::conditional_t< + Has_cp_async, + SM80_CP_ASYNC_CACHEGLOBAL, + DefaultCopy + >; + static constexpr int kBlockKSmem = kHeadDim % 64 == 0 ? 64 : 32; + static constexpr int kGmemElemsPerLoad = sizeof(cute::uint128_t) / sizeof(Element); + static_assert(kHeadDim % kGmemElemsPerLoad == 0, "kHeadDim must be a multiple of kGmemElemsPerLoad"); + // Using kBlockKSmem instead of kHeadDim here to avoid bank conflicts, but doesn't seem + // to affect speed in practice. + static constexpr int kGmemThreadsPerRow = kBlockKSmem / kGmemElemsPerLoad; + static_assert(kNThreads % kGmemThreadsPerRow == 0, "kNThreads must be a multiple of kGmemThreadsPerRow"); + using GmemLayoutAtom = Layout, Int>, + Stride, _1>>; + using GmemTiledCopydO = decltype( + make_tiled_copy(Copy_Atom{}, + GmemLayoutAtom{}, + Layout>{})); // Val layout, 8 vals per store + using GmemTiledCopydQ = decltype( + make_tiled_copy(Copy_Atom{}, + GmemLayoutAtom{}, + Layout>{})); // Val layout, 8 vals per store + using GmemLayoutAtomdQaccum = std::conditional_t< + kBlockKSmem == 32, + Layout, // Thread layout, 8 threads per row + Stride< _8, _1>>, + Layout, // Thread layout, 16 threads per row + Stride< _16, _1>> + >; + using GmemTiledCopydQaccum = decltype( + make_tiled_copy(Copy_Atom{}, + GmemLayoutAtomdQaccum{}, + Layout>{})); // Val layout, 4 vals per store + + using SmemLayoutAtomQ = decltype(cutlass::gemm::collective::detail::ss_smem_selector(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>()); + using SmemLayoutQ = decltype(tile_to_shape(SmemLayoutAtomQ{}, select<0, 2>(TileShape_MNK{}))); + using SmemLayoutdO = SmemLayoutQ; + + using SmemLayoutAtomK = decltype(cutlass::gemm::collective::detail::ss_smem_selector(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>()); + using SmemLayoutK = decltype(tile_to_shape(SmemLayoutAtomK{}, + make_shape(shape<1>(TileShape_MNK{}), shape<2>(TileShape_MNK{}), Int{}))); + + using SmemLayoutAtomV = decltype(cutlass::gemm::collective::detail::ss_smem_selector(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>()); + using SmemLayoutV = decltype(tile_to_shape(SmemLayoutAtomV{}, + make_shape(shape<1>(TileShape_MNK{}), shape<2>(TileShape_MNK{}), Int{}))); + + using SmemLayoutAtomP = decltype(cutlass::gemm::collective::detail::ss_smem_selector(TileShape_MNK{})), decltype(cute::get<1>(TileShape_MNK{}))>()); + using SmemLayoutP = decltype(tile_to_shape(SmemLayoutAtomP{}, select<0, 1>(TileShape_MNK{}))); + using SmemLayoutAtomdS = decltype(cutlass::gemm::collective::detail::ss_smem_selector(TileShape_MNK{})), decltype(cute::get<1>(TileShape_MNK{}))>()); + using SmemLayoutdS = decltype(tile_to_shape(SmemLayoutAtomdS{}, select<0, 1>(TileShape_MNK{}))); + + // Note this is the transpose in terms of the view, not in terms of memory. + using SmemLayoutQt = + decltype(cute::composition(SmemLayoutQ{}, + make_layout(make_shape(get<2>(TileShape_MNK{}), get<0>(TileShape_MNK{})), + make_stride(Int{}, _1{})))); + using SmemLayoutdOt = + decltype(cute::composition(SmemLayoutdO{}, + make_layout(make_shape(get<2>(TileShape_MNK{}), get<0>(TileShape_MNK{})), + make_stride(Int{}, _1{})))); + using SmemLayoutKt = + decltype(cute::composition(SmemLayoutK{}, + make_layout(make_shape(get<2>(TileShape_MNK{}), get<1>(TileShape_MNK{}), Int{}), + make_stride(Int{}, _1{}, Int{})))); + using SmemLayoutPt = + decltype(cute::composition(SmemLayoutP{}, + make_layout(make_shape(get<1>(TileShape_MNK{}), get<0>(TileShape_MNK{})), + make_stride(Int{}, _1{})))); + using SmemLayoutdSt = + decltype(cute::composition(SmemLayoutdS{}, + make_layout(make_shape(get<1>(TileShape_MNK{}), get<0>(TileShape_MNK{})), + make_stride(Int{}, _1{})))); + + using SmemLayoutdK = decltype(tile_to_shape(SmemLayoutAtomK{}, select<1, 2>(TileShape_MNK{}))); + using SmemLayoutdV = SmemLayoutdK; + using SmemLayoutdKt = SmemLayoutKt; + using SmemLayoutdVt = SmemLayoutKt; + using SmemLayoutdQTMA = decltype(tile_to_shape(SmemLayoutAtomK{}, select<0, 2>(TileShape_MNK{}))); + + static constexpr int kSwizzle = kBlockKSmem == 32 ? 2 : 3; + using SmemLayoutAtomdQ = decltype( + composition(Swizzle{}, + Layout>, + Stride, _1>>{})); + using SmemLayoutdQ = decltype(tile_to_shape( + SmemLayoutAtomdQ{}, + make_shape(Int{}, Int{}))); + using SmemLayoutdQt = + decltype(cute::composition(SmemLayoutdQ{}, + make_layout(make_shape(get<2>(TileShape_MNK{}), get<0>(TileShape_MNK{})), + make_stride(Int{}, _1{})))); + static constexpr int kSmemdQSize = size(SmemLayoutdQ{}) * sizeof(Element); + + using SmemLayoutAtomdKV = decltype( + composition(Swizzle{}, + Layout>, + Stride, _1>>{})); + using SmemLayoutdKV = decltype(tile_to_shape( + SmemLayoutAtomdKV{}, + make_shape(Int{}, Int{}))); + using SmemLayoutdKVt = + decltype(cute::composition(SmemLayoutdKV{}, + make_layout(make_shape(get<2>(TileShape_MNK{}), get<1>(TileShape_MNK{})), + make_stride(Int{}, _1{})))); + static constexpr int kSmemdKVSize = size(SmemLayoutdKV{}) * sizeof(Element) * 2; + + // using SmemCopyAtomQ = Copy_Atom; + using SmemCopyAtomPdS = Copy_Atom< + std::conditional_t, + Element>; + using SmemCopyAtomdKV = Copy_Atom< + std::conditional_t, + Element>; + using SmemCopyAtomdQ = Copy_Atom< + std::conditional_t, + Element>; + + using SharedStorage = SharedStorageQKVdOdKVSeqqPar; + + // using MainloopPipeline = typename cutlass::PipelineTmaAsync; + // using PipelineState = typename cutlass::PipelineState; + using MainloopPipeline = typename cutlass::PipelineTmaAsync; + +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/hopper/mainloop_fwd_sm90_tma_gmma_ws.hpp b/hopper/mainloop_fwd_sm90_tma_gmma_ws.hpp new file mode 100644 index 000000000..f9dc94a23 --- /dev/null +++ b/hopper/mainloop_fwd_sm90_tma_gmma_ws.hpp @@ -0,0 +1,486 @@ +/****************************************************************************** + * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. + ******************************************************************************/ + +#pragma once + +#include +#include +#include +#include +#include "cutlass/pipeline/pipeline.hpp" + +#include "cute/tensor.hpp" + +#include "cutlass/gemm/collective/collective_builder.hpp" + +#include "named_barrier.hpp" +#include "utils.h" + +namespace flash { + +using namespace cute; + +template +struct CollectiveMainloopFwd { + + using Element = typename Ktraits::Element; + using TileShape_MNK = typename Ktraits::TileShape_MNK; + using ClusterShape = typename Ktraits::ClusterShape_MNK; + + static constexpr int kStages = Ktraits::kStages; + static constexpr int kHeadDim = Ktraits::kHeadDim; + + using GmemTiledCopyQ = cute::SM90_TMA_LOAD; + using GmemTiledCopyKV = decltype(cutlass::gemm::collective::detail::sm90_cluster_shape_to_tma_atom(shape<0>(ClusterShape{}))); + + using SmemLayoutAtomQ = decltype(cutlass::gemm::collective::detail::ss_smem_selector(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>()); + using SmemLayoutQ = decltype(tile_to_shape(SmemLayoutAtomQ{}, select<0, 2>(TileShape_MNK{}))); + + using SmemLayoutAtomK = decltype(cutlass::gemm::collective::detail::ss_smem_selector(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>()); + using SmemLayoutK = + decltype(tile_to_shape(SmemLayoutAtomK{}, + make_shape(shape<1>(TileShape_MNK{}), shape<2>(TileShape_MNK{}), Int{}))); + using SmemLayoutV = SmemLayoutK; + // Note this is the transpose in terms of the view, not in terms of memory. + using SmemLayoutVt = + decltype(cute::composition(SmemLayoutV{}, + make_layout(make_shape(get<2>(TileShape_MNK{}), get<1>(TileShape_MNK{}), Int{}), + make_stride(get<1>(TileShape_MNK{}), _1{}, Int{})))); + // using SmemLayoutAtomVt = cute::GMMA::Layout_MN_SW128_Atom; + // using SmemLayoutVt = + // decltype(tile_to_shape(SmemLayoutAtomVt{}, + // make_shape(shape<2>(TileShape_MNK{}), shape<1>(TileShape_MNK{}), Int{}), + // Step<_2, _1, _3>{})); // This gives correct results, without Step it's wrong + // using SmemLayoutAtomVt = decltype(cutlass::gemm::collective::detail::ss_smem_selector(TileShape_MNK{})), decltype(cute::get<1>(TileShape_MNK{}))>()); + // using SmemLayoutVt = + // decltype(tile_to_shape(SmemLayoutAtomVt{}, + // make_shape(shape<2>(TileShape_MNK{}), shape<1>(TileShape_MNK{}), Int{}))); + // using SmemLayoutAtomVTMA = cute::GMMA::Layout_K_SW128_Atom; + // using SmemLayoutVTMA = + // decltype(tile_to_shape(SmemLayoutAtomVTMA{}, + // make_shape(shape<1>(TileShape_MNK{}), shape<2>(TileShape_MNK{}), Int{}))); + + using ShapeQKV = cute::Shape; // (seqlen, d, head, batch) + using StrideQKV = cute::Stride; + + using TMA_Q = decltype(make_tma_copy( + GmemTiledCopyQ{}, + make_tensor(make_gmem_ptr(static_cast(nullptr)), repeat_like(StrideQKV{}, int32_t(0)), StrideQKV{}), + SmemLayoutQ{}, + select<0, 2>(TileShape_MNK{}), + _1{})); // no mcast for Q + + using TMA_KV = decltype(make_tma_copy( + GmemTiledCopyKV{}, + make_tensor(make_gmem_ptr(static_cast(nullptr)), repeat_like(StrideQKV{}, int32_t(0)), StrideQKV{}), + take<0, 2>(SmemLayoutK{}), + select<1, 2>(TileShape_MNK{}), + size<0>(ClusterShape{}))); // mcast along M mode for this N load, if any + + static constexpr int NumMmaThreads = size(typename Ktraits::TiledMma0{}); + using MainloopPipeline = typename Ktraits::MainloopPipeline; + using PipelineParams = typename MainloopPipeline::Params; + using PipelineState = typename MainloopPipeline::PipelineState; + + // Set the bytes transferred in this TMA transaction (may involve multiple issues) + static constexpr uint32_t TmaTransactionBytesQ = static_cast(size(SmemLayoutQ{}) * cutlass::sizeof_bits_v / 8); + static constexpr uint32_t TmaTransactionBytesK = static_cast(size(take<0, 2>(SmemLayoutK{})) * cutlass::sizeof_bits_v / 8); + + static constexpr bool UseSchedulerBarrier = kHeadDim <= 128; + + // Host side kernel arguments + struct Arguments { + Element const* ptr_Q; + ShapeQKV const shape_Q; + StrideQKV const stride_Q; + Element const* ptr_K; + ShapeQKV const shape_K; + StrideQKV const stride_K; + Element const* ptr_V; + StrideQKV const stride_V; + float const softmax_scale_log2; + }; + + // Device side kernel params + struct Params { + ShapeQKV const shape_Q; + ShapeQKV const shape_K; + cutlass::FastDivmod qhead_per_khead_divmod; + TMA_Q tma_load_Q; + TMA_KV tma_load_K, tma_load_V; + float const softmax_scale_log2; + }; + + + static Params + to_underlying_arguments(Arguments const& args) { + Tensor mQ = make_tensor(make_gmem_ptr(args.ptr_Q), args.shape_Q, args.stride_Q); + TMA_Q tma_load_Q = make_tma_copy( + GmemTiledCopyQ{}, + mQ, + SmemLayoutQ{}, + select<0, 2>(TileShape_MNK{}), + _1{}); // no mcast for Q + Tensor mK = make_tensor(make_gmem_ptr(args.ptr_K), args.shape_K, args.stride_K); + TMA_KV tma_load_K = make_tma_copy( + GmemTiledCopyKV{}, + mK, + SmemLayoutK{}(_, _, _0{}), + select<1, 2>(TileShape_MNK{}), + size<0>(ClusterShape{})); // mcast along M mode for this N load, if any + Tensor mV = make_tensor(make_gmem_ptr(args.ptr_V), args.shape_K, args.stride_V); + TMA_KV tma_load_V = make_tma_copy( + GmemTiledCopyKV{}, + mV, + SmemLayoutV{}(_, _, _0{}), + select<1, 2>(TileShape_MNK{}), + size<0>(ClusterShape{})); // mcast along M mode for this N load, if any + return {args.shape_Q, args.shape_K, + cutlass::FastDivmod(cute::ceil_div(get<2>(args.shape_Q), get<2>(args.shape_K))), + tma_load_Q, tma_load_K, tma_load_V, + args.softmax_scale_log2}; + } + + /// Issue Tma Descriptor Prefetch -- ideally from a single thread for best performance + CUTLASS_DEVICE + static void prefetch_tma_descriptors(Params const& mainloop_params) { + cute::prefetch_tma_descriptor(mainloop_params.tma_load_Q.get_tma_descriptor()); + cute::prefetch_tma_descriptor(mainloop_params.tma_load_K.get_tma_descriptor()); + cute::prefetch_tma_descriptor(mainloop_params.tma_load_V.get_tma_descriptor()); + } + + CUTLASS_DEVICE + int get_n_block_max(Params const& mainloop_params, int m_block) { + static constexpr int kBlockM = get<0>(TileShape_MNK{}); + static constexpr int kBlockN = get<1>(TileShape_MNK{}); + int const seqlen_q = get<0>(mainloop_params.shape_Q); + int const seqlen_k = get<0>(mainloop_params.shape_K); + int n_block_max = cute::ceil_div(seqlen_k, kBlockN); + if constexpr (Is_causal) { + n_block_max = std::min(n_block_max, + cute::ceil_div((m_block + 1) * kBlockM + seqlen_k - seqlen_q, kBlockN)); + } + return n_block_max; + } + + template + CUTLASS_DEVICE void + load(Params const& mainloop_params, + MainloopPipeline pipeline_k, + MainloopPipeline pipeline_v, + PipelineState& smem_pipe_write_k, + PipelineState& smem_pipe_write_v, + SharedStorage &shared_storage, + Scheduler& scheduler, + typename Scheduler::Params const& scheduler_params, + typename Scheduler::WorkTileInfo& work_tile_info, + cute::tuple block_coord, + int work_idx + ) { + + Tensor sQ = make_tensor(make_smem_ptr(shared_storage.smem_q.data()), SmemLayoutQ{}); + Tensor sK = make_tensor(make_smem_ptr(shared_storage.smem_k.data()), SmemLayoutK{}); + Tensor sV = make_tensor(make_smem_ptr(shared_storage.smem_v.data()), SmemLayoutV{}); + + Tensor mQ = mainloop_params.tma_load_Q.get_tma_tensor(mainloop_params.shape_Q); + Tensor mK = mainloop_params.tma_load_K.get_tma_tensor(mainloop_params.shape_K); + Tensor mV = mainloop_params.tma_load_V.get_tma_tensor(mainloop_params.shape_K); + + auto [m_block, bidh, bidb] = block_coord; + int bidh_kv = mainloop_params.qhead_per_khead_divmod.divide(bidh); + + // Prepare the TMA loads + uint32_t block_rank_in_cluster = cute::block_rank_in_cluster(); + constexpr uint32_t cluster_shape_x = get<0>(ClusterShape()); + uint2 cluster_local_block_id = {block_rank_in_cluster % cluster_shape_x, block_rank_in_cluster / cluster_shape_x}; + Tensor gQ = local_tile(mQ(_, _, bidh, bidb), select<0, 2>(TileShape_MNK{}), make_coord(m_block, _0{})); // (M, K) + Tensor gK = local_tile(mK(_, _, bidh_kv, bidb), select<1, 2>(TileShape_MNK{}), make_coord(_, _0{})); // (N, K, _) + Tensor gV = local_tile(mV(_, _, bidh_kv, bidb), select<1, 2>(TileShape_MNK{}), make_coord(_, _0{})); // (N, K, _) + + Tensor sQ_x = make_tensor(sQ.data(), make_layout(sQ.layout(), Layout<_1>{})); + Tensor gQ_x = make_tensor(gQ.data(), make_layout(gQ.layout(), Layout<_1>{})); + auto [tQgQ, tQsQ] = tma_partition(mainloop_params.tma_load_Q, _0{}, Layout<_1>{}, + group_modes<0, 2>(sQ_x), group_modes<0, 2>(gQ_x)); // (TMA), (TMA) + auto [tKgK, tKsK] = tma_partition(mainloop_params.tma_load_K, block_rank_in_cluster, Layout{}, + group_modes<0, 2>(sK), group_modes<0, 2>(gK)); // (TMA, k), (TMA, PIPE) + auto [tVgV, tVsV] = tma_partition(mainloop_params.tma_load_V, block_rank_in_cluster, Layout{}, + group_modes<0, 2>(sV), group_modes<0, 2>(gV)); // (TMA, k), (TMA, PIPE) + + uint16_t mcast_mask_kv = 0; + if constexpr (cute::is_same_v) { + auto block_layout = Layout{}; // (m,n) -> block_id + for (int m = 0; m < size<0>(block_layout); ++m) { + mcast_mask_kv |= (uint16_t(1) << block_layout(m, cluster_local_block_id.y, _0{})); + } + } + + int n_block_max = get_n_block_max(mainloop_params, m_block); + int n_block = n_block_max - 1; + + int lane_predicate = cute::elect_one_sync(); + if (lane_predicate) { + pipeline_k.producer_acquire(smem_pipe_write_k); + copy(mainloop_params.tma_load_K.with(*pipeline_k.producer_get_barrier(smem_pipe_write_k), mcast_mask_kv), + tKgK(_, n_block), tKsK(_, smem_pipe_write_k.index())); + ++smem_pipe_write_k; + } + + // Wait for the MMA warpgroups to say that smem_q is ready + cutlass::arch::NamedBarrier::sync(NumMmaThreads + cutlass::NumThreadsPerWarp, static_cast(FwdNamedBarriers::QueryEmpty) /*id*/); + + if (lane_predicate) { + shared_storage.barrier_Q.arrive_and_expect_tx(TmaTransactionBytesQ); + copy(mainloop_params.tma_load_Q.with(reinterpret_cast(shared_storage.barrier_Q), 0 /*mcast_mask*/), tQgQ, tQsQ); + } + + // Wait for warp 1 to signal that smem_v are ready and V can be copied from gmem + // Need ClusterBarrier, not just NamedBarrier. Otherwise we might have CTA 0 finishing the + // TMA store on O first, call TMA multicast load on V, before CTA 1 can finishing TMA store on O. + shared_storage.barrier_O.wait((work_idx + 1) % 2); + + if (lane_predicate) { + // CUTLASS_PRAGMA_NO_UNROLL + #pragma unroll 2 + for (; n_block > 0; --n_block) { + pipeline_k.producer_acquire(smem_pipe_write_k); + copy(mainloop_params.tma_load_K.with(*pipeline_k.producer_get_barrier(smem_pipe_write_k), mcast_mask_kv), + tKgK(_, n_block - 1), tKsK(_, smem_pipe_write_k.index())); + ++smem_pipe_write_k; + pipeline_v.producer_acquire(smem_pipe_write_v); + copy(mainloop_params.tma_load_V.with(*pipeline_v.producer_get_barrier(smem_pipe_write_v), mcast_mask_kv), + tVgV(_, n_block), tVsV(_, smem_pipe_write_v.index())); + ++smem_pipe_write_v; + } + } + scheduler.prefetch_next_work(scheduler_params, work_tile_info); + if (lane_predicate) { + pipeline_v.producer_acquire(smem_pipe_write_v); + copy(mainloop_params.tma_load_V.with(*pipeline_v.producer_get_barrier(smem_pipe_write_v), mcast_mask_kv), + tVgV(_, n_block), tVsV(_, smem_pipe_write_v.index())); + ++smem_pipe_write_v; + } + scheduler.broadcast_next_work(work_tile_info); + } + + /// Perform a Producer Epilogue to prevent early exit of blocks in a Cluster + CUTLASS_DEVICE void + load_tail(MainloopPipeline pipeline_k, MainloopPipeline pipeline_v, + PipelineState& smem_pipe_write_k, PipelineState& smem_pipe_write_v) { + int lane_predicate = cute::elect_one_sync(); + // Issue the epilogue waits + if (lane_predicate) { + /* This helps avoid early exit of blocks in Cluster + * Waits for all stages to either be released (all Consumer UNLOCKs), or if the stage was never used + * then would just be acquired since the phase was still inverted from make_producer_start_state + */ + pipeline_k.producer_tail(smem_pipe_write_k); + pipeline_v.producer_tail(smem_pipe_write_v); + } + } + + CUTLASS_DEVICE void + warp_scheduler_barrier_sync() { + if constexpr (UseSchedulerBarrier) { + cutlass::arch::NamedBarrier::sync(NumMmaThreads, static_cast(FwdNamedBarriers::WarpSchedulerWG1) - 1 + cutlass::canonical_warp_group_idx() /*id*/); + } + } + + CUTLASS_DEVICE void + warp_scheduler_barrier_arrive() { + if constexpr (!UseSchedulerBarrier) { return; } + static_assert(NumMmaThreads == 2 * cutlass::NumThreadsPerWarpGroup || NumMmaThreads == 3 * cutlass::NumThreadsPerWarpGroup); + if constexpr (NumMmaThreads == 2 * cutlass::NumThreadsPerWarpGroup) { + cutlass::arch::NamedBarrier::arrive(NumMmaThreads, static_cast(FwdNamedBarriers::WarpSchedulerWG1) - 1 + (3 - cutlass::canonical_warp_group_idx()) /*id*/); + } else { + cutlass::arch::NamedBarrier::arrive(NumMmaThreads, static_cast(FwdNamedBarriers::WarpSchedulerWG1) - 1 + (cutlass::canonical_warp_group_idx() <= 2 ? cutlass::canonical_warp_group_idx() + 1 : cutlass::canonical_warp_group_idx() + 1 - 3) /*id*/); + cutlass::arch::NamedBarrier::arrive(NumMmaThreads, static_cast(FwdNamedBarriers::WarpSchedulerWG1) - 1 + (cutlass::canonical_warp_group_idx() <= 1 ? cutlass::canonical_warp_group_idx() + 2 : cutlass::canonical_warp_group_idx() + 2 - 3) /*id*/); + } + } + + CUTLASS_DEVICE void + mma_init() { + // Tell producer (warp 0) that smem_q is ready + cutlass::arch::NamedBarrier::arrive(NumMmaThreads + cutlass::NumThreadsPerWarp, static_cast(FwdNamedBarriers::QueryEmpty) /*id*/); + if constexpr (!UseSchedulerBarrier) { return; } + static_assert(NumMmaThreads == 2 * cutlass::NumThreadsPerWarpGroup || NumMmaThreads == 3 * cutlass::NumThreadsPerWarpGroup); + if (cutlass::canonical_warp_group_idx() > 1) { + cutlass::arch::NamedBarrier::arrive(NumMmaThreads, static_cast(FwdNamedBarriers::WarpSchedulerWG1) - 1 + 1 /*id*/); + } + if constexpr (NumMmaThreads == 3 * cutlass::NumThreadsPerWarpGroup) { + if (cutlass::canonical_warp_group_idx() > 2) { + cutlass::arch::NamedBarrier::arrive(NumMmaThreads, static_cast(FwdNamedBarriers::WarpSchedulerWG1) - 1 + 2 /*id*/); + } + } + + } + + template + CUTLASS_DEVICE void + mma(Params const& mainloop_params, + MainloopPipeline pipeline_k, + MainloopPipeline pipeline_v, + PipelineState& smem_pipe_read_k, + PipelineState& smem_pipe_read_v, + FrgTensorO& tOrO, + Softmax& softmax, + int n_block_count, + int thread_idx, + int work_idx, + int m_block, + SharedStorage& shared_storage + ) { + static_assert(is_rmem::value, "O tensor must be rmem resident."); + + static constexpr int kBlockM = get<0>(TileShape_MNK{}); + static constexpr int kBlockN = get<1>(TileShape_MNK{}); + + Tensor sQ = make_tensor(make_smem_ptr(shared_storage.smem_q.data()), SmemLayoutQ{}); + Tensor sK = make_tensor(make_smem_ptr(shared_storage.smem_k.data()), SmemLayoutK{}); + Tensor sVt = make_tensor(make_smem_ptr(shared_storage.smem_v.data()), SmemLayoutVt{}); + + typename Ktraits::TiledMma0 tiled_mma0; + typename Ktraits::TiledMma1 tiled_mma1; + auto threadMma0 = tiled_mma0.get_thread_slice(thread_idx); + auto threadMma1 = tiled_mma1.get_thread_slice(thread_idx); + + // Allocate "fragments/descriptors" for first matmul. + Tensor tSrQ = threadMma0.partition_fragment_A(sQ); + Tensor tSrK = threadMma0.partition_fragment_B(sK); + // Allocate "fragments/descriptors" for second matmul. + // Note: S becomes P. + Tensor tOrV = threadMma1.partition_fragment_B(sVt); + + auto consumer_wait = [](auto& pipeline, auto& smem_pipe_read) { + auto barrier_token = pipeline.consumer_try_wait(smem_pipe_read); + pipeline.consumer_wait(smem_pipe_read, barrier_token); + }; + + tiled_mma1.accumulate_ = GMMA::ScaleOut::Zero; + int const seqlen_q = get<0>(mainloop_params.shape_Q); + int const seqlen_k = get<0>(mainloop_params.shape_K); + int n_block = n_block_count - 1; + + cutlass::ConsumerToken barrier_token = static_cast(shared_storage.barrier_Q.try_wait(work_idx % 2)); + if (barrier_token == cutlass::BarrierStatus::WaitAgain) { shared_storage.barrier_Q.wait(work_idx % 2); } + + Tensor tSrS = partition_fragment_C(tiled_mma0, select<0, 1>(TileShape_MNK{})); + consumer_wait(pipeline_k, smem_pipe_read_k); + warp_scheduler_barrier_sync(); + flash::gemm(tiled_mma0, tSrQ, tSrK(_, _, _, smem_pipe_read_k.index()), tSrS); + warp_scheduler_barrier_arrive(); + if (work_idx != 0) { + int lane_predicate = cute::elect_one_sync(); + if (cutlass::canonical_warp_idx_sync() == Ktraits::kNWarps - 1 && lane_predicate) { + tma_store_wait<0>(); + #pragma unroll + for (uint32_t cta_id = 0; cta_id < size(ClusterShape{}); ++cta_id) { + shared_storage.barrier_O.arrive(cta_id, lane_predicate); + } + } + } + warpgroup_wait<0>(); + pipeline_k.consumer_release(smem_pipe_read_k); + ++smem_pipe_read_k; + + auto col_limit_causal = [&](int row, int n_block) { + return row + 1 + seqlen_k - n_block * kBlockN - seqlen_q + m_block * kBlockM; + }; + { + Tensor cS = cute::make_identity_tensor(select<0, 1>(TileShape_MNK{})); + Tensor tScS = threadMma0.partition_C(cS); + #pragma unroll + for (int i = 0; i < size(tSrS); ++i) { + if constexpr (!Is_causal) { // Just masking based on col + if (int(get<1>(tScS(i))) >= int(seqlen_k - n_block * kBlockN)) { tSrS(i) = -INFINITY; } + } else { // mask based on both row and col + // using std::min is faster than doing col >= limit0 or col >= limit1 + // Need to cast get<1>(tScS(i)) to (signed) int since by default it's unsigned, and the + // right hand side can be negative and might be converted to a very large unsigned integer. + if (int(get<1>(tScS(i))) >= std::min(seqlen_k - n_block * kBlockN, + col_limit_causal(int(get<0>(tScS(i))), n_block))) { + tSrS(i) = -INFINITY; + } + } + } + } + + softmax.template online_softmax(tSrS, mainloop_params.softmax_scale_log2); + Tensor tOrP = make_tensor(convert_type(tSrS).data(), convert_layout_acc_Aregs(tSrS.layout())); + Tensor scores_scale = make_fragment_like(softmax.row_max); + clear(scores_scale); + + constexpr int n_masking_steps = !Is_causal ? 1 : cute::ceil_div(kBlockM, kBlockN) + 1; + // Only go through these if Is_causal, since n_masking_steps = 1 when !Is_causal + #pragma unroll + for (int masking_step = 0; masking_step < n_masking_steps - 1 && n_block > 0; ++masking_step, --n_block) { + Tensor tSrS = partition_fragment_C(tiled_mma0, select<0, 1>(TileShape_MNK{})); + consumer_wait(pipeline_k, smem_pipe_read_k); + warp_scheduler_barrier_sync(); + flash::gemm(tiled_mma0, tSrQ, tSrK(_, _, _, smem_pipe_read_k.index()), tSrS); + if (masking_step > 0) { softmax.rescale_o(tOrO, scores_scale); } + consumer_wait(pipeline_v, smem_pipe_read_v); + flash::gemm(tiled_mma1, tOrP, tOrV(_, _, _, smem_pipe_read_v.index()), tOrO); + warp_scheduler_barrier_arrive(); + warpgroup_wait<1>(); + pipeline_k.consumer_release(smem_pipe_read_k); // release K + Tensor cS = cute::make_identity_tensor(select<0, 1>(TileShape_MNK{})); + Tensor tScS = threadMma0.partition_C(cS); + #pragma unroll + for (int i = 0; i < size(tSrS); ++i) { + if (int(get<1>(tScS(i))) >= col_limit_causal(int(get<0>(tScS(i))), n_block - 1)) { + tSrS(i) = -INFINITY; + } + } + cute::copy(softmax.template max(tSrS, mainloop_params.softmax_scale_log2), scores_scale); + softmax.template online_softmax(tSrS, mainloop_params.softmax_scale_log2); + warpgroup_wait<0>(); + pipeline_v.consumer_release(smem_pipe_read_v); // release V + ++smem_pipe_read_k; + ++smem_pipe_read_v; + cute::copy(make_tensor(convert_type(tSrS).data(), convert_layout_acc_Aregs(tSrS.layout())), tOrP); + } + + #pragma unroll 1 + for (; n_block > 0; --n_block) { + Tensor tSrS = partition_fragment_C(tiled_mma0, select<0, 1>(TileShape_MNK{})); + consumer_wait(pipeline_k, smem_pipe_read_k); + warp_scheduler_barrier_sync(); + flash::gemm(tiled_mma0, tSrQ, tSrK(_, _, _, smem_pipe_read_k.index()), tSrS); + softmax.rescale_o(tOrO, scores_scale); + consumer_wait(pipeline_v, smem_pipe_read_v); + flash::gemm(tiled_mma1, tOrP, tOrV(_, _, _, smem_pipe_read_v.index()), tOrO); + warp_scheduler_barrier_arrive(); + warpgroup_wait<1>(); + pipeline_k.consumer_release(smem_pipe_read_k); // release K + // auto scores_scale = softmax.template max(tSrS); + cute::copy(softmax.template max(tSrS, mainloop_params.softmax_scale_log2), scores_scale); + softmax.template online_softmax(tSrS, mainloop_params.softmax_scale_log2); + warpgroup_wait<0>(); + pipeline_v.consumer_release(smem_pipe_read_v); // release V + ++smem_pipe_read_k; + ++smem_pipe_read_v; + // softmax.rescale_o(tOrO, scores_scale); + cute::copy(make_tensor(convert_type(tSrS).data(), convert_layout_acc_Aregs(tSrS.layout())), tOrP); + } + // Tell warp 0 that smem_q is ready + cutlass::arch::NamedBarrier::arrive(NumMmaThreads + cutlass::NumThreadsPerWarp, static_cast(FwdNamedBarriers::QueryEmpty) /*id*/); + softmax.rescale_o(tOrO, scores_scale); + consumer_wait(pipeline_v, smem_pipe_read_v); + flash::gemm(tiled_mma1, tOrP, tOrV(_, _, _, smem_pipe_read_v.index()), tOrO); + cute::copy(softmax.template finalize(tSrS, mainloop_params.softmax_scale_log2), scores_scale); + warpgroup_wait<0>(); + pipeline_v.consumer_release(smem_pipe_read_v); // release V, otherwise producers will hang + ++smem_pipe_read_v; + + softmax.rescale_o(tOrO, scores_scale); + return; + } + +}; + +} // namespace flash + diff --git a/hopper/named_barrier.hpp b/hopper/named_barrier.hpp new file mode 100644 index 000000000..ac5f616f2 --- /dev/null +++ b/hopper/named_barrier.hpp @@ -0,0 +1,23 @@ +/****************************************************************************** + * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. + ******************************************************************************/ + +#pragma once + +#include "cutlass/arch/barrier.h" + +namespace flash { + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// Enumerates the reserved named barriers to avoid potential conflicts +enum class FwdNamedBarriers { + QueryEmpty = 0, + ValueEmpty = 1, + TileCountSmemEmpty = 2, + TileCountSmemFull = 3, + WarpSchedulerWG1 = 4, + WarpSchedulerWG2 = 5, + WarpSchedulerWG3 = 6, +}; + +} // flash \ No newline at end of file diff --git a/hopper/setup.py b/hopper/setup.py new file mode 100644 index 000000000..35a074af0 --- /dev/null +++ b/hopper/setup.py @@ -0,0 +1,292 @@ +# Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. + +import sys +import warnings +import os +import re +import shutil +import ast +from pathlib import Path +from packaging.version import parse, Version +import platform + +from setuptools import setup, find_packages +import subprocess + +import urllib.request +import urllib.error +from wheel.bdist_wheel import bdist_wheel as _bdist_wheel + +import torch +from torch.utils.cpp_extension import BuildExtension, CppExtension, CUDAExtension, CUDA_HOME + + +# with open("../README.md", "r", encoding="utf-8") as fh: +with open("../README.md", "r", encoding="utf-8") as fh: + long_description = fh.read() + + +# ninja build does not work unless include_dirs are abs path +this_dir = os.path.dirname(os.path.abspath(__file__)) + +PACKAGE_NAME = "flashattn-hopper" + +BASE_WHEEL_URL = "https://github.com/Dao-AILab/flash-attention/releases/download/{tag_name}/{wheel_name}" + +# FORCE_BUILD: Force a fresh build locally, instead of attempting to find prebuilt wheels +# SKIP_CUDA_BUILD: Intended to allow CI to use a simple `python setup.py sdist` run to copy over raw files, without any cuda compilation +FORCE_BUILD = os.getenv("FAHOPPER_FORCE_BUILD", "FALSE") == "TRUE" +SKIP_CUDA_BUILD = os.getenv("FAHOPPER_SKIP_CUDA_BUILD", "FALSE") == "TRUE" +# For CI, we want the option to build with C++11 ABI since the nvcr images use C++11 ABI +FORCE_CXX11_ABI = os.getenv("FAHOPPER_FORCE_CXX11_ABI", "FALSE") == "TRUE" + + +def get_platform(): + """ + Returns the platform name as used in wheel filenames. + """ + if sys.platform.startswith("linux"): + return "linux_x86_64" + elif sys.platform == "darwin": + mac_version = ".".join(platform.mac_ver()[0].split(".")[:2]) + return f"macosx_{mac_version}_x86_64" + elif sys.platform == "win32": + return "win_amd64" + else: + raise ValueError("Unsupported platform: {}".format(sys.platform)) + + +def get_cuda_bare_metal_version(cuda_dir): + raw_output = subprocess.check_output([cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True) + output = raw_output.split() + release_idx = output.index("release") + 1 + bare_metal_version = parse(output[release_idx].split(",")[0]) + + return raw_output, bare_metal_version + + +def check_if_cuda_home_none(global_option: str) -> None: + if CUDA_HOME is not None: + return + # warn instead of error because user could be downloading prebuilt wheels, so nvcc won't be necessary + # in that case. + warnings.warn( + f"{global_option} was requested, but nvcc was not found. Are you sure your environment has nvcc available? " + "If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, " + "only images whose names contain 'devel' will provide nvcc." + ) + + +def append_nvcc_threads(nvcc_extra_args): + return nvcc_extra_args + ["--threads", "4"] + + +cmdclass = {} +ext_modules = [] + +# We want this even if SKIP_CUDA_BUILD because when we run python setup.py sdist we want the .hpp +# files included in the source distribution, in case the user compiles from source. +subprocess.run(["git", "submodule", "update", "--init", "../csrc/cutlass"]) + +if not SKIP_CUDA_BUILD: + print("\n\ntorch.__version__ = {}\n\n".format(torch.__version__)) + TORCH_MAJOR = int(torch.__version__.split(".")[0]) + TORCH_MINOR = int(torch.__version__.split(".")[1]) + + check_if_cuda_home_none("--fahopper") + cc_flag = [] + _, bare_metal_version = get_cuda_bare_metal_version(CUDA_HOME) + if bare_metal_version < Version("12.3"): + raise RuntimeError("FA Hopper is only supported on CUDA 12.3 and above") + cc_flag.append("-gencode") + cc_flag.append("arch=compute_90a,code=sm_90a") + + # HACK: The compiler flag -D_GLIBCXX_USE_CXX11_ABI is set to be the same as + # torch._C._GLIBCXX_USE_CXX11_ABI + # https://github.com/pytorch/pytorch/blob/8472c24e3b5b60150096486616d98b7bea01500b/torch/utils/cpp_extension.py#L920 + if FORCE_CXX11_ABI: + torch._C._GLIBCXX_USE_CXX11_ABI = True + repo_dir = Path(this_dir).parent + cutlass_dir = repo_dir / "csrc" / "cutlass" + sources = [ + "flash_api.cpp", + "flash_fwd_hdim64_fp16_sm90.cu", + "flash_fwd_hdim64_bf16_sm90.cu", + "flash_fwd_hdim128_fp16_sm90.cu", + "flash_fwd_hdim128_bf16_sm90.cu", + "flash_fwd_hdim256_fp16_sm90.cu", + "flash_fwd_hdim256_bf16_sm90.cu", + "flash_bwd_hdim64_fp16_sm90.cu", + "flash_bwd_hdim128_fp16_sm90.cu", + "flash_bwd_hdim256_fp16_sm90.cu", + # "flash_fwd_hdim128_e4m3_sm90.cu", + ] + nvcc_flags = [ + "-O3", + # "-O0", + "-std=c++17", + "-U__CUDA_NO_HALF_OPERATORS__", + "-U__CUDA_NO_HALF_CONVERSIONS__", + "-U__CUDA_NO_BFLOAT16_OPERATORS__", + "-U__CUDA_NO_BFLOAT16_CONVERSIONS__", + "-U__CUDA_NO_BFLOAT162_OPERATORS__", + "-U__CUDA_NO_BFLOAT162_CONVERSIONS__", + "--expt-relaxed-constexpr", + "--expt-extended-lambda", + "--use_fast_math", + # "--ptxas-options=-v", # printing out number of registers + "--ptxas-options=--verbose,--register-usage-level=10,--warn-on-local-memory-usage", # printing out number of registers + "-lineinfo", + "-DCUTLASS_DEBUG_TRACE_LEVEL=0", # Can toggle for debugging + "-DNDEBUG", # Important, otherwise performance is severely impacted + "-DQBLKSIZE=128", + "-DKBLKSIZE=128", + "-DCTA256", + "-DDQINRMEM", + ] + include_dirs = [ + # Path(this_dir) / "fmha-pipeline", + # repo_dir / "lib", + # repo_dir / "include", + cutlass_dir / "include", + # cutlass_dir / "examples" / "common", + # cutlass_dir / "tools" / "util" / "include", + ] + + ext_modules.append( + CUDAExtension( + name="flashattn_hopper_cuda", + sources=sources, + extra_compile_args={ + "cxx": ["-O3", "-std=c++17"], + # "cxx": ["-O0", "-std=c++17"], + "nvcc": append_nvcc_threads( + nvcc_flags + ["-DEXECMODE=0"] + cc_flag + ), + }, + include_dirs=include_dirs, + # Without this we get and error about cuTensorMapEncodeTiled not defined + libraries=["cuda"] + ) + ) + # ext_modules.append( + # CUDAExtension( + # name="flashattn_hopper_cuda_ws", + # sources=sources, + # extra_compile_args={ + # "cxx": ["-O3", "-std=c++17"], + # "nvcc": append_nvcc_threads( + # nvcc_flags + ["-DEXECMODE=1"] + cc_flag + # ), + # }, + # include_dirs=include_dirs, + # # Without this we get and error about cuTensorMapEncodeTiled not defined + # libraries=["cuda"] + # ) + # ) + + +def get_package_version(): + with open(Path(this_dir) / "__init__.py", "r") as f: + version_match = re.search(r"^__version__\s*=\s*(.*)$", f.read(), re.MULTILINE) + public_version = ast.literal_eval(version_match.group(1)) + local_version = os.environ.get("FLASHATTN_HOPPER_LOCAL_VERSION") + if local_version: + return f"{public_version}+{local_version}" + else: + return str(public_version) + + +def get_wheel_url(): + # Determine the version numbers that will be used to determine the correct wheel + # We're using the CUDA version used to build torch, not the one currently installed + # _, cuda_version_raw = get_cuda_bare_metal_version(CUDA_HOME) + torch_cuda_version = parse(torch.version.cuda) + torch_version_raw = parse(torch.__version__) + # For CUDA 11, we only compile for CUDA 11.8, and for CUDA 12 we only compile for CUDA 12.2 + # to save CI time. Minor versions should be compatible. + torch_cuda_version = parse("11.8") if torch_cuda_version.major == 11 else parse("12.2") + python_version = f"cp{sys.version_info.major}{sys.version_info.minor}" + platform_name = get_platform() + package_version = get_package_version() + # cuda_version = f"{cuda_version_raw.major}{cuda_version_raw.minor}" + cuda_version = f"{torch_cuda_version.major}{torch_cuda_version.minor}" + torch_version = f"{torch_version_raw.major}.{torch_version_raw.minor}" + cxx11_abi = str(torch._C._GLIBCXX_USE_CXX11_ABI).upper() + + # Determine wheel URL based on CUDA version, torch version, python version and OS + wheel_filename = f"{PACKAGE_NAME}-{package_version}+cu{cuda_version}torch{torch_version}cxx11abi{cxx11_abi}-{python_version}-{python_version}-{platform_name}.whl" + wheel_url = BASE_WHEEL_URL.format(tag_name=f"v{package_version}", wheel_name=wheel_filename) + return wheel_url, wheel_filename + + +class CachedWheelsCommand(_bdist_wheel): + """ + The CachedWheelsCommand plugs into the default bdist wheel, which is ran by pip when it cannot + find an existing wheel (which is currently the case for all installs). We use + the environment parameters to detect whether there is already a pre-built version of a compatible + wheel available and short-circuits the standard full build pipeline. + """ + + def run(self): + if FORCE_BUILD: + return super().run() + + wheel_url, wheel_filename = get_wheel_url() + print("Guessing wheel URL: ", wheel_url) + try: + urllib.request.urlretrieve(wheel_url, wheel_filename) + + # Make the archive + # Lifted from the root wheel processing command + # https://github.com/pypa/wheel/blob/cf71108ff9f6ffc36978069acb28824b44ae028e/src/wheel/bdist_wheel.py#LL381C9-L381C85 + if not os.path.exists(self.dist_dir): + os.makedirs(self.dist_dir) + + impl_tag, abi_tag, plat_tag = self.get_tag() + archive_basename = f"{self.wheel_dist_name}-{impl_tag}-{abi_tag}-{plat_tag}" + + wheel_path = os.path.join(self.dist_dir, archive_basename + ".whl") + print("Raw wheel path", wheel_path) + shutil.move(wheel_filename, wheel_path) + except urllib.error.HTTPError: + print("Precompiled wheel not found. Building from source...") + # If the wheel could not be downloaded, build from source + super().run() + +setup( + name=PACKAGE_NAME, + version=get_package_version(), + packages=find_packages( + exclude=( + "build", + "csrc", + "include", + "tests", + "dist", + "docs", + "benchmarks", + ) + ), + description="FlashAttention-3", + long_description=long_description, + long_description_content_type="text/markdown", + classifiers=[ + "Programming Language :: Python :: 3", + "License :: OSI Approved :: Apache Software License", + "Operating System :: Unix", + ], + ext_modules=ext_modules, + cmdclass={"bdist_wheel": CachedWheelsCommand, "build_ext": BuildExtension} + if ext_modules + else { + "bdist_wheel": CachedWheelsCommand, + }, + python_requires=">=3.8", + install_requires=[ + "torch", + "einops", + "packaging", + "ninja", + ], +) diff --git a/hopper/softmax.h b/hopper/softmax.h new file mode 100644 index 000000000..5adf34fa2 --- /dev/null +++ b/hopper/softmax.h @@ -0,0 +1,226 @@ +/****************************************************************************** + * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. + ******************************************************************************/ + +#pragma once + +#include + +#include + +#include + +#include "utils.h" + +namespace flash { + +using namespace cute; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +__device__ __forceinline__ void thread_reduce_(Tensor const &tensor, Tensor &summary, Operator &op) { + static_assert(Layout0::rank == 2, "Only support 2D Tensor"); + static_assert(Layout1::rank == 1, "Only support 1D Tensor"); + CUTE_STATIC_ASSERT_V(size<0>(summary) == size<0>(tensor)); + #pragma unroll + for (int mi = 0; mi < size<0>(tensor); mi++) { + summary(mi) = zero_init ? tensor(mi, 0) : op(summary(mi), tensor(mi, 0)); + #pragma unroll + for (int ni = 1; ni < size<1>(tensor); ni++) { + summary(mi) = op(summary(mi), tensor(mi, ni)); + } + } +} + +template +__device__ __forceinline__ void quad_allreduce_(Tensor &dst, Tensor &src, Operator &op) { + CUTE_STATIC_ASSERT_V(size(dst) == size(src)); + #pragma unroll + for (int i = 0; i < size(dst); i++){ + dst(i) = Allreduce<4>::run(src(i), op); + } +} + +template +__device__ __forceinline__ void reduce_(Tensor const& tensor, Tensor &summary, Operator &op) { + thread_reduce_(tensor, summary, op); + quad_allreduce_(summary, summary, op); +} + +template +__device__ __forceinline__ void reduce_max(Tensor const& tensor, Tensor &max){ + MaxOp max_op; + reduce_(tensor, max, max_op); +} + +template +__device__ __forceinline__ void reduce_sum(Tensor const& tensor, Tensor &sum){ + SumOp sum_op; + thread_reduce_(tensor, sum, sum_op); + if constexpr (warp_reduce) { quad_allreduce_(sum, sum, sum_op); } +} + +__forceinline__ __device__ __half2 half_exp(__half2 x) { + uint32_t tmp_out, tmp_in; + tmp_in = reinterpret_cast(x); + asm ("ex2.approx.f16x2 %0, %1;\n" + : "=r"(tmp_out) + : "r"(tmp_in)); + __half2 out = reinterpret_cast<__half2&>(tmp_out); + return out; +} + +// Apply the exp to all the elements. +template +__forceinline__ __device__ void max_scale_exp2_sum(Tensor &tensor, Tensor &max, Tensor &sum, const float scale) { + static_assert(Layout0::rank == 2, "Only support 2D Tensor"); static_assert(Layout1::rank == 1, "Only support 1D Tensor"); CUTE_STATIC_ASSERT_V(size<0>(max) == size<0>(tensor)); + #pragma unroll + for (int mi = 0; mi < size<0>(tensor); ++mi) { + MaxOp max_op; + max(mi) = zero_init ? tensor(mi, 0) : max_op(max(mi), tensor(mi, 0)); + #pragma unroll + for (int ni = 1; ni < size<1>(tensor); ni++) { + max(mi) = max_op(max(mi), tensor(mi, ni)); + } + max(mi) = Allreduce<4>::run(max(mi), max_op); + // If max is -inf, then all elements must have been -inf (possibly due to masking). + // We don't want (-inf - (-inf)) since that would give NaN. + const float max_scaled = max(mi) == -INFINITY ? 0.f : max(mi) * scale; + sum(mi) = 0; + #pragma unroll + for (int ni = 0; ni < size<1>(tensor); ++ni) { + // Instead of computing exp(x - max), we compute exp2(x * log_2(e) - + // max * log_2(e)) This allows the compiler to use the ffma + // instruction instead of fadd and fmul separately. + tensor(mi, ni) = exp2f(tensor(mi, ni) * scale - max_scaled); + sum(mi) += tensor(mi, ni); + } + } +} + +// Apply the exp to all the elements. +template +__forceinline__ __device__ void scale_apply_exp2(Tensor &tensor, Tensor const &max, const float scale) { + static_assert(Layout0::rank == 2, "Only support 2D Tensor"); + static_assert(Layout1::rank == 1, "Only support 1D Tensor"); + CUTE_STATIC_ASSERT_V(size<0>(max) == size<0>(tensor)); + #pragma unroll + for (int mi = 0; mi < size<0>(tensor); ++mi) { + // If max is -inf, then all elements must have been -inf (possibly due to masking). + // We don't want (-inf - (-inf)) since that would give NaN. + // If we don't have float around M_LOG2E the multiplication is done in fp64. + const float max_scaled = Check_inf + ? (max(mi) == -INFINITY ? 0.f : (max(mi) * (Scale_max ? scale : float(M_LOG2E)))) + : (max(mi) * (Scale_max ? scale : float(M_LOG2E))); + #pragma unroll + for (int ni = 0; ni < size<1>(tensor); ++ni) { + // Instead of computing exp(x - max), we compute exp2(x * log_2(e) - + // max * log_2(e)) This allows the compiler to use the ffma + // instruction instead of fadd and fmul separately. + tensor(mi, ni) = exp2f(tensor(mi, ni) * scale - max_scaled); + } + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Softmax { + + using TensorT = decltype(make_tensor(Shape>{})); + TensorT row_max, row_sum; + + CUTLASS_DEVICE Softmax() {}; + + template + __forceinline__ __device__ TensorT max(Tensor0 &acc_s, float softmax_scale_log2) { + // Reshape acc_s from ((2, 2, V), MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, V, MMA_N)) + Tensor scores = make_tensor(acc_s.data(), flash::convert_layout_acc_rowcol(acc_s.layout())); + static_assert(decltype(size<0>(scores))::value == kNRows); + TensorT scores_scale; + if constexpr (Is_first) { + flash::template reduce_max(scores, row_max); + cute::fill(scores_scale, 1.f); + } else { + Tensor scores_max_prev = make_fragment_like(row_max); + cute::copy(row_max, scores_max_prev); + flash::template reduce_max(scores, row_max); + #pragma unroll + for (int mi = 0; mi < size(row_max); ++mi) { + float scores_max_cur = !Check_inf + ? row_max(mi) + : (row_max(mi) == -INFINITY ? 0.0f : row_max(mi)); + scores_scale(mi) = exp2f((scores_max_prev(mi) - scores_max_cur) * softmax_scale_log2); + row_sum(mi) *= scores_scale(mi); + } + } + return scores_scale; + }; + + template + __forceinline__ __device__ TensorT online_softmax(Tensor0 &acc_s, float softmax_scale_log2) { + // Reshape acc_s from ((2, 2, V), MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, V, MMA_N)) + Tensor scores = make_tensor(acc_s.data(), flash::convert_layout_acc_rowcol(acc_s.layout())); + static_assert(decltype(size<0>(scores))::value == kNRows); + TensorT scores_scale; + if constexpr (Is_first) { + flash::template reduce_max(scores, row_max); + flash::template scale_apply_exp2(scores, row_max, softmax_scale_log2); + flash::reduce_sum(scores, row_sum); + cute::fill(scores_scale, 1.f); + // if (cute::thread0()) { print_tensor(scores); printf("\n scale = %f\n", softmax_scale_log2); print_tensor(row_sum); } + } else { + // Tensor scores_max_prev = make_fragment_like(row_max); + // cute::copy(row_max, scores_max_prev); + // flash::template reduce_max(scores, row_max); + // // if (cute::thread0()) { print_tensor(scores); printf("\n"); print_tensor(row_max); printf("\n"); } + // #pragma unroll + // for (int mi = 0; mi < size(row_max); ++mi) { + // float scores_max_cur = !Check_inf + // ? row_max(mi) + // : (row_max(mi) == -INFINITY ? 0.0f : row_max(mi)); + // scores_scale(mi) = exp2f((scores_max_prev(mi) - scores_max_cur) * softmax_scale_log2); + // row_sum(mi) *= scores_scale(mi); + // } + flash::template scale_apply_exp2(scores, row_max, softmax_scale_log2); + // We don't do the reduce across threads here since we don't need to use the row_sum. + // We do that reduce at the end when we need to normalize the softmax. + flash::reduce_sum(scores, row_sum); + } + return scores_scale; + }; + + template + __forceinline__ __device__ TensorT finalize(Tensor0 &acc_s, float softmax_scale_log2, float rp_dropout=1.0) { + // Reshape acc_s from ((2, 2, V), MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, V, MMA_N)) + Tensor scores = make_tensor(acc_s.data(), flash::convert_layout_acc_rowcol(acc_s.layout())); + static_assert(decltype(size<0>(scores))::value == kNRows); + SumOp sum_op; + quad_allreduce_(row_sum, row_sum, sum_op); + TensorT scores_scale; + #pragma unroll + for (int mi = 0; mi < size(row_max); ++mi) { + float sum = row_sum(mi); + float inv_sum = (sum == 0.f || sum != sum) ? 0.f : 1.f / sum; + row_sum(mi) = (sum == 0.f || sum != sum) ? (Split ? -INFINITY : INFINITY) : row_max(mi) * (softmax_scale_log2 * float(M_LN2)) + __logf(sum); + scores_scale(mi) = !Is_dropout ? inv_sum : inv_sum * rp_dropout; + } + return scores_scale; + }; + + template + __forceinline__ __device__ void rescale_o(Tensor1 &acc_o, TensorT const &scores_scale) { + // Reshape acc_o from (MMA=4, MMA_M, MMA_K) to (nrow=(2, MMA_M), ncol=(2, MMA_K)) + Tensor acc_o_rowcol = make_tensor(acc_o.data(), flash::convert_layout_acc_rowcol(acc_o.layout())); + static_assert(decltype(size<0>(acc_o_rowcol))::value == kNRows); + #pragma unroll + for (int mi = 0; mi < size(row_max); ++mi) { + #pragma unroll + for (int ni = 0; ni < size<1>(acc_o_rowcol); ++ni) { acc_o_rowcol(mi, ni) *= scores_scale(mi); } + } + }; + +}; + +} // namespace flash diff --git a/hopper/static_switch.h b/hopper/static_switch.h new file mode 100644 index 000000000..e870643e7 --- /dev/null +++ b/hopper/static_switch.h @@ -0,0 +1,83 @@ +// Inspired by +// https://github.com/NVIDIA/DALI/blob/main/include/dali/core/static_switch.h +// and https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/Dispatch.h + +#pragma once + +/// @param COND - a boolean expression to switch by +/// @param CONST_NAME - a name given for the constexpr bool variable. +/// @param ... - code to execute for true and false +/// +/// Usage: +/// ``` +/// BOOL_SWITCH(flag, BoolConst, [&] { +/// some_function(...); +/// }); +/// ``` +// + +#define BOOL_SWITCH(COND, CONST_NAME, ...) \ + [&] { \ + if (COND) { \ + constexpr static bool CONST_NAME = true; \ + return __VA_ARGS__(); \ + } else { \ + constexpr static bool CONST_NAME = false; \ + return __VA_ARGS__(); \ + } \ + }() + +#define PREC_SWITCH(PRECTYPE, ...) \ + [&] { \ + if (PRECTYPE == 1) { \ + using kPrecType = cutlass::half_t; \ + constexpr static bool kSoftFp16 = false; \ + constexpr static bool kHybrid = false; \ + return __VA_ARGS__(); \ + } else if (PRECTYPE == 2) { \ + using kPrecType = cutlass::float_e4m3_t; \ + constexpr static bool kSoftFp16 = false; \ + constexpr static bool kHybrid = false; \ + return __VA_ARGS__(); \ + } else if (PRECTYPE == 3) { \ + using kPrecType = cutlass::float_e4m3_t; \ + constexpr static bool kSoftFp16 = false; \ + constexpr static bool kHybrid = true; \ + return __VA_ARGS__(); \ + } else if (PRECTYPE == 4) { \ + using kPrecType = cutlass::float_e4m3_t; \ + constexpr static bool kSoftFp16 = true; \ + constexpr static bool kHybrid = false; \ + return __VA_ARGS__(); \ + } \ + }() + +#define HEADDIM_SWITCH(HEADDIM, ...) \ + [&] { \ + if (HEADDIM == 64) { \ + constexpr static int kHeadSize = 64; \ + return __VA_ARGS__(); \ + } else if (HEADDIM == 128) { \ + constexpr static int kHeadSize = 128; \ + return __VA_ARGS__(); \ + } else if (HEADDIM == 256) { \ + constexpr static int kHeadSize = 256; \ + return __VA_ARGS__(); \ + } \ + }() + +#define SEQLEN_SWITCH(USE_VAR_SEQ_LEN, SEQ_LEN_OUT_OF_BOUND_CHECK, ...) \ + [&] { \ + if (!USE_VAR_SEQ_LEN) { \ + if (SEQ_LEN_OUT_OF_BOUND_CHECK) { \ + using kSeqLenTraitsType = FixedSeqLenTraits; \ + return __VA_ARGS__(); \ + } else { \ + using kSeqLenTraitsType = FixedSeqLenTraits; \ + return __VA_ARGS__(); \ + } \ + } else { \ + using kSeqLenTraitsType = VarSeqLenTraits; \ + return __VA_ARGS__(); \ + } \ + }() diff --git a/hopper/test_flash_attn.py b/hopper/test_flash_attn.py new file mode 100644 index 000000000..97852d47e --- /dev/null +++ b/hopper/test_flash_attn.py @@ -0,0 +1,252 @@ +import math + +import pytest +import torch +import torch.nn.functional as F + +from einops import rearrange, repeat +from flash_attn_interface import flash_attn_func + +ABS_TOL = 5e-3 +REL_TOL = 1e-1 + +def construct_local_mask( + seqlen_q, + seqlen_k, + window_size=(-1, -1), # -1 means infinite window size + query_padding_mask=None, + key_padding_mask=None, + device=None, +): + row_idx = rearrange(torch.arange(seqlen_q, device=device, dtype=torch.long), "s -> s 1") + col_idx = torch.arange(seqlen_k, device=device, dtype=torch.long) + sk = ( + seqlen_k + if key_padding_mask is None + else rearrange(key_padding_mask.sum(-1), "b -> b 1 1 1") + ) + sq = ( + seqlen_q + if query_padding_mask is None + else rearrange(query_padding_mask.sum(-1), "b -> b 1 1 1") + ) + if window_size[0] < 0: + return col_idx > row_idx + sk - sq + window_size[1] + else: + sk = torch.full_like(col_idx, seqlen_k) if key_padding_mask is None else sk + return torch.logical_or( + col_idx > torch.minimum(row_idx + sk - sq + window_size[1], sk), + col_idx < row_idx + sk - sq - window_size[0], + ) + +def print_diffs(out, out_ref): + out_1d = out.flatten() + out_ref_1d = out_ref.flatten() + for idx, (e_o, e_o_ref) in enumerate(zip(out_1d, out_ref_1d)): + diff = e_o - e_o_ref + abs_diff = abs(diff) + abs_ref = abs(e_o_ref + 1e-5) + relative_diff = abs_diff / abs_ref + if abs_diff > ABS_TOL or relative_diff > REL_TOL: + print(f"==== diff ==== {idx}, test: {e_o}, ref: {e_o_ref}") + + +def attention_ref( + q, + k, + v, + query_padding_mask=None, + key_padding_mask=None, + attn_bias=None, + dropout_p=0.0, + dropout_mask=None, + causal=False, + upcast=True, + reorder_ops=False, +): + """ + Arguments: + q: (batch_size, seqlen_q, nheads, head_dim) + k: (batch_size, seqlen_k, nheads, head_dim) + v: (batch_size, seqlen_k, nheads, head_dim) + query_padding_mask: (batch_size, seqlen_q) + key_padding_mask: (batch_size, seqlen_k) + attn_bias: broadcastable to (batch_size, nheads, seqlen_q, seqlen_k) + dropout_p: float + dropout_mask: (batch_size, nheads, seqlen_q, seqlen_k) + causal: whether to apply causal masking + upcast: whether to cast all inputs to fp32, do all computation in fp32, then cast + output back to fp16/bf16. + reorder_ops: whether to change the order of operations (scaling k instead of scaling k, etc.) + without changing the math. This is to estimate the numerical error from operation + reordering. + Output: + output: (batch_size, seqlen_q, nheads, head_dim) + attention: (batch_size, nheads, seqlen_q, seqlen_k), softmax after dropout + """ + dtype_og = q.dtype + if upcast: + q, k, v = q.float(), k.float(), v.float() + seqlen_q, seqlen_k = q.shape[1], k.shape[1] + k = repeat(k, "b s h d -> b s (h g) d", g=q.shape[2] // k.shape[2]) + v = repeat(v, "b s h d -> b s (h g) d", g=q.shape[2] // v.shape[2]) + d = q.shape[-1] + if not reorder_ops: + scores = torch.einsum("bthd,bshd->bhts", q / math.sqrt(d), k) + else: + scores = torch.einsum("bthd,bshd->bhts", q, k / math.sqrt(d)) + if key_padding_mask is not None: + scores.masked_fill_(rearrange(~key_padding_mask, "b s -> b 1 1 s"), float("-inf")) + if causal: + local_mask = construct_local_mask( + seqlen_q, + seqlen_k, + (-1, 0), + None, + None, + q.device, + ) + scores.masked_fill_(local_mask, float("-inf")) + if attn_bias is not None: + scores = scores + attn_bias + attention = torch.softmax(scores, dim=-1).to(v.dtype) + # We want to mask here so that the attention matrix doesn't have any NaNs + # Otherwise we'll get NaN in dV + if query_padding_mask is not None: + attention = attention.masked_fill(rearrange(~query_padding_mask, "b s -> b 1 s 1"), 0.0) + # Some rows might be completely masked out so we fill them with zero instead of NaN + if causal: + attention = attention.masked_fill(torch.all(local_mask, dim=-1, keepdim=True), 0.0) + dropout_scaling = 1.0 / (1 - dropout_p) + # attention_drop = attention.masked_fill(~dropout_mask, 0.0) * dropout_scaling + # output = torch.einsum('bhts,bshd->bthd', attention_drop , v) + if dropout_mask is not None: + attention_drop = attention.masked_fill(~dropout_mask, 0.0) + else: + attention_drop = attention + output = torch.einsum("bhts,bshd->bthd", attention_drop, v * dropout_scaling) + if query_padding_mask is not None: + output.masked_fill_(rearrange(~query_padding_mask, "b s -> b s 1 1"), 0.0) + return output.to(dtype=dtype_og), attention.to(dtype=dtype_og) + + + +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) +# @pytest.mark.parametrize("dtype", [torch.bfloat16]) +@pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"]) +# @pytest.mark.parametrize("mha_type", ["gqa"]) +@pytest.mark.parametrize("causal", [False, True]) +# @pytest.mark.parametrize("causal", [True]) +# @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192, 224, 256]) +# @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192]) +# @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192]) +# @pytest.mark.parametrize('d', [56, 80]) +@pytest.mark.parametrize("d", [64, 128, 256]) +# @pytest.mark.parametrize("d", [256]) +@pytest.mark.parametrize( + "seqlen_q,seqlen_k", + [ + (64, 128), + (128, 128), + (256, 256), + (113, 203), + (128, 217), + (113, 211), + (108, 256), + (256, 512), + (384, 256), + (640, 128), + (512, 256), + (1024, 1024), + (1023, 1024), + (1024, 1023), + (2048, 2048), + ], +) +# @pytest.mark.parametrize('seqlen_q,seqlen_k', [(128, 128)]) +def test_flash_attn_output( + seqlen_q, seqlen_k, d, causal, mha_type, dtype +): + device = "cuda" + # set seed + torch.random.manual_seed(0) + # batch_size = 40 + # nheads = 16 + batch_size = 9 + nheads = 6 + nheads_kv = 6 if mha_type == "mha" else (2 if mha_type == "gqa" else 1) + # batch_size = 1 + # nheads = 1 + q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype, requires_grad=True) + k = torch.randn(batch_size, seqlen_k, nheads_kv, d, device=device, dtype=dtype, requires_grad=True) + v = torch.randn(batch_size, seqlen_k, nheads_kv, d, device=device, dtype=dtype, requires_grad=True) + out, lse = flash_attn_func(q, k, v, causal=causal) + out_ref, attn_ref = attention_ref( + q, + k, + v, + None, + None, + causal=causal, + ) + out_pt, attn_pt = attention_ref( + q, + k, + v, + None, + None, + causal=causal, + upcast=False, + reorder_ops=True, + ) + + # qk = torch.einsum('bshd,bthd->bhst', q, k).float() + # m = qk.amax(-1, keepdim=True) + # s_tmp = torch.exp((qk - m) / math.sqrt(d)) + # exp_sum = s_tmp.sum(-1) + # qk = torch.einsum('bthd,bshd->bhts', q.float() / math.sqrt(d), k.float()) + # lse_ref = torch.logsumexp(qk, dim=-1) + + print(f"Output max diff: {(out - out_ref).abs().max().item()}") + print(f"Output mean diff: {(out - out_ref).abs().mean().item()}") + print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}") + print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}") + # if not causal: + # print(f"LSE max diff: {(lse - lse_ref).abs().max().item()}") + # breakpoint() + + # if d <= 128: + # g = torch.randn_like(out) + # do_o = (g.float() * out.float()).sum(-1) + # dq, dk, dv = torch.autograd.grad(out, (q, k, v), g) + # dq_ref, dk_ref, dv_ref = torch.autograd.grad(out_ref, (q, k, v), g) + # dq_pt, dk_pt, dv_pt = torch.autograd.grad(out_pt, (q, k, v), g) + # print(f"dQ max diff: {(dq - dq_ref).abs().max().item()}") + # print(f"dK max diff: {(dk - dk_ref).abs().max().item()}") + # print(f"dV max diff: {(dv - dv_ref).abs().max().item()}") + # print(f"dQ mean diff: {(dq - dq_ref).abs().mean().item()}") + # print(f"dK mean diff: {(dk - dk_ref).abs().mean().item()}") + # print(f"dV mean diff: {(dv - dv_ref).abs().mean().item()}") + # print(f"dQ Pytorch max diff: {(dq_pt - dq_ref).abs().max().item()}") + # print(f"dK Pytorch max diff: {(dk_pt - dk_ref).abs().max().item()}") + # print(f"dV Pytorch max diff: {(dv_pt - dv_ref).abs().max().item()}") + # print(f"dQ Pytorch mean diff: {(dq_pt - dq_ref).abs().mean().item()}") + # print(f"dK Pytorch mean diff: {(dk_pt - dk_ref).abs().mean().item()}") + # print(f"dV Pytorch mean diff: {(dv_pt - dv_ref).abs().mean().item()}") + + # dS = torch.einsum('bthd,bshd->bhts', g.float(), v.float()) + # P = torch.softmax(qk, -1) + # dP = P * (dS - do_o.unsqueeze(1)) + # dQ = torch.einsum('bhts,bshd->bthd', dP, k.float()) + # dV = torch.einsum('bhts,bthd->bshd', P, g.float()) + # dK = torch.einsum('bhts,bthd->bshd', dP, q.float()) + # breakpoint() + + # Check that FlashAttention's numerical error is at most twice the numerical error + # of a Pytorch implementation. + assert (out - out_ref).abs().max().item() <= 2 * (out_pt - out_ref).abs().max().item() + + # if d <= 128: + # assert (dq - dq_ref).abs().max().item() <= 2 * (dq_pt - dq_ref).abs().max().item() + # assert (dk - dk_ref).abs().max().item() <= 2 * (dk_pt - dk_ref).abs().max().item() + # assert (dv - dv_ref).abs().max().item() <= 2 * (dv_pt - dv_ref).abs().max().item() diff --git a/hopper/tile_scheduler.hpp b/hopper/tile_scheduler.hpp new file mode 100644 index 000000000..b27f179ec --- /dev/null +++ b/hopper/tile_scheduler.hpp @@ -0,0 +1,269 @@ +/****************************************************************************** + * Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. + ******************************************************************************/ + +#pragma once + +#include "cutlass/fast_math.h" +#include "cutlass/arch/barrier.h" + +#include "named_barrier.hpp" + +namespace flash { + +/////////////////////////////////////////////////////////////////////////////// + +struct SingleTileScheduler { + +public: + + // Host side kernel arguments + struct Arguments { + int const num_blocks_m, num_head, num_batch; + int* const tile_count_semaphore = nullptr; + }; + + // Device side kernel params + struct Params {}; + + static Params + to_underlying_arguments(Arguments const& args) { + return {}; + } + + static dim3 + get_grid_dim(Arguments const& args, int num_sm) { + return {uint32_t(args.num_blocks_m), uint32_t(args.num_head), uint32_t(args.num_batch)}; + } + + struct WorkTileInfo { + int M_idx = 0; + int H_idx = 0; + int B_idx = 0; + bool is_valid_tile = false; + + CUTLASS_DEVICE + bool + is_valid(Params const& params) const { + return is_valid_tile; + } + + CUTLASS_DEVICE + cute::tuple + get_block_coord(Params const& params) const { + return {M_idx, H_idx, B_idx}; + } + + }; + + CUTLASS_DEVICE + SingleTileScheduler(int* tile_count_smem_) { } + + CUTLASS_DEVICE + WorkTileInfo + get_initial_work() const { + return {int(blockIdx.x), int(blockIdx.y), int(blockIdx.z), true}; + } + + CUTLASS_DEVICE + void + init_consumer() const {} + + CUTLASS_DEVICE + void + prefetch_next_work(Params const& params, WorkTileInfo& current_work) const {} + + CUTLASS_DEVICE + void + broadcast_next_work(WorkTileInfo& current_work) const {} + + template + CUTLASS_DEVICE + WorkTileInfo + get_next_work(Params const& params, WorkTileInfo const& current_work) const { + return {-1, -1, -1, false}; + } + +}; + +/////////////////////////////////////////////////////////////////////////////// + +class StaticPersistentTileScheduler { + +public: + + // Host side kernel arguments + struct Arguments { + int const num_blocks_m, num_head, num_batch; + int* const tile_count_semaphore = nullptr; + }; + + // Device side kernel params + struct Params { + int total_blocks; + cutlass::FastDivmod m_block_divmod, head_divmod; + }; + + static Params + to_underlying_arguments(Arguments const& args) { + return {args.num_blocks_m * args.num_head * args.num_batch, + cutlass::FastDivmod(args.num_blocks_m), cutlass::FastDivmod(args.num_head)}; + } + + static dim3 + get_grid_dim(Arguments const& args, int num_sm) { + return {uint32_t(num_sm)}; + } + + struct WorkTileInfo { + int tile_idx; + + CUTLASS_DEVICE + bool + is_valid(Params const& params) const { + return tile_idx < params.total_blocks; + } + + CUTLASS_DEVICE + cute::tuple + get_block_coord(Params const& params) const { + int m_block, bidh, bidb; + bidb = params.head_divmod.divmod(bidh, params.m_block_divmod.divmod(m_block, tile_idx)); + return {m_block, bidh, bidb}; + } + + }; + + CUTLASS_DEVICE + StaticPersistentTileScheduler(int* tile_count_smem_) {}; + + CUTLASS_DEVICE + WorkTileInfo + get_initial_work() const { + return {int(blockIdx.x)}; + } + + CUTLASS_DEVICE + void + init_consumer() const {} + + CUTLASS_DEVICE + void + prefetch_next_work(Params const& params, WorkTileInfo& current_work) const {} + + CUTLASS_DEVICE + void + broadcast_next_work(WorkTileInfo& current_work) const {} + + template + CUTLASS_DEVICE + WorkTileInfo + get_next_work(Params const& params, WorkTileInfo const& current_work) const { + return {current_work.tile_idx + int(gridDim.x)}; + } + +}; + +template +class DynamicPersistentTileScheduler { + +protected: + int* const tile_count_smem; + +public: + + // Host side kernel arguments + struct Arguments { + int const num_blocks_m, num_head, num_batch; + int* const tile_count_semaphore; + }; + + // Device side kernel params + struct Params { + int const total_blocks; + cutlass::FastDivmod const m_block_divmod, head_divmod; + int* const tile_count_semaphore; + }; + + static Params + to_underlying_arguments(Arguments const& args) { + return {args.num_blocks_m * args.num_head * args.num_batch, + cutlass::FastDivmod(args.num_blocks_m), cutlass::FastDivmod(args.num_head), + args.tile_count_semaphore}; + } + + static dim3 + get_grid_dim(Arguments const& args, int num_sm) { + return {uint32_t(num_sm)}; + } + + struct WorkTileInfo { + int tile_idx; + + CUTLASS_DEVICE + bool + is_valid(Params const& params) const { + return tile_idx < params.total_blocks; + } + + CUTLASS_DEVICE + cute::tuple + get_block_coord(Params const& params) const { + int m_block, bidh, bidb; + bidb = params.head_divmod.divmod(bidh, params.m_block_divmod.divmod(m_block, tile_idx)); + return {m_block, bidh, bidb}; + } + + }; + + CUTLASS_DEVICE + DynamicPersistentTileScheduler(int* tile_count_smem_) : tile_count_smem(tile_count_smem_) {}; + + CUTLASS_DEVICE + WorkTileInfo + get_initial_work() const { + return {int(blockIdx.x)}; + } + + CUTLASS_DEVICE + void + init_consumer() const { + cutlass::arch::NamedBarrier::arrive(NumMmaThreads + cutlass::NumThreadsPerWarp, static_cast(FwdNamedBarriers::TileCountSmemEmpty) /*id*/); + } + + CUTLASS_DEVICE + void + prefetch_next_work(Params const& params, WorkTileInfo& current_work) const { + if (threadIdx.x % cutlass::NumThreadsPerWarp == 0) { + current_work.tile_idx = atomicAdd(params.tile_count_semaphore, 1) + int(gridDim.x); + } + } + + CUTLASS_DEVICE + void + broadcast_next_work(WorkTileInfo& current_work) const { + cutlass::arch::NamedBarrier::sync(NumMmaThreads + cutlass::NumThreadsPerWarp, static_cast(FwdNamedBarriers::TileCountSmemEmpty) /*id*/); + if (threadIdx.x % cutlass::NumThreadsPerWarp == 0) { + *tile_count_smem = current_work.tile_idx; + } + cutlass::arch::NamedBarrier::arrive(NumMmaThreads + cutlass::NumThreadsPerWarp, static_cast(FwdNamedBarriers::TileCountSmemFull) /*id*/); + } + + template + CUTLASS_DEVICE + WorkTileInfo + get_next_work(Params const& params, WorkTileInfo const& current_work) const { + if constexpr (IsProducer) { + // thread 0 already has the right tile_idx, just need to broadcast to the rest of warp 0 + return {__shfl_sync(0xffffffff, current_work.tile_idx, 0 /*lane*/)}; + } else { + cutlass::arch::NamedBarrier::sync(NumMmaThreads + cutlass::NumThreadsPerWarp, static_cast(FwdNamedBarriers::TileCountSmemFull) /*id*/); + int tile_idx = *tile_count_smem; + cutlass::arch::NamedBarrier::arrive(NumMmaThreads + cutlass::NumThreadsPerWarp, static_cast(FwdNamedBarriers::TileCountSmemEmpty) /*id*/); + return {tile_idx}; + } + } + +}; + +} // flash \ No newline at end of file diff --git a/hopper/utils.h b/hopper/utils.h new file mode 100644 index 000000000..64cd20a39 --- /dev/null +++ b/hopper/utils.h @@ -0,0 +1,219 @@ +/****************************************************************************** + * Copyright (c) 2024, Tri Dao. + ******************************************************************************/ + +#pragma once + +#include +#include +#include + +#include + +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 +#include +#endif + +#include + +#include +#include +#include +#include + +namespace flash { + +using namespace cute; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MaxOp { +__device__ __forceinline__ T operator()(T const & x, T const & y) { return x > y ? x : y; } +}; + +template <> +struct MaxOp { +// This is slightly faster +__device__ __forceinline__ float operator()(float const &x, float const &y) { return max(x, y); } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct SumOp { +__device__ __forceinline__ T operator()(T const & x, T const & y) { return x + y; } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Allreduce { + static_assert(THREADS == 32 || THREADS == 16 || THREADS == 8 || THREADS == 4); + template + static __device__ __forceinline__ T run(T x, Operator &op) { + constexpr int OFFSET = THREADS / 2; + x = op(x, __shfl_xor_sync(uint32_t(-1), x, OFFSET)); + return Allreduce::run(x, op); + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template<> +struct Allreduce<2> { +template +static __device__ __forceinline__ T run(T x, Operator &op) { + x = op(x, __shfl_xor_sync(uint32_t(-1), x, 1)); + return x; +} +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// For SM80, convert acc_layout from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N)) +// For SM90, convert acc_layout from ((2, 2, V), MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, V, MMA_N)) +template +__forceinline__ __device__ auto convert_layout_acc_rowcol(Layout acc_layout) { + if constexpr (decltype(rank<0>(acc_layout))::value == 3) { // SM90 + static_assert(decltype(size<0, 0>(acc_layout))::value == 2); + static_assert(decltype(size<0, 1>(acc_layout))::value == 2); + static_assert(decltype(rank(acc_layout))::value == 3); + auto l = acc_layout; + return make_layout(make_layout(get<0, 1>(l), get<1>(l)), make_layout(get<0, 0>(l), get<0, 2>(l), get<2>(l))); + } else { // SM80 + static_assert(decltype(size<0>(acc_layout))::value == 4); + static_assert(decltype(rank(acc_layout))::value == 3); + auto l = logical_divide(acc_layout, Shape<_2>{}); // ((2, 2), MMA_M, MMA_N) + return make_layout(make_layout(get<0, 1>(l), get<1>(l)), make_layout(get<0, 0>(l), get<2>(l))); + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// For SM90, convert acc_layout from ((2, 2, V), MMA_N, MMA_M) to (nrow=(2, V, MMA_M), ncol=(2, MMA_N)) +template +__forceinline__ __device__ auto convert_layout_acc_transposed_rowcol(Layout acc_layout) { + static_assert(decltype(size<0, 0>(acc_layout))::value == 2); + static_assert(decltype(size<0, 1>(acc_layout))::value == 2); + static_assert(decltype(rank(acc_layout))::value == 3); + auto l = acc_layout; + return make_layout(make_layout(get<0, 0>(l), get<0, 2>(l), get<2>(l)), make_layout(get<0, 1>(l), get<1>(l))); +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// For SM80, convert acc_layout from (MMA=4, MMA_M, MMA_N) to ((4, 2), MMA_M, MMA_N / 2) +// if using m16n8k16, or to (4, MMA_M, MMA_N) if using m16n8k8. +// For SM90, convert acc_layout from ((2, 2, N / 8), MMA_M, MMA_N) to ((2, 2, 2), MMA_M, (N / 16, MMA_N)) +template +__forceinline__ __device__ auto convert_layout_acc_Aregs(Layout acc_layout) { + using X = Underscore; + if constexpr (decltype(rank<0>(acc_layout))::value == 3) { // SM90 + static_assert(decltype(size<0, 0>(acc_layout))::value == 2); + static_assert(decltype(size<0, 1>(acc_layout))::value == 2); + static_assert(decltype(rank(acc_layout))::value == 3); + static_assert(decltype(rank(get<0>(acc_layout)))::value == 3); + auto l = logical_divide(get<0>(acc_layout), Shape{}); // (2, 2, (2, N / 16))) + return make_layout(make_layout(get<0>(l), get<1>(l), get<2, 0>(l)), get<1>(acc_layout), make_layout(get<2, 1>(l), get<2>(acc_layout))); + } else { // SM80 + static_assert(decltype(size<0>(acc_layout))::value == 4); + static_assert(decltype(rank(acc_layout))::value == 3); + constexpr int mma_shape_K = get<2>(typename MMA_traits::Shape_MNK{}); + static_assert(mma_shape_K == 8 || mma_shape_K == 16); + if constexpr (mma_shape_K == 8) { + return acc_layout; + } else { + auto l = logical_divide(acc_layout, Shape{}); // (4, MMA_M, (2, MMA_N / 2))) + return make_layout(make_layout(get<0>(l), get<2, 0>(l)), get<1>(l), get<2, 1>(l)); + } + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +__forceinline__ __device__ auto convert_type(Tensor const &tensor) { + using From_type = typename Engine::value_type; + constexpr int numel = decltype(size(tensor))::value; + cutlass::NumericArrayConverter convert_op; + // HACK: this requires tensor to be "contiguous" + auto frag = convert_op(*reinterpret_cast *>(tensor.data())); + return make_tensor(make_rmem_ptr(&frag), tensor.layout()); + // Tensor out = make_tensor_like(tensor); + // cute::copy(make_tensor(make_rmem_ptr(&frag), tensor.layout()), out); + // return out; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +__forceinline__ __device__ void gemm(TiledMma &tiled_mma, Tensor0 const &tCrA, Tensor1 const &tCrB, Tensor2 &tCrC) { + constexpr bool Is_RS = !cute::is_base_of::value; + // Need to cast away const on tCrA since warpgroup_fence_operand doesn't take const + if constexpr (Is_RS) { warpgroup_fence_operand(const_cast(tCrA)); } + warpgroup_fence_operand(tCrC); + if constexpr (arrive) { + warpgroup_arrive(); + } + if constexpr (zero_init) { + tiled_mma.accumulate_ = GMMA::ScaleOut::Zero; + // Unroll the K mode manually to set scale D to 1 + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) { + cute::gemm(tiled_mma, tCrA(_,_,k_block), tCrB(_,_,k_block), tCrC); + tiled_mma.accumulate_ = GMMA::ScaleOut::One; + } + } else { + // cute::gemm(tiled_mma, tCrA, tCrB, tCrC); + // Unroll the K mode manually to set scale D to 1 + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) { + cute::gemm(tiled_mma, tCrA(_,_,k_block), tCrB(_,_,k_block), tCrC); + tiled_mma.accumulate_ = GMMA::ScaleOut::One; + } + } + if constexpr (commit) { + warpgroup_commit_batch(); + } + if constexpr (wg_wait >= 0) { warpgroup_wait(); } + warpgroup_fence_operand(tCrC); + if constexpr (Is_RS) { warpgroup_fence_operand(const_cast(tCrA)); } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +__forceinline__ __device__ void copy(TiledCopy tiled_copy, Tensor const &S, + Tensor &D, Tensor const &identity_MN, + Tensor const &predicate_K, const int max_MN=0) { + CUTE_STATIC_ASSERT_V(rank(S) == Int<3>{}); + CUTE_STATIC_ASSERT_V(rank(D) == Int<3>{}); + CUTE_STATIC_ASSERT_V(size<0>(S) == size<0>(D)); // MMA + CUTE_STATIC_ASSERT_V(size<1>(S) == size<1>(D)); // MMA_M + CUTE_STATIC_ASSERT_V(size<2>(S) == size<2>(D)); // MMA_K + // There's no case where !Clear_OOB_K && Clear_OOB_MN + static_assert(!(Clear_OOB_MN && !Clear_OOB_K)); + #pragma unroll + for (int m = 0; m < size<1>(S); ++m) { + if (Is_even_MN || get<0>(identity_MN(0, m, 0)) < max_MN) { + #pragma unroll + for (int k = 0; k < size<2>(S); ++k) { + if (Is_even_K || predicate_K(k)) { + cute::copy(tiled_copy, S(_, m, k), D(_, m, k)); + } else if (Clear_OOB_K) { + cute::clear(D(_, m, k)); + } + } + } else if (Clear_OOB_MN) { + cute::clear(D(_, m, _)); + } + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace flash diff --git a/setup.py b/setup.py index 7e929a07b..1051a1ff6 100644 --- a/setup.py +++ b/setup.py @@ -196,6 +196,22 @@ def validate_and_update_archs(archs): "csrc/flash_attn/src/flash_fwd_hdim224_bf16_sm80.cu", "csrc/flash_attn/src/flash_fwd_hdim256_fp16_sm80.cu", "csrc/flash_attn/src/flash_fwd_hdim256_bf16_sm80.cu", + "csrc/flash_attn/src/flash_fwd_hdim32_fp16_causal_sm80.cu", + "csrc/flash_attn/src/flash_fwd_hdim32_bf16_causal_sm80.cu", + "csrc/flash_attn/src/flash_fwd_hdim64_fp16_causal_sm80.cu", + "csrc/flash_attn/src/flash_fwd_hdim64_bf16_causal_sm80.cu", + "csrc/flash_attn/src/flash_fwd_hdim96_fp16_causal_sm80.cu", + "csrc/flash_attn/src/flash_fwd_hdim96_bf16_causal_sm80.cu", + "csrc/flash_attn/src/flash_fwd_hdim128_fp16_causal_sm80.cu", + "csrc/flash_attn/src/flash_fwd_hdim128_bf16_causal_sm80.cu", + "csrc/flash_attn/src/flash_fwd_hdim160_fp16_causal_sm80.cu", + "csrc/flash_attn/src/flash_fwd_hdim160_bf16_causal_sm80.cu", + "csrc/flash_attn/src/flash_fwd_hdim192_fp16_causal_sm80.cu", + "csrc/flash_attn/src/flash_fwd_hdim192_bf16_causal_sm80.cu", + "csrc/flash_attn/src/flash_fwd_hdim224_fp16_causal_sm80.cu", + "csrc/flash_attn/src/flash_fwd_hdim224_bf16_causal_sm80.cu", + "csrc/flash_attn/src/flash_fwd_hdim256_fp16_causal_sm80.cu", + "csrc/flash_attn/src/flash_fwd_hdim256_bf16_causal_sm80.cu", "csrc/flash_attn/src/flash_bwd_hdim32_fp16_sm80.cu", "csrc/flash_attn/src/flash_bwd_hdim32_bf16_sm80.cu", "csrc/flash_attn/src/flash_bwd_hdim64_fp16_sm80.cu", @@ -228,6 +244,22 @@ def validate_and_update_archs(archs): "csrc/flash_attn/src/flash_fwd_split_hdim224_bf16_sm80.cu", "csrc/flash_attn/src/flash_fwd_split_hdim256_fp16_sm80.cu", "csrc/flash_attn/src/flash_fwd_split_hdim256_bf16_sm80.cu", + "csrc/flash_attn/src/flash_fwd_split_hdim32_fp16_causal_sm80.cu", + "csrc/flash_attn/src/flash_fwd_split_hdim32_bf16_causal_sm80.cu", + "csrc/flash_attn/src/flash_fwd_split_hdim64_fp16_causal_sm80.cu", + "csrc/flash_attn/src/flash_fwd_split_hdim64_bf16_causal_sm80.cu", + "csrc/flash_attn/src/flash_fwd_split_hdim96_fp16_causal_sm80.cu", + "csrc/flash_attn/src/flash_fwd_split_hdim96_bf16_causal_sm80.cu", + "csrc/flash_attn/src/flash_fwd_split_hdim128_fp16_causal_sm80.cu", + "csrc/flash_attn/src/flash_fwd_split_hdim128_bf16_causal_sm80.cu", + "csrc/flash_attn/src/flash_fwd_split_hdim160_fp16_causal_sm80.cu", + "csrc/flash_attn/src/flash_fwd_split_hdim160_bf16_causal_sm80.cu", + "csrc/flash_attn/src/flash_fwd_split_hdim192_fp16_causal_sm80.cu", + "csrc/flash_attn/src/flash_fwd_split_hdim192_bf16_causal_sm80.cu", + "csrc/flash_attn/src/flash_fwd_split_hdim224_fp16_causal_sm80.cu", + "csrc/flash_attn/src/flash_fwd_split_hdim224_bf16_causal_sm80.cu", + "csrc/flash_attn/src/flash_fwd_split_hdim256_fp16_causal_sm80.cu", + "csrc/flash_attn/src/flash_fwd_split_hdim256_bf16_causal_sm80.cu", ], extra_compile_args={ "cxx": ["-O3", "-std=c++17"] + generator_flag, @@ -248,6 +280,7 @@ def validate_and_update_archs(archs): # "-DFLASHATTENTION_DISABLE_BACKWARD", # "-DFLASHATTENTION_DISABLE_DROPOUT", # "-DFLASHATTENTION_DISABLE_ALIBI", + # "-DFLASHATTENTION_DISABLE_SOFTCAP", # "-DFLASHATTENTION_DISABLE_UNEVEN_K", # "-DFLASHATTENTION_DISABLE_LOCAL", ] @@ -381,9 +414,9 @@ def get_wheel_url(): # We're using the CUDA version used to build torch, not the one currently installed # _, cuda_version_raw = get_cuda_bare_metal_version(CUDA_HOME) torch_cuda_version = parse(torch.version.cuda) - # For CUDA 11, we only compile for CUDA 11.8, and for CUDA 12 we only compile for CUDA 12.2 + # For CUDA 11, we only compile for CUDA 11.8, and for CUDA 12 we only compile for CUDA 12.3 # to save CI time. Minor versions should be compatible. - torch_cuda_version = parse("11.8") if torch_cuda_version.major == 11 else parse("12.2") + torch_cuda_version = parse("11.8") if torch_cuda_version.major == 11 else parse("12.3") # cuda_version = f"{cuda_version_raw.major}{cuda_version_raw.minor}" cuda_version = f"{torch_cuda_version.major}{torch_cuda_version.minor}" @@ -466,7 +499,7 @@ def __init__(self, *args, **kwargs) -> None: ) ), author="Tri Dao", - author_email="trid@cs.stanford.edu", + author_email="tri@tridao.me", description="Flash Attention: Fast and Memory-Efficient Exact Attention", long_description=long_description, long_description_content_type="text/markdown", @@ -482,7 +515,7 @@ def __init__(self, *args, **kwargs) -> None: else { "bdist_wheel": CachedWheelsCommand, }, - python_requires=">=3.7", + python_requires=">=3.8", install_requires=[ "torch", "einops", diff --git a/tests/test_flash_attn.py b/tests/test_flash_attn.py index 308e30bec..cd9262b86 100644 --- a/tests/test_flash_attn.py +++ b/tests/test_flash_attn.py @@ -182,9 +182,14 @@ def construct_local_mask( query_padding_mask=None, key_padding_mask=None, device=None, + key_leftpad=None, ): row_idx = rearrange(torch.arange(seqlen_q, device=device, dtype=torch.long), "s -> s 1") col_idx = torch.arange(seqlen_k, device=device, dtype=torch.long) + if key_leftpad is not None: + key_leftpad = rearrange(key_leftpad, "b -> b 1 1 1") + col_idx = repeat(col_idx, "s -> b 1 1 s", b=key_leftpad.shape[0]) + col_idx = torch.where(col_idx >= key_leftpad, col_idx - key_leftpad, 2**32) sk = ( seqlen_k if key_padding_mask is None @@ -216,8 +221,10 @@ def attention_ref( dropout_mask=None, causal=False, window_size=(-1, -1), # -1 means infinite window size + softcap=0.0, upcast=True, reorder_ops=False, + key_leftpad=None, ): """ Arguments: @@ -233,7 +240,7 @@ def attention_ref( window_size: (int, int), left and right window size upcast: whether to cast all inputs to fp32, do all computation in fp32, then cast output back to fp16/bf16. - reorder_ops: whether to change the order of operations (scaling k instead of scaling k, etc.) + reorder_ops: whether to change the order of operations (scaling k instead of scaling q, etc.) without changing the math. This is to estimate the numerical error from operation reordering. Output: @@ -253,6 +260,10 @@ def attention_ref( scores = torch.einsum("bthd,bshd->bhts", q / math.sqrt(d), k) else: scores = torch.einsum("bthd,bshd->bhts", q, k / math.sqrt(d)) + if softcap > 0: + scores = scores / softcap + scores = scores.tanh() + scores = scores * softcap if key_padding_mask is not None: scores.masked_fill_(rearrange(~key_padding_mask, "b s -> b 1 1 s"), float("-inf")) if window_size[0] >= 0 or window_size[1] >= 0: @@ -263,6 +274,7 @@ def attention_ref( query_padding_mask, key_padding_mask, q.device, + key_leftpad=key_leftpad, ) scores.masked_fill_(local_mask, float("-inf")) if attn_bias is not None: @@ -298,8 +310,10 @@ def attention_kvpacked_ref( dropout_mask=None, causal=False, window_size=(-1, -1), # -1 means infinite window size + softcap=0.0, upcast=True, reorder_ops=False, + key_leftpad=None, ): return attention_ref( q, @@ -313,7 +327,9 @@ def attention_kvpacked_ref( upcast=upcast, causal=causal, window_size=window_size, + softcap=softcap, reorder_ops=reorder_ops, + key_leftpad=key_leftpad, ) @@ -325,6 +341,7 @@ def attention_qkvpacked_ref( dropout_mask=None, causal=False, window_size=(-1, -1), # -1 means infinite window size + softcap=0.0, upcast=True, reorder_ops=False, ): @@ -340,6 +357,7 @@ def attention_qkvpacked_ref( upcast=upcast, causal=causal, window_size=window_size, + softcap=softcap, reorder_ops=reorder_ops, ) @@ -877,23 +895,29 @@ def test_flash_attn_varlen_qkvpacked( # @pytest.mark.parametrize('seqlen_q,seqlen_k', [(256, 128)]) @pytest.mark.parametrize("dropout_p", [0.0, 0.17]) # @pytest.mark.parametrize("dropout_p", [0.17]) +@pytest.mark.parametrize("softcap", [0.0, 50.0]) def test_flash_attn_output( - seqlen_q, seqlen_k, d, dropout_p, causal, local, alibi, deterministic, mha_type, dtype, kvpacked + seqlen_q, seqlen_k, d, dropout_p, causal, local, alibi, deterministic, mha_type, dtype, kvpacked, softcap ): if ( max(seqlen_q, seqlen_k) >= 2048 and torch.cuda.get_device_properties("cuda").total_memory <= 16 * 2**30 ): pytest.skip() # Reference implementation OOM + if softcap > 0.0 and dropout_p > 0.0: + pytest.skip("Softcap and dropout not supported together") device = "cuda" # set seed torch.random.manual_seed(0) batch_size = 4 - nheads = 9 - nheads_k = nheads if mha_type == "mha" else (1 if mha_type == "mqa" else 3) + nheads = 6 if softcap == 0.0 else 4 # softcap reference impl takes more memory + nheads_k = nheads if mha_type == "mha" else (1 if mha_type == "mqa" else 2) assert nheads % nheads_k == 0 window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,)) q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype, requires_grad=True) + if softcap > 0: + # Ensure the values of qk are at least within softcap range. + q = q * softcap if kvpacked: kv = torch.randn( batch_size, seqlen_k, 2, nheads_k, d, device=device, dtype=dtype, requires_grad=True @@ -918,6 +942,7 @@ def test_flash_attn_output( dropout_p, causal=causal, window_size=window_size, + softcap=softcap, alibi_slopes=alibi_slopes, deterministic=deterministic, return_attn_probs=True, @@ -930,6 +955,7 @@ def test_flash_attn_output( dropout_p, causal=causal, window_size=window_size, + softcap=softcap, alibi_slopes=alibi_slopes, deterministic=deterministic, return_attn_probs=True, @@ -984,6 +1010,7 @@ def test_flash_attn_output( dropout_mask, causal=causal, window_size=window_size, + softcap=softcap, ) out_pt, attn_pt = attention_kvpacked_ref( q, @@ -995,6 +1022,7 @@ def test_flash_attn_output( dropout_mask, causal=causal, window_size=window_size, + softcap=softcap, upcast=False, reorder_ops=True, ) @@ -1010,6 +1038,7 @@ def test_flash_attn_output( dropout_mask, causal=causal, window_size=window_size, + softcap=softcap, ) out_pt, attn_pt = attention_ref( q, @@ -1022,6 +1051,7 @@ def test_flash_attn_output( dropout_mask, causal=causal, window_size=window_size, + softcap=softcap, upcast=False, reorder_ops=True, ) @@ -1036,7 +1066,7 @@ def test_flash_attn_output( g = torch.randn_like(out) do_o = (g.float() * out.float()).sum(-1) - if (d <= MAX_HEADDIM_SM8x or (d > 224 and dropout_p == 0)) or (is_sm80 or is_sm90): + if ((d <= MAX_HEADDIM_SM8x or (d > 224 and dropout_p == 0)) or (is_sm80 or is_sm90)) and softcap == 0.0: if kvpacked: ( dq, @@ -1092,7 +1122,7 @@ def test_flash_attn_output( if not alibi: assert abs(dropout_fraction - dropout_p) <= (0.01 if not local else 0.025) - if (d <= MAX_HEADDIM_SM8x or (d > 224 and dropout_p == 0)) or (is_sm80 or is_sm90): + if ((d <= MAX_HEADDIM_SM8x or (d > 224 and dropout_p == 0)) or (is_sm80 or is_sm90)): assert (dq - dq_ref).abs().max().item() <= 2 * (dq_pt - dq_ref).abs().max().item() assert (dk - dk_ref).abs().max().item() <= 2 * (dk_pt - dk_ref).abs().max().item() assert (dv - dv_ref).abs().max().item() <= 2 * (dv_pt - dv_ref).abs().max().item() @@ -1133,24 +1163,31 @@ def test_flash_attn_output( ) # @pytest.mark.parametrize('seqlen_q,seqlen_k', [(128, 128)]) @pytest.mark.parametrize("dropout_p", [0.0, 0.17]) +@pytest.mark.parametrize("softcap", [0.0, 50.0]) # @pytest.mark.parametrize('dropout_p', [0.0]) def test_flash_attn_varlen_output( - seqlen_q, seqlen_k, d, dropout_p, causal, local, alibi, deterministic, mha_type, dtype, kvpacked + seqlen_q, seqlen_k, d, dropout_p, causal, local, alibi, deterministic, mha_type, dtype, kvpacked, softcap ): if ( max(seqlen_q, seqlen_k) >= 2048 and torch.cuda.get_device_properties("cuda").total_memory <= 16 * 2**30 ): pytest.skip() # Reference implementation OOM + if softcap > 0.0 and dropout_p > 0.0: + pytest.skip("Softcap and dropout not supported together") device = "cuda" # set seed torch.random.manual_seed(0) batch_size = 4 - nheads = 9 - nheads_k = nheads if mha_type == "mha" else (1 if mha_type == "mqa" else 3) + nheads = 6 if softcap == 0.0 else 4 # softcap reference impl takes more memory + nheads_k = nheads if mha_type == "mha" else (1 if mha_type == "mqa" else 2) assert nheads % nheads_k == 0 window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,)) q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype, requires_grad=True) + if softcap > 0: + # Ensure the values of qk are at least within softcap range. + q = q * softcap + if kvpacked: kv = torch.randn( batch_size, seqlen_k, 2, nheads_k, d, device=device, dtype=dtype, requires_grad=True @@ -1198,6 +1235,7 @@ def test_flash_attn_varlen_output( dropout_p, causal=causal, window_size=window_size, + softcap=softcap, alibi_slopes=alibi_slopes, deterministic=deterministic, return_attn_probs=True, @@ -1229,6 +1267,7 @@ def test_flash_attn_varlen_output( dropout_p, causal=causal, window_size=window_size, + softcap=softcap, alibi_slopes=alibi_slopes, deterministic=deterministic, return_attn_probs=True, @@ -1288,6 +1327,7 @@ def test_flash_attn_varlen_output( dropout_mask, causal=causal, window_size=window_size, + softcap=softcap, ) out_pt, attn_pt = attention_kvpacked_ref( q, @@ -1299,6 +1339,7 @@ def test_flash_attn_varlen_output( dropout_mask, causal=causal, window_size=window_size, + softcap=softcap, upcast=False, reorder_ops=True, ) @@ -1314,6 +1355,7 @@ def test_flash_attn_varlen_output( dropout_mask, causal=causal, window_size=window_size, + softcap=softcap, ) out_pt, attn_pt = attention_ref( q, @@ -1326,6 +1368,7 @@ def test_flash_attn_varlen_output( dropout_mask, causal=causal, window_size=window_size, + softcap=softcap, upcast=False, reorder_ops=True, ) @@ -1339,7 +1382,7 @@ def test_flash_attn_varlen_output( print(f"Attention Pytorch max diff: {(attn_pt - attn_ref).abs().max().item()}") g = torch.randn_like(out) - if (d <= MAX_HEADDIM_SM8x or (d > 224 and dropout_p == 0)) or (is_sm80 or is_sm90): + if ((d <= MAX_HEADDIM_SM8x or (d > 224 and dropout_p == 0)) or (is_sm80 or is_sm90)): if kvpacked: ( dq_unpad, @@ -1396,9 +1439,9 @@ def test_flash_attn_varlen_output( assert (attn - attn_ref).abs().max().item() <= 2 * (attn_pt - attn_ref).abs().max().item() # With alibi, many of the prob values are 0.0 & -0.0 so dropout_fraction isn't accurate if not alibi: - assert abs(dropout_fraction - dropout_p) <= (0.01 if not local else 0.025) + assert abs(dropout_fraction - dropout_p) <= (0.01 if not local else 0.04) - if (d <= MAX_HEADDIM_SM8x or (d > 224 and dropout_p == 0)) or (is_sm80 or is_sm90): + if ((d <= MAX_HEADDIM_SM8x or (d > 224 and dropout_p == 0)) or (is_sm80 or is_sm90)) and softcap == 0.0: assert (dq - dq_ref).abs().max().item() <= 3 * (dq_pt - dq_ref).abs().max().item() assert (dk - dk_ref).abs().max().item() <= 3 * (dk_pt - dk_ref).abs().max().item() assert (dv - dv_ref).abs().max().item() <= 3 * (dv_pt - dv_ref).abs().max().item() @@ -1834,9 +1877,11 @@ def test_flash_attn_splitkv( # @pytest.mark.parametrize("rotary_fraction", [0.0]) @pytest.mark.parametrize("paged_kv_block_size", [None, 256]) # @pytest.mark.parametrize("paged_kv_block_size", [256, 512]) -# @pytest.mark.parametrize("paged_kv_block_size", [256]) -@pytest.mark.parametrize("has_batch_idx", [False, True]) -# @pytest.mark.parametrize("has_batch_idx", [False]) +# @pytest.mark.parametrize("paged_kv_block_size", [None]) +@pytest.mark.parametrize("has_leftpad", [False, True]) +# @pytest.mark.parametrize("has_leftpad", [True]) +# @pytest.mark.parametrize("has_batch_idx", [False, True]) +@pytest.mark.parametrize("has_batch_idx", [False]) @pytest.mark.parametrize("d", [32, 59, 64, 80, 128, 256]) # @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192, 224, 256]) # @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192]) @@ -1864,6 +1909,7 @@ def test_flash_attn_kvcache( seqlen_k, d, has_batch_idx, + has_leftpad, paged_kv_block_size, rotary_fraction, rotary_interleaved, @@ -1882,6 +1928,8 @@ def test_flash_attn_kvcache( pytest.skip() if has_batch_idx and paged_kv_block_size is not None: pytest.skip() + if has_leftpad and paged_kv_block_size is not None: + pytest.skip() device = "cuda" # set seed torch.random.manual_seed(0) @@ -1918,16 +1966,28 @@ def test_flash_attn_kvcache( cache_seqlens = torch.randint( 0 if new_kv else 1, # If we don't use seqlen_q in the case of causal and rotary, cos/sin won't be long enough - (seqlen_k - (seqlen_q if (causal or local) and rotary_dim > 1 else seqlen_new) + 1) - if new_kv - else (seqlen_k + 1), + ( + (seqlen_k - (seqlen_q if (causal or local) and rotary_dim > 1 else seqlen_new) + 1) + if new_kv + else (seqlen_k + 1) + ), (batch_size,), dtype=torch.int32, device=device, ) + if has_leftpad: + cache_leftpad = torch.cat([torch.randint(0, cache_seqlens[i].item(), (1,), dtype=torch.int32, device=device) + if cache_seqlens[i].item() > 0 else torch.zeros(1, dtype=torch.int32, device=device) + for i in range(batch_size)]) + else: + cache_leftpad = None arange = rearrange(torch.arange(seqlen_k, device=device), "s -> 1 s") cache_seqlens_expanded = rearrange(cache_seqlens, "b -> b 1") key_padding_mask = arange < cache_seqlens_expanded + (seqlen_new if new_kv else 0) + if has_leftpad: + key_padding_mask = torch.logical_and( + key_padding_mask, arange >= cache_leftpad.unsqueeze(-1).expand(-1, seqlen_k) + ) if has_batch_idx: cache_batch_idx = torch.randperm(batch_size_cache, dtype=torch.int32, device=device)[ :batch_size @@ -2002,6 +2062,7 @@ def test_flash_attn_kvcache( rotary_sin=sin, cache_seqlens=cache_seqlens, cache_batch_idx=cache_batch_idx, + cache_leftpad=cache_leftpad, block_table=block_table, causal=causal, window_size=window_size, @@ -2030,6 +2091,7 @@ def test_flash_attn_kvcache( None, causal=causal, window_size=window_size, + key_leftpad=cache_leftpad, ) out_pt, _ = attention_ref( q_ro, @@ -2044,6 +2106,7 @@ def test_flash_attn_kvcache( window_size=window_size, upcast=False, reorder_ops=True, + key_leftpad=cache_leftpad, ) print(f"Output max diff: {(out - out_ref).abs().max().item()}") print(f"Output mean diff: {(out - out_ref).abs().mean().item()}") @@ -2456,9 +2519,9 @@ def test_flash_attn_varlen_deterministic(seqlen_q, seqlen_k, swap_sq_sk, d, caus g = torch.randn_like(out) if (d <= MAX_HEADDIM_SM8x or d > 224) or (is_sm80 or is_sm90): - dq, dk, dv = torch.autograd.grad(out, (q_unpad, k_unpad, v_unpad), g, retain_graph=True) + dq0, dk0, dv0 = torch.autograd.grad(out, (q_unpad, k_unpad, v_unpad), g, retain_graph=True) for _ in range(50): dq, dk, dv = torch.autograd.grad(out, (q_unpad, k_unpad, v_unpad), g, retain_graph=True) - assert torch.equal(dv, dv) - assert torch.equal(dk, dk) - assert torch.equal(dq, dq) + assert torch.equal(dv, dv0) + assert torch.equal(dk, dk0) + assert torch.equal(dq, dq0) diff --git a/training/Dockerfile b/training/Dockerfile index 0baec9278..c2412c059 100644 --- a/training/Dockerfile +++ b/training/Dockerfile @@ -85,7 +85,7 @@ RUN pip install transformers==4.25.1 datasets==2.8.0 pytorch-lightning==1.8.6 tr RUN pip install git+https://github.com/mlcommons/logging.git@2.1.0 # Install FlashAttention -RUN pip install flash-attn==2.5.9.post1 +RUN pip install flash-attn==2.6.1 # Install CUDA extensions for fused dense -RUN pip install git+https://github.com/HazyResearch/flash-attention@v2.5.9.post1#subdirectory=csrc/fused_dense_lib +RUN pip install git+https://github.com/HazyResearch/flash-attention@v2.6.1#subdirectory=csrc/fused_dense_lib