diff --git a/lightllm/common/all_kernel_configs/mla_decode_attentnion/{out_dtype=torch.bfloat16,q_head_dim=512,q_head_num=128,q_rope_dim=64}_NVIDIA_H800.json b/lightllm/common/all_kernel_configs/mla_decode_attentnion/{out_dtype=torch.bfloat16,q_head_dim=512,q_head_num=128,q_rope_dim=64}_NVIDIA_H800.json new file mode 100644 index 000000000..7272d3f06 --- /dev/null +++ b/lightllm/common/all_kernel_configs/mla_decode_attentnion/{out_dtype=torch.bfloat16,q_head_dim=512,q_head_num=128,q_rope_dim=64}_NVIDIA_H800.json @@ -0,0 +1 @@ +{"256": {"1": {"BLOCK_N": 32, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 4, "stage1_num_stages": 2, "stage2_num_warps": 4, "stage2_num_stages": 1}, "8": {"BLOCK_N": 32, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 4, "stage1_num_stages": 2, "stage2_num_warps": 4, "stage2_num_stages": 1}, "16": {"BLOCK_N": 32, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 4, "stage1_num_stages": 2, "stage2_num_warps": 4, "stage2_num_stages": 3}, "32": {"BLOCK_N": 32, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 4, "stage1_num_stages": 2, "stage2_num_warps": 2, "stage2_num_stages": 1}, "64": {"BLOCK_N": 32, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 4, "stage1_num_stages": 2, "stage2_num_warps": 1, "stage2_num_stages": 1}, "128": {"BLOCK_N": 32, "BLOCK_Q_HEAD": 32, "stage1_num_warps": 8, "stage1_num_stages": 3, "stage2_num_warps": 1, "stage2_num_stages": 3}, "256": {"BLOCK_N": 32, "BLOCK_Q_HEAD": 32, "stage1_num_warps": 8, "stage1_num_stages": 4, "stage2_num_warps": 1, "stage2_num_stages": 1}}, "512": {"1": {"BLOCK_N": 32, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 8, "stage1_num_stages": 3, "stage2_num_warps": 4, "stage2_num_stages": 1}, "8": {"BLOCK_N": 32, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 4, "stage1_num_stages": 2, "stage2_num_warps": 4, "stage2_num_stages": 1}, "16": {"BLOCK_N": 32, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 4, "stage1_num_stages": 2, "stage2_num_warps": 1, "stage2_num_stages": 3}, "32": {"BLOCK_N": 32, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 4, "stage1_num_stages": 2, "stage2_num_warps": 2, "stage2_num_stages": 1}, "64": {"BLOCK_N": 32, "BLOCK_Q_HEAD": 32, "stage1_num_warps": 8, "stage1_num_stages": 3, "stage2_num_warps": 4, "stage2_num_stages": 3}, "128": {"BLOCK_N": 32, "BLOCK_Q_HEAD": 32, "stage1_num_warps": 8, "stage1_num_stages": 4, "stage2_num_warps": 4, "stage2_num_stages": 3}, "256": {"BLOCK_N": 32, "BLOCK_Q_HEAD": 32, "stage1_num_warps": 8, "stage1_num_stages": 4, "stage2_num_warps": 2, "stage2_num_stages": 3}}, "1024": {"1": {"BLOCK_N": 32, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 4, "stage1_num_stages": 2, "stage2_num_warps": 4, "stage2_num_stages": 1}, "8": {"BLOCK_N": 32, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 4, "stage1_num_stages": 2, "stage2_num_warps": 4, "stage2_num_stages": 1}, "16": {"BLOCK_N": 32, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 4, "stage1_num_stages": 2, "stage2_num_warps": 1, "stage2_num_stages": 3}, "32": {"BLOCK_N": 32, "BLOCK_Q_HEAD": 32, "stage1_num_warps": 8, "stage1_num_stages": 3, "stage2_num_warps": 1, "stage2_num_stages": 1}, "64": {"BLOCK_N": 32, "BLOCK_Q_HEAD": 32, "stage1_num_warps": 8, "stage1_num_stages": 4, "stage2_num_warps": 1, "stage2_num_stages": 1}, "128": {"BLOCK_N": 32, "BLOCK_Q_HEAD": 32, "stage1_num_warps": 8, "stage1_num_stages": 4, "stage2_num_warps": 1, "stage2_num_stages": 1}, "256": {"BLOCK_N": 32, "BLOCK_Q_HEAD": 32, "stage1_num_warps": 8, "stage1_num_stages": 4, "stage2_num_warps": 2, "stage2_num_stages": 3}}, "2048": {"1": {"BLOCK_N": 32, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 4, "stage1_num_stages": 2, "stage2_num_warps": 4, "stage2_num_stages": 1}, "8": {"BLOCK_N": 32, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 4, "stage1_num_stages": 2, "stage2_num_warps": 4, "stage2_num_stages": 1}, "16": {"BLOCK_N": 32, "BLOCK_Q_HEAD": 32, "stage1_num_warps": 8, "stage1_num_stages": 3, "stage2_num_warps": 1, "stage2_num_stages": 3}, "32": {"BLOCK_N": 32, "BLOCK_Q_HEAD": 32, "stage1_num_warps": 8, "stage1_num_stages": 4, "stage2_num_warps": 2, "stage2_num_stages": 1}, "64": {"BLOCK_N": 32, "BLOCK_Q_HEAD": 32, "stage1_num_warps": 8, "stage1_num_stages": 4, "stage2_num_warps": 1, "stage2_num_stages": 3}, "128": {"BLOCK_N": 32, "BLOCK_Q_HEAD": 32, "stage1_num_warps": 8, "stage1_num_stages": 4, "stage2_num_warps": 1, "stage2_num_stages": 1}, "256": {"BLOCK_N": 32, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 4, "stage1_num_stages": 4, "stage2_num_warps": 2, "stage2_num_stages": 3}}, "4096": {"1": {"BLOCK_N": 32, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 4, "stage1_num_stages": 2, "stage2_num_warps": 4, "stage2_num_stages": 1}, "8": {"BLOCK_N": 32, "BLOCK_Q_HEAD": 32, "stage1_num_warps": 8, "stage1_num_stages": 3, "stage2_num_warps": 4, "stage2_num_stages": 1}, "16": {"BLOCK_N": 32, "BLOCK_Q_HEAD": 32, "stage1_num_warps": 8, "stage1_num_stages": 4, "stage2_num_warps": 2, "stage2_num_stages": 1}, "32": {"BLOCK_N": 32, "BLOCK_Q_HEAD": 32, "stage1_num_warps": 8, "stage1_num_stages": 4, "stage2_num_warps": 2, "stage2_num_stages": 3}, "64": {"BLOCK_N": 32, "BLOCK_Q_HEAD": 32, "stage1_num_warps": 8, "stage1_num_stages": 4, "stage2_num_warps": 2, "stage2_num_stages": 3}, "128": {"BLOCK_N": 32, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 4, "stage1_num_stages": 4, "stage2_num_warps": 2, "stage2_num_stages": 1}}, "8192": {"1": {"BLOCK_N": 32, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 4, "stage1_num_stages": 2, "stage2_num_warps": 4, "stage2_num_stages": 1}, "8": {"BLOCK_N": 32, "BLOCK_Q_HEAD": 32, "stage1_num_warps": 8, "stage1_num_stages": 4, "stage2_num_warps": 4, "stage2_num_stages": 1}, "16": {"BLOCK_N": 32, "BLOCK_Q_HEAD": 32, "stage1_num_warps": 8, "stage1_num_stages": 4, "stage2_num_warps": 1, "stage2_num_stages": 3}, "32": {"BLOCK_N": 32, "BLOCK_Q_HEAD": 32, "stage1_num_warps": 8, "stage1_num_stages": 4, "stage2_num_warps": 1, "stage2_num_stages": 1}, "64": {"BLOCK_N": 32, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 4, "stage1_num_stages": 4, "stage2_num_warps": 1, "stage2_num_stages": 3}}} \ No newline at end of file diff --git a/lightllm/common/all_kernel_configs/mla_decode_attentnion/{out_dtype=torch.bfloat16,q_head_dim=512,q_head_num=16,q_rope_dim=64}_NVIDIA_H800.json b/lightllm/common/all_kernel_configs/mla_decode_attentnion/{out_dtype=torch.bfloat16,q_head_dim=512,q_head_num=16,q_rope_dim=64}_NVIDIA_H800.json index 4abf69e8b..26143b609 100644 --- a/lightllm/common/all_kernel_configs/mla_decode_attentnion/{out_dtype=torch.bfloat16,q_head_dim=512,q_head_num=16,q_rope_dim=64}_NVIDIA_H800.json +++ b/lightllm/common/all_kernel_configs/mla_decode_attentnion/{out_dtype=torch.bfloat16,q_head_dim=512,q_head_num=16,q_rope_dim=64}_NVIDIA_H800.json @@ -1 +1 @@ -{"256": {"1": {"BLOCK_SEQ": 32, "BLOCK_N": 32, "BLOCK_Q_HEAD": 32, "stage1_num_warps": 4, "stage1_num_stages": 3, "stage2_num_warps": 8, "stage2_num_stages": 1}, "8": {"BLOCK_SEQ": 32, "BLOCK_N": 32, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 4, "stage1_num_stages": 3, "stage2_num_warps": 8, "stage2_num_stages": 1}, "16": {"BLOCK_SEQ": 64, "BLOCK_N": 64, "BLOCK_Q_HEAD": 64, "stage1_num_warps": 8, "stage1_num_stages": 5, "stage2_num_warps": 4, "stage2_num_stages": 1}, "32": {"BLOCK_SEQ": 64, "BLOCK_N": 64, "BLOCK_Q_HEAD": 64, "stage1_num_warps": 8, "stage1_num_stages": 5, "stage2_num_warps": 4, "stage2_num_stages": 1}, "64": {"BLOCK_SEQ": 128, "BLOCK_N": 32, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 4, "stage1_num_stages": 4, "stage2_num_warps": 4, "stage2_num_stages": 1}, "128": {"BLOCK_SEQ": 256, "BLOCK_N": 64, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 8, "stage1_num_stages": 4, "stage2_num_warps": 1, "stage2_num_stages": 1}, "256": {"BLOCK_SEQ": 256, "BLOCK_N": 32, "BLOCK_Q_HEAD": 32, "stage1_num_warps": 4, "stage1_num_stages": 4, "stage2_num_warps": 2, "stage2_num_stages": 1}}, "512": {"1": {"BLOCK_SEQ": 32, "BLOCK_N": 32, "BLOCK_Q_HEAD": 64, "stage1_num_warps": 4, "stage1_num_stages": 2, "stage2_num_warps": 16, "stage2_num_stages": 1}, "8": {"BLOCK_SEQ": 64, "BLOCK_N": 64, "BLOCK_Q_HEAD": 64, "stage1_num_warps": 8, "stage1_num_stages": 5, "stage2_num_warps": 4, "stage2_num_stages": 1}, "16": {"BLOCK_SEQ": 64, "BLOCK_N": 32, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 4, "stage1_num_stages": 4, "stage2_num_warps": 4, "stage2_num_stages": 1}, "32": {"BLOCK_SEQ": 128, "BLOCK_N": 32, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 4, "stage1_num_stages": 4, "stage2_num_warps": 8, "stage2_num_stages": 1}, "64": {"BLOCK_SEQ": 256, "BLOCK_N": 64, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 8, "stage1_num_stages": 5, "stage2_num_warps": 4, "stage2_num_stages": 1}, "128": {"BLOCK_SEQ": 256, "BLOCK_N": 32, "BLOCK_Q_HEAD": 64, "stage1_num_warps": 4, "stage1_num_stages": 4, "stage2_num_warps": 4, "stage2_num_stages": 1}, "256": {"BLOCK_SEQ": 256, "BLOCK_N": 16, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 2, "stage1_num_stages": 3, "stage2_num_warps": 2, "stage2_num_stages": 1}}, "1024": {"1": {"BLOCK_SEQ": 64, "BLOCK_N": 64, "BLOCK_Q_HEAD": 32, "stage1_num_warps": 4, "stage1_num_stages": 2, "stage2_num_warps": 16, "stage2_num_stages": 1}, "8": {"BLOCK_SEQ": 64, "BLOCK_N": 32, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 4, "stage1_num_stages": 4, "stage2_num_warps": 16, "stage2_num_stages": 1}, "16": {"BLOCK_SEQ": 128, "BLOCK_N": 32, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 4, "stage1_num_stages": 4, "stage2_num_warps": 8, "stage2_num_stages": 1}, "32": {"BLOCK_SEQ": 256, "BLOCK_N": 64, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 8, "stage1_num_stages": 5, "stage2_num_warps": 4, "stage2_num_stages": 1}, "64": {"BLOCK_SEQ": 256, "BLOCK_N": 32, "BLOCK_Q_HEAD": 32, "stage1_num_warps": 4, "stage1_num_stages": 4, "stage2_num_warps": 4, "stage2_num_stages": 1}, "128": {"BLOCK_SEQ": 256, "BLOCK_N": 16, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 2, "stage1_num_stages": 3, "stage2_num_warps": 4, "stage2_num_stages": 1}, "256": {"BLOCK_SEQ": 256, "BLOCK_N": 16, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 2, "stage1_num_stages": 3, "stage2_num_warps": 2, "stage2_num_stages": 1}}, "2048": {"1": {"BLOCK_SEQ": 64, "BLOCK_N": 64, "BLOCK_Q_HEAD": 64, "stage1_num_warps": 8, "stage1_num_stages": 4, "stage2_num_warps": 16, "stage2_num_stages": 1}, "8": {"BLOCK_SEQ": 128, "BLOCK_N": 32, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 4, "stage1_num_stages": 4, "stage2_num_warps": 16, "stage2_num_stages": 1}, "16": {"BLOCK_SEQ": 256, "BLOCK_N": 64, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 8, "stage1_num_stages": 5, "stage2_num_warps": 4, "stage2_num_stages": 5}, "32": {"BLOCK_SEQ": 256, "BLOCK_N": 32, "BLOCK_Q_HEAD": 64, "stage1_num_warps": 4, "stage1_num_stages": 4, "stage2_num_warps": 4, "stage2_num_stages": 1}, "64": {"BLOCK_SEQ": 256, "BLOCK_N": 16, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 2, "stage1_num_stages": 3, "stage2_num_warps": 8, "stage2_num_stages": 1}, "128": {"BLOCK_SEQ": 256, "BLOCK_N": 16, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 2, "stage1_num_stages": 3, "stage2_num_warps": 4, "stage2_num_stages": 1}, "256": {"BLOCK_SEQ": 256, "BLOCK_N": 32, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 4, "stage1_num_stages": 5, "stage2_num_warps": 1, "stage2_num_stages": 1}}, "4096": {"1": {"BLOCK_SEQ": 128, "BLOCK_N": 64, "BLOCK_Q_HEAD": 64, "stage1_num_warps": 8, "stage1_num_stages": 5, "stage2_num_warps": 8, "stage2_num_stages": 1}, "8": {"BLOCK_SEQ": 256, "BLOCK_N": 64, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 8, "stage1_num_stages": 5, "stage2_num_warps": 4, "stage2_num_stages": 4}, "16": {"BLOCK_SEQ": 256, "BLOCK_N": 32, "BLOCK_Q_HEAD": 64, "stage1_num_warps": 4, "stage1_num_stages": 4, "stage2_num_warps": 8, "stage2_num_stages": 1}, "32": {"BLOCK_SEQ": 256, "BLOCK_N": 16, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 2, "stage1_num_stages": 3, "stage2_num_warps": 8, "stage2_num_stages": 1}, "64": {"BLOCK_SEQ": 256, "BLOCK_N": 16, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 2, "stage1_num_stages": 3, "stage2_num_warps": 8, "stage2_num_stages": 1}, "128": {"BLOCK_SEQ": 256, "BLOCK_N": 32, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 4, "stage1_num_stages": 4, "stage2_num_warps": 4, "stage2_num_stages": 1}}, "8192": {"1": {"BLOCK_SEQ": 128, "BLOCK_N": 64, "BLOCK_Q_HEAD": 64, "stage1_num_warps": 8, "stage1_num_stages": 5, "stage2_num_warps": 16, "stage2_num_stages": 1}, "8": {"BLOCK_SEQ": 256, "BLOCK_N": 32, "BLOCK_Q_HEAD": 64, "stage1_num_warps": 4, "stage1_num_stages": 4, "stage2_num_warps": 16, "stage2_num_stages": 1}, "16": {"BLOCK_SEQ": 256, "BLOCK_N": 16, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 2, "stage1_num_stages": 3, "stage2_num_warps": 8, "stage2_num_stages": 1}, "32": {"BLOCK_SEQ": 256, "BLOCK_N": 32, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 4, "stage1_num_stages": 5, "stage2_num_warps": 16, "stage2_num_stages": 1}, "64": {"BLOCK_SEQ": 256, "BLOCK_N": 32, "BLOCK_Q_HEAD": 32, "stage1_num_warps": 4, "stage1_num_stages": 4, "stage2_num_warps": 4, "stage2_num_stages": 1}}} \ No newline at end of file +{"256": {"1": {"BLOCK_N": 32, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 4, "stage1_num_stages": 2, "stage2_num_warps": 4, "stage2_num_stages": 1}, "8": {"BLOCK_N": 32, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 8, "stage1_num_stages": 3, "stage2_num_warps": 4, "stage2_num_stages": 1}, "16": {"BLOCK_N": 32, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 8, "stage1_num_stages": 3, "stage2_num_warps": 4, "stage2_num_stages": 1}, "32": {"BLOCK_N": 32, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 4, "stage1_num_stages": 2, "stage2_num_warps": 4, "stage2_num_stages": 1}, "64": {"BLOCK_N": 32, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 4, "stage1_num_stages": 2, "stage2_num_warps": 4, "stage2_num_stages": 1}, "128": {"BLOCK_N": 32, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 8, "stage1_num_stages": 5, "stage2_num_warps": 4, "stage2_num_stages": 3}, "256": {"BLOCK_N": 32, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 8, "stage1_num_stages": 5, "stage2_num_warps": 1, "stage2_num_stages": 1}}, "512": {"1": {"BLOCK_N": 32, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 4, "stage1_num_stages": 2, "stage2_num_warps": 4, "stage2_num_stages": 1}, "8": {"BLOCK_N": 32, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 8, "stage1_num_stages": 3, "stage2_num_warps": 4, "stage2_num_stages": 1}, "16": {"BLOCK_N": 32, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 4, "stage1_num_stages": 2, "stage2_num_warps": 4, "stage2_num_stages": 1}, "32": {"BLOCK_N": 32, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 4, "stage1_num_stages": 2, "stage2_num_warps": 4, "stage2_num_stages": 1}, "64": {"BLOCK_N": 32, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 8, "stage1_num_stages": 4, "stage2_num_warps": 4, "stage2_num_stages": 1}, "128": {"BLOCK_N": 32, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 8, "stage1_num_stages": 5, "stage2_num_warps": 1, "stage2_num_stages": 1}, "256": {"BLOCK_N": 32, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 8, "stage1_num_stages": 5, "stage2_num_warps": 2, "stage2_num_stages": 1}}, "1024": {"1": {"BLOCK_N": 32, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 8, "stage1_num_stages": 3, "stage2_num_warps": 4, "stage2_num_stages": 1}, "8": {"BLOCK_N": 32, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 4, "stage1_num_stages": 2, "stage2_num_warps": 4, "stage2_num_stages": 1}, "16": {"BLOCK_N": 32, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 4, "stage1_num_stages": 2, "stage2_num_warps": 4, "stage2_num_stages": 1}, "32": {"BLOCK_N": 32, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 8, "stage1_num_stages": 4, "stage2_num_warps": 4, "stage2_num_stages": 1}, "64": {"BLOCK_N": 32, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 8, "stage1_num_stages": 4, "stage2_num_warps": 4, "stage2_num_stages": 1}, "128": {"BLOCK_N": 32, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 8, "stage1_num_stages": 5, "stage2_num_warps": 1, "stage2_num_stages": 1}, "256": {"BLOCK_N": 32, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 8, "stage1_num_stages": 5, "stage2_num_warps": 1, "stage2_num_stages": 1}}, "2048": {"1": {"BLOCK_N": 32, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 8, "stage1_num_stages": 1, "stage2_num_warps": 4, "stage2_num_stages": 1}, "8": {"BLOCK_N": 32, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 4, "stage1_num_stages": 2, "stage2_num_warps": 4, "stage2_num_stages": 1}, "16": {"BLOCK_N": 32, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 8, "stage1_num_stages": 4, "stage2_num_warps": 4, "stage2_num_stages": 1}, "32": {"BLOCK_N": 32, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 8, "stage1_num_stages": 4, "stage2_num_warps": 4, "stage2_num_stages": 1}, "64": {"BLOCK_N": 32, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 8, "stage1_num_stages": 4, "stage2_num_warps": 4, "stage2_num_stages": 1}, "128": {"BLOCK_N": 32, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 8, "stage1_num_stages": 5, "stage2_num_warps": 4, "stage2_num_stages": 3}, "256": {"BLOCK_N": 32, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 4, "stage1_num_stages": 4, "stage2_num_warps": 1, "stage2_num_stages": 1}}, "4096": {"1": {"BLOCK_N": 32, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 8, "stage1_num_stages": 3, "stage2_num_warps": 4, "stage2_num_stages": 1}, "8": {"BLOCK_N": 32, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 8, "stage1_num_stages": 5, "stage2_num_warps": 4, "stage2_num_stages": 1}, "16": {"BLOCK_N": 32, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 8, "stage1_num_stages": 4, "stage2_num_warps": 4, "stage2_num_stages": 1}, "32": {"BLOCK_N": 32, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 8, "stage1_num_stages": 4, "stage2_num_warps": 4, "stage2_num_stages": 1}, "64": {"BLOCK_N": 32, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 8, "stage1_num_stages": 5, "stage2_num_warps": 1, "stage2_num_stages": 3}, "128": {"BLOCK_N": 32, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 4, "stage1_num_stages": 4, "stage2_num_warps": 2, "stage2_num_stages": 1}}, "8192": {"1": {"BLOCK_N": 32, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 4, "stage1_num_stages": 2, "stage2_num_warps": 4, "stage2_num_stages": 1}, "8": {"BLOCK_N": 32, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 8, "stage1_num_stages": 4, "stage2_num_warps": 4, "stage2_num_stages": 1}, "16": {"BLOCK_N": 32, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 8, "stage1_num_stages": 5, "stage2_num_warps": 4, "stage2_num_stages": 1}, "32": {"BLOCK_N": 32, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 8, "stage1_num_stages": 5, "stage2_num_warps": 1, "stage2_num_stages": 3}, "64": {"BLOCK_N": 32, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 4, "stage1_num_stages": 4, "stage2_num_warps": 1, "stage2_num_stages": 3}}} \ No newline at end of file diff --git a/lightllm/common/vllm_kernel/_ops.py b/lightllm/common/vllm_kernel/_ops.py index d8c8bd4b6..00fdcb748 100644 --- a/lightllm/common/vllm_kernel/_ops.py +++ b/lightllm/common/vllm_kernel/_ops.py @@ -37,11 +37,6 @@ supports_moe_ops = True -try: - from torch.library import register_fake -except ImportError: - from torch.library import impl_abstract as register_fake - # for vllm_quant.py torch.ops._C.cutlass_scaled_mm = torch.ops.vllm_total.cutlass_scaled_mm @@ -327,21 +322,6 @@ def gptq_gemm( return torch.ops.vllm_total.gptq_gemm(a, b_q_weight, b_gptq_qzeros, b_gptq_scales, b_g_idx, use_exllama, bit) -if hasattr(torch.ops.vllm_total, "gptq_gemm"): - - @register_fake("vllm_total::gptq_gemm") - def _gptq_gemm_fake( - a: torch.Tensor, - b_q_weight: torch.Tensor, - b_gptq_qzeros: torch.Tensor, - b_gptq_scales: torch.Tensor, - b_g_idx: torch.Tensor, - use_exllama: bool, - bit: int, - ) -> torch.Tensor: - return torch.empty((a.size(0), b_q_weight.size(1)), dtype=a.dtype, device=a.device) - - def gptq_shuffle(q_weight: torch.Tensor, q_perm: torch.Tensor, bit: int) -> None: torch.ops.vllm_total.gptq_shuffle(q_weight, q_perm, bit) @@ -376,179 +356,6 @@ def gptq_marlin_24_gemm( ) -if hasattr(torch.ops.vllm_total, "gptq_marlin_24_gemm"): - - @register_fake("vllm_total::gptq_marlin_24_gemm") - def _gptq_marlin_24_gemm_fake( - a: torch.Tensor, - b_q_weight: torch.Tensor, - b_meta: torch.Tensor, - b_scales: torch.Tensor, - workspace: torch.Tensor, - b_q_type, - size_m: torch.SymInt, - size_n: torch.SymInt, - size_k: torch.SymInt, - ) -> torch.Tensor: - return torch.empty((size_m, size_n), device=a.device, dtype=a.dtype) - - @register_fake("vllm_total::gptq_marlin_gemm") - def _gptq_marlin_gemm_fake( - a: torch.Tensor, - b_q_weight: torch.Tensor, - b_scales: torch.Tensor, - b_zeros: torch.Tensor, - g_idx: torch.Tensor, - perm: torch.Tensor, - workspace: torch.Tensor, - b_q_type, - size_m: torch.SymInt, - size_n: torch.SymInt, - size_k: torch.SymInt, - is_k_full: bool, - has_zp: bool = False, - use_fp32_reduce: bool = False, - ) -> torch.Tensor: - return torch.empty((size_m, size_n), device=a.device, dtype=a.dtype) - - @register_fake("vllm_total::ggml_dequantize") - def _ggml_dequantize_fake(W: torch.Tensor, quant_type: int, m: torch.SymInt, n: torch.SymInt) -> torch.Tensor: - return torch.empty((m, n), dtype=torch.float16, device=W.device) - - @register_fake("vllm_total::ggml_mul_mat_vec_a8") - def _ggml_mul_mat_vec_a8_fake( - W: torch.Tensor, - X: torch.Tensor, - quant_type: int, - row: torch.SymInt, - ) -> torch.Tensor: - return torch.empty((1, row), dtype=torch.float16, device=W.device) - - @register_fake("vllm_total::ggml_mul_mat_a8") - def _ggml_mul_mat_a8_fake( - W: torch.Tensor, - X: torch.Tensor, - quant_type: int, - row: torch.SymInt, - ) -> torch.Tensor: - batch = X.size(0) - return torch.empty((batch, row), dtype=torch.float16, device=W.device) - - @register_fake("vllm_total::marlin_qqq_gemm") - def _marlin_qqq_gemm_fake( - a: torch.Tensor, - b_q_weight: torch.Tensor, - s_tok: torch.Tensor, - s_ch: torch.Tensor, - s_group: torch.Tensor, - workspace: torch.Tensor, - size_m: torch.SymInt, - size_n: torch.SymInt, - size_k: torch.SymInt, - ) -> torch.Tensor: - return torch.empty((size_m, size_n), dtype=torch.float16, device=a.device) - - @register_fake("vllm_total::marlin_gemm") - def _marlin_gemm_fake( - a: torch.Tensor, - b_q_weight: torch.Tensor, - b_scales: torch.Tensor, - workspace: torch.Tensor, - size_m: torch.SymInt, - size_n: torch.SymInt, - size_k: torch.SymInt, - ) -> torch.Tensor: - return torch.empty((size_m, size_n), dtype=torch.float16, device=a.device) - - @register_fake("vllm_total::awq_dequantize") - def _awq_dequantize_fake( - qweight: torch.Tensor, - scales: torch.Tensor, - zeros: torch.Tensor, - split_k_iters: torch.SymInt, - thx: int, - thy: int, - ) -> torch.Tensor: - in_c = qweight.size(0) - qout_c = qweight.size(1) - out_c = qout_c * 8 - return torch.empty((in_c, out_c), dtype=scales.dtype, device=scales.device) - - @register_fake("vllm_total::awq_gemm") - def _awq_gemm_fake( - input: torch.Tensor, - qweight: torch.Tensor, - qzeros: torch.Tensor, - scales: torch.Tensor, - split_k_iters: torch.SymInt, - ) -> torch.Tensor: - num_in_feats = input.size(0) - return torch.empty( - (split_k_iters, num_in_feats, qweight.size(1) * 8), dtype=input.dtype, device=input.device - ).sum(0) - - @register_fake("vllm_total::aqlm_gemm") - def _aqlm_gemm_fake( - input: torch.Tensor, - codes: torch.Tensor, - codebooks: torch.Tensor, - scales: torch.Tensor, - codebook_partition_sizes: List[int], - bias: Optional[torch.Tensor], - ) -> torch.Tensor: - out_features = codes.size(0) * codebooks.size(2) - flat_input = input.reshape((-1, input.size(-1))) - flat_output = torch.empty((flat_input.size(0), out_features), dtype=input.dtype, device=input.device) - - output_sizes = list(input.shape) - output_sizes.pop() - output_sizes.append(-1) - return flat_output.reshape(tuple(output_sizes)) - - @register_fake("vllm_total::aqlm_dequant") - def _aqlm_dequant_fake( - codes: torch.Tensor, codebooks: torch.Tensor, codebook_partition_sizes: List[int] - ) -> torch.Tensor: - in_features = codes.size(1) * 8 - out_features = codes.size(0) - return torch.empty((out_features, in_features), dtype=codebooks.dtype, device=codebooks.device) - - @register_fake("vllm_total::fp8_marlin_gemm") - def _fp8_marlin_gemm_fake( - a: torch.Tensor, - b_q_weight: torch.Tensor, - b_scales: torch.Tensor, - workspace: torch.Tensor, - num_bits: int, - size_m: torch.SymInt, - size_n: torch.SymInt, - size_k: torch.SymInt, - ) -> torch.Tensor: - return torch.empty((size_m, size_n), dtype=a.dtype, device=a.device) - - @register_fake("vllm_total::machete_gemm") - def machete_gemm_fake( - a: torch.Tensor, - # Should be the tensor returned by machete_prepack_B - b_q: torch.Tensor, - b_type, - b_scales: Optional[torch.Tensor] = None, - b_zeros: Optional[torch.Tensor] = None, - b_group_size: Optional[int] = None, - c: Optional[torch.Tensor] = None, - alpha: Optional[float] = None, - beta: Optional[float] = None, - schedule: Optional[str] = None, - ) -> torch.Tensor: - m = a.size(0) - n = b_q.size(1) - return torch.empty((m, n), device=a.device, dtype=a.dtype) - - @register_fake("vllm_total::machete_prepack_B") - def machete_prepack_B_fake(b_q_weight: torch.Tensor, b_type) -> torch.Tensor: - return torch.empty_like(b_q_weight, memory_format=torch.contiguous_format) - - # cutlass def cutlass_scaled_mm_supports_fp8(cuda_device_capability: int) -> bool: return torch.ops.vllm_total.cutlass_scaled_mm_supports_fp8(cuda_device_capability) @@ -732,13 +539,6 @@ def machete_prepack_B(b_q_weight: torch.Tensor, b_type) -> torch.Tensor: return torch.ops.vllm_total.machete_prepack_B(b_q_weight, b_type.id) -if hasattr(torch.ops.vllm_total, "permute_cols"): - - @register_fake("vllm_total::permute_cols") - def _permute_cols_fake(a: torch.Tensor, perm: torch.Tensor) -> torch.Tensor: - return torch.empty_like(a) - - def permute_cols(a: torch.Tensor, perm: torch.Tensor) -> torch.Tensor: return torch.ops.vllm_total.permute_cols(a, perm) @@ -961,34 +761,6 @@ def topk_softmax( torch.ops.vllm_moe.topk_softmax(topk_weights, topk_ids, token_expert_indicies, gating_output) -if supports_moe_ops and hasattr(torch.ops.vllm_moe, "marlin_gemm_moe"): - - @register_fake("vllm_moe::marlin_gemm_moe") - def marlin_gemm_moe_fake( - a: torch.Tensor, - b_q_weights: torch.Tensor, - sorted_ids: torch.Tensor, - topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - b_scales: torch.Tensor, - b_zero_points: torch.Tensor, - g_idx: torch.Tensor, - perm: torch.Tensor, - workspace: torch.Tensor, - b_q_type, - size_m: torch.SymInt, - size_n: torch.SymInt, - size_k: torch.SymInt, - is_k_full: bool, - num_experts: int, - topk: int, - moe_block_size: int, - replicate_input: bool, - apply_weights: bool, - ) -> torch.Tensor: - return torch.empty((size_m, topk, size_n), dtype=a.dtype, device=a.device) - - def reshape_and_cache( key: torch.Tensor, value: torch.Tensor, diff --git a/lightllm/server/router/req_queue/continues_batch/impl.py b/lightllm/server/router/req_queue/continues_batch/impl.py index 07cc83379..3c6b2d38d 100644 --- a/lightllm/server/router/req_queue/continues_batch/impl.py +++ b/lightllm/server/router/req_queue/continues_batch/impl.py @@ -104,12 +104,15 @@ def generate_new_batch(self, current_batch: Batch): def _calcu_batch_token_load_batch_not_none(self, current_batch: Batch): is_busy = self.is_busy() self._init_cache_list(current_batch, is_busy) - self.cache_len_list.sort(key=lambda x: -x[1]) - left_out_len_array = np.array([e[1] for e in self.cache_len_list]) - has_run_len_array = np.array([e[0] for e in self.cache_len_list]) - cum_run_len_array = np.cumsum(has_run_len_array) - size_array = np.arange(1, len(self.cache_len_list) + 1, 1) - need_max_token_num = (left_out_len_array * size_array + cum_run_len_array).max() + if len(self.cache_len_list) != 0: + self.cache_len_list.sort(key=lambda x: -x[1]) + left_out_len_array = np.array([e[1] for e in self.cache_len_list]) + has_run_len_array = np.array([e[0] for e in self.cache_len_list]) + cum_run_len_array = np.cumsum(has_run_len_array) + size_array = np.arange(1, len(self.cache_len_list) + 1, 1) + need_max_token_num = (left_out_len_array * size_array + cum_run_len_array).max() + else: + need_max_token_num = 0 return ( need_max_token_num, (need_max_token_num + self.router.shared_token_load.get_frozened_token_count(self.dp_index))