From d631586ae3063c88653845e10283d3cea06657af Mon Sep 17 00:00:00 2001 From: Xibo Sun Date: Fri, 22 Nov 2024 15:50:29 +0800 Subject: [PATCH 1/8] split text in MM-DiT --- xfuser/core/distributed/runtime_state.py | 5 ++ .../long_ctx_attention/hybrid/attn_layer.py | 36 +++++---- .../ring/ring_flash_attn.py | 30 +++++--- .../layers/attention_processor.py | 74 ++++++++++++------- .../model_executor/pipelines/base_pipeline.py | 37 ++++------ .../pipelines/pipeline_cogvideox.py | 43 ++++------- .../model_executor/pipelines/pipeline_flux.py | 21 +++++- .../pipelines/pipeline_stable_diffusion_3.py | 22 +++++- 8 files changed, 159 insertions(+), 109 deletions(-) diff --git a/xfuser/core/distributed/runtime_state.py b/xfuser/core/distributed/runtime_state.py index d889574..cfb4c79 100644 --- a/xfuser/core/distributed/runtime_state.py +++ b/xfuser/core/distributed/runtime_state.py @@ -93,6 +93,7 @@ class DiTRuntimeState(RuntimeState): pp_patches_token_start_end_idx_global: Optional[List[List[int]]] pp_patches_token_num: Optional[List[int]] max_condition_sequence_length: int + split_text_embed_in_sp: bool def __init__(self, pipeline: DiffusionPipeline, config: EngineConfig): super().__init__(config) @@ -128,11 +129,13 @@ def set_input_parameters( num_inference_steps: Optional[int] = None, seed: Optional[int] = None, max_condition_sequence_length: Optional[int] = None, + split_text_embed_in_sp: bool = True, ): self.input_config.num_inference_steps = ( num_inference_steps or self.input_config.num_inference_steps ) self.max_condition_sequence_length = max_condition_sequence_length + self.split_text_embed_in_sp = split_text_embed_in_sp if self.runtime_config.warmup_steps > self.input_config.num_inference_steps: self.runtime_config.warmup_steps = self.input_config.num_inference_steps if seed is not None and seed != self.input_config.seed: @@ -156,12 +159,14 @@ def set_video_input_parameters( batch_size: Optional[int] = None, num_inference_steps: Optional[int] = None, seed: Optional[int] = None, + split_text_embed_in_sp: bool = True, ): self.input_config.num_inference_steps = ( num_inference_steps or self.input_config.num_inference_steps ) if self.runtime_config.warmup_steps > self.input_config.num_inference_steps: self.runtime_config.warmup_steps = self.input_config.num_inference_steps + self.split_text_embed_in_sp = split_text_embed_in_sp if seed is not None and seed != self.input_config.seed: self.input_config.seed = seed set_random_seed(seed) diff --git a/xfuser/core/long_ctx_attention/hybrid/attn_layer.py b/xfuser/core/long_ctx_attention/hybrid/attn_layer.py index 8a23d62..d4fadf7 100644 --- a/xfuser/core/long_ctx_attention/hybrid/attn_layer.py +++ b/xfuser/core/long_ctx_attention/hybrid/attn_layer.py @@ -85,23 +85,31 @@ def forward( Returns: * output (Tensor): context output """ - supported_joint_strategy = ["none", "front", "rear"] - if joint_strategy not in supported_joint_strategy: - raise ValueError( - f"joint_strategy: {joint_strategy} not supprted. supported joint strategy: {supported_joint_strategy}" - ) - elif joint_strategy != "none" and joint_tensor_query is None: + is_joint = False + if (joint_tensor_query is not None and + joint_tensor_key is not None and + joint_tensor_value is not None): + supported_joint_strategy = ["front", "rear"] + if joint_strategy not in supported_joint_strategy: + raise ValueError( + f"joint_strategy: {joint_strategy} not supprted. supported joint strategy: {supported_joint_strategy}" + ) + elif joint_strategy == "rear": + query = torch.cat([query, joint_tensor_query], dim=1) + is_joint = True + else: + query = torch.cat([joint_tensor_query, query], dim=1) + is_joint = True + elif (joint_tensor_query is None and + joint_tensor_key is None and + joint_tensor_value is None): + pass + else: raise ValueError( - f"joint_tensor_query must not be None when joint_strategy is not None" + f"joint_tensor_query, joint_tensor_key, and joint_tensor_value should be None or not None simultaneously." ) - elif joint_strategy == "rear": - query = torch.cat([query, joint_tensor_query], dim=1) - elif joint_strategy == "front": - query = torch.cat([joint_tensor_query, query], dim=1) - else: - pass - if joint_strategy != "none": + if is_joint: ulysses_world_size = torch.distributed.get_world_size(self.ulysses_pg) ulysses_rank = torch.distributed.get_rank(self.ulysses_pg) attn_heads_per_ulysses_rank = ( diff --git a/xfuser/core/long_ctx_attention/ring/ring_flash_attn.py b/xfuser/core/long_ctx_attention/ring/ring_flash_attn.py index bdae65d..7a729d1 100644 --- a/xfuser/core/long_ctx_attention/ring/ring_flash_attn.py +++ b/xfuser/core/long_ctx_attention/ring/ring_flash_attn.py @@ -22,16 +22,22 @@ def xdit_ring_flash_attn_forward( joint_tensor_value=None, joint_strategy="none", ): - supported_joint_strategy = ["none", "front", "rear"] - if joint_strategy not in supported_joint_strategy: - raise ValueError( - f"joint_strategy: {joint_strategy} not supprted. supported joint strategy: {supported_joint_strategy}" - ) - elif joint_strategy != "none" and ( - joint_tensor_key is None or joint_tensor_value is None - ): + is_joint = False + if (joint_tensor_key is not None and + joint_tensor_value is not None): + supported_joint_strategy = ["front", "rear"] + if joint_strategy not in supported_joint_strategy: + raise ValueError( + f"joint_strategy: {joint_strategy} not supprted. supported joint strategy: {supported_joint_strategy}" + ) + else: + is_joint = True + elif (joint_tensor_key is None and + joint_tensor_value is None): + pass + else: raise ValueError( - f"joint_tensor_key & joint_tensor_value must not be None when joint_strategy is not None" + f"joint_tensor_key and joint_tensor_value should be None or not None simultaneously." ) comm = RingComm(process_group) @@ -57,19 +63,19 @@ def xdit_ring_flash_attn_forward( next_v: torch.Tensor = comm.send_recv(v) comm.commit() - if joint_strategy == "rear": + if is_joint and joint_strategy == "rear": if step + 1 == comm.world_size: key = torch.cat([k, joint_tensor_key], dim=1) value = torch.cat([v, joint_tensor_value], dim=1) else: key, value = k, v - elif joint_strategy == "front": + elif is_joint and joint_strategy == "front": if step == 0: key = torch.cat([joint_tensor_key, k], dim=1) value = torch.cat([joint_tensor_value, v], dim=1) else: key, value = k, v - elif joint_strategy == "none": + else: key, value = k, v if not causal or step <= comm.rank: diff --git a/xfuser/model_executor/layers/attention_processor.py b/xfuser/model_executor/layers/attention_processor.py index 23c4e55..d37e284 100644 --- a/xfuser/model_executor/layers/attention_processor.py +++ b/xfuser/model_executor/layers/attention_processor.py @@ -453,19 +453,28 @@ def __call__( #! ---------------------------------------- ATTENTION ---------------------------------------- if HAS_LONG_CTX_ATTN and get_sequence_parallel_world_size() > 1: + if get_runtime_state().split_text_embed_in_sp: + query = torch.cat([query, encoder_hidden_states_query_proj], dim=1) + key = torch.cat([key, encoder_hidden_states_key_proj], dim=1) + value = torch.cat([value, encoder_hidden_states_value_proj], dim=1) + + encoder_hidden_states_query_proj = None + encoder_hidden_states_key_proj = None + encoder_hidden_states_value_proj = None + else: + encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view( + batch_size, -1, attn.heads, head_dim + ) + encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view( + batch_size, -1, attn.heads, head_dim + ) + encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view( + batch_size, -1, attn.heads, head_dim + ) query = query.view(batch_size, -1, attn.heads, head_dim) key = key.view(batch_size, -1, attn.heads, head_dim) value = value.view(batch_size, -1, attn.heads, head_dim) - encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view( - batch_size, -1, attn.heads, head_dim - ) - encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view( - batch_size, -1, attn.heads, head_dim - ) - encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view( - batch_size, -1, attn.heads, head_dim - ) hidden_states = self.hybrid_seq_parallel_attn( attn, query, @@ -682,15 +691,20 @@ def __call__( query = query.transpose(1, 2) key = key.transpose(1, 2) value = value.transpose(1, 2) - encoder_hidden_states_query_proj, query = query.split( - [num_encoder_hidden_states_tokens, num_query_tokens], dim=1 - ) - encoder_hidden_states_key_proj, key = key.split( - [num_encoder_hidden_states_tokens, num_query_tokens], dim=1 - ) - encoder_hidden_states_value_proj, value = value.split( - [num_encoder_hidden_states_tokens, num_query_tokens], dim=1 - ) + if get_runtime_state().split_text_embed_in_sp: + encoder_hidden_states_query_proj = None + encoder_hidden_states_key_proj = None + encoder_hidden_states_value_proj = None + else: + encoder_hidden_states_query_proj, query = query.split( + [num_encoder_hidden_states_tokens, num_query_tokens], dim=1 + ) + encoder_hidden_states_key_proj, key = key.split( + [num_encoder_hidden_states_tokens, num_query_tokens], dim=1 + ) + encoder_hidden_states_value_proj, value = value.split( + [num_encoder_hidden_states_tokens, num_query_tokens], dim=1 + ) hidden_states = self.hybrid_seq_parallel_attn( attn, query, @@ -1043,19 +1057,25 @@ def __call__( #! ---------------------------------------- ATTENTION ---------------------------------------- if HAS_LONG_CTX_ATTN and get_sequence_parallel_world_size() > 1: - encoder_query = query[:, :, :text_seq_length, :] - query = query[:, :, text_seq_length:, :] - encoder_key = key[:, :, :text_seq_length, :] - key = key[:, :, text_seq_length:, :] - encoder_value = value[:, :, :text_seq_length, :] - value = value[:, :, text_seq_length:, :] + if get_runtime_state().split_text_embed_in_sp: + encoder_query = None + encoder_key = None + encoder_value = None + else: + encoder_query = query[:, :, :text_seq_length, :] + query = query[:, :, text_seq_length:, :] + encoder_key = key[:, :, :text_seq_length, :] + key = key[:, :, text_seq_length:, :] + encoder_value = value[:, :, :text_seq_length, :] + value = value[:, :, text_seq_length:, :] + + encoder_query = encoder_query.transpose(1, 2) + encoder_key = encoder_key.transpose(1, 2) + encoder_value = encoder_value.transpose(1, 2) query = query.transpose(1, 2) key = key.transpose(1, 2) value = value.transpose(1, 2) - encoder_query = encoder_query.transpose(1, 2) - encoder_key = encoder_key.transpose(1, 2) - encoder_value = encoder_value.transpose(1, 2) hidden_states = self.hybrid_seq_parallel_attn( attn, diff --git a/xfuser/model_executor/pipelines/base_pipeline.py b/xfuser/model_executor/pipelines/base_pipeline.py index 1028e7c..5ca54f3 100644 --- a/xfuser/model_executor/pipelines/base_pipeline.py +++ b/xfuser/model_executor/pipelines/base_pipeline.py @@ -28,7 +28,6 @@ get_runtime_state, initialize_runtime_state, is_dp_last_group, - get_sequence_parallel_rank, ) from xfuser.core.fast_attention import ( get_fast_attn_enable, @@ -408,39 +407,33 @@ def _init_async_pipeline( def _process_cfg_split_batch( self, - concat_group_0_negative: torch.Tensor, - concat_group_0: torch.Tensor, - concat_group_1_negative: torch.Tensor, - concat_group_1: torch.Tensor, + negative_embeds: torch.Tensor, + embeds: torch.Tensor, + negative_embdes_mask: torch.Tensor = None, + embeds_mask: torch.Tensor = None, ): if get_classifier_free_guidance_world_size() == 1: - concat_group_0 = torch.cat([concat_group_0_negative, concat_group_0], dim=0) - concat_group_1 = torch.cat([concat_group_1_negative, concat_group_1], dim=0) + embeds = torch.cat([negative_embeds, embeds], dim=0) elif get_classifier_free_guidance_rank() == 0: - concat_group_0 = concat_group_0_negative - concat_group_1 = concat_group_1_negative + embeds = negative_embeds elif get_classifier_free_guidance_rank() == 1: - concat_group_0 = concat_group_0 - concat_group_1 = concat_group_1 + embeds = embeds else: raise ValueError("Invalid classifier free guidance rank") - return concat_group_0, concat_group_1 - def _process_cfg_split_batch_latte( - self, - concat_group_0: torch.Tensor, - concat_group_0_negative: torch.Tensor, - ): + if negative_embdes_mask is None: + return embeds + if get_classifier_free_guidance_world_size() == 1: - concat_group_0 = torch.cat([concat_group_0_negative, concat_group_0], dim=0) + embeds_mask = torch.cat([negative_embdes_mask, embeds_mask], dim=0) elif get_classifier_free_guidance_rank() == 0: - concat_group_0 = concat_group_0_negative + embeds_mask = negative_embdes_mask elif get_classifier_free_guidance_rank() == 1: - concat_group_0 = concat_group_0 + embeds_mask = embeds_mask else: raise ValueError("Invalid classifier free guidance rank") - return concat_group_0 - + return embeds, embeds_mask + def is_dp_last_group(self): """Return True if in the last data parallel group, False otherwise. Also include parallel vae situation. diff --git a/xfuser/model_executor/pipelines/pipeline_cogvideox.py b/xfuser/model_executor/pipelines/pipeline_cogvideox.py index bb4795c..98631d5 100644 --- a/xfuser/model_executor/pipelines/pipeline_cogvideox.py +++ b/xfuser/model_executor/pipelines/pipeline_cogvideox.py @@ -16,21 +16,13 @@ from xfuser.config import EngineConfig from xfuser.core.distributed import ( - get_data_parallel_world_size, - get_sequence_parallel_world_size, get_pipeline_parallel_world_size, + get_sequence_parallel_world_size, + get_sequence_parallel_rank, get_classifier_free_guidance_world_size, - get_classifier_free_guidance_rank, - get_pipeline_parallel_rank, - get_pp_group, - get_world_group, get_cfg_group, get_sp_group, get_runtime_state, - initialize_runtime_state, - get_data_parallel_rank, - is_pipeline_first_stage, - is_pipeline_last_stage, is_dp_last_group, ) @@ -206,6 +198,7 @@ def __call__( num_frames=num_frames, batch_size=batch_size, num_inference_steps=num_inference_steps, + split_text_embed_in_sp=get_pipeline_parallel_world_size() == 1, ) # 3. Encode input prompt @@ -219,8 +212,8 @@ def __call__( max_sequence_length=max_sequence_length, device=device, ) - prompt_embeds = self._process_cfg_split_batch_latte( - prompt_embeds, negative_prompt_embeds + prompt_embeds = self._process_cfg_split_batch( + negative_prompt_embeds, prompt_embeds ) # 4. Prepare timesteps @@ -266,8 +259,8 @@ def __call__( num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) p_t = self.transformer.config.patch_size_t or 1 - latents, image_rotary_emb = self._init_sync_pipeline( - latents, image_rotary_emb, + latents, prompt_embeds, image_rotary_emb = self._init_sync_pipeline( + latents, prompt_embeds, image_rotary_emb, (latents.size(1) + p_t - 1) // p_t ) with self.progress_bar(total=num_inference_steps) as progress_bar: @@ -301,18 +294,7 @@ def __call__( # perform guidance if use_dynamic_cfg: self._guidance_scale = 1 + guidance_scale * ( - ( - 1 - - math.cos( - math.pi - * ( - (num_inference_steps - t.item()) - / num_inference_steps - ) - ** 5.0 - ) - ) - / 2 + (1 - math.cos(math.pi * ((num_inference_steps - t.item()) / num_inference_steps) ** 5.0)) / 2 ) if do_classifier_free_guidance: if get_classifier_free_guidance_world_size() == 1: @@ -381,10 +363,17 @@ def __call__( def _init_sync_pipeline( self, latents: torch.Tensor, + prompt_embeds: torch.Tensor, image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, latents_frames: Optional[int] = None, ): latents = super()._init_video_sync_pipeline(latents) + + if get_runtime_state().split_text_embed_in_sp: + assert prompt_embeds.shape[-2] % get_sequence_parallel_world_size() == 0, \ + f"the length of text sequence {prompt_embeds.shape[-2]} is not divisible by sp_degree {get_sequence_parallel_world_size()}" + prompt_embeds = torch.chunk(prompt_embeds, get_sequence_parallel_world_size(), dim=-2)[get_sequence_parallel_rank()] + if image_rotary_emb is not None: assert latents_frames is not None d = image_rotary_emb[0].shape[-1] @@ -412,7 +401,7 @@ def _init_sync_pipeline( dim=0, ), ) - return latents, image_rotary_emb + return latents, prompt_embeds, image_rotary_emb @property def interrupt(self): diff --git a/xfuser/model_executor/pipelines/pipeline_flux.py b/xfuser/model_executor/pipelines/pipeline_flux.py index a1c8048..2ac3239 100644 --- a/xfuser/model_executor/pipelines/pipeline_flux.py +++ b/xfuser/model_executor/pipelines/pipeline_flux.py @@ -28,11 +28,11 @@ get_runtime_state, get_pp_group, get_sequence_parallel_world_size, + get_sequence_parallel_rank, get_sp_group, is_pipeline_first_stage, is_pipeline_last_stage, is_dp_last_group, - get_world_group, ) from .base_pipeline import xFuserPipelineBaseWrapper from .register import xFuserPipelineWrapperRegister @@ -226,6 +226,7 @@ def __call__( batch_size=batch_size, num_inference_steps=num_inference_steps, max_condition_sequence_length=max_sequence_length, + split_text_embed_in_sp=get_pipeline_parallel_world_size() == 1, ) #! ---------------------------------------- ADDED ABOVE ---------------------------------------- @@ -381,7 +382,8 @@ def vae_decode(latents): return None def _init_sync_pipeline( - self, latents: torch.Tensor, latent_image_ids: torch.Tensor + self, latents: torch.Tensor, latent_image_ids: torch.Tensor, + prompt_embeds: torch.Tensor, text_ids: torch.Tensor ): get_runtime_state().set_patched_mode(patch_mode=False) @@ -395,7 +397,18 @@ def _init_sync_pipeline( for start_idx, end_idx in get_runtime_state().pp_patches_token_start_end_idx_global ] latent_image_ids = torch.cat(latent_image_ids_list, dim=-2) - return latents, latent_image_ids + + if get_runtime_state().split_text_embed_in_sp: + assert prompt_embeds.shape[-2] % get_sequence_parallel_world_size() == 0, \ + f"the length of text sequence {prompt_embeds.shape[-2]} is not divisible by sp_degree {get_sequence_parallel_world_size()}" + prompt_embeds = torch.chunk(prompt_embeds, get_sequence_parallel_world_size(), dim=-2)[get_sequence_parallel_rank()] + + if get_runtime_state().split_text_embed_in_sp: + assert text_ids.shape[-2] % get_sequence_parallel_world_size() == 0, \ + f"the length of text sequence {text_ids.shape[-2]} is not divisible by sp_degree {get_sequence_parallel_world_size()}" + text_ids = torch.chunk(text_ids, get_sequence_parallel_world_size(), dim=-2)[get_sequence_parallel_rank()] + + return latents, latent_image_ids, prompt_embeds, text_ids # synchronized compute the whole feature map in each pp stage def _sync_pipeline( @@ -413,7 +426,7 @@ def _sync_pipeline( callback_on_step_end_tensor_inputs: List[str] = ["latents"], sync_only: bool = False, ): - latents, latent_image_ids = self._init_sync_pipeline(latents, latent_image_ids) + latents, latent_image_ids, prompt_embeds, text_ids = self._init_sync_pipeline(latents, latent_image_ids, prompt_embeds, text_ids) for i, t in enumerate(timesteps): if self.interrupt: continue diff --git a/xfuser/model_executor/pipelines/pipeline_stable_diffusion_3.py b/xfuser/model_executor/pipelines/pipeline_stable_diffusion_3.py index a706acf..3b04d45 100644 --- a/xfuser/model_executor/pipelines/pipeline_stable_diffusion_3.py +++ b/xfuser/model_executor/pipelines/pipeline_stable_diffusion_3.py @@ -33,12 +33,11 @@ get_runtime_state, get_cfg_group, get_classifier_free_guidance_world_size, - get_pipeline_parallel_rank, get_pp_group, get_sequence_parallel_world_size, + get_sequence_parallel_rank, get_sp_group, is_dp_last_group, - get_world_group, ) from .base_pipeline import xFuserPipelineBaseWrapper from .register import xFuserPipelineWrapperRegister @@ -273,6 +272,7 @@ def __call__( width=width, batch_size=batch_size, num_inference_steps=num_inference_steps, + split_text_embed_in_sp=get_pipeline_parallel_world_size() == 1, ) #! ---------------------------------------- ADDED ABOVE ---------------------------------------- @@ -412,6 +412,22 @@ def vae_decode(latents): else: return None + def _init_sync_pipeline(self, latents: torch.Tensor, prompt_embeds: torch.Tensor): + get_runtime_state().set_patched_mode(patch_mode=False) + + latents_list = [ + latents[:, :, start_idx:end_idx, :] + for start_idx, end_idx in get_runtime_state().pp_patches_start_end_idx_global + ] + latents = torch.cat(latents_list, dim=-2) + + if get_runtime_state().split_text_embed_in_sp: + assert prompt_embeds.shape[-2] % get_sequence_parallel_world_size() == 0, \ + f"the length of text sequence {prompt_embeds.shape[-2]} is not divisible by sp_degree {get_sequence_parallel_world_size()}" + prompt_embeds = torch.chunk(prompt_embeds, get_sequence_parallel_world_size(), dim=-2)[get_sequence_parallel_rank()] + + return latents, prompt_embeds + # synchronized compute the whole feature map in each pp stage def _sync_pipeline( self, @@ -425,7 +441,7 @@ def _sync_pipeline( callback_on_step_end_tensor_inputs: List[str] = ["latents"], sync_only: bool = False, ): - latents = self._init_sync_pipeline(latents) + latents, prompt_embeds = self._init_sync_pipeline(latents, prompt_embeds) for i, t in enumerate(timesteps): if self.interrupt: continue From d48b78bcc3e73af3b70afe3d2ab5346567e22d16 Mon Sep 17 00:00:00 2001 From: Xibo Sun Date: Fri, 22 Nov 2024 15:54:04 +0800 Subject: [PATCH 2/8] check the use_parallel_vae flag for CogVideo --- examples/cogvideox_example.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/examples/cogvideox_example.py b/examples/cogvideox_example.py index c488527..fe1a051 100644 --- a/examples/cogvideox_example.py +++ b/examples/cogvideox_example.py @@ -22,6 +22,8 @@ def main(): engine_config, input_config = engine_args.create_config() local_rank = get_world_group().local_rank + + assert engine_args.use_parallel_vae is False, "parallel VAE not implemented for CogVideo" pipe = xFuserCogVideoXPipeline.from_pretrained( pretrained_model_name_or_path=engine_config.model_config.model, From d35e495d8f3fa8876b0305675b83505c0890fa1d Mon Sep 17 00:00:00 2001 From: Xibo Sun Date: Wed, 27 Nov 2024 16:44:25 +0800 Subject: [PATCH 3/8] fix dimensions in all_gather --- xfuser/core/distributed/group_coordinator.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/xfuser/core/distributed/group_coordinator.py b/xfuser/core/distributed/group_coordinator.py index 4012e02..0402756 100644 --- a/xfuser/core/distributed/group_coordinator.py +++ b/xfuser/core/distributed/group_coordinator.py @@ -223,15 +223,18 @@ def all_gather( # Convert negative dim to positive. dim += input_.dim() # Allocate output tensor. - input_size = input_.size() + input_size = list(input_.size()) + input_size[0] *= world_size output_tensor = torch.empty( - (world_size,) + input_size, dtype=input_.dtype, device=input_.device + input_size, dtype=input_.dtype, device=input_.device ) # All-gather. torch.distributed.all_gather_into_tensor( output_tensor, input_, group=self.device_group ) if dim != 0: + input_size[0] //= world_size + output_tensor = output_tensor.reshape([world_size, ] + input_size) output_tensor = output_tensor.movedim(0, dim) if separate_tensors: From faafcd19fd1fcde491a579c09b0c90933422dade Mon Sep 17 00:00:00 2001 From: Xibo Sun Date: Wed, 27 Nov 2024 17:04:54 +0800 Subject: [PATCH 4/8] optimizations on H100 --- examples/flux_usp_example.py | 168 ++++++++++++++++++ .../layers/attention_processor_usp.py | 95 ++++++++++ xfuser/model_executor/layers/usp.py | 98 ++++++++++ 3 files changed, 361 insertions(+) create mode 100644 examples/flux_usp_example.py create mode 100644 xfuser/model_executor/layers/attention_processor_usp.py create mode 100644 xfuser/model_executor/layers/usp.py diff --git a/examples/flux_usp_example.py b/examples/flux_usp_example.py new file mode 100644 index 0000000..9742d1e --- /dev/null +++ b/examples/flux_usp_example.py @@ -0,0 +1,168 @@ +import functools +from typing import List, Optional, Tuple, Union + +import logging +import time +import torch +import torch.distributed +from diffusers import DiffusionPipeline, FluxPipeline + +from xfuser import xFuserArgs +from xfuser.config import FlexibleArgumentParser +from xfuser.core.distributed import ( + get_world_group, + get_data_parallel_world_size, + get_data_parallel_rank, + get_runtime_state, + get_classifier_free_guidance_world_size, + get_classifier_free_guidance_rank, + get_cfg_group, + get_sequence_parallel_world_size, + get_sequence_parallel_rank, + get_sp_group, + is_dp_last_group, + initialize_runtime_state, + get_pipeline_parallel_world_size, +) + +from xfuser.model_executor.layers.attention_processor_usp import xFuserFluxAttnProcessor2_0USP + +def parallelize_transformer(pipe: DiffusionPipeline): + transformer = pipe.transformer + original_forward = transformer.forward + + @functools.wraps(transformer.__class__.forward) + def new_forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + *args, + timestep: torch.LongTensor = None, + img_ids: torch.Tensor = None, + txt_ids: torch.Tensor = None, + controlnet_block_samples: Optional[List[torch.Tensor]] = None, + controlnet_single_block_samples: Optional[List[torch.Tensor]] = None, + **kwargs, + ): + if isinstance(timestep, torch.Tensor) and timestep.ndim != 0 and timestep.shape[0] == hidden_states.shape[0]: + timestep = torch.chunk(timestep, get_classifier_free_guidance_world_size(),dim=0)[get_classifier_free_guidance_rank()] + hidden_states = torch.chunk(hidden_states, get_classifier_free_guidance_world_size(),dim=0)[get_classifier_free_guidance_rank()] + hidden_states = torch.chunk(hidden_states, get_sequence_parallel_world_size(),dim=-2)[get_sequence_parallel_rank()] + encoder_hidden_states = torch.chunk(encoder_hidden_states, get_classifier_free_guidance_world_size(),dim=0)[get_classifier_free_guidance_rank()] + encoder_hidden_states = torch.chunk(encoder_hidden_states, get_sequence_parallel_world_size(),dim=-2)[get_sequence_parallel_rank()] + img_ids = torch.chunk(img_ids, get_sequence_parallel_world_size(),dim=-2)[get_sequence_parallel_rank()] + txt_ids = torch.chunk(txt_ids, get_sequence_parallel_world_size(),dim=-2)[get_sequence_parallel_rank()] + + for block in transformer.transformer_blocks + transformer.single_transformer_blocks: + block.attn.processor = xFuserFluxAttnProcessor2_0USP() + + output = original_forward( + hidden_states, + encoder_hidden_states, + *args, + timestep=timestep, + img_ids=img_ids, + txt_ids=txt_ids, + **kwargs, + ) + + return_dict = not isinstance(output, tuple) + sample = output[0] + sample = get_sp_group().all_gather(sample, dim=-2) + sample = get_cfg_group().all_gather(sample, dim=0) + if return_dict: + return output.__class__(sample, *output[1:]) + return (sample, *output[1:]) + + new_forward = new_forward.__get__(transformer) + transformer.forward = new_forward + + +def main(): + parser = FlexibleArgumentParser(description="xFuser Arguments") + args = xFuserArgs.add_cli_args(parser).parse_args() + engine_args = xFuserArgs.from_cli_args(args) + engine_config, input_config = engine_args.create_config() + engine_config.runtime_config.dtype = torch.bfloat16 + local_rank = get_world_group().local_rank + + pipe = FluxPipeline.from_pretrained( + pretrained_model_name_or_path=engine_config.model_config.model, + torch_dtype=torch.bfloat16, + ) + + if args.enable_sequential_cpu_offload: + pipe.enable_sequential_cpu_offload(gpu_id=local_rank) + logging.info(f"rank {local_rank} sequential CPU offload enabled") + else: + pipe = pipe.to(f"cuda:{local_rank}") + + parameter_peak_memory = torch.cuda.max_memory_allocated(device=f"cuda:{local_rank}") + + initialize_runtime_state(pipe, engine_config) + get_runtime_state().set_input_parameters( + height=input_config.height, + width=input_config.width, + batch_size=1, + num_inference_steps=input_config.num_inference_steps, + max_condition_sequence_length=512, + split_text_embed_in_sp=get_pipeline_parallel_world_size() == 1, + ) + + parallelize_transformer(pipe) + + if engine_config.runtime_config.use_torch_compile: + torch._inductor.config.reorder_for_compute_comm_overlap = True + pipe.transformer = torch.compile(pipe.transformer, mode="max-autotune-no-cudagraphs") + + # warmup + output = pipe( + height=input_config.height, + width=input_config.width, + prompt=input_config.prompt, + num_inference_steps=1, + output_type=input_config.output_type, + generator=torch.Generator(device="cuda").manual_seed(input_config.seed), + ).images + + torch.cuda.reset_peak_memory_stats() + start_time = time.time() + + output = pipe( + height=input_config.height, + width=input_config.width, + prompt=input_config.prompt, + num_inference_steps=input_config.num_inference_steps, + output_type=input_config.output_type, + generator=torch.Generator(device="cuda").manual_seed(input_config.seed), + ) + end_time = time.time() + elapsed_time = end_time - start_time + peak_memory = torch.cuda.max_memory_allocated(device=f"cuda:{local_rank}") + + parallel_info = ( + f"dp{engine_args.data_parallel_degree}_cfg{engine_config.parallel_config.cfg_degree}_" + f"ulysses{engine_args.ulysses_degree}_ring{engine_args.ring_degree}_" + f"tp{engine_args.tensor_parallel_degree}_" + f"pp{engine_args.pipefusion_parallel_degree}_patch{engine_args.num_pipeline_patch}" + ) + if input_config.output_type == "pil": + dp_group_index = get_data_parallel_rank() + num_dp_groups = get_data_parallel_world_size() + dp_batch_size = (input_config.batch_size + num_dp_groups - 1) // num_dp_groups + if is_dp_last_group(): + for i, image in enumerate(output.images): + image_rank = dp_group_index * dp_batch_size + i + image_name = f"flux_result_{parallel_info}_{image_rank}_tc_{engine_args.use_torch_compile}.png" + image.save(f"./results/{image_name}") + print(f"image {i} saved to ./results/{image_name}") + + if get_world_group().rank == get_world_group().world_size - 1: + print( + f"epoch time: {elapsed_time:.2f} sec, parameter memory: {parameter_peak_memory/1e9:.2f} GB, memory: {peak_memory/1e9:.2f} GB" + ) + get_runtime_state().destory_distributed_env() + + +if __name__ == "__main__": + main() diff --git a/xfuser/model_executor/layers/attention_processor_usp.py b/xfuser/model_executor/layers/attention_processor_usp.py new file mode 100644 index 0000000..516799e --- /dev/null +++ b/xfuser/model_executor/layers/attention_processor_usp.py @@ -0,0 +1,95 @@ +from typing import Optional + +import torch +import torch.distributed +from diffusers.models.attention import Attention +from .attention_processor import Attention + +from diffusers.models.embeddings import apply_rotary_emb + +from xfuser.model_executor.layers.usp import USP + + +class xFuserFluxAttnProcessor2_0USP: + def __call__( + self, + attn: Attention, + hidden_states: torch.FloatTensor, + encoder_hidden_states: torch.FloatTensor = None, + attention_mask: Optional[torch.FloatTensor] = None, + image_rotary_emb: Optional[torch.Tensor] = None, + *args, + **kwargs, + ) -> torch.FloatTensor: + batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + + # `sample` projections. + query = attn.to_q(hidden_states) + key = attn.to_k(hidden_states) + value = attn.to_v(hidden_states) + + inner_dim = key.shape[-1] + head_dim = inner_dim // attn.heads + + query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + if attn.norm_q is not None: + query = attn.norm_q(query) + if attn.norm_k is not None: + key = attn.norm_k(key) + + # the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states` + if encoder_hidden_states is not None: + # `context` projections. + encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states) + encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states) + encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states) + + encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view( + batch_size, -1, attn.heads, head_dim + ).transpose(1, 2) + encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view( + batch_size, -1, attn.heads, head_dim + ).transpose(1, 2) + encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view( + batch_size, -1, attn.heads, head_dim + ).transpose(1, 2) + + if attn.norm_added_q is not None: + encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj) + if attn.norm_added_k is not None: + encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj) + + # attention + query = torch.cat([encoder_hidden_states_query_proj, query], dim=2) + key = torch.cat([encoder_hidden_states_key_proj, key], dim=2) + value = torch.cat([encoder_hidden_states_value_proj, value], dim=2) + + if image_rotary_emb is not None: + + query = apply_rotary_emb(query, image_rotary_emb) + key = apply_rotary_emb(key, image_rotary_emb) + + hidden_states = USP( + query, key, value, dropout_p=0.0, is_causal=False + ) + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + hidden_states = hidden_states.to(query.dtype) + + if encoder_hidden_states is not None: + encoder_hidden_states, hidden_states = ( + hidden_states[:, : encoder_hidden_states.shape[1]], + hidden_states[:, encoder_hidden_states.shape[1] :], + ) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + encoder_hidden_states = attn.to_add_out(encoder_hidden_states) + + return hidden_states, encoder_hidden_states + else: + return hidden_states \ No newline at end of file diff --git a/xfuser/model_executor/layers/usp.py b/xfuser/model_executor/layers/usp.py new file mode 100644 index 0000000..92deb58 --- /dev/null +++ b/xfuser/model_executor/layers/usp.py @@ -0,0 +1,98 @@ +import torch +from torch.nn import functional as F +from torch.distributed.tensor.experimental._attention import _templated_ring_attention +aten = torch.ops.aten + +import torch.distributed._functional_collectives as ft_c +import torch.distributed as dist + +from yunchang.globals import PROCESS_GROUP +from yunchang.comm.all_to_all import SeqAllToAll4D +from xfuser.core.distributed import ( + get_sequence_parallel_world_size, + get_ulysses_parallel_world_size, + get_ring_parallel_world_size, +) + +def ring_attn(query, key, value, dropout_p=0.0, is_causal=False): + out, *_ = _templated_ring_attention( + PROCESS_GROUP.RING_PG, + aten._scaled_dot_product_flash_attention, + query, + key, + value, + dropout_p=dropout_p, + is_causal=is_causal + ) + return out + +def _maybe_wait(tensor: torch.Tensor) -> torch.Tensor: + """ + When tracing the code, the result tensor is not an AsyncCollectiveTensor, + so we cannot call ``wait()``. + """ + if isinstance(tensor, ft_c.AsyncCollectiveTensor): + return tensor.wait() + return tensor + +def _sdpa_all_to_all_single(x): + x_shape = x.shape + x = x.flatten() + x = ft_c.all_to_all_single(x, output_split_sizes=None, input_split_sizes=None, group=PROCESS_GROUP.ULYSSES_PG) + x = _maybe_wait(x) + x = x.reshape(x_shape) + return x + + +def _ft_c_input_all_to_all(x): + world_size = get_ulysses_parallel_world_size() + if world_size <= 1: + return x + + assert x.ndim == 4, "x must have 4 dimensions, got {}".format(x.ndim) + b, h, s, d = x.shape + assert h % world_size == 0, "h must be divisible by world_size, got {} and {}".format(h, world_size) + + x = x.permute(1, 0, 2, 3).contiguous() + x = _sdpa_all_to_all_single(x) + x = x.reshape(world_size, h // world_size, b, -1, d).permute(2, 1, 0, 3, 4).reshape(b, h // world_size, -1, d) + return x + + +def _ft_c_output_all_to_all(x): + world_size = get_ulysses_parallel_world_size() + if world_size <= 1: + return x + + assert x.ndim == 4, "x must have 4 dimensions, got {}".format(x.ndim) + b, h, s, d = x.shape + assert s % world_size == 0, "s must be divisible by world_size, got {} and {}".format(s, world_size) + + x = x.permute(2, 0, 1, 3).contiguous() + x = _sdpa_all_to_all_single(x) + x = x.reshape(world_size, s // world_size, b, -1, d).permute(2, 0, 3, 1, 4).reshape(b, -1, s // world_size, d) + return x + + +def USP(query, key, value, dropout_p=0.0, is_causal=False): + if get_sequence_parallel_world_size() == 1: + out = F.scaled_dot_product_attention( + query, key, value, dropout_p=dropout_p, is_causal=is_causal + ) + elif get_ulysses_parallel_world_size() == 1: + out = ring_attn(query, key, value, dropout_p=dropout_p, is_causal=is_causal) + elif get_ulysses_parallel_world_size() > 1: + query = _ft_c_input_all_to_all(query) + key = _ft_c_input_all_to_all(key) + value = _ft_c_input_all_to_all(value) + + if get_ring_parallel_world_size() == 1: + out = F.scaled_dot_product_attention( + query, key, value, dropout_p=dropout_p, is_causal=is_causal + ) + else: + out = ring_attn(query, key, value, dropout_p=dropout_p, is_causal=is_causal) + + out = _ft_c_output_all_to_all(out) + + return out \ No newline at end of file From a41b7c12186d21f27e03c012247774ea4a470f37 Mon Sep 17 00:00:00 2001 From: Xibo Sun Date: Thu, 28 Nov 2024 15:25:10 +0800 Subject: [PATCH 5/8] support optimized USP in Flux --- examples/flux_usp_example.py | 7 ++++--- .../model_executor/layers/attention_processor.py | 15 ++++++++++++--- 2 files changed, 16 insertions(+), 6 deletions(-) diff --git a/examples/flux_usp_example.py b/examples/flux_usp_example.py index 9742d1e..b10a81f 100644 --- a/examples/flux_usp_example.py +++ b/examples/flux_usp_example.py @@ -1,5 +1,8 @@ +# Flux inference with USP +# from https://github.com/chengzeyi/ParaAttention/blob/main/examples/run_flux.py + import functools -from typing import List, Optional, Tuple, Union +from typing import List, Optional import logging import time @@ -40,8 +43,6 @@ def new_forward( timestep: torch.LongTensor = None, img_ids: torch.Tensor = None, txt_ids: torch.Tensor = None, - controlnet_block_samples: Optional[List[torch.Tensor]] = None, - controlnet_single_block_samples: Optional[List[torch.Tensor]] = None, **kwargs, ): if isinstance(timestep, torch.Tensor) and timestep.ndim != 0 and timestep.shape[0] == hidden_states.shape[0]: diff --git a/xfuser/model_executor/layers/attention_processor.py b/xfuser/model_executor/layers/attention_processor.py index d37e284..79f5ad0 100644 --- a/xfuser/model_executor/layers/attention_processor.py +++ b/xfuser/model_executor/layers/attention_processor.py @@ -19,8 +19,7 @@ from xfuser.core.distributed import ( get_sequence_parallel_world_size, - get_sequence_parallel_rank, - get_sp_group, + get_pipeline_parallel_world_size ) from xfuser.core.fast_attention import ( xFuserFastAttention, @@ -34,6 +33,9 @@ from xfuser.logger import init_logger from xfuser.envs import PACKAGES_CHECKER +if torch.__version__ >= '2.5.0': + from xfuser.model_executor.layers.usp import USP + logger = init_logger(__name__) env_info = PACKAGES_CHECKER.get_packages_info() @@ -687,7 +689,14 @@ def __call__( #! ---------------------------------------- KV CACHE ---------------------------------------- #! ---------------------------------------- ATTENTION ---------------------------------------- - if HAS_LONG_CTX_ATTN and get_sequence_parallel_world_size() > 1: + if get_pipeline_parallel_world_size() == 1 and torch.__version__ >= '2.5.0' and get_runtime_state().split_text_embed_in_sp: + hidden_states = USP( + query, key, value, dropout_p=0.0, is_causal=False + ) + hidden_states = hidden_states.transpose(1, 2).reshape( + batch_size, -1, attn.heads * head_dim + ) + elif HAS_LONG_CTX_ATTN and get_sequence_parallel_world_size() > 1: query = query.transpose(1, 2) key = key.transpose(1, 2) value = value.transpose(1, 2) From 55b871135248410d6cf57280176081ff97510d17 Mon Sep 17 00:00:00 2001 From: Xibo Sun Date: Thu, 28 Nov 2024 15:48:02 +0800 Subject: [PATCH 6/8] do not split text if undivisible by sp_degree --- .../model_executor/pipelines/pipeline_cogvideox.py | 7 ++++--- xfuser/model_executor/pipelines/pipeline_flux.py | 14 ++++++++------ .../pipelines/pipeline_stable_diffusion_3.py | 7 ++++--- 3 files changed, 16 insertions(+), 12 deletions(-) diff --git a/xfuser/model_executor/pipelines/pipeline_cogvideox.py b/xfuser/model_executor/pipelines/pipeline_cogvideox.py index 98631d5..770306f 100644 --- a/xfuser/model_executor/pipelines/pipeline_cogvideox.py +++ b/xfuser/model_executor/pipelines/pipeline_cogvideox.py @@ -370,9 +370,10 @@ def _init_sync_pipeline( latents = super()._init_video_sync_pipeline(latents) if get_runtime_state().split_text_embed_in_sp: - assert prompt_embeds.shape[-2] % get_sequence_parallel_world_size() == 0, \ - f"the length of text sequence {prompt_embeds.shape[-2]} is not divisible by sp_degree {get_sequence_parallel_world_size()}" - prompt_embeds = torch.chunk(prompt_embeds, get_sequence_parallel_world_size(), dim=-2)[get_sequence_parallel_rank()] + if prompt_embeds.shape[-2] % get_sequence_parallel_world_size() == 0: + prompt_embeds = torch.chunk(prompt_embeds, get_sequence_parallel_world_size(), dim=-2)[get_sequence_parallel_rank()] + else: + get_runtime_state().split_text_embed_in_sp = False if image_rotary_emb is not None: assert latents_frames is not None diff --git a/xfuser/model_executor/pipelines/pipeline_flux.py b/xfuser/model_executor/pipelines/pipeline_flux.py index 2ac3239..4255edb 100644 --- a/xfuser/model_executor/pipelines/pipeline_flux.py +++ b/xfuser/model_executor/pipelines/pipeline_flux.py @@ -399,14 +399,16 @@ def _init_sync_pipeline( latent_image_ids = torch.cat(latent_image_ids_list, dim=-2) if get_runtime_state().split_text_embed_in_sp: - assert prompt_embeds.shape[-2] % get_sequence_parallel_world_size() == 0, \ - f"the length of text sequence {prompt_embeds.shape[-2]} is not divisible by sp_degree {get_sequence_parallel_world_size()}" - prompt_embeds = torch.chunk(prompt_embeds, get_sequence_parallel_world_size(), dim=-2)[get_sequence_parallel_rank()] + if prompt_embeds.shape[-2] % get_sequence_parallel_world_size() == 0: + prompt_embeds = torch.chunk(prompt_embeds, get_sequence_parallel_world_size(), dim=-2)[get_sequence_parallel_rank()] + else: + get_runtime_state().split_text_embed_in_sp = False if get_runtime_state().split_text_embed_in_sp: - assert text_ids.shape[-2] % get_sequence_parallel_world_size() == 0, \ - f"the length of text sequence {text_ids.shape[-2]} is not divisible by sp_degree {get_sequence_parallel_world_size()}" - text_ids = torch.chunk(text_ids, get_sequence_parallel_world_size(), dim=-2)[get_sequence_parallel_rank()] + if text_ids.shape[-2] % get_sequence_parallel_world_size() == 0: + text_ids = torch.chunk(text_ids, get_sequence_parallel_world_size(), dim=-2)[get_sequence_parallel_rank()] + else: + get_runtime_state().split_text_embed_in_sp = False return latents, latent_image_ids, prompt_embeds, text_ids diff --git a/xfuser/model_executor/pipelines/pipeline_stable_diffusion_3.py b/xfuser/model_executor/pipelines/pipeline_stable_diffusion_3.py index 3b04d45..22130c9 100644 --- a/xfuser/model_executor/pipelines/pipeline_stable_diffusion_3.py +++ b/xfuser/model_executor/pipelines/pipeline_stable_diffusion_3.py @@ -422,9 +422,10 @@ def _init_sync_pipeline(self, latents: torch.Tensor, prompt_embeds: torch.Tensor latents = torch.cat(latents_list, dim=-2) if get_runtime_state().split_text_embed_in_sp: - assert prompt_embeds.shape[-2] % get_sequence_parallel_world_size() == 0, \ - f"the length of text sequence {prompt_embeds.shape[-2]} is not divisible by sp_degree {get_sequence_parallel_world_size()}" - prompt_embeds = torch.chunk(prompt_embeds, get_sequence_parallel_world_size(), dim=-2)[get_sequence_parallel_rank()] + if prompt_embeds.shape[-2] % get_sequence_parallel_world_size() == 0: + prompt_embeds = torch.chunk(prompt_embeds, get_sequence_parallel_world_size(), dim=-2)[get_sequence_parallel_rank()] + else: + get_runtime_state().split_text_embed_in_sp = False return latents, prompt_embeds From 726f402f0dddc6e345f590f3974a0c06cafb50eb Mon Sep 17 00:00:00 2001 From: Xibo Sun Date: Thu, 28 Nov 2024 15:58:04 +0800 Subject: [PATCH 7/8] polish optimized USP --- examples/flux_usp_example.py | 4 +- .../layers/attention_processor_usp.py | 95 ------------------- 2 files changed, 2 insertions(+), 97 deletions(-) delete mode 100644 xfuser/model_executor/layers/attention_processor_usp.py diff --git a/examples/flux_usp_example.py b/examples/flux_usp_example.py index b10a81f..f01b63e 100644 --- a/examples/flux_usp_example.py +++ b/examples/flux_usp_example.py @@ -28,7 +28,7 @@ get_pipeline_parallel_world_size, ) -from xfuser.model_executor.layers.attention_processor_usp import xFuserFluxAttnProcessor2_0USP +from xfuser.model_executor.layers.attention_processor import xFuserFluxAttnProcessor2_0 def parallelize_transformer(pipe: DiffusionPipeline): transformer = pipe.transformer @@ -55,7 +55,7 @@ def new_forward( txt_ids = torch.chunk(txt_ids, get_sequence_parallel_world_size(),dim=-2)[get_sequence_parallel_rank()] for block in transformer.transformer_blocks + transformer.single_transformer_blocks: - block.attn.processor = xFuserFluxAttnProcessor2_0USP() + block.attn.processor = xFuserFluxAttnProcessor2_0() output = original_forward( hidden_states, diff --git a/xfuser/model_executor/layers/attention_processor_usp.py b/xfuser/model_executor/layers/attention_processor_usp.py deleted file mode 100644 index 516799e..0000000 --- a/xfuser/model_executor/layers/attention_processor_usp.py +++ /dev/null @@ -1,95 +0,0 @@ -from typing import Optional - -import torch -import torch.distributed -from diffusers.models.attention import Attention -from .attention_processor import Attention - -from diffusers.models.embeddings import apply_rotary_emb - -from xfuser.model_executor.layers.usp import USP - - -class xFuserFluxAttnProcessor2_0USP: - def __call__( - self, - attn: Attention, - hidden_states: torch.FloatTensor, - encoder_hidden_states: torch.FloatTensor = None, - attention_mask: Optional[torch.FloatTensor] = None, - image_rotary_emb: Optional[torch.Tensor] = None, - *args, - **kwargs, - ) -> torch.FloatTensor: - batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape - - # `sample` projections. - query = attn.to_q(hidden_states) - key = attn.to_k(hidden_states) - value = attn.to_v(hidden_states) - - inner_dim = key.shape[-1] - head_dim = inner_dim // attn.heads - - query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - - if attn.norm_q is not None: - query = attn.norm_q(query) - if attn.norm_k is not None: - key = attn.norm_k(key) - - # the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states` - if encoder_hidden_states is not None: - # `context` projections. - encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states) - encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states) - encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states) - - encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view( - batch_size, -1, attn.heads, head_dim - ).transpose(1, 2) - encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view( - batch_size, -1, attn.heads, head_dim - ).transpose(1, 2) - encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view( - batch_size, -1, attn.heads, head_dim - ).transpose(1, 2) - - if attn.norm_added_q is not None: - encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj) - if attn.norm_added_k is not None: - encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj) - - # attention - query = torch.cat([encoder_hidden_states_query_proj, query], dim=2) - key = torch.cat([encoder_hidden_states_key_proj, key], dim=2) - value = torch.cat([encoder_hidden_states_value_proj, value], dim=2) - - if image_rotary_emb is not None: - - query = apply_rotary_emb(query, image_rotary_emb) - key = apply_rotary_emb(key, image_rotary_emb) - - hidden_states = USP( - query, key, value, dropout_p=0.0, is_causal=False - ) - hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) - hidden_states = hidden_states.to(query.dtype) - - if encoder_hidden_states is not None: - encoder_hidden_states, hidden_states = ( - hidden_states[:, : encoder_hidden_states.shape[1]], - hidden_states[:, encoder_hidden_states.shape[1] :], - ) - - # linear proj - hidden_states = attn.to_out[0](hidden_states) - # dropout - hidden_states = attn.to_out[1](hidden_states) - encoder_hidden_states = attn.to_add_out(encoder_hidden_states) - - return hidden_states, encoder_hidden_states - else: - return hidden_states \ No newline at end of file From 61b4b907bfb60c269f1169e032eac5edb552c7ab Mon Sep 17 00:00:00 2001 From: Xibo Sun Date: Thu, 28 Nov 2024 16:03:12 +0800 Subject: [PATCH 8/8] update diffusers versio in setup.py --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index beb2c20..d6fa12f 100644 --- a/setup.py +++ b/setup.py @@ -28,7 +28,7 @@ def get_cuda_version(): install_requires=[ "torch>=2.1.0", "accelerate>=0.33.0", - "diffusers>=0.31", # NOTE: diffusers>=0.31.0 is necessary for CogVideoX and Flux + "diffusers@git+https://github.com/huggingface/diffusers", # NOTE: diffusers>=0.32.0.dev is necessary for CogVideoX and Flux "transformers>=4.39.1", "sentencepiece>=0.1.99", "beautifulsoup4>=4.12.3",