diff --git a/lightllm/common/all_kernel_configs/mla_decode_attentnion/{out_dtype=torch.bfloat16,q_head_dim=512,q_head_num=16,q_rope_dim=64}_NVIDIA_A800-SXM4-80GB.json b/lightllm/common/all_kernel_configs/mla_decode_attentnion/{out_dtype=torch.bfloat16,q_head_dim=512,q_head_num=16,q_rope_dim=64}_NVIDIA_A800-SXM4-80GB.json index 15502895f..9fe18f792 100644 --- a/lightllm/common/all_kernel_configs/mla_decode_attentnion/{out_dtype=torch.bfloat16,q_head_dim=512,q_head_num=16,q_rope_dim=64}_NVIDIA_A800-SXM4-80GB.json +++ b/lightllm/common/all_kernel_configs/mla_decode_attentnion/{out_dtype=torch.bfloat16,q_head_dim=512,q_head_num=16,q_rope_dim=64}_NVIDIA_A800-SXM4-80GB.json @@ -1 +1 @@ -{"256": {"1": {"BLOCK_SEQ": 32, "BLOCK_N": 32, "BLOCK_Q_HEAD": 32, "stage1_num_warps": 4, "stage1_num_stages": 2, "stage2_num_warps": 8, "stage2_num_stages": 1}, "8": {"BLOCK_SEQ": 32, "BLOCK_N": 32, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 4, "stage1_num_stages": 3, "stage2_num_warps": 4, "stage2_num_stages": 1}, "16": {"BLOCK_SEQ": 64, "BLOCK_N": 32, "BLOCK_Q_HEAD": 32, "stage1_num_warps": 4, "stage1_num_stages": 4, "stage2_num_warps": 8, "stage2_num_stages": 1}, "32": {"BLOCK_SEQ": 128, "BLOCK_N": 32, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 4, "stage1_num_stages": 4, "stage2_num_warps": 2, "stage2_num_stages": 1}, "64": {"BLOCK_SEQ": 256, "BLOCK_N": 32, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 4, "stage1_num_stages": 4, "stage2_num_warps": 1, "stage2_num_stages": 1}, "128": {"BLOCK_SEQ": 256, "BLOCK_N": 32, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 4, "stage1_num_stages": 3, "stage2_num_warps": 1, "stage2_num_stages": 1}, "256": {"BLOCK_SEQ": 256, "BLOCK_N": 16, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 2, "stage1_num_stages": 3, "stage2_num_warps": 2, "stage2_num_stages": 1}}, "512": {"1": {"BLOCK_SEQ": 32, "BLOCK_N": 32, "BLOCK_Q_HEAD": 32, "stage1_num_warps": 4, "stage1_num_stages": 2, "stage2_num_warps": 8, "stage2_num_stages": 1}, "8": {"BLOCK_SEQ": 64, "BLOCK_N": 32, "BLOCK_Q_HEAD": 64, "stage1_num_warps": 4, "stage1_num_stages": 4, "stage2_num_warps": 8, "stage2_num_stages": 1}, "16": {"BLOCK_SEQ": 128, "BLOCK_N": 32, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 4, "stage1_num_stages": 4, "stage2_num_warps": 4, "stage2_num_stages": 1}, "32": {"BLOCK_SEQ": 256, "BLOCK_N": 32, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 4, "stage1_num_stages": 4, "stage2_num_warps": 8, "stage2_num_stages": 1}, "64": {"BLOCK_SEQ": 128, "BLOCK_N": 32, "BLOCK_Q_HEAD": 64, "stage1_num_warps": 4, "stage1_num_stages": 3, "stage2_num_warps": 4, "stage2_num_stages": 1}, "128": {"BLOCK_SEQ": 256, "BLOCK_N": 16, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 2, "stage1_num_stages": 3, "stage2_num_warps": 1, "stage2_num_stages": 1}, "256": {"BLOCK_SEQ": 256, "BLOCK_N": 32, "BLOCK_Q_HEAD": 64, "stage1_num_warps": 4, "stage1_num_stages": 3, "stage2_num_warps": 4, "stage2_num_stages": 1}}, "1024": {"1": {"BLOCK_SEQ": 32, "BLOCK_N": 32, "BLOCK_Q_HEAD": 32, "stage1_num_warps": 4, "stage1_num_stages": 3, "stage2_num_warps": 8, "stage2_num_stages": 1}, "8": {"BLOCK_SEQ": 128, "BLOCK_N": 32, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 4, "stage1_num_stages": 4, "stage2_num_warps": 8, "stage2_num_stages": 1}, "16": {"BLOCK_SEQ": 256, "BLOCK_N": 32, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 4, "stage1_num_stages": 4, "stage2_num_warps": 8, "stage2_num_stages": 1}, "32": {"BLOCK_SEQ": 128, "BLOCK_N": 32, "BLOCK_Q_HEAD": 32, "stage1_num_warps": 4, "stage1_num_stages": 3, "stage2_num_warps": 4, "stage2_num_stages": 1}, "64": {"BLOCK_SEQ": 256, "BLOCK_N": 16, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 2, "stage1_num_stages": 3, "stage2_num_warps": 2, "stage2_num_stages": 1}, "128": {"BLOCK_SEQ": 256, "BLOCK_N": 32, "BLOCK_Q_HEAD": 64, "stage1_num_warps": 4, "stage1_num_stages": 3, "stage2_num_warps": 1, "stage2_num_stages": 2}, "256": {"BLOCK_SEQ": 256, "BLOCK_N": 32, "BLOCK_Q_HEAD": 32, "stage1_num_warps": 4, "stage1_num_stages": 3, "stage2_num_warps": 1, "stage2_num_stages": 1}}, "2048": {"1": {"BLOCK_SEQ": 64, "BLOCK_N": 32, "BLOCK_Q_HEAD": 32, "stage1_num_warps": 4, "stage1_num_stages": 4, "stage2_num_warps": 8, "stage2_num_stages": 1}, "8": {"BLOCK_SEQ": 256, "BLOCK_N": 32, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 4, "stage1_num_stages": 4, "stage2_num_warps": 8, "stage2_num_stages": 1}, "16": {"BLOCK_SEQ": 128, "BLOCK_N": 32, "BLOCK_Q_HEAD": 64, "stage1_num_warps": 4, "stage1_num_stages": 3, "stage2_num_warps": 8, "stage2_num_stages": 1}, "32": {"BLOCK_SEQ": 256, "BLOCK_N": 16, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 2, "stage1_num_stages": 3, "stage2_num_warps": 8, "stage2_num_stages": 1}, "64": {"BLOCK_SEQ": 256, "BLOCK_N": 32, "BLOCK_Q_HEAD": 64, "stage1_num_warps": 4, "stage1_num_stages": 4, "stage2_num_warps": 2, "stage2_num_stages": 1}, "128": {"BLOCK_SEQ": 256, "BLOCK_N": 32, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 4, "stage1_num_stages": 3, "stage2_num_warps": 1, "stage2_num_stages": 1}, "256": {"BLOCK_SEQ": 256, "BLOCK_N": 32, "BLOCK_Q_HEAD": 64, "stage1_num_warps": 4, "stage1_num_stages": 3, "stage2_num_warps": 4, "stage2_num_stages": 1}}, "4096": {"1": {"BLOCK_SEQ": 128, "BLOCK_N": 32, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 4, "stage1_num_stages": 4, "stage2_num_warps": 8, "stage2_num_stages": 1}, "8": {"BLOCK_SEQ": 256, "BLOCK_N": 32, "BLOCK_Q_HEAD": 32, "stage1_num_warps": 4, "stage1_num_stages": 3, "stage2_num_warps": 8, "stage2_num_stages": 1}, "16": {"BLOCK_SEQ": 256, "BLOCK_N": 16, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 2, "stage1_num_stages": 3, "stage2_num_warps": 1, "stage2_num_stages": 1}, "32": {"BLOCK_SEQ": 256, "BLOCK_N": 32, "BLOCK_Q_HEAD": 64, "stage1_num_warps": 4, "stage1_num_stages": 4, "stage2_num_warps": 4, "stage2_num_stages": 1}, "64": {"BLOCK_SEQ": 256, "BLOCK_N": 32, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 4, "stage1_num_stages": 3, "stage2_num_warps": 1, "stage2_num_stages": 1}, "128": {"BLOCK_SEQ": 256, "BLOCK_N": 32, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 4, "stage1_num_stages": 3, "stage2_num_warps": 1, "stage2_num_stages": 1}}, "8192": {"1": {"BLOCK_SEQ": 128, "BLOCK_N": 32, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 4, "stage1_num_stages": 4, "stage2_num_warps": 8, "stage2_num_stages": 1}, "8": {"BLOCK_SEQ": 256, "BLOCK_N": 32, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 4, "stage1_num_stages": 4, "stage2_num_warps": 4, "stage2_num_stages": 5}, "16": {"BLOCK_SEQ": 256, "BLOCK_N": 32, "BLOCK_Q_HEAD": 64, "stage1_num_warps": 4, "stage1_num_stages": 4, "stage2_num_warps": 4, "stage2_num_stages": 5}, "32": {"BLOCK_SEQ": 256, "BLOCK_N": 32, "BLOCK_Q_HEAD": 64, "stage1_num_warps": 4, "stage1_num_stages": 3, "stage2_num_warps": 4, "stage2_num_stages": 5}, "64": {"BLOCK_SEQ": 256, "BLOCK_N": 32, "BLOCK_Q_HEAD": 32, "stage1_num_warps": 4, "stage1_num_stages": 3, "stage2_num_warps": 4, "stage2_num_stages": 5}}} \ No newline at end of file +{"256": {"1": {"BLOCK_N": 32, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 4, "stage1_num_stages": 2, "stage2_num_warps": 4, "stage2_num_stages": 3}, "8": {"BLOCK_N": 32, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 8, "stage1_num_stages": 3, "stage2_num_warps": 4, "stage2_num_stages": 3}, "16": {"BLOCK_N": 32, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 4, "stage1_num_stages": 2, "stage2_num_warps": 4, "stage2_num_stages": 3}, "32": {"BLOCK_N": 32, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 4, "stage1_num_stages": 2, "stage2_num_warps": 4, "stage2_num_stages": 3}, "64": {"BLOCK_N": 16, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 4, "stage1_num_stages": 3, "stage2_num_warps": 1, "stage2_num_stages": 1}, "128": {"BLOCK_N": 16, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 4, "stage1_num_stages": 3, "stage2_num_warps": 2, "stage2_num_stages": 1}, "256": {"BLOCK_N": 16, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 4, "stage1_num_stages": 3, "stage2_num_warps": 1, "stage2_num_stages": 1}}, "512": {"1": {"BLOCK_N": 32, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 8, "stage1_num_stages": 3, "stage2_num_warps": 4, "stage2_num_stages": 3}, "8": {"BLOCK_N": 32, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 4, "stage1_num_stages": 2, "stage2_num_warps": 4, "stage2_num_stages": 3}, "16": {"BLOCK_N": 32, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 4, "stage1_num_stages": 2, "stage2_num_warps": 4, "stage2_num_stages": 1}, "32": {"BLOCK_N": 16, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 4, "stage1_num_stages": 3, "stage2_num_warps": 1, "stage2_num_stages": 3}, "64": {"BLOCK_N": 16, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 4, "stage1_num_stages": 4, "stage2_num_warps": 1, "stage2_num_stages": 1}, "128": {"BLOCK_N": 16, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 4, "stage1_num_stages": 5, "stage2_num_warps": 1, "stage2_num_stages": 1}, "256": {"BLOCK_N": 32, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 8, "stage1_num_stages": 4, "stage2_num_warps": 4, "stage2_num_stages": 1}}, "1024": {"1": {"BLOCK_N": 32, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 8, "stage1_num_stages": 3, "stage2_num_warps": 4, "stage2_num_stages": 3}, "8": {"BLOCK_N": 16, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 8, "stage1_num_stages": 6, "stage2_num_warps": 4, "stage2_num_stages": 3}, "16": {"BLOCK_N": 32, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 8, "stage1_num_stages": 3, "stage2_num_warps": 4, "stage2_num_stages": 3}, "32": {"BLOCK_N": 32, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 4, "stage1_num_stages": 2, "stage2_num_warps": 4, "stage2_num_stages": 3}, "64": {"BLOCK_N": 16, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 4, "stage1_num_stages": 4, "stage2_num_warps": 1, "stage2_num_stages": 1}, "128": {"BLOCK_N": 16, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 4, "stage1_num_stages": 5, "stage2_num_warps": 2, "stage2_num_stages": 1}, "256": {"BLOCK_N": 16, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 4, "stage1_num_stages": 6, "stage2_num_warps": 4, "stage2_num_stages": 1}}, "2048": {"1": {"BLOCK_N": 32, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 8, "stage1_num_stages": 3, "stage2_num_warps": 4, "stage2_num_stages": 3}, "8": {"BLOCK_N": 32, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 8, "stage1_num_stages": 3, "stage2_num_warps": 4, "stage2_num_stages": 3}, "16": {"BLOCK_N": 32, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 4, "stage1_num_stages": 2, "stage2_num_warps": 4, "stage2_num_stages": 3}, "32": {"BLOCK_N": 16, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 4, "stage1_num_stages": 4, "stage2_num_warps": 4, "stage2_num_stages": 3}, "64": {"BLOCK_N": 16, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 4, "stage1_num_stages": 4, "stage2_num_warps": 1, "stage2_num_stages": 1}, "128": {"BLOCK_N": 16, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 4, "stage1_num_stages": 5, "stage2_num_warps": 1, "stage2_num_stages": 1}, "256": {"BLOCK_N": 32, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 4, "stage1_num_stages": 2, "stage2_num_warps": 1, "stage2_num_stages": 3}}, "4096": {"1": {"BLOCK_N": 32, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 4, "stage1_num_stages": 2, "stage2_num_warps": 4, "stage2_num_stages": 3}, "8": {"BLOCK_N": 32, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 8, "stage1_num_stages": 3, "stage2_num_warps": 4, "stage2_num_stages": 3}, "16": {"BLOCK_N": 32, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 8, "stage1_num_stages": 4, "stage2_num_warps": 4, "stage2_num_stages": 3}, "32": {"BLOCK_N": 32, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 8, "stage1_num_stages": 4, "stage2_num_warps": 2, "stage2_num_stages": 1}, "64": {"BLOCK_N": 32, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 4, "stage1_num_stages": 2, "stage2_num_warps": 4, "stage2_num_stages": 1}, "128": {"BLOCK_N": 16, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 4, "stage1_num_stages": 5, "stage2_num_warps": 2, "stage2_num_stages": 1}}, "8192": {"1": {"BLOCK_N": 32, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 4, "stage1_num_stages": 2, "stage2_num_warps": 4, "stage2_num_stages": 3}, "8": {"BLOCK_N": 32, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 8, "stage1_num_stages": 4, "stage2_num_warps": 4, "stage2_num_stages": 3}, "16": {"BLOCK_N": 32, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 8, "stage1_num_stages": 4, "stage2_num_warps": 1, "stage2_num_stages": 3}, "32": {"BLOCK_N": 16, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 4, "stage1_num_stages": 4, "stage2_num_warps": 1, "stage2_num_stages": 3}, "64": {"BLOCK_N": 16, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 4, "stage1_num_stages": 6, "stage2_num_warps": 1, "stage2_num_stages": 3}}} \ No newline at end of file diff --git a/lightllm/common/basemodel/basemodel.py b/lightllm/common/basemodel/basemodel.py index f0f383203..081fba6d3 100755 --- a/lightllm/common/basemodel/basemodel.py +++ b/lightllm/common/basemodel/basemodel.py @@ -18,7 +18,6 @@ from lightllm.common.basemodel.cuda_graph import CudaGraph from lightllm.common.quantization import Quantcfg from lightllm.utils.log_utils import init_logger -from lightllm.common.basemodel.infer_lock import g_infer_state_lock logger = init_logger(__name__) diff --git a/lightllm/models/deepseek2/infer_struct.py b/lightllm/models/deepseek2/infer_struct.py index 90954c60d..697f319fc 100644 --- a/lightllm/models/deepseek2/infer_struct.py +++ b/lightllm/models/deepseek2/infer_struct.py @@ -16,6 +16,7 @@ def init_some_extra_state(self, model, input_ids: torch.Tensor): # 只有 decode 阶段使用 ppl 的优化算子才会有这个管理变量 if not self.is_prefill: self.kv_starts = torch.cat([self.b_start_loc, self.b_start_loc[-1:] + self.b_seq_len[-1:]], dim=0) + self.total_token_num_tensor = torch.sum(self.b_seq_len) if self.enable_dp: rank = dist.get_rank() diff --git a/lightllm/models/deepseek2/model.py b/lightllm/models/deepseek2/model.py index f0d643950..c09acfb30 100644 --- a/lightllm/models/deepseek2/model.py +++ b/lightllm/models/deepseek2/model.py @@ -1,9 +1,8 @@ -import os -import json import torch from lightllm.models.deepseek2.layer_infer.transformer_layer_infer import Deepseek2TransformerLayerInfer from lightllm.models.deepseek2.layer_weights.transformer_layer_weight import Deepseek2TransformerLayerWeight from lightllm.models.deepseek2.infer_struct import Deepseek2InferStateInfo +from lightllm.models.deepseek2.splitfuse_infer_struct import DeepSeekv2SplitFuseInferStateInfo from lightllm.common.basemodel.layer_weights.hf_load_utils import load_hf_weights from lightllm.models.llama.model import LlamaTpPartModel @@ -24,6 +23,9 @@ class Deepseek2TpPartModel(LlamaTpPartModel): # infer state class infer_state_class = Deepseek2InferStateInfo + # split fuse state class + splitfuse_infer_state_class = DeepSeekv2SplitFuseInferStateInfo + def __init__(self, kvargs): super().__init__(kvargs) return diff --git a/lightllm/models/deepseek2/splitfuse_infer_struct.py b/lightllm/models/deepseek2/splitfuse_infer_struct.py new file mode 100644 index 000000000..2d2ecc71c --- /dev/null +++ b/lightllm/models/deepseek2/splitfuse_infer_struct.py @@ -0,0 +1,35 @@ +import torch +import numpy as np +from lightllm.common.basemodel import SplitFuseInferStateInfo +from .infer_struct import Deepseek2InferStateInfo + + +class DeepSeekv2SplitFuseInferStateInfo(SplitFuseInferStateInfo): + + inner_infer_state_class = Deepseek2InferStateInfo + + def __init__(self): + super().__init__() + self.position_cos = None + self.position_sin = None + + def init_some_extra_state(self, model, input_ids: torch.Tensor): + position_ids = [] + if self.decode_req_num != 0: + position_ids.append((self.decode_b_seq_len - 1).cpu().numpy()) + if self.prefill_req_num != 0: + b_seq_len_numpy = self.prefill_b_seq_len.cpu().numpy() + b_ready_cache_len_numpy = self.prefill_b_split_ready_cache_len.cpu().numpy() + position_ids.extend( + [np.arange(b_ready_cache_len_numpy[i], b_seq_len_numpy[i]) for i in range(len(b_seq_len_numpy))] + ) + + position_ids = torch.from_numpy(np.concatenate(position_ids, axis=0)).cuda().view(-1) + self.position_cos = torch.index_select(model._cos_cached, 0, position_ids).view(position_ids.shape[0], -1) + self.position_sin = torch.index_select(model._sin_cached, 0, position_ids).view(position_ids.shape[0], -1) + return + + def create_inner_decode_infer_status(self): + infer_state = super().create_inner_decode_infer_status() + infer_state.total_token_num_tensor = torch.sum(infer_state.b_seq_len) + return diff --git a/lightllm/models/deepseek2/triton_kernel/gqa_flash_decoding.py b/lightllm/models/deepseek2/triton_kernel/gqa_flash_decoding.py index c4226905d..8b688ddd3 100644 --- a/lightllm/models/deepseek2/triton_kernel/gqa_flash_decoding.py +++ b/lightllm/models/deepseek2/triton_kernel/gqa_flash_decoding.py @@ -4,6 +4,7 @@ from typing import List from lightllm.utils.log_utils import init_logger from .gqa_flash_decoding_config import MlaDecodeAttentionKernelConfig +from lightllm.utils.device_utils import get_device_sm_count logger = init_logger(__name__) @@ -43,21 +44,48 @@ def gqa_token_decode_attention_flash_decoding( out_dtype=torch.bfloat16, ) - BLOCK_SEQ = run_config["BLOCK_SEQ"] - from .gqa_flash_decoding_stage1 import flash_decode_stage1 from .gqa_flash_decoding_stage2 import flash_decode_stage2 o_tensor = alloc_tensor_func(q_nope.shape, q_nope.dtype, q_nope.device) if out is None else out - mid_o = alloc_tensor_func( - [batch_size, q_head_num, max_len_in_batch // BLOCK_SEQ + 1, kv_lora_rank], dtype=torch.float32, device="cuda" + mid_o_block_seq = torch.empty([1], dtype=torch.int64, device="cuda") + mid_o_batch_start_index = alloc_tensor_func( + [ + batch_size, + ], + dtype=torch.int64, + device="cuda", ) - mid_o_logexpsum = alloc_tensor_func( - [batch_size, q_head_num, max_len_in_batch // BLOCK_SEQ + 1], dtype=torch.float32, device="cuda" + + mid_o = torch.empty([q_head_num, 0, kv_lora_rank], dtype=torch.float32, device="cuda") + mid_o_logexpsum = torch.empty([q_head_num, 0], dtype=torch.float32, device="cuda") + + vsm_count = flash_decode_stage1( + infer_state.total_token_num_tensor, + mid_o_block_seq, + mid_o_batch_start_index, + q_nope.view(calcu_shape1), + q_rope.view(calcu_shape2), + kv_nope, + kv_rope, + infer_state.req_manager.req_to_token_indexs, + infer_state.b_req_idx, + infer_state.b_seq_len, + mid_o, + mid_o_logexpsum, + softmax_scale, + get_sm_count=True, + **run_config ) + mid_o = torch.empty([q_head_num, vsm_count * 4 + batch_size, kv_lora_rank], dtype=torch.float32, device="cuda") + mid_o_logexpsum = torch.empty([q_head_num, vsm_count * 4 + batch_size], dtype=torch.float32, device="cuda") + flash_decode_stage1( + infer_state.total_token_num_tensor, + mid_o_block_seq, + mid_o_batch_start_index, q_nope.view(calcu_shape1), q_rope.view(calcu_shape2), kv_nope, @@ -65,14 +93,20 @@ def gqa_token_decode_attention_flash_decoding( infer_state.req_manager.req_to_token_indexs, infer_state.b_req_idx, infer_state.b_seq_len, - infer_state.max_len_in_batch, mid_o, mid_o_logexpsum, - BLOCK_SEQ, softmax_scale, + get_sm_count=False, **run_config ) + flash_decode_stage2( - mid_o, mid_o_logexpsum, infer_state.b_seq_len, o_tensor.view(calcu_shape1), BLOCK_SEQ, **run_config + mid_o_block_seq, + mid_o_batch_start_index, + mid_o, + mid_o_logexpsum, + infer_state.b_seq_len, + o_tensor.view(calcu_shape1), + **run_config ) return o_tensor diff --git a/lightllm/models/deepseek2/triton_kernel/gqa_flash_decoding_config.py b/lightllm/models/deepseek2/triton_kernel/gqa_flash_decoding_config.py index 183680b98..be99ca9bf 100644 --- a/lightllm/models/deepseek2/triton_kernel/gqa_flash_decoding_config.py +++ b/lightllm/models/deepseek2/triton_kernel/gqa_flash_decoding_config.py @@ -38,7 +38,6 @@ def try_to_get_best_config( return config else: config = { - "BLOCK_SEQ": 64, "BLOCK_N": 16, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 4, diff --git a/lightllm/models/deepseek2/triton_kernel/gqa_flash_decoding_stage1.py b/lightllm/models/deepseek2/triton_kernel/gqa_flash_decoding_stage1.py index 0c5a252a2..9028c6e85 100644 --- a/lightllm/models/deepseek2/triton_kernel/gqa_flash_decoding_stage1.py +++ b/lightllm/models/deepseek2/triton_kernel/gqa_flash_decoding_stage1.py @@ -1,11 +1,11 @@ -import os import torch import triton import triton.language as tl +from lightllm.utils.device_utils import calcu_kernel_best_vsm_count @triton.jit -def _fwd_kernel_flash_decode_stage1( +def _fwd_kernel_flash_decode_stage1_padding( Q_nope, Q_rope, KV_nope, @@ -14,8 +14,8 @@ def _fwd_kernel_flash_decode_stage1( Req_to_tokens, B_req_idx, B_Seqlen, - Mid_O, # [batch, head, seq_block_num, head_dim] - Mid_O_LogExpSum, # [batch, head, seq_block_num] + Mid_O, # [head, seq_block_num, head_dim] + Mid_O_LogExpSum, # [head, seq_block_num] stride_req_to_tokens_b, stride_req_to_tokens_s, stride_q_bs, @@ -30,223 +30,152 @@ def _fwd_kernel_flash_decode_stage1( stride_kv_rope_bs, stride_kv_rope_h, stride_kv_rope_d, - stride_mid_ob, stride_mid_oh, stride_mid_os, stride_mid_od, - stride_mid_o_eb, stride_mid_o_eh, stride_mid_o_es, + total_token_ptr, + block_size_ptr, + batch_start_index_ptr, + num_sm, + head_group_num, + head_num, + batch_size, Q_HEAD_NUM: tl.constexpr, - BLOCK_SEQ: tl.constexpr, BLOCK_DMODEL: tl.constexpr, BLOCK_ROPE_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr, + NUM_STAGES: tl.constexpr, + NEED_HEAD_MASK: tl.constexpr, ): - seq_start_block = tl.program_id(0) - cur_q_head = tl.program_id(1) - cur_batch = tl.program_id(2) + # cur_kv_head = 0 + sm_id = tl.program_id(0).to(tl.int64) + grid_id = sm_id + out_batch_start_index = tl.cast(0, tl.int64) + total_token_num = tl.load(total_token_ptr, eviction_policy="evict_last") - cur_q_head_offs = tl.arange(0, Q_HEAD_NUM) - cur_q_head_range = cur_q_head * Q_HEAD_NUM + cur_q_head_offs + block_seq = tl.cast(total_token_num / num_sm / 4, dtype=tl.int32) + 1 + block_seq = tl.cdiv(block_seq, BLOCK_N) * BLOCK_N + if grid_id == 0: + tl.store(block_size_ptr, block_seq) + cur_q_head_offs = tl.arange(0, Q_HEAD_NUM) offs_d = tl.arange(0, BLOCK_DMODEL) offs_rope_d = tl.arange(0, BLOCK_ROPE_DMODEL) - cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) - cur_batch_req_idx = tl.load(B_req_idx + cur_batch) - cur_batch_start_index = seq_start_block * BLOCK_SEQ - cur_batch_end_index = tl.minimum(cur_batch_seq_len, cur_batch_start_index + BLOCK_SEQ) - off_q = cur_batch * stride_q_bs + cur_q_head_range[:, None] * stride_q_h + offs_d[None, :] - off_rope_q = cur_batch * stride_q_rope_bs + cur_q_head_range[:, None] * stride_q_rope_h + offs_rope_d[None, :] - q = tl.load(Q_nope + off_q) - q_rope = tl.load(Q_rope + off_rope_q) - block_n_size = ( - tl.where( - cur_batch_end_index - cur_batch_start_index <= 0, - 0, - cur_batch_end_index - cur_batch_start_index + BLOCK_N - 1, - ) - // BLOCK_N - ) + for cur_batch in range(batch_size): + cur_batch_seq_len = tl.load(B_Seqlen + cur_batch, eviction_policy="evict_last") + cur_block_num = tl.cdiv(cur_batch_seq_len, block_seq) * head_group_num + cur_batch_req_idx = tl.load(B_req_idx + cur_batch, eviction_policy="evict_last") + req_to_tokens_ptr = Req_to_tokens + stride_req_to_tokens_b * cur_batch_req_idx - offs_n = cur_batch_start_index + tl.arange(0, BLOCK_N) - sum_exp = tl.zeros([Q_HEAD_NUM], dtype=tl.float32) - max_logic = tl.zeros([Q_HEAD_NUM], dtype=tl.float32) - float("inf") - acc = tl.zeros([Q_HEAD_NUM, BLOCK_DMODEL], dtype=tl.float32) - for start_n in range(0, block_n_size, 1): - offs_n_new = start_n * BLOCK_N + offs_n - kv_loc = tl.load( - Req_to_tokens + stride_req_to_tokens_b * cur_batch_req_idx + offs_n_new, - mask=offs_n_new < cur_batch_end_index, - other=0, - ) - off_kv = kv_loc[None, :] * stride_kv_bs + offs_d[:, None] - kv = tl.load(KV_nope + off_kv, mask=offs_n_new[None, :] < cur_batch_end_index, other=0.0) - att_value = tl.dot(q, kv) - off_rope_kv = kv_loc[None, :] * stride_kv_rope_bs + offs_rope_d[:, None] - rope_kv = tl.load(KV_rope + off_rope_kv, mask=offs_n_new[None, :] < cur_batch_end_index, other=0.0) - att_value += tl.dot(q_rope, rope_kv) + while sm_id < cur_block_num: + loop_head_group_index = sm_id % head_group_num + loop_seq_block_index = sm_id // head_group_num - att_value *= sm_scale - att_value = tl.where(offs_n_new[None, :] < cur_batch_end_index, att_value, float("-inf")) + cur_q_head_range = loop_head_group_index * Q_HEAD_NUM + cur_q_head_offs + if NEED_HEAD_MASK: + head_mask = cur_q_head_range < head_num - cur_max_logic = tl.max(att_value, axis=1) - new_max_logic = tl.maximum(cur_max_logic, max_logic) + cur_batch_start_index = block_seq * loop_seq_block_index + cur_batch_end_index = tl.minimum(cur_batch_seq_len, cur_batch_start_index + block_seq) - exp_logic = tl.exp(att_value - new_max_logic[:, None]) - logic_scale = tl.exp(max_logic - new_max_logic) - acc *= logic_scale[:, None] - acc += tl.dot(exp_logic.to(kv.dtype), tl.trans(kv)) + off_q = cur_batch * stride_q_bs + cur_q_head_range[:, None] * stride_q_h + offs_d[None, :] + off_rope_q = ( + cur_batch * stride_q_rope_bs + cur_q_head_range[:, None] * stride_q_rope_h + offs_rope_d[None, :] + ) + if NEED_HEAD_MASK: + q = tl.load( + Q_nope + off_q, + mask=head_mask[:, None], + other=0.0, + ) + q_rope = tl.load( + Q_rope + off_rope_q, + mask=head_mask[:, None], + other=0.0, + ) + else: + q = tl.load(Q_nope + off_q) + q_rope = tl.load(Q_rope + off_rope_q) - sum_exp = sum_exp * logic_scale + tl.sum(exp_logic, axis=1) - max_logic = new_max_logic + block_n_size = tl.cdiv(cur_batch_end_index - cur_batch_start_index, BLOCK_N) - need_store = tl.where(block_n_size == 0, 0, 1) - for _ in range(0, need_store, 1): - off_mid_o = ( - cur_batch * stride_mid_ob - + cur_q_head_range[:, None] * stride_mid_oh - + seq_start_block * stride_mid_os - + offs_d[None, :] - ) - off_mid_o_logexpsum = cur_batch * stride_mid_o_eb + cur_q_head_range * stride_mid_o_eh + seq_start_block - tl.store( - Mid_O + off_mid_o, - acc / sum_exp[:, None], - ) - tl.store( - Mid_O_LogExpSum + off_mid_o_logexpsum, - max_logic + tl.log(sum_exp), - ) - return + offs_n = cur_batch_start_index + tl.arange(0, BLOCK_N) + sum_exp = tl.zeros([Q_HEAD_NUM], dtype=tl.float32) + max_logic = tl.zeros([Q_HEAD_NUM], dtype=tl.float32) - float("inf") + acc = tl.zeros([Q_HEAD_NUM, BLOCK_DMODEL], dtype=tl.float32) + for start_n in tl.range(0, block_n_size, 1, num_stages=NUM_STAGES): + offs_n_new = start_n * BLOCK_N + offs_n + seq_n_mask = offs_n_new < cur_batch_end_index + kv_loc = tl.load( + req_to_tokens_ptr + offs_n_new, + mask=seq_n_mask, + other=0, + ) + off_kv = kv_loc[None, :] * stride_kv_bs + offs_d[:, None] + kv = tl.load(KV_nope + off_kv, mask=seq_n_mask[None, :], other=0.0) + att_value = tl.dot(q, kv) + off_rope_kv = kv_loc[None, :] * stride_kv_rope_bs + offs_rope_d[:, None] + rope_kv = tl.load(KV_rope + off_rope_kv, mask=seq_n_mask[None, :], other=0.0) + att_value += tl.dot(q_rope, rope_kv) + att_value *= sm_scale + att_value = tl.where(seq_n_mask[None, :], att_value, float("-inf")) -@triton.jit -def _fwd_kernel_flash_decode_stage1_padding( - Q_nope, - Q_rope, - KV_nope, - KV_rope, - sm_scale, - Req_to_tokens, - B_req_idx, - B_Seqlen, - Mid_O, # [batch, head, seq_block_num, head_dim] - Mid_O_LogExpSum, # [batch, head, seq_block_num] - stride_req_to_tokens_b, - stride_req_to_tokens_s, - stride_q_bs, - stride_q_h, - stride_q_d, - stride_q_rope_bs, - stride_q_rope_h, - stride_q_rope_d, - stride_kv_bs, - stride_kv_h, - stride_kv_d, - stride_kv_rope_bs, - stride_kv_rope_h, - stride_kv_rope_d, - stride_mid_ob, - stride_mid_oh, - stride_mid_os, - stride_mid_od, - stride_mid_o_eb, - stride_mid_o_eh, - stride_mid_o_es, - gqa_group_size, - Q_HEAD_NUM: tl.constexpr, - BLOCK_SEQ: tl.constexpr, - BLOCK_DMODEL: tl.constexpr, - BLOCK_ROPE_DMODEL: tl.constexpr, - BLOCK_N: tl.constexpr, -): - seq_start_block = tl.program_id(0) - cur_kv_head = tl.program_id(1) - cur_batch = tl.program_id(2) + cur_max_logic = tl.max(att_value, axis=1) + new_max_logic = tl.maximum(cur_max_logic, max_logic) - cur_q_head_offs = tl.arange(0, Q_HEAD_NUM) - cur_q_head_range = cur_kv_head * gqa_group_size + cur_q_head_offs + exp_logic = tl.exp(att_value - new_max_logic[:, None]) + logic_scale = tl.exp(max_logic - new_max_logic) + acc *= logic_scale[:, None] + acc += tl.dot(exp_logic.to(kv.dtype), tl.trans(kv)) - offs_d = tl.arange(0, BLOCK_DMODEL) - offs_rope_d = tl.arange(0, BLOCK_ROPE_DMODEL) - cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) - cur_batch_req_idx = tl.load(B_req_idx + cur_batch) - cur_batch_start_index = seq_start_block * BLOCK_SEQ - cur_batch_end_index = tl.minimum(cur_batch_seq_len, cur_batch_start_index + BLOCK_SEQ) + sum_exp = sum_exp * logic_scale + tl.sum(exp_logic, axis=1) + max_logic = new_max_logic - off_q = cur_batch * stride_q_bs + cur_q_head_range[:, None] * stride_q_h + offs_d[None, :] - off_rope_q = cur_batch * stride_q_rope_bs + cur_q_head_range[:, None] * stride_q_rope_h + offs_rope_d[None, :] - q = tl.load(Q_nope + off_q, mask=cur_q_head_range[:, None] < (cur_kv_head + 1) * gqa_group_size, other=0.0) - q_rope = tl.load( - Q_rope + off_rope_q, mask=cur_q_head_range[:, None] < (cur_kv_head + 1) * gqa_group_size, other=0.0 - ) - block_n_size = ( - tl.where( - cur_batch_end_index - cur_batch_start_index <= 0, - 0, - cur_batch_end_index - cur_batch_start_index + BLOCK_N - 1, - ) - // BLOCK_N - ) - - offs_n = cur_batch_start_index + tl.arange(0, BLOCK_N) - sum_exp = tl.zeros([Q_HEAD_NUM], dtype=tl.float32) - max_logic = tl.zeros([Q_HEAD_NUM], dtype=tl.float32) - float("inf") - acc = tl.zeros([Q_HEAD_NUM, BLOCK_DMODEL], dtype=tl.float32) - for start_n in range(0, block_n_size, 1): - offs_n_new = start_n * BLOCK_N + offs_n - kv_loc = tl.load( - Req_to_tokens + stride_req_to_tokens_b * cur_batch_req_idx + offs_n_new, - mask=offs_n_new < cur_batch_end_index, - other=0, - ) - off_kv = kv_loc[None, :] * stride_kv_bs + offs_d[:, None] - kv = tl.load(KV_nope + off_kv, mask=offs_n_new[None, :] < cur_batch_end_index, other=0.0) - att_value = tl.dot(q, kv) - off_rope_kv = kv_loc[None, :] * stride_kv_rope_bs + offs_rope_d[:, None] - rope_kv = tl.load(KV_rope + off_rope_kv, mask=offs_n_new[None, :] < cur_batch_end_index, other=0.0) - att_value += tl.dot(q_rope, rope_kv) - - att_value *= sm_scale - att_value = tl.where(offs_n_new[None, :] < cur_batch_end_index, att_value, float("-inf")) + off_mid_o = ( + cur_q_head_range[:, None] * stride_mid_oh + + (out_batch_start_index + loop_seq_block_index) * stride_mid_os + + offs_d[None, :] + ) + off_mid_o_logexpsum = cur_q_head_range * stride_mid_o_eh + out_batch_start_index + loop_seq_block_index + if NEED_HEAD_MASK: + tl.store( + Mid_O + off_mid_o, + acc / sum_exp[:, None], + mask=head_mask[:, None], + ) + tl.store( + Mid_O_LogExpSum + off_mid_o_logexpsum, + max_logic + tl.log(sum_exp), + mask=head_mask, + ) + else: + tl.store( + Mid_O + off_mid_o, + acc / sum_exp[:, None], + ) + tl.store( + Mid_O_LogExpSum + off_mid_o_logexpsum, + max_logic + tl.log(sum_exp), + ) + sm_id += num_sm - cur_max_logic = tl.max(att_value, axis=1) - new_max_logic = tl.maximum(cur_max_logic, max_logic) + if grid_id == 0: + tl.store(batch_start_index_ptr + cur_batch, out_batch_start_index) - exp_logic = tl.exp(att_value - new_max_logic[:, None]) - logic_scale = tl.exp(max_logic - new_max_logic) - acc *= logic_scale[:, None] - acc += tl.dot(exp_logic.to(kv.dtype), tl.trans(kv)) - - sum_exp = sum_exp * logic_scale + tl.sum(exp_logic, axis=1) - max_logic = new_max_logic - - need_store = tl.where(block_n_size == 0, 0, 1) - for _ in range(0, need_store, 1): - off_mid_o = ( - cur_batch * stride_mid_ob - + cur_q_head_range[:, None] * stride_mid_oh - + seq_start_block * stride_mid_os - + offs_d[None, :] - ) - off_mid_o_logexpsum = cur_batch * stride_mid_o_eb + cur_q_head_range * stride_mid_o_eh + seq_start_block - tl.store( - Mid_O + off_mid_o, - acc / sum_exp[:, None], - mask=cur_q_head_range[:, None] < (cur_kv_head + 1) * gqa_group_size, - ) - tl.store( - Mid_O_LogExpSum + off_mid_o_logexpsum, - max_logic + tl.log(sum_exp), - mask=cur_q_head_range < (cur_kv_head + 1) * gqa_group_size, - ) + out_batch_start_index += cur_block_num // head_group_num + sm_id -= cur_block_num return @torch.no_grad() def flash_decode_stage1( + total_token_num_tensor: torch.Tensor, + out_block_seq: torch.Tensor, + batch_start_index: torch.Tensor, q_nope, q_rope, kv_nope, @@ -254,27 +183,18 @@ def flash_decode_stage1( Req_to_tokens, B_req_idx, B_Seqlen, - max_len_in_batch, mid_out, mid_out_logsumexp, - block_seq, softmax_scale, + get_sm_count: bool = False, **run_config, ): if run_config: - BLOCK_SEQ = run_config["BLOCK_SEQ"] + Q_HEAD_NUM = run_config["BLOCK_Q_HEAD"] BLOCK_N = run_config["BLOCK_N"] - BLOCK_Q_HEAD = run_config["BLOCK_Q_HEAD"] num_warps = run_config["stage1_num_warps"] num_stages = run_config["stage1_num_stages"] - else: - BLOCK_SEQ = block_seq - BLOCK_N = 16 - BLOCK_Q_HEAD = 16 - num_warps = 4 - num_stages = 2 - assert BLOCK_SEQ % BLOCK_N == 0 # shape constraints q_nope_dim = q_nope.shape[-1] q_rope_dim = q_rope.shape[-1] @@ -283,95 +203,89 @@ def flash_decode_stage1( assert q_rope_dim == kv_rope.shape[-1] assert q_nope_dim in {16, 32, 64, 128, 256, 512} assert q_rope_dim in {16, 32, 64, 128, 256} + assert kv_nope.shape[1] == 1 + + batch_size, q_head_num = B_req_idx.shape[0], q_nope.shape[1] + head_group_num = triton.cdiv(q_head_num, Q_HEAD_NUM) + NEED_HEAD_MASK = (q_head_num % Q_HEAD_NUM) != 0 + + kernel = _fwd_kernel_flash_decode_stage1_padding.warmup( + q_nope, + q_rope, + kv_nope, + kv_rope, + softmax_scale, + Req_to_tokens, + B_req_idx, + B_Seqlen, + mid_out, + mid_out_logsumexp, + *Req_to_tokens.stride(), + *q_nope.stride(), + *q_rope.stride(), + *kv_nope.stride(), + *kv_rope.stride(), + *mid_out.stride(), + *mid_out_logsumexp.stride(), + total_token_num_tensor, + out_block_seq, + batch_start_index, + num_sm=1, + head_group_num=head_group_num, + head_num=q_head_num, + batch_size=batch_size, + Q_HEAD_NUM=Q_HEAD_NUM, + BLOCK_DMODEL=q_nope_dim, + BLOCK_ROPE_DMODEL=q_rope_dim, + BLOCK_N=BLOCK_N, + NEED_HEAD_MASK=NEED_HEAD_MASK, + NUM_STAGES=num_stages, + num_warps=num_warps, + num_stages=1, + grid=(1,), + ) + + kernel._init_handles() + num_sm = calcu_kernel_best_vsm_count(kernel, num_warps=num_warps) + grid = (num_sm,) + if get_sm_count: + return num_sm + + assert num_sm * 4 + batch_size <= mid_out.shape[1] + + _fwd_kernel_flash_decode_stage1_padding[grid]( + q_nope, + q_rope, + kv_nope, + kv_rope, + softmax_scale, + Req_to_tokens, + B_req_idx, + B_Seqlen, + mid_out, + mid_out_logsumexp, + *Req_to_tokens.stride(), + *q_nope.stride(), + *q_rope.stride(), + *kv_nope.stride(), + *kv_rope.stride(), + *mid_out.stride(), + *mid_out_logsumexp.stride(), + total_token_num_tensor, + out_block_seq, + batch_start_index, + num_sm=num_sm, + head_group_num=head_group_num, + head_num=q_head_num, + batch_size=batch_size, + Q_HEAD_NUM=Q_HEAD_NUM, + BLOCK_DMODEL=q_nope_dim, + BLOCK_ROPE_DMODEL=q_rope_dim, + BLOCK_N=BLOCK_N, + NEED_HEAD_MASK=NEED_HEAD_MASK, + NUM_STAGES=num_stages, + num_warps=num_warps, + num_stages=1, + ) - sm_scale = softmax_scale # 计算scale系数 - batch, q_head_num = B_req_idx.shape[0], q_nope.shape[1] - if q_head_num % BLOCK_Q_HEAD == 0: - grid = (triton.cdiv(max_len_in_batch, BLOCK_SEQ), q_head_num // BLOCK_Q_HEAD, batch) - _fwd_kernel_flash_decode_stage1[grid]( - q_nope, - q_rope, - kv_nope, - kv_rope, - sm_scale, - Req_to_tokens, - B_req_idx, - B_Seqlen, - mid_out, - mid_out_logsumexp, - Req_to_tokens.stride(0), - Req_to_tokens.stride(1), - q_nope.stride(0), - q_nope.stride(1), - q_nope.stride(2), - q_rope.stride(0), - q_rope.stride(1), - q_rope.stride(2), - kv_nope.stride(0), - kv_nope.stride(1), - kv_nope.stride(2), - kv_rope.stride(0), - kv_rope.stride(1), - kv_rope.stride(2), - mid_out.stride(0), - mid_out.stride(1), - mid_out.stride(2), - mid_out.stride(3), - mid_out_logsumexp.stride(0), - mid_out_logsumexp.stride(1), - mid_out_logsumexp.stride(2), - Q_HEAD_NUM=BLOCK_Q_HEAD, - BLOCK_SEQ=BLOCK_SEQ, - BLOCK_DMODEL=q_nope_dim, - BLOCK_ROPE_DMODEL=q_rope_dim, - BLOCK_N=BLOCK_N, - num_warps=num_warps, - num_stages=num_stages, - ) - else: - assert q_head_num < BLOCK_Q_HEAD - kv_head_num = kv_nope.shape[1] - grid = (triton.cdiv(max_len_in_batch, BLOCK_SEQ), kv_head_num, batch) - gqa_group_size = q_head_num // kv_head_num - _fwd_kernel_flash_decode_stage1_padding[grid]( - q_nope, - q_rope, - kv_nope, - kv_rope, - sm_scale, - Req_to_tokens, - B_req_idx, - B_Seqlen, - mid_out, - mid_out_logsumexp, - Req_to_tokens.stride(0), - Req_to_tokens.stride(1), - q_nope.stride(0), - q_nope.stride(1), - q_nope.stride(2), - q_rope.stride(0), - q_rope.stride(1), - q_rope.stride(2), - kv_nope.stride(0), - kv_nope.stride(1), - kv_nope.stride(2), - kv_rope.stride(0), - kv_rope.stride(1), - kv_rope.stride(2), - mid_out.stride(0), - mid_out.stride(1), - mid_out.stride(2), - mid_out.stride(3), - mid_out_logsumexp.stride(0), - mid_out_logsumexp.stride(1), - mid_out_logsumexp.stride(2), - gqa_group_size, - Q_HEAD_NUM=max(16, triton.next_power_of_2(gqa_group_size)), - BLOCK_SEQ=BLOCK_SEQ, - BLOCK_DMODEL=q_nope_dim, - BLOCK_ROPE_DMODEL=q_rope_dim, - BLOCK_N=BLOCK_N, - num_warps=num_warps, - num_stages=num_stages, - ) return diff --git a/lightllm/models/deepseek2/triton_kernel/gqa_flash_decoding_stage2.py b/lightllm/models/deepseek2/triton_kernel/gqa_flash_decoding_stage2.py index 376dc493b..5b5e7b747 100644 --- a/lightllm/models/deepseek2/triton_kernel/gqa_flash_decoding_stage2.py +++ b/lightllm/models/deepseek2/triton_kernel/gqa_flash_decoding_stage2.py @@ -6,21 +6,20 @@ @triton.jit def _fwd_kernel_flash_decode_stage2( + block_seq_ptr, + batch_start_index, B_Seqlen, Mid_O, # [batch, head, seq_block_num, head_dim] Mid_O_LogExpSum, # [batch, head, seq_block_num] Out, # [batch, head, head_dim] - stride_mid_ob, stride_mid_oh, stride_mid_os, stride_mid_od, - stride_mid_o_eb, stride_mid_o_eh, stride_mid_o_es, stride_obs, stride_oh, stride_od, - BLOCK_SEQ: tl.constexpr, BLOCK_DMODEL: tl.constexpr, NUM_STAGES: tl.constexpr, ): @@ -29,15 +28,16 @@ def _fwd_kernel_flash_decode_stage2( offs_d = tl.arange(0, BLOCK_DMODEL) cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) + cur_batch_start_index = tl.load(batch_start_index + cur_batch) + block_seq = tl.load(block_seq_ptr) - block_n_size = tl.where(cur_batch_seq_len <= 0, 0, cur_batch_seq_len + BLOCK_SEQ - 1) // BLOCK_SEQ - + block_n_size = tl.cdiv(cur_batch_seq_len, block_seq) sum_exp = 0.0 max_logic = -float("inf") acc = tl.zeros([BLOCK_DMODEL], dtype=tl.float32) - offs_v = cur_batch * stride_mid_ob + cur_head * stride_mid_oh + offs_d - offs_logic = cur_batch * stride_mid_o_eb + cur_head * stride_mid_o_eh + offs_v = cur_head * stride_mid_oh + cur_batch_start_index * stride_mid_os + offs_d + offs_logic = cur_head * stride_mid_o_eh + cur_batch_start_index for block_seq_n in tl.range(0, block_n_size, 1, num_stages=NUM_STAGES): tv = tl.load(Mid_O + offs_v + block_seq_n * stride_mid_os) tlogic = tl.load(Mid_O_LogExpSum + offs_logic + block_seq_n) @@ -55,40 +55,37 @@ def _fwd_kernel_flash_decode_stage2( @torch.no_grad() -def flash_decode_stage2(mid_out, mid_out_logexpsum, B_Seqlen, Out, block_seq, **run_config): +def flash_decode_stage2( + out_block_seq: torch.Tensor, + batch_start_index: torch.Tensor, + mid_out, + mid_out_logexpsum, + B_Seqlen, + Out, + **run_config +): if run_config: - BLOCK_SEQ = run_config["BLOCK_SEQ"] num_warps = run_config["stage2_num_warps"] num_stages = run_config["stage2_num_stages"] - else: - BLOCK_SEQ = block_seq - num_warps = 4 - num_stages = 2 Lk = mid_out.shape[-1] assert Lk in {16, 32, 64, 128, 256, 512} - batch, head_num = mid_out.shape[0], mid_out.shape[1] + batch, head_num = batch_start_index.shape[0], mid_out.shape[0] grid = (head_num, batch) _fwd_kernel_flash_decode_stage2[grid]( + out_block_seq, + batch_start_index, B_Seqlen, mid_out, mid_out_logexpsum, Out, - mid_out.stride(0), - mid_out.stride(1), - mid_out.stride(2), - mid_out.stride(3), - mid_out_logexpsum.stride(0), - mid_out_logexpsum.stride(1), - mid_out_logexpsum.stride(2), - Out.stride(0), - Out.stride(1), - Out.stride(2), - BLOCK_SEQ=BLOCK_SEQ, + *mid_out.stride(), + *mid_out_logexpsum.stride(), + *Out.stride(), BLOCK_DMODEL=Lk, NUM_STAGES=num_stages, num_warps=num_warps, - num_stages=num_stages, + num_stages=1, ) return diff --git a/lightllm/models/llama/infer_struct.py b/lightllm/models/llama/infer_struct.py index cbd7ad6ec..2f64a3621 100644 --- a/lightllm/models/llama/infer_struct.py +++ b/lightllm/models/llama/infer_struct.py @@ -9,7 +9,6 @@ def __init__(self): super().__init__() self.position_cos = None self.position_sin = None - self.other_kv_index = None def init_some_extra_state(self, model, input_ids: torch.Tensor): if self.is_prefill: @@ -29,6 +28,4 @@ def init_some_extra_state(self, model, input_ids: torch.Tensor): position_ids = self.b_seq_len - 1 self.position_cos = torch.index_select(model._cos_cached, 0, position_ids).view(self.b_seq_len.shape[0], -1) self.position_sin = torch.index_select(model._sin_cached, 0, position_ids).view(self.b_seq_len.shape[0], -1) - self.other_kv_index = self.req_manager.req_to_token_indexs[self.b_req_idx[0], 0].item() - # b_loc[0, max_len_in_batch - 1].item() return diff --git a/lightllm/models/llama/layer_infer/transformer_layer_infer.py b/lightllm/models/llama/layer_infer/transformer_layer_infer.py index bc8ab44fb..b70b26aaa 100755 --- a/lightllm/models/llama/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/llama/layer_infer/transformer_layer_infer.py @@ -362,7 +362,6 @@ def _token_decode_attention_normal(self, q, infer_state: LlamaInferStateInfo, la infer_state.b_req_idx, infer_state.b_start_loc, infer_state.b_seq_len, - infer_state.other_kv_index, ) return o_tensor diff --git a/lightllm/models/llama/splitfuse_infer_struct.py b/lightllm/models/llama/splitfuse_infer_struct.py index 45420d0ec..41af22663 100644 --- a/lightllm/models/llama/splitfuse_infer_struct.py +++ b/lightllm/models/llama/splitfuse_infer_struct.py @@ -1,7 +1,6 @@ import torch import numpy as np from lightllm.common.basemodel import SplitFuseInferStateInfo -from lightllm.common.req_manager import ReqManager from .infer_struct import LlamaInferStateInfo @@ -13,7 +12,6 @@ def __init__(self): super().__init__() self.position_cos = None self.position_sin = None - self.other_kv_index = None def init_some_extra_state(self, model, input_ids: torch.Tensor): position_ids = [] @@ -29,14 +27,4 @@ def init_some_extra_state(self, model, input_ids: torch.Tensor): position_ids = torch.from_numpy(np.concatenate(position_ids, axis=0)).cuda().view(-1) self.position_cos = torch.index_select(model._cos_cached, 0, position_ids).view(position_ids.shape[0], -1) self.position_sin = torch.index_select(model._sin_cached, 0, position_ids).view(position_ids.shape[0], -1) - - if self.decode_req_num != 0: - self.other_kv_index = self.req_manager.req_to_token_indexs[self.decode_b_req_idx[0], 0].item() - elif self.prefill_req_num != 0: - self.other_kv_index = self.req_manager.req_to_token_indexs[self.prefill_b_req_idx[0], 0].item() - return - - def create_inner_decode_infer_status(self): - infer_status = super().create_inner_decode_infer_status() - infer_status.other_kv_index = self.other_kv_index return diff --git a/lightllm/models/llama/triton_kernel/token_attention_softmax_and_reducev.py b/lightllm/models/llama/triton_kernel/token_attention_softmax_and_reducev.py index 9c13fcab3..b93964989 100644 --- a/lightllm/models/llama/triton_kernel/token_attention_softmax_and_reducev.py +++ b/lightllm/models/llama/triton_kernel/token_attention_softmax_and_reducev.py @@ -6,13 +6,24 @@ @triton.jit def _fwd_kernel( - Logics, V, Out, - Req_to_tokens, B_req_idx, B_Start_Loc, B_Seqlen, - stride_logic_h, stride_logic_bs, - stride_vbs, stride_vh, stride_vd, - stride_obs, stride_oh, stride_od, - stride_req_to_token_b, stride_req_to_token_s, - other_kv_index, # 避免读取到nan的数据 + Logics, + V, + Out, + Req_to_tokens, + B_req_idx, + B_Start_Loc, + B_Seqlen, + stride_logic_h, + stride_logic_bs, + stride_vbs, + stride_vh, + stride_vd, + stride_obs, + stride_oh, + stride_od, + stride_req_to_token_b, + stride_req_to_token_s, + other_kv_index, # 避免读取到nan的数据 kv_group_num, BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr, @@ -38,13 +49,18 @@ def _fwd_kernel( for start_n in range(0, cur_batch_seq_len, BLOCK_N): start_n = tl.multiple_of(start_n, BLOCK_N) - v_index = tl.load(Req_to_tokens + cur_batch_req_idx * stride_req_to_token_b + - (start_n + offs_n) * stride_req_to_token_s, - mask=(start_n + offs_n) < cur_batch_seq_len, other=other_kv_index) + v_index = tl.load( + Req_to_tokens + cur_batch_req_idx * stride_req_to_token_b + (start_n + offs_n) * stride_req_to_token_s, + mask=(start_n + offs_n) < cur_batch_seq_len, + other=other_kv_index, + ) + + qk = tl.load( + Logics + cur_head * stride_logic_h + (cur_batch_start_loc + start_n + offs_n) * stride_logic_bs, + mask=start_n + offs_n < cur_batch_seq_len, + other=float("-inf"), + ) - qk = tl.load(Logics + cur_head * stride_logic_h + (cur_batch_start_loc + start_n + offs_n) * stride_logic_bs, - mask=start_n + offs_n < cur_batch_seq_len, other=float("-inf")) - n_e_max = tl.maximum(tl.max(qk, 0), e_max) old_scale = tl.exp(e_max - n_e_max) p = tl.exp(qk - n_e_max) @@ -61,7 +77,7 @@ def _fwd_kernel( @torch.no_grad() -def token_softmax_reducev_fwd(logics, v, o, req_to_tokens, b_req_idx, b_start_loc, b_seq_len, other_kv_index): +def token_softmax_reducev_fwd(logics, v, o, req_to_tokens, b_req_idx, b_start_loc, b_seq_len): BLOCK = 64 batch, head = b_seq_len.shape[0], logics.shape[0] grid = (batch, head) @@ -69,16 +85,28 @@ def token_softmax_reducev_fwd(logics, v, o, req_to_tokens, b_req_idx, b_start_lo num_warps = 1 _fwd_kernel[grid]( - logics, v, o, req_to_tokens, b_req_idx, b_start_loc, b_seq_len, - logics.stride(0), logics.stride(1), - v.stride(0), v.stride(1), v.stride(2), - o.stride(0), o.stride(1), o.stride(2), - req_to_tokens.stride(0), req_to_tokens.stride(1), - other_kv_index, + logics, + v, + o, + req_to_tokens, + b_req_idx, + b_start_loc, + b_seq_len, + logics.stride(0), + logics.stride(1), + v.stride(0), + v.stride(1), + v.stride(2), + o.stride(0), + o.stride(1), + o.stride(2), + req_to_tokens.stride(0), + req_to_tokens.stride(1), + 0, kv_group_num, BLOCK_DMODEL=v.shape[-1], BLOCK_N=BLOCK, num_warps=num_warps, - num_stages=3 + num_stages=3, ) - return \ No newline at end of file + return diff --git a/lightllm/models/mistral/infer_struct.py b/lightllm/models/mistral/infer_struct.py index 81a70bc9a..fe8a91b55 100644 --- a/lightllm/models/mistral/infer_struct.py +++ b/lightllm/models/mistral/infer_struct.py @@ -1,7 +1,6 @@ import torch import numpy as np from lightllm.models.llama.infer_struct import LlamaInferStateInfo -from lightllm.common.req_manager import ReqManager from lightllm.models.mistral.triton_kernel.init_att_sliding_window_info import init_att_window_info_fwd @@ -32,8 +31,6 @@ def init_some_extra_state(self, model, input_ids: torch.Tensor): position_ids = self.b_seq_len - 1 self.position_cos = torch.index_select(model._cos_cached, 0, position_ids).view(self.b_seq_len.shape[0], -1) self.position_sin = torch.index_select(model._sin_cached, 0, position_ids).view(self.b_seq_len.shape[0], -1) - self.other_kv_index = self.req_manager.req_to_token_indexs[self.b_req_idx[0], 0].item() - # b_loc[0, max_len_in_batch - 1].item() # [SYM] still reserve all kv cache self.b_att_seq_len = torch.zeros_like(self.b_seq_len) diff --git a/lightllm/models/mistral/layer_infer/transformer_layer_infer.py b/lightllm/models/mistral/layer_infer/transformer_layer_infer.py index e3e011f25..e84f8b0aa 100755 --- a/lightllm/models/mistral/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/mistral/layer_infer/transformer_layer_infer.py @@ -1,21 +1,10 @@ import torch -import torch.functional as F -import torch.distributed as dist -import numpy as np -from typing import Tuple -import triton -from lightllm.models.llama.layer_weights.transformer_layer_weight import LlamaTransformerLayerWeight from lightllm.models.llama.layer_infer.transformer_layer_infer import LlamaTransformerLayerInfer from lightllm.models.mistral.infer_struct import MistralInferStateInfo from lightllm.models.mistral.triton_kernel.context_flashattention_nopad import context_attention_fwd from lightllm.models.mistral.triton_kernel.token_attention_nopad_att1 import token_att_fwd -from lightllm.models.mistral.triton_kernel.token_attention_nopad_reduceV import token_att_fwd2 -from lightllm.models.llama.triton_kernel.token_attention_nopad_softmax import token_softmax_fwd -from lightllm.models.llama.triton_kernel.rotary_emb import rotary_emb_fwd - -from lightllm.common.basemodel.triton_kernel.destindex_copy_kv import destindex_copy_kv class MistralTransformerLayerInfer(LlamaTransformerLayerInfer): @@ -87,7 +76,6 @@ def _token_decode_attention_normal(self, q, infer_state: MistralInferStateInfo, infer_state.b_seq_len, infer_state.b_att_start_loc, infer_state.b_att_seq_len, - infer_state.other_kv_index, infer_state.sliding_window, ) return o_tensor diff --git a/lightllm/models/mistral/triton_kernel/token_attention_softmax_and_reducev.py b/lightllm/models/mistral/triton_kernel/token_attention_softmax_and_reducev.py index c37013f18..bf9928f98 100644 --- a/lightllm/models/mistral/triton_kernel/token_attention_softmax_and_reducev.py +++ b/lightllm/models/mistral/triton_kernel/token_attention_softmax_and_reducev.py @@ -1,7 +1,6 @@ import torch import triton import triton.language as tl -import torch.nn.functional as F @triton.jit @@ -94,7 +93,6 @@ def token_softmax_reducev_fwd( b_seq_len, b_att_start_loc, b_att_seq_len, - other_kv_index, sliding_window, ): BLOCK = 64 @@ -123,7 +121,7 @@ def token_softmax_reducev_fwd( o.stride(2), req_to_tokens.stride(0), req_to_tokens.stride(1), - other_kv_index, + 0, kv_group_num, sliding_window, BLOCK_DMODEL=v.shape[-1], diff --git a/lightllm/models/mixtral/infer_struct.py b/lightllm/models/mixtral/infer_struct.py index 426b28c5a..cfb5dcaf4 100644 --- a/lightllm/models/mixtral/infer_struct.py +++ b/lightllm/models/mixtral/infer_struct.py @@ -1,7 +1,5 @@ import torch import numpy as np -from lightllm.models.llama.infer_struct import LlamaInferStateInfo -from lightllm.common.req_manager import ReqManager from lightllm.models.mistral.infer_struct import MistralInferStateInfo from lightllm.models.mistral.triton_kernel.init_att_sliding_window_info import init_att_window_info_fwd from lightllm.utils.log_utils import init_logger @@ -36,8 +34,6 @@ def init_some_extra_state(self, model, input_ids: torch.Tensor): position_ids = self.b_seq_len - 1 self.position_cos = torch.index_select(model._cos_cached, 0, position_ids).view(self.b_seq_len.shape[0], -1) self.position_sin = torch.index_select(model._sin_cached, 0, position_ids).view(self.b_seq_len.shape[0], -1) - self.other_kv_index = self.req_manager.req_to_token_indexs[self.b_req_idx[0], 0].item() - # b_loc[0, max_len_in_batch - 1].item() # [SYM] still reserve all kv cache self.b_att_seq_len = torch.zeros_like(self.b_seq_len) diff --git a/lightllm/models/qwen/infer_struct.py b/lightllm/models/qwen/infer_struct.py index bd61d1326..5e7d0200b 100644 --- a/lightllm/models/qwen/infer_struct.py +++ b/lightllm/models/qwen/infer_struct.py @@ -2,31 +2,44 @@ import numpy as np from lightllm.models.llama.infer_struct import LlamaInferStateInfo + class QwenInferStateInfo(LlamaInferStateInfo): def __init__(self): super().__init__() self.position_cos = None self.position_sin = None - self.other_kv_index = None self.logn_values = None - def init_some_extra_state(self, model, input_ids : torch.Tensor): + def init_some_extra_state(self, model, input_ids: torch.Tensor): use_dynamic_ntk = model.config.get("use_dynamic_ntk", False) if not use_dynamic_ntk: super().init_some_extra_state(model, input_ids) return - + if self.is_prefill: - b_start_loc_numpy = self.b_start_loc.cpu().numpy() + b_start_loc_numpy = self.b_start_loc.cpu().numpy() b_seq_len_numpy = self.b_seq_len.cpu().numpy() - position_ids = torch.from_numpy(np.concatenate([np.arange(0, b_seq_len_numpy[i]) - for i in range(len(b_seq_len_numpy))], axis=0)).cuda() + position_ids = torch.from_numpy( + np.concatenate([np.arange(0, b_seq_len_numpy[i]) for i in range(len(b_seq_len_numpy))], axis=0) + ).cuda() self.position_sin = [] self.position_cos = [] - infer_ntk_id = torch.clamp(torch.ceil(torch.log2(self.b_seq_len / model.config.get("seq_length", 2048)) + 1), 0, model.max_ntk_alpha).long() + infer_ntk_id = torch.clamp( + torch.ceil(torch.log2(self.b_seq_len / model.config.get("seq_length", 2048)) + 1), + 0, + model.max_ntk_alpha, + ).long() for i in range(len(infer_ntk_id)): - self.position_sin.append(model._sin_cached[infer_ntk_id[i]][position_ids[b_start_loc_numpy[i]: b_start_loc_numpy[i] + b_seq_len_numpy[i]]]) - self.position_cos.append(model._cos_cached[infer_ntk_id[i]][position_ids[b_start_loc_numpy[i]: b_start_loc_numpy[i] + b_seq_len_numpy[i]]]) + self.position_sin.append( + model._sin_cached[infer_ntk_id[i]][ + position_ids[b_start_loc_numpy[i] : b_start_loc_numpy[i] + b_seq_len_numpy[i]] + ] + ) + self.position_cos.append( + model._cos_cached[infer_ntk_id[i]][ + position_ids[b_start_loc_numpy[i] : b_start_loc_numpy[i] + b_seq_len_numpy[i]] + ] + ) self.position_sin = torch.cat(self.position_sin, dim=0) self.position_cos = torch.cat(self.position_cos, dim=0) @@ -34,12 +47,15 @@ def init_some_extra_state(self, model, input_ids : torch.Tensor): self.logn_values = torch.index_select(model.logn_tensor, 0, position_ids).view(-1) position_ids = None else: - infer_ntk_id = torch.clamp(torch.ceil(torch.log2(self.b_seq_len / model.config.get("seq_length", 2048)) + 1), 0, model.max_ntk_alpha).long() + infer_ntk_id = torch.clamp( + torch.ceil(torch.log2(self.b_seq_len / model.config.get("seq_length", 2048)) + 1), + 0, + model.max_ntk_alpha, + ).long() position_ids = (self.b_seq_len - 1).long() self.position_cos = model._cos_cached[infer_ntk_id, position_ids].view(position_ids.shape[0], -1) self.position_sin = model._sin_cached[infer_ntk_id, position_ids].view(position_ids.shape[0], -1) if model.logn_tensor is not None: self.logn_values = torch.index_select(model.logn_tensor, 0, position_ids).view(-1) - self.other_kv_index = self.req_manager.req_to_token_indexs[self.b_req_idx[0], 0].item() position_ids = None return diff --git a/lightllm/models/starcoder/infer_struct.py b/lightllm/models/starcoder/infer_struct.py index 613dd3b16..3b8fd9fff 100644 --- a/lightllm/models/starcoder/infer_struct.py +++ b/lightllm/models/starcoder/infer_struct.py @@ -20,5 +20,4 @@ def init_some_extra_state(self, model, input_ids: torch.Tensor): ).cuda() else: self.position_ids = self.b_seq_len - 1 - self.other_kv_index = self.req_manager.req_to_token_indexs[self.b_req_idx[0], 0].item() return diff --git a/lightllm/models/starcoder2/layer_infer/transformer_layer_infer.py b/lightllm/models/starcoder2/layer_infer/transformer_layer_infer.py index f4e65e851..5442ec2b5 100644 --- a/lightllm/models/starcoder2/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/starcoder2/layer_infer/transformer_layer_infer.py @@ -111,7 +111,6 @@ def _token_decode_attention_normal(self, q, infer_state: MistralInferStateInfo, infer_state.b_seq_len, infer_state.b_att_start_loc, infer_state.b_att_seq_len, - infer_state.other_kv_index, infer_state.sliding_window, ) return o_tensor diff --git a/lightllm/models/vit/infer_struct.py b/lightllm/models/vit/infer_struct.py index 35a8e68bc..4a22f7f12 100644 --- a/lightllm/models/vit/infer_struct.py +++ b/lightllm/models/vit/infer_struct.py @@ -9,7 +9,6 @@ def __init__(self): super().__init__() self.position_cos = None self.position_sin = None - self.other_kv_index = None def init_some_extra_state(self, model, input_ids: torch.Tensor): if self.is_prefill: @@ -28,5 +27,4 @@ def init_some_extra_state(self, model, input_ids: torch.Tensor): position_ids = self.b_seq_len - 1 self.position_cos = torch.index_select(model._cos_cached, 0, position_ids).view(self.b_seq_len.shape[0], -1) self.position_sin = torch.index_select(model._sin_cached, 0, position_ids).view(self.b_seq_len.shape[0], -1) - self.other_kv_index = self.req_manager.req_to_token_indexs[self.b_req_idx[0], 0].item() return diff --git a/lightllm/utils/device_utils.py b/lightllm/utils/device_utils.py index 51cc10174..b9d4b3abf 100644 --- a/lightllm/utils/device_utils.py +++ b/lightllm/utils/device_utils.py @@ -55,6 +55,21 @@ def get_device_warp_size(): return properties["warpSize"] +def calcu_kernel_best_vsm_count(kernel, num_warps): + n_regs = kernel.n_regs + size_smem = kernel.metadata.shared + + sm_count = get_device_sm_count() + max_regs = get_device_sm_regs_num() + shared_mem_max = get_device_sm_shared_mem_num() + warp_size = get_device_warp_size() + + occupancy = max_regs // (n_regs * warp_size * num_warps) + occupancy = min(occupancy, shared_mem_max // size_smem) + num_sm = sm_count * occupancy + return num_sm + + @lru_cache(maxsize=None) def get_current_device_name(): import torch diff --git a/test/kernel/deepseekv2_gqa_decode_tuning.py b/test/kernel/deepseekv2_gqa_decode_tuning.py index b39d23904..ed08c4ce9 100644 --- a/test/kernel/deepseekv2_gqa_decode_tuning.py +++ b/test/kernel/deepseekv2_gqa_decode_tuning.py @@ -5,6 +5,7 @@ from typing import List from lightllm.utils.log_utils import init_logger from lightllm.models.deepseek2.triton_kernel.gqa_flash_decoding import gqa_token_decode_attention_flash_decoding +from lightllm.utils.watchdog_utils import Watchdog logger = init_logger(__name__) @@ -49,6 +50,7 @@ def test_decode_attentions( ).cuda() infer_state.b_req_idx = torch.arange(0, infer_state.batch_size, step=1, dtype=torch.int32).cuda() infer_state.b_seq_len = torch.full((infer_state.batch_size,), fill_value=test_seq_len, dtype=torch.int32).cuda() + infer_state.total_token_num_tensor = torch.sum(infer_state.b_seq_len) input_tuples = [] for _ in range(test_count): @@ -137,6 +139,9 @@ def worker( test_configs, queue, ): + dog = Watchdog(timeout=10) + dog.start() + try: for index in range(len(test_configs)): tuning_config = test_configs[index] @@ -150,9 +155,10 @@ def worker( test_count=test_count, **tuning_config, ) + dog.heartbeat() queue.put(cost_time) # Put result in queue except Exception as ex: - logger.error(str(ex)) + logger.error(str(ex) + f"config {tuning_config}") import sys sys.exit(-1) @@ -161,28 +167,41 @@ def worker( def get_test_configs(split_id, split_count): index = 0 - for block_seq in [32, 64, 128, 256]: - for block_n in [16, 32, 64, 128, 256]: - for block_q_head in [16, 32, 64]: - for stage1_num_warps in [1, 2, 4, 8, 16]: - for stage1_num_stages in [1, 2, 3, 4, 5]: - for stage2_num_warps in [1, 2, 4, 8, 16]: - for stage2_num_stages in [1, 2, 3, 4, 5]: - if block_seq % block_n == 0: - t_config = { - "BLOCK_SEQ": block_seq, - "BLOCK_N": block_n, - "BLOCK_Q_HEAD": block_q_head, - "stage1_num_warps": stage1_num_warps, - "stage1_num_stages": stage1_num_stages, - "stage2_num_warps": stage2_num_warps, - "stage2_num_stages": stage2_num_stages, - } - if index % split_count == split_id: - yield t_config - index += 1 - else: - index += 1 + for block_n in [16, 32]: + for block_q_head in [ + 16, + ]: + for stage1_num_warps in [2, 4, 8, 16]: + for stage1_num_stages in [ + 1, + 2, + 3, + 4, + 5, + 6, + 7, + 8, + 12, + 15, + ]: + for stage2_num_warps in [1, 2, 4]: + for stage2_num_stages in [ + 1, + 3, + ]: + t_config = { + "BLOCK_N": block_n, + "BLOCK_Q_HEAD": block_q_head, + "stage1_num_warps": stage1_num_warps, + "stage1_num_stages": stage1_num_stages, + "stage2_num_warps": stage2_num_warps, + "stage2_num_stages": stage2_num_stages, + } + if index % split_count == split_id: + yield t_config + index += 1 + else: + index += 1 def tuning_configs( @@ -233,7 +252,7 @@ def tuning_configs( del test_configs[0:1] except: logger.info(f"cur best {best_config}, {best_cost_time}") - del test_configs[0:16] + del test_configs[0:1] break while len(test_configs) != 0: @@ -265,7 +284,7 @@ def tuning_configs( del test_configs[0:1] except: logger.info(f"cur best {best_config}, {best_cost_time}") - del test_configs[0:16] + del test_configs[0:1] break logger.info(f"{best_config} best cost: {best_cost_time}") @@ -298,7 +317,7 @@ def tuning_configs( "kv_rope_shape": [None, 1, q_rope_dim], "test_seq_len": seq_len, "dtype": torch.bfloat16, - "test_count": 20, + "test_count": 40, }, ) store_json_ans[seq_len][batch_size] = ans