From 02effd70e9b600c89cc102121a4dbf15efadfddf Mon Sep 17 00:00:00 2001 From: hiworldwzj <30762946+hiworldwzj@users.noreply.github.com> Date: Fri, 3 Jan 2025 19:18:56 +0800 Subject: [PATCH 1/6] update config and fix dp router error. (#695) Co-authored-by: wangzaijun --- ...ad_num=128,q_rope_dim=64}_NVIDIA_H800.json | 1 + ...ead_num=16,q_rope_dim=64}_NVIDIA_H800.json | 2 +- lightllm/common/vllm_kernel/_ops.py | 228 ------------------ .../router/req_queue/continues_batch/impl.py | 15 +- 4 files changed, 11 insertions(+), 235 deletions(-) create mode 100644 lightllm/common/all_kernel_configs/mla_decode_attentnion/{out_dtype=torch.bfloat16,q_head_dim=512,q_head_num=128,q_rope_dim=64}_NVIDIA_H800.json diff --git a/lightllm/common/all_kernel_configs/mla_decode_attentnion/{out_dtype=torch.bfloat16,q_head_dim=512,q_head_num=128,q_rope_dim=64}_NVIDIA_H800.json b/lightllm/common/all_kernel_configs/mla_decode_attentnion/{out_dtype=torch.bfloat16,q_head_dim=512,q_head_num=128,q_rope_dim=64}_NVIDIA_H800.json new file mode 100644 index 000000000..7272d3f06 --- /dev/null +++ b/lightllm/common/all_kernel_configs/mla_decode_attentnion/{out_dtype=torch.bfloat16,q_head_dim=512,q_head_num=128,q_rope_dim=64}_NVIDIA_H800.json @@ -0,0 +1 @@ +{"256": {"1": {"BLOCK_N": 32, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 4, "stage1_num_stages": 2, "stage2_num_warps": 4, "stage2_num_stages": 1}, "8": {"BLOCK_N": 32, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 4, "stage1_num_stages": 2, "stage2_num_warps": 4, "stage2_num_stages": 1}, "16": {"BLOCK_N": 32, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 4, "stage1_num_stages": 2, "stage2_num_warps": 4, "stage2_num_stages": 3}, "32": {"BLOCK_N": 32, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 4, "stage1_num_stages": 2, "stage2_num_warps": 2, "stage2_num_stages": 1}, "64": {"BLOCK_N": 32, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 4, "stage1_num_stages": 2, "stage2_num_warps": 1, "stage2_num_stages": 1}, "128": {"BLOCK_N": 32, "BLOCK_Q_HEAD": 32, "stage1_num_warps": 8, "stage1_num_stages": 3, "stage2_num_warps": 1, "stage2_num_stages": 3}, "256": {"BLOCK_N": 32, "BLOCK_Q_HEAD": 32, "stage1_num_warps": 8, "stage1_num_stages": 4, "stage2_num_warps": 1, "stage2_num_stages": 1}}, "512": {"1": {"BLOCK_N": 32, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 8, "stage1_num_stages": 3, "stage2_num_warps": 4, "stage2_num_stages": 1}, "8": {"BLOCK_N": 32, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 4, "stage1_num_stages": 2, "stage2_num_warps": 4, "stage2_num_stages": 1}, "16": {"BLOCK_N": 32, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 4, "stage1_num_stages": 2, "stage2_num_warps": 1, "stage2_num_stages": 3}, "32": {"BLOCK_N": 32, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 4, "stage1_num_stages": 2, "stage2_num_warps": 2, "stage2_num_stages": 1}, "64": {"BLOCK_N": 32, "BLOCK_Q_HEAD": 32, "stage1_num_warps": 8, "stage1_num_stages": 3, "stage2_num_warps": 4, "stage2_num_stages": 3}, "128": {"BLOCK_N": 32, "BLOCK_Q_HEAD": 32, "stage1_num_warps": 8, "stage1_num_stages": 4, "stage2_num_warps": 4, "stage2_num_stages": 3}, "256": {"BLOCK_N": 32, "BLOCK_Q_HEAD": 32, "stage1_num_warps": 8, "stage1_num_stages": 4, "stage2_num_warps": 2, "stage2_num_stages": 3}}, "1024": {"1": {"BLOCK_N": 32, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 4, "stage1_num_stages": 2, "stage2_num_warps": 4, "stage2_num_stages": 1}, "8": {"BLOCK_N": 32, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 4, "stage1_num_stages": 2, "stage2_num_warps": 4, "stage2_num_stages": 1}, "16": {"BLOCK_N": 32, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 4, "stage1_num_stages": 2, "stage2_num_warps": 1, "stage2_num_stages": 3}, "32": {"BLOCK_N": 32, "BLOCK_Q_HEAD": 32, "stage1_num_warps": 8, "stage1_num_stages": 3, "stage2_num_warps": 1, "stage2_num_stages": 1}, "64": {"BLOCK_N": 32, "BLOCK_Q_HEAD": 32, "stage1_num_warps": 8, "stage1_num_stages": 4, "stage2_num_warps": 1, "stage2_num_stages": 1}, "128": {"BLOCK_N": 32, "BLOCK_Q_HEAD": 32, "stage1_num_warps": 8, "stage1_num_stages": 4, "stage2_num_warps": 1, "stage2_num_stages": 1}, "256": {"BLOCK_N": 32, "BLOCK_Q_HEAD": 32, "stage1_num_warps": 8, "stage1_num_stages": 4, "stage2_num_warps": 2, "stage2_num_stages": 3}}, "2048": {"1": {"BLOCK_N": 32, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 4, "stage1_num_stages": 2, "stage2_num_warps": 4, "stage2_num_stages": 1}, "8": {"BLOCK_N": 32, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 4, "stage1_num_stages": 2, "stage2_num_warps": 4, "stage2_num_stages": 1}, "16": {"BLOCK_N": 32, "BLOCK_Q_HEAD": 32, "stage1_num_warps": 8, "stage1_num_stages": 3, "stage2_num_warps": 1, "stage2_num_stages": 3}, "32": {"BLOCK_N": 32, "BLOCK_Q_HEAD": 32, "stage1_num_warps": 8, "stage1_num_stages": 4, "stage2_num_warps": 2, "stage2_num_stages": 1}, "64": {"BLOCK_N": 32, "BLOCK_Q_HEAD": 32, "stage1_num_warps": 8, "stage1_num_stages": 4, "stage2_num_warps": 1, "stage2_num_stages": 3}, "128": {"BLOCK_N": 32, "BLOCK_Q_HEAD": 32, "stage1_num_warps": 8, "stage1_num_stages": 4, "stage2_num_warps": 1, "stage2_num_stages": 1}, "256": {"BLOCK_N": 32, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 4, "stage1_num_stages": 4, "stage2_num_warps": 2, "stage2_num_stages": 3}}, "4096": {"1": {"BLOCK_N": 32, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 4, "stage1_num_stages": 2, "stage2_num_warps": 4, "stage2_num_stages": 1}, "8": {"BLOCK_N": 32, "BLOCK_Q_HEAD": 32, "stage1_num_warps": 8, "stage1_num_stages": 3, "stage2_num_warps": 4, "stage2_num_stages": 1}, "16": {"BLOCK_N": 32, "BLOCK_Q_HEAD": 32, "stage1_num_warps": 8, "stage1_num_stages": 4, "stage2_num_warps": 2, "stage2_num_stages": 1}, "32": {"BLOCK_N": 32, "BLOCK_Q_HEAD": 32, "stage1_num_warps": 8, "stage1_num_stages": 4, "stage2_num_warps": 2, "stage2_num_stages": 3}, "64": {"BLOCK_N": 32, "BLOCK_Q_HEAD": 32, "stage1_num_warps": 8, "stage1_num_stages": 4, "stage2_num_warps": 2, "stage2_num_stages": 3}, "128": {"BLOCK_N": 32, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 4, "stage1_num_stages": 4, "stage2_num_warps": 2, "stage2_num_stages": 1}}, "8192": {"1": {"BLOCK_N": 32, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 4, "stage1_num_stages": 2, "stage2_num_warps": 4, "stage2_num_stages": 1}, "8": {"BLOCK_N": 32, "BLOCK_Q_HEAD": 32, "stage1_num_warps": 8, "stage1_num_stages": 4, "stage2_num_warps": 4, "stage2_num_stages": 1}, "16": {"BLOCK_N": 32, "BLOCK_Q_HEAD": 32, "stage1_num_warps": 8, "stage1_num_stages": 4, "stage2_num_warps": 1, "stage2_num_stages": 3}, "32": {"BLOCK_N": 32, "BLOCK_Q_HEAD": 32, "stage1_num_warps": 8, "stage1_num_stages": 4, "stage2_num_warps": 1, "stage2_num_stages": 1}, "64": {"BLOCK_N": 32, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 4, "stage1_num_stages": 4, "stage2_num_warps": 1, "stage2_num_stages": 3}}} \ No newline at end of file diff --git a/lightllm/common/all_kernel_configs/mla_decode_attentnion/{out_dtype=torch.bfloat16,q_head_dim=512,q_head_num=16,q_rope_dim=64}_NVIDIA_H800.json b/lightllm/common/all_kernel_configs/mla_decode_attentnion/{out_dtype=torch.bfloat16,q_head_dim=512,q_head_num=16,q_rope_dim=64}_NVIDIA_H800.json index 4abf69e8b..26143b609 100644 --- a/lightllm/common/all_kernel_configs/mla_decode_attentnion/{out_dtype=torch.bfloat16,q_head_dim=512,q_head_num=16,q_rope_dim=64}_NVIDIA_H800.json +++ b/lightllm/common/all_kernel_configs/mla_decode_attentnion/{out_dtype=torch.bfloat16,q_head_dim=512,q_head_num=16,q_rope_dim=64}_NVIDIA_H800.json @@ -1 +1 @@ -{"256": {"1": {"BLOCK_SEQ": 32, "BLOCK_N": 32, "BLOCK_Q_HEAD": 32, "stage1_num_warps": 4, "stage1_num_stages": 3, "stage2_num_warps": 8, "stage2_num_stages": 1}, "8": {"BLOCK_SEQ": 32, "BLOCK_N": 32, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 4, "stage1_num_stages": 3, "stage2_num_warps": 8, "stage2_num_stages": 1}, "16": {"BLOCK_SEQ": 64, "BLOCK_N": 64, "BLOCK_Q_HEAD": 64, "stage1_num_warps": 8, "stage1_num_stages": 5, "stage2_num_warps": 4, "stage2_num_stages": 1}, "32": {"BLOCK_SEQ": 64, "BLOCK_N": 64, "BLOCK_Q_HEAD": 64, "stage1_num_warps": 8, "stage1_num_stages": 5, "stage2_num_warps": 4, "stage2_num_stages": 1}, "64": {"BLOCK_SEQ": 128, "BLOCK_N": 32, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 4, "stage1_num_stages": 4, "stage2_num_warps": 4, "stage2_num_stages": 1}, "128": {"BLOCK_SEQ": 256, "BLOCK_N": 64, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 8, "stage1_num_stages": 4, "stage2_num_warps": 1, "stage2_num_stages": 1}, "256": {"BLOCK_SEQ": 256, "BLOCK_N": 32, "BLOCK_Q_HEAD": 32, "stage1_num_warps": 4, "stage1_num_stages": 4, "stage2_num_warps": 2, "stage2_num_stages": 1}}, "512": {"1": {"BLOCK_SEQ": 32, "BLOCK_N": 32, "BLOCK_Q_HEAD": 64, "stage1_num_warps": 4, "stage1_num_stages": 2, "stage2_num_warps": 16, "stage2_num_stages": 1}, "8": {"BLOCK_SEQ": 64, "BLOCK_N": 64, "BLOCK_Q_HEAD": 64, "stage1_num_warps": 8, "stage1_num_stages": 5, "stage2_num_warps": 4, "stage2_num_stages": 1}, "16": {"BLOCK_SEQ": 64, "BLOCK_N": 32, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 4, "stage1_num_stages": 4, "stage2_num_warps": 4, "stage2_num_stages": 1}, "32": {"BLOCK_SEQ": 128, "BLOCK_N": 32, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 4, "stage1_num_stages": 4, "stage2_num_warps": 8, "stage2_num_stages": 1}, "64": {"BLOCK_SEQ": 256, "BLOCK_N": 64, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 8, "stage1_num_stages": 5, "stage2_num_warps": 4, "stage2_num_stages": 1}, "128": {"BLOCK_SEQ": 256, "BLOCK_N": 32, "BLOCK_Q_HEAD": 64, "stage1_num_warps": 4, "stage1_num_stages": 4, "stage2_num_warps": 4, "stage2_num_stages": 1}, "256": {"BLOCK_SEQ": 256, "BLOCK_N": 16, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 2, "stage1_num_stages": 3, "stage2_num_warps": 2, "stage2_num_stages": 1}}, "1024": {"1": {"BLOCK_SEQ": 64, "BLOCK_N": 64, "BLOCK_Q_HEAD": 32, "stage1_num_warps": 4, "stage1_num_stages": 2, "stage2_num_warps": 16, "stage2_num_stages": 1}, "8": {"BLOCK_SEQ": 64, "BLOCK_N": 32, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 4, "stage1_num_stages": 4, "stage2_num_warps": 16, "stage2_num_stages": 1}, "16": {"BLOCK_SEQ": 128, "BLOCK_N": 32, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 4, "stage1_num_stages": 4, "stage2_num_warps": 8, "stage2_num_stages": 1}, "32": {"BLOCK_SEQ": 256, "BLOCK_N": 64, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 8, "stage1_num_stages": 5, "stage2_num_warps": 4, "stage2_num_stages": 1}, "64": {"BLOCK_SEQ": 256, "BLOCK_N": 32, "BLOCK_Q_HEAD": 32, "stage1_num_warps": 4, "stage1_num_stages": 4, "stage2_num_warps": 4, "stage2_num_stages": 1}, "128": {"BLOCK_SEQ": 256, "BLOCK_N": 16, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 2, "stage1_num_stages": 3, "stage2_num_warps": 4, "stage2_num_stages": 1}, "256": {"BLOCK_SEQ": 256, "BLOCK_N": 16, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 2, "stage1_num_stages": 3, "stage2_num_warps": 2, "stage2_num_stages": 1}}, "2048": {"1": {"BLOCK_SEQ": 64, "BLOCK_N": 64, "BLOCK_Q_HEAD": 64, "stage1_num_warps": 8, "stage1_num_stages": 4, "stage2_num_warps": 16, "stage2_num_stages": 1}, "8": {"BLOCK_SEQ": 128, "BLOCK_N": 32, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 4, "stage1_num_stages": 4, "stage2_num_warps": 16, "stage2_num_stages": 1}, "16": {"BLOCK_SEQ": 256, "BLOCK_N": 64, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 8, "stage1_num_stages": 5, "stage2_num_warps": 4, "stage2_num_stages": 5}, "32": {"BLOCK_SEQ": 256, "BLOCK_N": 32, "BLOCK_Q_HEAD": 64, "stage1_num_warps": 4, "stage1_num_stages": 4, "stage2_num_warps": 4, "stage2_num_stages": 1}, "64": {"BLOCK_SEQ": 256, "BLOCK_N": 16, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 2, "stage1_num_stages": 3, "stage2_num_warps": 8, "stage2_num_stages": 1}, "128": {"BLOCK_SEQ": 256, "BLOCK_N": 16, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 2, "stage1_num_stages": 3, "stage2_num_warps": 4, "stage2_num_stages": 1}, "256": {"BLOCK_SEQ": 256, "BLOCK_N": 32, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 4, "stage1_num_stages": 5, "stage2_num_warps": 1, "stage2_num_stages": 1}}, "4096": {"1": {"BLOCK_SEQ": 128, "BLOCK_N": 64, "BLOCK_Q_HEAD": 64, "stage1_num_warps": 8, "stage1_num_stages": 5, "stage2_num_warps": 8, "stage2_num_stages": 1}, "8": {"BLOCK_SEQ": 256, "BLOCK_N": 64, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 8, "stage1_num_stages": 5, "stage2_num_warps": 4, "stage2_num_stages": 4}, "16": {"BLOCK_SEQ": 256, "BLOCK_N": 32, "BLOCK_Q_HEAD": 64, "stage1_num_warps": 4, "stage1_num_stages": 4, "stage2_num_warps": 8, "stage2_num_stages": 1}, "32": {"BLOCK_SEQ": 256, "BLOCK_N": 16, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 2, "stage1_num_stages": 3, "stage2_num_warps": 8, "stage2_num_stages": 1}, "64": {"BLOCK_SEQ": 256, "BLOCK_N": 16, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 2, "stage1_num_stages": 3, "stage2_num_warps": 8, "stage2_num_stages": 1}, "128": {"BLOCK_SEQ": 256, "BLOCK_N": 32, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 4, "stage1_num_stages": 4, "stage2_num_warps": 4, "stage2_num_stages": 1}}, "8192": {"1": {"BLOCK_SEQ": 128, "BLOCK_N": 64, "BLOCK_Q_HEAD": 64, "stage1_num_warps": 8, "stage1_num_stages": 5, "stage2_num_warps": 16, "stage2_num_stages": 1}, "8": {"BLOCK_SEQ": 256, "BLOCK_N": 32, "BLOCK_Q_HEAD": 64, "stage1_num_warps": 4, "stage1_num_stages": 4, "stage2_num_warps": 16, "stage2_num_stages": 1}, "16": {"BLOCK_SEQ": 256, "BLOCK_N": 16, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 2, "stage1_num_stages": 3, "stage2_num_warps": 8, "stage2_num_stages": 1}, "32": {"BLOCK_SEQ": 256, "BLOCK_N": 32, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 4, "stage1_num_stages": 5, "stage2_num_warps": 16, "stage2_num_stages": 1}, "64": {"BLOCK_SEQ": 256, "BLOCK_N": 32, "BLOCK_Q_HEAD": 32, "stage1_num_warps": 4, "stage1_num_stages": 4, "stage2_num_warps": 4, "stage2_num_stages": 1}}} \ No newline at end of file +{"256": {"1": {"BLOCK_N": 32, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 4, "stage1_num_stages": 2, "stage2_num_warps": 4, "stage2_num_stages": 1}, "8": {"BLOCK_N": 32, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 8, "stage1_num_stages": 3, "stage2_num_warps": 4, "stage2_num_stages": 1}, "16": {"BLOCK_N": 32, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 8, "stage1_num_stages": 3, "stage2_num_warps": 4, "stage2_num_stages": 1}, "32": {"BLOCK_N": 32, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 4, "stage1_num_stages": 2, "stage2_num_warps": 4, "stage2_num_stages": 1}, "64": {"BLOCK_N": 32, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 4, "stage1_num_stages": 2, "stage2_num_warps": 4, "stage2_num_stages": 1}, "128": {"BLOCK_N": 32, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 8, "stage1_num_stages": 5, "stage2_num_warps": 4, "stage2_num_stages": 3}, "256": {"BLOCK_N": 32, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 8, "stage1_num_stages": 5, "stage2_num_warps": 1, "stage2_num_stages": 1}}, "512": {"1": {"BLOCK_N": 32, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 4, "stage1_num_stages": 2, "stage2_num_warps": 4, "stage2_num_stages": 1}, "8": {"BLOCK_N": 32, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 8, "stage1_num_stages": 3, "stage2_num_warps": 4, "stage2_num_stages": 1}, "16": {"BLOCK_N": 32, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 4, "stage1_num_stages": 2, "stage2_num_warps": 4, "stage2_num_stages": 1}, "32": {"BLOCK_N": 32, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 4, "stage1_num_stages": 2, "stage2_num_warps": 4, "stage2_num_stages": 1}, "64": {"BLOCK_N": 32, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 8, "stage1_num_stages": 4, "stage2_num_warps": 4, "stage2_num_stages": 1}, "128": {"BLOCK_N": 32, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 8, "stage1_num_stages": 5, "stage2_num_warps": 1, "stage2_num_stages": 1}, "256": {"BLOCK_N": 32, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 8, "stage1_num_stages": 5, "stage2_num_warps": 2, "stage2_num_stages": 1}}, "1024": {"1": {"BLOCK_N": 32, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 8, "stage1_num_stages": 3, "stage2_num_warps": 4, "stage2_num_stages": 1}, "8": {"BLOCK_N": 32, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 4, "stage1_num_stages": 2, "stage2_num_warps": 4, "stage2_num_stages": 1}, "16": {"BLOCK_N": 32, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 4, "stage1_num_stages": 2, "stage2_num_warps": 4, "stage2_num_stages": 1}, "32": {"BLOCK_N": 32, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 8, "stage1_num_stages": 4, "stage2_num_warps": 4, "stage2_num_stages": 1}, "64": {"BLOCK_N": 32, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 8, "stage1_num_stages": 4, "stage2_num_warps": 4, "stage2_num_stages": 1}, "128": {"BLOCK_N": 32, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 8, "stage1_num_stages": 5, "stage2_num_warps": 1, "stage2_num_stages": 1}, "256": {"BLOCK_N": 32, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 8, "stage1_num_stages": 5, "stage2_num_warps": 1, "stage2_num_stages": 1}}, "2048": {"1": {"BLOCK_N": 32, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 8, "stage1_num_stages": 1, "stage2_num_warps": 4, "stage2_num_stages": 1}, "8": {"BLOCK_N": 32, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 4, "stage1_num_stages": 2, "stage2_num_warps": 4, "stage2_num_stages": 1}, "16": {"BLOCK_N": 32, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 8, "stage1_num_stages": 4, "stage2_num_warps": 4, "stage2_num_stages": 1}, "32": {"BLOCK_N": 32, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 8, "stage1_num_stages": 4, "stage2_num_warps": 4, "stage2_num_stages": 1}, "64": {"BLOCK_N": 32, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 8, "stage1_num_stages": 4, "stage2_num_warps": 4, "stage2_num_stages": 1}, "128": {"BLOCK_N": 32, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 8, "stage1_num_stages": 5, "stage2_num_warps": 4, "stage2_num_stages": 3}, "256": {"BLOCK_N": 32, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 4, "stage1_num_stages": 4, "stage2_num_warps": 1, "stage2_num_stages": 1}}, "4096": {"1": {"BLOCK_N": 32, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 8, "stage1_num_stages": 3, "stage2_num_warps": 4, "stage2_num_stages": 1}, "8": {"BLOCK_N": 32, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 8, "stage1_num_stages": 5, "stage2_num_warps": 4, "stage2_num_stages": 1}, "16": {"BLOCK_N": 32, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 8, "stage1_num_stages": 4, "stage2_num_warps": 4, "stage2_num_stages": 1}, "32": {"BLOCK_N": 32, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 8, "stage1_num_stages": 4, "stage2_num_warps": 4, "stage2_num_stages": 1}, "64": {"BLOCK_N": 32, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 8, "stage1_num_stages": 5, "stage2_num_warps": 1, "stage2_num_stages": 3}, "128": {"BLOCK_N": 32, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 4, "stage1_num_stages": 4, "stage2_num_warps": 2, "stage2_num_stages": 1}}, "8192": {"1": {"BLOCK_N": 32, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 4, "stage1_num_stages": 2, "stage2_num_warps": 4, "stage2_num_stages": 1}, "8": {"BLOCK_N": 32, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 8, "stage1_num_stages": 4, "stage2_num_warps": 4, "stage2_num_stages": 1}, "16": {"BLOCK_N": 32, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 8, "stage1_num_stages": 5, "stage2_num_warps": 4, "stage2_num_stages": 1}, "32": {"BLOCK_N": 32, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 8, "stage1_num_stages": 5, "stage2_num_warps": 1, "stage2_num_stages": 3}, "64": {"BLOCK_N": 32, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 4, "stage1_num_stages": 4, "stage2_num_warps": 1, "stage2_num_stages": 3}}} \ No newline at end of file diff --git a/lightllm/common/vllm_kernel/_ops.py b/lightllm/common/vllm_kernel/_ops.py index d8c8bd4b6..00fdcb748 100644 --- a/lightllm/common/vllm_kernel/_ops.py +++ b/lightllm/common/vllm_kernel/_ops.py @@ -37,11 +37,6 @@ supports_moe_ops = True -try: - from torch.library import register_fake -except ImportError: - from torch.library import impl_abstract as register_fake - # for vllm_quant.py torch.ops._C.cutlass_scaled_mm = torch.ops.vllm_total.cutlass_scaled_mm @@ -327,21 +322,6 @@ def gptq_gemm( return torch.ops.vllm_total.gptq_gemm(a, b_q_weight, b_gptq_qzeros, b_gptq_scales, b_g_idx, use_exllama, bit) -if hasattr(torch.ops.vllm_total, "gptq_gemm"): - - @register_fake("vllm_total::gptq_gemm") - def _gptq_gemm_fake( - a: torch.Tensor, - b_q_weight: torch.Tensor, - b_gptq_qzeros: torch.Tensor, - b_gptq_scales: torch.Tensor, - b_g_idx: torch.Tensor, - use_exllama: bool, - bit: int, - ) -> torch.Tensor: - return torch.empty((a.size(0), b_q_weight.size(1)), dtype=a.dtype, device=a.device) - - def gptq_shuffle(q_weight: torch.Tensor, q_perm: torch.Tensor, bit: int) -> None: torch.ops.vllm_total.gptq_shuffle(q_weight, q_perm, bit) @@ -376,179 +356,6 @@ def gptq_marlin_24_gemm( ) -if hasattr(torch.ops.vllm_total, "gptq_marlin_24_gemm"): - - @register_fake("vllm_total::gptq_marlin_24_gemm") - def _gptq_marlin_24_gemm_fake( - a: torch.Tensor, - b_q_weight: torch.Tensor, - b_meta: torch.Tensor, - b_scales: torch.Tensor, - workspace: torch.Tensor, - b_q_type, - size_m: torch.SymInt, - size_n: torch.SymInt, - size_k: torch.SymInt, - ) -> torch.Tensor: - return torch.empty((size_m, size_n), device=a.device, dtype=a.dtype) - - @register_fake("vllm_total::gptq_marlin_gemm") - def _gptq_marlin_gemm_fake( - a: torch.Tensor, - b_q_weight: torch.Tensor, - b_scales: torch.Tensor, - b_zeros: torch.Tensor, - g_idx: torch.Tensor, - perm: torch.Tensor, - workspace: torch.Tensor, - b_q_type, - size_m: torch.SymInt, - size_n: torch.SymInt, - size_k: torch.SymInt, - is_k_full: bool, - has_zp: bool = False, - use_fp32_reduce: bool = False, - ) -> torch.Tensor: - return torch.empty((size_m, size_n), device=a.device, dtype=a.dtype) - - @register_fake("vllm_total::ggml_dequantize") - def _ggml_dequantize_fake(W: torch.Tensor, quant_type: int, m: torch.SymInt, n: torch.SymInt) -> torch.Tensor: - return torch.empty((m, n), dtype=torch.float16, device=W.device) - - @register_fake("vllm_total::ggml_mul_mat_vec_a8") - def _ggml_mul_mat_vec_a8_fake( - W: torch.Tensor, - X: torch.Tensor, - quant_type: int, - row: torch.SymInt, - ) -> torch.Tensor: - return torch.empty((1, row), dtype=torch.float16, device=W.device) - - @register_fake("vllm_total::ggml_mul_mat_a8") - def _ggml_mul_mat_a8_fake( - W: torch.Tensor, - X: torch.Tensor, - quant_type: int, - row: torch.SymInt, - ) -> torch.Tensor: - batch = X.size(0) - return torch.empty((batch, row), dtype=torch.float16, device=W.device) - - @register_fake("vllm_total::marlin_qqq_gemm") - def _marlin_qqq_gemm_fake( - a: torch.Tensor, - b_q_weight: torch.Tensor, - s_tok: torch.Tensor, - s_ch: torch.Tensor, - s_group: torch.Tensor, - workspace: torch.Tensor, - size_m: torch.SymInt, - size_n: torch.SymInt, - size_k: torch.SymInt, - ) -> torch.Tensor: - return torch.empty((size_m, size_n), dtype=torch.float16, device=a.device) - - @register_fake("vllm_total::marlin_gemm") - def _marlin_gemm_fake( - a: torch.Tensor, - b_q_weight: torch.Tensor, - b_scales: torch.Tensor, - workspace: torch.Tensor, - size_m: torch.SymInt, - size_n: torch.SymInt, - size_k: torch.SymInt, - ) -> torch.Tensor: - return torch.empty((size_m, size_n), dtype=torch.float16, device=a.device) - - @register_fake("vllm_total::awq_dequantize") - def _awq_dequantize_fake( - qweight: torch.Tensor, - scales: torch.Tensor, - zeros: torch.Tensor, - split_k_iters: torch.SymInt, - thx: int, - thy: int, - ) -> torch.Tensor: - in_c = qweight.size(0) - qout_c = qweight.size(1) - out_c = qout_c * 8 - return torch.empty((in_c, out_c), dtype=scales.dtype, device=scales.device) - - @register_fake("vllm_total::awq_gemm") - def _awq_gemm_fake( - input: torch.Tensor, - qweight: torch.Tensor, - qzeros: torch.Tensor, - scales: torch.Tensor, - split_k_iters: torch.SymInt, - ) -> torch.Tensor: - num_in_feats = input.size(0) - return torch.empty( - (split_k_iters, num_in_feats, qweight.size(1) * 8), dtype=input.dtype, device=input.device - ).sum(0) - - @register_fake("vllm_total::aqlm_gemm") - def _aqlm_gemm_fake( - input: torch.Tensor, - codes: torch.Tensor, - codebooks: torch.Tensor, - scales: torch.Tensor, - codebook_partition_sizes: List[int], - bias: Optional[torch.Tensor], - ) -> torch.Tensor: - out_features = codes.size(0) * codebooks.size(2) - flat_input = input.reshape((-1, input.size(-1))) - flat_output = torch.empty((flat_input.size(0), out_features), dtype=input.dtype, device=input.device) - - output_sizes = list(input.shape) - output_sizes.pop() - output_sizes.append(-1) - return flat_output.reshape(tuple(output_sizes)) - - @register_fake("vllm_total::aqlm_dequant") - def _aqlm_dequant_fake( - codes: torch.Tensor, codebooks: torch.Tensor, codebook_partition_sizes: List[int] - ) -> torch.Tensor: - in_features = codes.size(1) * 8 - out_features = codes.size(0) - return torch.empty((out_features, in_features), dtype=codebooks.dtype, device=codebooks.device) - - @register_fake("vllm_total::fp8_marlin_gemm") - def _fp8_marlin_gemm_fake( - a: torch.Tensor, - b_q_weight: torch.Tensor, - b_scales: torch.Tensor, - workspace: torch.Tensor, - num_bits: int, - size_m: torch.SymInt, - size_n: torch.SymInt, - size_k: torch.SymInt, - ) -> torch.Tensor: - return torch.empty((size_m, size_n), dtype=a.dtype, device=a.device) - - @register_fake("vllm_total::machete_gemm") - def machete_gemm_fake( - a: torch.Tensor, - # Should be the tensor returned by machete_prepack_B - b_q: torch.Tensor, - b_type, - b_scales: Optional[torch.Tensor] = None, - b_zeros: Optional[torch.Tensor] = None, - b_group_size: Optional[int] = None, - c: Optional[torch.Tensor] = None, - alpha: Optional[float] = None, - beta: Optional[float] = None, - schedule: Optional[str] = None, - ) -> torch.Tensor: - m = a.size(0) - n = b_q.size(1) - return torch.empty((m, n), device=a.device, dtype=a.dtype) - - @register_fake("vllm_total::machete_prepack_B") - def machete_prepack_B_fake(b_q_weight: torch.Tensor, b_type) -> torch.Tensor: - return torch.empty_like(b_q_weight, memory_format=torch.contiguous_format) - - # cutlass def cutlass_scaled_mm_supports_fp8(cuda_device_capability: int) -> bool: return torch.ops.vllm_total.cutlass_scaled_mm_supports_fp8(cuda_device_capability) @@ -732,13 +539,6 @@ def machete_prepack_B(b_q_weight: torch.Tensor, b_type) -> torch.Tensor: return torch.ops.vllm_total.machete_prepack_B(b_q_weight, b_type.id) -if hasattr(torch.ops.vllm_total, "permute_cols"): - - @register_fake("vllm_total::permute_cols") - def _permute_cols_fake(a: torch.Tensor, perm: torch.Tensor) -> torch.Tensor: - return torch.empty_like(a) - - def permute_cols(a: torch.Tensor, perm: torch.Tensor) -> torch.Tensor: return torch.ops.vllm_total.permute_cols(a, perm) @@ -961,34 +761,6 @@ def topk_softmax( torch.ops.vllm_moe.topk_softmax(topk_weights, topk_ids, token_expert_indicies, gating_output) -if supports_moe_ops and hasattr(torch.ops.vllm_moe, "marlin_gemm_moe"): - - @register_fake("vllm_moe::marlin_gemm_moe") - def marlin_gemm_moe_fake( - a: torch.Tensor, - b_q_weights: torch.Tensor, - sorted_ids: torch.Tensor, - topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - b_scales: torch.Tensor, - b_zero_points: torch.Tensor, - g_idx: torch.Tensor, - perm: torch.Tensor, - workspace: torch.Tensor, - b_q_type, - size_m: torch.SymInt, - size_n: torch.SymInt, - size_k: torch.SymInt, - is_k_full: bool, - num_experts: int, - topk: int, - moe_block_size: int, - replicate_input: bool, - apply_weights: bool, - ) -> torch.Tensor: - return torch.empty((size_m, topk, size_n), dtype=a.dtype, device=a.device) - - def reshape_and_cache( key: torch.Tensor, value: torch.Tensor, diff --git a/lightllm/server/router/req_queue/continues_batch/impl.py b/lightllm/server/router/req_queue/continues_batch/impl.py index 07cc83379..3c6b2d38d 100644 --- a/lightllm/server/router/req_queue/continues_batch/impl.py +++ b/lightllm/server/router/req_queue/continues_batch/impl.py @@ -104,12 +104,15 @@ def generate_new_batch(self, current_batch: Batch): def _calcu_batch_token_load_batch_not_none(self, current_batch: Batch): is_busy = self.is_busy() self._init_cache_list(current_batch, is_busy) - self.cache_len_list.sort(key=lambda x: -x[1]) - left_out_len_array = np.array([e[1] for e in self.cache_len_list]) - has_run_len_array = np.array([e[0] for e in self.cache_len_list]) - cum_run_len_array = np.cumsum(has_run_len_array) - size_array = np.arange(1, len(self.cache_len_list) + 1, 1) - need_max_token_num = (left_out_len_array * size_array + cum_run_len_array).max() + if len(self.cache_len_list) != 0: + self.cache_len_list.sort(key=lambda x: -x[1]) + left_out_len_array = np.array([e[1] for e in self.cache_len_list]) + has_run_len_array = np.array([e[0] for e in self.cache_len_list]) + cum_run_len_array = np.cumsum(has_run_len_array) + size_array = np.arange(1, len(self.cache_len_list) + 1, 1) + need_max_token_num = (left_out_len_array * size_array + cum_run_len_array).max() + else: + need_max_token_num = 0 return ( need_max_token_num, (need_max_token_num + self.router.shared_token_load.get_frozened_token_count(self.dp_index)) From 6ede09e14a839314ff4991c504030caa412a4fc6 Mon Sep 17 00:00:00 2001 From: hiworldwzj <30762946+hiworldwzj@users.noreply.github.com> Date: Sat, 4 Jan 2025 17:34:04 +0800 Subject: [PATCH 2/6] udpate mla decode attention. (#696) --- lightllm/common/basemodel/cuda_graph.py | 11 ++- .../triton_kernel/gqa_flash_decoding.py | 78 +++++++++++++++---- .../gqa_flash_decoding_stage1.py | 25 +----- 3 files changed, 75 insertions(+), 39 deletions(-) diff --git a/lightllm/common/basemodel/cuda_graph.py b/lightllm/common/basemodel/cuda_graph.py index 85475c3ee..82e4a9c9a 100644 --- a/lightllm/common/basemodel/cuda_graph.py +++ b/lightllm/common/basemodel/cuda_graph.py @@ -1,5 +1,6 @@ import os import torch +import copy from lightllm.utils.log_utils import init_logger from lightllm.distributed import custom_comm_ops @@ -27,10 +28,18 @@ def capture_decode(self, decode_func, input_ids, infer_state): infer_state.max_len_in_batch = self.graph_max_len_in_batch infer_state.total_token_num = self.graph_max_len_in_batch * batch_size # warmup + # 因为有些推理过程的代码,会通过判断infer_state中是否存在某些属性来在一层上 + # 做一些初始化的操作,后续层可以复用这些计算的结果,如 + # lightllm/models/deepseek2/triton_kernel/gqa_flash_decoding.py + # 中做的一些操作,所以在 warmup 的时候,需要调用infer_state的copy函数做一个 + # 浅拷贝,不然后续传入到cuda graph捕获过程中后,infer_state因为提前拥有了这些属性, + # 导致不会重新初始化,这样捕获过程中会不能捕获这些临时添加到 infer_state 管理对象 + # 中的 tensor。 for _ in range(1): torch.cuda.synchronize() - decode_func(input_ids, infer_state) + decode_func(input_ids, copy.copy(infer_state)) # infer_state must copy() torch.cuda.synchronize() + with custom_comm_ops.lightllm_capture_graph(): with torch.cuda.graph(graph_obj, pool=self.mempool): predict_logics = decode_func(input_ids, infer_state) diff --git a/lightllm/models/deepseek2/triton_kernel/gqa_flash_decoding.py b/lightllm/models/deepseek2/triton_kernel/gqa_flash_decoding.py index 8b688ddd3..256dfce5a 100644 --- a/lightllm/models/deepseek2/triton_kernel/gqa_flash_decoding.py +++ b/lightllm/models/deepseek2/triton_kernel/gqa_flash_decoding.py @@ -1,6 +1,8 @@ import os import torch import torch.multiprocessing as mp +import triton +import triton.language as tl from typing import List from lightllm.utils.log_utils import init_logger from .gqa_flash_decoding_config import MlaDecodeAttentionKernelConfig @@ -44,27 +46,19 @@ def gqa_token_decode_attention_flash_decoding( out_dtype=torch.bfloat16, ) + BLOCK_N = run_config["BLOCK_N"] + from .gqa_flash_decoding_stage1 import flash_decode_stage1 from .gqa_flash_decoding_stage2 import flash_decode_stage2 o_tensor = alloc_tensor_func(q_nope.shape, q_nope.dtype, q_nope.device) if out is None else out - mid_o_block_seq = torch.empty([1], dtype=torch.int64, device="cuda") - mid_o_batch_start_index = alloc_tensor_func( - [ - batch_size, - ], - dtype=torch.int64, - device="cuda", - ) - + fake_decode_att_block_seq = torch.empty([0], dtype=torch.int64, device="cuda") mid_o = torch.empty([q_head_num, 0, kv_lora_rank], dtype=torch.float32, device="cuda") mid_o_logexpsum = torch.empty([q_head_num, 0], dtype=torch.float32, device="cuda") vsm_count = flash_decode_stage1( - infer_state.total_token_num_tensor, - mid_o_block_seq, - mid_o_batch_start_index, + fake_decode_att_block_seq, q_nope.view(calcu_shape1), q_rope.view(calcu_shape2), kv_nope, @@ -79,13 +73,40 @@ def gqa_token_decode_attention_flash_decoding( **run_config ) + if not hasattr(infer_state, "decode_att_block_seq"): + assert batch_size <= 2048 + decode_att_block_seq = torch.empty( + [ + 1, + ], + dtype=torch.int64, + device="cuda", + ) + mid_o_batch_start_index = torch.empty( + [ + batch_size, + ], + dtype=torch.int64, + device="cuda", + ) + _fwd_kernel_calcu_index_and_block_seq[(1,)]( + infer_state.b_seq_len, + decode_att_block_seq, + mid_o_batch_start_index, + vsm_count, + batch_size, + BLOCK_N=BLOCK_N, + num_warps=4, + ) + + infer_state.decode_att_block_seq = decode_att_block_seq + infer_state.mid_o_batch_start_index = mid_o_batch_start_index + mid_o = torch.empty([q_head_num, vsm_count * 4 + batch_size, kv_lora_rank], dtype=torch.float32, device="cuda") mid_o_logexpsum = torch.empty([q_head_num, vsm_count * 4 + batch_size], dtype=torch.float32, device="cuda") flash_decode_stage1( - infer_state.total_token_num_tensor, - mid_o_block_seq, - mid_o_batch_start_index, + infer_state.decode_att_block_seq, q_nope.view(calcu_shape1), q_rope.view(calcu_shape2), kv_nope, @@ -101,8 +122,8 @@ def gqa_token_decode_attention_flash_decoding( ) flash_decode_stage2( - mid_o_block_seq, - mid_o_batch_start_index, + infer_state.decode_att_block_seq, + infer_state.mid_o_batch_start_index, mid_o, mid_o_logexpsum, infer_state.b_seq_len, @@ -110,3 +131,26 @@ def gqa_token_decode_attention_flash_decoding( **run_config ) return o_tensor + + +@triton.jit +def _fwd_kernel_calcu_index_and_block_seq( + b_seq_len_ptr, + mid_o_decode_att_block_seq_ptr, + mid_o_batch_start_index_ptr, + num_sm, + batch_size, + BLOCK_N: tl.constexpr, +): + b_seq_len = tl.load(b_seq_len_ptr + tl.arange(0, 2048), mask=tl.arange(0, 2048) < batch_size, other=0) + total_token_num = tl.sum(b_seq_len) + + block_seq = tl.cast(total_token_num / (num_sm * 4), dtype=tl.int32) + 1 + block_seq = tl.cdiv(block_seq, BLOCK_N) * BLOCK_N + + block_seq_len = tl.cdiv(b_seq_len, block_seq) + cumsum_seq_len = tl.cumsum(block_seq_len) + batch_start_index = cumsum_seq_len - block_seq_len + tl.store(mid_o_batch_start_index_ptr + tl.arange(0, 2048), batch_start_index, mask=tl.arange(0, 2048) < batch_size) + tl.store(mid_o_decode_att_block_seq_ptr, block_seq) + return diff --git a/lightllm/models/deepseek2/triton_kernel/gqa_flash_decoding_stage1.py b/lightllm/models/deepseek2/triton_kernel/gqa_flash_decoding_stage1.py index 9028c6e85..9191de6d1 100644 --- a/lightllm/models/deepseek2/triton_kernel/gqa_flash_decoding_stage1.py +++ b/lightllm/models/deepseek2/triton_kernel/gqa_flash_decoding_stage1.py @@ -35,9 +35,7 @@ def _fwd_kernel_flash_decode_stage1_padding( stride_mid_od, stride_mid_o_eh, stride_mid_o_es, - total_token_ptr, block_size_ptr, - batch_start_index_ptr, num_sm, head_group_num, head_num, @@ -51,15 +49,9 @@ def _fwd_kernel_flash_decode_stage1_padding( ): # cur_kv_head = 0 sm_id = tl.program_id(0).to(tl.int64) - grid_id = sm_id out_batch_start_index = tl.cast(0, tl.int64) - total_token_num = tl.load(total_token_ptr, eviction_policy="evict_last") + block_seq = tl.load(block_size_ptr, eviction_policy="evict_last") - block_seq = tl.cast(total_token_num / num_sm / 4, dtype=tl.int32) + 1 - block_seq = tl.cdiv(block_seq, BLOCK_N) * BLOCK_N - - if grid_id == 0: - tl.store(block_size_ptr, block_seq) cur_q_head_offs = tl.arange(0, Q_HEAD_NUM) offs_d = tl.arange(0, BLOCK_DMODEL) offs_rope_d = tl.arange(0, BLOCK_ROPE_DMODEL) @@ -163,9 +155,6 @@ def _fwd_kernel_flash_decode_stage1_padding( ) sm_id += num_sm - if grid_id == 0: - tl.store(batch_start_index_ptr + cur_batch, out_batch_start_index) - out_batch_start_index += cur_block_num // head_group_num sm_id -= cur_block_num return @@ -173,9 +162,7 @@ def _fwd_kernel_flash_decode_stage1_padding( @torch.no_grad() def flash_decode_stage1( - total_token_num_tensor: torch.Tensor, - out_block_seq: torch.Tensor, - batch_start_index: torch.Tensor, + in_block_seq: torch.Tensor, q_nope, q_rope, kv_nope, @@ -227,9 +214,7 @@ def flash_decode_stage1( *kv_rope.stride(), *mid_out.stride(), *mid_out_logsumexp.stride(), - total_token_num_tensor, - out_block_seq, - batch_start_index, + in_block_seq, num_sm=1, head_group_num=head_group_num, head_num=q_head_num, @@ -271,9 +256,7 @@ def flash_decode_stage1( *kv_rope.stride(), *mid_out.stride(), *mid_out_logsumexp.stride(), - total_token_num_tensor, - out_block_seq, - batch_start_index, + in_block_seq, num_sm=num_sm, head_group_num=head_group_num, head_num=q_head_num, From 7243b6018c3fa8ef61dd372ec18ad3a911611bb3 Mon Sep 17 00:00:00 2001 From: hiworldwzj <30762946+hiworldwzj@users.noreply.github.com> Date: Mon, 6 Jan 2025 11:42:32 +0800 Subject: [PATCH 3/6] overlap post sample. (#697) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit batch 1, decode faster 1ms. --------- Co-authored-by: wangzaijun --- .../continues_batch/post_process.py | 53 ++++++++++++------- 1 file changed, 35 insertions(+), 18 deletions(-) diff --git a/lightllm/server/router/model_infer/mode_backend/continues_batch/post_process.py b/lightllm/server/router/model_infer/mode_backend/continues_batch/post_process.py index a602fcdd3..79f634727 100644 --- a/lightllm/server/router/model_infer/mode_backend/continues_batch/post_process.py +++ b/lightllm/server/router/model_infer/mode_backend/continues_batch/post_process.py @@ -1,27 +1,44 @@ -import re import torch -from typing import List, Tuple -from lightllm.server.router.model_infer.infer_batch import InferBatch +from typing import List from lightllm.common.basemodel.triton_kernel.apply_penalty import apply_penalty +from dataclasses import dataclass + + +@dataclass +class OverlapStream: + overlap_stream: torch.cuda.Stream = None + + def get_overlap_stream(self): + if self.overlap_stream is None: + self.overlap_stream = torch.cuda.Stream() + return self.overlap_stream + + +g_single_overlap_stream = OverlapStream() def sample(logits, reqs, eos_id: List[int] = [2]): + + with torch.cuda.stream(g_single_overlap_stream.get_overlap_stream()): + ( + presence_penalties, + frequency_penalties, + repetition_penalties, + exponential_decay_length_penalties, + temperatures, + top_ps, + top_ks, + p_token_ids, + p_token_counts, + p_cumsum_seq_len, + p_max_len_in_batch, + length_penalty_idx, + mask_eos_reqs, + ) = _get_post_sample_tensors(reqs) + + torch.cuda.current_stream().wait_stream(g_single_overlap_stream.get_overlap_stream()) + logits = logits.contiguous() - ( - presence_penalties, - frequency_penalties, - repetition_penalties, - exponential_decay_length_penalties, - temperatures, - top_ps, - top_ks, - p_token_ids, - p_token_counts, - p_cumsum_seq_len, - p_max_len_in_batch, - length_penalty_idx, - mask_eos_reqs, - ) = _get_post_sample_tensors(reqs) apply_penalty( logits, From 777fc041cc550bca9d1583c31c3b72def65fea48 Mon Sep 17 00:00:00 2001 From: shihaobai <42648726+shihaobai@users.noreply.github.com> Date: Mon, 6 Jan 2025 17:57:18 +0800 Subject: [PATCH 4/6] add sample_param logs and env: HEALTH_TIMEOUT (#698) Co-authored-by: shihaobai --- lightllm/server/httpserver/manager.py | 44 ++++++++++++++------------- lightllm/server/sampling_params.py | 6 ++++ lightllm/utils/health_check.py | 4 ++- 3 files changed, 32 insertions(+), 22 deletions(-) diff --git a/lightllm/server/httpserver/manager.py b/lightllm/server/httpserver/manager.py index 7bde5724f..598d38929 100644 --- a/lightllm/server/httpserver/manager.py +++ b/lightllm/server/httpserver/manager.py @@ -358,27 +358,29 @@ async def _wait_to_token_package( prompt_cache_len = metadata.pop("prompt_cache_len", 0) prompt_cache_ratio = prompt_cache_len / prompt_tokens format_start_time = datetime.datetime.fromtimestamp(start_time).strftime("%Y-%m-%d %H:%M:%S") - logger.info( - f"X-Request-Id:{x_request_id} " - f"X-Session-Id:{x_session_id} start_time:{format_start_time} " - f"lightllm_req_id:{group_request_id} first_token_cost:{first_token_cost_ms}ms " - f"total_cost_time:{total_cost_time_ms}ms,out_token_counter:{out_token_counter} " - f"mean_per_token_cost_time: {mean_per_token_cost_time_ms}ms " - f"prompt_token_num:{prompt_tokens} " - f"prompt_cache_len:{prompt_cache_len} " - f"prompt_cache_ratio:{prompt_cache_ratio} " - ) - self.metric_client.histogram_observe( - "lightllm_request_inference_duration", total_cost_time_ms / 1000.0 - ) - self.metric_client.histogram_observe( - "lightllm_request_mean_time_per_token_duration", mean_per_token_cost_time_ms / 1000.0 - ) - self.metric_client.histogram_observe( - "lightllm_request_first_token_duration", first_token_cost_ms / 1000.0 - ) - self.metric_client.histogram_observe("lightllm_request_generated_tokens", out_token_counter) - self.metric_client.counter_inc("lightllm_request_success") + if request is not None: + logger.info( + f"X-Request-Id:{x_request_id} " + f"X-Session-Id:{x_session_id} start_time:{format_start_time} " + f"lightllm_req_id:{group_request_id} first_token_cost:{first_token_cost_ms}ms " + f"total_cost_time:{total_cost_time_ms}ms,out_token_counter:{out_token_counter} " + f"mean_per_token_cost_time: {mean_per_token_cost_time_ms}ms " + f"prompt_token_num:{prompt_tokens} " + f"prompt_cache_len:{prompt_cache_len} " + f"prompt_cache_ratio:{prompt_cache_ratio} " + f"sampling_params: {{{sampling_params.to_string()}}}" + ) + self.metric_client.histogram_observe( + "lightllm_request_inference_duration", total_cost_time_ms / 1000.0 + ) + self.metric_client.histogram_observe( + "lightllm_request_mean_time_per_token_duration", mean_per_token_cost_time_ms / 1000.0 + ) + self.metric_client.histogram_observe( + "lightllm_request_first_token_duration", first_token_cost_ms / 1000.0 + ) + self.metric_client.histogram_observe("lightllm_request_generated_tokens", out_token_counter) + self.metric_client.counter_inc("lightllm_request_success") return req_status.out_token_info_list.clear() diff --git a/lightllm/server/sampling_params.py b/lightllm/server/sampling_params.py index af8e6215f..1b8b16272 100644 --- a/lightllm/server/sampling_params.py +++ b/lightllm/server/sampling_params.py @@ -260,3 +260,9 @@ def to_origin_dict(self): ret["group_request_id"] = self.group_request_id ret["suggested_dp_index"] = self.suggested_dp_index return ret + + def to_string(self): + output_str = "" + for name, value in vars(self).items(): + output_str += f"{name}: {value} " + return output_str diff --git a/lightllm/utils/health_check.py b/lightllm/utils/health_check.py index fecf59843..cb5a755b5 100644 --- a/lightllm/utils/health_check.py +++ b/lightllm/utils/health_check.py @@ -1,3 +1,4 @@ +import os import asyncio import numpy as np from dataclasses import dataclass @@ -19,6 +20,7 @@ class HealthObj: _is_health: bool = True _is_health_checking: bool = False + timeout: int = int(os.getenv("HEALTH_TIMEOUT", 100)) def begin_check(self): self._is_health_checking = True @@ -65,7 +67,7 @@ async def check_timeout(results_generator): pass try: - await asyncio.wait_for(check_timeout(results_generator), timeout=88) + await asyncio.wait_for(check_timeout(results_generator), timeout=health_obj.timeout) health_obj.set_health() except asyncio.TimeoutError: health_obj.set_unhealth() From 2c64be6d943d6108065d03cdce0f656d34a956f7 Mon Sep 17 00:00:00 2001 From: shihaobai <42648726+shihaobai@users.noreply.github.com> Date: Mon, 6 Jan 2025 18:58:14 +0800 Subject: [PATCH 5/6] [debug]: add some debug log (#699) Co-authored-by: shihaobai --- .../server/router/model_infer/model_rpc.py | 20 +++++++++++++------ 1 file changed, 14 insertions(+), 6 deletions(-) diff --git a/lightllm/server/router/model_infer/model_rpc.py b/lightllm/server/router/model_infer/model_rpc.py index d5f1a6106..064cad0ae 100644 --- a/lightllm/server/router/model_infer/model_rpc.py +++ b/lightllm/server/router/model_infer/model_rpc.py @@ -93,15 +93,23 @@ def exposed_add_batch(self, batch_id, reqs): # @calculate_time(show=False, min_cost_ms=300) def exposed_prefill_batch(self, batch_id): - if self.world_size != 1: - batch_id = obtain(batch_id) - return self.backend.prefill_batch(batch_id) + try: + if self.world_size != 1: + batch_id = obtain(batch_id) + return self.backend.prefill_batch(batch_id) + except Exception as e: + err_msg = str(e) + logger.exception(f"Batch prefill encountered an unexpected ERROR: {err_msg}") # @calculate_time(show=True, min_cost_ms=200) def exposed_decode_batch(self, batch_id): - if self.world_size != 1: - batch_id = obtain(batch_id) - return self.backend.decode_batch(batch_id) + try: + if self.world_size != 1: + batch_id = obtain(batch_id) + return self.backend.decode_batch(batch_id) + except Exception as e: + err_msg = str(e) + logger.exception(f"Batch decode encountered an unexpected ERROR: {err_msg}") # @calculate_time(show=True, min_cost_ms=0.1) def exposed_filter_batch(self, batch_id, req_id_list, finished_req_id_list): From 269631d60021399182ca5c3f0b5b6ab5cc4a8503 Mon Sep 17 00:00:00 2001 From: blueswhen Date: Mon, 13 Jan 2025 09:45:32 +0800 Subject: [PATCH 6/6] feat: add _context_attention_kernel_with_CC in deepseek2 (#693) --- .../layer_infer/transformer_layer_infer.py | 71 ++++- .../context_flashattention_nopad_with_v.py | 269 ++---------------- .../deepseek2/triton_kernel/sample_kv.py | 104 +++++++ 3 files changed, 198 insertions(+), 246 deletions(-) create mode 100644 lightllm/models/deepseek2/triton_kernel/sample_kv.py diff --git a/lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py b/lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py index f8eaa0865..87777463b 100644 --- a/lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py @@ -10,6 +10,8 @@ context_attention_fwd, context_attention_fwd_no_prompt_cache, ) +from lightllm.models.deepseek2.triton_kernel.context_flashattention_nopad_with_v import context_attention_fwd_with_v +from lightllm.models.deepseek2.triton_kernel.sample_kv import sample_kv from lightllm.models.deepseek2.triton_kernel.gqa_flash_decoding import gqa_token_decode_attention_flash_decoding from lightllm.models.llama.layer_infer.transformer_layer_infer import LlamaTransformerLayerInfer @@ -54,6 +56,7 @@ def __init__(self, layer_num, tp_rank, world_size, network_config, mode=[]): if mscale_all_dim: mscale = get_deepseek_mscale(scaling_factor, mscale_all_dim) self.softmax_scale = self.softmax_scale * mscale * mscale + self.enable_cc_method = os.getenv("ENABLE_CC_METHOD", "False").upper() in ["ON", "TRUE", "1"] super().__init__(layer_num, tp_rank, world_size, network_config, mode) self.enable_dp = os.getenv("ENABLE_DP", "0").upper() in ["ON", "TRUE", "1"] if self.enable_dp: @@ -65,7 +68,14 @@ def __init__(self, layer_num, tp_rank, world_size, network_config, mode=[]): return def _bind_attention(self): - self._context_attention_kernel = partial(Deepseek2TransformerLayerInfer._context_attention_kernel_origin, self) + if self.enable_cc_method: + self._context_attention_kernel = partial( + Deepseek2TransformerLayerInfer._context_attention_kernel_with_CC, self + ) + else: + self._context_attention_kernel = partial( + Deepseek2TransformerLayerInfer._context_attention_kernel_origin, self + ) self._token_attention_kernel = partial( Deepseek2TransformerLayerInfer._token_gqa_decode_attention_flashdecoding, self ) @@ -123,6 +133,65 @@ def _get_o( o_tensor = layer_weight.o_weight_.mm(input.reshape(-1, self.tp_q_head_num_ * self.qk_nope_head_dim)) return o_tensor + def _context_attention_kernel_with_CC( + self, + q: torch.Tensor, + kv, + infer_state: Deepseek2InferStateInfo, + layer_weight: Deepseek2TransformerLayerWeight, + out=None, + ) -> torch.Tensor: + if infer_state.use_dynamic_prompt_cache: + kv = infer_state.mem_manager.kv_buffer[self.layer_num_] + compressed_kv = self.alloc_tensor( + [infer_state.total_token_num, 1, layer_weight.kv_lora_rank], dtype=kv.dtype + ) + k_rope = self.alloc_tensor([infer_state.total_token_num, 1, self.qk_rope_head_dim], dtype=kv.dtype) + sample_kv( + kv, + compressed_kv, + k_rope, + infer_state.b_req_idx, + infer_state.b_seq_len, + infer_state.req_manager.req_to_token_indexs, + ) + else: + compressed_kv, k_rope = torch.split( # (b*s, 1, kv_lora + qk_r) + kv, [layer_weight.kv_lora_rank, layer_weight.qk_rope_head_dim], dim=-1 + ) + + # CC + k_nope = self.alloc_tensor( + [compressed_kv.shape[0], q.shape[1], self.qk_nope_head_dim], + dtype=compressed_kv.dtype, + ) + v = self.alloc_tensor( + k_nope.shape, + dtype=compressed_kv.dtype, + ) + compressed_kv = compressed_kv.view(-1, layer_weight.kv_lora_rank) + wk = layer_weight.k_b_proj_.weight.view(-1, layer_weight.kv_lora_rank) + wv = layer_weight.v_b_proj_.weight.transpose(1, 2).view(-1, layer_weight.kv_lora_rank) + torch.mm(compressed_kv, wk.transpose(0, 1), out=k_nope.reshape(compressed_kv.shape[0], -1)) + torch.mm(compressed_kv, wv.transpose(0, 1), out=v.reshape(compressed_kv.shape[0], -1)) + + q_nope, q_rope = q[:, :, : -self.qk_rope_head_dim], q[:, :, -self.qk_rope_head_dim :] + o_tensor = self.alloc_tensor(q_nope.shape, dtype=q_nope.dtype) if out is None else out + context_attention_fwd_with_v( + q_nope, + q_rope, + k_nope, + k_rope, + v, + o_tensor.view(-1, self.tp_q_head_num_, q_nope.shape[-1]), + infer_state.b_start_loc, + infer_state.b_seq_len, + infer_state.b_ready_cache_len, + infer_state.max_len_in_batch, + self.softmax_scale, + ) + return o_tensor + def _context_attention_kernel_origin( self, q: torch.Tensor, diff --git a/lightllm/models/deepseek2/triton_kernel/context_flashattention_nopad_with_v.py b/lightllm/models/deepseek2/triton_kernel/context_flashattention_nopad_with_v.py index 49cd34556..04e6facbb 100644 --- a/lightllm/models/deepseek2/triton_kernel/context_flashattention_nopad_with_v.py +++ b/lightllm/models/deepseek2/triton_kernel/context_flashattention_nopad_with_v.py @@ -20,28 +20,18 @@ def _fwd_kernel_with_v( B_Start_Loc, B_Seqlen, # B_LOC 内部记录每个batch 输入的真实位置, B_SEQ_len 记录当前输入的真实长度 Out, - Req_to_tokens, - B_req_idx, stride_q_bs, stride_q_h, - stride_q_d, stride_q_rope_bs, stride_q_rope_h, - stride_q_rope_d, stride_k_bs, stride_k_h, - stride_k_d, stride_k_rope_bs, stride_k_rope_h, - stride_k_rope_d, stride_vbs, stride_vh, - stride_vd, stride_obs, stride_oh, - stride_od, - stride_req_to_tokens_b, - stride_req_to_tokens_s, b_prompt_cache_len, BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, @@ -57,7 +47,6 @@ def _fwd_kernel_with_v( cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch) prompt_cache_len = tl.load(b_prompt_cache_len + cur_batch) cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) - prompt_cache_len - cur_batch_req_idx = tl.load(B_req_idx + cur_batch) block_start_loc = BLOCK_M * start_m @@ -66,20 +55,23 @@ def _fwd_kernel_with_v( offs_d = tl.arange(0, BLOCK_DMODEL) offs_rope_d = tl.arange(0, BLOCK_ROPE_DMODEL) offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) - off_q = ( - (cur_batch_in_all_start_index + offs_m[:, None]) * stride_q_bs - + cur_head * stride_q_h - + offs_d[None, :] * stride_q_d - ) + off_q = (cur_batch_in_all_start_index + offs_m[:, None]) * stride_q_bs + cur_head * stride_q_h + offs_d[None, :] off_q_rope = ( (cur_batch_in_all_start_index + offs_m[:, None]) * stride_q_rope_bs + cur_head * stride_q_rope_h - + offs_rope_d[None, :] * stride_q_rope_d + + offs_rope_d[None, :] ) + off_k = offs_n[None, :] * stride_k_bs + cur_k_head * stride_k_h + offs_d[:, None] + off_k_rope = offs_n[None, :] * stride_k_rope_bs + offs_rope_d[:, None] + off_v = offs_n[:, None] * stride_vbs + cur_k_head * stride_vh + offs_d[None, :] q = tl.load(Q_nope + off_q, mask=offs_m[:, None] < cur_batch_seq_len, other=0.0) q_rope = tl.load(Q_rope + off_q_rope, mask=offs_m[:, None] < cur_batch_seq_len, other=0.0) + k_ptrs = K_nope + off_k + k_rope_ptrs = K_rope + off_k_rope + v_ptrs = V + off_v + # initialize pointer to m and l m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") l_i = tl.zeros([BLOCK_M], dtype=tl.float32) @@ -91,20 +83,20 @@ def _fwd_kernel_with_v( for start_n in range(0, block_mask * block_end_loc, BLOCK_N): start_n = tl.multiple_of(start_n, BLOCK_N) # -- compute qk ---- - k_loc = tl.load( - Req_to_tokens + stride_req_to_tokens_b * cur_batch_req_idx + stride_req_to_tokens_s * (start_n + offs_n), - mask=(start_n + offs_n) < block_end_loc, - other=0, + k = tl.load( + k_ptrs + (cur_batch_in_all_start_index + start_n) * stride_k_bs, + mask=(start_n + offs_n[None, :]) < block_end_loc, + other=0.0, + ) + k_rope = tl.load( + k_rope_ptrs + (cur_batch_in_all_start_index + start_n) * stride_k_rope_bs, + mask=(start_n + offs_n[None, :]) < block_end_loc, + other=0.0, ) - off_k = k_loc[None, :] * stride_k_bs + cur_k_head * stride_k_h + offs_d[:, None] * stride_k_d - off_k_rope = k_loc[None, :] * stride_k_rope_bs + offs_rope_d[:, None] * stride_k_rope_d - k = tl.load(K_nope + off_k, mask=(start_n + offs_n[None, :]) < block_end_loc, other=0.0) - k_rope = tl.load(K_rope + off_k_rope, mask=(start_n + offs_n[None, :]) < block_end_loc, other=0.0) qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) qk += tl.dot(q, k) qk += tl.dot(q_rope, k_rope) - qk *= sm_scale qk = tl.where(offs_m[:, None] + prompt_cache_len >= start_n + offs_n[None, :], qk, float("-100000000.0")) @@ -126,19 +118,18 @@ def _fwd_kernel_with_v( acc_scale = tl.where(offs_m + prompt_cache_len >= start_n, acc_scale, 1.0) acc = acc * acc_scale[:, None] # update acc - off_v = k_loc[:, None] * stride_vbs + cur_k_head * stride_vh + offs_d[None, :] * stride_vd - v = tl.load(V + off_v, mask=(start_n + offs_n[:, None]) < block_end_loc, other=0.0) + v = tl.load( + v_ptrs + (cur_batch_in_all_start_index + start_n) * stride_vbs, + mask=(start_n + offs_n[:, None]) < block_end_loc, + other=0.0, + ) p = p.to(v.dtype) acc += tl.dot(p, v) # update m_i and l_i l_i = l_i_new m_i = m_i_new # initialize pointers to output - off_o = ( - (cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs - + cur_head * stride_oh - + offs_d[None, :] * stride_od - ) + off_o = (cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs + cur_head * stride_oh + offs_d[None, :] out_ptrs = Out + off_o tl.store(out_ptrs, acc, mask=offs_m[:, None] < cur_batch_seq_len) return @@ -152,12 +143,10 @@ def context_attention_fwd_with_v( k_rope, v, o, - b_req_idx, b_start_loc, b_seq_len, b_prompt_cache_len, max_input_len, - req_to_token_indexs, softmax_scale, ): @@ -194,28 +183,18 @@ def context_attention_fwd_with_v( b_start_loc, b_seq_len, o, - req_to_token_indexs, - b_req_idx, q_nope.stride(0), q_nope.stride(1), - q_nope.stride(2), q_rope.stride(0), q_rope.stride(1), - q_rope.stride(2), k_nope.stride(0), k_nope.stride(1), - k_nope.stride(2), k_rope.stride(0), k_rope.stride(1), - k_rope.stride(2), v.stride(0), v.stride(1), - v.stride(2), o.stride(0), o.stride(1), - o.stride(2), - req_to_token_indexs.stride(0), - req_to_token_indexs.stride(1), b_prompt_cache_len=b_prompt_cache_len, BLOCK_M=BLOCK, BLOCK_DMODEL=q_nope_dim, @@ -225,203 +204,3 @@ def context_attention_fwd_with_v( num_stages=1, ) return - - -@triton.jit -def _fwd_kernel_no_prompt_cache_with_v( - Q_nope, - Q_rope, - K_nope, - K_rope, - V, - sm_scale, - B_Start_Loc, - B_Seqlen, # B_LOC 内部记录每个batch 输入的真实位置, B_SEQ_len 记录当前输入的真实长度 - Out, - stride_q_bs, - stride_q_h, - stride_q_d, - stride_q_rope_bs, - stride_q_rope_h, - stride_q_rope_d, - stride_k_bs, - stride_k_h, - stride_k_d, - stride_k_rope_bs, - stride_k_rope_h, - stride_k_rope_d, - stride_vbs, - stride_vh, - stride_vd, - stride_obs, - stride_oh, - stride_od, - BLOCK_M: tl.constexpr, - BLOCK_DMODEL: tl.constexpr, - BLOCK_ROPE_DMODEL: tl.constexpr, - BLOCK_N: tl.constexpr, -): - cur_batch = tl.program_id(0) - cur_head = tl.program_id(1) - start_m = tl.program_id(2) - - cur_k_head = cur_head - - cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) - cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch) - - block_start_loc = BLOCK_M * start_m - - # initialize offsets - offs_n = tl.arange(0, BLOCK_N) - offs_d = tl.arange(0, BLOCK_DMODEL) - offs_rope_d = tl.arange(0, BLOCK_ROPE_DMODEL) - offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) - off_q = ( - (cur_batch_in_all_start_index + offs_m[:, None]) * stride_q_bs - + cur_head * stride_q_h - + offs_d[None, :] * stride_q_d - ) - off_rope_q = ( - (cur_batch_in_all_start_index + offs_m[:, None]) * stride_q_rope_bs - + cur_head * stride_q_rope_h - + offs_rope_d[None, :] * stride_q_rope_d - ) - off_k = offs_n[None, :] * stride_k_bs + cur_k_head * stride_k_h + offs_d[:, None] * stride_k_d - off_rope_k = offs_n[None, :] * stride_k_rope_bs + offs_rope_d[:, None] * stride_k_rope_d - off_v = offs_n[:, None] * stride_vbs + cur_k_head * stride_vh + offs_d[None, :] * stride_vd - - q = tl.load(Q_nope + off_q, mask=offs_m[:, None] < cur_batch_seq_len, other=0.0) - q_rope = tl.load(Q_rope + off_rope_q, mask=offs_m[:, None] < cur_batch_seq_len, other=0.0) - - k_ptrs = K_nope + off_k - k_rope_ptrs = K_rope + off_rope_k - v_ptrs = V + off_v - - # initialize pointer to m and l - m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") - l_i = tl.zeros([BLOCK_M], dtype=tl.float32) - acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) - - block_mask = tl.where(block_start_loc < cur_batch_seq_len, 1, 0) - - for start_n in range(0, block_mask * (start_m + 1) * BLOCK_M, BLOCK_N): - start_n = tl.multiple_of(start_n, BLOCK_N) - # -- compute qk ---- - k = tl.load( - k_ptrs + (cur_batch_in_all_start_index + start_n) * stride_k_bs, - mask=(start_n + offs_n[None, :]) < cur_batch_seq_len, - other=0.0, - ) - k_rope = tl.load( - k_rope_ptrs + (cur_batch_in_all_start_index + start_n) * stride_k_rope_bs, - mask=(start_n + offs_n[None, :]) < cur_batch_seq_len, - other=0.0, - ) - - qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - qk += tl.dot(q, k) - qk += tl.dot(q_rope, k_rope) - qk *= sm_scale - qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk, float("-inf")) - - # -- compute m_ij, p, l_ij - m_ij = tl.max(qk, 1) - p = tl.exp(qk - m_ij[:, None]) - l_ij = tl.sum(p, 1) - # -- update m_i and l_i - m_i_new = tl.maximum(m_i, m_ij) - alpha = tl.exp(m_i - m_i_new) - beta = tl.exp(m_ij - m_i_new) - l_i_new = alpha * l_i + beta * l_ij - # -- update output accumulator -- - # scale p - p_scale = beta / l_i_new - p = p * p_scale[:, None] - # scale acc - acc_scale = l_i / l_i_new * alpha - acc = acc * acc_scale[:, None] - # update acc - v = tl.load( - v_ptrs + (cur_batch_in_all_start_index + start_n) * stride_vbs, - mask=(start_n + offs_n[:, None]) < cur_batch_seq_len, - other=0.0, - ) - p = p.to(v.dtype) - acc += tl.dot(p, v) - # update m_i and l_i - l_i = l_i_new - m_i = m_i_new - # initialize pointers to output - off_o = ( - (cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs - + cur_head * stride_oh - + offs_d[None, :] * stride_od - ) - out_ptrs = Out + off_o - tl.store(out_ptrs, acc, mask=offs_m[:, None] < cur_batch_seq_len) - return - - -@torch.no_grad() -def context_attention_fwd_no_prompt_cache_with_v( - q_nope, q_rope, k_nope, k_rope, v, o, b_start_loc, b_seq_len, max_input_len, softmax_scale -): - q_nope_dim = q_nope.shape[-1] - q_rope_dim = q_rope.shape[-1] - assert q_nope_dim == k_nope.shape[-1] - assert q_rope_dim == k_rope.shape[-1] - assert q_nope_dim in {16, 32, 64, 128, 256, 512} - assert q_rope_dim in {16, 32, 64, 128, 256} - assert q_nope_dim == v.shape[-1] - - if q_nope_dim >= 512: - BLOCK = 64 if not TESLA else 32 - else: - BLOCK = 128 if not TESLA else 64 - - if q_nope.dtype == torch.float32: - BLOCK = BLOCK // 4 - - sm_scale = softmax_scale - batch, head = b_seq_len.shape[0], q_nope.shape[1] - - grid = (batch, head, triton.cdiv(max_input_len, BLOCK)) # batch, head, - - num_warps = 4 if q_nope_dim <= 64 else 8 - _fwd_kernel_no_prompt_cache_with_v[grid]( - q_nope, - q_rope, - k_nope, - k_rope, - v, - sm_scale, - b_start_loc, - b_seq_len, - o, - q_nope.stride(0), - q_nope.stride(1), - q_nope.stride(2), - q_rope.stride(0), - q_rope.stride(1), - q_rope.stride(2), - k_nope.stride(0), - k_nope.stride(1), - k_nope.stride(2), - k_rope.stride(0), - k_rope.stride(1), - k_rope.stride(2), - v.stride(0), - v.stride(1), - v.stride(2), - o.stride(0), - o.stride(1), - o.stride(2), - BLOCK_M=BLOCK, - BLOCK_DMODEL=q_nope_dim, - BLOCK_ROPE_DMODEL=q_rope_dim, - BLOCK_N=BLOCK, - num_warps=num_warps, - num_stages=1, - ) - return diff --git a/lightllm/models/deepseek2/triton_kernel/sample_kv.py b/lightllm/models/deepseek2/triton_kernel/sample_kv.py new file mode 100644 index 000000000..912f8603d --- /dev/null +++ b/lightllm/models/deepseek2/triton_kernel/sample_kv.py @@ -0,0 +1,104 @@ +import torch + +import triton +import triton.language as tl + +TESLA = "Tesla" in torch.cuda.get_device_name(0) +CUDA_CAPABILITY = torch.cuda.get_device_capability() + + +@triton.jit +def _sample_kv_kernel( + KV_input, + KV_nope, + KV_rope, + B_start_loc, + B_Seqlen, + Req_to_tokens, + B_req_idx, + stride_input_dim, + stride_nope_dim, + stride_rope_dim, + stride_req_to_tokens_b, + BLOCK_M: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, + BLOCK_ROPE_DMODEL: tl.constexpr, +): + cur_batch = tl.program_id(0) + start_m = tl.program_id(1) + + cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) + cur_batch_req_idx = tl.load(B_req_idx + cur_batch) + cur_batch_start_loc = tl.load(B_start_loc + cur_batch) + + offs_nope_d = tl.arange(0, BLOCK_DMODEL) + offs_rope_d = tl.arange(0, BLOCK_ROPE_DMODEL) + offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + + block_end_loc = tl.minimum((start_m + 1) * BLOCK_M, cur_batch_seq_len) + + kv_loc = tl.load( + Req_to_tokens + stride_req_to_tokens_b * cur_batch_req_idx + offs_m, + mask=offs_m < block_end_loc, + other=0, + ) + off_kv_nope = kv_loc[:, None] * stride_input_dim + offs_nope_d[None, :] + off_kv_rope = kv_loc[:, None] * stride_input_dim + (offs_rope_d + BLOCK_DMODEL)[None, :] + kv_nope = tl.load(KV_input + off_kv_nope, mask=offs_m[:, None] < block_end_loc, other=0.0) + kv_rope = tl.load(KV_input + off_kv_rope, mask=offs_m[:, None] < block_end_loc, other=0.0) + off_nope = (offs_m + cur_batch_start_loc)[:, None] * stride_nope_dim + offs_nope_d[None, :] + off_rope = (offs_m + cur_batch_start_loc)[:, None] * stride_rope_dim + offs_rope_d[None, :] + nope_ptrs = KV_nope + off_nope + rope_ptrs = KV_rope + off_rope + tl.store(nope_ptrs, kv_nope, mask=offs_m[:, None] < block_end_loc) + tl.store(rope_ptrs, kv_rope, mask=offs_m[:, None] < block_end_loc) + return + + +@torch.no_grad() +def sample_kv( + kv_input, + kv_nope, + kv_rope, + b_req_idx, + b_seq_len, + req_to_token_indexs, +): + BLOCK = 128 if not TESLA else 64 + + nope_dim = kv_nope.shape[-1] + rope_dim = kv_rope.shape[-1] + if nope_dim >= 512: + BLOCK = 64 if not TESLA else 32 + else: + BLOCK = 128 if not TESLA else 64 + + batch = b_seq_len.shape[0] + + max_input_len = b_seq_len.max() + grid = ( + batch, + triton.cdiv(max_input_len, BLOCK), + ) + num_warps = 4 if nope_dim <= 64 else 8 + + b_start_loc = torch.cat([torch.zeros([1], device=b_seq_len.device, dtype=b_seq_len.dtype), b_seq_len[1:].cumsum(0)]) + _sample_kv_kernel[grid]( + kv_input, + kv_nope, + kv_rope, + b_start_loc, + b_seq_len, + req_to_token_indexs, + b_req_idx, + kv_input.stride(0), + kv_nope.stride(0), + kv_rope.stride(0), + req_to_token_indexs.stride(0), + BLOCK_M=BLOCK, + BLOCK_DMODEL=nope_dim, + BLOCK_ROPE_DMODEL=rope_dim, + num_warps=num_warps, + num_stages=1, + ) + return