diff --git a/colossalai/shardformer/policies/llama.py b/colossalai/shardformer/policies/llama.py index f72a72df0b1b..67e6e92d1d36 100644 --- a/colossalai/shardformer/policies/llama.py +++ b/colossalai/shardformer/policies/llama.py @@ -95,19 +95,20 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: policy=policy, target_key=attn_cls, ) - if self.pipeline_stage_manager is None: - self.append_or_create_method_replacement( - description={ - "forward": get_llama_flash_attention_model_forward( - self.shard_config, - sp_mode=sp_mode, - sp_size=sp_size, - sp_group=sp_group, - ), - }, - policy=policy, - target_key=LlamaModel, - ) + + if self.pipeline_stage_manager is None: + self.append_or_create_method_replacement( + description={ + "forward": get_llama_flash_attention_model_forward( + self.shard_config, + sp_mode=sp_mode, + sp_size=sp_size, + sp_group=sp_group, + ), + }, + policy=policy, + target_key=LlamaModel, + ) if self.shard_config.enable_tensor_parallelism: assert (