Skip to content

Commit

Permalink
add hunyuan_video_usp_example.py
Browse files Browse the repository at this point in the history
  • Loading branch information
chengzeyi committed Dec 19, 2024
1 parent 46c0d54 commit df8f853
Show file tree
Hide file tree
Showing 3 changed files with 506 additions and 2 deletions.
302 changes: 302 additions & 0 deletions examples/hunyuan_video_usp_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,302 @@
import functools
from typing import Any, Dict, Union
import logging
import time

import torch

from diffusers import DiffusionPipeline, HunyuanVideoPipeline, HunyuanVideoTransformer3DModel
from diffusers.models.modeling_outputs import Transformer2DModelOutput
from diffusers.utils import export_to_video

from xfuser import xFuserArgs
from xfuser.config import FlexibleArgumentParser
from xfuser.core.distributed import (
get_world_group,
get_data_parallel_world_size,
get_data_parallel_rank,
get_runtime_state,
get_classifier_free_guidance_world_size,
get_classifier_free_guidance_rank,
get_cfg_group,
get_sequence_parallel_world_size,
get_sequence_parallel_rank,
get_sp_group,
is_dp_last_group,
initialize_runtime_state,
get_pipeline_parallel_world_size,
)

from xfuser.model_executor.layers.attention_processor import xFuserHunyuanVideoAttnProcessor2_0

assert xFuserHunyuanVideoAttnProcessor2_0 is not None


def parallelize_transformer(pipe: DiffusionPipeline):
transformer = pipe.transformer

@functools.wraps(transformer.__class__.forward)
def new_forward(
self,
hidden_states: torch.Tensor,
timestep: torch.LongTensor,
encoder_hidden_states: torch.Tensor,
encoder_attention_mask: torch.Tensor,
pooled_projections: torch.Tensor,
guidance: torch.Tensor = None,
return_dict: bool = True,
) -> Union[torch.Tensor, Dict[str, torch.Tensor]]:
batch_size, num_channels, num_frames, height, width = hidden_states.shape

assert batch_size % get_classifier_free_guidance_world_size(
) == 0, f"Cannot split dim 0 of hidden_states ({batch_size}) into {get_classifier_free_guidance_world_size()} parts."

p, p_t = self.config.patch_size, self.config.patch_size_t
post_patch_num_frames = num_frames // p_t
post_patch_height = height // p
post_patch_width = width // p

# 1. RoPE
image_rotary_emb = self.rope(hidden_states)

# 2. Conditional embeddings
temb = self.time_text_embed(timestep, guidance, pooled_projections)
hidden_states = self.x_embedder(hidden_states)
encoder_hidden_states = self.context_embedder(encoder_hidden_states,
timestep,
encoder_attention_mask)

encoder_attention_mask = encoder_attention_mask[0].to(torch.bool)
encoder_hidden_states_indices = torch.arange(
encoder_hidden_states.shape[1],
device=encoder_hidden_states.device)
encoder_hidden_states_indices = encoder_hidden_states_indices[
encoder_attention_mask]
encoder_hidden_states = encoder_hidden_states[
..., encoder_hidden_states_indices, :]
if encoder_hidden_states.shape[-2] % get_sequence_parallel_world_size(
) != 0:
get_runtime_state().split_text_embed_in_sp = False
else:
get_runtime_state().split_text_embed_in_sp = True

hidden_states = torch.chunk(hidden_states,
get_classifier_free_guidance_world_size(),
dim=0)[get_classifier_free_guidance_rank()]
hidden_states = torch.chunk(hidden_states,
get_sequence_parallel_world_size(),
dim=-2)[get_sequence_parallel_rank()]
encoder_hidden_states = torch.chunk(
encoder_hidden_states,
get_classifier_free_guidance_world_size(),
dim=0)[get_classifier_free_guidance_rank()]
if get_runtime_state().split_text_embed_in_sp:
encoder_hidden_states = torch.chunk(
encoder_hidden_states,
get_sequence_parallel_world_size(),
dim=-2)[get_sequence_parallel_rank()]

freqs_cos, freqs_sin = image_rotary_emb

def get_rotary_emb_chunk(freqs):
freqs = torch.chunk(freqs,
get_sequence_parallel_world_size(),
dim=0)[get_sequence_parallel_rank()]
return freqs

freqs_cos = get_rotary_emb_chunk(freqs_cos)
freqs_sin = get_rotary_emb_chunk(freqs_sin)
image_rotary_emb = (freqs_cos, freqs_sin)

# 4. Transformer blocks
if torch.is_grad_enabled() and self.gradient_checkpointing:

def create_custom_forward(module, return_dict=None):

def custom_forward(*inputs):
if return_dict is not None:
return module(*inputs, return_dict=return_dict)
else:
return module(*inputs)

return custom_forward

ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False}

for block in self.transformer_blocks:
hidden_states, encoder_hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
hidden_states,
encoder_hidden_states,
temb,
None,
image_rotary_emb,
**ckpt_kwargs,
)

for block in self.single_transformer_blocks:
hidden_states, encoder_hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
hidden_states,
encoder_hidden_states,
temb,
None,
image_rotary_emb,
**ckpt_kwargs,
)

else:
for block in self.transformer_blocks:
hidden_states, encoder_hidden_states = block(
hidden_states, encoder_hidden_states, temb, None,
image_rotary_emb)

for block in self.single_transformer_blocks:
hidden_states, encoder_hidden_states = block(
hidden_states, encoder_hidden_states, temb, None,
image_rotary_emb)

# 5. Output projection
hidden_states = self.norm_out(hidden_states, temb)
hidden_states = self.proj_out(hidden_states)

hidden_states = get_sp_group().all_gather(hidden_states, dim=-2)
hidden_states = get_cfg_group().all_gather(hidden_states, dim=0)

hidden_states = hidden_states.reshape(batch_size,
post_patch_num_frames,
post_patch_height,
post_patch_width, -1, p_t, p, p)
hidden_states = hidden_states.permute(0, 4, 1, 5, 2, 6, 3, 7)
hidden_states = hidden_states.flatten(6, 7).flatten(4, 5).flatten(2, 3)

if not return_dict:
return (hidden_states, )

return Transformer2DModelOutput(sample=hidden_states)

new_forward = new_forward.__get__(transformer)
transformer.forward = new_forward

for block in transformer.transformer_blocks + transformer.single_transformer_blocks:
block.attn.processor = xFuserHunyuanVideoAttnProcessor2_0()


def main():
parser = FlexibleArgumentParser(description="xFuser Arguments")
args = xFuserArgs.add_cli_args(parser).parse_args()
engine_args = xFuserArgs.from_cli_args(args)

engine_config, input_config = engine_args.create_config()
local_rank = get_world_group().local_rank

assert engine_args.pipefusion_parallel_degree == 1, "This script does not support PipeFusion."
assert engine_args.use_parallel_vae is False, "parallel VAE not implemented for HunyuanVideo"

transformer = HunyuanVideoTransformer3DModel.from_pretrained(
pretrained_model_name_or_path=engine_config.model_config.model,
subfolder="transformer",
torch_dtype=torch.bfloat16,
revision="refs/pr/18",
)
pipe = HunyuanVideoPipeline.from_pretrained(
pretrained_model_name_or_path=engine_config.model_config.model,
transformer=transformer,
torch_dtype=torch.float16,
revision="refs/pr/18",
)

if args.enable_sequential_cpu_offload:
pipe.enable_sequential_cpu_offload(gpu_id=local_rank)
logging.info(f"rank {local_rank} sequential CPU offload enabled")
elif args.enable_model_cpu_offload:
pipe.enable_model_cpu_offload(gpu_id=local_rank)
logging.info(f"rank {local_rank} model CPU offload enabled")
else:
device = torch.device(f"cuda:{local_rank}")
pipe = pipe.to(device)

if args.enable_tiling:
pipe.vae.enable_tiling(
# Make it runnable on GPUs with 48GB memory
tile_sample_min_height=128,
tile_sample_stride_height=96,
tile_sample_min_width=128,
tile_sample_stride_width=96,
tile_sample_min_num_frames=32,
tile_sample_stride_num_frames=24,
)

if args.enable_slicing:
pipe.vae.enable_slicing()

parameter_peak_memory = torch.cuda.max_memory_allocated(
device=f"cuda:{local_rank}")

initialize_runtime_state(pipe, engine_config)
get_runtime_state().set_video_input_parameters(
height=input_config.height,
width=input_config.width,
num_frames=input_config.num_frames,
batch_size=1,
num_inference_steps=input_config.num_inference_steps,
split_text_embed_in_sp=get_pipeline_parallel_world_size() == 1,
)

parallelize_transformer(pipe)

if engine_config.runtime_config.use_torch_compile:
torch._inductor.config.reorder_for_compute_comm_overlap = True
pipe.transformer = torch.compile(pipe.transformer,
mode="max-autotune-no-cudagraphs")

# one step to warmup the torch compiler
output = pipe(
height=input_config.height,
width=input_config.width,
num_frames=input_config.num_frames,
prompt=input_config.prompt,
num_inference_steps=1,
generator=torch.Generator(device="cuda").manual_seed(
input_config.seed),
).frames[0]

torch.cuda.reset_peak_memory_stats()
start_time = time.time()

output = pipe(
height=input_config.height,
width=input_config.width,
num_frames=input_config.num_frames,
prompt=input_config.prompt,
num_inference_steps=input_config.num_inference_steps,
generator=torch.Generator(device="cuda").manual_seed(
input_config.seed),
).frames[0]

end_time = time.time()
elapsed_time = end_time - start_time
peak_memory = torch.cuda.max_memory_allocated(device=f"cuda:{local_rank}")

parallel_info = (
f"dp{engine_args.data_parallel_degree}_cfg{engine_config.parallel_config.cfg_degree}_"
f"ulysses{engine_args.ulysses_degree}_ring{engine_args.ring_degree}_"
f"tp{engine_args.tensor_parallel_degree}_"
f"pp{engine_args.pipefusion_parallel_degree}_patch{engine_args.num_pipeline_patch}"
)
if is_dp_last_group():
resolution = f"{input_config.width}x{input_config.height}"
output_filename = f"results/hunyuan_video_{parallel_info}_{resolution}.mp4"
export_to_video(output, output_filename, fps=15)
print(f"output saved to {output_filename}")

if get_world_group().rank == get_world_group().world_size - 1:
print(
f"epoch time: {elapsed_time:.2f} sec, parameter memory: {parameter_peak_memory/1e9:.2f} GB, memory: {peak_memory/1e9} GB"
)
get_runtime_state().destory_distributed_env()


# mkdir -p results && torchrun --nproc_per_node=2 examples/hunyuan_video_usp_example.py --model tencent/HunyuanVideo --ulysses_degree 2 --num_inference_steps 30 --warmup_steps 0 --prompt "A cat walks on the grass, realistic" --height 320 --width 512 --num_frames 61 --enable_tiling
if __name__ == "__main__":
main()
4 changes: 2 additions & 2 deletions xfuser/core/distributed/runtime_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import numpy as np
import torch
from diffusers import DiffusionPipeline, CogVideoXPipeline
from diffusers import DiffusionPipeline
import torch.distributed

from xfuser.config.config import (
Expand Down Expand Up @@ -103,7 +103,7 @@ def __init__(self, pipeline: DiffusionPipeline, config: EngineConfig):
pipeline=pipeline, parallel_config=config.parallel_config
)
self.cogvideox = False
if isinstance(pipeline, CogVideoXPipeline):
if pipeline.__class__.__name__.startswith(("CogVideoX", "HunyuanVideo")):
self._set_cogvideox_parameters(
vae_scale_factor_spatial=pipeline.vae_scale_factor_spatial,
vae_scale_factor_temporal=pipeline.vae_scale_factor_temporal,
Expand Down
Loading

0 comments on commit df8f853

Please sign in to comment.