Skip to content

Commit

Permalink
FIX: Decouple runtime_state from USP (#382)
Browse files Browse the repository at this point in the history
  • Loading branch information
xibosun authored Dec 5, 2024
1 parent e3c7955 commit 6096bd4
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 14 deletions.
3 changes: 1 addition & 2 deletions xfuser/core/long_ctx_attention/ring/ring_flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from xfuser.core.cache_manager.cache_manager import get_cache_manager
from yunchang.ring.utils import RingComm, update_out_and_lse
from yunchang.ring.ring_flash_attn import RingFlashAttnFunc
from xfuser.core.distributed import get_runtime_state


def xdit_ring_flash_attn_forward(
Expand Down Expand Up @@ -49,7 +48,7 @@ def xdit_ring_flash_attn_forward(

next_k, next_v = None, None

if get_runtime_state().num_pipeline_patch > 1 and attn_layer is not None:
if attn_layer is not None:
k, v = get_cache_manager().update_and_get_kv_cache(
new_kv=[k, v],
layer=attn_layer,
Expand Down
14 changes: 2 additions & 12 deletions xfuser/model_executor/layers/attention_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -717,7 +717,7 @@ def __call__(
[num_encoder_hidden_states_tokens, num_query_tokens], dim=1
)
hidden_states = self.hybrid_seq_parallel_attn(
attn,
attn if get_runtime_state().num_pipeline_patch > 1 else None,
query,
key,
value,
Expand Down Expand Up @@ -1056,16 +1056,6 @@ def __call__(
key[:, :, text_seq_length:], image_rotary_emb
)

#! ---------------------------------------- KV CACHE ----------------------------------------
if get_pipeline_parallel_world_size() == 1 and not self.use_long_ctx_attn_kvcache:
key, value = get_cache_manager().update_and_get_kv_cache(
new_kv=[key, value],
layer=attn,
slice_dim=2,
layer_type="attn",
)
#! ---------------------------------------- KV CACHE ----------------------------------------

#! ---------------------------------------- ATTENTION ----------------------------------------
if get_pipeline_parallel_world_size() == 1 and get_runtime_state().split_text_embed_in_sp:
hidden_states = USP(
Expand Down Expand Up @@ -1096,7 +1086,7 @@ def __call__(
value = value.transpose(1, 2)

hidden_states = self.hybrid_seq_parallel_attn(
attn,
None,
query,
key,
value,
Expand Down

0 comments on commit 6096bd4

Please sign in to comment.