Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support optimized USP in Flux #367

Closed
wants to merge 8 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions examples/cogvideox_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ def main():

engine_config, input_config = engine_args.create_config()
local_rank = get_world_group().local_rank

assert engine_args.use_parallel_vae is False, "parallel VAE not implemented for CogVideo"

pipe = xFuserCogVideoXPipeline.from_pretrained(
pretrained_model_name_or_path=engine_config.model_config.model,
Expand Down
169 changes: 169 additions & 0 deletions examples/flux_usp_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,169 @@
# Flux inference with USP
# from https://github.com/chengzeyi/ParaAttention/blob/main/examples/run_flux.py

import functools
from typing import List, Optional

import logging
import time
import torch
import torch.distributed
from diffusers import DiffusionPipeline, FluxPipeline

from xfuser import xFuserArgs
from xfuser.config import FlexibleArgumentParser
from xfuser.core.distributed import (
get_world_group,
get_data_parallel_world_size,
get_data_parallel_rank,
get_runtime_state,
get_classifier_free_guidance_world_size,
get_classifier_free_guidance_rank,
get_cfg_group,
get_sequence_parallel_world_size,
get_sequence_parallel_rank,
get_sp_group,
is_dp_last_group,
initialize_runtime_state,
get_pipeline_parallel_world_size,
)

from xfuser.model_executor.layers.attention_processor import xFuserFluxAttnProcessor2_0

def parallelize_transformer(pipe: DiffusionPipeline):
transformer = pipe.transformer
original_forward = transformer.forward

@functools.wraps(transformer.__class__.forward)
def new_forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor] = None,
*args,
timestep: torch.LongTensor = None,
img_ids: torch.Tensor = None,
txt_ids: torch.Tensor = None,
**kwargs,
):
if isinstance(timestep, torch.Tensor) and timestep.ndim != 0 and timestep.shape[0] == hidden_states.shape[0]:
timestep = torch.chunk(timestep, get_classifier_free_guidance_world_size(),dim=0)[get_classifier_free_guidance_rank()]
hidden_states = torch.chunk(hidden_states, get_classifier_free_guidance_world_size(),dim=0)[get_classifier_free_guidance_rank()]
hidden_states = torch.chunk(hidden_states, get_sequence_parallel_world_size(),dim=-2)[get_sequence_parallel_rank()]
encoder_hidden_states = torch.chunk(encoder_hidden_states, get_classifier_free_guidance_world_size(),dim=0)[get_classifier_free_guidance_rank()]
encoder_hidden_states = torch.chunk(encoder_hidden_states, get_sequence_parallel_world_size(),dim=-2)[get_sequence_parallel_rank()]
img_ids = torch.chunk(img_ids, get_sequence_parallel_world_size(),dim=-2)[get_sequence_parallel_rank()]
txt_ids = torch.chunk(txt_ids, get_sequence_parallel_world_size(),dim=-2)[get_sequence_parallel_rank()]

for block in transformer.transformer_blocks + transformer.single_transformer_blocks:
block.attn.processor = xFuserFluxAttnProcessor2_0()

output = original_forward(
hidden_states,
encoder_hidden_states,
*args,
timestep=timestep,
img_ids=img_ids,
txt_ids=txt_ids,
**kwargs,
)

return_dict = not isinstance(output, tuple)
sample = output[0]
sample = get_sp_group().all_gather(sample, dim=-2)
sample = get_cfg_group().all_gather(sample, dim=0)
if return_dict:
return output.__class__(sample, *output[1:])
return (sample, *output[1:])

new_forward = new_forward.__get__(transformer)
transformer.forward = new_forward


def main():
parser = FlexibleArgumentParser(description="xFuser Arguments")
args = xFuserArgs.add_cli_args(parser).parse_args()
engine_args = xFuserArgs.from_cli_args(args)
engine_config, input_config = engine_args.create_config()
engine_config.runtime_config.dtype = torch.bfloat16
local_rank = get_world_group().local_rank

pipe = FluxPipeline.from_pretrained(
pretrained_model_name_or_path=engine_config.model_config.model,
torch_dtype=torch.bfloat16,
)

if args.enable_sequential_cpu_offload:
pipe.enable_sequential_cpu_offload(gpu_id=local_rank)
logging.info(f"rank {local_rank} sequential CPU offload enabled")
else:
pipe = pipe.to(f"cuda:{local_rank}")

parameter_peak_memory = torch.cuda.max_memory_allocated(device=f"cuda:{local_rank}")

initialize_runtime_state(pipe, engine_config)
get_runtime_state().set_input_parameters(
height=input_config.height,
width=input_config.width,
batch_size=1,
num_inference_steps=input_config.num_inference_steps,
max_condition_sequence_length=512,
split_text_embed_in_sp=get_pipeline_parallel_world_size() == 1,
)

parallelize_transformer(pipe)

if engine_config.runtime_config.use_torch_compile:
torch._inductor.config.reorder_for_compute_comm_overlap = True
pipe.transformer = torch.compile(pipe.transformer, mode="max-autotune-no-cudagraphs")

# warmup
output = pipe(
height=input_config.height,
width=input_config.width,
prompt=input_config.prompt,
num_inference_steps=1,
output_type=input_config.output_type,
generator=torch.Generator(device="cuda").manual_seed(input_config.seed),
).images

torch.cuda.reset_peak_memory_stats()
start_time = time.time()

output = pipe(
height=input_config.height,
width=input_config.width,
prompt=input_config.prompt,
num_inference_steps=input_config.num_inference_steps,
output_type=input_config.output_type,
generator=torch.Generator(device="cuda").manual_seed(input_config.seed),
)
end_time = time.time()
elapsed_time = end_time - start_time
peak_memory = torch.cuda.max_memory_allocated(device=f"cuda:{local_rank}")

parallel_info = (
f"dp{engine_args.data_parallel_degree}_cfg{engine_config.parallel_config.cfg_degree}_"
f"ulysses{engine_args.ulysses_degree}_ring{engine_args.ring_degree}_"
f"tp{engine_args.tensor_parallel_degree}_"
f"pp{engine_args.pipefusion_parallel_degree}_patch{engine_args.num_pipeline_patch}"
)
if input_config.output_type == "pil":
dp_group_index = get_data_parallel_rank()
num_dp_groups = get_data_parallel_world_size()
dp_batch_size = (input_config.batch_size + num_dp_groups - 1) // num_dp_groups
if is_dp_last_group():
for i, image in enumerate(output.images):
image_rank = dp_group_index * dp_batch_size + i
image_name = f"flux_result_{parallel_info}_{image_rank}_tc_{engine_args.use_torch_compile}.png"
image.save(f"./results/{image_name}")
print(f"image {i} saved to ./results/{image_name}")

if get_world_group().rank == get_world_group().world_size - 1:
print(
f"epoch time: {elapsed_time:.2f} sec, parameter memory: {parameter_peak_memory/1e9:.2f} GB, memory: {peak_memory/1e9:.2f} GB"
)
get_runtime_state().destory_distributed_env()


if __name__ == "__main__":
main()
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def get_cuda_version():
install_requires=[
"torch>=2.1.0",
"accelerate>=0.33.0",
"diffusers>=0.31", # NOTE: diffusers>=0.31.0 is necessary for CogVideoX and Flux
"diffusers@git+https://github.com/huggingface/diffusers", # NOTE: diffusers>=0.32.0.dev is necessary for CogVideoX and Flux
"transformers>=4.39.1",
"sentencepiece>=0.1.99",
"beautifulsoup4>=4.12.3",
Expand Down
7 changes: 5 additions & 2 deletions xfuser/core/distributed/group_coordinator.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,15 +223,18 @@ def all_gather(
# Convert negative dim to positive.
dim += input_.dim()
# Allocate output tensor.
input_size = input_.size()
input_size = list(input_.size())
input_size[0] *= world_size
output_tensor = torch.empty(
(world_size,) + input_size, dtype=input_.dtype, device=input_.device
input_size, dtype=input_.dtype, device=input_.device
)
# All-gather.
torch.distributed.all_gather_into_tensor(
output_tensor, input_, group=self.device_group
)
if dim != 0:
input_size[0] //= world_size
output_tensor = output_tensor.reshape([world_size, ] + input_size)
output_tensor = output_tensor.movedim(0, dim)

if separate_tensors:
Expand Down
5 changes: 5 additions & 0 deletions xfuser/core/distributed/runtime_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ class DiTRuntimeState(RuntimeState):
pp_patches_token_start_end_idx_global: Optional[List[List[int]]]
pp_patches_token_num: Optional[List[int]]
max_condition_sequence_length: int
split_text_embed_in_sp: bool

def __init__(self, pipeline: DiffusionPipeline, config: EngineConfig):
super().__init__(config)
Expand Down Expand Up @@ -128,11 +129,13 @@ def set_input_parameters(
num_inference_steps: Optional[int] = None,
seed: Optional[int] = None,
max_condition_sequence_length: Optional[int] = None,
split_text_embed_in_sp: bool = True,
):
self.input_config.num_inference_steps = (
num_inference_steps or self.input_config.num_inference_steps
)
self.max_condition_sequence_length = max_condition_sequence_length
self.split_text_embed_in_sp = split_text_embed_in_sp
if self.runtime_config.warmup_steps > self.input_config.num_inference_steps:
self.runtime_config.warmup_steps = self.input_config.num_inference_steps
if seed is not None and seed != self.input_config.seed:
Expand All @@ -156,12 +159,14 @@ def set_video_input_parameters(
batch_size: Optional[int] = None,
num_inference_steps: Optional[int] = None,
seed: Optional[int] = None,
split_text_embed_in_sp: bool = True,
):
self.input_config.num_inference_steps = (
num_inference_steps or self.input_config.num_inference_steps
)
if self.runtime_config.warmup_steps > self.input_config.num_inference_steps:
self.runtime_config.warmup_steps = self.input_config.num_inference_steps
self.split_text_embed_in_sp = split_text_embed_in_sp
if seed is not None and seed != self.input_config.seed:
self.input_config.seed = seed
set_random_seed(seed)
Expand Down
36 changes: 22 additions & 14 deletions xfuser/core/long_ctx_attention/hybrid/attn_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,23 +85,31 @@ def forward(
Returns:
* output (Tensor): context output
"""
supported_joint_strategy = ["none", "front", "rear"]
if joint_strategy not in supported_joint_strategy:
raise ValueError(
f"joint_strategy: {joint_strategy} not supprted. supported joint strategy: {supported_joint_strategy}"
)
elif joint_strategy != "none" and joint_tensor_query is None:
is_joint = False
if (joint_tensor_query is not None and
joint_tensor_key is not None and
joint_tensor_value is not None):
supported_joint_strategy = ["front", "rear"]
if joint_strategy not in supported_joint_strategy:
raise ValueError(
f"joint_strategy: {joint_strategy} not supprted. supported joint strategy: {supported_joint_strategy}"
)
elif joint_strategy == "rear":
query = torch.cat([query, joint_tensor_query], dim=1)
is_joint = True
else:
query = torch.cat([joint_tensor_query, query], dim=1)
is_joint = True
elif (joint_tensor_query is None and
joint_tensor_key is None and
joint_tensor_value is None):
pass
else:
raise ValueError(
f"joint_tensor_query must not be None when joint_strategy is not None"
f"joint_tensor_query, joint_tensor_key, and joint_tensor_value should be None or not None simultaneously."
)
elif joint_strategy == "rear":
query = torch.cat([query, joint_tensor_query], dim=1)
elif joint_strategy == "front":
query = torch.cat([joint_tensor_query, query], dim=1)
else:
pass

if joint_strategy != "none":
if is_joint:
ulysses_world_size = torch.distributed.get_world_size(self.ulysses_pg)
ulysses_rank = torch.distributed.get_rank(self.ulysses_pg)
attn_heads_per_ulysses_rank = (
Expand Down
30 changes: 18 additions & 12 deletions xfuser/core/long_ctx_attention/ring/ring_flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,16 +22,22 @@ def xdit_ring_flash_attn_forward(
joint_tensor_value=None,
joint_strategy="none",
):
supported_joint_strategy = ["none", "front", "rear"]
if joint_strategy not in supported_joint_strategy:
raise ValueError(
f"joint_strategy: {joint_strategy} not supprted. supported joint strategy: {supported_joint_strategy}"
)
elif joint_strategy != "none" and (
joint_tensor_key is None or joint_tensor_value is None
):
is_joint = False
if (joint_tensor_key is not None and
joint_tensor_value is not None):
supported_joint_strategy = ["front", "rear"]
if joint_strategy not in supported_joint_strategy:
raise ValueError(
f"joint_strategy: {joint_strategy} not supprted. supported joint strategy: {supported_joint_strategy}"
)
else:
is_joint = True
elif (joint_tensor_key is None and
joint_tensor_value is None):
pass
else:
raise ValueError(
f"joint_tensor_key & joint_tensor_value must not be None when joint_strategy is not None"
f"joint_tensor_key and joint_tensor_value should be None or not None simultaneously."
)

comm = RingComm(process_group)
Expand All @@ -57,19 +63,19 @@ def xdit_ring_flash_attn_forward(
next_v: torch.Tensor = comm.send_recv(v)
comm.commit()

if joint_strategy == "rear":
if is_joint and joint_strategy == "rear":
if step + 1 == comm.world_size:
key = torch.cat([k, joint_tensor_key], dim=1)
value = torch.cat([v, joint_tensor_value], dim=1)
else:
key, value = k, v
elif joint_strategy == "front":
elif is_joint and joint_strategy == "front":
if step == 0:
key = torch.cat([joint_tensor_key, k], dim=1)
value = torch.cat([joint_tensor_value, v], dim=1)
else:
key, value = k, v
elif joint_strategy == "none":
else:
key, value = k, v

if not causal or step <= comm.rank:
Expand Down
Loading