From d83c4129a92e4258081f92dfafd34345b3b06130 Mon Sep 17 00:00:00 2001 From: rocking Date: Mon, 22 Jul 2024 19:00:07 +0000 Subject: [PATCH] Sync the api with upstream FA --- csrc/flash_attn_ck/flash_api.cpp | 5 +++++ csrc/flash_attn_ck/mha_bwd.cpp | 1 + csrc/flash_attn_ck/mha_fwd.cpp | 3 ++- csrc/flash_attn_ck/mha_varlen_bwd.cpp | 1 + csrc/flash_attn_ck/mha_varlen_fwd.cpp | 2 ++ 5 files changed, 11 insertions(+), 1 deletion(-) diff --git a/csrc/flash_attn_ck/flash_api.cpp b/csrc/flash_attn_ck/flash_api.cpp index 734208402..0c7474b97 100644 --- a/csrc/flash_attn_ck/flash_api.cpp +++ b/csrc/flash_attn_ck/flash_api.cpp @@ -15,6 +15,7 @@ mha_fwd(at::Tensor &q, bool is_causal, int window_size_left, int window_size_right, + const float softcap, const bool return_softmax, c10::optional gen_); @@ -26,6 +27,7 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads const at::Tensor &cu_seqlens_q, // b+1 const at::Tensor &cu_seqlens_k, // b+1 c10::optional &seqused_k, // b. If given, only this many elements of each batch element's keys are used. + c10::optional &leftpad_k_, // batch_size c10::optional &block_table_, // batch_size x max_num_blocks_per_seq c10::optional &alibi_slopes_, // num_heads or b x num_heads int max_seqlen_q, @@ -36,6 +38,7 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads bool is_causal, int window_size_left, int window_size_right, + const float softcap, const bool return_softmax, c10::optional gen_); @@ -55,6 +58,7 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num const bool is_causal, int window_size_left, int window_size_right, + const float softcap, const bool deterministic, c10::optional gen_, c10::optional &rng_state); @@ -80,6 +84,7 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads const bool is_causal, int window_size_left, int window_size_right, + const float softcap, const bool deterministic, c10::optional gen_, c10::optional &rng_state); diff --git a/csrc/flash_attn_ck/mha_bwd.cpp b/csrc/flash_attn_ck/mha_bwd.cpp index c2ae09208..884215adf 100644 --- a/csrc/flash_attn_ck/mha_bwd.cpp +++ b/csrc/flash_attn_ck/mha_bwd.cpp @@ -180,6 +180,7 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num const bool is_causal, int window_size_left, int window_size_right, + const float /*softcap*/, const bool deterministic, c10::optional gen_, c10::optional &rng_state) diff --git a/csrc/flash_attn_ck/mha_fwd.cpp b/csrc/flash_attn_ck/mha_fwd.cpp index 8d1fce0f8..c1eeba507 100644 --- a/csrc/flash_attn_ck/mha_fwd.cpp +++ b/csrc/flash_attn_ck/mha_fwd.cpp @@ -161,6 +161,7 @@ mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num bool is_causal, int window_size_left, int window_size_right, + const float /*softcap*/, const bool return_dropout_randval, c10::optional gen_) { @@ -187,7 +188,7 @@ mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num const int head_size_og = sizes[3]; const int seqlen_k = k.size(1); const int num_heads_k = k.size(2); - TORCH_CHECK(batch_size > 0, "batch size must be postive"); + TORCH_CHECK(batch_size > 0, "batch size must be positive"); TORCH_CHECK(head_size_og <= 256, "CK only supports head dimension at most 256"); TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query"); diff --git a/csrc/flash_attn_ck/mha_varlen_bwd.cpp b/csrc/flash_attn_ck/mha_varlen_bwd.cpp index 63b5d1140..d8eabab15 100644 --- a/csrc/flash_attn_ck/mha_varlen_bwd.cpp +++ b/csrc/flash_attn_ck/mha_varlen_bwd.cpp @@ -190,6 +190,7 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads const bool is_causal, int window_size_left, int window_size_right, + const float /*softcap*/, const bool deterministic, c10::optional gen_, c10::optional &rng_state) diff --git a/csrc/flash_attn_ck/mha_varlen_fwd.cpp b/csrc/flash_attn_ck/mha_varlen_fwd.cpp index cab0dd942..2d2f4cfef 100644 --- a/csrc/flash_attn_ck/mha_varlen_fwd.cpp +++ b/csrc/flash_attn_ck/mha_varlen_fwd.cpp @@ -162,6 +162,7 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_si const at::Tensor &cu_seqlens_q, // b+1 const at::Tensor &cu_seqlens_k, // b+1 c10::optional & /*seqused_k*/, + c10::optional &/*leftpad_k_*/, // batch_size c10::optional &block_table_, // batch_size x max_num_blocks_per_seq c10::optional &alibi_slopes_, // num_heads or b x num_heads int max_seqlen_q, @@ -172,6 +173,7 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_si bool is_causal, int window_size_left, int window_size_right, + const float /*softcap*/, const bool return_dropout_randval, c10::optional gen_) {