Skip to content

Commit

Permalink
[FEAT] Slice text embedding in MM-DiT (#361)
Browse files Browse the repository at this point in the history
  • Loading branch information
xibosun authored Nov 23, 2024
1 parent c203225 commit a7bcdb8
Show file tree
Hide file tree
Showing 9 changed files with 161 additions and 109 deletions.
2 changes: 2 additions & 0 deletions examples/cogvideox_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ def main():

engine_config, input_config = engine_args.create_config()
local_rank = get_world_group().local_rank

assert engine_args.use_parallel_vae is False, "parallel VAE not implemented for CogVideo"

pipe = xFuserCogVideoXPipeline.from_pretrained(
pretrained_model_name_or_path=engine_config.model_config.model,
Expand Down
5 changes: 5 additions & 0 deletions xfuser/core/distributed/runtime_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ class DiTRuntimeState(RuntimeState):
pp_patches_token_start_end_idx_global: Optional[List[List[int]]]
pp_patches_token_num: Optional[List[int]]
max_condition_sequence_length: int
split_text_embed_in_sp: bool

def __init__(self, pipeline: DiffusionPipeline, config: EngineConfig):
super().__init__(config)
Expand Down Expand Up @@ -128,11 +129,13 @@ def set_input_parameters(
num_inference_steps: Optional[int] = None,
seed: Optional[int] = None,
max_condition_sequence_length: Optional[int] = None,
split_text_embed_in_sp: bool = True,
):
self.input_config.num_inference_steps = (
num_inference_steps or self.input_config.num_inference_steps
)
self.max_condition_sequence_length = max_condition_sequence_length
self.split_text_embed_in_sp = split_text_embed_in_sp
if self.runtime_config.warmup_steps > self.input_config.num_inference_steps:
self.runtime_config.warmup_steps = self.input_config.num_inference_steps
if seed is not None and seed != self.input_config.seed:
Expand All @@ -156,12 +159,14 @@ def set_video_input_parameters(
batch_size: Optional[int] = None,
num_inference_steps: Optional[int] = None,
seed: Optional[int] = None,
split_text_embed_in_sp: bool = True,
):
self.input_config.num_inference_steps = (
num_inference_steps or self.input_config.num_inference_steps
)
if self.runtime_config.warmup_steps > self.input_config.num_inference_steps:
self.runtime_config.warmup_steps = self.input_config.num_inference_steps
self.split_text_embed_in_sp = split_text_embed_in_sp
if seed is not None and seed != self.input_config.seed:
self.input_config.seed = seed
set_random_seed(seed)
Expand Down
36 changes: 22 additions & 14 deletions xfuser/core/long_ctx_attention/hybrid/attn_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,23 +85,31 @@ def forward(
Returns:
* output (Tensor): context output
"""
supported_joint_strategy = ["none", "front", "rear"]
if joint_strategy not in supported_joint_strategy:
raise ValueError(
f"joint_strategy: {joint_strategy} not supprted. supported joint strategy: {supported_joint_strategy}"
)
elif joint_strategy != "none" and joint_tensor_query is None:
is_joint = False
if (joint_tensor_query is not None and
joint_tensor_key is not None and
joint_tensor_value is not None):
supported_joint_strategy = ["front", "rear"]
if joint_strategy not in supported_joint_strategy:
raise ValueError(
f"joint_strategy: {joint_strategy} not supprted. supported joint strategy: {supported_joint_strategy}"
)
elif joint_strategy == "rear":
query = torch.cat([query, joint_tensor_query], dim=1)
is_joint = True
else:
query = torch.cat([joint_tensor_query, query], dim=1)
is_joint = True
elif (joint_tensor_query is None and
joint_tensor_key is None and
joint_tensor_value is None):
pass
else:
raise ValueError(
f"joint_tensor_query must not be None when joint_strategy is not None"
f"joint_tensor_query, joint_tensor_key, and joint_tensor_value should be None or not None simultaneously."
)
elif joint_strategy == "rear":
query = torch.cat([query, joint_tensor_query], dim=1)
elif joint_strategy == "front":
query = torch.cat([joint_tensor_query, query], dim=1)
else:
pass

if joint_strategy != "none":
if is_joint:
ulysses_world_size = torch.distributed.get_world_size(self.ulysses_pg)
ulysses_rank = torch.distributed.get_rank(self.ulysses_pg)
attn_heads_per_ulysses_rank = (
Expand Down
30 changes: 18 additions & 12 deletions xfuser/core/long_ctx_attention/ring/ring_flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,16 +22,22 @@ def xdit_ring_flash_attn_forward(
joint_tensor_value=None,
joint_strategy="none",
):
supported_joint_strategy = ["none", "front", "rear"]
if joint_strategy not in supported_joint_strategy:
raise ValueError(
f"joint_strategy: {joint_strategy} not supprted. supported joint strategy: {supported_joint_strategy}"
)
elif joint_strategy != "none" and (
joint_tensor_key is None or joint_tensor_value is None
):
is_joint = False
if (joint_tensor_key is not None and
joint_tensor_value is not None):
supported_joint_strategy = ["front", "rear"]
if joint_strategy not in supported_joint_strategy:
raise ValueError(
f"joint_strategy: {joint_strategy} not supprted. supported joint strategy: {supported_joint_strategy}"
)
else:
is_joint = True
elif (joint_tensor_key is None and
joint_tensor_value is None):
pass
else:
raise ValueError(
f"joint_tensor_key & joint_tensor_value must not be None when joint_strategy is not None"
f"joint_tensor_key and joint_tensor_value should be None or not None simultaneously."
)

comm = RingComm(process_group)
Expand All @@ -57,19 +63,19 @@ def xdit_ring_flash_attn_forward(
next_v: torch.Tensor = comm.send_recv(v)
comm.commit()

if joint_strategy == "rear":
if is_joint and joint_strategy == "rear":
if step + 1 == comm.world_size:
key = torch.cat([k, joint_tensor_key], dim=1)
value = torch.cat([v, joint_tensor_value], dim=1)
else:
key, value = k, v
elif joint_strategy == "front":
elif is_joint and joint_strategy == "front":
if step == 0:
key = torch.cat([joint_tensor_key, k], dim=1)
value = torch.cat([joint_tensor_value, v], dim=1)
else:
key, value = k, v
elif joint_strategy == "none":
else:
key, value = k, v

if not causal or step <= comm.rank:
Expand Down
74 changes: 47 additions & 27 deletions xfuser/model_executor/layers/attention_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -453,19 +453,28 @@ def __call__(

#! ---------------------------------------- ATTENTION ----------------------------------------
if HAS_LONG_CTX_ATTN and get_sequence_parallel_world_size() > 1:
if get_runtime_state().split_text_embed_in_sp:
query = torch.cat([query, encoder_hidden_states_query_proj], dim=1)
key = torch.cat([key, encoder_hidden_states_key_proj], dim=1)
value = torch.cat([value, encoder_hidden_states_value_proj], dim=1)

encoder_hidden_states_query_proj = None
encoder_hidden_states_key_proj = None
encoder_hidden_states_value_proj = None
else:
encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
batch_size, -1, attn.heads, head_dim
)
encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(
batch_size, -1, attn.heads, head_dim
)
encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
batch_size, -1, attn.heads, head_dim
)
query = query.view(batch_size, -1, attn.heads, head_dim)
key = key.view(batch_size, -1, attn.heads, head_dim)
value = value.view(batch_size, -1, attn.heads, head_dim)

encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
batch_size, -1, attn.heads, head_dim
)
encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(
batch_size, -1, attn.heads, head_dim
)
encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
batch_size, -1, attn.heads, head_dim
)
hidden_states = self.hybrid_seq_parallel_attn(
attn,
query,
Expand Down Expand Up @@ -682,15 +691,20 @@ def __call__(
query = query.transpose(1, 2)
key = key.transpose(1, 2)
value = value.transpose(1, 2)
encoder_hidden_states_query_proj, query = query.split(
[num_encoder_hidden_states_tokens, num_query_tokens], dim=1
)
encoder_hidden_states_key_proj, key = key.split(
[num_encoder_hidden_states_tokens, num_query_tokens], dim=1
)
encoder_hidden_states_value_proj, value = value.split(
[num_encoder_hidden_states_tokens, num_query_tokens], dim=1
)
if get_runtime_state().split_text_embed_in_sp:
encoder_hidden_states_query_proj = None
encoder_hidden_states_key_proj = None
encoder_hidden_states_value_proj = None
else:
encoder_hidden_states_query_proj, query = query.split(
[num_encoder_hidden_states_tokens, num_query_tokens], dim=1
)
encoder_hidden_states_key_proj, key = key.split(
[num_encoder_hidden_states_tokens, num_query_tokens], dim=1
)
encoder_hidden_states_value_proj, value = value.split(
[num_encoder_hidden_states_tokens, num_query_tokens], dim=1
)
hidden_states = self.hybrid_seq_parallel_attn(
attn,
query,
Expand Down Expand Up @@ -1043,19 +1057,25 @@ def __call__(

#! ---------------------------------------- ATTENTION ----------------------------------------
if HAS_LONG_CTX_ATTN and get_sequence_parallel_world_size() > 1:
encoder_query = query[:, :, :text_seq_length, :]
query = query[:, :, text_seq_length:, :]
encoder_key = key[:, :, :text_seq_length, :]
key = key[:, :, text_seq_length:, :]
encoder_value = value[:, :, :text_seq_length, :]
value = value[:, :, text_seq_length:, :]
if get_runtime_state().split_text_embed_in_sp:
encoder_query = None
encoder_key = None
encoder_value = None
else:
encoder_query = query[:, :, :text_seq_length, :]
query = query[:, :, text_seq_length:, :]
encoder_key = key[:, :, :text_seq_length, :]
key = key[:, :, text_seq_length:, :]
encoder_value = value[:, :, :text_seq_length, :]
value = value[:, :, text_seq_length:, :]

encoder_query = encoder_query.transpose(1, 2)
encoder_key = encoder_key.transpose(1, 2)
encoder_value = encoder_value.transpose(1, 2)

query = query.transpose(1, 2)
key = key.transpose(1, 2)
value = value.transpose(1, 2)
encoder_query = encoder_query.transpose(1, 2)
encoder_key = encoder_key.transpose(1, 2)
encoder_value = encoder_value.transpose(1, 2)

hidden_states = self.hybrid_seq_parallel_attn(
attn,
Expand Down
37 changes: 15 additions & 22 deletions xfuser/model_executor/pipelines/base_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
get_runtime_state,
initialize_runtime_state,
is_dp_last_group,
get_sequence_parallel_rank,
)
from xfuser.core.fast_attention import (
get_fast_attn_enable,
Expand Down Expand Up @@ -408,39 +407,33 @@ def _init_async_pipeline(

def _process_cfg_split_batch(
self,
concat_group_0_negative: torch.Tensor,
concat_group_0: torch.Tensor,
concat_group_1_negative: torch.Tensor,
concat_group_1: torch.Tensor,
negative_embeds: torch.Tensor,
embeds: torch.Tensor,
negative_embdes_mask: torch.Tensor = None,
embeds_mask: torch.Tensor = None,
):
if get_classifier_free_guidance_world_size() == 1:
concat_group_0 = torch.cat([concat_group_0_negative, concat_group_0], dim=0)
concat_group_1 = torch.cat([concat_group_1_negative, concat_group_1], dim=0)
embeds = torch.cat([negative_embeds, embeds], dim=0)
elif get_classifier_free_guidance_rank() == 0:
concat_group_0 = concat_group_0_negative
concat_group_1 = concat_group_1_negative
embeds = negative_embeds
elif get_classifier_free_guidance_rank() == 1:
concat_group_0 = concat_group_0
concat_group_1 = concat_group_1
embeds = embeds
else:
raise ValueError("Invalid classifier free guidance rank")
return concat_group_0, concat_group_1

def _process_cfg_split_batch_latte(
self,
concat_group_0: torch.Tensor,
concat_group_0_negative: torch.Tensor,
):
if negative_embdes_mask is None:
return embeds

if get_classifier_free_guidance_world_size() == 1:
concat_group_0 = torch.cat([concat_group_0_negative, concat_group_0], dim=0)
embeds_mask = torch.cat([negative_embdes_mask, embeds_mask], dim=0)
elif get_classifier_free_guidance_rank() == 0:
concat_group_0 = concat_group_0_negative
embeds_mask = negative_embdes_mask
elif get_classifier_free_guidance_rank() == 1:
concat_group_0 = concat_group_0
embeds_mask = embeds_mask
else:
raise ValueError("Invalid classifier free guidance rank")
return concat_group_0
return embeds, embeds_mask

def is_dp_last_group(self):
"""Return True if in the last data parallel group, False otherwise.
Also include parallel vae situation.
Expand Down
Loading

0 comments on commit a7bcdb8

Please sign in to comment.