Skip to content

Commit

Permalink
Enable sequence_parallel in bwd (#89)
Browse files Browse the repository at this point in the history
* sequence_parallel working on bwd_impl test

* fix qkv error

* save

* save

* save

* bwd 3 times faster

* clean up

* fix varlen bug

* use copy back dict

* fix qkvpacked bug

* reduce bench sizes

* print copy back
  • Loading branch information
micmelesse authored Oct 30, 2024
1 parent 730d260 commit 4e1d6ef
Show file tree
Hide file tree
Showing 6 changed files with 76 additions and 52 deletions.
4 changes: 3 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,16 @@ var/
*.egg-info/
.installed.cfg
*.egg
.eggs

# IDE-related
.idea/

# Dev
venv

# AMD
scripts
.eggs
*.log
core.*
*.csv
Expand Down
16 changes: 8 additions & 8 deletions flash_attn/flash_attn_triton_amd/bench.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,21 +50,21 @@ def get_benchmark_configs(args, varlen=False):
(16, 16, 16, 1024, 1024),
(8, 16, 16, 2048, 2048),
(4, 16, 16, 4096, 4096),
(2, 16, 16, 8192, 8192),
(1, 16, 16, 16384, 16384),
(1, 8, 8, 8192, 8192),
(1, 2, 2, 16384, 16384),
(2, 48, 48, 1024, 1024),
(2, 48, 48, 2048, 1024),
(2, 48, 48, 4096, 8192),
(2, 48, 48, 8192, 4096),
(2, 48, 48, 16384, 8192),
(8, 16, 16, 1989, 15344),
(1, 8, 8, 4096, 8192),
(1, 8, 8, 8192, 4096),
(2, 4, 4, 16384, 8192),
(2, 8, 8, 1989, 15344),
(4, 16, 16, 4097, 163),
(2, 16, 16, 8122, 2159),
(1, 16, 16, 16281, 7),
(2, 48, 48, 1021, 1020),
(2, 48, 48, 2001, 2048),
(2, 48, 48, 3996, 9639),
(2, 48, 48, 8181, 1021),
(2, 8, 8, 3996, 9639),
(2, 8, 8, 8181, 1021),
]

def gen_fn_inputs(fn_name, BATCH, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, dtype, device, layout, causal):
Expand Down
89 changes: 54 additions & 35 deletions flash_attn/flash_attn_triton_amd/bwd_prefill.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import torch
import triton
import triton.language as tl
from .utils import get_shape_from_layout, get_strides_from_layout, DEBUG
from .utils import get_shape_from_layout, get_strides_from_layout, DEBUG, PERF

@triton.jit
def _bwd_preprocess_use_o(
Expand Down Expand Up @@ -74,7 +74,6 @@ def _bwd_preprocess_use_o(
tl.store(delta_ptrs, delta, mask=mask_m)



@triton.jit
def _bwd_kernel_one_col_block(
Q,
Expand Down Expand Up @@ -311,7 +310,7 @@ def _bwd_kernel(
dk_offset = DK + off_z * stride_kz + off_h * stride_kh + k_start * stride_kn
dv_offset = DV + off_z * stride_vz + off_h * stride_vh + k_start * stride_vn
if SEQUENCE_PARALLEL:
dq_offset = DQ + stride_dq_all * start_n + off_z * stride_qz + off_h * stride_qh + q_start * stride_qm
dq_offset = DQ + start_n * stride_dq_all + off_z * stride_qz + off_h * stride_qh + q_start * stride_qm
else:
dq_offset = DQ + off_z * stride_qz + off_h * stride_qh + q_start * stride_qm

Expand Down Expand Up @@ -451,7 +450,7 @@ def attention_prefill_backward_triton_impl(
max_seqlen_q: int,
max_seqlen_k: int,
use_exp2: bool,
sequence_parallel = False,
sequence_parallel = True,
):
if DEBUG:
print()
Expand Down Expand Up @@ -489,7 +488,6 @@ def attention_prefill_backward_triton_impl(
stride_kz, stride_kh, stride_kn, stride_kk = k_strides
stride_vz, stride_vh, stride_vn, stride_vk = v_strides
stride_oz, stride_oh, stride_om, stride_ok = o_strides
stride_dq_all = q.numel()
batch_headsize = batch * nheads_q
is_varlen = layout == "thd"

Expand All @@ -515,31 +513,46 @@ def attention_prefill_backward_triton_impl(
ACTUAL_BLOCK_DMODEL = head_size

do = do.contiguous()
if sequence_parallel:
# replicate q for each parallel sequence
replicas = num_blocks_n
dq_shape = (replicas,) + q.shape
# NOTE: we might need to copy the output tensor if they are not continuous or have other issues
copy_back = {"dq": False, "dk": False, "dv": False}

# deal with dq
if dq is None:
if sequence_parallel:
dq = torch.zeros((num_blocks_n,) + q.shape, device=q.device, dtype=q.dtype)
else:
dq = torch.zeros(q.shape, device=q.device, dtype=q.dtype)
else:
dq_shape = q.shape
dq_og = dq
if (not dq.is_contiguous()):
dq = dq.contiguous()
copy_back["dq"] = True

if sequence_parallel:
dq = torch.zeros((num_blocks_n,) + q.shape, device=q.device, dtype=q.dtype)
copy_back["dq"] = True
else:
# NOTE: the kernel does inplace accumlation so dq has to be zeros. This avoids the case where we are passed empty dq and it is not all zeros
dq.zero_()
stride_dq_all = dq.stride()[0]

is_qkvpacked = False
if dq is None or dk is None or dv is None:
dq = torch.zeros(dq_shape, device=q.device, dtype=q.dtype)
# deal with dk, dv
if (dk is None) or (dv is None):
dk = torch.empty_like(k)
dv = torch.empty_like(v)
elif (not dq.is_contiguous()) or (not dq.is_contiguous()) or (not dq.is_contiguous()):
if DEBUG:
print("Not contigious and setting is packed to True")
is_qkvpacked = True
dq_og = dq
dq = dq.contiguous()
else:
if (not dk.is_contiguous()):
dk_og = dk
dk = dk.contiguous()
copy_back["dk"] = True

if (not dv.is_contiguous()):
dv_og = dv
dv = dv.contiguous()

# NOTE: the kernel does inplace accumlation so dq has to be zeros. This avoids the case where we are passed empty dq and it is not all zeros
dq.zero_()
dv = dv.contiguous()
copy_back["dv"] = True

if DEBUG:
print("copy_back:", copy_back)

# assert contigious
assert do.is_contiguous()
Expand Down Expand Up @@ -647,26 +660,32 @@ def attention_prefill_backward_triton_impl(
IS_VARLEN=is_varlen
)

if len(dq.shape) == 5:
if DEBUG:
print("_bwd_kernel outputs")
print("dq:", dq, dq.shape)
print("dk:", dk, dk.shape)
print("dv:", dv, dv.shape)
print("delta:", delta, delta.shape)

if sequence_parallel:
dq = dq.sum(dim=0)

if DEBUG:
print("_bwd_kernel outputs")
print("attention_prefill_backward_triton_new_impl outputs")
print("dq:", dq, dq.shape)
print("dk:", dk, dk.shape)
print("dv:", dv, dv.shape)
print("delta:", delta, delta.shape)

if is_qkvpacked:
if DEBUG:
print("Copying back to original tensors due to ispacked")

# copy back results to og tensors
print("copy_back:", copy_back)

if copy_back["dq"]:
dq_og.copy_(dq)
dq = dq_og
if copy_back["dk"]:
dk_og.copy_(dk)
dk = dk_og
if copy_back["dv"]:
dv_og.copy_(dv)
return dq_og, dk_og, dv_og, delta, None, None
else:
return dq, dk, dv, delta, None, None

dv = dv_og

return dq, dk, dv, delta, None, None
12 changes: 7 additions & 5 deletions flash_attn/flash_attn_triton_amd/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -505,8 +505,8 @@ def test_op_prefill_fwd_impl(Z, H, N_CTX_Q, N_CTX_K, D_HEAD, causal, return_scor
(1, 1, 4, 4, 32),
(1, 1, 16, 16, 16),
(1, 1, 32, 32, 16),
(1, 1, 64, 64, 16), # pass # smallest head_size = 16
(1, 1, 64, 64, 64), # pass # smallest seq len seems to be 64
(1, 1, 64, 64, 16),
(1, 1, 64, 64, 64),
(1, 1, 64, 128, 32),
(1, 1, 128, 128, 64),
(1, 1, 128, 256, 45),
Expand All @@ -529,10 +529,11 @@ def test_op_prefill_fwd_impl(Z, H, N_CTX_Q, N_CTX_K, D_HEAD, causal, return_scor
(1, 16, 1024, 1024, 128),
])
@pytest.mark.parametrize('causal', [True, False])
@pytest.mark.parametrize('use_exp2', [False]) # using exp2 causas issue with exp2
@pytest.mark.parametrize('use_exp2', [False]) # FIXME: using exp2 causes issue when used with causal
@pytest.mark.parametrize('layout', ["bhsd", "bshd", "thd"])
@pytest.mark.parametrize('sequence_parallel', [True, False])
@pytest.mark.parametrize('DEBUG_INPUT', [False]) # debug output causes nans in both new and old backend
def test_op_prefill_bwd_impl(Z, H, N_CTX_Q, N_CTX_K, D_HEAD, causal, use_exp2, layout, DEBUG_INPUT):
def test_op_prefill_bwd_impl(Z, H, N_CTX_Q, N_CTX_K, D_HEAD, causal, use_exp2, layout, sequence_parallel, DEBUG_INPUT):
dtype = torch.float16
torch.manual_seed(20) # seed from test_op_bwd

Expand Down Expand Up @@ -619,7 +620,8 @@ def test_op_prefill_bwd_impl(Z, H, N_CTX_Q, N_CTX_K, D_HEAD, causal, use_exp2, l
metadata.cu_seqlens_k,
metadata.max_seqlens_q,
metadata.max_seqlens_k,
use_exp2
use_exp2,
sequence_parallel=sequence_parallel
)

# =============================================== Check ==============================================================
Expand Down
1 change: 1 addition & 0 deletions flash_attn/flash_attn_triton_amd/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

AUTOTUNE = os.environ.get('FLASH_ATTENTION_TRITON_AMD_AUTOTUNE', '1').lower() in ('1', 'true', 'yes')
DEBUG = os.environ.get('FLASH_ATTENTION_TRITON_AMD_DEBUG', '0').lower() in ('1', 'true', 'yes')
PERF = os.environ.get('FLASH_ATTENTION_TRITON_AMD_PERF', '0').lower() in ('1', 'true', 'yes')

class MetaData():
cu_seqlens_q = None
Expand Down
6 changes: 3 additions & 3 deletions tests/test_flash_attn_triton_amd.py
Original file line number Diff line number Diff line change
Expand Up @@ -585,10 +585,10 @@ def get_dropout_fraction(
@pytest.mark.parametrize("d", [32, 40, 59, 64, 80, 96, 111, 128, 160, 192, 224, 256])
# @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192, 224, 256])
# @pytest.mark.parametrize('d', [32, 64, 96, 128])
# @pytest.mark.parametrize("d", [64])
# @pytest.mark.parametrize("d", [32])
# @pytest.mark.parametrize('seqlen', [128, 256, 384, 512, 768, 1024, 2048])
@pytest.mark.parametrize("seqlen", [97, 128, 200, 384, 768, 1024, 1025, 2048])
# @pytest.mark.parametrize("seqlen", [512])
# @pytest.mark.parametrize("seqlen", [128])
# @pytest.mark.parametrize("dropout_p", [0.0, 0.17])
@pytest.mark.parametrize("dropout_p", [0.0])
def test_flash_attn_qkvpacked(seqlen, d, dropout_p, causal, local, alibi, deterministic, dtype):
Expand Down Expand Up @@ -746,7 +746,7 @@ def test_flash_attn_qkvpacked(seqlen, d, dropout_p, causal, local, alibi, determ
# @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192, 224, 256])
# @pytest.mark.parametrize('d', [32])
@pytest.mark.parametrize("seqlen", [97, 128, 200, 257, 384, 512, 768, 1025, 2048])
# @pytest.mark.parametrize('seqlen', [2])
# @pytest.mark.parametrize('seqlen', [128])
# @pytest.mark.parametrize("dropout_p", [0.0, 0.17])
@pytest.mark.parametrize('dropout_p', [0.0])
def test_flash_attn_varlen_qkvpacked(
Expand Down

0 comments on commit 4e1d6ef

Please sign in to comment.