diff --git a/benchmarks/benchmark_rotary.py b/benchmarks/benchmark_rotary.py new file mode 100644 index 000000000..11f63a5f1 --- /dev/null +++ b/benchmarks/benchmark_rotary.py @@ -0,0 +1,294 @@ +import argparse +import math +import torch +import triton +from flash_attn.flash_attn_triton_amd.utils import ( + MetaData, + input_helper, + varlen_input_helper, +) +from flash_attn.flash_attn_triton_amd.interface_torch import attention_decode + +ARGS_TO_TORCH_DTYPE = { + "fp16": torch.float16, + "bf16": torch.bfloat16, + "fp32": torch.float32, +} + +FUNCTIONS = { + "decode": attention_decode +} + +def get_benchmark_configs(args, varlen=False): + """ + Returns benchmark configurations based on whether variable-length sequences are used. + """ + if args.custom_config: + hk = args.hq if not args.hk else args.hk + sk = args.sq if not args.sk else args.sk + return [(args.b, args.hq, hk, args.sq, sk)] + elif varlen: + return [ + # (2, 16, 4, 1024, 1024), + (8, 16, 2, 2048, 2048), + # (4, 16, 8, 4096, 4096), + # (2, 16, 4, 8192, 8192), + # (2, 16, 8, 16384, 16384), + # (2, 48, 12, 1024, 1024), + # (2, 48, 24, 2048, 2048), + # (2, 48, 8, 4096, 4096), + # (2, 48, 4, 8192, 8192), + # (2, 48, 2, 16384, 16384), + # (2, 64, 32, 1024, 1024), + # (4, 64, 16, 2048, 2048), + # (4, 64, 8, 4096, 4096), + # (4, 64, 32, 8192, 8192), + # (4, 128, 16, 16384, 16384), + ] + else: + return [ + (16, 16, 16, 1024, 1024), + # (8, 16, 16, 2048, 2048), + # (4, 16, 16, 4096, 4096), + # (1, 8, 8, 8192, 8192), + # (1, 2, 2, 16384, 16384), + # (2, 48, 48, 1024, 1024), + # (2, 48, 48, 2048, 1024), + # (1, 8, 8, 4096, 8192), + # (1, 8, 8, 8192, 4096), + # (2, 4, 4, 16384, 8192), + # (2, 8, 8, 1989, 15344), + # (4, 16, 16, 4097, 163), + # (2, 16, 16, 8122, 2159), + # (1, 16, 16, 16281, 7), + # (2, 48, 48, 1021, 1020), + # (2, 48, 48, 2001, 2048), + # (2, 8, 8, 3996, 9639), + # (2, 8, 8, 8181, 1021), + ] + +def gen_fn_inputs(fn_name, BATCH, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, dtype, device, layout, causal, rotary_fraction=0.0, rotary_interleaved=False): + flops_per_matmul = 0 + + q = torch.randn( + [BATCH, N_CTX_Q, HK, HQ // HK, D_HEAD], + device=device, + dtype=dtype, + requires_grad=False, + ) + k = torch.randn( + [BATCH, N_CTX_K, HK, 1, D_HEAD], + device=device, + dtype=dtype, + requires_grad=False, + ).expand(-1, -1, -1, HQ // HK, -1) + v = torch.randn( + [BATCH, N_CTX_K, HK, 1, D_HEAD], + device=device, + dtype=dtype, + requires_grad=False, + ).expand(-1, -1, -1, HQ // HK, -1) + input_metadata = MetaData(sm_scale=1.3) + input_metadata.layout = "bsghd" + + rotary_dim = math.floor(int(rotary_fraction * D_HEAD) / 16) * 16 + if rotary_dim > 0: + angle = ( + torch.rand( + N_CTX_K, + rotary_dim // 2, + device=device, + ) + * 2 + * math.pi + ) + cos = torch.cos(angle).to(dtype=dtype) + sin = torch.sin(angle).to(dtype=dtype) + + # add rotary + input_metadata.need_rotary(rotary_dim, sin, cos, rotary_interleaved=rotary_interleaved) + + # Adjust flops calculation if needed + flops_per_matmul = 2.0 * BATCH * HQ * N_CTX_Q * N_CTX_K * D_HEAD + + input_data = (q, k, v, input_metadata) + + print('meta', input_metadata) + + return input_data, flops_per_matmul + +def run_benchmark(args, fn_name, fn, mode): + """ + Runs the benchmark for the provided function based on the provided arguments. + """ + print(f"Benchmarking {fn_name} in {mode} mode...") + + dtype = ARGS_TO_TORCH_DTYPE[args.dtype] + head_size = args.d if args.d else 128 + causal = args.causal + rotary_fraction = args.rotary_fraction + rotary_interleaved = args.rotary_interleaved + varlen = args.layout == "thd" + return_tflops = args.return_tflops + line_names = "TFLOPS" if return_tflops else "Time (ms)" + + # Determine configurations + x_vals_list = get_benchmark_configs(args, varlen=varlen) + + # Setup benchmark configurations + configs = [ + triton.testing.Benchmark( + x_names=["BATCH", "HQ", "HK", "N_CTX_Q", "N_CTX_K"], + x_vals=x_vals_list, + line_arg="provider", + line_vals=["triton"], + line_names=[line_names], + styles=[("red", "-")], + ylabel="ms", + plot_name=f"benchmark-{fn_name}-d{head_size}-layout{args.layout}-mode{mode}", + args={ + "D_HEAD": head_size, + "dtype": dtype, + "causal": causal, + "rotary_fraction": rotary_fraction, + "rotary_interleaved": rotary_interleaved, + "mode": mode, + }, + ) + ] + + @triton.testing.perf_report(configs) + def bench_function( + BATCH, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, dtype, causal, rotary_fraction, rotary_interleaved, mode, provider, device="cuda" + ): + warmup = 25 + rep = 100 + flops_per_matmul = 0 + + # generate function inputs + fn_inputs, flops_per_matmul = gen_fn_inputs( + fn_name, BATCH, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, dtype, device, args.layout, causal, rotary_fraction, rotary_interleaved + ) + + # define the function to benchmark + if mode == "fwd": + benchmark_fn = lambda: fn(*fn_inputs) + total_flops = 2 * flops_per_matmul + elif mode == "bwd": + outputs = fn(*fn_inputs) + output = outputs[0] + grad_output = torch.randn_like(output) + benchmark_fn = lambda: output.backward(grad_output, retain_graph=True) + total_flops = 2 * flops_per_matmul * 2.5 + else: + raise ValueError("Unsupported mode. Choose 'fwd' or 'bwd'.") + + if causal: + total_flops *= 0.5 + + # Run the benchmark + ms = triton.testing.do_bench(benchmark_fn, warmup=warmup, rep=rep) + + if return_tflops: + return total_flops / ms * 1e-9 + else: + return ms + + bench_function.run(save_path=".", print_data=True) + +def supported_layouts(): + """ + Returns a string describing the supported layouts. + """ + return ( + "bhsd: Q, K, V are individual tensors of [batch, num_heads, seqlen_q/k, head_size]\n" + "bshd: Q, K, V are individual tensors of [batch, seqlen_q/k, num_heads, head_size]\n" + "thd: Q, K, V are individual tensors of [total_q/k, num_heads, head_size]\n" + 'This layout is sometimes called "varlen" or "grouped" layout.' + ) + +def parse_args(): + """ + Parses command-line arguments. + """ + parser = argparse.ArgumentParser( + prog="Benchmark FlashAttention", + allow_abbrev=False, + ) + parser.add_argument("-b", type=int, default=0) + parser.add_argument("-hq", type=int, default=0) + parser.add_argument("-hk", type=int, default=0) + parser.add_argument("-sq", type=int, default=0) + parser.add_argument("-sk", type=int, default=0) + parser.add_argument( + "-equal_seqlens", + action="store_true", + default=False, + help="If specified, each context within the thd layout has same seqlen as sq and sk", + ) + parser.add_argument("-d", type=int, default=0) + parser.add_argument("-causal", action="store_true", default=False) + parser.add_argument("-rotary_fraction", type=float, default=0.0) + parser.add_argument("-rotary_interleaved", action="store_true", default=False) + parser.add_argument("-dtype", default="fp16") + parser.add_argument("-return_tflops", action="store_true", default=False) + parser.add_argument( + "-layout", + type=str, + default="bhsd", + help=supported_layouts(), + ) + parser.add_argument( + "-benchmark_fn", + type=str, + nargs="*", + choices=FUNCTIONS.keys(), + help="Function(s) to benchmark: prefill, decode, or both", + ) + parser.add_argument( + "-mode", + type=str, + nargs='*', + default=["fwd", "bwd"], + choices=["fwd", "bwd"], + help="Mode(s) to run: 'fwd' for forward pass, 'bwd' for backward pass", + ) + return parser.parse_args() + +def main(): + """ + Main function to run benchmarks. + """ + args = parse_args() + + # Validate arguments + assert ( + args.layout == "thd" or not args.equal_seqlens + ), "Equal sequence lengths arg must be used with the thd layout." + args.custom_config = False + if args.b or args.hq or args.hk or args.sq or args.sk or args.d: + args.custom_config = True + assert args.b and args.hq and args.sq and args.d, ( + "If custom config is specified, please provide all of batch, " + "number of Q heads, Q sequence length, and head size." + ) + assert args.dtype in ARGS_TO_TORCH_DTYPE, "Only fp16, bf16 and fp32 types currently supported." + + # determine the functions to benchmark + if args.benchmark_fn is None or len(args.benchmark_fn) == 0: + bench_fn_list = FUNCTIONS.keys() + else: + bench_fn_list = args.benchmark_fn + + # benchmark functions + for fn_name in bench_fn_list: + if fn_name not in FUNCTIONS: + raise ValueError(f"Invalid benchmark function specified: {fn_name}") + for mode in args.mode: + if fn_name == "decode" and mode == "bwd": + print(f"Decode kernel doesnot have a backward pass") + continue + run_benchmark(args, fn_name, FUNCTIONS[fn_name], mode) + +if __name__ == "__main__": + main() diff --git a/flash_attn/flash_attn_triton_amd/bench.py b/flash_attn/flash_attn_triton_amd/bench.py index 91939f831..bf06bd15b 100644 --- a/flash_attn/flash_attn_triton_amd/bench.py +++ b/flash_attn/flash_attn_triton_amd/bench.py @@ -67,7 +67,7 @@ def get_benchmark_configs(args, varlen=False): (2, 8, 8, 8181, 1021), ] -def gen_fn_inputs(fn_name, BATCH, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, dtype, device, layout, causal): +def gen_fn_inputs(fn_name, BATCH, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, dtype, device, layout, causal, rotary): flops_per_matmul = 0 if fn_name.startswith("prefill"): @@ -110,6 +110,8 @@ def gen_fn_inputs(fn_name, BATCH, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, dtype, devic ).expand(-1, -1, -1, HQ // HK, -1) input_metadata = MetaData(sm_scale=1.3) input_metadata.layout = "bsghd" + + # Adjust flops calculation if needed flops_per_matmul = 2.0 * BATCH * HQ * N_CTX_Q * N_CTX_K * D_HEAD diff --git a/flash_attn/flash_attn_triton_amd/fwd_decode.py b/flash_attn/flash_attn_triton_amd/fwd_decode.py index e9b59d814..d582eecda 100644 --- a/flash_attn/flash_attn_triton_amd/fwd_decode.py +++ b/flash_attn/flash_attn_triton_amd/fwd_decode.py @@ -43,7 +43,7 @@ def rotary_kernel_splitk( Note: - for K in splitk let BLOCK_M = BLOCK_N, and start_m=start_n """ - # pdb.set_trace() + # import pdb; pdb.set_trace() range_m = start_m + tl.arange(0, BLOCK_M) range_d = tl.arange(0, BLOCK_K) @@ -240,6 +240,8 @@ def _fwd_kernel_splitK( off_g_q = off_zhg % G_q splitk_idx = tl.program_id(2) + # import pdb; pdb.set_trace() + # pick batch index if USE_CACHE_BATCH_IDX: cache_batch_idx = tl.load(Cache_batch_idx + off_z) @@ -279,7 +281,7 @@ def _fwd_kernel_splitK( # Copy new Keys and Values into Cache if NEW_KV: knew_base = K_new + k_head_idx * stride_kn_h + off_z * stride_kn_z + off_g_q * stride_kn_g - + # Determine the starting position for new data in the cache if USE_CACHE_SEQLENS: start_idx = tl.load(Cache_seqlens + off_z) @@ -847,7 +849,7 @@ def attention_decode_forward_triton_impl(q, k, v, Metadata=metadata, K_new = k_new, V_new = v_new, - Cache_seqlens=cache_seqlens, + Cache_seqlens=cache_seqlens if use_cache_seqlens else 0, Cache_batch_idx=cache_batch_idx, Alibi_slopes=alibi_slopes, Rotary_cos=rotary_cos, diff --git a/flash_attn/flash_attn_triton_amd/interface_fa.py b/flash_attn/flash_attn_triton_amd/interface_fa.py index 35e9aaa23..685c7a162 100644 --- a/flash_attn/flash_attn_triton_amd/interface_fa.py +++ b/flash_attn/flash_attn_triton_amd/interface_fa.py @@ -581,6 +581,6 @@ def fwd_kvcache( metadata.rotary_sin, metadata.rotary_dim, metadata.rotary_interleaved, - metadata.rotary_conjunction + metadata.rotary_conjugate ) return output, softmax_lse diff --git a/flash_attn/flash_attn_triton_amd/interface_torch.py b/flash_attn/flash_attn_triton_amd/interface_torch.py index d4906606e..8d0638fb8 100644 --- a/flash_attn/flash_attn_triton_amd/interface_torch.py +++ b/flash_attn/flash_attn_triton_amd/interface_torch.py @@ -91,6 +91,11 @@ def forward(ctx, q, k, v, metadata): metadata.new_kv, metadata.k_new, metadata.v_new, + metadata.rotary_cos, + metadata.rotary_sin, + metadata.rotary_dim, + metadata.rotary_interleaved, + metadata.rotary_conjugate ) return output, softmax_lse diff --git a/flash_attn/flash_attn_triton_amd/utils.py b/flash_attn/flash_attn_triton_amd/utils.py index a3b3e925e..fef7211d3 100644 --- a/flash_attn/flash_attn_triton_amd/utils.py +++ b/flash_attn/flash_attn_triton_amd/utils.py @@ -31,7 +31,7 @@ class MetaData(): rotary_sin = None rotary_cos = None rotary_interleaved = False - rotary_conjunction = False + rotary_conjugate = False def __repr__(self) -> str: @@ -44,6 +44,10 @@ def __repr__(self) -> str: f" bias={self.bias},\n" f" alibi_slopes={self.alibi_slopes},\n" f" causal={self.causal},\n" + f" rotary_dim={self.rotary_dim},\n" + f" rotary_interleaved={self.rotary_interleaved},\n" + f" rotary_sin={self.rotary_sin.shape},\n" + f" rotary_cos={self.rotary_cos.shape},\n" f" num_contexts={self.num_contexts},\n" f" varlen={self.varlen},\n" f" layout={self.layout},\n" @@ -90,12 +94,12 @@ def need_alibi(self, alibi_slopes, batch, nheads): def need_causal(self): self.causal = True - def need_rotary(self, rotary_dim, sin, cos, rotary_interleaved, rotary_conjunction=False): + def need_rotary(self, rotary_dim, sin, cos, rotary_interleaved, rotary_conjugate=False): self.rotary_dim = rotary_dim self.rotary_sin = sin self.rotary_cos = cos self.rotary_interleaved = rotary_interleaved - self.rotary_conjunction = rotary_conjunction + self.rotary_conjugate = rotary_conjugate def need_dropout(self, dropout_p, return_scores): self.dropout_p = dropout_p