diff --git a/examples/hunyuan_video_usp_example.py b/examples/hunyuan_video_usp_example.py new file mode 100644 index 0000000..1b79632 --- /dev/null +++ b/examples/hunyuan_video_usp_example.py @@ -0,0 +1,302 @@ +import functools +from typing import Any, Dict, Union +import logging +import time + +import torch + +from diffusers import DiffusionPipeline, HunyuanVideoPipeline, HunyuanVideoTransformer3DModel +from diffusers.models.modeling_outputs import Transformer2DModelOutput +from diffusers.utils import export_to_video + +from xfuser import xFuserArgs +from xfuser.config import FlexibleArgumentParser +from xfuser.core.distributed import ( + get_world_group, + get_data_parallel_world_size, + get_data_parallel_rank, + get_runtime_state, + get_classifier_free_guidance_world_size, + get_classifier_free_guidance_rank, + get_cfg_group, + get_sequence_parallel_world_size, + get_sequence_parallel_rank, + get_sp_group, + is_dp_last_group, + initialize_runtime_state, + get_pipeline_parallel_world_size, +) + +from xfuser.model_executor.layers.attention_processor import xFuserHunyuanVideoAttnProcessor2_0 + +assert xFuserHunyuanVideoAttnProcessor2_0 is not None + + +def parallelize_transformer(pipe: DiffusionPipeline): + transformer = pipe.transformer + + @functools.wraps(transformer.__class__.forward) + def new_forward( + self, + hidden_states: torch.Tensor, + timestep: torch.LongTensor, + encoder_hidden_states: torch.Tensor, + encoder_attention_mask: torch.Tensor, + pooled_projections: torch.Tensor, + guidance: torch.Tensor = None, + return_dict: bool = True, + ) -> Union[torch.Tensor, Dict[str, torch.Tensor]]: + batch_size, num_channels, num_frames, height, width = hidden_states.shape + + assert batch_size % get_classifier_free_guidance_world_size( + ) == 0, f"Cannot split dim 0 of hidden_states ({batch_size}) into {get_classifier_free_guidance_world_size()} parts." + + p, p_t = self.config.patch_size, self.config.patch_size_t + post_patch_num_frames = num_frames // p_t + post_patch_height = height // p + post_patch_width = width // p + + # 1. RoPE + image_rotary_emb = self.rope(hidden_states) + + # 2. Conditional embeddings + temb = self.time_text_embed(timestep, guidance, pooled_projections) + hidden_states = self.x_embedder(hidden_states) + encoder_hidden_states = self.context_embedder(encoder_hidden_states, + timestep, + encoder_attention_mask) + + encoder_attention_mask = encoder_attention_mask[0].to(torch.bool) + encoder_hidden_states_indices = torch.arange( + encoder_hidden_states.shape[1], + device=encoder_hidden_states.device) + encoder_hidden_states_indices = encoder_hidden_states_indices[ + encoder_attention_mask] + encoder_hidden_states = encoder_hidden_states[ + ..., encoder_hidden_states_indices, :] + if encoder_hidden_states.shape[-2] % get_sequence_parallel_world_size( + ) != 0: + get_runtime_state().split_text_embed_in_sp = False + else: + get_runtime_state().split_text_embed_in_sp = True + + hidden_states = torch.chunk(hidden_states, + get_classifier_free_guidance_world_size(), + dim=0)[get_classifier_free_guidance_rank()] + hidden_states = torch.chunk(hidden_states, + get_sequence_parallel_world_size(), + dim=-2)[get_sequence_parallel_rank()] + encoder_hidden_states = torch.chunk( + encoder_hidden_states, + get_classifier_free_guidance_world_size(), + dim=0)[get_classifier_free_guidance_rank()] + if get_runtime_state().split_text_embed_in_sp: + encoder_hidden_states = torch.chunk( + encoder_hidden_states, + get_sequence_parallel_world_size(), + dim=-2)[get_sequence_parallel_rank()] + + freqs_cos, freqs_sin = image_rotary_emb + + def get_rotary_emb_chunk(freqs): + freqs = torch.chunk(freqs, + get_sequence_parallel_world_size(), + dim=0)[get_sequence_parallel_rank()] + return freqs + + freqs_cos = get_rotary_emb_chunk(freqs_cos) + freqs_sin = get_rotary_emb_chunk(freqs_sin) + image_rotary_emb = (freqs_cos, freqs_sin) + + # 4. Transformer blocks + if torch.is_grad_enabled() and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): + + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} + + for block in self.transformer_blocks: + hidden_states, encoder_hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + hidden_states, + encoder_hidden_states, + temb, + None, + image_rotary_emb, + **ckpt_kwargs, + ) + + for block in self.single_transformer_blocks: + hidden_states, encoder_hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + hidden_states, + encoder_hidden_states, + temb, + None, + image_rotary_emb, + **ckpt_kwargs, + ) + + else: + for block in self.transformer_blocks: + hidden_states, encoder_hidden_states = block( + hidden_states, encoder_hidden_states, temb, None, + image_rotary_emb) + + for block in self.single_transformer_blocks: + hidden_states, encoder_hidden_states = block( + hidden_states, encoder_hidden_states, temb, None, + image_rotary_emb) + + # 5. Output projection + hidden_states = self.norm_out(hidden_states, temb) + hidden_states = self.proj_out(hidden_states) + + hidden_states = get_sp_group().all_gather(hidden_states, dim=-2) + hidden_states = get_cfg_group().all_gather(hidden_states, dim=0) + + hidden_states = hidden_states.reshape(batch_size, + post_patch_num_frames, + post_patch_height, + post_patch_width, -1, p_t, p, p) + hidden_states = hidden_states.permute(0, 4, 1, 5, 2, 6, 3, 7) + hidden_states = hidden_states.flatten(6, 7).flatten(4, 5).flatten(2, 3) + + if not return_dict: + return (hidden_states, ) + + return Transformer2DModelOutput(sample=hidden_states) + + new_forward = new_forward.__get__(transformer) + transformer.forward = new_forward + + for block in transformer.transformer_blocks + transformer.single_transformer_blocks: + block.attn.processor = xFuserHunyuanVideoAttnProcessor2_0() + + +def main(): + parser = FlexibleArgumentParser(description="xFuser Arguments") + args = xFuserArgs.add_cli_args(parser).parse_args() + engine_args = xFuserArgs.from_cli_args(args) + + engine_config, input_config = engine_args.create_config() + local_rank = get_world_group().local_rank + + assert engine_args.pipefusion_parallel_degree == 1, "This script does not support PipeFusion." + assert engine_args.use_parallel_vae is False, "parallel VAE not implemented for HunyuanVideo" + + transformer = HunyuanVideoTransformer3DModel.from_pretrained( + pretrained_model_name_or_path=engine_config.model_config.model, + subfolder="transformer", + torch_dtype=torch.bfloat16, + revision="refs/pr/18", + ) + pipe = HunyuanVideoPipeline.from_pretrained( + pretrained_model_name_or_path=engine_config.model_config.model, + transformer=transformer, + torch_dtype=torch.float16, + revision="refs/pr/18", + ) + + if args.enable_sequential_cpu_offload: + pipe.enable_sequential_cpu_offload(gpu_id=local_rank) + logging.info(f"rank {local_rank} sequential CPU offload enabled") + elif args.enable_model_cpu_offload: + pipe.enable_model_cpu_offload(gpu_id=local_rank) + logging.info(f"rank {local_rank} model CPU offload enabled") + else: + device = torch.device(f"cuda:{local_rank}") + pipe = pipe.to(device) + + if args.enable_tiling: + pipe.vae.enable_tiling( + # Make it runnable on GPUs with 48GB memory + tile_sample_min_height=128, + tile_sample_stride_height=96, + tile_sample_min_width=128, + tile_sample_stride_width=96, + tile_sample_min_num_frames=32, + tile_sample_stride_num_frames=24, + ) + + if args.enable_slicing: + pipe.vae.enable_slicing() + + parameter_peak_memory = torch.cuda.max_memory_allocated( + device=f"cuda:{local_rank}") + + initialize_runtime_state(pipe, engine_config) + get_runtime_state().set_video_input_parameters( + height=input_config.height, + width=input_config.width, + num_frames=input_config.num_frames, + batch_size=1, + num_inference_steps=input_config.num_inference_steps, + split_text_embed_in_sp=get_pipeline_parallel_world_size() == 1, + ) + + parallelize_transformer(pipe) + + if engine_config.runtime_config.use_torch_compile: + torch._inductor.config.reorder_for_compute_comm_overlap = True + pipe.transformer = torch.compile(pipe.transformer, + mode="max-autotune-no-cudagraphs") + + # one step to warmup the torch compiler + output = pipe( + height=input_config.height, + width=input_config.width, + num_frames=input_config.num_frames, + prompt=input_config.prompt, + num_inference_steps=1, + generator=torch.Generator(device="cuda").manual_seed( + input_config.seed), + ).frames[0] + + torch.cuda.reset_peak_memory_stats() + start_time = time.time() + + output = pipe( + height=input_config.height, + width=input_config.width, + num_frames=input_config.num_frames, + prompt=input_config.prompt, + num_inference_steps=input_config.num_inference_steps, + generator=torch.Generator(device="cuda").manual_seed( + input_config.seed), + ).frames[0] + + end_time = time.time() + elapsed_time = end_time - start_time + peak_memory = torch.cuda.max_memory_allocated(device=f"cuda:{local_rank}") + + parallel_info = ( + f"dp{engine_args.data_parallel_degree}_cfg{engine_config.parallel_config.cfg_degree}_" + f"ulysses{engine_args.ulysses_degree}_ring{engine_args.ring_degree}_" + f"tp{engine_args.tensor_parallel_degree}_" + f"pp{engine_args.pipefusion_parallel_degree}_patch{engine_args.num_pipeline_patch}" + ) + if is_dp_last_group(): + resolution = f"{input_config.width}x{input_config.height}" + output_filename = f"results/hunyuan_video_{parallel_info}_{resolution}.mp4" + export_to_video(output, output_filename, fps=15) + print(f"output saved to {output_filename}") + + if get_world_group().rank == get_world_group().world_size - 1: + print( + f"epoch time: {elapsed_time:.2f} sec, parameter memory: {parameter_peak_memory/1e9:.2f} GB, memory: {peak_memory/1e9} GB" + ) + get_runtime_state().destory_distributed_env() + + +# mkdir -p results && torchrun --nproc_per_node=2 examples/hunyuan_video_usp_example.py --model tencent/HunyuanVideo --ulysses_degree 2 --num_inference_steps 30 --warmup_steps 0 --prompt "A cat walks on the grass, realistic" --height 320 --width 512 --num_frames 61 --enable_tiling +if __name__ == "__main__": + main() diff --git a/xfuser/core/distributed/runtime_state.py b/xfuser/core/distributed/runtime_state.py index cfb4c79..7de0606 100644 --- a/xfuser/core/distributed/runtime_state.py +++ b/xfuser/core/distributed/runtime_state.py @@ -4,7 +4,7 @@ import numpy as np import torch -from diffusers import DiffusionPipeline, CogVideoXPipeline +from diffusers import DiffusionPipeline import torch.distributed from xfuser.config.config import ( @@ -103,7 +103,7 @@ def __init__(self, pipeline: DiffusionPipeline, config: EngineConfig): pipeline=pipeline, parallel_config=config.parallel_config ) self.cogvideox = False - if isinstance(pipeline, CogVideoXPipeline): + if pipeline.__class__.__name__.startswith(("CogVideoX", "HunyuanVideo")): self._set_cogvideox_parameters( vae_scale_factor_spatial=pipeline.vae_scale_factor_spatial, vae_scale_factor_temporal=pipeline.vae_scale_factor_temporal, diff --git a/xfuser/model_executor/layers/attention_processor.py b/xfuser/model_executor/layers/attention_processor.py index 9869a7e..31c3fff 100644 --- a/xfuser/model_executor/layers/attention_processor.py +++ b/xfuser/model_executor/layers/attention_processor.py @@ -15,6 +15,11 @@ CogVideoXAttnProcessor2_0 ) +try: + from diffusers.models.transformers.transformer_hunyuan_video import HunyuanVideoAttnProcessor2_0 +except ImportError: + HunyuanVideoAttnProcessor2_0 = None + from diffusers.models.embeddings import apply_rotary_emb from xfuser.core.distributed import ( @@ -1143,3 +1148,200 @@ def __call__( [text_seq_length, latent_seq_length], dim=1 ) return hidden_states, encoder_hidden_states + + +if HunyuanVideoAttnProcessor2_0 is not None: + @xFuserAttentionProcessorRegister.register(HunyuanVideoAttnProcessor2_0) + class xFuserHunyuanVideoAttnProcessor2_0(HunyuanVideoAttnProcessor2_0): + def __init__(self): + super().__init__() + use_long_ctx_attn_kvcache = True + self.use_long_ctx_attn_kvcache = ( + HAS_LONG_CTX_ATTN + and use_long_ctx_attn_kvcache + and get_sequence_parallel_world_size() > 1 + ) + if HAS_LONG_CTX_ATTN and get_sequence_parallel_world_size() > 1: + from xfuser.core.long_ctx_attention import ( + xFuserLongContextAttention, + xFuserUlyssesAttention, + ) + + if HAS_FLASH_ATTN: + self.hybrid_seq_parallel_attn = xFuserLongContextAttention( + use_kv_cache=self.use_long_ctx_attn_kvcache + ) + else: + self.hybrid_seq_parallel_attn = xFuserUlyssesAttention( + use_fa=False, + use_kv_cache=self.use_long_ctx_attn_kvcache, + ) + else: + self.hybrid_seq_parallel_attn = None + + def __call__( + self, + attn: Attention, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + image_rotary_emb: Optional[torch.Tensor] = None, + *args, + **kwargs, + ) -> torch.Tensor: + batch_size, _, _ = ( + hidden_states.shape + if encoder_hidden_states is None + else encoder_hidden_states.shape + ) + + if attn.add_q_proj is None and encoder_hidden_states is not None: + hidden_states = torch.cat([hidden_states, encoder_hidden_states], dim=1) + + # 1. QKV projections + query = attn.to_q(hidden_states) + key = attn.to_k(hidden_states) + value = attn.to_v(hidden_states) + + query = query.unflatten(2, (attn.heads, -1)).transpose(1, 2) + key = key.unflatten(2, (attn.heads, -1)).transpose(1, 2) + value = value.unflatten(2, (attn.heads, -1)).transpose(1, 2) + + # 2. QK normalization + if attn.norm_q is not None: + query = attn.norm_q(query) + if attn.norm_k is not None: + key = attn.norm_k(key) + + # 3. Rotational positional embeddings applied to latent stream + if image_rotary_emb is not None: + if attn.add_q_proj is None and encoder_hidden_states is not None: + query = torch.cat( + [ + apply_rotary_emb(query[:, :, : -encoder_hidden_states.shape[1]], image_rotary_emb), + query[:, :, -encoder_hidden_states.shape[1] :], + ], + dim=2, + ) + key = torch.cat( + [ + apply_rotary_emb(key[:, :, : -encoder_hidden_states.shape[1]], image_rotary_emb), + key[:, :, -encoder_hidden_states.shape[1] :], + ], + dim=2, + ) + else: + query = apply_rotary_emb(query, image_rotary_emb) + key = apply_rotary_emb(key, image_rotary_emb) + + # 4. Encoder condition QKV projection and normalization + if attn.add_q_proj is not None and encoder_hidden_states is not None: + encoder_query = attn.add_q_proj(encoder_hidden_states) + encoder_key = attn.add_k_proj(encoder_hidden_states) + encoder_value = attn.add_v_proj(encoder_hidden_states) + + encoder_query = encoder_query.unflatten(2, (attn.heads, -1)).transpose(1, 2) + encoder_key = encoder_key.unflatten(2, (attn.heads, -1)).transpose(1, 2) + encoder_value = encoder_value.unflatten(2, (attn.heads, -1)).transpose(1, 2) + + if attn.norm_added_q is not None: + encoder_query = attn.norm_added_q(encoder_query) + if attn.norm_added_k is not None: + encoder_key = attn.norm_added_k(encoder_key) + + query = torch.cat([query, encoder_query], dim=2) + key = torch.cat([key, encoder_key], dim=2) + value = torch.cat([value, encoder_value], dim=2) + + if encoder_hidden_states is not None: + num_encoder_hidden_states_tokens = encoder_hidden_states.shape[1] + num_query_tokens = query.shape[2] - num_encoder_hidden_states_tokens + else: + num_encoder_hidden_states_tokens = ( + get_runtime_state().max_condition_sequence_length + ) + num_query_tokens = query.shape[2] - num_encoder_hidden_states_tokens + + #! ---------------------------------------- ATTENTION ---------------------------------------- + if get_pipeline_parallel_world_size() == 1 and get_runtime_state().split_text_embed_in_sp: + hidden_states = USP( + query, key, value, dropout_p=0.0, is_causal=False + ) + hidden_states = hidden_states.transpose(1, 2).flatten(2, 3) + elif HAS_LONG_CTX_ATTN and get_sequence_parallel_world_size() > 1: + if get_runtime_state().split_text_embed_in_sp: + encoder_query = None + encoder_key = None + encoder_value = None + else: + query, encoder_query = query.split( + [num_query_tokens, num_encoder_hidden_states_tokens], dim=2 + ) + key, encoder_key = key.split( + [num_query_tokens, num_encoder_hidden_states_tokens], dim=2 + ) + value, encoder_value = value.split( + [num_query_tokens, num_encoder_hidden_states_tokens], dim=2 + ) + + encoder_query = encoder_query.transpose(1, 2) + encoder_key = encoder_key.transpose(1, 2) + encoder_value = encoder_value.transpose(1, 2) + + query = query.transpose(1, 2) + key = key.transpose(1, 2) + value = value.transpose(1, 2) + + hidden_states = self.hybrid_seq_parallel_attn( + None, + query, + key, + value, + dropout_p=0.0, + causal=False, + joint_tensor_query=encoder_query, + joint_tensor_key=encoder_key, + joint_tensor_value=encoder_value, + joint_strategy="rear", + ) + + hidden_states = hidden_states.flatten(2, 3) + else: + if HAS_FLASH_ATTN: + from flash_attn import flash_attn_func + + query = query.transpose(1, 2) + key = key.transpose(1, 2) + value = value.transpose(1, 2) + hidden_states = flash_attn_func( + query, key, value, dropout_p=0.0, causal=False + ) + hidden_states = hidden_states.flatten(2, 3) + + else: + # the output of sdp = (batch, num_heads, seq_len, head_dim) + # TODO: add support for attn.scale when we move to Torch 2.1 + hidden_states = F.scaled_dot_product_attention( + query, key, value, dropout_p=0.0, is_causal=False + ) + hidden_states = hidden_states.transpose(1, 2).flatten(2, 3) + + hidden_states = hidden_states.to(query.dtype) + + # 6. Output projection + if encoder_hidden_states is not None: + hidden_states, encoder_hidden_states = ( + hidden_states[:, : -encoder_hidden_states.shape[1]], + hidden_states[:, -encoder_hidden_states.shape[1] :], + ) + + if getattr(attn, "to_out", None) is not None: + hidden_states = attn.to_out[0](hidden_states) + hidden_states = attn.to_out[1](hidden_states) + + if getattr(attn, "to_add_out", None) is not None: + encoder_hidden_states = attn.to_add_out(encoder_hidden_states) + + return hidden_states, encoder_hidden_states +else: + xFuserHunyuanVideoAttnProcessor2_0 = None