From ef3e358a2561474b319a9c3fbcb694ba6be9bc43 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Sun, 21 Jul 2024 23:24:38 -0700 Subject: [PATCH] remove lambda (#1056) --- flash_attn/flash_attn_interface.py | 7 ++----- hopper/flash_attn_interface.py | 4 ++-- 2 files changed, 4 insertions(+), 7 deletions(-) diff --git a/flash_attn/flash_attn_interface.py b/flash_attn/flash_attn_interface.py index 8e7076d8c..ecb3515c0 100644 --- a/flash_attn/flash_attn_interface.py +++ b/flash_attn/flash_attn_interface.py @@ -11,6 +11,8 @@ # isort: on +def maybe_contiguous(x): + return x.contiguous() if x is not None and x.stride(-1) != 1 else x def _get_block_size_n(device, head_dim, is_dropout, is_causal): # This should match the block sizes in the CUDA kernel @@ -46,7 +48,6 @@ def _get_block_size_n(device, head_dim, is_dropout, is_causal): def _flash_attn_forward( q, k, v, dropout_p, softmax_scale, causal, window_size, softcap, alibi_slopes, return_softmax ): - maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x q, k, v = [maybe_contiguous(x) for x in (q, k, v)] out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = flash_attn_cuda.fwd( q, @@ -85,7 +86,6 @@ def _flash_attn_varlen_forward( leftpad_k=None, seqused_k=None, ): - maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x q, k, v = [maybe_contiguous(x) for x in (q, k, v)] out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = flash_attn_cuda.varlen_fwd( q, @@ -134,7 +134,6 @@ def _flash_attn_backward( deterministic, rng_state=None, ): - maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x # dq, dk, dv are allocated by us so they should already be contiguous dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)] ( @@ -189,7 +188,6 @@ def _flash_attn_varlen_backward( deterministic, rng_state=None, ): - maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x # dq, dk, dv are allocated by us so they should already be contiguous dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)] ( @@ -1253,7 +1251,6 @@ def flash_attn_with_kvcache( """ assert k_cache.stride(-1) == 1, "k_cache must have contiguous last dimension" assert v_cache.stride(-1) == 1, "v_cache must have contiguous last dimension" - maybe_contiguous = lambda x: x.contiguous() if x is not None and x.stride(-1) != 1 else x q, k, v = [maybe_contiguous(x) for x in (q, k, v)] if softmax_scale is None: softmax_scale = q.shape[-1] ** (-0.5) diff --git a/hopper/flash_attn_interface.py b/hopper/flash_attn_interface.py index 561fc6fad..c09342826 100644 --- a/hopper/flash_attn_interface.py +++ b/hopper/flash_attn_interface.py @@ -11,9 +11,10 @@ # isort: on +def maybe_contiguous(x): + return x.contiguous() if x is not None and x.stride(-1) != 1 else x def _flash_attn_forward(q, k, v, softmax_scale, causal): - maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x q, k, v = [maybe_contiguous(x) for x in (q, k, v)] out, q, k, v, out_padded, softmax_lse, S_dmask = flashattn_hopper_cuda.fwd( q, @@ -39,7 +40,6 @@ def _flash_attn_backward( softmax_scale, causal ): - maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x # dq, dk, dv are allocated by us so they should already be contiguous dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)] dq, dk, dv, softmax_d, = flashattn_hopper_cuda.bwd(