Skip to content

Commit

Permalink
Sync the api with upstream FA
Browse files Browse the repository at this point in the history
  • Loading branch information
rocking5566 committed Jul 22, 2024
1 parent a4417c7 commit d83c412
Show file tree
Hide file tree
Showing 5 changed files with 11 additions and 1 deletion.
5 changes: 5 additions & 0 deletions csrc/flash_attn_ck/flash_api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ mha_fwd(at::Tensor &q,
bool is_causal,
int window_size_left,
int window_size_right,
const float softcap,
const bool return_softmax,
c10::optional<at::Generator> gen_);

Expand All @@ -26,6 +27,7 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads
const at::Tensor &cu_seqlens_q, // b+1
const at::Tensor &cu_seqlens_k, // b+1
c10::optional<at::Tensor> &seqused_k, // b. If given, only this many elements of each batch element's keys are used.
c10::optional<const at::Tensor> &leftpad_k_, // batch_size
c10::optional<at::Tensor> &block_table_, // batch_size x max_num_blocks_per_seq
c10::optional<at::Tensor> &alibi_slopes_, // num_heads or b x num_heads
int max_seqlen_q,
Expand All @@ -36,6 +38,7 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads
bool is_causal,
int window_size_left,
int window_size_right,
const float softcap,
const bool return_softmax,
c10::optional<at::Generator> gen_);

Expand All @@ -55,6 +58,7 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num
const bool is_causal,
int window_size_left,
int window_size_right,
const float softcap,
const bool deterministic,
c10::optional<at::Generator> gen_,
c10::optional<at::Tensor> &rng_state);
Expand All @@ -80,6 +84,7 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads
const bool is_causal,
int window_size_left,
int window_size_right,
const float softcap,
const bool deterministic,
c10::optional<at::Generator> gen_,
c10::optional<at::Tensor> &rng_state);
Expand Down
1 change: 1 addition & 0 deletions csrc/flash_attn_ck/mha_bwd.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,7 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num
const bool is_causal,
int window_size_left,
int window_size_right,
const float /*softcap*/,
const bool deterministic,
c10::optional<at::Generator> gen_,
c10::optional<at::Tensor> &rng_state)
Expand Down
3 changes: 2 additions & 1 deletion csrc/flash_attn_ck/mha_fwd.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,7 @@ mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num
bool is_causal,
int window_size_left,
int window_size_right,
const float /*softcap*/,
const bool return_dropout_randval,
c10::optional<at::Generator> gen_)
{
Expand All @@ -187,7 +188,7 @@ mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num
const int head_size_og = sizes[3];
const int seqlen_k = k.size(1);
const int num_heads_k = k.size(2);
TORCH_CHECK(batch_size > 0, "batch size must be postive");
TORCH_CHECK(batch_size > 0, "batch size must be positive");
TORCH_CHECK(head_size_og <= 256, "CK only supports head dimension at most 256");
TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query");

Expand Down
1 change: 1 addition & 0 deletions csrc/flash_attn_ck/mha_varlen_bwd.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,7 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads
const bool is_causal,
int window_size_left,
int window_size_right,
const float /*softcap*/,
const bool deterministic,
c10::optional<at::Generator> gen_,
c10::optional<at::Tensor> &rng_state)
Expand Down
2 changes: 2 additions & 0 deletions csrc/flash_attn_ck/mha_varlen_fwd.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,7 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_si
const at::Tensor &cu_seqlens_q, // b+1
const at::Tensor &cu_seqlens_k, // b+1
c10::optional<at::Tensor> & /*seqused_k*/,
c10::optional<const at::Tensor> &/*leftpad_k_*/, // batch_size
c10::optional<at::Tensor> &block_table_, // batch_size x max_num_blocks_per_seq
c10::optional<at::Tensor> &alibi_slopes_, // num_heads or b x num_heads
int max_seqlen_q,
Expand All @@ -172,6 +173,7 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_si
bool is_causal,
int window_size_left,
int window_size_right,
const float /*softcap*/,
const bool return_dropout_randval,
c10::optional<at::Generator> gen_)
{
Expand Down

0 comments on commit d83c412

Please sign in to comment.