Skip to content

Commit

Permalink
save
Browse files Browse the repository at this point in the history
  • Loading branch information
alexkranias-amd committed Nov 18, 2024
1 parent 2a14c01 commit bb5ba12
Show file tree
Hide file tree
Showing 6 changed files with 315 additions and 8 deletions.
294 changes: 294 additions & 0 deletions benchmarks/benchmark_rotary.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,294 @@
import argparse
import math
import torch
import triton
from flash_attn.flash_attn_triton_amd.utils import (
MetaData,
input_helper,
varlen_input_helper,
)
from flash_attn.flash_attn_triton_amd.interface_torch import attention_decode

ARGS_TO_TORCH_DTYPE = {
"fp16": torch.float16,
"bf16": torch.bfloat16,
"fp32": torch.float32,
}

FUNCTIONS = {
"decode": attention_decode
}

def get_benchmark_configs(args, varlen=False):
"""
Returns benchmark configurations based on whether variable-length sequences are used.
"""
if args.custom_config:
hk = args.hq if not args.hk else args.hk
sk = args.sq if not args.sk else args.sk
return [(args.b, args.hq, hk, args.sq, sk)]
elif varlen:
return [
# (2, 16, 4, 1024, 1024),
(8, 16, 2, 2048, 2048),
# (4, 16, 8, 4096, 4096),
# (2, 16, 4, 8192, 8192),
# (2, 16, 8, 16384, 16384),
# (2, 48, 12, 1024, 1024),
# (2, 48, 24, 2048, 2048),
# (2, 48, 8, 4096, 4096),
# (2, 48, 4, 8192, 8192),
# (2, 48, 2, 16384, 16384),
# (2, 64, 32, 1024, 1024),
# (4, 64, 16, 2048, 2048),
# (4, 64, 8, 4096, 4096),
# (4, 64, 32, 8192, 8192),
# (4, 128, 16, 16384, 16384),
]
else:
return [
(16, 16, 16, 1024, 1024),
# (8, 16, 16, 2048, 2048),
# (4, 16, 16, 4096, 4096),
# (1, 8, 8, 8192, 8192),
# (1, 2, 2, 16384, 16384),
# (2, 48, 48, 1024, 1024),
# (2, 48, 48, 2048, 1024),
# (1, 8, 8, 4096, 8192),
# (1, 8, 8, 8192, 4096),
# (2, 4, 4, 16384, 8192),
# (2, 8, 8, 1989, 15344),
# (4, 16, 16, 4097, 163),
# (2, 16, 16, 8122, 2159),
# (1, 16, 16, 16281, 7),
# (2, 48, 48, 1021, 1020),
# (2, 48, 48, 2001, 2048),
# (2, 8, 8, 3996, 9639),
# (2, 8, 8, 8181, 1021),
]

def gen_fn_inputs(fn_name, BATCH, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, dtype, device, layout, causal, rotary_fraction=0.0, rotary_interleaved=False):
flops_per_matmul = 0

q = torch.randn(
[BATCH, N_CTX_Q, HK, HQ // HK, D_HEAD],
device=device,
dtype=dtype,
requires_grad=False,
)
k = torch.randn(
[BATCH, N_CTX_K, HK, 1, D_HEAD],
device=device,
dtype=dtype,
requires_grad=False,
).expand(-1, -1, -1, HQ // HK, -1)
v = torch.randn(
[BATCH, N_CTX_K, HK, 1, D_HEAD],
device=device,
dtype=dtype,
requires_grad=False,
).expand(-1, -1, -1, HQ // HK, -1)
input_metadata = MetaData(sm_scale=1.3)
input_metadata.layout = "bsghd"

rotary_dim = math.floor(int(rotary_fraction * D_HEAD) / 16) * 16
if rotary_dim > 0:
angle = (
torch.rand(
N_CTX_K,
rotary_dim // 2,
device=device,
)
* 2
* math.pi
)
cos = torch.cos(angle).to(dtype=dtype)
sin = torch.sin(angle).to(dtype=dtype)

# add rotary
input_metadata.need_rotary(rotary_dim, sin, cos, rotary_interleaved=rotary_interleaved)

# Adjust flops calculation if needed
flops_per_matmul = 2.0 * BATCH * HQ * N_CTX_Q * N_CTX_K * D_HEAD

input_data = (q, k, v, input_metadata)

print('meta', input_metadata)

return input_data, flops_per_matmul

def run_benchmark(args, fn_name, fn, mode):
"""
Runs the benchmark for the provided function based on the provided arguments.
"""
print(f"Benchmarking {fn_name} in {mode} mode...")

dtype = ARGS_TO_TORCH_DTYPE[args.dtype]
head_size = args.d if args.d else 128
causal = args.causal
rotary_fraction = args.rotary_fraction
rotary_interleaved = args.rotary_interleaved
varlen = args.layout == "thd"
return_tflops = args.return_tflops
line_names = "TFLOPS" if return_tflops else "Time (ms)"

# Determine configurations
x_vals_list = get_benchmark_configs(args, varlen=varlen)

# Setup benchmark configurations
configs = [
triton.testing.Benchmark(
x_names=["BATCH", "HQ", "HK", "N_CTX_Q", "N_CTX_K"],
x_vals=x_vals_list,
line_arg="provider",
line_vals=["triton"],
line_names=[line_names],
styles=[("red", "-")],
ylabel="ms",
plot_name=f"benchmark-{fn_name}-d{head_size}-layout{args.layout}-mode{mode}",
args={
"D_HEAD": head_size,
"dtype": dtype,
"causal": causal,
"rotary_fraction": rotary_fraction,
"rotary_interleaved": rotary_interleaved,
"mode": mode,
},
)
]

@triton.testing.perf_report(configs)
def bench_function(
BATCH, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, dtype, causal, rotary_fraction, rotary_interleaved, mode, provider, device="cuda"
):
warmup = 25
rep = 100
flops_per_matmul = 0

# generate function inputs
fn_inputs, flops_per_matmul = gen_fn_inputs(
fn_name, BATCH, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, dtype, device, args.layout, causal, rotary_fraction, rotary_interleaved
)

# define the function to benchmark
if mode == "fwd":
benchmark_fn = lambda: fn(*fn_inputs)
total_flops = 2 * flops_per_matmul
elif mode == "bwd":
outputs = fn(*fn_inputs)
output = outputs[0]
grad_output = torch.randn_like(output)
benchmark_fn = lambda: output.backward(grad_output, retain_graph=True)
total_flops = 2 * flops_per_matmul * 2.5
else:
raise ValueError("Unsupported mode. Choose 'fwd' or 'bwd'.")

if causal:
total_flops *= 0.5

# Run the benchmark
ms = triton.testing.do_bench(benchmark_fn, warmup=warmup, rep=rep)

if return_tflops:
return total_flops / ms * 1e-9
else:
return ms

bench_function.run(save_path=".", print_data=True)

def supported_layouts():
"""
Returns a string describing the supported layouts.
"""
return (
"bhsd: Q, K, V are individual tensors of [batch, num_heads, seqlen_q/k, head_size]\n"
"bshd: Q, K, V are individual tensors of [batch, seqlen_q/k, num_heads, head_size]\n"
"thd: Q, K, V are individual tensors of [total_q/k, num_heads, head_size]\n"
'This layout is sometimes called "varlen" or "grouped" layout.'
)

def parse_args():
"""
Parses command-line arguments.
"""
parser = argparse.ArgumentParser(
prog="Benchmark FlashAttention",
allow_abbrev=False,
)
parser.add_argument("-b", type=int, default=0)
parser.add_argument("-hq", type=int, default=0)
parser.add_argument("-hk", type=int, default=0)
parser.add_argument("-sq", type=int, default=0)
parser.add_argument("-sk", type=int, default=0)
parser.add_argument(
"-equal_seqlens",
action="store_true",
default=False,
help="If specified, each context within the thd layout has same seqlen as sq and sk",
)
parser.add_argument("-d", type=int, default=0)
parser.add_argument("-causal", action="store_true", default=False)
parser.add_argument("-rotary_fraction", type=float, default=0.0)
parser.add_argument("-rotary_interleaved", action="store_true", default=False)
parser.add_argument("-dtype", default="fp16")
parser.add_argument("-return_tflops", action="store_true", default=False)
parser.add_argument(
"-layout",
type=str,
default="bhsd",
help=supported_layouts(),
)
parser.add_argument(
"-benchmark_fn",
type=str,
nargs="*",
choices=FUNCTIONS.keys(),
help="Function(s) to benchmark: prefill, decode, or both",
)
parser.add_argument(
"-mode",
type=str,
nargs='*',
default=["fwd", "bwd"],
choices=["fwd", "bwd"],
help="Mode(s) to run: 'fwd' for forward pass, 'bwd' for backward pass",
)
return parser.parse_args()

def main():
"""
Main function to run benchmarks.
"""
args = parse_args()

# Validate arguments
assert (
args.layout == "thd" or not args.equal_seqlens
), "Equal sequence lengths arg must be used with the thd layout."
args.custom_config = False
if args.b or args.hq or args.hk or args.sq or args.sk or args.d:
args.custom_config = True
assert args.b and args.hq and args.sq and args.d, (
"If custom config is specified, please provide all of batch, "
"number of Q heads, Q sequence length, and head size."
)
assert args.dtype in ARGS_TO_TORCH_DTYPE, "Only fp16, bf16 and fp32 types currently supported."

# determine the functions to benchmark
if args.benchmark_fn is None or len(args.benchmark_fn) == 0:
bench_fn_list = FUNCTIONS.keys()
else:
bench_fn_list = args.benchmark_fn

# benchmark functions
for fn_name in bench_fn_list:
if fn_name not in FUNCTIONS:
raise ValueError(f"Invalid benchmark function specified: {fn_name}")
for mode in args.mode:
if fn_name == "decode" and mode == "bwd":
print(f"Decode kernel doesnot have a backward pass")
continue
run_benchmark(args, fn_name, FUNCTIONS[fn_name], mode)

if __name__ == "__main__":
main()
4 changes: 3 additions & 1 deletion flash_attn/flash_attn_triton_amd/bench.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def get_benchmark_configs(args, varlen=False):
(2, 8, 8, 8181, 1021),
]

def gen_fn_inputs(fn_name, BATCH, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, dtype, device, layout, causal):
def gen_fn_inputs(fn_name, BATCH, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, dtype, device, layout, causal, rotary):
flops_per_matmul = 0

if fn_name.startswith("prefill"):
Expand Down Expand Up @@ -110,6 +110,8 @@ def gen_fn_inputs(fn_name, BATCH, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, dtype, devic
).expand(-1, -1, -1, HQ // HK, -1)
input_metadata = MetaData(sm_scale=1.3)
input_metadata.layout = "bsghd"



# Adjust flops calculation if needed
flops_per_matmul = 2.0 * BATCH * HQ * N_CTX_Q * N_CTX_K * D_HEAD
Expand Down
8 changes: 5 additions & 3 deletions flash_attn/flash_attn_triton_amd/fwd_decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def rotary_kernel_splitk(
Note:
- for K in splitk let BLOCK_M = BLOCK_N, and start_m=start_n
"""
# pdb.set_trace()
# import pdb; pdb.set_trace()
range_m = start_m + tl.arange(0, BLOCK_M)
range_d = tl.arange(0, BLOCK_K)

Expand Down Expand Up @@ -240,6 +240,8 @@ def _fwd_kernel_splitK(
off_g_q = off_zhg % G_q
splitk_idx = tl.program_id(2)

# import pdb; pdb.set_trace()

# pick batch index
if USE_CACHE_BATCH_IDX:
cache_batch_idx = tl.load(Cache_batch_idx + off_z)
Expand Down Expand Up @@ -279,7 +281,7 @@ def _fwd_kernel_splitK(
# Copy new Keys and Values into Cache
if NEW_KV:
knew_base = K_new + k_head_idx * stride_kn_h + off_z * stride_kn_z + off_g_q * stride_kn_g

# Determine the starting position for new data in the cache
if USE_CACHE_SEQLENS:
start_idx = tl.load(Cache_seqlens + off_z)
Expand Down Expand Up @@ -847,7 +849,7 @@ def attention_decode_forward_triton_impl(q, k, v,
Metadata=metadata,
K_new = k_new,
V_new = v_new,
Cache_seqlens=cache_seqlens,
Cache_seqlens=cache_seqlens if use_cache_seqlens else 0,
Cache_batch_idx=cache_batch_idx,
Alibi_slopes=alibi_slopes,
Rotary_cos=rotary_cos,
Expand Down
2 changes: 1 addition & 1 deletion flash_attn/flash_attn_triton_amd/interface_fa.py
Original file line number Diff line number Diff line change
Expand Up @@ -581,6 +581,6 @@ def fwd_kvcache(
metadata.rotary_sin,
metadata.rotary_dim,
metadata.rotary_interleaved,
metadata.rotary_conjunction
metadata.rotary_conjugate
)
return output, softmax_lse
5 changes: 5 additions & 0 deletions flash_attn/flash_attn_triton_amd/interface_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,11 @@ def forward(ctx, q, k, v, metadata):
metadata.new_kv,
metadata.k_new,
metadata.v_new,
metadata.rotary_cos,
metadata.rotary_sin,
metadata.rotary_dim,
metadata.rotary_interleaved,
metadata.rotary_conjugate
)
return output, softmax_lse

Expand Down
Loading

0 comments on commit bb5ba12

Please sign in to comment.