Skip to content

Commit

Permalink
add custom allgather(into tensor) (#692)
Browse files Browse the repository at this point in the history
Co-authored-by: baishihao <[email protected]>
Co-authored-by: hiworldwzj <[email protected]>
  • Loading branch information
3 people authored Jan 2, 2025
1 parent c7896c9 commit d5edea4
Show file tree
Hide file tree
Showing 9 changed files with 390 additions and 56 deletions.
4 changes: 2 additions & 2 deletions lightllm/common/basemodel/cuda_graph.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import os
import torch
from lightllm.utils.log_utils import init_logger
from lightllm.distributed import lightllm_capture_graph
from lightllm.distributed import custom_comm_ops

logger = init_logger(__name__)

Expand Down Expand Up @@ -31,7 +31,7 @@ def capture_decode(self, decode_func, input_ids, infer_state):
torch.cuda.synchronize()
decode_func(input_ids, infer_state)
torch.cuda.synchronize()
with lightllm_capture_graph():
with custom_comm_ops.lightllm_capture_graph():
with torch.cuda.graph(graph_obj, pool=self.mempool):
predict_logics = decode_func(input_ids, infer_state)
self.graph[batch_size] = (graph_obj, input_ids, infer_state, predict_logics)
Expand Down
10 changes: 10 additions & 0 deletions lightllm/common/vllm_kernel/_custom_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,3 +49,13 @@ def select_experts(

except ImportError:
logger.error("vllm or lightllm_kernel is not installed, you can't use custom ops")

try:
from lightllm.common.vllm_kernel._ops import init_custom_gather_ar
from lightllm.common.vllm_kernel._ops import all_gather
from lightllm.common.vllm_kernel._ops import allgather_dispose
from lightllm.common.vllm_kernel._ops import allgather_register_buffer
from lightllm.common.vllm_kernel._ops import allgather_get_graph_buffer_ipc_meta
from lightllm.common.vllm_kernel._ops import allgather_register_graph_buffers
except ImportError:
logger.error("lightllm_kernel is not installed, you can't use custom allgather")
27 changes: 27 additions & 0 deletions lightllm/common/vllm_kernel/_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1069,6 +1069,33 @@ def register_graph_buffers(fa: int, handles: List[List[int]], offsets: List[List
torch.ops.vllm_total_custom_ar.register_graph_buffers(fa, handles, offsets)


# custom ar
def init_custom_gather_ar(
ipc_tensors: List[torch.Tensor], rank_data: torch.Tensor, rank: int, full_nvlink: bool
) -> int:
return torch.ops.vllm_total_custom_ar.init_custom_gather_ar(ipc_tensors, rank_data, rank, full_nvlink)


def all_gather(fa: int, inp: torch.Tensor, out: torch.Tensor, reg_buffer: int, reg_buffer_sz_bytes: int) -> None:
torch.ops.vllm_total_custom_ar.all_gather(fa, inp, out, reg_buffer, reg_buffer_sz_bytes)


def allgather_dispose(fa: int) -> None:
torch.ops.vllm_total_custom_ar.allgather_dispose(fa)


def allgather_register_buffer(fa: int, ipc_tensors: List[int]) -> None:
return torch.ops.vllm_total_custom_ar.allgather_register_buffer(fa, ipc_tensors)


def allgather_get_graph_buffer_ipc_meta(fa: int) -> Tuple[List[int], List[int]]:
return torch.ops.vllm_total_custom_ar.allgather_get_graph_buffer_ipc_meta(fa)


def allgather_register_graph_buffers(fa: int, handles: List[List[int]], offsets: List[List[int]]) -> None:
torch.ops.vllm_total_custom_ar.allgather_register_graph_buffers(fa, handles, offsets)


# temporary fix for https://github.com/vllm-project/vllm/issues/5456
# TODO: remove this in v0.6.0
names_and_values = globals()
Expand Down
109 changes: 74 additions & 35 deletions lightllm/distributed/communication_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@
from functools import partial

original_all_reduce = torch.distributed.all_reduce
original_all_gather_into_tensor = torch.distributed.all_gather_into_tensor

from contextlib import nullcontext, contextmanager

logger = init_logger(__name__)
Expand All @@ -40,43 +42,80 @@
HAS_VLLM = False
logger.info("vllm or lightllm_kernel is not installed, you can't use custom allreduce")

vllm_reduce = None
try:
HAS_LIGHTLLM_KERNEL = True
from .custom_all_gather import CustomAllgather

logger.info("using custom allgather")
except:
HAS_LIGHTLLM_KERNEL = False
logger.info("lightllm_kernel is not installed, you can't use custom allgather")

@contextmanager
def lightllm_capture_graph():
if vllm_reduce is not None:
with vllm_reduce.capture():
yield
else:
yield
pass


def set_custom_reduce():
global vllm_reduce
global device_group
ENABLE_VLLM_REDUCE = os.getenv("ENABLE_VLLM_REDUCE", "False").upper() in [
"ON",
"TRUE",
"1",
]
world_size = dist.get_world_size()
ranks = list(range(world_size))
# new_group prevent stuck of torch origin all_reduce with cudagraph
device_group = torch.distributed.new_group(ranks, backend="nccl")
if ENABLE_VLLM_REDUCE and HAS_VLLM:
cpu_group = torch.distributed.new_group(ranks, backend="gloo")
vllm_reduce = CustomAllreduce(cpu_group, torch.cuda.current_device())
logger.info("Enable VLLM ALLReduce.")

def _all_reduce_closure(input_, op=ReduceOp.SUM, group=device_group, async_op=False):
if op != ReduceOp.SUM or async_op:
original_all_reduce(input_, op, group, async_op)

class CustomCommunicationOp:
def __init__(self):
self.vllm_reduce = None
self.custom_gather = None
self.device_group = None

@contextmanager
def lightllm_capture_graph(self):
if self.vllm_reduce is not None:
with self.vllm_reduce.capture():
if self.custom_gather is not None:
with self.custom_gather.capture():
yield
else:
yield
else:
if vllm_reduce is not None and vllm_reduce.should_custom_ar(input_):
input_.data = vllm_reduce.custom_all_reduce(input_)
else:
yield

def set_custom_reduce(self):
ENABLE_VLLM_REDUCE = os.getenv("ENABLE_VLLM_REDUCE", "False").upper() in ["ON", "TRUE", "1"]
world_size = dist.get_world_size()
ranks = list(range(world_size))

# 创建新的 NCCL 组以防止原始 all_reduce 与 cudagraph 卡住
if self.device_group is None:
self.device_group = dist.new_group(ranks, backend="nccl")

if ENABLE_VLLM_REDUCE and HAS_VLLM:
cpu_group = dist.new_group(ranks, backend="gloo")
self.vllm_reduce = CustomAllreduce(cpu_group, torch.cuda.current_device())
logger.info("Enable VLLM ALLReduce.")

def _all_reduce_closure(input_, op=ReduceOp.SUM, group=self.device_group, async_op=False):
if op != ReduceOp.SUM or async_op:
original_all_reduce(input_, op, group, async_op)
else:
if self.vllm_reduce is not None and self.vllm_reduce.should_custom_ar(input_):
input_.data = self.vllm_reduce.custom_all_reduce(input_)
else:
original_all_reduce(input_, op, group, async_op)

dist.all_reduce = _all_reduce_closure

def set_custom_gather(self):
ENABLE_CUSTOM_GATHER = os.getenv("ENABLE_CUSTOM_GATHER", "False").upper() in ["ON", "TRUE", "1"]
world_size = dist.get_world_size()
ranks = list(range(world_size))
if self.device_group is None:
self.device_group = dist.new_group(ranks, backend="nccl")
if ENABLE_CUSTOM_GATHER and HAS_LIGHTLLM_KERNEL:
cpu_group = dist.new_group(ranks, backend="gloo")
self.custom_gather = CustomAllgather(cpu_group, torch.cuda.current_device())
logger.info("Enable Custom ALLGather.")

def _all_gather_closure(output_, input_, group=self.device_group, async_op=False):
if async_op:
original_all_gather_into_tensor(output_, input_, group, async_op)
else:
if self.custom_gather is not None and self.custom_gather.should_custom_ar(input_):
self.custom_gather.custom_all_gather(output_, input_)
else:
original_all_gather_into_tensor(output_, input_, group, async_op)

dist.all_gather_into_tensor = _all_gather_closure


dist.all_reduce = _all_reduce_closure
custom_comm_ops = CustomCommunicationOp()
Loading

0 comments on commit d5edea4

Please sign in to comment.