Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FEAT] Slice text embedding in MM-DiT #361

Merged
merged 2 commits into from
Nov 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions examples/cogvideox_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ def main():

engine_config, input_config = engine_args.create_config()
local_rank = get_world_group().local_rank

assert engine_args.use_parallel_vae is False, "parallel VAE not implemented for CogVideo"

pipe = xFuserCogVideoXPipeline.from_pretrained(
pretrained_model_name_or_path=engine_config.model_config.model,
Expand Down
5 changes: 5 additions & 0 deletions xfuser/core/distributed/runtime_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ class DiTRuntimeState(RuntimeState):
pp_patches_token_start_end_idx_global: Optional[List[List[int]]]
pp_patches_token_num: Optional[List[int]]
max_condition_sequence_length: int
split_text_embed_in_sp: bool

def __init__(self, pipeline: DiffusionPipeline, config: EngineConfig):
super().__init__(config)
Expand Down Expand Up @@ -128,11 +129,13 @@ def set_input_parameters(
num_inference_steps: Optional[int] = None,
seed: Optional[int] = None,
max_condition_sequence_length: Optional[int] = None,
split_text_embed_in_sp: bool = True,
):
self.input_config.num_inference_steps = (
num_inference_steps or self.input_config.num_inference_steps
)
self.max_condition_sequence_length = max_condition_sequence_length
self.split_text_embed_in_sp = split_text_embed_in_sp
if self.runtime_config.warmup_steps > self.input_config.num_inference_steps:
self.runtime_config.warmup_steps = self.input_config.num_inference_steps
if seed is not None and seed != self.input_config.seed:
Expand All @@ -156,12 +159,14 @@ def set_video_input_parameters(
batch_size: Optional[int] = None,
num_inference_steps: Optional[int] = None,
seed: Optional[int] = None,
split_text_embed_in_sp: bool = True,
):
self.input_config.num_inference_steps = (
num_inference_steps or self.input_config.num_inference_steps
)
if self.runtime_config.warmup_steps > self.input_config.num_inference_steps:
self.runtime_config.warmup_steps = self.input_config.num_inference_steps
self.split_text_embed_in_sp = split_text_embed_in_sp
if seed is not None and seed != self.input_config.seed:
self.input_config.seed = seed
set_random_seed(seed)
Expand Down
36 changes: 22 additions & 14 deletions xfuser/core/long_ctx_attention/hybrid/attn_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,23 +85,31 @@ def forward(
Returns:
* output (Tensor): context output
"""
supported_joint_strategy = ["none", "front", "rear"]
if joint_strategy not in supported_joint_strategy:
raise ValueError(
f"joint_strategy: {joint_strategy} not supprted. supported joint strategy: {supported_joint_strategy}"
)
elif joint_strategy != "none" and joint_tensor_query is None:
is_joint = False
if (joint_tensor_query is not None and
joint_tensor_key is not None and
joint_tensor_value is not None):
supported_joint_strategy = ["front", "rear"]
if joint_strategy not in supported_joint_strategy:
raise ValueError(
f"joint_strategy: {joint_strategy} not supprted. supported joint strategy: {supported_joint_strategy}"
)
elif joint_strategy == "rear":
query = torch.cat([query, joint_tensor_query], dim=1)
is_joint = True
else:
query = torch.cat([joint_tensor_query, query], dim=1)
is_joint = True
elif (joint_tensor_query is None and
joint_tensor_key is None and
joint_tensor_value is None):
pass
else:
raise ValueError(
f"joint_tensor_query must not be None when joint_strategy is not None"
f"joint_tensor_query, joint_tensor_key, and joint_tensor_value should be None or not None simultaneously."
)
elif joint_strategy == "rear":
query = torch.cat([query, joint_tensor_query], dim=1)
elif joint_strategy == "front":
query = torch.cat([joint_tensor_query, query], dim=1)
else:
pass

if joint_strategy != "none":
if is_joint:
ulysses_world_size = torch.distributed.get_world_size(self.ulysses_pg)
ulysses_rank = torch.distributed.get_rank(self.ulysses_pg)
attn_heads_per_ulysses_rank = (
Expand Down
30 changes: 18 additions & 12 deletions xfuser/core/long_ctx_attention/ring/ring_flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,16 +22,22 @@ def xdit_ring_flash_attn_forward(
joint_tensor_value=None,
joint_strategy="none",
):
supported_joint_strategy = ["none", "front", "rear"]
if joint_strategy not in supported_joint_strategy:
raise ValueError(
f"joint_strategy: {joint_strategy} not supprted. supported joint strategy: {supported_joint_strategy}"
)
elif joint_strategy != "none" and (
joint_tensor_key is None or joint_tensor_value is None
):
is_joint = False
if (joint_tensor_key is not None and
joint_tensor_value is not None):
supported_joint_strategy = ["front", "rear"]
if joint_strategy not in supported_joint_strategy:
raise ValueError(
f"joint_strategy: {joint_strategy} not supprted. supported joint strategy: {supported_joint_strategy}"
)
else:
is_joint = True
elif (joint_tensor_key is None and
joint_tensor_value is None):
pass
else:
raise ValueError(
f"joint_tensor_key & joint_tensor_value must not be None when joint_strategy is not None"
f"joint_tensor_key and joint_tensor_value should be None or not None simultaneously."
)

comm = RingComm(process_group)
Expand All @@ -57,19 +63,19 @@ def xdit_ring_flash_attn_forward(
next_v: torch.Tensor = comm.send_recv(v)
comm.commit()

if joint_strategy == "rear":
if is_joint and joint_strategy == "rear":
if step + 1 == comm.world_size:
key = torch.cat([k, joint_tensor_key], dim=1)
value = torch.cat([v, joint_tensor_value], dim=1)
else:
key, value = k, v
elif joint_strategy == "front":
elif is_joint and joint_strategy == "front":
if step == 0:
key = torch.cat([joint_tensor_key, k], dim=1)
value = torch.cat([joint_tensor_value, v], dim=1)
else:
key, value = k, v
elif joint_strategy == "none":
else:
key, value = k, v

if not causal or step <= comm.rank:
Expand Down
74 changes: 47 additions & 27 deletions xfuser/model_executor/layers/attention_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -453,19 +453,28 @@ def __call__(

#! ---------------------------------------- ATTENTION ----------------------------------------
if HAS_LONG_CTX_ATTN and get_sequence_parallel_world_size() > 1:
if get_runtime_state().split_text_embed_in_sp:
query = torch.cat([query, encoder_hidden_states_query_proj], dim=1)
key = torch.cat([key, encoder_hidden_states_key_proj], dim=1)
value = torch.cat([value, encoder_hidden_states_value_proj], dim=1)

encoder_hidden_states_query_proj = None
encoder_hidden_states_key_proj = None
encoder_hidden_states_value_proj = None
else:
encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
batch_size, -1, attn.heads, head_dim
)
encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(
batch_size, -1, attn.heads, head_dim
)
encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
batch_size, -1, attn.heads, head_dim
)
query = query.view(batch_size, -1, attn.heads, head_dim)
key = key.view(batch_size, -1, attn.heads, head_dim)
value = value.view(batch_size, -1, attn.heads, head_dim)

encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
batch_size, -1, attn.heads, head_dim
)
encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(
batch_size, -1, attn.heads, head_dim
)
encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
batch_size, -1, attn.heads, head_dim
)
hidden_states = self.hybrid_seq_parallel_attn(
attn,
query,
Expand Down Expand Up @@ -682,15 +691,20 @@ def __call__(
query = query.transpose(1, 2)
key = key.transpose(1, 2)
value = value.transpose(1, 2)
encoder_hidden_states_query_proj, query = query.split(
[num_encoder_hidden_states_tokens, num_query_tokens], dim=1
)
encoder_hidden_states_key_proj, key = key.split(
[num_encoder_hidden_states_tokens, num_query_tokens], dim=1
)
encoder_hidden_states_value_proj, value = value.split(
[num_encoder_hidden_states_tokens, num_query_tokens], dim=1
)
if get_runtime_state().split_text_embed_in_sp:
encoder_hidden_states_query_proj = None
encoder_hidden_states_key_proj = None
encoder_hidden_states_value_proj = None
else:
encoder_hidden_states_query_proj, query = query.split(
[num_encoder_hidden_states_tokens, num_query_tokens], dim=1
)
encoder_hidden_states_key_proj, key = key.split(
[num_encoder_hidden_states_tokens, num_query_tokens], dim=1
)
encoder_hidden_states_value_proj, value = value.split(
[num_encoder_hidden_states_tokens, num_query_tokens], dim=1
)
hidden_states = self.hybrid_seq_parallel_attn(
attn,
query,
Expand Down Expand Up @@ -1043,19 +1057,25 @@ def __call__(

#! ---------------------------------------- ATTENTION ----------------------------------------
if HAS_LONG_CTX_ATTN and get_sequence_parallel_world_size() > 1:
encoder_query = query[:, :, :text_seq_length, :]
query = query[:, :, text_seq_length:, :]
encoder_key = key[:, :, :text_seq_length, :]
key = key[:, :, text_seq_length:, :]
encoder_value = value[:, :, :text_seq_length, :]
value = value[:, :, text_seq_length:, :]
if get_runtime_state().split_text_embed_in_sp:
encoder_query = None
encoder_key = None
encoder_value = None
else:
encoder_query = query[:, :, :text_seq_length, :]
query = query[:, :, text_seq_length:, :]
encoder_key = key[:, :, :text_seq_length, :]
key = key[:, :, text_seq_length:, :]
encoder_value = value[:, :, :text_seq_length, :]
value = value[:, :, text_seq_length:, :]

encoder_query = encoder_query.transpose(1, 2)
encoder_key = encoder_key.transpose(1, 2)
encoder_value = encoder_value.transpose(1, 2)

query = query.transpose(1, 2)
key = key.transpose(1, 2)
value = value.transpose(1, 2)
encoder_query = encoder_query.transpose(1, 2)
encoder_key = encoder_key.transpose(1, 2)
encoder_value = encoder_value.transpose(1, 2)

hidden_states = self.hybrid_seq_parallel_attn(
attn,
Expand Down
37 changes: 15 additions & 22 deletions xfuser/model_executor/pipelines/base_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
get_runtime_state,
initialize_runtime_state,
is_dp_last_group,
get_sequence_parallel_rank,
)
from xfuser.core.fast_attention import (
get_fast_attn_enable,
Expand Down Expand Up @@ -408,39 +407,33 @@ def _init_async_pipeline(

def _process_cfg_split_batch(
self,
concat_group_0_negative: torch.Tensor,
concat_group_0: torch.Tensor,
concat_group_1_negative: torch.Tensor,
concat_group_1: torch.Tensor,
negative_embeds: torch.Tensor,
embeds: torch.Tensor,
negative_embdes_mask: torch.Tensor = None,
embeds_mask: torch.Tensor = None,
):
if get_classifier_free_guidance_world_size() == 1:
concat_group_0 = torch.cat([concat_group_0_negative, concat_group_0], dim=0)
concat_group_1 = torch.cat([concat_group_1_negative, concat_group_1], dim=0)
embeds = torch.cat([negative_embeds, embeds], dim=0)
elif get_classifier_free_guidance_rank() == 0:
concat_group_0 = concat_group_0_negative
concat_group_1 = concat_group_1_negative
embeds = negative_embeds
elif get_classifier_free_guidance_rank() == 1:
concat_group_0 = concat_group_0
concat_group_1 = concat_group_1
embeds = embeds
else:
raise ValueError("Invalid classifier free guidance rank")
return concat_group_0, concat_group_1

def _process_cfg_split_batch_latte(
self,
concat_group_0: torch.Tensor,
concat_group_0_negative: torch.Tensor,
):
if negative_embdes_mask is None:
return embeds

if get_classifier_free_guidance_world_size() == 1:
concat_group_0 = torch.cat([concat_group_0_negative, concat_group_0], dim=0)
embeds_mask = torch.cat([negative_embdes_mask, embeds_mask], dim=0)
elif get_classifier_free_guidance_rank() == 0:
concat_group_0 = concat_group_0_negative
embeds_mask = negative_embdes_mask
elif get_classifier_free_guidance_rank() == 1:
concat_group_0 = concat_group_0
embeds_mask = embeds_mask
else:
raise ValueError("Invalid classifier free guidance rank")
return concat_group_0
return embeds, embeds_mask

def is_dp_last_group(self):
"""Return True if in the last data parallel group, False otherwise.
Also include parallel vae situation.
Expand Down
Loading
Loading