Skip to content

Commit

Permalink
better cuda graph mla decode kernel. (#694)
Browse files Browse the repository at this point in the history
Co-authored-by: wangzaijun <[email protected]>
  • Loading branch information
hiworldwzj and wangzaijun authored Jan 3, 2025
1 parent d5edea4 commit e2a39e4
Show file tree
Hide file tree
Showing 23 changed files with 448 additions and 430 deletions.
Original file line number Diff line number Diff line change
@@ -1 +1 @@
{"256": {"1": {"BLOCK_SEQ": 32, "BLOCK_N": 32, "BLOCK_Q_HEAD": 32, "stage1_num_warps": 4, "stage1_num_stages": 2, "stage2_num_warps": 8, "stage2_num_stages": 1}, "8": {"BLOCK_SEQ": 32, "BLOCK_N": 32, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 4, "stage1_num_stages": 3, "stage2_num_warps": 4, "stage2_num_stages": 1}, "16": {"BLOCK_SEQ": 64, "BLOCK_N": 32, "BLOCK_Q_HEAD": 32, "stage1_num_warps": 4, "stage1_num_stages": 4, "stage2_num_warps": 8, "stage2_num_stages": 1}, "32": {"BLOCK_SEQ": 128, "BLOCK_N": 32, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 4, "stage1_num_stages": 4, "stage2_num_warps": 2, "stage2_num_stages": 1}, "64": {"BLOCK_SEQ": 256, "BLOCK_N": 32, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 4, "stage1_num_stages": 4, "stage2_num_warps": 1, "stage2_num_stages": 1}, "128": {"BLOCK_SEQ": 256, "BLOCK_N": 32, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 4, "stage1_num_stages": 3, "stage2_num_warps": 1, "stage2_num_stages": 1}, "256": {"BLOCK_SEQ": 256, "BLOCK_N": 16, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 2, "stage1_num_stages": 3, "stage2_num_warps": 2, "stage2_num_stages": 1}}, "512": {"1": {"BLOCK_SEQ": 32, "BLOCK_N": 32, "BLOCK_Q_HEAD": 32, "stage1_num_warps": 4, "stage1_num_stages": 2, "stage2_num_warps": 8, "stage2_num_stages": 1}, "8": {"BLOCK_SEQ": 64, "BLOCK_N": 32, "BLOCK_Q_HEAD": 64, "stage1_num_warps": 4, "stage1_num_stages": 4, "stage2_num_warps": 8, "stage2_num_stages": 1}, "16": {"BLOCK_SEQ": 128, "BLOCK_N": 32, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 4, "stage1_num_stages": 4, "stage2_num_warps": 4, "stage2_num_stages": 1}, "32": {"BLOCK_SEQ": 256, "BLOCK_N": 32, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 4, "stage1_num_stages": 4, "stage2_num_warps": 8, "stage2_num_stages": 1}, "64": {"BLOCK_SEQ": 128, "BLOCK_N": 32, "BLOCK_Q_HEAD": 64, "stage1_num_warps": 4, "stage1_num_stages": 3, "stage2_num_warps": 4, "stage2_num_stages": 1}, "128": {"BLOCK_SEQ": 256, "BLOCK_N": 16, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 2, "stage1_num_stages": 3, "stage2_num_warps": 1, "stage2_num_stages": 1}, "256": {"BLOCK_SEQ": 256, "BLOCK_N": 32, "BLOCK_Q_HEAD": 64, "stage1_num_warps": 4, "stage1_num_stages": 3, "stage2_num_warps": 4, "stage2_num_stages": 1}}, "1024": {"1": {"BLOCK_SEQ": 32, "BLOCK_N": 32, "BLOCK_Q_HEAD": 32, "stage1_num_warps": 4, "stage1_num_stages": 3, "stage2_num_warps": 8, "stage2_num_stages": 1}, "8": {"BLOCK_SEQ": 128, "BLOCK_N": 32, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 4, "stage1_num_stages": 4, "stage2_num_warps": 8, "stage2_num_stages": 1}, "16": {"BLOCK_SEQ": 256, "BLOCK_N": 32, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 4, "stage1_num_stages": 4, "stage2_num_warps": 8, "stage2_num_stages": 1}, "32": {"BLOCK_SEQ": 128, "BLOCK_N": 32, "BLOCK_Q_HEAD": 32, "stage1_num_warps": 4, "stage1_num_stages": 3, "stage2_num_warps": 4, "stage2_num_stages": 1}, "64": {"BLOCK_SEQ": 256, "BLOCK_N": 16, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 2, "stage1_num_stages": 3, "stage2_num_warps": 2, "stage2_num_stages": 1}, "128": {"BLOCK_SEQ": 256, "BLOCK_N": 32, "BLOCK_Q_HEAD": 64, "stage1_num_warps": 4, "stage1_num_stages": 3, "stage2_num_warps": 1, "stage2_num_stages": 2}, "256": {"BLOCK_SEQ": 256, "BLOCK_N": 32, "BLOCK_Q_HEAD": 32, "stage1_num_warps": 4, "stage1_num_stages": 3, "stage2_num_warps": 1, "stage2_num_stages": 1}}, "2048": {"1": {"BLOCK_SEQ": 64, "BLOCK_N": 32, "BLOCK_Q_HEAD": 32, "stage1_num_warps": 4, "stage1_num_stages": 4, "stage2_num_warps": 8, "stage2_num_stages": 1}, "8": {"BLOCK_SEQ": 256, "BLOCK_N": 32, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 4, "stage1_num_stages": 4, "stage2_num_warps": 8, "stage2_num_stages": 1}, "16": {"BLOCK_SEQ": 128, "BLOCK_N": 32, "BLOCK_Q_HEAD": 64, "stage1_num_warps": 4, "stage1_num_stages": 3, "stage2_num_warps": 8, "stage2_num_stages": 1}, "32": {"BLOCK_SEQ": 256, "BLOCK_N": 16, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 2, "stage1_num_stages": 3, "stage2_num_warps": 8, "stage2_num_stages": 1}, "64": {"BLOCK_SEQ": 256, "BLOCK_N": 32, "BLOCK_Q_HEAD": 64, "stage1_num_warps": 4, "stage1_num_stages": 4, "stage2_num_warps": 2, "stage2_num_stages": 1}, "128": {"BLOCK_SEQ": 256, "BLOCK_N": 32, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 4, "stage1_num_stages": 3, "stage2_num_warps": 1, "stage2_num_stages": 1}, "256": {"BLOCK_SEQ": 256, "BLOCK_N": 32, "BLOCK_Q_HEAD": 64, "stage1_num_warps": 4, "stage1_num_stages": 3, "stage2_num_warps": 4, "stage2_num_stages": 1}}, "4096": {"1": {"BLOCK_SEQ": 128, "BLOCK_N": 32, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 4, "stage1_num_stages": 4, "stage2_num_warps": 8, "stage2_num_stages": 1}, "8": {"BLOCK_SEQ": 256, "BLOCK_N": 32, "BLOCK_Q_HEAD": 32, "stage1_num_warps": 4, "stage1_num_stages": 3, "stage2_num_warps": 8, "stage2_num_stages": 1}, "16": {"BLOCK_SEQ": 256, "BLOCK_N": 16, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 2, "stage1_num_stages": 3, "stage2_num_warps": 1, "stage2_num_stages": 1}, "32": {"BLOCK_SEQ": 256, "BLOCK_N": 32, "BLOCK_Q_HEAD": 64, "stage1_num_warps": 4, "stage1_num_stages": 4, "stage2_num_warps": 4, "stage2_num_stages": 1}, "64": {"BLOCK_SEQ": 256, "BLOCK_N": 32, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 4, "stage1_num_stages": 3, "stage2_num_warps": 1, "stage2_num_stages": 1}, "128": {"BLOCK_SEQ": 256, "BLOCK_N": 32, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 4, "stage1_num_stages": 3, "stage2_num_warps": 1, "stage2_num_stages": 1}}, "8192": {"1": {"BLOCK_SEQ": 128, "BLOCK_N": 32, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 4, "stage1_num_stages": 4, "stage2_num_warps": 8, "stage2_num_stages": 1}, "8": {"BLOCK_SEQ": 256, "BLOCK_N": 32, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 4, "stage1_num_stages": 4, "stage2_num_warps": 4, "stage2_num_stages": 5}, "16": {"BLOCK_SEQ": 256, "BLOCK_N": 32, "BLOCK_Q_HEAD": 64, "stage1_num_warps": 4, "stage1_num_stages": 4, "stage2_num_warps": 4, "stage2_num_stages": 5}, "32": {"BLOCK_SEQ": 256, "BLOCK_N": 32, "BLOCK_Q_HEAD": 64, "stage1_num_warps": 4, "stage1_num_stages": 3, "stage2_num_warps": 4, "stage2_num_stages": 5}, "64": {"BLOCK_SEQ": 256, "BLOCK_N": 32, "BLOCK_Q_HEAD": 32, "stage1_num_warps": 4, "stage1_num_stages": 3, "stage2_num_warps": 4, "stage2_num_stages": 5}}}
{"256": {"1": {"BLOCK_N": 32, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 4, "stage1_num_stages": 2, "stage2_num_warps": 4, "stage2_num_stages": 3}, "8": {"BLOCK_N": 32, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 8, "stage1_num_stages": 3, "stage2_num_warps": 4, "stage2_num_stages": 3}, "16": {"BLOCK_N": 32, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 4, "stage1_num_stages": 2, "stage2_num_warps": 4, "stage2_num_stages": 3}, "32": {"BLOCK_N": 32, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 4, "stage1_num_stages": 2, "stage2_num_warps": 4, "stage2_num_stages": 3}, "64": {"BLOCK_N": 16, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 4, "stage1_num_stages": 3, "stage2_num_warps": 1, "stage2_num_stages": 1}, "128": {"BLOCK_N": 16, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 4, "stage1_num_stages": 3, "stage2_num_warps": 2, "stage2_num_stages": 1}, "256": {"BLOCK_N": 16, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 4, "stage1_num_stages": 3, "stage2_num_warps": 1, "stage2_num_stages": 1}}, "512": {"1": {"BLOCK_N": 32, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 8, "stage1_num_stages": 3, "stage2_num_warps": 4, "stage2_num_stages": 3}, "8": {"BLOCK_N": 32, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 4, "stage1_num_stages": 2, "stage2_num_warps": 4, "stage2_num_stages": 3}, "16": {"BLOCK_N": 32, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 4, "stage1_num_stages": 2, "stage2_num_warps": 4, "stage2_num_stages": 1}, "32": {"BLOCK_N": 16, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 4, "stage1_num_stages": 3, "stage2_num_warps": 1, "stage2_num_stages": 3}, "64": {"BLOCK_N": 16, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 4, "stage1_num_stages": 4, "stage2_num_warps": 1, "stage2_num_stages": 1}, "128": {"BLOCK_N": 16, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 4, "stage1_num_stages": 5, "stage2_num_warps": 1, "stage2_num_stages": 1}, "256": {"BLOCK_N": 32, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 8, "stage1_num_stages": 4, "stage2_num_warps": 4, "stage2_num_stages": 1}}, "1024": {"1": {"BLOCK_N": 32, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 8, "stage1_num_stages": 3, "stage2_num_warps": 4, "stage2_num_stages": 3}, "8": {"BLOCK_N": 16, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 8, "stage1_num_stages": 6, "stage2_num_warps": 4, "stage2_num_stages": 3}, "16": {"BLOCK_N": 32, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 8, "stage1_num_stages": 3, "stage2_num_warps": 4, "stage2_num_stages": 3}, "32": {"BLOCK_N": 32, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 4, "stage1_num_stages": 2, "stage2_num_warps": 4, "stage2_num_stages": 3}, "64": {"BLOCK_N": 16, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 4, "stage1_num_stages": 4, "stage2_num_warps": 1, "stage2_num_stages": 1}, "128": {"BLOCK_N": 16, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 4, "stage1_num_stages": 5, "stage2_num_warps": 2, "stage2_num_stages": 1}, "256": {"BLOCK_N": 16, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 4, "stage1_num_stages": 6, "stage2_num_warps": 4, "stage2_num_stages": 1}}, "2048": {"1": {"BLOCK_N": 32, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 8, "stage1_num_stages": 3, "stage2_num_warps": 4, "stage2_num_stages": 3}, "8": {"BLOCK_N": 32, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 8, "stage1_num_stages": 3, "stage2_num_warps": 4, "stage2_num_stages": 3}, "16": {"BLOCK_N": 32, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 4, "stage1_num_stages": 2, "stage2_num_warps": 4, "stage2_num_stages": 3}, "32": {"BLOCK_N": 16, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 4, "stage1_num_stages": 4, "stage2_num_warps": 4, "stage2_num_stages": 3}, "64": {"BLOCK_N": 16, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 4, "stage1_num_stages": 4, "stage2_num_warps": 1, "stage2_num_stages": 1}, "128": {"BLOCK_N": 16, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 4, "stage1_num_stages": 5, "stage2_num_warps": 1, "stage2_num_stages": 1}, "256": {"BLOCK_N": 32, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 4, "stage1_num_stages": 2, "stage2_num_warps": 1, "stage2_num_stages": 3}}, "4096": {"1": {"BLOCK_N": 32, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 4, "stage1_num_stages": 2, "stage2_num_warps": 4, "stage2_num_stages": 3}, "8": {"BLOCK_N": 32, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 8, "stage1_num_stages": 3, "stage2_num_warps": 4, "stage2_num_stages": 3}, "16": {"BLOCK_N": 32, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 8, "stage1_num_stages": 4, "stage2_num_warps": 4, "stage2_num_stages": 3}, "32": {"BLOCK_N": 32, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 8, "stage1_num_stages": 4, "stage2_num_warps": 2, "stage2_num_stages": 1}, "64": {"BLOCK_N": 32, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 4, "stage1_num_stages": 2, "stage2_num_warps": 4, "stage2_num_stages": 1}, "128": {"BLOCK_N": 16, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 4, "stage1_num_stages": 5, "stage2_num_warps": 2, "stage2_num_stages": 1}}, "8192": {"1": {"BLOCK_N": 32, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 4, "stage1_num_stages": 2, "stage2_num_warps": 4, "stage2_num_stages": 3}, "8": {"BLOCK_N": 32, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 8, "stage1_num_stages": 4, "stage2_num_warps": 4, "stage2_num_stages": 3}, "16": {"BLOCK_N": 32, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 8, "stage1_num_stages": 4, "stage2_num_warps": 1, "stage2_num_stages": 3}, "32": {"BLOCK_N": 16, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 4, "stage1_num_stages": 4, "stage2_num_warps": 1, "stage2_num_stages": 3}, "64": {"BLOCK_N": 16, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 4, "stage1_num_stages": 6, "stage2_num_warps": 1, "stage2_num_stages": 3}}}
1 change: 0 additions & 1 deletion lightllm/common/basemodel/basemodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
from lightllm.common.basemodel.cuda_graph import CudaGraph
from lightllm.common.quantization import Quantcfg
from lightllm.utils.log_utils import init_logger
from lightllm.common.basemodel.infer_lock import g_infer_state_lock

logger = init_logger(__name__)

Expand Down
1 change: 1 addition & 0 deletions lightllm/models/deepseek2/infer_struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ def init_some_extra_state(self, model, input_ids: torch.Tensor):
# 只有 decode 阶段使用 ppl 的优化算子才会有这个管理变量
if not self.is_prefill:
self.kv_starts = torch.cat([self.b_start_loc, self.b_start_loc[-1:] + self.b_seq_len[-1:]], dim=0)
self.total_token_num_tensor = torch.sum(self.b_seq_len)

if self.enable_dp:
rank = dist.get_rank()
Expand Down
6 changes: 4 additions & 2 deletions lightllm/models/deepseek2/model.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
import os
import json
import torch
from lightllm.models.deepseek2.layer_infer.transformer_layer_infer import Deepseek2TransformerLayerInfer
from lightllm.models.deepseek2.layer_weights.transformer_layer_weight import Deepseek2TransformerLayerWeight
from lightllm.models.deepseek2.infer_struct import Deepseek2InferStateInfo
from lightllm.models.deepseek2.splitfuse_infer_struct import DeepSeekv2SplitFuseInferStateInfo
from lightllm.common.basemodel.layer_weights.hf_load_utils import load_hf_weights

from lightllm.models.llama.model import LlamaTpPartModel
Expand All @@ -24,6 +23,9 @@ class Deepseek2TpPartModel(LlamaTpPartModel):
# infer state class
infer_state_class = Deepseek2InferStateInfo

# split fuse state class
splitfuse_infer_state_class = DeepSeekv2SplitFuseInferStateInfo

def __init__(self, kvargs):
super().__init__(kvargs)
return
Expand Down
35 changes: 35 additions & 0 deletions lightllm/models/deepseek2/splitfuse_infer_struct.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
import torch
import numpy as np
from lightllm.common.basemodel import SplitFuseInferStateInfo
from .infer_struct import Deepseek2InferStateInfo


class DeepSeekv2SplitFuseInferStateInfo(SplitFuseInferStateInfo):

inner_infer_state_class = Deepseek2InferStateInfo

def __init__(self):
super().__init__()
self.position_cos = None
self.position_sin = None

def init_some_extra_state(self, model, input_ids: torch.Tensor):
position_ids = []
if self.decode_req_num != 0:
position_ids.append((self.decode_b_seq_len - 1).cpu().numpy())
if self.prefill_req_num != 0:
b_seq_len_numpy = self.prefill_b_seq_len.cpu().numpy()
b_ready_cache_len_numpy = self.prefill_b_split_ready_cache_len.cpu().numpy()
position_ids.extend(
[np.arange(b_ready_cache_len_numpy[i], b_seq_len_numpy[i]) for i in range(len(b_seq_len_numpy))]
)

position_ids = torch.from_numpy(np.concatenate(position_ids, axis=0)).cuda().view(-1)
self.position_cos = torch.index_select(model._cos_cached, 0, position_ids).view(position_ids.shape[0], -1)
self.position_sin = torch.index_select(model._sin_cached, 0, position_ids).view(position_ids.shape[0], -1)
return

def create_inner_decode_infer_status(self):
infer_state = super().create_inner_decode_infer_status()
infer_state.total_token_num_tensor = torch.sum(infer_state.b_seq_len)
return
52 changes: 43 additions & 9 deletions lightllm/models/deepseek2/triton_kernel/gqa_flash_decoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from typing import List
from lightllm.utils.log_utils import init_logger
from .gqa_flash_decoding_config import MlaDecodeAttentionKernelConfig
from lightllm.utils.device_utils import get_device_sm_count

logger = init_logger(__name__)

Expand Down Expand Up @@ -43,36 +44,69 @@ def gqa_token_decode_attention_flash_decoding(
out_dtype=torch.bfloat16,
)

BLOCK_SEQ = run_config["BLOCK_SEQ"]

from .gqa_flash_decoding_stage1 import flash_decode_stage1
from .gqa_flash_decoding_stage2 import flash_decode_stage2

o_tensor = alloc_tensor_func(q_nope.shape, q_nope.dtype, q_nope.device) if out is None else out

mid_o = alloc_tensor_func(
[batch_size, q_head_num, max_len_in_batch // BLOCK_SEQ + 1, kv_lora_rank], dtype=torch.float32, device="cuda"
mid_o_block_seq = torch.empty([1], dtype=torch.int64, device="cuda")
mid_o_batch_start_index = alloc_tensor_func(
[
batch_size,
],
dtype=torch.int64,
device="cuda",
)
mid_o_logexpsum = alloc_tensor_func(
[batch_size, q_head_num, max_len_in_batch // BLOCK_SEQ + 1], dtype=torch.float32, device="cuda"

mid_o = torch.empty([q_head_num, 0, kv_lora_rank], dtype=torch.float32, device="cuda")
mid_o_logexpsum = torch.empty([q_head_num, 0], dtype=torch.float32, device="cuda")

vsm_count = flash_decode_stage1(
infer_state.total_token_num_tensor,
mid_o_block_seq,
mid_o_batch_start_index,
q_nope.view(calcu_shape1),
q_rope.view(calcu_shape2),
kv_nope,
kv_rope,
infer_state.req_manager.req_to_token_indexs,
infer_state.b_req_idx,
infer_state.b_seq_len,
mid_o,
mid_o_logexpsum,
softmax_scale,
get_sm_count=True,
**run_config
)

mid_o = torch.empty([q_head_num, vsm_count * 4 + batch_size, kv_lora_rank], dtype=torch.float32, device="cuda")
mid_o_logexpsum = torch.empty([q_head_num, vsm_count * 4 + batch_size], dtype=torch.float32, device="cuda")

flash_decode_stage1(
infer_state.total_token_num_tensor,
mid_o_block_seq,
mid_o_batch_start_index,
q_nope.view(calcu_shape1),
q_rope.view(calcu_shape2),
kv_nope,
kv_rope,
infer_state.req_manager.req_to_token_indexs,
infer_state.b_req_idx,
infer_state.b_seq_len,
infer_state.max_len_in_batch,
mid_o,
mid_o_logexpsum,
BLOCK_SEQ,
softmax_scale,
get_sm_count=False,
**run_config
)

flash_decode_stage2(
mid_o, mid_o_logexpsum, infer_state.b_seq_len, o_tensor.view(calcu_shape1), BLOCK_SEQ, **run_config
mid_o_block_seq,
mid_o_batch_start_index,
mid_o,
mid_o_logexpsum,
infer_state.b_seq_len,
o_tensor.view(calcu_shape1),
**run_config
)
return o_tensor
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@ def try_to_get_best_config(
return config
else:
config = {
"BLOCK_SEQ": 64,
"BLOCK_N": 16,
"BLOCK_Q_HEAD": 16,
"stage1_num_warps": 4,
Expand Down
Loading

0 comments on commit e2a39e4

Please sign in to comment.