Skip to content

Commit

Permalink
save
Browse files Browse the repository at this point in the history
  • Loading branch information
alexkranias-amd committed Nov 8, 2024
1 parent bf67fc9 commit 3836ec3
Show file tree
Hide file tree
Showing 8 changed files with 857 additions and 21 deletions.
68 changes: 67 additions & 1 deletion flash_attn/flash_attn_triton_amd/bwd_prefill.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,30 @@
import triton.language as tl
from .utils import get_shape_from_layout, get_strides_from_layout, DEBUG, PERF

@triton.jit
def cdiv_fn(x, y):
return (x + y - 1) // y

@triton.jit
def dropout_offsets(philox_seed, philox_offset, dropout_p, m, n, stride):
ms = tl.arange(0, m)
ns = tl.arange(0, n)
return philox_offset + ms[:, None] * stride + ns[None, :]


@triton.jit
def dropout_rng(philox_seed, philox_offset, dropout_p, m, n, stride):
rng_offsets = dropout_offsets(philox_seed, philox_offset, dropout_p, m, n, stride).to(tl.uint32)
# TODO: use tl.randint for better performance
return tl.rand(philox_seed, rng_offsets)


@triton.jit
def dropout_mask(philox_seed, philox_offset, dropout_p, m, n, stride):
rng_output = dropout_rng(philox_seed, philox_offset, dropout_p, m, n, stride)
rng_keep = rng_output > dropout_p
return rng_keep

@triton.jit
def _bwd_preprocess_use_o(
Out,
Expand Down Expand Up @@ -122,12 +146,14 @@ def _bwd_kernel_one_col_block(
start_n,
num_block_m,
num_block_n,
dropout_p, philox_seed, philox_offset_base,
BLOCK_M: tl.constexpr,
BLOCK_DMODEL: tl.constexpr,
ACTUAL_BLOCK_DMODEL: tl.constexpr,
BLOCK_N: tl.constexpr,
SEQUENCE_PARALLEL: tl.constexpr,
CAUSAL: tl.constexpr,
DROPOUT: tl.constexpr,
USE_EXP2: tl.constexpr,
):
if CAUSAL:
Expand Down Expand Up @@ -196,13 +222,30 @@ def _bwd_kernel_one_col_block(
# mask block in the cases where the data is smaller the block size
p_mask = mask_m[:, None] & mask_n[None, :]
p = tl.where(p_mask, p, 0.0)

# NOTE: must create a new var p_drop to prevent p (which is used later to compute ds) from changing
if DROPOUT:
philox_offset = philox_offset_base + start_m * N_CTX_K + start_n * BLOCK_N
keep = dropout_mask(philox_seed, philox_offset, dropout_p, BLOCK_M, BLOCK_N, N_CTX_K)
p_drop = tl.where(keep, p, 0.0)

p_drop = p_drop / (1 - dropout_p)
p_drop = p_drop.to(Q.dtype.element_ty)

# compute dv
dv += tl.dot(tl.trans(p.to(Q.dtype.element_ty)), do)
dv += tl.dot(tl.trans(p_drop.to(Q.dtype.element_ty)), do)

# compute dp
dp = tl.dot(do, tl.trans(v))

if DROPOUT:
philox_offset = philox_offset_base + start_m * N_CTX_K + start_n * BLOCK_N
keep = dropout_mask(philox_seed, philox_offset, dropout_p, BLOCK_M, BLOCK_N, N_CTX_K)
dp = tl.where(keep, dp, 0.0)

dp = dp / (1 - dropout_p)
dp = dp.to(Q.dtype.element_ty)

# compute ds , ds = p * (dp - delta[:, None])
d_ptrs = d_offset + offs_m * stride_deltam
Di = tl.load(d_ptrs, mask=mask_m)
Expand Down Expand Up @@ -265,12 +308,14 @@ def _bwd_kernel(
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
dropout_p, philox_seed, philox_offset,
BLOCK_M: tl.constexpr,
BLOCK_DMODEL: tl.constexpr,
ACTUAL_BLOCK_DMODEL: tl.constexpr,
BLOCK_N: tl.constexpr,
SEQUENCE_PARALLEL: tl.constexpr,
CAUSAL: tl.constexpr,
DROPOUT: tl.constexpr,
USE_EXP2: tl.constexpr,
IS_VARLEN: tl.constexpr,
):
Expand All @@ -281,6 +326,12 @@ def _bwd_kernel(
off_z = off_hz // H
off_h = off_hz % H

if DROPOUT:
off_hz = off_z * H + off_h
batch_philox_offset = philox_offset + off_hz * max_seqlen_q * max_seqlen_k
else:
batch_philox_offset = 0

if IS_VARLEN:
# Compute sequence lengths for the current batch
q_start = tl.load(cu_seqlens_q + off_z)
Expand Down Expand Up @@ -363,12 +414,14 @@ def _bwd_kernel(
start_n,
num_block_m,
num_block_n,
dropout_p, philox_seed, batch_philox_offset,
BLOCK_M=BLOCK_M,
BLOCK_DMODEL=BLOCK_DMODEL,
ACTUAL_BLOCK_DMODEL=ACTUAL_BLOCK_DMODEL,
BLOCK_N=BLOCK_N,
SEQUENCE_PARALLEL=SEQUENCE_PARALLEL,
CAUSAL=CAUSAL,
DROPOUT=dropout_p>0.0,
USE_EXP2=USE_EXP2,
)
else:
Expand Down Expand Up @@ -420,12 +473,14 @@ def _bwd_kernel(
start_n,
num_block_m,
num_block_n,
dropout_p, philox_seed, batch_philox_offset,
BLOCK_M=BLOCK_M,
BLOCK_DMODEL=BLOCK_DMODEL,
ACTUAL_BLOCK_DMODEL=ACTUAL_BLOCK_DMODEL,
BLOCK_N=BLOCK_N,
SEQUENCE_PARALLEL=SEQUENCE_PARALLEL,
CAUSAL=CAUSAL,
DROPOUT=dropout_p>0.0,
USE_EXP2=USE_EXP2,
)

Expand All @@ -444,12 +499,14 @@ def attention_prefill_backward_triton_impl(
sm_scale: float,
alibi_slopes,
causal,
dropout_p,
layout: str,
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q: int,
max_seqlen_k: int,
use_exp2: bool,
rng_state: torch.Tensor,
sequence_parallel = True,
):
if DEBUG:
Expand All @@ -473,6 +530,7 @@ def attention_prefill_backward_triton_impl(
print("max_seqlen_q:", max_seqlen_q)
print("max_seqlen_k:", max_seqlen_k)
print("use_exp2:", use_exp2)
print("rng_state", rng_state)
print("sequence_parallel:", sequence_parallel)

# make contigious
Expand All @@ -491,6 +549,9 @@ def attention_prefill_backward_triton_impl(
batch_headsize = batch * nheads_q
is_varlen = layout == "thd"

# get dropout metadata
philox_seed, philox_offset = rng_state[0].item(), rng_state[1].item()

# FIXME: some configs lead to oom for some reason when using 64 x 64 blocks
if max_seqlen_q <= 32 or max_seqlen_k <= 32:
BLOCK_M = 32
Expand Down Expand Up @@ -610,6 +671,9 @@ def attention_prefill_backward_triton_impl(
print("heads_q:",nheads_q)
print("max_seqlen_q:",max_seqlen_q)
print("max_seqlen_k:",max_seqlen_k)
print("dropout_p:",dropout_p)
print("philox_seed:", philox_seed)
print("philox_offset:",philox_offset)
print("BLOCK_M:",BLOCK_M)
print("BLOCK_N:",BLOCK_M)
print("BLOCK_DMODEL:",BLOCK_DMODEL)
Expand Down Expand Up @@ -647,12 +711,14 @@ def attention_prefill_backward_triton_impl(
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
dropout_p, philox_seed, philox_offset,
BLOCK_M=BLOCK_M,
BLOCK_N=BLOCK_N,
BLOCK_DMODEL=BLOCK_DMODEL,
ACTUAL_BLOCK_DMODEL=ACTUAL_BLOCK_DMODEL,
SEQUENCE_PARALLEL=sequence_parallel,
CAUSAL=causal,
DROPOUT=dropout_p>0.0,
USE_EXP2=use_exp2,
num_warps=num_warps,
num_stages=num_stages,
Expand Down
Loading

0 comments on commit 3836ec3

Please sign in to comment.