Skip to content

Commit

Permalink
Fix stride issue in flash_attn_interface
Browse files Browse the repository at this point in the history
  • Loading branch information
clintg6 committed May 29, 2024
1 parent 2554f49 commit 530f407
Showing 1 changed file with 9 additions and 3 deletions.
12 changes: 9 additions & 3 deletions flash_attn/flash_attn_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,9 +60,15 @@ def _flash_attn_varlen_forward(q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q

def _flash_attn_backward(dout, q, k, v, out, softmax_lse, dq, dk, dv,
dropout_p, softmax_scale, causal, rng_state=None):
maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x
maybe_contiguous = lambda x: x.contiguous() if not x.is_contiguous() else x
# dq, dk, dv are allocated by us so they should already be contiguous
dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)]

if out.stride() != dout.stride():
out = out.as_strided(dout.size(),dout.stride())
if dq.stride() != q.stride():
dq = dq.as_strided(q.size(),q.stride())

dq, dk, dv, softmax_d, = flash_attn_cuda.bwd(
dout, q, k, v, out, softmax_lse, dq, dk, dv, dropout_p,
softmax_scale, causal, None, rng_state
Expand All @@ -73,7 +79,7 @@ def _flash_attn_backward(dout, q, k, v, out, softmax_lse, dq, dk, dv,
def _flash_attn_varlen_backward(dout, q, k, v, out, softmax_lse, dq, dk, dv,
cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k,
dropout_p, softmax_scale, causal, rng_state=None):
maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x
maybe_contiguous = lambda x: x.contiguous() if not x.is_contiguous() else x
# dq, dk, dv are allocated by us so they should already be contiguous
dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)]
dq, dk, dv, softmax_d, = flash_attn_cuda.varlen_bwd(
Expand Down Expand Up @@ -232,7 +238,7 @@ def forward(ctx, q, k, v, dropout_p, softmax_scale, causal, return_softmax):
@staticmethod
def backward(ctx, dout, *args):
q, k, v, out, softmax_lse, rng_state = ctx.saved_tensors
dq, dk, dv = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v)
dq, dk, dv = torch.empty_strided(q.size(),q.stride(), dtype=q.dtype, device=q.device), torch.empty_strided(k.size(), k.stride(), dtype=k.dtype, device=k.device), torch.empty_strided(v.size(), v.stride(), dtype=v.dtype, device=v.device)
_flash_attn_backward(
dout, q, k, v, out, softmax_lse,
dq, dk, dv, ctx.dropout_p, ctx.softmax_scale, ctx.causal,
Expand Down

0 comments on commit 530f407

Please sign in to comment.