From ca94011acb0a1699222431bfafce8720ef9d743d Mon Sep 17 00:00:00 2001 From: Xibo Sun Date: Thu, 28 Nov 2024 16:35:14 +0800 Subject: [PATCH] Support optimized USP in Flux (#368) --- examples/flux_usp_example.py | 11 ++++++----- setup.py | 2 +- .../model_executor/layers/attention_processor.py | 15 ++++++++++++--- .../pipelines/pipeline_cogvideox.py | 7 ++++--- xfuser/model_executor/pipelines/pipeline_flux.py | 14 ++++++++------ .../pipelines/pipeline_stable_diffusion_3.py | 7 ++++--- 6 files changed, 35 insertions(+), 21 deletions(-) diff --git a/examples/flux_usp_example.py b/examples/flux_usp_example.py index 9742d1e..f01b63e 100644 --- a/examples/flux_usp_example.py +++ b/examples/flux_usp_example.py @@ -1,5 +1,8 @@ +# Flux inference with USP +# from https://github.com/chengzeyi/ParaAttention/blob/main/examples/run_flux.py + import functools -from typing import List, Optional, Tuple, Union +from typing import List, Optional import logging import time @@ -25,7 +28,7 @@ get_pipeline_parallel_world_size, ) -from xfuser.model_executor.layers.attention_processor_usp import xFuserFluxAttnProcessor2_0USP +from xfuser.model_executor.layers.attention_processor import xFuserFluxAttnProcessor2_0 def parallelize_transformer(pipe: DiffusionPipeline): transformer = pipe.transformer @@ -40,8 +43,6 @@ def new_forward( timestep: torch.LongTensor = None, img_ids: torch.Tensor = None, txt_ids: torch.Tensor = None, - controlnet_block_samples: Optional[List[torch.Tensor]] = None, - controlnet_single_block_samples: Optional[List[torch.Tensor]] = None, **kwargs, ): if isinstance(timestep, torch.Tensor) and timestep.ndim != 0 and timestep.shape[0] == hidden_states.shape[0]: @@ -54,7 +55,7 @@ def new_forward( txt_ids = torch.chunk(txt_ids, get_sequence_parallel_world_size(),dim=-2)[get_sequence_parallel_rank()] for block in transformer.transformer_blocks + transformer.single_transformer_blocks: - block.attn.processor = xFuserFluxAttnProcessor2_0USP() + block.attn.processor = xFuserFluxAttnProcessor2_0() output = original_forward( hidden_states, diff --git a/setup.py b/setup.py index 710e0b8..8823d8c 100644 --- a/setup.py +++ b/setup.py @@ -28,7 +28,7 @@ def get_cuda_version(): install_requires=[ "torch>=2.1.0", "accelerate>=0.33.0", - "diffusers==0.31", # NOTE: diffusers>=0.31.0 is necessary for CogVideoX and Flux + "diffusers@git+https://github.com/huggingface/diffusers", # NOTE: diffusers>=0.32.0.dev is necessary for CogVideoX and Flux "transformers>=4.39.1", "sentencepiece>=0.1.99", "beautifulsoup4>=4.12.3", diff --git a/xfuser/model_executor/layers/attention_processor.py b/xfuser/model_executor/layers/attention_processor.py index d37e284..79f5ad0 100644 --- a/xfuser/model_executor/layers/attention_processor.py +++ b/xfuser/model_executor/layers/attention_processor.py @@ -19,8 +19,7 @@ from xfuser.core.distributed import ( get_sequence_parallel_world_size, - get_sequence_parallel_rank, - get_sp_group, + get_pipeline_parallel_world_size ) from xfuser.core.fast_attention import ( xFuserFastAttention, @@ -34,6 +33,9 @@ from xfuser.logger import init_logger from xfuser.envs import PACKAGES_CHECKER +if torch.__version__ >= '2.5.0': + from xfuser.model_executor.layers.usp import USP + logger = init_logger(__name__) env_info = PACKAGES_CHECKER.get_packages_info() @@ -687,7 +689,14 @@ def __call__( #! ---------------------------------------- KV CACHE ---------------------------------------- #! ---------------------------------------- ATTENTION ---------------------------------------- - if HAS_LONG_CTX_ATTN and get_sequence_parallel_world_size() > 1: + if get_pipeline_parallel_world_size() == 1 and torch.__version__ >= '2.5.0' and get_runtime_state().split_text_embed_in_sp: + hidden_states = USP( + query, key, value, dropout_p=0.0, is_causal=False + ) + hidden_states = hidden_states.transpose(1, 2).reshape( + batch_size, -1, attn.heads * head_dim + ) + elif HAS_LONG_CTX_ATTN and get_sequence_parallel_world_size() > 1: query = query.transpose(1, 2) key = key.transpose(1, 2) value = value.transpose(1, 2) diff --git a/xfuser/model_executor/pipelines/pipeline_cogvideox.py b/xfuser/model_executor/pipelines/pipeline_cogvideox.py index 98631d5..770306f 100644 --- a/xfuser/model_executor/pipelines/pipeline_cogvideox.py +++ b/xfuser/model_executor/pipelines/pipeline_cogvideox.py @@ -370,9 +370,10 @@ def _init_sync_pipeline( latents = super()._init_video_sync_pipeline(latents) if get_runtime_state().split_text_embed_in_sp: - assert prompt_embeds.shape[-2] % get_sequence_parallel_world_size() == 0, \ - f"the length of text sequence {prompt_embeds.shape[-2]} is not divisible by sp_degree {get_sequence_parallel_world_size()}" - prompt_embeds = torch.chunk(prompt_embeds, get_sequence_parallel_world_size(), dim=-2)[get_sequence_parallel_rank()] + if prompt_embeds.shape[-2] % get_sequence_parallel_world_size() == 0: + prompt_embeds = torch.chunk(prompt_embeds, get_sequence_parallel_world_size(), dim=-2)[get_sequence_parallel_rank()] + else: + get_runtime_state().split_text_embed_in_sp = False if image_rotary_emb is not None: assert latents_frames is not None diff --git a/xfuser/model_executor/pipelines/pipeline_flux.py b/xfuser/model_executor/pipelines/pipeline_flux.py index 2ac3239..4255edb 100644 --- a/xfuser/model_executor/pipelines/pipeline_flux.py +++ b/xfuser/model_executor/pipelines/pipeline_flux.py @@ -399,14 +399,16 @@ def _init_sync_pipeline( latent_image_ids = torch.cat(latent_image_ids_list, dim=-2) if get_runtime_state().split_text_embed_in_sp: - assert prompt_embeds.shape[-2] % get_sequence_parallel_world_size() == 0, \ - f"the length of text sequence {prompt_embeds.shape[-2]} is not divisible by sp_degree {get_sequence_parallel_world_size()}" - prompt_embeds = torch.chunk(prompt_embeds, get_sequence_parallel_world_size(), dim=-2)[get_sequence_parallel_rank()] + if prompt_embeds.shape[-2] % get_sequence_parallel_world_size() == 0: + prompt_embeds = torch.chunk(prompt_embeds, get_sequence_parallel_world_size(), dim=-2)[get_sequence_parallel_rank()] + else: + get_runtime_state().split_text_embed_in_sp = False if get_runtime_state().split_text_embed_in_sp: - assert text_ids.shape[-2] % get_sequence_parallel_world_size() == 0, \ - f"the length of text sequence {text_ids.shape[-2]} is not divisible by sp_degree {get_sequence_parallel_world_size()}" - text_ids = torch.chunk(text_ids, get_sequence_parallel_world_size(), dim=-2)[get_sequence_parallel_rank()] + if text_ids.shape[-2] % get_sequence_parallel_world_size() == 0: + text_ids = torch.chunk(text_ids, get_sequence_parallel_world_size(), dim=-2)[get_sequence_parallel_rank()] + else: + get_runtime_state().split_text_embed_in_sp = False return latents, latent_image_ids, prompt_embeds, text_ids diff --git a/xfuser/model_executor/pipelines/pipeline_stable_diffusion_3.py b/xfuser/model_executor/pipelines/pipeline_stable_diffusion_3.py index 3b04d45..22130c9 100644 --- a/xfuser/model_executor/pipelines/pipeline_stable_diffusion_3.py +++ b/xfuser/model_executor/pipelines/pipeline_stable_diffusion_3.py @@ -422,9 +422,10 @@ def _init_sync_pipeline(self, latents: torch.Tensor, prompt_embeds: torch.Tensor latents = torch.cat(latents_list, dim=-2) if get_runtime_state().split_text_embed_in_sp: - assert prompt_embeds.shape[-2] % get_sequence_parallel_world_size() == 0, \ - f"the length of text sequence {prompt_embeds.shape[-2]} is not divisible by sp_degree {get_sequence_parallel_world_size()}" - prompt_embeds = torch.chunk(prompt_embeds, get_sequence_parallel_world_size(), dim=-2)[get_sequence_parallel_rank()] + if prompt_embeds.shape[-2] % get_sequence_parallel_world_size() == 0: + prompt_embeds = torch.chunk(prompt_embeds, get_sequence_parallel_world_size(), dim=-2)[get_sequence_parallel_rank()] + else: + get_runtime_state().split_text_embed_in_sp = False return latents, prompt_embeds