From f31e09a13a9ec7ef23e346a50798f852db640d05 Mon Sep 17 00:00:00 2001 From: rocking Date: Mon, 7 Oct 2024 19:07:49 +0000 Subject: [PATCH 1/4] update ck --- csrc/composable_kernel | 2 +- csrc/flash_attn_ck/mha_bwd.cpp | 9 ++++----- csrc/flash_attn_ck/mha_fwd.cpp | 9 ++++----- csrc/flash_attn_ck/mha_varlen_bwd.cpp | 9 ++++----- csrc/flash_attn_ck/mha_varlen_fwd.cpp | 9 ++++----- 5 files changed, 17 insertions(+), 21 deletions(-) diff --git a/csrc/composable_kernel b/csrc/composable_kernel index 11b7a4db0..0023f01ab 160000 --- a/csrc/composable_kernel +++ b/csrc/composable_kernel @@ -1 +1 @@ -Subproject commit 11b7a4db005dc38e60b1ea045d03a92d2a8f9cd0 +Subproject commit 0023f01ab02b9cc05a98ae1a7753df1481252e4d diff --git a/csrc/flash_attn_ck/mha_bwd.cpp b/csrc/flash_attn_ck/mha_bwd.cpp index 1859137f8..9aa9992d1 100644 --- a/csrc/flash_attn_ck/mha_bwd.cpp +++ b/csrc/flash_attn_ck/mha_bwd.cpp @@ -49,8 +49,7 @@ fmha_bwd_args get_ck_fmha_bwd_args(const mask_info &mask, at::Tensor dv, float softmax_scale, float p_dropout, - uint64_t drop_seed, - uint64_t drop_offset) + std::pair drop_seed_offset) { // q: (batch_size, seqlen_q, nheads, hdim) ck_tile::index_t batch_stride_q = q.stride(0); @@ -191,7 +190,7 @@ fmha_bwd_args get_ck_fmha_bwd_args(const mask_info &mask, static_cast(mask.type), p_dropout, p_undrop, - {drop_seed, drop_offset}}; + drop_seed_offset}; } std::vector @@ -352,6 +351,7 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num } if (seqlen_q > 0) { + auto drop_seed_offset = std::make_pair(drop_seed, drop_offset); ck_tile::stream_config stream_config{stream}; auto traits = @@ -380,8 +380,7 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num dv_expanded, softmax_scale, p_dropout, - drop_seed, - drop_offset); + drop_seed_offset); float t = fmha_bwd(traits, args, stream_config); TORCH_CHECK(t >= 0, "invalid argument for fmha_bwd"); diff --git a/csrc/flash_attn_ck/mha_fwd.cpp b/csrc/flash_attn_ck/mha_fwd.cpp index a6b33b4ab..88c0aac0e 100644 --- a/csrc/flash_attn_ck/mha_fwd.cpp +++ b/csrc/flash_attn_ck/mha_fwd.cpp @@ -46,8 +46,7 @@ fmha_fwd_args get_ck_fmha_fwd_args(bool has_lse, at::Tensor dropout_randval, float softmax_scale, float p_dropout, - uint64_t drop_seed, - uint64_t drop_offset) + std::pair drop_seed_offset) { // q: (batch_size, seqlen_q, nheads, d) // k: (batch_size, seqlen_k, nheads_k, d) @@ -137,7 +136,7 @@ fmha_fwd_args get_ck_fmha_fwd_args(bool has_lse, static_cast(mask.type), p_dropout, has_dropout_randval, - {drop_seed, drop_offset}}; + drop_seed_offset}; } std::vector @@ -273,6 +272,7 @@ mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num rng_state[1] = *(reinterpret_cast(&drop_offset)); if (seqlen_k > 0) { + auto drop_seed_offset = std::make_pair(drop_seed, drop_offset); auto stream = at::cuda::getCurrentHIPStream().stream(); ck_tile::stream_config stream_config{stream}; @@ -305,8 +305,7 @@ mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num p, softmax_scale, p_dropout, - drop_seed, - drop_offset); + drop_seed_offset); float t = fmha_fwd(traits, args, stream_config); TORCH_CHECK(t >= 0, "invalid argument for fmha_fwd"); diff --git a/csrc/flash_attn_ck/mha_varlen_bwd.cpp b/csrc/flash_attn_ck/mha_varlen_bwd.cpp index 531d735ed..600ae623d 100644 --- a/csrc/flash_attn_ck/mha_varlen_bwd.cpp +++ b/csrc/flash_attn_ck/mha_varlen_bwd.cpp @@ -51,8 +51,7 @@ fmha_bwd_args get_ck_fmha_varlen_bwd_args(const mask_info &mask, at::Tensor dv, float softmax_scale, float p_dropout, - uint64_t drop_seed, - uint64_t drop_offset) + std::pair drop_seed_offset) { ck_tile::index_t total_q = q.size(0); ck_tile::index_t total_k = k.size(0); @@ -197,7 +196,7 @@ fmha_bwd_args get_ck_fmha_varlen_bwd_args(const mask_info &mask, static_cast(mask.type), p_dropout, p_undrop, - {drop_seed, drop_offset}}; + drop_seed_offset}; } std::vector @@ -377,6 +376,7 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads } if (max_seqlen_q > 0) { + auto drop_seed_offset = std::make_pair(drop_seed, drop_offset); ck_tile::stream_config stream_config{stream}; auto traits = @@ -407,8 +407,7 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads dv_expanded, softmax_scale, p_dropout, - drop_seed, - drop_offset); + drop_seed_offset); float t = fmha_bwd(traits, args, stream_config); TORCH_CHECK(t >= 0, "invalid argument for fmha_bwd"); diff --git a/csrc/flash_attn_ck/mha_varlen_fwd.cpp b/csrc/flash_attn_ck/mha_varlen_fwd.cpp index 6e30aa74a..a1b1a7402 100644 --- a/csrc/flash_attn_ck/mha_varlen_fwd.cpp +++ b/csrc/flash_attn_ck/mha_varlen_fwd.cpp @@ -47,8 +47,7 @@ fmha_fwd_args get_ck_fmha_varlen_fwd_args(bool has_lse, at::Tensor dropout_randval, float softmax_scale, float p_dropout, - uint64_t drop_seed, - uint64_t drop_offset) + std::pair drop_seed_offset) { // q: (total_q, nheads, d) // k: (total_k, nheads_k, d) @@ -140,7 +139,7 @@ fmha_fwd_args get_ck_fmha_varlen_fwd_args(bool has_lse, static_cast(mask.type), p_dropout, has_dropout_randval, - {drop_seed, drop_offset}}; + drop_seed_offset}; } std::vector @@ -299,6 +298,7 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_si rng_state[1] = *(reinterpret_cast(&drop_offset)); if (max_seqlen_k > 0) { + auto drop_seed_offset = std::make_pair(drop_seed, drop_offset); auto stream = at::cuda::getCurrentHIPStream().stream(); ck_tile::stream_config stream_config{stream}; @@ -332,8 +332,7 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_si p, softmax_scale, p_dropout, - drop_seed, - drop_offset); + drop_seed_offset); float t = fmha_fwd(traits, args, stream_config); TORCH_CHECK(t >= 0, "invalid argument for fmha_fwd"); From 5cb68bb833f17d1059bb96e70b8871d32b34af98 Mon Sep 17 00:00:00 2001 From: rocking Date: Tue, 8 Oct 2024 09:46:03 +0000 Subject: [PATCH 2/4] use pointer as seed and offset --- csrc/flash_attn_ck/flash_common.hpp | 22 +++++++++++----------- csrc/flash_attn_ck/mha_bwd.cpp | 22 +++++++++++++--------- csrc/flash_attn_ck/mha_fwd.cpp | 15 ++++++--------- csrc/flash_attn_ck/mha_varlen_bwd.cpp | 22 +++++++++++++--------- csrc/flash_attn_ck/mha_varlen_fwd.cpp | 15 ++++++--------- 5 files changed, 49 insertions(+), 47 deletions(-) diff --git a/csrc/flash_attn_ck/flash_common.hpp b/csrc/flash_attn_ck/flash_common.hpp index 1c7c2f062..cc86546ea 100644 --- a/csrc/flash_attn_ck/flash_common.hpp +++ b/csrc/flash_attn_ck/flash_common.hpp @@ -22,17 +22,17 @@ #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") namespace flash { -// Copy from PyTorch -// https://github.com/pytorch/pytorch/blob/8b61daaf7349e9102117e1aeefaa51666d887547/aten/src/ATen/cuda/detail/UnpackRaw.cuh#L17 -inline std::tuple unpack(at::PhiloxCudaState arg) { - if (arg.captured_) { - // static_cast avoids "warning: invalid narrowing conversion from "long" to "unsigned long". - // *(arg.offset_.ptr) is a broadcast load of a single int64_t to the entire kernel. - // For most threads' reads it will hit in cache, so it shouldn't hurt performance. - return std::make_tuple(static_cast(*arg.seed_.ptr), static_cast(*(arg.offset_.ptr) + arg.offset_intragraph_)); - } else { - return std::make_tuple(arg.seed_.val, arg.offset_.val); - } +inline __global__ void ParsePhiloxCudaState(at::PhiloxCudaState arg, uint64_t* rng_state) +{ + // Imitate from PyTorch + // https://github.com/pytorch/pytorch/blob/8b61daaf7349e9102117e1aeefaa51666d887547/aten/src/ATen/cuda/detail/UnpackRaw.cuh#L17 + if (arg.captured_) { + rng_state[0] = static_cast(*arg.seed_.ptr); + rng_state[1] = static_cast(*(arg.offset_.ptr) + arg.offset_intragraph_); + } else { + rng_state[0] = arg.seed_.val; + rng_state[1] = arg.offset_.val; + } } inline int num_splits_heuristic_ck(int batch_nheads_mblocks, int num_SMs, int num_n_blocks, int max_splits) { diff --git a/csrc/flash_attn_ck/mha_bwd.cpp b/csrc/flash_attn_ck/mha_bwd.cpp index 9aa9992d1..b6ad44dd1 100644 --- a/csrc/flash_attn_ck/mha_bwd.cpp +++ b/csrc/flash_attn_ck/mha_bwd.cpp @@ -49,7 +49,7 @@ fmha_bwd_args get_ck_fmha_bwd_args(const mask_info &mask, at::Tensor dv, float softmax_scale, float p_dropout, - std::pair drop_seed_offset) + std::pair drop_seed_offset) { // q: (batch_size, seqlen_q, nheads, hdim) ck_tile::index_t batch_stride_q = q.stride(0); @@ -212,7 +212,7 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num const float /*softcap*/, const bool deterministic, c10::optional gen_, - c10::optional &rng_state) + c10::optional &rng_state_) { #ifdef FLASHATTENTION_DISABLE_BACKWARD TORCH_CHECK(false, "This flash attention build does not support backward."); @@ -336,22 +336,26 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num auto gen = at::get_generator_or_default( gen_, at::cuda::detail::getDefaultCUDAGenerator()); - uint64_t drop_seed = 1, drop_offset = 0; int64_t counter_offset = batch_size * num_heads * ck_tile::get_warp_size(); + at::Tensor rng_state; - if (rng_state.has_value()) { - uint64_t* d = reinterpret_cast(rng_state.value().data_ptr()); - drop_seed = d[0]; - drop_offset = d[1]; + if (rng_state_.has_value()) { + rng_state = rng_state_.value(); } else if(is_dropout) { + rng_state = torch::empty({2}, opts.dtype(torch::kInt64)); // See Note [Acquire lock when using random generators] std::lock_guard lock(gen->mutex_); auto philox_args = gen->philox_cuda_state(counter_offset); - std::tie(drop_seed, drop_offset) = flash::unpack(philox_args); + hipLaunchKernelGGL( + flash::ParsePhiloxCudaState, dim3(1), dim3(64), 0, 0, + philox_args, reinterpret_cast(rng_state.data_ptr())); + } else { + rng_state = torch::empty({2}, opts.dtype(torch::kInt64)); } if (seqlen_q > 0) { - auto drop_seed_offset = std::make_pair(drop_seed, drop_offset); + auto rng_state_ptr = reinterpret_cast(rng_state.data_ptr()); + auto drop_seed_offset = std::make_pair(rng_state_ptr, rng_state_ptr + 1); ck_tile::stream_config stream_config{stream}; auto traits = diff --git a/csrc/flash_attn_ck/mha_fwd.cpp b/csrc/flash_attn_ck/mha_fwd.cpp index 88c0aac0e..7202cf2c8 100644 --- a/csrc/flash_attn_ck/mha_fwd.cpp +++ b/csrc/flash_attn_ck/mha_fwd.cpp @@ -46,7 +46,7 @@ fmha_fwd_args get_ck_fmha_fwd_args(bool has_lse, at::Tensor dropout_randval, float softmax_scale, float p_dropout, - std::pair drop_seed_offset) + std::pair drop_seed_offset) { // q: (batch_size, seqlen_q, nheads, d) // k: (batch_size, seqlen_k, nheads_k, d) @@ -254,10 +254,9 @@ mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num p = torch::empty({ 0 }, opts); } - uint64_t drop_seed = 1, drop_offset = 0; int64_t counter_offset = batch_size * num_heads * ck_tile::get_warp_size(); - auto options = torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA); - auto rng_state = torch::empty({2}, options.dtype(torch::kInt64)); + auto rng_state = torch::empty({2}, opts.dtype(torch::kInt64)); + auto rng_state_ptr = reinterpret_cast(rng_state.data_ptr()); if (p_dropout > 0.0) { auto gen = at::get_generator_or_default( @@ -265,14 +264,12 @@ mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num // See Note [Acquire lock when using random generators] std::lock_guard lock(gen->mutex_); auto philox_args = gen->philox_cuda_state(counter_offset); - std::tie(drop_seed, drop_offset) = flash::unpack(philox_args); + hipLaunchKernelGGL( + flash::ParsePhiloxCudaState, dim3(1), dim3(64), 0, 0, philox_args, rng_state_ptr); } - rng_state[0] = *(reinterpret_cast(&drop_seed)); - rng_state[1] = *(reinterpret_cast(&drop_offset)); - if (seqlen_k > 0) { - auto drop_seed_offset = std::make_pair(drop_seed, drop_offset); + auto drop_seed_offset = std::make_pair(rng_state_ptr, rng_state_ptr + 1); auto stream = at::cuda::getCurrentHIPStream().stream(); ck_tile::stream_config stream_config{stream}; diff --git a/csrc/flash_attn_ck/mha_varlen_bwd.cpp b/csrc/flash_attn_ck/mha_varlen_bwd.cpp index 600ae623d..2e5dd7b51 100644 --- a/csrc/flash_attn_ck/mha_varlen_bwd.cpp +++ b/csrc/flash_attn_ck/mha_varlen_bwd.cpp @@ -51,7 +51,7 @@ fmha_bwd_args get_ck_fmha_varlen_bwd_args(const mask_info &mask, at::Tensor dv, float softmax_scale, float p_dropout, - std::pair drop_seed_offset) + std::pair drop_seed_offset) { ck_tile::index_t total_q = q.size(0); ck_tile::index_t total_k = k.size(0); @@ -223,7 +223,7 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads const float /*softcap*/, const bool deterministic, c10::optional gen_, - c10::optional &rng_state) + c10::optional &rng_state_) { #ifdef FLASHATTENTION_DISABLE_BACKWARD TORCH_CHECK(false, "This flash attention build does not support backward."); @@ -361,22 +361,26 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads auto gen = at::get_generator_or_default( gen_, at::cuda::detail::getDefaultCUDAGenerator()); - uint64_t drop_seed = 1, drop_offset = 0; int64_t counter_offset = batch_size * num_heads * ck_tile::get_warp_size(); + at::Tensor rng_state; - if (rng_state.has_value()) { - uint64_t* d = reinterpret_cast(rng_state.value().data_ptr()); - drop_seed = d[0]; - drop_offset = d[1]; + if (rng_state_.has_value()) { + rng_state = rng_state_.value(); } else if(is_dropout) { + rng_state = torch::empty({2}, opts.dtype(torch::kInt64)); // See Note [Acquire lock when using random generators] std::lock_guard lock(gen->mutex_); auto philox_args = gen->philox_cuda_state(counter_offset); - std::tie(drop_seed, drop_offset) = flash::unpack(philox_args); + hipLaunchKernelGGL( + flash::ParsePhiloxCudaState, dim3(1), dim3(64), 0, 0, + philox_args, reinterpret_cast(rng_state.data_ptr())); + } else { + rng_state = torch::empty({2}, opts.dtype(torch::kInt64)); } if (max_seqlen_q > 0) { - auto drop_seed_offset = std::make_pair(drop_seed, drop_offset); + auto rng_state_ptr = reinterpret_cast(rng_state.data_ptr()); + auto drop_seed_offset = std::make_pair(rng_state_ptr, rng_state_ptr + 1); ck_tile::stream_config stream_config{stream}; auto traits = diff --git a/csrc/flash_attn_ck/mha_varlen_fwd.cpp b/csrc/flash_attn_ck/mha_varlen_fwd.cpp index a1b1a7402..7e8a347d4 100644 --- a/csrc/flash_attn_ck/mha_varlen_fwd.cpp +++ b/csrc/flash_attn_ck/mha_varlen_fwd.cpp @@ -47,7 +47,7 @@ fmha_fwd_args get_ck_fmha_varlen_fwd_args(bool has_lse, at::Tensor dropout_randval, float softmax_scale, float p_dropout, - std::pair drop_seed_offset) + std::pair drop_seed_offset) { // q: (total_q, nheads, d) // k: (total_k, nheads_k, d) @@ -280,10 +280,9 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_si if (return_dropout_randval) {p.zero_();} } - uint64_t drop_seed = 1, drop_offset = 0; int64_t counter_offset = batch_size * num_heads * ck_tile::get_warp_size(); - auto options = torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA); - auto rng_state = torch::empty({2}, options.dtype(torch::kInt64)); + auto rng_state = torch::empty({2}, opts.dtype(torch::kInt64)); + auto rng_state_ptr = reinterpret_cast(rng_state.data_ptr()); if (p_dropout > 0.0) { auto gen = at::get_generator_or_default( @@ -291,14 +290,12 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_si // See Note [Acquire lock when using random generators] std::lock_guard lock(gen->mutex_); auto philox_args = gen->philox_cuda_state(counter_offset); - std::tie(drop_seed, drop_offset) = flash::unpack(philox_args); + hipLaunchKernelGGL( + flash::ParsePhiloxCudaState, dim3(1), dim3(64), 0, 0, philox_args, rng_state_ptr); } - rng_state[0] = *(reinterpret_cast(&drop_seed)); - rng_state[1] = *(reinterpret_cast(&drop_offset)); - if (max_seqlen_k > 0) { - auto drop_seed_offset = std::make_pair(drop_seed, drop_offset); + auto drop_seed_offset = std::make_pair(rng_state_ptr, rng_state_ptr + 1); auto stream = at::cuda::getCurrentHIPStream().stream(); ck_tile::stream_config stream_config{stream}; From 5467ceaa7f4169297136251a0ffee437e6f8a04b Mon Sep 17 00:00:00 2001 From: rocking Date: Tue, 5 Nov 2024 16:18:14 +0000 Subject: [PATCH 3/4] update CK --- csrc/composable_kernel | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/csrc/composable_kernel b/csrc/composable_kernel index 0023f01ab..464abd235 160000 --- a/csrc/composable_kernel +++ b/csrc/composable_kernel @@ -1 +1 @@ -Subproject commit 0023f01ab02b9cc05a98ae1a7753df1481252e4d +Subproject commit 464abd235e27c33422aa52ed2044af8fbcc3a88d From 86f9b1b3f0644b9452d40acd10dc340679b2af9a Mon Sep 17 00:00:00 2001 From: rocking Date: Tue, 5 Nov 2024 16:21:16 +0000 Subject: [PATCH 4/4] Remove useless "else" --- csrc/flash_attn_ck/mha_bwd.cpp | 2 -- 1 file changed, 2 deletions(-) diff --git a/csrc/flash_attn_ck/mha_bwd.cpp b/csrc/flash_attn_ck/mha_bwd.cpp index b6ad44dd1..e4a4b2a6b 100644 --- a/csrc/flash_attn_ck/mha_bwd.cpp +++ b/csrc/flash_attn_ck/mha_bwd.cpp @@ -349,8 +349,6 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num hipLaunchKernelGGL( flash::ParsePhiloxCudaState, dim3(1), dim3(64), 0, 0, philox_args, reinterpret_cast(rng_state.data_ptr())); - } else { - rng_state = torch::empty({2}, opts.dtype(torch::kInt64)); } if (seqlen_q > 0) {