From 4e1d6ef8d3f4127123d29f0b71da9889e6230c01 Mon Sep 17 00:00:00 2001 From: Michael Melesse Date: Wed, 30 Oct 2024 11:52:14 -0400 Subject: [PATCH] Enable sequence_parallel in bwd (#89) * sequence_parallel working on bwd_impl test * fix qkv error * save * save * save * bwd 3 times faster * clean up * fix varlen bug * use copy back dict * fix qkvpacked bug * reduce bench sizes * print copy back --- .gitignore | 4 +- flash_attn/flash_attn_triton_amd/bench.py | 16 ++-- .../flash_attn_triton_amd/bwd_prefill.py | 89 +++++++++++-------- flash_attn/flash_attn_triton_amd/test.py | 12 +-- flash_attn/flash_attn_triton_amd/utils.py | 1 + tests/test_flash_attn_triton_amd.py | 6 +- 6 files changed, 76 insertions(+), 52 deletions(-) diff --git a/.gitignore b/.gitignore index 131fe3ca5..dede7ecf0 100644 --- a/.gitignore +++ b/.gitignore @@ -19,14 +19,16 @@ var/ *.egg-info/ .installed.cfg *.egg -.eggs # IDE-related .idea/ # Dev venv + +# AMD scripts +.eggs *.log core.* *.csv diff --git a/flash_attn/flash_attn_triton_amd/bench.py b/flash_attn/flash_attn_triton_amd/bench.py index d510d41e8..91939f831 100644 --- a/flash_attn/flash_attn_triton_amd/bench.py +++ b/flash_attn/flash_attn_triton_amd/bench.py @@ -50,21 +50,21 @@ def get_benchmark_configs(args, varlen=False): (16, 16, 16, 1024, 1024), (8, 16, 16, 2048, 2048), (4, 16, 16, 4096, 4096), - (2, 16, 16, 8192, 8192), - (1, 16, 16, 16384, 16384), + (1, 8, 8, 8192, 8192), + (1, 2, 2, 16384, 16384), (2, 48, 48, 1024, 1024), (2, 48, 48, 2048, 1024), - (2, 48, 48, 4096, 8192), - (2, 48, 48, 8192, 4096), - (2, 48, 48, 16384, 8192), - (8, 16, 16, 1989, 15344), + (1, 8, 8, 4096, 8192), + (1, 8, 8, 8192, 4096), + (2, 4, 4, 16384, 8192), + (2, 8, 8, 1989, 15344), (4, 16, 16, 4097, 163), (2, 16, 16, 8122, 2159), (1, 16, 16, 16281, 7), (2, 48, 48, 1021, 1020), (2, 48, 48, 2001, 2048), - (2, 48, 48, 3996, 9639), - (2, 48, 48, 8181, 1021), + (2, 8, 8, 3996, 9639), + (2, 8, 8, 8181, 1021), ] def gen_fn_inputs(fn_name, BATCH, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, dtype, device, layout, causal): diff --git a/flash_attn/flash_attn_triton_amd/bwd_prefill.py b/flash_attn/flash_attn_triton_amd/bwd_prefill.py index c2f009a82..84212235a 100644 --- a/flash_attn/flash_attn_triton_amd/bwd_prefill.py +++ b/flash_attn/flash_attn_triton_amd/bwd_prefill.py @@ -1,7 +1,7 @@ import torch import triton import triton.language as tl -from .utils import get_shape_from_layout, get_strides_from_layout, DEBUG +from .utils import get_shape_from_layout, get_strides_from_layout, DEBUG, PERF @triton.jit def _bwd_preprocess_use_o( @@ -74,7 +74,6 @@ def _bwd_preprocess_use_o( tl.store(delta_ptrs, delta, mask=mask_m) - @triton.jit def _bwd_kernel_one_col_block( Q, @@ -311,7 +310,7 @@ def _bwd_kernel( dk_offset = DK + off_z * stride_kz + off_h * stride_kh + k_start * stride_kn dv_offset = DV + off_z * stride_vz + off_h * stride_vh + k_start * stride_vn if SEQUENCE_PARALLEL: - dq_offset = DQ + stride_dq_all * start_n + off_z * stride_qz + off_h * stride_qh + q_start * stride_qm + dq_offset = DQ + start_n * stride_dq_all + off_z * stride_qz + off_h * stride_qh + q_start * stride_qm else: dq_offset = DQ + off_z * stride_qz + off_h * stride_qh + q_start * stride_qm @@ -451,7 +450,7 @@ def attention_prefill_backward_triton_impl( max_seqlen_q: int, max_seqlen_k: int, use_exp2: bool, - sequence_parallel = False, + sequence_parallel = True, ): if DEBUG: print() @@ -489,7 +488,6 @@ def attention_prefill_backward_triton_impl( stride_kz, stride_kh, stride_kn, stride_kk = k_strides stride_vz, stride_vh, stride_vn, stride_vk = v_strides stride_oz, stride_oh, stride_om, stride_ok = o_strides - stride_dq_all = q.numel() batch_headsize = batch * nheads_q is_varlen = layout == "thd" @@ -515,31 +513,46 @@ def attention_prefill_backward_triton_impl( ACTUAL_BLOCK_DMODEL = head_size do = do.contiguous() - if sequence_parallel: - # replicate q for each parallel sequence - replicas = num_blocks_n - dq_shape = (replicas,) + q.shape + # NOTE: we might need to copy the output tensor if they are not continuous or have other issues + copy_back = {"dq": False, "dk": False, "dv": False} + + # deal with dq + if dq is None: + if sequence_parallel: + dq = torch.zeros((num_blocks_n,) + q.shape, device=q.device, dtype=q.dtype) + else: + dq = torch.zeros(q.shape, device=q.device, dtype=q.dtype) else: - dq_shape = q.shape + dq_og = dq + if (not dq.is_contiguous()): + dq = dq.contiguous() + copy_back["dq"] = True + + if sequence_parallel: + dq = torch.zeros((num_blocks_n,) + q.shape, device=q.device, dtype=q.dtype) + copy_back["dq"] = True + else: + # NOTE: the kernel does inplace accumlation so dq has to be zeros. This avoids the case where we are passed empty dq and it is not all zeros + dq.zero_() + stride_dq_all = dq.stride()[0] - is_qkvpacked = False - if dq is None or dk is None or dv is None: - dq = torch.zeros(dq_shape, device=q.device, dtype=q.dtype) + # deal with dk, dv + if (dk is None) or (dv is None): dk = torch.empty_like(k) dv = torch.empty_like(v) - elif (not dq.is_contiguous()) or (not dq.is_contiguous()) or (not dq.is_contiguous()): - if DEBUG: - print("Not contigious and setting is packed to True") - is_qkvpacked = True - dq_og = dq - dq = dq.contiguous() + else: + if (not dk.is_contiguous()): dk_og = dk dk = dk.contiguous() + copy_back["dk"] = True + + if (not dv.is_contiguous()): dv_og = dv - dv = dv.contiguous() - - # NOTE: the kernel does inplace accumlation so dq has to be zeros. This avoids the case where we are passed empty dq and it is not all zeros - dq.zero_() + dv = dv.contiguous() + copy_back["dv"] = True + + if DEBUG: + print("copy_back:", copy_back) # assert contigious assert do.is_contiguous() @@ -647,26 +660,32 @@ def attention_prefill_backward_triton_impl( IS_VARLEN=is_varlen ) - if len(dq.shape) == 5: + if DEBUG: + print("_bwd_kernel outputs") + print("dq:", dq, dq.shape) + print("dk:", dk, dk.shape) + print("dv:", dv, dv.shape) + print("delta:", delta, delta.shape) + + if sequence_parallel: dq = dq.sum(dim=0) if DEBUG: - print("_bwd_kernel outputs") + print("attention_prefill_backward_triton_new_impl outputs") print("dq:", dq, dq.shape) print("dk:", dk, dk.shape) print("dv:", dv, dv.shape) print("delta:", delta, delta.shape) - - if is_qkvpacked: - if DEBUG: - print("Copying back to original tensors due to ispacked") - - # copy back results to og tensors + print("copy_back:", copy_back) + + if copy_back["dq"]: dq_og.copy_(dq) + dq = dq_og + if copy_back["dk"]: dk_og.copy_(dk) + dk = dk_og + if copy_back["dv"]: dv_og.copy_(dv) - return dq_og, dk_og, dv_og, delta, None, None - else: - return dq, dk, dv, delta, None, None - + dv = dv_og + return dq, dk, dv, delta, None, None diff --git a/flash_attn/flash_attn_triton_amd/test.py b/flash_attn/flash_attn_triton_amd/test.py index 515c98e71..9a6ab8dab 100644 --- a/flash_attn/flash_attn_triton_amd/test.py +++ b/flash_attn/flash_attn_triton_amd/test.py @@ -505,8 +505,8 @@ def test_op_prefill_fwd_impl(Z, H, N_CTX_Q, N_CTX_K, D_HEAD, causal, return_scor (1, 1, 4, 4, 32), (1, 1, 16, 16, 16), (1, 1, 32, 32, 16), - (1, 1, 64, 64, 16), # pass # smallest head_size = 16 - (1, 1, 64, 64, 64), # pass # smallest seq len seems to be 64 + (1, 1, 64, 64, 16), + (1, 1, 64, 64, 64), (1, 1, 64, 128, 32), (1, 1, 128, 128, 64), (1, 1, 128, 256, 45), @@ -529,10 +529,11 @@ def test_op_prefill_fwd_impl(Z, H, N_CTX_Q, N_CTX_K, D_HEAD, causal, return_scor (1, 16, 1024, 1024, 128), ]) @pytest.mark.parametrize('causal', [True, False]) -@pytest.mark.parametrize('use_exp2', [False]) # using exp2 causas issue with exp2 +@pytest.mark.parametrize('use_exp2', [False]) # FIXME: using exp2 causes issue when used with causal @pytest.mark.parametrize('layout', ["bhsd", "bshd", "thd"]) +@pytest.mark.parametrize('sequence_parallel', [True, False]) @pytest.mark.parametrize('DEBUG_INPUT', [False]) # debug output causes nans in both new and old backend -def test_op_prefill_bwd_impl(Z, H, N_CTX_Q, N_CTX_K, D_HEAD, causal, use_exp2, layout, DEBUG_INPUT): +def test_op_prefill_bwd_impl(Z, H, N_CTX_Q, N_CTX_K, D_HEAD, causal, use_exp2, layout, sequence_parallel, DEBUG_INPUT): dtype = torch.float16 torch.manual_seed(20) # seed from test_op_bwd @@ -619,7 +620,8 @@ def test_op_prefill_bwd_impl(Z, H, N_CTX_Q, N_CTX_K, D_HEAD, causal, use_exp2, l metadata.cu_seqlens_k, metadata.max_seqlens_q, metadata.max_seqlens_k, - use_exp2 + use_exp2, + sequence_parallel=sequence_parallel ) # =============================================== Check ============================================================== diff --git a/flash_attn/flash_attn_triton_amd/utils.py b/flash_attn/flash_attn_triton_amd/utils.py index 41fb83156..b59486495 100644 --- a/flash_attn/flash_attn_triton_amd/utils.py +++ b/flash_attn/flash_attn_triton_amd/utils.py @@ -5,6 +5,7 @@ AUTOTUNE = os.environ.get('FLASH_ATTENTION_TRITON_AMD_AUTOTUNE', '1').lower() in ('1', 'true', 'yes') DEBUG = os.environ.get('FLASH_ATTENTION_TRITON_AMD_DEBUG', '0').lower() in ('1', 'true', 'yes') +PERF = os.environ.get('FLASH_ATTENTION_TRITON_AMD_PERF', '0').lower() in ('1', 'true', 'yes') class MetaData(): cu_seqlens_q = None diff --git a/tests/test_flash_attn_triton_amd.py b/tests/test_flash_attn_triton_amd.py index ee17697ad..d64246f95 100644 --- a/tests/test_flash_attn_triton_amd.py +++ b/tests/test_flash_attn_triton_amd.py @@ -585,10 +585,10 @@ def get_dropout_fraction( @pytest.mark.parametrize("d", [32, 40, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256]) # @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192, 224, 256]) # @pytest.mark.parametrize('d', [32, 64, 96, 128]) -# @pytest.mark.parametrize("d", [64]) +# @pytest.mark.parametrize("d", [32]) # @pytest.mark.parametrize('seqlen', [128, 256, 384, 512, 768, 1024, 2048]) @pytest.mark.parametrize("seqlen", [97, 128, 200, 384, 768, 1024, 1025, 2048]) -# @pytest.mark.parametrize("seqlen", [512]) +# @pytest.mark.parametrize("seqlen", [128]) # @pytest.mark.parametrize("dropout_p", [0.0, 0.17]) @pytest.mark.parametrize("dropout_p", [0.0]) def test_flash_attn_qkvpacked(seqlen, d, dropout_p, causal, local, alibi, deterministic, dtype): @@ -746,7 +746,7 @@ def test_flash_attn_qkvpacked(seqlen, d, dropout_p, causal, local, alibi, determ # @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192, 224, 256]) # @pytest.mark.parametrize('d', [32]) @pytest.mark.parametrize("seqlen", [97, 128, 200, 257, 384, 512, 768, 1025, 2048]) -# @pytest.mark.parametrize('seqlen', [2]) +# @pytest.mark.parametrize('seqlen', [128]) # @pytest.mark.parametrize("dropout_p", [0.0, 0.17]) @pytest.mark.parametrize('dropout_p', [0.0]) def test_flash_attn_varlen_qkvpacked(