Skip to content

Commit

Permalink
remove lambda (Dao-AILab#1056)
Browse files Browse the repository at this point in the history
  • Loading branch information
youkaichao authored Jul 22, 2024
1 parent 4df62e1 commit ef3e358
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 7 deletions.
7 changes: 2 additions & 5 deletions flash_attn/flash_attn_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@

# isort: on

def maybe_contiguous(x):
return x.contiguous() if x is not None and x.stride(-1) != 1 else x

def _get_block_size_n(device, head_dim, is_dropout, is_causal):
# This should match the block sizes in the CUDA kernel
Expand Down Expand Up @@ -46,7 +48,6 @@ def _get_block_size_n(device, head_dim, is_dropout, is_causal):
def _flash_attn_forward(
q, k, v, dropout_p, softmax_scale, causal, window_size, softcap, alibi_slopes, return_softmax
):
maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x
q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = flash_attn_cuda.fwd(
q,
Expand Down Expand Up @@ -85,7 +86,6 @@ def _flash_attn_varlen_forward(
leftpad_k=None,
seqused_k=None,
):
maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x
q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = flash_attn_cuda.varlen_fwd(
q,
Expand Down Expand Up @@ -134,7 +134,6 @@ def _flash_attn_backward(
deterministic,
rng_state=None,
):
maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x
# dq, dk, dv are allocated by us so they should already be contiguous
dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)]
(
Expand Down Expand Up @@ -189,7 +188,6 @@ def _flash_attn_varlen_backward(
deterministic,
rng_state=None,
):
maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x
# dq, dk, dv are allocated by us so they should already be contiguous
dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)]
(
Expand Down Expand Up @@ -1253,7 +1251,6 @@ def flash_attn_with_kvcache(
"""
assert k_cache.stride(-1) == 1, "k_cache must have contiguous last dimension"
assert v_cache.stride(-1) == 1, "v_cache must have contiguous last dimension"
maybe_contiguous = lambda x: x.contiguous() if x is not None and x.stride(-1) != 1 else x
q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
if softmax_scale is None:
softmax_scale = q.shape[-1] ** (-0.5)
Expand Down
4 changes: 2 additions & 2 deletions hopper/flash_attn_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,10 @@

# isort: on

def maybe_contiguous(x):
return x.contiguous() if x is not None and x.stride(-1) != 1 else x

def _flash_attn_forward(q, k, v, softmax_scale, causal):
maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x
q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
out, q, k, v, out_padded, softmax_lse, S_dmask = flashattn_hopper_cuda.fwd(
q,
Expand All @@ -39,7 +40,6 @@ def _flash_attn_backward(
softmax_scale,
causal
):
maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x
# dq, dk, dv are allocated by us so they should already be contiguous
dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)]
dq, dk, dv, softmax_d, = flashattn_hopper_cuda.bwd(
Expand Down

0 comments on commit ef3e358

Please sign in to comment.