From 7ad0d08081202170d8b5b8e4bc484ee830c7e942 Mon Sep 17 00:00:00 2001 From: lijiaxing Date: Mon, 8 Jul 2024 10:50:53 +0800 Subject: [PATCH 1/7] tp recompute --- configs/7B_internlm2.py | 3 +- configs/7B_sft.py | 1 + internlm/core/parallel/comm/tensor.py | 41 ++++++++++++++++-------- internlm/initialize/launch.py | 3 ++ internlm/model/builder.py | 3 ++ internlm/model/modeling_internlm.py | 27 +++++++++++++++- internlm/model/modeling_internlm2.py | 29 ++++++++++++++++- internlm/model/modules/linear.py | 17 +++++++--- internlm/model/modules/mlp.py | 4 +-- internlm/solver/activation_checkpoint.py | 24 ++++++++++++++ 10 files changed, 129 insertions(+), 23 deletions(-) diff --git a/configs/7B_internlm2.py b/configs/7B_internlm2.py index a69896ce..8085cdf0 100644 --- a/configs/7B_internlm2.py +++ b/configs/7B_internlm2.py @@ -1,5 +1,5 @@ JOB_NAME = "7b_internlm2_train" -model_type="INTERNLM2_PUBLIC" +model_type = "INTERNLM2_PUBLIC" DO_ALERT = False VOCAB_SIZE = 92544 @@ -128,6 +128,7 @@ use_fp32_norm = False model = dict( checkpoint=False, + # checkpoint_tp_no_comm=True, # whether use TP recomputation communication optimization num_chunks=1, num_attention_heads=NUM_ATTENTION_HEAD, embed_split_hidden=True, diff --git a/configs/7B_sft.py b/configs/7B_sft.py index eba87bcd..f3e28221 100644 --- a/configs/7B_sft.py +++ b/configs/7B_sft.py @@ -141,6 +141,7 @@ use_fp32_norm = False model = dict( checkpoint=False, # The proportion of layers for activation aheckpointing, the optional value are True/False/[0-1] + # checkpoint_tp_no_comm=True, # whether use TP recomputation communication optimization num_attention_heads=NUM_ATTENTION_HEAD, embed_split_hidden=True, vocab_size=VOCAB_SIZE, diff --git a/internlm/core/parallel/comm/tensor.py b/internlm/core/parallel/comm/tensor.py index ca8c1900..e33172bb 100644 --- a/internlm/core/parallel/comm/tensor.py +++ b/internlm/core/parallel/comm/tensor.py @@ -66,7 +66,7 @@ def input_hook( @abstractmethod def grad_output_hook( - self, grad_output: torch.Tensor, async_op: bool = False + self, grad_output: torch.Tensor, async_op: bool = False, recompute_forward: bool = False ) -> Tuple[torch.Tensor, AsyncCommHandle]: """ communication for grad_output when backward. @@ -81,7 +81,9 @@ def grad_input_hook(self, grad_input: torch.Tensor, async_op: bool = False) -> T pass @abstractmethod - def output_hook(self, output: torch.Tensor, async_op: bool = False) -> Tuple[torch.Tensor, AsyncCommHandle]: + def output_hook( + self, output: torch.Tensor, async_op: bool = False, recompute_forward: bool = False + ) -> Tuple[torch.Tensor, AsyncCommHandle]: """ communication for output when forward. """ @@ -116,7 +118,10 @@ def input_hook( return _input, DUMMY_HANDLE_CONST def grad_output_hook( - self, grad_output: torch.Tensor, async_op: bool = False # pylint: disable=W0613 + self, + grad_output: torch.Tensor, + async_op: bool = False, + recompute_forward: bool = False, # pylint: disable=W0613 ) -> Tuple[torch.Tensor, AsyncCommHandle]: """ tensor parallel should do nothing for grad_output. @@ -132,11 +137,13 @@ def grad_input_hook(self, grad_input: torch.Tensor, async_op: bool = False) -> T return all_reduce_raw(grad_input, process_group=self._process_group, async_op=async_op) - def output_hook(self, output: torch.Tensor, async_op: bool = False) -> Tuple[torch.Tensor, AsyncCommHandle]: + def output_hook( + self, output: torch.Tensor, async_op: bool = False, recompute_forward: bool = False + ) -> Tuple[torch.Tensor, AsyncCommHandle]: """ all reduce output only for row parallel linear when forward. """ - if dist.get_world_size(self._process_group) <= 1 or self._role == LinearRole.COLUMN: + if recompute_forward or dist.get_world_size(self._process_group) <= 1 or self._role == LinearRole.COLUMN: return output, DUMMY_HANDLE_CONST return all_reduce_raw(output, process_group=self._process_group, async_op=async_op) @@ -182,12 +189,12 @@ def input_hook( return all_gather_raw(_input, process_group=self._process_group, async_op=async_op, gather_dim=_GATHER_DIM) def grad_output_hook( - self, grad_output: torch.Tensor, async_op: bool = False + self, grad_output: torch.Tensor, async_op: bool = False, recompute_forward: bool = False ) -> Tuple[torch.Tensor, AsyncCommHandle]: """ all gather grad_output only for row parallel linear when backward. """ - if dist.get_world_size(self._process_group) <= 1 or self._role == LinearRole.COLUMN: + if recompute_forward or dist.get_world_size(self._process_group) <= 1 or self._role == LinearRole.COLUMN: return grad_output, DUMMY_HANDLE_CONST return all_gather_raw(grad_output, process_group=self._process_group, async_op=async_op, gather_dim=_GATHER_DIM) @@ -203,11 +210,13 @@ def grad_input_hook(self, grad_input: torch.Tensor, async_op: bool = False) -> T grad_input, process_group=self._process_group, async_op=async_op, reduce_dim=_REDUCE_DIM ) - def output_hook(self, output: torch.Tensor, async_op: bool = False) -> Tuple[torch.Tensor, AsyncCommHandle]: + def output_hook( + self, output: torch.Tensor, async_op: bool = False, recompute_forward: bool = False + ) -> Tuple[torch.Tensor, AsyncCommHandle]: """ reduce scatter output only for row parallel linear when forward. """ - if dist.get_world_size(self._process_group) <= 1 or self._role == LinearRole.COLUMN: + if recompute_forward or dist.get_world_size(self._process_group) <= 1 or self._role == LinearRole.COLUMN: return output, DUMMY_HANDLE_CONST return reduce_scatter_raw(output, process_group=self._process_group, async_op=async_op, reduce_dim=_REDUCE_DIM) @@ -225,7 +234,10 @@ def __init__(self, parallel_mode: ParallelMode, retain_out_sharded: bool = True) self._retain_out_sharded = retain_out_sharded def grad_output_hook( - self, grad_output: torch.Tensor, async_op: bool = False # pylint: disable=W0613 + self, + grad_output: torch.Tensor, + async_op: bool = False, + recompute_forward: bool = False, # pylint: disable=W0613 ) -> Tuple[torch.Tensor, AsyncCommHandle]: """ split grad_output if retain_out_sharded is False. @@ -236,7 +248,7 @@ def grad_output_hook( return _split(grad_output, parallel_mode=self._parallel_mode, dim=-1), DUMMY_HANDLE_CONST def output_hook( - self, output: torch.Tensor, async_op: bool = False # pylint: disable=W0613 + self, output: torch.Tensor, async_op: bool = False, recompute_forward: bool = False # pylint: disable=W0613 ) -> Tuple[torch.Tensor, AsyncCommHandle]: """ all gather output for head layer if retain_out_sharded is False. @@ -266,7 +278,10 @@ def __init__( # rewrite grad_output communication hook def grad_output_hook( - self, grad_output: torch.Tensor, async_op: bool = False # pylint: disable=W0613 + self, + grad_output: torch.Tensor, + async_op: bool = False, + recompute_forward: bool = False, # pylint: disable=W0613 ) -> Tuple[torch.Tensor, AsyncCommHandle]: """ split grad_output if retain_out_sharded is False. @@ -278,7 +293,7 @@ def grad_output_hook( # rewrite ouput communication hook def output_hook( - self, output: torch.Tensor, async_op: bool = False # pylint: disable=W0613 + self, output: torch.Tensor, async_op: bool = False, recompute_forward: bool = False # pylint: disable=W0613 ) -> Tuple[torch.Tensor, AsyncCommHandle]: """ all gather output for head layer if retain_out_sharded is False. diff --git a/internlm/initialize/launch.py b/internlm/initialize/launch.py index a07c3e76..34b38e63 100644 --- a/internlm/initialize/launch.py +++ b/internlm/initialize/launch.py @@ -295,10 +295,13 @@ def args_sanity_check(): ] if "checkpoint" in model: + if "checkpoint_tp_no_comm" not in model: + gpc.config.model._add_item("checkpoint_tp_no_comm", True) if model.checkpoint is True: model.checkpoint = 1 elif model.checkpoint is False: model.checkpoint = 0 + model.checkpoint_tp_no_comm = False else: assert ( model.checkpoint >= 0 and model.checkpoint <= 1 diff --git a/internlm/model/builder.py b/internlm/model/builder.py index 2b10406b..8bf45326 100644 --- a/internlm/model/builder.py +++ b/internlm/model/builder.py @@ -21,6 +21,9 @@ def create_model(model_type, *args, **kwargs) -> Union[nn.Module, List[nn.Module kwargs["checkpoint"] = float(kwargs.get("checkpoint", False)) kwargs["device"] = get_current_device() + if "checkpoint_tp_no_comm" in kwargs: + kwargs.pop("checkpoint_tp_no_comm") + model_buidler = model_initializer.get_module(module_name=model_type) if not gpc.is_using_parallel_mode(ParallelMode.PIPELINE): diff --git a/internlm/model/modeling_internlm.py b/internlm/model/modeling_internlm.py index 5994e15d..5d498b7d 100644 --- a/internlm/model/modeling_internlm.py +++ b/internlm/model/modeling_internlm.py @@ -10,6 +10,7 @@ from internlm.core.context import ParallelMode from internlm.core.context.parallel_context import global_context as gpc from internlm.core.naive_amp import set_output_attr_to_module +from internlm.core.parallel.comm.tensor import _GATHER_DIM from internlm.initialize.initialize_tensor import normal_, scaled_init_method_normal from internlm.model.modules.embedding import Embedding1D from internlm.model.modules.linear import new_linear @@ -24,6 +25,7 @@ ) from internlm.solver.activation_checkpoint import activation_checkpoint from internlm.utils.logger import get_logger +from internlm.utils.parallel import is_using_isp, is_using_sequence_parallel logger = get_logger(__file__) @@ -179,6 +181,8 @@ def _forward(self, hidden_states, *args, **kwargs): cu_seqlens: 1d LongTensor, len(cu_seqlens) = hidden_states + 1 indexes: the length of index is same as hidden states, which stand for the current position """ + recompute_forward = args[4] if len(args) > 4 else False + args = args[:4] def _dropout_and_norm_attn(_hidden_states): _dropped = self.dropout1(_hidden_states) @@ -211,7 +215,28 @@ def _dropout_and_norm_ffn(_residual, _hidden_states): if self.residual_in_fp32: residual = residual.to(torch.float32) - hidden_states = self.mlp(hidden_states) + hidden_states = self.mlp(hidden_states, recompute_forward=recompute_forward) + + # pad residual + if recompute_forward and is_using_sequence_parallel() and not is_using_isp(): + requires_grad = residual.requires_grad + pad_before = gpc.get_local_rank(ParallelMode.TENSOR) * residual.shape[_GATHER_DIM] + pad_after = ( + gpc.get_world_size(ParallelMode.TENSOR) - gpc.get_local_rank(ParallelMode.TENSOR) - 1 + ) * residual.shape[_GATHER_DIM] + + pad_before_tensor = torch.zeros( + (*residual.shape[:_GATHER_DIM], pad_before, *residual.shape[_GATHER_DIM + 1 :]), + dtype=residual.dtype, + device=residual.device, + ) + pad_after_tensor = torch.zeros( + (*residual.shape[:_GATHER_DIM], pad_after, *residual.shape[_GATHER_DIM + 1 :]), + dtype=residual.dtype, + device=residual.device, + ) + + residual = torch.cat([pad_before_tensor, residual, pad_after_tensor], dim=1).requires_grad_(requires_grad) return hidden_states + residual diff --git a/internlm/model/modeling_internlm2.py b/internlm/model/modeling_internlm2.py index c3b89412..6f7fe108 100644 --- a/internlm/model/modeling_internlm2.py +++ b/internlm/model/modeling_internlm2.py @@ -7,6 +7,7 @@ from internlm.core.context import ParallelMode from internlm.core.context.parallel_context import global_context as gpc +from internlm.core.parallel.comm.tensor import _GATHER_DIM from internlm.initialize.initialize_tensor import ( normal_, scaled_init_method_normal, @@ -24,6 +25,7 @@ ) from internlm.solver.activation_checkpoint import activation_checkpoint from internlm.utils.logger import get_logger +from internlm.utils.parallel import is_using_isp, is_using_sequence_parallel logger = get_logger(__file__) @@ -216,6 +218,8 @@ def _forward(self, hidden_states, residual, *args, **kwargs): cu_seqlens: 1d LongTensor, len(cu_seqlens) = hidden_states + 1 indexes: the length of index is same as hidden states, which stand for the current position """ + recompute_forward = args[4] if len(args) > 4 else False + args = args[:4] if self.prenorm: def _dropout_and_norm_attn(_residual, _hidden_states): @@ -255,7 +259,30 @@ def _dropout_and_norm_ffn(_residual, _hidden_states): if self.residual_in_fp32: residual = residual.to(torch.float32) - hidden_states = self.feed_forward(hidden_states) + hidden_states = self.feed_forward(hidden_states, recompute_forward=recompute_forward) + + # pad residual + if recompute_forward and is_using_sequence_parallel() and not is_using_isp(): + requires_grad = residual.requires_grad + pad_before = gpc.get_local_rank(ParallelMode.TENSOR) * residual.shape[_GATHER_DIM] + pad_after = ( + gpc.get_world_size(ParallelMode.TENSOR) - gpc.get_local_rank(ParallelMode.TENSOR) - 1 + ) * residual.shape[_GATHER_DIM] + + pad_before_tensor = torch.zeros( + (*residual.shape[:_GATHER_DIM], pad_before, *residual.shape[_GATHER_DIM + 1 :]), + dtype=residual.dtype, + device=residual.device, + ) + pad_after_tensor = torch.zeros( + (*residual.shape[:_GATHER_DIM], pad_after, *residual.shape[_GATHER_DIM + 1 :]), + dtype=residual.dtype, + device=residual.device, + ) + + residual = torch.cat([pad_before_tensor, residual, pad_after_tensor], dim=1).requires_grad_( + requires_grad + ) return hidden_states + residual else: diff --git a/internlm/model/modules/linear.py b/internlm/model/modules/linear.py index 0d8c4bf8..2a2f1f94 100644 --- a/internlm/model/modules/linear.py +++ b/internlm/model/modules/linear.py @@ -45,10 +45,12 @@ def forward( bias: Optional[torch.Tensor], communicator: TPCommunicator, return_residual=False, + recompute_forward=False, ): ctx.compute_weight_gradient = weight.requires_grad ctx.return_residual = return_residual ctx.communicator = communicator + ctx.recompute_forward = recompute_forward if torch.is_autocast_enabled(): x = x.to(dtype=torch.get_autocast_gpu_dtype()) @@ -77,7 +79,7 @@ def forward( # parallel strategy-specific communication callback 2. # see more details in the communicator for different parallel strategies. - output, _ = communicator.output_hook(output, async_op=False) + output, _ = communicator.output_hook(output, async_op=False, recompute_forward=recompute_forward) saved_x = None if ctx.compute_weight_gradient is False else total_x if communicator.save_total_input() else x ctx.save_for_backward(saved_x, weight) @@ -91,7 +93,9 @@ def backward(ctx, grad_output, *args): # parallel strategy-specific communication callback 3. # see more details in the communicator for different parallel strategies. - grad_output, _ = communicator.grad_output_hook(grad_output, async_op=False) + grad_output, _ = communicator.grad_output_hook( + grad_output, recompute_forward=ctx.recompute_forward, async_op=False + ) grad_output = grad_output.contiguous() if ctx.return_residual: @@ -264,6 +268,7 @@ def fused_dense_func( module: Optional[nn.Module] = None, bias: Optional[torch.Tensor] = None, return_residual: bool = False, + recompute_forward=False, ): if communicator.communication_mode() == "wp": return WPFusedDenseFunc.apply( @@ -281,6 +286,7 @@ def fused_dense_func( bias, communicator, return_residual, + recompute_forward, ) @@ -343,16 +349,16 @@ def __init__( else: super().__init__(in_features, out_features, bias=bias, device=device, dtype=dtype) - def forward(self, input: torch.Tensor) -> torch.Tensor: # pylint: disable=W0622 + def forward(self, input: torch.Tensor, recompute_forward=False) -> torch.Tensor: # pylint: disable=W0622 _class_name = self.__class__.__name__ assert self._communicator is not None, f"{_class_name} should register with a communicator first." - return fused_dense_func( input, self.weight, communicator=self._communicator, module=self, bias=self.bias, + recompute_forward=recompute_forward, ) @@ -465,7 +471,7 @@ def __init__( self.first_eval_flag = True self.tmp_weight = None - def forward(self, input): # pylint: disable=W0622 + def forward(self, input, recompute_forward=False): # pylint: disable=W0622 _class_name = self.__class__.__name__ assert self._communicator is not None, f"{_class_name} should register with a communicator first." @@ -496,6 +502,7 @@ def forward(self, input): # pylint: disable=W0622 communicator=self._communicator, module=self, bias=self.bias, + recompute_forward=recompute_forward, ) diff --git a/internlm/model/modules/mlp.py b/internlm/model/modules/mlp.py index 897e1363..d6e89d97 100644 --- a/internlm/model/modules/mlp.py +++ b/internlm/model/modules/mlp.py @@ -91,14 +91,14 @@ def __init__( self.w2 = new_linear("w2", hidden_features, out_features, bias, device=device, dtype=dtype) self.w3 = new_linear("w3", in_features, hidden_features, bias, device=device, dtype=dtype) - def forward(self, x): + def forward(self, x, recompute_forward=False): if not self.mlp_layer_fusion: w1_o = self.w1(x) w3_o = self.w3(x) else: fussed_out = self.fused_w1_w3(x) w1_o, w3_o = torch.split(fussed_out, fussed_out.shape[-1] // 2, dim=-1) - out = self.w2(Silu(w1_o, w3_o)) + out = self.w2(Silu(w1_o, w3_o), recompute_forward=recompute_forward) return out diff --git a/internlm/solver/activation_checkpoint.py b/internlm/solver/activation_checkpoint.py index 87055771..154252be 100644 --- a/internlm/solver/activation_checkpoint.py +++ b/internlm/solver/activation_checkpoint.py @@ -7,6 +7,8 @@ from torch.utils.checkpoint import check_backward_validity, detach_variable from internlm.accelerator import get_accelerator +from internlm.core.context import ParallelMode +from internlm.core.context import global_context as gpc from internlm.core.context.random import ( get_current_mode, get_states, @@ -14,6 +16,8 @@ set_seed_states, sync_states, ) +from internlm.core.parallel.comm.tensor import _GATHER_DIM, all_gather_raw +from internlm.utils.parallel import is_using_isp, is_using_sequence_parallel from ..utils.common import get_current_device @@ -122,7 +126,22 @@ def backward(ctx, *args): # Fill in inputs with appropriate saved tensors. for i, idx in enumerate(tensor_indices): inputs[idx] = tensors[i] + + # recompute_forward + recompute_forward = False + if gpc.config.model.checkpoint_tp_no_comm: + recompute_forward = True + inputs.append(True) + detached_inputs = detach_variable(tuple(inputs)) + + handle = None + if recompute_forward and is_using_sequence_parallel() and not is_using_isp(): + grad_output = args[0] + grad_output, handle = all_gather_raw( + grad_output, process_group=gpc.get_group(ParallelMode.TENSOR), async_op=True, gather_dim=_GATHER_DIM + ) + if ctx.had_autocast_in_fwd: with torch.enable_grad(), internlm_accelerator.amp.autocast(): outputs = ctx.run_function(*detached_inputs) @@ -130,6 +149,11 @@ def backward(ctx, *args): with torch.enable_grad(): outputs = ctx.run_function(*detached_inputs) + if handle: + handle.wait() + args = list(args) + args[0] = grad_output + if isinstance(outputs, torch.Tensor): outputs = (outputs,) # recover the rng states From 00c2481f1bd47fb58198c522359fcf7713aceff9 Mon Sep 17 00:00:00 2001 From: lijiaxing Date: Mon, 8 Jul 2024 15:31:54 +0800 Subject: [PATCH 2/7] change name and fix bug --- internlm/core/parallel/comm/tensor.py | 26 ++++++++++++------------ internlm/initialize/launch.py | 4 ++++ internlm/model/modeling_internlm.py | 6 +++--- internlm/model/modeling_internlm2.py | 6 +++--- internlm/model/modules/linear.py | 20 +++++++++--------- internlm/model/modules/mlp.py | 4 ++-- internlm/solver/activation_checkpoint.py | 10 ++++----- 7 files changed, 40 insertions(+), 36 deletions(-) diff --git a/internlm/core/parallel/comm/tensor.py b/internlm/core/parallel/comm/tensor.py index e33172bb..ac883063 100644 --- a/internlm/core/parallel/comm/tensor.py +++ b/internlm/core/parallel/comm/tensor.py @@ -66,7 +66,7 @@ def input_hook( @abstractmethod def grad_output_hook( - self, grad_output: torch.Tensor, async_op: bool = False, recompute_forward: bool = False + self, grad_output: torch.Tensor, async_op: bool = False, no_communication: bool = False ) -> Tuple[torch.Tensor, AsyncCommHandle]: """ communication for grad_output when backward. @@ -82,7 +82,7 @@ def grad_input_hook(self, grad_input: torch.Tensor, async_op: bool = False) -> T @abstractmethod def output_hook( - self, output: torch.Tensor, async_op: bool = False, recompute_forward: bool = False + self, output: torch.Tensor, async_op: bool = False, no_communication: bool = False ) -> Tuple[torch.Tensor, AsyncCommHandle]: """ communication for output when forward. @@ -121,7 +121,7 @@ def grad_output_hook( self, grad_output: torch.Tensor, async_op: bool = False, - recompute_forward: bool = False, # pylint: disable=W0613 + no_communication: bool = False, # pylint: disable=W0613 ) -> Tuple[torch.Tensor, AsyncCommHandle]: """ tensor parallel should do nothing for grad_output. @@ -138,12 +138,12 @@ def grad_input_hook(self, grad_input: torch.Tensor, async_op: bool = False) -> T return all_reduce_raw(grad_input, process_group=self._process_group, async_op=async_op) def output_hook( - self, output: torch.Tensor, async_op: bool = False, recompute_forward: bool = False + self, output: torch.Tensor, async_op: bool = False, no_communication: bool = False ) -> Tuple[torch.Tensor, AsyncCommHandle]: """ all reduce output only for row parallel linear when forward. """ - if recompute_forward or dist.get_world_size(self._process_group) <= 1 or self._role == LinearRole.COLUMN: + if no_communication or dist.get_world_size(self._process_group) <= 1 or self._role == LinearRole.COLUMN: return output, DUMMY_HANDLE_CONST return all_reduce_raw(output, process_group=self._process_group, async_op=async_op) @@ -189,12 +189,12 @@ def input_hook( return all_gather_raw(_input, process_group=self._process_group, async_op=async_op, gather_dim=_GATHER_DIM) def grad_output_hook( - self, grad_output: torch.Tensor, async_op: bool = False, recompute_forward: bool = False + self, grad_output: torch.Tensor, async_op: bool = False, no_communication: bool = False ) -> Tuple[torch.Tensor, AsyncCommHandle]: """ all gather grad_output only for row parallel linear when backward. """ - if recompute_forward or dist.get_world_size(self._process_group) <= 1 or self._role == LinearRole.COLUMN: + if no_communication or dist.get_world_size(self._process_group) <= 1 or self._role == LinearRole.COLUMN: return grad_output, DUMMY_HANDLE_CONST return all_gather_raw(grad_output, process_group=self._process_group, async_op=async_op, gather_dim=_GATHER_DIM) @@ -211,12 +211,12 @@ def grad_input_hook(self, grad_input: torch.Tensor, async_op: bool = False) -> T ) def output_hook( - self, output: torch.Tensor, async_op: bool = False, recompute_forward: bool = False + self, output: torch.Tensor, async_op: bool = False, no_communication: bool = False ) -> Tuple[torch.Tensor, AsyncCommHandle]: """ reduce scatter output only for row parallel linear when forward. """ - if recompute_forward or dist.get_world_size(self._process_group) <= 1 or self._role == LinearRole.COLUMN: + if no_communication or dist.get_world_size(self._process_group) <= 1 or self._role == LinearRole.COLUMN: return output, DUMMY_HANDLE_CONST return reduce_scatter_raw(output, process_group=self._process_group, async_op=async_op, reduce_dim=_REDUCE_DIM) @@ -237,7 +237,7 @@ def grad_output_hook( self, grad_output: torch.Tensor, async_op: bool = False, - recompute_forward: bool = False, # pylint: disable=W0613 + no_communication: bool = False, # pylint: disable=W0613 ) -> Tuple[torch.Tensor, AsyncCommHandle]: """ split grad_output if retain_out_sharded is False. @@ -248,7 +248,7 @@ def grad_output_hook( return _split(grad_output, parallel_mode=self._parallel_mode, dim=-1), DUMMY_HANDLE_CONST def output_hook( - self, output: torch.Tensor, async_op: bool = False, recompute_forward: bool = False # pylint: disable=W0613 + self, output: torch.Tensor, async_op: bool = False, no_communication: bool = False # pylint: disable=W0613 ) -> Tuple[torch.Tensor, AsyncCommHandle]: """ all gather output for head layer if retain_out_sharded is False. @@ -281,7 +281,7 @@ def grad_output_hook( self, grad_output: torch.Tensor, async_op: bool = False, - recompute_forward: bool = False, # pylint: disable=W0613 + no_communication: bool = False, # pylint: disable=W0613 ) -> Tuple[torch.Tensor, AsyncCommHandle]: """ split grad_output if retain_out_sharded is False. @@ -293,7 +293,7 @@ def grad_output_hook( # rewrite ouput communication hook def output_hook( - self, output: torch.Tensor, async_op: bool = False, recompute_forward: bool = False # pylint: disable=W0613 + self, output: torch.Tensor, async_op: bool = False, no_communication: bool = False # pylint: disable=W0613 ) -> Tuple[torch.Tensor, AsyncCommHandle]: """ all gather output for head layer if retain_out_sharded is False. diff --git a/internlm/initialize/launch.py b/internlm/initialize/launch.py index 34b38e63..36689ffe 100644 --- a/internlm/initialize/launch.py +++ b/internlm/initialize/launch.py @@ -414,6 +414,10 @@ def args_sanity_check(): gpc.config.parallel["pipeline"].get("interleaved_overlap", False) is True ), "only support interleaved pipeline scheduler with overlap" + # when not use tp or sp, checkpoint_tp_no_comm should always be False + if gpc.config.parallel["tensor"]["size"] <= 1 and getattr(gpc.config.model, "checkpoint_tp_no_comm", False): + gpc.config.model.checkpoint_tp_no_comm = False + # monitoring default config monitor_default_config = { "alert_address": None, # compatible with old alert config diff --git a/internlm/model/modeling_internlm.py b/internlm/model/modeling_internlm.py index 5d498b7d..300e33db 100644 --- a/internlm/model/modeling_internlm.py +++ b/internlm/model/modeling_internlm.py @@ -181,7 +181,7 @@ def _forward(self, hidden_states, *args, **kwargs): cu_seqlens: 1d LongTensor, len(cu_seqlens) = hidden_states + 1 indexes: the length of index is same as hidden states, which stand for the current position """ - recompute_forward = args[4] if len(args) > 4 else False + no_communication = args[4] if len(args) > 4 else False args = args[:4] def _dropout_and_norm_attn(_hidden_states): @@ -215,10 +215,10 @@ def _dropout_and_norm_ffn(_residual, _hidden_states): if self.residual_in_fp32: residual = residual.to(torch.float32) - hidden_states = self.mlp(hidden_states, recompute_forward=recompute_forward) + hidden_states = self.mlp(hidden_states, no_communication=no_communication) # pad residual - if recompute_forward and is_using_sequence_parallel() and not is_using_isp(): + if no_communication and is_using_sequence_parallel() and not is_using_isp(): requires_grad = residual.requires_grad pad_before = gpc.get_local_rank(ParallelMode.TENSOR) * residual.shape[_GATHER_DIM] pad_after = ( diff --git a/internlm/model/modeling_internlm2.py b/internlm/model/modeling_internlm2.py index 6f7fe108..5cafeb74 100644 --- a/internlm/model/modeling_internlm2.py +++ b/internlm/model/modeling_internlm2.py @@ -218,7 +218,7 @@ def _forward(self, hidden_states, residual, *args, **kwargs): cu_seqlens: 1d LongTensor, len(cu_seqlens) = hidden_states + 1 indexes: the length of index is same as hidden states, which stand for the current position """ - recompute_forward = args[4] if len(args) > 4 else False + no_communication = args[4] if len(args) > 4 else False args = args[:4] if self.prenorm: @@ -259,10 +259,10 @@ def _dropout_and_norm_ffn(_residual, _hidden_states): if self.residual_in_fp32: residual = residual.to(torch.float32) - hidden_states = self.feed_forward(hidden_states, recompute_forward=recompute_forward) + hidden_states = self.feed_forward(hidden_states, no_communication=no_communication) # pad residual - if recompute_forward and is_using_sequence_parallel() and not is_using_isp(): + if no_communication and is_using_sequence_parallel() and not is_using_isp(): requires_grad = residual.requires_grad pad_before = gpc.get_local_rank(ParallelMode.TENSOR) * residual.shape[_GATHER_DIM] pad_after = ( diff --git a/internlm/model/modules/linear.py b/internlm/model/modules/linear.py index 2a2f1f94..17e1b290 100644 --- a/internlm/model/modules/linear.py +++ b/internlm/model/modules/linear.py @@ -45,12 +45,12 @@ def forward( bias: Optional[torch.Tensor], communicator: TPCommunicator, return_residual=False, - recompute_forward=False, + no_communication=False, ): ctx.compute_weight_gradient = weight.requires_grad ctx.return_residual = return_residual ctx.communicator = communicator - ctx.recompute_forward = recompute_forward + ctx.no_communication = no_communication if torch.is_autocast_enabled(): x = x.to(dtype=torch.get_autocast_gpu_dtype()) @@ -79,7 +79,7 @@ def forward( # parallel strategy-specific communication callback 2. # see more details in the communicator for different parallel strategies. - output, _ = communicator.output_hook(output, async_op=False, recompute_forward=recompute_forward) + output, _ = communicator.output_hook(output, async_op=False, no_communication=no_communication) saved_x = None if ctx.compute_weight_gradient is False else total_x if communicator.save_total_input() else x ctx.save_for_backward(saved_x, weight) @@ -94,7 +94,7 @@ def backward(ctx, grad_output, *args): # parallel strategy-specific communication callback 3. # see more details in the communicator for different parallel strategies. grad_output, _ = communicator.grad_output_hook( - grad_output, recompute_forward=ctx.recompute_forward, async_op=False + grad_output, no_communication=ctx.no_communication, async_op=False ) grad_output = grad_output.contiguous() @@ -268,7 +268,7 @@ def fused_dense_func( module: Optional[nn.Module] = None, bias: Optional[torch.Tensor] = None, return_residual: bool = False, - recompute_forward=False, + no_communication=False, ): if communicator.communication_mode() == "wp": return WPFusedDenseFunc.apply( @@ -286,7 +286,7 @@ def fused_dense_func( bias, communicator, return_residual, - recompute_forward, + no_communication, ) @@ -349,7 +349,7 @@ def __init__( else: super().__init__(in_features, out_features, bias=bias, device=device, dtype=dtype) - def forward(self, input: torch.Tensor, recompute_forward=False) -> torch.Tensor: # pylint: disable=W0622 + def forward(self, input: torch.Tensor, no_communication=False) -> torch.Tensor: # pylint: disable=W0622 _class_name = self.__class__.__name__ assert self._communicator is not None, f"{_class_name} should register with a communicator first." return fused_dense_func( @@ -358,7 +358,7 @@ def forward(self, input: torch.Tensor, recompute_forward=False) -> torch.Tensor: communicator=self._communicator, module=self, bias=self.bias, - recompute_forward=recompute_forward, + no_communication=no_communication, ) @@ -471,7 +471,7 @@ def __init__( self.first_eval_flag = True self.tmp_weight = None - def forward(self, input, recompute_forward=False): # pylint: disable=W0622 + def forward(self, input, no_communication=False): # pylint: disable=W0622 _class_name = self.__class__.__name__ assert self._communicator is not None, f"{_class_name} should register with a communicator first." @@ -502,7 +502,7 @@ def forward(self, input, recompute_forward=False): # pylint: disable=W0622 communicator=self._communicator, module=self, bias=self.bias, - recompute_forward=recompute_forward, + no_communication=no_communication, ) diff --git a/internlm/model/modules/mlp.py b/internlm/model/modules/mlp.py index d6e89d97..f3feee42 100644 --- a/internlm/model/modules/mlp.py +++ b/internlm/model/modules/mlp.py @@ -91,14 +91,14 @@ def __init__( self.w2 = new_linear("w2", hidden_features, out_features, bias, device=device, dtype=dtype) self.w3 = new_linear("w3", in_features, hidden_features, bias, device=device, dtype=dtype) - def forward(self, x, recompute_forward=False): + def forward(self, x, no_communication=False): if not self.mlp_layer_fusion: w1_o = self.w1(x) w3_o = self.w3(x) else: fussed_out = self.fused_w1_w3(x) w1_o, w3_o = torch.split(fussed_out, fussed_out.shape[-1] // 2, dim=-1) - out = self.w2(Silu(w1_o, w3_o), recompute_forward=recompute_forward) + out = self.w2(Silu(w1_o, w3_o), no_communication=no_communication) return out diff --git a/internlm/solver/activation_checkpoint.py b/internlm/solver/activation_checkpoint.py index 154252be..f852bd38 100644 --- a/internlm/solver/activation_checkpoint.py +++ b/internlm/solver/activation_checkpoint.py @@ -127,16 +127,16 @@ def backward(ctx, *args): for i, idx in enumerate(tensor_indices): inputs[idx] = tensors[i] - # recompute_forward - recompute_forward = False - if gpc.config.model.checkpoint_tp_no_comm: - recompute_forward = True + # no_communication + no_communication = False + if getattr(gpc.config.model, "checkpoint_tp_no_comm", False): + no_communication = True inputs.append(True) detached_inputs = detach_variable(tuple(inputs)) handle = None - if recompute_forward and is_using_sequence_parallel() and not is_using_isp(): + if no_communication and is_using_sequence_parallel() and not is_using_isp(): grad_output = args[0] grad_output, handle = all_gather_raw( grad_output, process_group=gpc.get_group(ParallelMode.TENSOR), async_op=True, gather_dim=_GATHER_DIM From f8002fbc85bb2af2824b7c5539436b23465e4345 Mon Sep 17 00:00:00 2001 From: lijiaxing Date: Tue, 9 Jul 2024 15:22:37 +0800 Subject: [PATCH 3/7] fix comment --- internlm/core/context/parallel_context.py | 1 + internlm/initialize/launch.py | 4 ++- internlm/model/modeling_internlm.py | 29 +++++--------------- internlm/model/modeling_internlm2.py | 32 +++++------------------ internlm/model/utils.py | 27 +++++++++++++++++++ internlm/solver/activation_checkpoint.py | 22 +++++++++------- 6 files changed, 56 insertions(+), 59 deletions(-) diff --git a/internlm/core/context/parallel_context.py b/internlm/core/context/parallel_context.py index 6b23fdae..f411f863 100644 --- a/internlm/core/context/parallel_context.py +++ b/internlm/core/context/parallel_context.py @@ -159,6 +159,7 @@ def __init__(self): self.virtual_pipeline_parallel_rank = None self._expert_parallel_group_names = [] self.is_evaluating = False + self.recompute_forward_no_comm = False @property def config(self): diff --git a/internlm/initialize/launch.py b/internlm/initialize/launch.py index 36689ffe..712e0006 100644 --- a/internlm/initialize/launch.py +++ b/internlm/initialize/launch.py @@ -415,7 +415,9 @@ def args_sanity_check(): ), "only support interleaved pipeline scheduler with overlap" # when not use tp or sp, checkpoint_tp_no_comm should always be False - if gpc.config.parallel["tensor"]["size"] <= 1 and getattr(gpc.config.model, "checkpoint_tp_no_comm", False): + if (gpc.config.parallel["tensor"]["mode"] == "isp" or gpc.config.parallel["tensor"]["size"] <= 1) and getattr( + gpc.config.model, "checkpoint_tp_no_comm", False + ): gpc.config.model.checkpoint_tp_no_comm = False # monitoring default config diff --git a/internlm/model/modeling_internlm.py b/internlm/model/modeling_internlm.py index 300e33db..311ae2b3 100644 --- a/internlm/model/modeling_internlm.py +++ b/internlm/model/modeling_internlm.py @@ -10,7 +10,6 @@ from internlm.core.context import ParallelMode from internlm.core.context.parallel_context import global_context as gpc from internlm.core.naive_amp import set_output_attr_to_module -from internlm.core.parallel.comm.tensor import _GATHER_DIM from internlm.initialize.initialize_tensor import normal_, scaled_init_method_normal from internlm.model.modules.embedding import Embedding1D from internlm.model.modules.linear import new_linear @@ -22,10 +21,11 @@ convert_attn_kwargs_to_args, internlm1_mha_pre_load_convert, internlm1_mha_save_convert, + padding_residual, ) from internlm.solver.activation_checkpoint import activation_checkpoint from internlm.utils.logger import get_logger -from internlm.utils.parallel import is_using_isp, is_using_sequence_parallel +from internlm.utils.parallel import is_using_sequence_parallel logger = get_logger(__file__) @@ -181,8 +181,6 @@ def _forward(self, hidden_states, *args, **kwargs): cu_seqlens: 1d LongTensor, len(cu_seqlens) = hidden_states + 1 indexes: the length of index is same as hidden states, which stand for the current position """ - no_communication = args[4] if len(args) > 4 else False - args = args[:4] def _dropout_and_norm_attn(_hidden_states): _dropped = self.dropout1(_hidden_states) @@ -215,28 +213,13 @@ def _dropout_and_norm_ffn(_residual, _hidden_states): if self.residual_in_fp32: residual = residual.to(torch.float32) + no_communication = gpc.recompute_forward_no_comm + hidden_states = self.mlp(hidden_states, no_communication=no_communication) # pad residual - if no_communication and is_using_sequence_parallel() and not is_using_isp(): - requires_grad = residual.requires_grad - pad_before = gpc.get_local_rank(ParallelMode.TENSOR) * residual.shape[_GATHER_DIM] - pad_after = ( - gpc.get_world_size(ParallelMode.TENSOR) - gpc.get_local_rank(ParallelMode.TENSOR) - 1 - ) * residual.shape[_GATHER_DIM] - - pad_before_tensor = torch.zeros( - (*residual.shape[:_GATHER_DIM], pad_before, *residual.shape[_GATHER_DIM + 1 :]), - dtype=residual.dtype, - device=residual.device, - ) - pad_after_tensor = torch.zeros( - (*residual.shape[:_GATHER_DIM], pad_after, *residual.shape[_GATHER_DIM + 1 :]), - dtype=residual.dtype, - device=residual.device, - ) - - residual = torch.cat([pad_before_tensor, residual, pad_after_tensor], dim=1).requires_grad_(requires_grad) + if no_communication and is_using_sequence_parallel(): + residual = padding_residual(residual) return hidden_states + residual diff --git a/internlm/model/modeling_internlm2.py b/internlm/model/modeling_internlm2.py index 5cafeb74..d44837e0 100644 --- a/internlm/model/modeling_internlm2.py +++ b/internlm/model/modeling_internlm2.py @@ -7,7 +7,6 @@ from internlm.core.context import ParallelMode from internlm.core.context.parallel_context import global_context as gpc -from internlm.core.parallel.comm.tensor import _GATHER_DIM from internlm.initialize.initialize_tensor import ( normal_, scaled_init_method_normal, @@ -22,10 +21,11 @@ from internlm.model.utils import ( convert_attn_args_to_kwargs, convert_attn_kwargs_to_args, + padding_residual, ) from internlm.solver.activation_checkpoint import activation_checkpoint from internlm.utils.logger import get_logger -from internlm.utils.parallel import is_using_isp, is_using_sequence_parallel +from internlm.utils.parallel import is_using_sequence_parallel logger = get_logger(__file__) @@ -218,8 +218,6 @@ def _forward(self, hidden_states, residual, *args, **kwargs): cu_seqlens: 1d LongTensor, len(cu_seqlens) = hidden_states + 1 indexes: the length of index is same as hidden states, which stand for the current position """ - no_communication = args[4] if len(args) > 4 else False - args = args[:4] if self.prenorm: def _dropout_and_norm_attn(_residual, _hidden_states): @@ -259,30 +257,14 @@ def _dropout_and_norm_ffn(_residual, _hidden_states): if self.residual_in_fp32: residual = residual.to(torch.float32) + + no_communication = gpc.recompute_forward_no_comm + hidden_states = self.feed_forward(hidden_states, no_communication=no_communication) # pad residual - if no_communication and is_using_sequence_parallel() and not is_using_isp(): - requires_grad = residual.requires_grad - pad_before = gpc.get_local_rank(ParallelMode.TENSOR) * residual.shape[_GATHER_DIM] - pad_after = ( - gpc.get_world_size(ParallelMode.TENSOR) - gpc.get_local_rank(ParallelMode.TENSOR) - 1 - ) * residual.shape[_GATHER_DIM] - - pad_before_tensor = torch.zeros( - (*residual.shape[:_GATHER_DIM], pad_before, *residual.shape[_GATHER_DIM + 1 :]), - dtype=residual.dtype, - device=residual.device, - ) - pad_after_tensor = torch.zeros( - (*residual.shape[:_GATHER_DIM], pad_after, *residual.shape[_GATHER_DIM + 1 :]), - dtype=residual.dtype, - device=residual.device, - ) - - residual = torch.cat([pad_before_tensor, residual, pad_after_tensor], dim=1).requires_grad_( - requires_grad - ) + if no_communication and is_using_sequence_parallel(): + residual = padding_residual(residual) return hidden_states + residual else: diff --git a/internlm/model/utils.py b/internlm/model/utils.py index c2311007..019e9af6 100644 --- a/internlm/model/utils.py +++ b/internlm/model/utils.py @@ -1,5 +1,10 @@ from typing import Any, Dict, List +import torch + +from internlm.core.context import ParallelMode +from internlm.core.context.parallel_context import global_context as gpc +from internlm.core.parallel.comm.tensor import _GATHER_DIM from internlm.model.modules.mha import MHA @@ -51,3 +56,25 @@ def convert_attn_args_to_kwargs(args, kwargs) -> Dict[str, Any]: kwargs["max_seqlen"] = args[3] return kwargs + + +def padding_residual(residual): + requires_grad = residual.requires_grad + pad_before = gpc.get_local_rank(ParallelMode.TENSOR) * residual.shape[_GATHER_DIM] + pad_after = ( + gpc.get_world_size(ParallelMode.TENSOR) - gpc.get_local_rank(ParallelMode.TENSOR) - 1 + ) * residual.shape[_GATHER_DIM] + + pad_before_tensor = torch.zeros( + (*residual.shape[:_GATHER_DIM], pad_before, *residual.shape[_GATHER_DIM + 1 :]), + dtype=residual.dtype, + device=residual.device, + ) + pad_after_tensor = torch.zeros( + (*residual.shape[:_GATHER_DIM], pad_after, *residual.shape[_GATHER_DIM + 1 :]), + dtype=residual.dtype, + device=residual.device, + ) + residual = torch.cat([pad_before_tensor, residual, pad_after_tensor], dim=1).requires_grad_(requires_grad) + + return residual diff --git a/internlm/solver/activation_checkpoint.py b/internlm/solver/activation_checkpoint.py index f852bd38..33760b2a 100644 --- a/internlm/solver/activation_checkpoint.py +++ b/internlm/solver/activation_checkpoint.py @@ -17,7 +17,7 @@ sync_states, ) from internlm.core.parallel.comm.tensor import _GATHER_DIM, all_gather_raw -from internlm.utils.parallel import is_using_isp, is_using_sequence_parallel +from internlm.utils.parallel import is_using_sequence_parallel from ..utils.common import get_current_device @@ -128,19 +128,18 @@ def backward(ctx, *args): inputs[idx] = tensors[i] # no_communication - no_communication = False - if getattr(gpc.config.model, "checkpoint_tp_no_comm", False): - no_communication = True - inputs.append(True) + no_communication = getattr(gpc.config.model, "checkpoint_tp_no_comm", False) detached_inputs = detach_variable(tuple(inputs)) handle = None - if no_communication and is_using_sequence_parallel() and not is_using_isp(): - grad_output = args[0] - grad_output, handle = all_gather_raw( - grad_output, process_group=gpc.get_group(ParallelMode.TENSOR), async_op=True, gather_dim=_GATHER_DIM - ) + if no_communication: + gpc.recompute_forward_no_comm = True + if is_using_sequence_parallel(): + grad_output = args[0] + grad_output, handle = all_gather_raw( + grad_output, process_group=gpc.get_group(ParallelMode.TENSOR), async_op=True, gather_dim=_GATHER_DIM + ) if ctx.had_autocast_in_fwd: with torch.enable_grad(), internlm_accelerator.amp.autocast(): @@ -149,6 +148,9 @@ def backward(ctx, *args): with torch.enable_grad(): outputs = ctx.run_function(*detached_inputs) + if gpc.recompute_forward_no_comm: + gpc.recompute_forward_no_comm = False + if handle: handle.wait() args = list(args) From 8951edbdfc76c93ca9b289fc336dce5cb9d1818b Mon Sep 17 00:00:00 2001 From: lijiaxing Date: Tue, 9 Jul 2024 16:49:22 +0800 Subject: [PATCH 4/7] change --- internlm/initialize/launch.py | 8 ++-- internlm/model/modeling_internlm.py | 6 +-- internlm/model/modeling_internlm2.py | 6 +-- internlm/model/modules/mlp.py | 5 ++- internlm/solver/activation_checkpoint.py | 55 ++++++++++++++---------- 5 files changed, 44 insertions(+), 36 deletions(-) diff --git a/internlm/initialize/launch.py b/internlm/initialize/launch.py index 712e0006..5b499572 100644 --- a/internlm/initialize/launch.py +++ b/internlm/initialize/launch.py @@ -415,9 +415,11 @@ def args_sanity_check(): ), "only support interleaved pipeline scheduler with overlap" # when not use tp or sp, checkpoint_tp_no_comm should always be False - if (gpc.config.parallel["tensor"]["mode"] == "isp" or gpc.config.parallel["tensor"]["size"] <= 1) and getattr( - gpc.config.model, "checkpoint_tp_no_comm", False - ): + if ( + gpc.config.parallel["tensor"]["mode"] == "isp" + or gpc.config.parallel["tensor"]["size"] <= 1 + or gpc.config.model_type not in ["INTERNLM", "INTERNLM2_PUBLIC"] + ) and getattr(gpc.config.model, "checkpoint_tp_no_comm", False): gpc.config.model.checkpoint_tp_no_comm = False # monitoring default config diff --git a/internlm/model/modeling_internlm.py b/internlm/model/modeling_internlm.py index 311ae2b3..3cf9c14a 100644 --- a/internlm/model/modeling_internlm.py +++ b/internlm/model/modeling_internlm.py @@ -213,12 +213,10 @@ def _dropout_and_norm_ffn(_residual, _hidden_states): if self.residual_in_fp32: residual = residual.to(torch.float32) - no_communication = gpc.recompute_forward_no_comm - - hidden_states = self.mlp(hidden_states, no_communication=no_communication) + hidden_states = self.mlp(hidden_states) # pad residual - if no_communication and is_using_sequence_parallel(): + if gpc.recompute_forward_no_comm and is_using_sequence_parallel(): residual = padding_residual(residual) return hidden_states + residual diff --git a/internlm/model/modeling_internlm2.py b/internlm/model/modeling_internlm2.py index d44837e0..303125c2 100644 --- a/internlm/model/modeling_internlm2.py +++ b/internlm/model/modeling_internlm2.py @@ -258,12 +258,10 @@ def _dropout_and_norm_ffn(_residual, _hidden_states): if self.residual_in_fp32: residual = residual.to(torch.float32) - no_communication = gpc.recompute_forward_no_comm - - hidden_states = self.feed_forward(hidden_states, no_communication=no_communication) + hidden_states = self.feed_forward(hidden_states) # pad residual - if no_communication and is_using_sequence_parallel(): + if gpc.recompute_forward_no_comm and is_using_sequence_parallel(): residual = padding_residual(residual) return hidden_states + residual diff --git a/internlm/model/modules/mlp.py b/internlm/model/modules/mlp.py index f3feee42..ae02081a 100644 --- a/internlm/model/modules/mlp.py +++ b/internlm/model/modules/mlp.py @@ -6,6 +6,7 @@ import torch from torch import nn +from internlm.core.context.parallel_context import global_context as gpc from internlm.model.modules.linear import new_linear from internlm.model.modules.utils import Silu from internlm.utils.logger import get_logger @@ -91,14 +92,14 @@ def __init__( self.w2 = new_linear("w2", hidden_features, out_features, bias, device=device, dtype=dtype) self.w3 = new_linear("w3", in_features, hidden_features, bias, device=device, dtype=dtype) - def forward(self, x, no_communication=False): + def forward(self, x): if not self.mlp_layer_fusion: w1_o = self.w1(x) w3_o = self.w3(x) else: fussed_out = self.fused_w1_w3(x) w1_o, w3_o = torch.split(fussed_out, fussed_out.shape[-1] // 2, dim=-1) - out = self.w2(Silu(w1_o, w3_o), no_communication=no_communication) + out = self.w2(Silu(w1_o, w3_o), no_communication=gpc.recompute_forward_no_comm) return out diff --git a/internlm/solver/activation_checkpoint.py b/internlm/solver/activation_checkpoint.py index 33760b2a..c95b637a 100644 --- a/internlm/solver/activation_checkpoint.py +++ b/internlm/solver/activation_checkpoint.py @@ -2,6 +2,7 @@ # -*- encoding: utf-8 -*- import weakref +from contextlib import contextmanager import torch from torch.utils.checkpoint import check_backward_validity, detach_variable @@ -41,6 +42,30 @@ def copy_to_device(obj, device): return obj +@contextmanager +def recompute_forward_context(args, no_communication): + handle = None + try: + # Set True when entering the context + if no_communication: + gpc.recompute_forward_no_comm = True + if is_using_sequence_parallel(): + # overlap all_gather + grad_output = args[0] + grad_output, handle = all_gather_raw( + grad_output, process_group=gpc.get_group(ParallelMode.TENSOR), async_op=True, gather_dim=_GATHER_DIM + ) + yield + finally: + # Set False when exiting the context + gpc.recompute_forward_no_comm = False + + if handle: + handle.wait() + args = list(args) + args[0] = grad_output + + class CheckpointFunction(torch.autograd.Function): """ Checkpoint Function @@ -132,29 +157,13 @@ def backward(ctx, *args): detached_inputs = detach_variable(tuple(inputs)) - handle = None - if no_communication: - gpc.recompute_forward_no_comm = True - if is_using_sequence_parallel(): - grad_output = args[0] - grad_output, handle = all_gather_raw( - grad_output, process_group=gpc.get_group(ParallelMode.TENSOR), async_op=True, gather_dim=_GATHER_DIM - ) - - if ctx.had_autocast_in_fwd: - with torch.enable_grad(), internlm_accelerator.amp.autocast(): - outputs = ctx.run_function(*detached_inputs) - else: - with torch.enable_grad(): - outputs = ctx.run_function(*detached_inputs) - - if gpc.recompute_forward_no_comm: - gpc.recompute_forward_no_comm = False - - if handle: - handle.wait() - args = list(args) - args[0] = grad_output + with recompute_forward_context(args, no_communication): + if ctx.had_autocast_in_fwd: + with torch.enable_grad(), internlm_accelerator.amp.autocast(): + outputs = ctx.run_function(*detached_inputs) + else: + with torch.enable_grad(): + outputs = ctx.run_function(*detached_inputs) if isinstance(outputs, torch.Tensor): outputs = (outputs,) From 4b41bb511665c35d0cde73604ddc5f590cf38ecc Mon Sep 17 00:00:00 2001 From: lijiaxing Date: Tue, 9 Jul 2024 19:34:08 +0800 Subject: [PATCH 5/7] fix bug --- internlm/solver/activation_checkpoint.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/internlm/solver/activation_checkpoint.py b/internlm/solver/activation_checkpoint.py index c95b637a..e31b3fb5 100644 --- a/internlm/solver/activation_checkpoint.py +++ b/internlm/solver/activation_checkpoint.py @@ -62,7 +62,6 @@ def recompute_forward_context(args, no_communication): if handle: handle.wait() - args = list(args) args[0] = grad_output @@ -157,6 +156,7 @@ def backward(ctx, *args): detached_inputs = detach_variable(tuple(inputs)) + args = list(args) with recompute_forward_context(args, no_communication): if ctx.had_autocast_in_fwd: with torch.enable_grad(), internlm_accelerator.amp.autocast(): From 2fe28913d299665c3d7adc3ca46358760c2a765d Mon Sep 17 00:00:00 2001 From: lijiaxing Date: Thu, 11 Jul 2024 17:17:46 +0800 Subject: [PATCH 6/7] improve --- internlm/model/modules/linear.py | 13 ++++++++++--- internlm/model/modules/mlp.py | 3 +-- internlm/model/utils.py | 22 ++++++++-------------- internlm/solver/activation_checkpoint.py | 2 +- 4 files changed, 20 insertions(+), 20 deletions(-) diff --git a/internlm/model/modules/linear.py b/internlm/model/modules/linear.py index 17e1b290..b1af42ad 100644 --- a/internlm/model/modules/linear.py +++ b/internlm/model/modules/linear.py @@ -349,9 +349,12 @@ def __init__( else: super().__init__(in_features, out_features, bias=bias, device=device, dtype=dtype) - def forward(self, input: torch.Tensor, no_communication=False) -> torch.Tensor: # pylint: disable=W0622 + self.last_block_layer = False + + def forward(self, input: torch.Tensor) -> torch.Tensor: # pylint: disable=W0622 _class_name = self.__class__.__name__ assert self._communicator is not None, f"{_class_name} should register with a communicator first." + no_communication = bool(gpc.recompute_forward_no_comm and self.last_block_layer) return fused_dense_func( input, self.weight, @@ -414,6 +417,7 @@ def __init__( multiple_of: int = 1, device: torch.device = None, dtype: torch.dtype = None, + layer_name: str = "default", ) -> None: if in_features % multiple_of: raise ValueError(f"in_features ({in_features}) must be a multiple of {multiple_of}") @@ -430,6 +434,9 @@ def __init__( split_mode="row", ) + if layer_name == "w2": + self.last_block_layer = True + class ScaleColumnParallelLinear(ParallelLinearWithCommExt): """ @@ -471,7 +478,7 @@ def __init__( self.first_eval_flag = True self.tmp_weight = None - def forward(self, input, no_communication=False): # pylint: disable=W0622 + def forward(self, input): # pylint: disable=W0622 _class_name = self.__class__.__name__ assert self._communicator is not None, f"{_class_name} should register with a communicator first." @@ -502,7 +509,6 @@ def forward(self, input, no_communication=False): # pylint: disable=W0622 communicator=self._communicator, module=self, bias=self.bias, - no_communication=no_communication, ) @@ -603,6 +609,7 @@ def new_linear( multiple_of, device, dtype, + layer_name=name, ) else: err_msg = ( diff --git a/internlm/model/modules/mlp.py b/internlm/model/modules/mlp.py index ae02081a..897e1363 100644 --- a/internlm/model/modules/mlp.py +++ b/internlm/model/modules/mlp.py @@ -6,7 +6,6 @@ import torch from torch import nn -from internlm.core.context.parallel_context import global_context as gpc from internlm.model.modules.linear import new_linear from internlm.model.modules.utils import Silu from internlm.utils.logger import get_logger @@ -99,7 +98,7 @@ def forward(self, x): else: fussed_out = self.fused_w1_w3(x) w1_o, w3_o = torch.split(fussed_out, fussed_out.shape[-1] // 2, dim=-1) - out = self.w2(Silu(w1_o, w3_o), no_communication=gpc.recompute_forward_no_comm) + out = self.w2(Silu(w1_o, w3_o)) return out diff --git a/internlm/model/utils.py b/internlm/model/utils.py index 019e9af6..ba92d457 100644 --- a/internlm/model/utils.py +++ b/internlm/model/utils.py @@ -4,7 +4,6 @@ from internlm.core.context import ParallelMode from internlm.core.context.parallel_context import global_context as gpc -from internlm.core.parallel.comm.tensor import _GATHER_DIM from internlm.model.modules.mha import MHA @@ -60,21 +59,16 @@ def convert_attn_args_to_kwargs(args, kwargs) -> Dict[str, Any]: def padding_residual(residual): requires_grad = residual.requires_grad - pad_before = gpc.get_local_rank(ParallelMode.TENSOR) * residual.shape[_GATHER_DIM] - pad_after = ( - gpc.get_world_size(ParallelMode.TENSOR) - gpc.get_local_rank(ParallelMode.TENSOR) - 1 - ) * residual.shape[_GATHER_DIM] - - pad_before_tensor = torch.zeros( - (*residual.shape[:_GATHER_DIM], pad_before, *residual.shape[_GATHER_DIM + 1 :]), - dtype=residual.dtype, - device=residual.device, - ) - pad_after_tensor = torch.zeros( - (*residual.shape[:_GATHER_DIM], pad_after, *residual.shape[_GATHER_DIM + 1 :]), + _GATHER_DIM = 1 + total_size = gpc.get_world_size(ParallelMode.TENSOR) * residual.shape[_GATHER_DIM] + zero_padding_tensor = torch.zeros( + (*residual.shape[:_GATHER_DIM], total_size, *residual.shape[_GATHER_DIM + 1 :]), dtype=residual.dtype, device=residual.device, ) - residual = torch.cat([pad_before_tensor, residual, pad_after_tensor], dim=1).requires_grad_(requires_grad) + start_idx = gpc.get_local_rank(ParallelMode.TENSOR) * residual.shape[_GATHER_DIM] + end_idx = start_idx + residual.shape[_GATHER_DIM] + zero_padding_tensor[:, start_idx:end_idx, :] = residual + residual = zero_padding_tensor.requires_grad_(requires_grad) return residual diff --git a/internlm/solver/activation_checkpoint.py b/internlm/solver/activation_checkpoint.py index e31b3fb5..b46e91fb 100644 --- a/internlm/solver/activation_checkpoint.py +++ b/internlm/solver/activation_checkpoint.py @@ -151,7 +151,7 @@ def backward(ctx, *args): for i, idx in enumerate(tensor_indices): inputs[idx] = tensors[i] - # no_communication + # when checkpoint_tp_no_comm==True, we use TP recomputation communication optimization no_communication = getattr(gpc.config.model, "checkpoint_tp_no_comm", False) detached_inputs = detach_variable(tuple(inputs)) From 537d4705bf6ee093d06bb3c9650816b3117f7e33 Mon Sep 17 00:00:00 2001 From: lijiaxing Date: Fri, 12 Jul 2024 12:14:38 +0800 Subject: [PATCH 7/7] fix comment --- internlm/core/parallel/comm/tensor.py | 68 ++++++++++++++++++++------- internlm/model/modules/linear.py | 24 +++------- internlm/train/pipeline.py | 21 +++++++++ 3 files changed, 77 insertions(+), 36 deletions(-) diff --git a/internlm/core/parallel/comm/tensor.py b/internlm/core/parallel/comm/tensor.py index ac883063..b4203875 100644 --- a/internlm/core/parallel/comm/tensor.py +++ b/internlm/core/parallel/comm/tensor.py @@ -66,7 +66,9 @@ def input_hook( @abstractmethod def grad_output_hook( - self, grad_output: torch.Tensor, async_op: bool = False, no_communication: bool = False + self, + grad_output: torch.Tensor, + async_op: bool = False, ) -> Tuple[torch.Tensor, AsyncCommHandle]: """ communication for grad_output when backward. @@ -82,7 +84,9 @@ def grad_input_hook(self, grad_input: torch.Tensor, async_op: bool = False) -> T @abstractmethod def output_hook( - self, output: torch.Tensor, async_op: bool = False, no_communication: bool = False + self, + output: torch.Tensor, + async_op: bool = False, ) -> Tuple[torch.Tensor, AsyncCommHandle]: """ communication for output when forward. @@ -95,13 +99,14 @@ class TensorParallelCommunicator(TPCommunicator): tensor parallel communicator for linear """ - def __init__(self, process_group: dist.ProcessGroup, role: LinearRole) -> None: + def __init__(self, process_group: dist.ProcessGroup, role: LinearRole, last_block_layer=False) -> None: assert role in (LinearRole.COLUMN, LinearRole.ROW), f"Unknown linear role: {role}" self._process_group = process_group self._role = role self._save_total_input = False + self.last_block_layer = last_block_layer def save_total_input(self) -> bool: return self._save_total_input @@ -120,8 +125,7 @@ def input_hook( def grad_output_hook( self, grad_output: torch.Tensor, - async_op: bool = False, - no_communication: bool = False, # pylint: disable=W0613 + async_op: bool = False, # pylint: disable=W0613 ) -> Tuple[torch.Tensor, AsyncCommHandle]: """ tensor parallel should do nothing for grad_output. @@ -138,12 +142,18 @@ def grad_input_hook(self, grad_input: torch.Tensor, async_op: bool = False) -> T return all_reduce_raw(grad_input, process_group=self._process_group, async_op=async_op) def output_hook( - self, output: torch.Tensor, async_op: bool = False, no_communication: bool = False + self, + output: torch.Tensor, + async_op: bool = False, ) -> Tuple[torch.Tensor, AsyncCommHandle]: """ all reduce output only for row parallel linear when forward. """ - if no_communication or dist.get_world_size(self._process_group) <= 1 or self._role == LinearRole.COLUMN: + if ( + (self.last_block_layer and gpc.recompute_forward_no_comm) + or dist.get_world_size(self._process_group) <= 1 + or self._role == LinearRole.COLUMN + ): return output, DUMMY_HANDLE_CONST return all_reduce_raw(output, process_group=self._process_group, async_op=async_op) @@ -155,7 +165,11 @@ class SequenceParallelCommunicator(TPCommunicator): """ def __init__( - self, process_group: dist.ProcessGroup, role: LinearRole, save_total_input_as_activation: bool = False + self, + process_group: dist.ProcessGroup, + role: LinearRole, + save_total_input_as_activation: bool = False, + last_block_layer=False, ) -> None: assert role in (LinearRole.COLUMN, LinearRole.ROW), f"Unknown linear role: {role}" @@ -163,6 +177,8 @@ def __init__( self._role = role self._save_total_input = save_total_input_as_activation + self.last_block_layer = last_block_layer + self.no_communication = False def save_total_input(self) -> bool: return self._save_total_input @@ -189,12 +205,19 @@ def input_hook( return all_gather_raw(_input, process_group=self._process_group, async_op=async_op, gather_dim=_GATHER_DIM) def grad_output_hook( - self, grad_output: torch.Tensor, async_op: bool = False, no_communication: bool = False + self, + grad_output: torch.Tensor, + async_op: bool = False, ) -> Tuple[torch.Tensor, AsyncCommHandle]: """ all gather grad_output only for row parallel linear when backward. """ - if no_communication or dist.get_world_size(self._process_group) <= 1 or self._role == LinearRole.COLUMN: + if ( + (self.last_block_layer and self.no_communication) + or dist.get_world_size(self._process_group) <= 1 + or self._role == LinearRole.COLUMN + ): + self.no_communication = False return grad_output, DUMMY_HANDLE_CONST return all_gather_raw(grad_output, process_group=self._process_group, async_op=async_op, gather_dim=_GATHER_DIM) @@ -211,12 +234,19 @@ def grad_input_hook(self, grad_input: torch.Tensor, async_op: bool = False) -> T ) def output_hook( - self, output: torch.Tensor, async_op: bool = False, no_communication: bool = False + self, + output: torch.Tensor, + async_op: bool = False, ) -> Tuple[torch.Tensor, AsyncCommHandle]: """ reduce scatter output only for row parallel linear when forward. """ - if no_communication or dist.get_world_size(self._process_group) <= 1 or self._role == LinearRole.COLUMN: + self.no_communication = gpc.recompute_forward_no_comm + if ( + (self.last_block_layer and self.no_communication) + or dist.get_world_size(self._process_group) <= 1 + or self._role == LinearRole.COLUMN + ): return output, DUMMY_HANDLE_CONST return reduce_scatter_raw(output, process_group=self._process_group, async_op=async_op, reduce_dim=_REDUCE_DIM) @@ -236,8 +266,7 @@ def __init__(self, parallel_mode: ParallelMode, retain_out_sharded: bool = True) def grad_output_hook( self, grad_output: torch.Tensor, - async_op: bool = False, - no_communication: bool = False, # pylint: disable=W0613 + async_op: bool = False, # pylint: disable=W0613 ) -> Tuple[torch.Tensor, AsyncCommHandle]: """ split grad_output if retain_out_sharded is False. @@ -248,7 +277,9 @@ def grad_output_hook( return _split(grad_output, parallel_mode=self._parallel_mode, dim=-1), DUMMY_HANDLE_CONST def output_hook( - self, output: torch.Tensor, async_op: bool = False, no_communication: bool = False # pylint: disable=W0613 + self, + output: torch.Tensor, + async_op: bool = False, # pylint: disable=W0613 ) -> Tuple[torch.Tensor, AsyncCommHandle]: """ all gather output for head layer if retain_out_sharded is False. @@ -280,8 +311,7 @@ def __init__( def grad_output_hook( self, grad_output: torch.Tensor, - async_op: bool = False, - no_communication: bool = False, # pylint: disable=W0613 + async_op: bool = False, # pylint: disable=W0613 ) -> Tuple[torch.Tensor, AsyncCommHandle]: """ split grad_output if retain_out_sharded is False. @@ -293,7 +323,9 @@ def grad_output_hook( # rewrite ouput communication hook def output_hook( - self, output: torch.Tensor, async_op: bool = False, no_communication: bool = False # pylint: disable=W0613 + self, + output: torch.Tensor, + async_op: bool = False, # pylint: disable=W0613 ) -> Tuple[torch.Tensor, AsyncCommHandle]: """ all gather output for head layer if retain_out_sharded is False. diff --git a/internlm/model/modules/linear.py b/internlm/model/modules/linear.py index b1af42ad..77c482c9 100644 --- a/internlm/model/modules/linear.py +++ b/internlm/model/modules/linear.py @@ -45,12 +45,10 @@ def forward( bias: Optional[torch.Tensor], communicator: TPCommunicator, return_residual=False, - no_communication=False, ): ctx.compute_weight_gradient = weight.requires_grad ctx.return_residual = return_residual ctx.communicator = communicator - ctx.no_communication = no_communication if torch.is_autocast_enabled(): x = x.to(dtype=torch.get_autocast_gpu_dtype()) @@ -79,7 +77,7 @@ def forward( # parallel strategy-specific communication callback 2. # see more details in the communicator for different parallel strategies. - output, _ = communicator.output_hook(output, async_op=False, no_communication=no_communication) + output, _ = communicator.output_hook(output, async_op=False) saved_x = None if ctx.compute_weight_gradient is False else total_x if communicator.save_total_input() else x ctx.save_for_backward(saved_x, weight) @@ -93,9 +91,7 @@ def backward(ctx, grad_output, *args): # parallel strategy-specific communication callback 3. # see more details in the communicator for different parallel strategies. - grad_output, _ = communicator.grad_output_hook( - grad_output, no_communication=ctx.no_communication, async_op=False - ) + grad_output, _ = communicator.grad_output_hook(grad_output, async_op=False) grad_output = grad_output.contiguous() if ctx.return_residual: @@ -268,7 +264,6 @@ def fused_dense_func( module: Optional[nn.Module] = None, bias: Optional[torch.Tensor] = None, return_residual: bool = False, - no_communication=False, ): if communicator.communication_mode() == "wp": return WPFusedDenseFunc.apply( @@ -286,7 +281,6 @@ def fused_dense_func( bias, communicator, return_residual, - no_communication, ) @@ -349,19 +343,15 @@ def __init__( else: super().__init__(in_features, out_features, bias=bias, device=device, dtype=dtype) - self.last_block_layer = False - def forward(self, input: torch.Tensor) -> torch.Tensor: # pylint: disable=W0622 _class_name = self.__class__.__name__ assert self._communicator is not None, f"{_class_name} should register with a communicator first." - no_communication = bool(gpc.recompute_forward_no_comm and self.last_block_layer) return fused_dense_func( input, self.weight, communicator=self._communicator, module=self, bias=self.bias, - no_communication=no_communication, ) @@ -417,7 +407,6 @@ def __init__( multiple_of: int = 1, device: torch.device = None, dtype: torch.dtype = None, - layer_name: str = "default", ) -> None: if in_features % multiple_of: raise ValueError(f"in_features ({in_features}) must be a multiple of {multiple_of}") @@ -434,9 +423,6 @@ def __init__( split_mode="row", ) - if layer_name == "w2": - self.last_block_layer = True - class ScaleColumnParallelLinear(ParallelLinearWithCommExt): """ @@ -602,15 +588,17 @@ def new_linear( dtype, ) elif split_mode == "row": - return RowParallelLinear( + linear = RowParallelLinear( in_features, out_features, bias, multiple_of, device, dtype, - layer_name=name, ) + if name == "w2": + setattr(linear, "last_block_layer", True) + return linear else: err_msg = ( f"Parallel strategies for linear is unsupported, which is named as {name}.\n" diff --git a/internlm/train/pipeline.py b/internlm/train/pipeline.py index bdafdedb..7b13d0c8 100644 --- a/internlm/train/pipeline.py +++ b/internlm/train/pipeline.py @@ -277,6 +277,15 @@ def initialize_parallel_communicator(model: Union[nn.Module, nn.ModuleList]): ) _head_communicator = HeadTensorParallelCommunicator(ParallelMode.TENSOR, _retain_out_sharded) _embedding_communicator = EmbbedingTensorParallelCommunicator(ParallelMode.TENSOR) + + # for tp recompute communication optimization, sign last block layer + for row_parallel_linear in _submodule_filter(model, RowParallelLinear): + if getattr(row_parallel_linear, "last_block_layer", False): + row_parallel_linear.register_communicator( + TensorParallelCommunicator( + process_group=gpc.get_group(ParallelMode.TENSOR), role=LinearRole.ROW, last_block_layer=True + ) + ) # sequence parallel if gpc.config.parallel.tensor.mode in ("msp", "fsp"): save_total_input_as_activation = gpc.config.parallel.tensor.mode == "msp" @@ -296,6 +305,18 @@ def initialize_parallel_communicator(model: Union[nn.Module, nn.ModuleList]): ) ) + # for tp recompute communication optimization, sign last block layer + for row_parallel_linear in _submodule_filter(model, RowParallelLinear): + if getattr(row_parallel_linear, "last_block_layer", False): + row_parallel_linear.register_communicator( + SequenceParallelCommunicator( + gpc.get_group(ParallelMode.TENSOR), + role=LinearRole.ROW, + save_total_input_as_activation=save_total_input_as_activation, + last_block_layer=True, + ) + ) + _head_communicator = HeadSequenceParallelCommunicator( ParallelMode.TENSOR, _retain_out_sharded, save_total_input_as_activation )