From d5edea49a73b3b1bdc55235684f96769f78f4949 Mon Sep 17 00:00:00 2001 From: shihaobai <42648726+shihaobai@users.noreply.github.com> Date: Thu, 2 Jan 2025 18:56:24 +0800 Subject: [PATCH] add custom allgather(into tensor) (#692) Co-authored-by: baishihao Co-authored-by: hiworldwzj <30762946+hiworldwzj@users.noreply.github.com> --- lightllm/common/basemodel/cuda_graph.py | 4 +- lightllm/common/vllm_kernel/_custom_ops.py | 10 + lightllm/common/vllm_kernel/_ops.py | 27 ++ lightllm/distributed/communication_op.py | 109 +++++--- lightllm/distributed/custom_all_gather.py | 249 ++++++++++++++++++ .../layer_infer/transformer_layer_infer.py | 33 ++- .../model_infer/mode_backend/base_backend.py | 5 +- test/model/model_infer_vit.py | 4 +- .../model/test_settings/model_infer_batchs.py | 5 +- 9 files changed, 390 insertions(+), 56 deletions(-) create mode 100644 lightllm/distributed/custom_all_gather.py diff --git a/lightllm/common/basemodel/cuda_graph.py b/lightllm/common/basemodel/cuda_graph.py index 5163a260d..85475c3ee 100644 --- a/lightllm/common/basemodel/cuda_graph.py +++ b/lightllm/common/basemodel/cuda_graph.py @@ -1,7 +1,7 @@ import os import torch from lightllm.utils.log_utils import init_logger -from lightllm.distributed import lightllm_capture_graph +from lightllm.distributed import custom_comm_ops logger = init_logger(__name__) @@ -31,7 +31,7 @@ def capture_decode(self, decode_func, input_ids, infer_state): torch.cuda.synchronize() decode_func(input_ids, infer_state) torch.cuda.synchronize() - with lightllm_capture_graph(): + with custom_comm_ops.lightllm_capture_graph(): with torch.cuda.graph(graph_obj, pool=self.mempool): predict_logics = decode_func(input_ids, infer_state) self.graph[batch_size] = (graph_obj, input_ids, infer_state, predict_logics) diff --git a/lightllm/common/vllm_kernel/_custom_ops.py b/lightllm/common/vllm_kernel/_custom_ops.py index 5a0fbbab9..7de42c3c6 100644 --- a/lightllm/common/vllm_kernel/_custom_ops.py +++ b/lightllm/common/vllm_kernel/_custom_ops.py @@ -49,3 +49,13 @@ def select_experts( except ImportError: logger.error("vllm or lightllm_kernel is not installed, you can't use custom ops") + +try: + from lightllm.common.vllm_kernel._ops import init_custom_gather_ar + from lightllm.common.vllm_kernel._ops import all_gather + from lightllm.common.vllm_kernel._ops import allgather_dispose + from lightllm.common.vllm_kernel._ops import allgather_register_buffer + from lightllm.common.vllm_kernel._ops import allgather_get_graph_buffer_ipc_meta + from lightllm.common.vllm_kernel._ops import allgather_register_graph_buffers +except ImportError: + logger.error("lightllm_kernel is not installed, you can't use custom allgather") diff --git a/lightllm/common/vllm_kernel/_ops.py b/lightllm/common/vllm_kernel/_ops.py index 38cb850c6..d8c8bd4b6 100644 --- a/lightllm/common/vllm_kernel/_ops.py +++ b/lightllm/common/vllm_kernel/_ops.py @@ -1069,6 +1069,33 @@ def register_graph_buffers(fa: int, handles: List[List[int]], offsets: List[List torch.ops.vllm_total_custom_ar.register_graph_buffers(fa, handles, offsets) +# custom ar +def init_custom_gather_ar( + ipc_tensors: List[torch.Tensor], rank_data: torch.Tensor, rank: int, full_nvlink: bool +) -> int: + return torch.ops.vllm_total_custom_ar.init_custom_gather_ar(ipc_tensors, rank_data, rank, full_nvlink) + + +def all_gather(fa: int, inp: torch.Tensor, out: torch.Tensor, reg_buffer: int, reg_buffer_sz_bytes: int) -> None: + torch.ops.vllm_total_custom_ar.all_gather(fa, inp, out, reg_buffer, reg_buffer_sz_bytes) + + +def allgather_dispose(fa: int) -> None: + torch.ops.vllm_total_custom_ar.allgather_dispose(fa) + + +def allgather_register_buffer(fa: int, ipc_tensors: List[int]) -> None: + return torch.ops.vllm_total_custom_ar.allgather_register_buffer(fa, ipc_tensors) + + +def allgather_get_graph_buffer_ipc_meta(fa: int) -> Tuple[List[int], List[int]]: + return torch.ops.vllm_total_custom_ar.allgather_get_graph_buffer_ipc_meta(fa) + + +def allgather_register_graph_buffers(fa: int, handles: List[List[int]], offsets: List[List[int]]) -> None: + torch.ops.vllm_total_custom_ar.allgather_register_graph_buffers(fa, handles, offsets) + + # temporary fix for https://github.com/vllm-project/vllm/issues/5456 # TODO: remove this in v0.6.0 names_and_values = globals() diff --git a/lightllm/distributed/communication_op.py b/lightllm/distributed/communication_op.py index 8683bc498..a6c584fd7 100644 --- a/lightllm/distributed/communication_op.py +++ b/lightllm/distributed/communication_op.py @@ -27,6 +27,8 @@ from functools import partial original_all_reduce = torch.distributed.all_reduce +original_all_gather_into_tensor = torch.distributed.all_gather_into_tensor + from contextlib import nullcontext, contextmanager logger = init_logger(__name__) @@ -40,43 +42,80 @@ HAS_VLLM = False logger.info("vllm or lightllm_kernel is not installed, you can't use custom allreduce") -vllm_reduce = None +try: + HAS_LIGHTLLM_KERNEL = True + from .custom_all_gather import CustomAllgather + logger.info("using custom allgather") +except: + HAS_LIGHTLLM_KERNEL = False + logger.info("lightllm_kernel is not installed, you can't use custom allgather") -@contextmanager -def lightllm_capture_graph(): - if vllm_reduce is not None: - with vllm_reduce.capture(): - yield - else: - yield - pass - - -def set_custom_reduce(): - global vllm_reduce - global device_group - ENABLE_VLLM_REDUCE = os.getenv("ENABLE_VLLM_REDUCE", "False").upper() in [ - "ON", - "TRUE", - "1", - ] - world_size = dist.get_world_size() - ranks = list(range(world_size)) - # new_group prevent stuck of torch origin all_reduce with cudagraph - device_group = torch.distributed.new_group(ranks, backend="nccl") - if ENABLE_VLLM_REDUCE and HAS_VLLM: - cpu_group = torch.distributed.new_group(ranks, backend="gloo") - vllm_reduce = CustomAllreduce(cpu_group, torch.cuda.current_device()) - logger.info("Enable VLLM ALLReduce.") - - def _all_reduce_closure(input_, op=ReduceOp.SUM, group=device_group, async_op=False): - if op != ReduceOp.SUM or async_op: - original_all_reduce(input_, op, group, async_op) + +class CustomCommunicationOp: + def __init__(self): + self.vllm_reduce = None + self.custom_gather = None + self.device_group = None + + @contextmanager + def lightllm_capture_graph(self): + if self.vllm_reduce is not None: + with self.vllm_reduce.capture(): + if self.custom_gather is not None: + with self.custom_gather.capture(): + yield + else: + yield else: - if vllm_reduce is not None and vllm_reduce.should_custom_ar(input_): - input_.data = vllm_reduce.custom_all_reduce(input_) - else: + yield + + def set_custom_reduce(self): + ENABLE_VLLM_REDUCE = os.getenv("ENABLE_VLLM_REDUCE", "False").upper() in ["ON", "TRUE", "1"] + world_size = dist.get_world_size() + ranks = list(range(world_size)) + + # 创建新的 NCCL 组以防止原始 all_reduce 与 cudagraph 卡住 + if self.device_group is None: + self.device_group = dist.new_group(ranks, backend="nccl") + + if ENABLE_VLLM_REDUCE and HAS_VLLM: + cpu_group = dist.new_group(ranks, backend="gloo") + self.vllm_reduce = CustomAllreduce(cpu_group, torch.cuda.current_device()) + logger.info("Enable VLLM ALLReduce.") + + def _all_reduce_closure(input_, op=ReduceOp.SUM, group=self.device_group, async_op=False): + if op != ReduceOp.SUM or async_op: original_all_reduce(input_, op, group, async_op) + else: + if self.vllm_reduce is not None and self.vllm_reduce.should_custom_ar(input_): + input_.data = self.vllm_reduce.custom_all_reduce(input_) + else: + original_all_reduce(input_, op, group, async_op) + + dist.all_reduce = _all_reduce_closure + + def set_custom_gather(self): + ENABLE_CUSTOM_GATHER = os.getenv("ENABLE_CUSTOM_GATHER", "False").upper() in ["ON", "TRUE", "1"] + world_size = dist.get_world_size() + ranks = list(range(world_size)) + if self.device_group is None: + self.device_group = dist.new_group(ranks, backend="nccl") + if ENABLE_CUSTOM_GATHER and HAS_LIGHTLLM_KERNEL: + cpu_group = dist.new_group(ranks, backend="gloo") + self.custom_gather = CustomAllgather(cpu_group, torch.cuda.current_device()) + logger.info("Enable Custom ALLGather.") + + def _all_gather_closure(output_, input_, group=self.device_group, async_op=False): + if async_op: + original_all_gather_into_tensor(output_, input_, group, async_op) + else: + if self.custom_gather is not None and self.custom_gather.should_custom_ar(input_): + self.custom_gather.custom_all_gather(output_, input_) + else: + original_all_gather_into_tensor(output_, input_, group, async_op) + + dist.all_gather_into_tensor = _all_gather_closure + - dist.all_reduce = _all_reduce_closure +custom_comm_ops = CustomCommunicationOp() diff --git a/lightllm/distributed/custom_all_gather.py b/lightllm/distributed/custom_all_gather.py new file mode 100644 index 000000000..108014fb2 --- /dev/null +++ b/lightllm/distributed/custom_all_gather.py @@ -0,0 +1,249 @@ +# Adapted from +# https://github.com/vllm-project/vllm/blob/v0.6.3.post1/vllm/distributed/device_communicators/custom_all_gather.py +# of the vllm-project/vllm GitHub repository. +# +# Copyright 2023 ModelTC Team +# Copyright 2023 vLLM Team +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import ctypes +from contextlib import contextmanager +from typing import List, Optional, Union + +import torch +import torch.distributed as dist +from torch.distributed import ProcessGroup + +from lightllm.common.vllm_kernel import _custom_ops as ops +from lightllm.common.cuda_wrapper import CudaRTLibrary +from lightllm.utils.log_utils import init_logger +from lightllm.utils.vllm_utils import is_full_nvlink +from lightllm.common.basemodel.layer_infer.cache_tensor_manager import g_cache_manager + +ops.meta_size() +custom_ar = True + +logger = init_logger(__name__) + + +def is_weak_contiguous(inp: torch.Tensor): + return inp.is_contiguous() or ( + inp.storage().nbytes() - inp.storage_offset() * inp.element_size() == inp.numel() * inp.element_size() + ) + + +class CustomAllgather: + + _SUPPORTED_WORLD_SIZES = [2, 4, 6, 8] + + # max_size: max supported allgather size + def __init__(self, group: ProcessGroup, device: Union[int, str, torch.device], max_size=8192 * 1024 * 10) -> None: + """ + Args: + group: the process group to work on. If None, it will use the + default process group. + device: the device to bind the CustomAllgather to. If None, + it will be bind to f"cuda:{local_rank}". + It is the caller's responsibility to make sure each communicator + is bind to a unique device, and all communicators in this group + are in the same node. + """ + self._IS_CAPTURING = False + self.disabled = True + + if not custom_ar: + # disable because of missing custom allgather library + # e.g. in a non-cuda environment + return + self.group = group + assert dist.get_backend(group) != dist.Backend.NCCL, "CustomAllgather should be attached to a non-NCCL group." + + rank = dist.get_rank(group=self.group) + world_size = dist.get_world_size(group=self.group) + if world_size == 1: + # No need to initialize custom allgather for single GPU case. + return + + if world_size not in CustomAllgather._SUPPORTED_WORLD_SIZES: + logger.warning( + "Custom allgather is disabled due to an unsupported world" + " size: %d. Supported world sizes: %s. To silence this " + "warning, specify disable_custom_all_gather=True explicitly.", + world_size, + str(CustomAllgather._SUPPORTED_WORLD_SIZES), + ) + return + + if isinstance(device, int): + device = torch.device(f"cuda:{device}") + elif isinstance(device, str): + device = torch.device(device) + # now `device` is a `torch.device` object + assert isinstance(device, torch.device) + self.device = device + + cuda_visible_devices = None + if cuda_visible_devices: + device_ids = list(map(int, cuda_visible_devices.split(","))) + else: + device_ids = list(range(torch._C._cuda_getDeviceCount())) + + physical_device_id = device_ids[device.index] + tensor = torch.tensor([physical_device_id], dtype=torch.int, device="cpu") + gather_list = [torch.tensor([0], dtype=torch.int, device="cpu") for _ in range(world_size)] + dist.all_gather(gather_list, tensor, group=self.group) + # physical_device_ids = [t.item() for t in gather_list] + + full_nvlink = is_full_nvlink() + if world_size > 2 and not full_nvlink: + logger.warning( + "Custom allgather is disabled because it's not supported on" + " more than two PCIe-only GPUs. To silence this warning, " + "specify disable_custom_all_gather=True explicitly." + ) + return + + self.disabled = False + # Buffers memory are owned by this Python class and passed to C++. + # Meta data is for synchronization + self.meta_ptrs = self.create_shared_buffer(ops.meta_size(), group=group) + # This is a pre-registered IPC buffer. In eager mode, input tensors + # are first copied into this buffer before allgather is performed + self.buffer_ptrs = self.create_shared_buffer(max_size, group=group) + # This is a buffer for storing the tuples of pointers pointing to + # IPC buffers from all ranks. Each registered tuple has size of + # 8*world_size bytes where world_size is at most 8. Allocating 8MB + # is enough for 131072 such tuples. The largest model I've seen only + # needs less than 10000 of registered tuples. + self.rank_data = torch.empty(8 * 1024 * 1024, dtype=torch.uint8, device=self.device) + self.max_size = max_size + self.rank = rank + self.world_size = world_size + self.full_nvlink = full_nvlink + self._ptr = ops.init_custom_gather_ar(self.meta_ptrs, self.rank_data, rank, self.full_nvlink) + ops.allgather_register_buffer(self._ptr, self.buffer_ptrs) + + @staticmethod + def create_shared_buffer(size_in_bytes: int, group: Optional[ProcessGroup] = None) -> List[int]: + """ + Creates a shared buffer and returns a list of pointers + representing the buffer on all processes in the group. + """ + lib = CudaRTLibrary() + pointer = lib.cudaMalloc(size_in_bytes) + handle = lib.cudaIpcGetMemHandle(pointer) + world_size = dist.get_world_size(group=group) + rank = dist.get_rank(group=group) + handles = [None] * world_size + dist.all_gather_object(handles, handle, group=group) + + pointers: List[int] = [] + for i, h in enumerate(handles): + if i == rank: + pointers.append(pointer.value) # type: ignore + else: + pointers.append(lib.cudaIpcOpenMemHandle(h).value) # type: ignore + + return pointers + + @staticmethod + def free_shared_buffer(pointers: List[int], group: Optional[ProcessGroup] = None) -> None: + rank = dist.get_rank(group=group) + lib = CudaRTLibrary() + lib.cudaFree(ctypes.c_void_p(pointers[rank])) + + @contextmanager + def capture(self): + """ + The main responsibility of this context manager is the + `register_graph_buffers` call at the end of the context. + It records all the buffer addresses used in the CUDA graph. + """ + try: + self._IS_CAPTURING = True + yield + finally: + self._IS_CAPTURING = False + if not self.disabled: + self.register_graph_buffers() + + def register_graph_buffers(self): + handle, offset = ops.allgather_get_graph_buffer_ipc_meta(self._ptr) + # We cannot directly use `dist.all_gather_object` here + # because it is incompatible with `gloo` backend under inference mode. + # see https://github.com/pytorch/pytorch/issues/126032 for details. + all_data = [[None, None] for _ in range(dist.get_world_size(group=self.group))] + all_data[self.rank] = [handle, offset] + ranks = sorted(dist.get_process_group_ranks(group=self.group)) + for i, rank in enumerate(ranks): + dist.broadcast_object_list(all_data[i], src=rank, group=self.group, device="cpu") + # Unpack list of tuples to tuple of lists. + handles = [d[0] for d in all_data] # type: ignore + offsets = [d[1] for d in all_data] # type: ignore + ops.allgather_register_graph_buffers(self._ptr, handles, offsets) + + def should_custom_ar(self, inp: torch.Tensor): + if self.disabled: + return False + inp_size = inp.numel() * inp.element_size() + # custom allgather requires input byte size to be multiples of 16 + if inp_size % 16 != 0: + return False + if not is_weak_contiguous(inp): + return False + if self.world_size == 2 or self.full_nvlink: + return inp_size < self.max_size + return False + + def all_gather(self, out: torch.Tensor, inp: torch.Tensor, registered: bool = False): + """Performs an out-of-place all gather. + + If registered is True, this assumes inp's pointer is already + IPC-registered. Otherwise, inp is first copied into a pre-registered + buffer. + """ + if registered: + ops.all_gather(self._ptr, inp, out, 0, 0) + else: + ops.all_gather(self._ptr, inp, out, self.buffer_ptrs[self.rank], self.max_size) + return out + + def custom_all_gather(self, output: torch.Tensor, input: torch.Tensor) -> Optional[torch.Tensor]: + """The main allgather API that provides support for cuda graph.""" + # When custom allgather is disabled, this will be None. + if self.disabled or not self.should_custom_ar(input): + return + if self._IS_CAPTURING: + if torch.cuda.is_current_stream_capturing(): + self.all_gather(output, input, registered=True) + return + else: + # If warm up, mimic the allocation pattern since custom + # allgather is out-of-place. + return + else: + # Note: outside of cuda graph context, custom allgather incurs a + # cost of cudaMemcpy, which should be small (<=1% of overall + # latency) compared to the performance gain of using custom kernels + self.all_gather(output, input, registered=False) + return + + def close(self): + if not self.disabled and self._ptr: + ops.allgather_dispose(self._ptr) + self._ptr = 0 + self.free_shared_buffer(self.meta_ptrs) + self.free_shared_buffer(self.buffer_ptrs) + + def __del__(self): + self.close() diff --git a/lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py b/lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py index 2b2933a27..f8eaa0865 100644 --- a/lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py @@ -257,11 +257,8 @@ def _ffn_dp( hidden_states = self.alloc_tensor( [infer_state.all_token_num, hidden_dim], dtype=tp_hidden_states.dtype, device=tp_hidden_states.device ) - dist.all_gather( - [ - hidden_states[infer_state.all_start_idx[i] : infer_state.all_end_idx[i], :] - for i in range(self.world_size_) - ], + dist.all_gather_into_tensor( + hidden_states, tp_hidden_states, group=None, async_op=False, @@ -314,15 +311,23 @@ def _moe_ffn_dtp( hidden_states = self.alloc_tensor( [infer_state.all_token_num, hidden_dim], dtype=tp_hidden_states.dtype, device=tp_hidden_states.device ) - dist.all_gather( - [ - hidden_states[infer_state.all_start_idx[i] : infer_state.all_end_idx[i], :] - for i in range(self.world_size_) - ], - tp_hidden_states, - group=None, - async_op=False, - ) + if infer_state.is_prefill: + dist.all_gather( + [ + hidden_states[infer_state.all_start_idx[i] : infer_state.all_end_idx[i], :] + for i in range(self.world_size_) + ], + tp_hidden_states, + group=None, + async_op=False, + ) + else: + dist.all_gather_into_tensor( + hidden_states, + tp_hidden_states, + group=None, + async_op=False, + ) if self.n_shared_experts is not None: shared_output = LlamaTransformerLayerInfer._ffn(self, hidden_states, infer_state, layer_weight) diff --git a/lightllm/server/router/model_infer/mode_backend/base_backend.py b/lightllm/server/router/model_infer/mode_backend/base_backend.py index 44dddb9cc..dc0afeff1 100644 --- a/lightllm/server/router/model_infer/mode_backend/base_backend.py +++ b/lightllm/server/router/model_infer/mode_backend/base_backend.py @@ -96,9 +96,10 @@ def init_model(self, kvargs): dist.all_reduce(_a) del _a - from lightllm.distributed import set_custom_reduce + from lightllm.distributed import custom_comm_ops - set_custom_reduce() + custom_comm_ops.set_custom_reduce() + custom_comm_ops.set_custom_gather() # 为 p d 分离模式添加的全局锁管理,用于做一些同步操作。 一定需要在 # init_process_group 之后调用 diff --git a/test/model/model_infer_vit.py b/test/model/model_infer_vit.py index 65d0a7915..04d817f87 100644 --- a/test/model/model_infer_vit.py +++ b/test/model/model_infer_vit.py @@ -39,7 +39,9 @@ def tppart_model_infer(model_kvargs): torch.cuda.set_device(rank_id) os.environ["CURRENT_DEVICE_ID"] = str(rank_id) dist.init_process_group("nccl", init_method="tcp://127.0.0.1:28765", rank=rank_id, world_size=world_size) - set_custom_reduce() + from lightllm.distributed import custom_comm_ops + + custom_comm_ops.set_custom_reduce() dist.barrier() torch.cuda.empty_cache() model_part = VisionTransformer(model_kvargs) diff --git a/test/model/test_settings/model_infer_batchs.py b/test/model/test_settings/model_infer_batchs.py index b5937bd48..efc3d5331 100644 --- a/test/model/test_settings/model_infer_batchs.py +++ b/test/model/test_settings/model_infer_batchs.py @@ -69,7 +69,6 @@ def tppart_model_infer(model_class, model_kvargs, batch_sizes, input_len, output return import torch - from lightllm.distributed import set_custom_reduce import torch.distributed as dist rank_id = model_kvargs["tp_rank"] @@ -77,7 +76,9 @@ def tppart_model_infer(model_class, model_kvargs, batch_sizes, input_len, output torch.cuda.set_device(rank_id) dist.init_process_group("nccl", init_method="tcp://127.0.0.1:28765", rank=rank_id, world_size=world_size) - set_custom_reduce() + from lightllm.distributed import custom_comm_ops + + custom_comm_ops.set_custom_reduce() dist.barrier() torch.cuda.empty_cache()