Skip to content

Commit

Permalink
Fix and refactor ray directory
Browse files Browse the repository at this point in the history
  • Loading branch information
lihuahua123 committed Dec 20, 2024
1 parent 708a070 commit df22136
Show file tree
Hide file tree
Showing 16 changed files with 104 additions and 39 deletions.
64 changes: 64 additions & 0 deletions examples/ray/ray_flux_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
import time
import os
import torch
import torch.distributed
from transformers import T5EncoderModel
from xfuser import xFuserArgs
from xfuser.ray.pipeline.pipeline_utils import RayDiffusionPipeline
from xfuser.config import FlexibleArgumentParser
from xfuser.model_executor.pipelines import xFuserPixArtAlphaPipeline, xFuserPixArtSigmaPipeline, xFuserStableDiffusion3Pipeline, xFuserHunyuanDiTPipeline, xFuserFluxPipeline

def main():
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "12355"
parser = FlexibleArgumentParser(description="xFuser Arguments")
args = xFuserArgs.add_cli_args(parser).parse_args()
engine_args = xFuserArgs.from_cli_args(args)
engine_config, input_config = engine_args.create_config()
engine_config.runtime_config.dtype = torch.bfloat16
model_name = engine_config.model_config.model.split("/")[-1]
PipelineClass = xFuserFluxPipeline
text_encoder_2 = T5EncoderModel.from_pretrained(engine_config.model_config.model, subfolder="text_encoder_2", torch_dtype=torch.bfloat16)
if args.use_fp8_t5_encoder:
from optimum.quanto import freeze, qfloat8, quantize
quantize(text_encoder_2, weights=qfloat8)
freeze(text_encoder_2)

pipe = RayDiffusionPipeline.from_pretrained(
PipelineClass=PipelineClass,
pretrained_model_name_or_path=engine_config.model_config.model,
engine_config=engine_config,
torch_dtype=torch.bfloat16,
text_encoder_2=text_encoder_2,
)
pipe.prepare_run(input_config)

start_time = time.time()
output = pipe(
height=input_config.height,
width=input_config.width,
prompt=input_config.prompt,
num_inference_steps=input_config.num_inference_steps,
output_type=input_config.output_type,
max_sequence_length=256,
guidance_scale=0.0,
generator=torch.Generator(device="cuda").manual_seed(input_config.seed),
)
end_time = time.time()
elapsed_time = end_time - start_time

print(f"elapsed time:{elapsed_time}")
if not os.path.exists("results"):
os.mkdir("results")
# output is a list of results from each worker, we take the last one
for i, image in enumerate(output[-1].images):
image.save(
f"/data/results/{model_name}_result_{i}.png"
)
print(
f"image {i} saved to /data/results/{model_name}_result_{i}.png"
)


if __name__ == "__main__":
main()
10 changes: 3 additions & 7 deletions tests/executor/ray_run.sh → examples/ray/ray_run.sh
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,8 @@ export MODEL_TYPE="Sd3"
# Configuration for different model types
# script, model_id, inference_step
declare -A MODEL_CONFIGS=(
["Pixart-alpha"]="pixartalpha_example.py /cfs/dit/PixArt-XL-2-1024-MS 20"
["Pixart-sigma"]="pixartsigma_example.py /cfs/dit/PixArt-Sigma-XL-2-2K-MS 20"
["Sd3"]="./test_ray.py /cfs/dit/stable-diffusion-3-medium-diffusers 20"
["Flux"]="flux_example.py /cfs/dit/FLUX.1-dev 28"
["HunyuanDiT"]="hunyuandit_example.py /cfs/dit/HunyuanDiT-v1.2-Diffusers 50"
["Sd3"]="ray_sd3_example.py /cfs/dit/stable-diffusion-3-medium-diffusers 20"
["Flux"]="ray_flux_example.py /cfs/dit/FLUX.1-dev 28"
)

if [[ -v MODEL_CONFIGS[$MODEL_TYPE] ]]; then
Expand All @@ -28,7 +25,6 @@ mkdir -p ./results
TASK_ARGS="--height 1024 --width 1024 --no_use_resolution_binning"


# On 8 gpus, pp=2, ulysses=2, ring=1, cfg_parallel=2 (split batch)
N_GPUS=2
PARALLEL_ARGS="--pipefusion_parallel_degree 2 --ulysses_degree 1 --ring_degree 1"

Expand All @@ -51,7 +47,7 @@ PARALLEL_ARGS="--pipefusion_parallel_degree 2 --ulysses_degree 1 --ring_degree 1

export CUDA_VISIBLE_DEVICES=0,1

python ./tests/executor/$SCRIPT \
python ./examples/ray/$SCRIPT \
--model $MODEL_ID \
$PARALLEL_ARGS \
$TASK_ARGS \
Expand Down
40 changes: 23 additions & 17 deletions tests/executor/test_ray.py → examples/ray/ray_sd3_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,25 @@
import torch
import torch.distributed
from transformers import T5EncoderModel
from xfuser import xFuserStableDiffusion3Pipeline, xFuserArgs
from xfuser.executor.gpu_executor import RayDiffusionPipeline
from xfuser import xFuserArgs
from xfuser.ray.pipeline.pipeline_utils import RayDiffusionPipeline
from xfuser.config import FlexibleArgumentParser
from xfuser.executor.gpu_executor import RayDiffusionPipeline
from xfuser.model_executor.pipelines import xFuserPixArtAlphaPipeline, xFuserPixArtSigmaPipeline, xFuserStableDiffusion3Pipeline, xFuserHunyuanDiTPipeline, xFuserFluxPipeline
import time
import os
import torch
import torch.distributed
from transformers import T5EncoderModel
from xfuser import xFuserStableDiffusion3Pipeline, xFuserArgs
from xfuser.config import FlexibleArgumentParser
from xfuser.core.distributed import (
get_world_group,
is_dp_last_group,
get_data_parallel_rank,
get_runtime_state,
)
from xfuser.core.distributed.parallel_state import get_data_parallel_world_size


def main():
os.environ["MASTER_ADDR"] = "localhost"
Expand All @@ -16,20 +30,12 @@ def main():
args = xFuserArgs.add_cli_args(parser).parse_args()
engine_args = xFuserArgs.from_cli_args(args)
engine_config, input_config = engine_args.create_config()
pipeline_map = {
"PixArt-XL-2-1024-MS": xFuserPixArtAlphaPipeline,
"PixArt-Sigma-XL-2-2K-MS": xFuserPixArtSigmaPipeline,
"stable-diffusion-3-medium-diffusers": xFuserStableDiffusion3Pipeline,
"HunyuanDiT-v1.2-Diffusers": xFuserHunyuanDiTPipeline,
"FLUX.1-schnell": xFuserFluxPipeline,
}
model_name = engine_config.model_config.model.split("/")[-1]
PipelineClass = pipeline_map.get(model_name)
if PipelineClass is None:
raise NotImplementedError(f"{model_name} is currently not supported!")
PipelineClass = xFuserStableDiffusion3Pipeline
text_encoder_3 = T5EncoderModel.from_pretrained(engine_config.model_config.model, subfolder="text_encoder_3", torch_dtype=torch.float16)
if args.use_fp8_t5_encoder:
from optimum.quanto import freeze, qfloat8, quantize
print(f"rank {local_rank} quantizing text encoder 2")
quantize(text_encoder_3, weights=qfloat8)
freeze(text_encoder_3)

Expand All @@ -42,6 +48,7 @@ def main():
)
pipe.prepare_run(input_config)

torch.cuda.reset_peak_memory_stats()
start_time = time.time()
output = pipe(
height=input_config.height,
Expand All @@ -53,19 +60,18 @@ def main():
)
end_time = time.time()
elapsed_time = end_time - start_time

print(f"elapsed time:{elapsed_time}")
if not os.path.exists("results"):
os.mkdir("results")
# output is a list of results from each worker, we take the last one
for i, image in enumerate(output[-1].images):
image.save(
f"/data/results/stable_diffusion_3_result_{i}.png"
f"/data/results/{model_name}_result_{i}.png"
)
print(
f"image {i} saved to /data/results/stable_diffusion_3_result_{i}.png"
f"image {i} saved to /data/results/{model_name}_result_{i}.png"
)


if __name__ == "__main__":
main()
main()
8 changes: 4 additions & 4 deletions examples/run.sh
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ set -x
export PYTHONPATH=$PWD:$PYTHONPATH

# Select the model type
export MODEL_TYPE="Sd3"
export MODEL_TYPE="Flux"
# Configuration for different model types
# script, model_id, inference_step
declare -A MODEL_CONFIGS=(
Expand All @@ -29,8 +29,8 @@ TASK_ARGS="--height 1024 --width 1024 --no_use_resolution_binning"


# On 8 gpus, pp=2, ulysses=2, ring=1, cfg_parallel=2 (split batch)
N_GPUS=2
PARALLEL_ARGS="--pipefusion_parallel_degree 2 --ulysses_degree 1 --ring_degree 1"
N_GPUS=8
PARALLEL_ARGS="--pipefusion_parallel_degree 2 --ulysses_degree 2 --ring_degree 2"

# CFG_ARGS="--use_cfg_parallel"

Expand All @@ -49,7 +49,7 @@ PARALLEL_ARGS="--pipefusion_parallel_degree 2 --ulysses_degree 1 --ring_degree 1
# Use this flag to quantize the T5 text encoder, which could reduce the memory usage and have no effect on the result quality.
# QUANTIZE_FLAG="--use_fp8_t5_encoder"

export CUDA_VISIBLE_DEVICES=0,1
# export CUDA_VISIBLE_DEVICES=4,5,6,7

torchrun --nproc_per_node=$N_GPUS ./examples/$SCRIPT \
--model $MODEL_ID \
Expand Down
4 changes: 2 additions & 2 deletions examples/sd3_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,11 +75,11 @@ def main():

if get_world_group().rank == get_world_group().world_size - 1:
print(
f"epoch time: {elapsed_time} sec, parameter memory: {parameter_peak_memory/1e9:.2f} GB, peak memory: {peak_memory/1e9:.2f} GB"
f"epoch time: {elapsed_time:.2f} sec, parameter memory: {parameter_peak_memory/1e9:.2f} GB, peak memory: {peak_memory/1e9:.2f} GB"
)

get_runtime_state().destory_distributed_env()


if __name__ == "__main__":
main()
main()
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def get_cuda_version():
"imageio",
"imageio-ffmpeg",
"optimum-quanto",
"flash_attn>=2.7.0", # flash_attn>=2.7.0 with torch>=2.4.0 wraps ops with torch.ops
"flash_attn>=2.6.3",
"ray"
],
extras_require={
Expand Down
2 changes: 1 addition & 1 deletion xfuser/config/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ def add_cli_args(parser: FlexibleArgumentParser):
"--ray_world_size",
type=int,
default=1,
help="World size.",
help="The number of ray workers (world_size for ray)",
)
parallel_group.add_argument(
"--use_cfg_parallel",
Expand Down
2 changes: 1 addition & 1 deletion xfuser/config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ class ParallelConfig:
tp_config: TensorParallelConfig
distributed_executor_backend: Optional[str] = None
world_size: int = 1 # FIXME: remove this
worker_cls: str = "xfuser.worker.worker.Worker"
worker_cls: str = "xfuser.ray.worker.worker.Worker"

def __post_init__(self):
assert self.tp_config is not None, "tp_config must be set"
Expand Down
File renamed without changes.
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,11 @@
from itertools import islice, repeat
from typing import Any, Dict, List, Optional, Tuple

from xfuser.executor.base_executor import BaseExecutor
from xfuser.executor.ray_utils import initialize_ray_cluster
from xfuser.ray.pipeline.base_executor import BaseExecutor
from xfuser.ray.pipeline.ray_utils import initialize_ray_cluster
from xfuser.logger import init_logger
from xfuser.worker.worker_wrappers import RayWorkerWrapper
from xfuser.ray.worker.worker_wrappers import RayWorkerWrapper
from xfuser.config.config import InputConfig, EngineConfig
from xfuser.model_executor.pipelines.base_pipeline import xFuserPipelineBaseWrapper
logger = init_logger(__name__)


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -292,5 +292,5 @@ def _verify_bundles(
len(bundles),
device_str,
node_id,
parallel_config.tensor_parallel_size,
parallel_config.tp_degree,
)
File renamed without changes.
File renamed without changes.
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from abc import ABC
from typing import Any, Dict

from xfuser.worker.utils import update_environment_variables, resolve_obj_by_qualname
from xfuser.ray.worker.utils import update_environment_variables, resolve_obj_by_qualname
from xfuser.config.config import EngineConfig

class BaseWorkerWrapper(ABC):
Expand Down

0 comments on commit df22136

Please sign in to comment.