Skip to content

Commit

Permalink
feat: gqa flash decode able to run
Browse files Browse the repository at this point in the history
  • Loading branch information
PannenetsF committed Jan 7, 2025
1 parent 6ede09e commit f7cd225
Show file tree
Hide file tree
Showing 5 changed files with 675 additions and 1 deletion.
22 changes: 22 additions & 0 deletions lightllm/models/llama/layer_infer/transformer_layer_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,11 @@ def _bind_attention(self):
LlamaTransformerLayerInfer._token_decode_attention_gqa_flashdecoding, self
)
self._copy_kv_to_mem_cache = partial(LlamaTransformerLayerInfer._copy_kv_to_mem_cache_normal, self)
elif "triton_gqa_flashdecoding_vsm" in self.mode:
self._token_attention_kernel = partial(
LlamaTransformerLayerInfer._token_decode_attention_gqa_flashdecoding_vsm, self
)
self._copy_kv_to_mem_cache = partial(LlamaTransformerLayerInfer._copy_kv_to_mem_cache_normal, self)
else:
self._token_attention_kernel = partial(LlamaTransformerLayerInfer._token_decode_attention_normal, self)
self._copy_kv_to_mem_cache = partial(LlamaTransformerLayerInfer._copy_kv_to_mem_cache_normal, self)
Expand Down Expand Up @@ -587,3 +592,20 @@ def _token_decode_attention_ppl_int4kv_flashdecoding(
out=out,
alloc_tensor_func=self.alloc_tensor,
)

def _token_decode_attention_gqa_flashdecoding_vsm( self, q, infer_state: LlamaInferStateInfo, layer_weight, out=None):
from lightllm.models.llama.triton_kernel.gqa_flash_decoding_vsm import gqa_token_decode_attention_flash_decoding_vsm

cache_k = infer_state.mem_manager.kv_buffer[self.layer_num_][:, 0 : self.tp_k_head_num_, :]
cache_v = infer_state.mem_manager.kv_buffer[self.layer_num_][
:, self.tp_k_head_num_ : self.tp_k_head_num_ + self.tp_v_head_num_, :
]
q_shape = (infer_state.batch_size, self.tp_q_head_num_, self.head_dim_)
return gqa_token_decode_attention_flash_decoding_vsm(
q.view(q_shape),
cache_k,
cache_v,
infer_state,
out=out,
alloc_tensor_func=self.alloc_tensor,
)
Loading

0 comments on commit f7cd225

Please sign in to comment.