Skip to content

Commit

Permalink
[example] print model parameter memory usage (#308)
Browse files Browse the repository at this point in the history
  • Loading branch information
feifeibear authored Oct 14, 2024
1 parent e439562 commit a58e4fa
Show file tree
Hide file tree
Showing 6 changed files with 24 additions and 22 deletions.
2 changes: 2 additions & 0 deletions docs/performance/cogvideo.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
## CogVideo Performance
[Chinese Version](./cogvideo_zh.md)

Details on how to apply xDiT to CogVideoX: [Leveraging xDiT to Parallelize the Open-Sourced Video Generation Model CogVideoX](https://medium.com/@xditproject/boosting-aigc-inference-leveraging-xdit-to-parallelize-the-cogvideox-text-to-video-workflow-8128e45b36e9)

CogVideo functions as a text-to-video model. xDiT presently integrates USP techniques (including Ulysses attention and Ring attention) and CFG parallelism to enhance inference speed, while work on PipeFusion is ongoing. Due to constraints in video generation dimensions in CogVideo, the maximum parallelism level for USP is 2. Thus, xDiT can leverage up to 4 GPUs to execute CogVideo, despite the potential for additional GPUs within the machine.

In a system equipped with L40 (PCIe) GPUs, we compared the inference performance of single-GPU CogVideoX utilizing the `diffusers` library with our parallelized versions for generating 49-frame (6-second) 720x480 videos.
Expand Down
2 changes: 2 additions & 0 deletions docs/performance/cogvideo_zh.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
## CogVideo 性能表现

使用xDiT在CogVideo中的细节: [利用xDiT多GPU并行执行CogVideoX文生视频流程](https://medium.com/@xditproject/aigc%E6%8E%A8%E7%90%86%E5%8A%A0%E9%80%9F-%E5%88%A9%E7%94%A8xdit%E5%B9%B6%E8%A1%8Ccogvideox%E6%96%87%E7%94%9F%E8%A7%86%E9%A2%91%E6%B5%81%E7%A8%8B-86255f9979a9)

CogVideo 是一个文本到视频的模型。xDiT 目前整合了 USP 技术(包括 Ulysses 注意力和 Ring 注意力)和 CFG 并行来提高推理速度,同时 PipeFusion 的工作正在进行中。由于 CogVideo 在视频生成尺寸上的限制,USP 的最大并行级别为 2。因此,xDiT 可以利用最多 4 个 GPU 来执行 CogVideo,尽管机器内可能有更多的 GPU。

在配备 L40(PCIe)GPU 的计算平台上,我们对基于 `diffusers` 库的单 GPU CogVideoX 推理与我们提出的并行化版本在生成 49帧(6秒)720x480 分辨率视频时的性能差异进行了深入分析。
Expand Down
6 changes: 5 additions & 1 deletion examples/flux_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ def main():
else:
pipe = pipe.to(f"cuda:{local_rank}")

parameter_peak_memory = torch.cuda.max_memory_allocated(device=f"cuda:{local_rank}")

pipe.prepare_run(input_config, steps=1)

torch.cuda.reset_peak_memory_stats()
Expand Down Expand Up @@ -69,7 +71,9 @@ def main():
print(f"image {i} saved to ./results/{image_name}")

if get_world_group().rank == get_world_group().world_size - 1:
print(f"epoch time: {elapsed_time:.2f} sec, memory: {peak_memory/1e9} GB")
print(
f"epoch time: {elapsed_time:.2f} sec, parameter memory: {parameter_peak_memory/1e9:.2f} GB, memory: {peak_memory/1e9:.2f} GB"
)
get_runtime_state().destory_distributed_env()


Expand Down
7 changes: 6 additions & 1 deletion examples/hunyuandit_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@ def main():
engine_config=engine_config,
torch_dtype=torch.float16,
).to(f"cuda:{local_rank}")

parameter_peak_memory = torch.cuda.max_memory_allocated(device=f"cuda:{local_rank}")

pipe.prepare_run(input_config)

torch.cuda.reset_peak_memory_stats()
Expand Down Expand Up @@ -63,7 +66,9 @@ def main():
)

if get_world_group().rank == get_world_group().world_size - 1:
print(f"epoch time: {elapsed_time:.2f} sec, memory: {peak_memory/1e9} GB")
print(
f"epoch time: {elapsed_time:.2f} sec, parameter memory: {parameter_peak_memory/1e9:.2f} GB, memory: {peak_memory/1e9:.2f} GB"
)
get_runtime_state().destory_distributed_env()


Expand Down
24 changes: 5 additions & 19 deletions examples/run.sh
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,14 @@ set -x
export PYTHONPATH=$PWD:$PYTHONPATH

# Select the model type
export MODEL_TYPE="Pixart-alpha"
export MODEL_TYPE="Flux"
# Configuration for different model types
# script, model_id, inference_step
declare -A MODEL_CONFIGS=(
["Pixart-alpha"]="pixartalpha_example.py /cfs/dit/PixArt-XL-2-1024-MS 20"
["Pixart-sigma"]="pixartsigma_example.py /cfs/dit/PixArt-Sigma-XL-2-2K-MS 20"
["Sd3"]="sd3_example.py /cfs/dit/stable-diffusion-3-medium-diffusers 20"
["Flux"]="flux_example.py /cfs/dit/FLUX.1-schnell 4"
["Flux"]="flux_example.py /cfs/dit/FLUX.1-dev 28"
["HunyuanDiT"]="hunyuandit_example.py /cfs/dit/HunyuanDiT-v1.2-Diffusers 50"
)

Expand All @@ -27,24 +27,10 @@ mkdir -p ./results
# task args
TASK_ARGS="--height 1024 --width 1024 --no_use_resolution_binning"

# Flux only supports SP. Do not set the pipefusion degree.
if [ "$MODEL_TYPE" = "Flux" ]; then
N_GPUS=8
PARALLEL_ARGS="--ulysses_degree $N_GPUS"
CFG_ARGS=""
PARALLEL_ARGS="--ulysses_degree 1 --ring_degree 1 --pipefusion_parallel_degree 8"

# HunyuanDiT asserts sp_degree == ulysses_degree*ring_degree <= 2, or the output will be incorrect.
elif [ "$MODEL_TYPE" = "HunyuanDiT" ]; then
N_GPUS=8
PARALLEL_ARGS="--pipefusion_parallel_degree 2 --ulysses_degree 2 --ring_degree 1"
CFG_ARGS="--use_cfg_parallel"

else
# On 8 gpus, pp=2, ulysses=2, ring=1, cfg_parallel=2 (split batch)
N_GPUS=8
PARALLEL_ARGS="--pipefusion_parallel_degree 2 --ulysses_degree 2 --ring_degree 1"
CFG_ARGS="--use_cfg_parallel"
fi
# CFG_ARGS="--use_cfg_parallel"

# By default, num_pipeline_patch = pipefusion_degree, and you can tune this parameter to achieve optimal performance.
# PIPEFUSION_ARGS="--num_pipeline_patch 8 "
Expand All @@ -65,7 +51,7 @@ $PIPEFUSION_ARGS \
$OUTPUT_ARGS \
--num_inference_steps $INFERENCE_STEP \
--warmup_steps 0 \
--prompt "A small dog" \
--prompt "brown dog laying on the ground with a metal bowl in front of him." \
$CFG_ARGS \
$PARALLLEL_VAE \
$COMPILE_FLAG
5 changes: 4 additions & 1 deletion examples/sd3_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@ def main():
engine_config=engine_config,
torch_dtype=torch.float16,
).to(f"cuda:{local_rank}")

parameter_peak_memory = torch.cuda.max_memory_allocated(device=f"cuda:{local_rank}")

pipe.prepare_run(input_config)

torch.cuda.reset_peak_memory_stats()
Expand Down Expand Up @@ -63,7 +66,7 @@ def main():

if get_world_group().rank == get_world_group().world_size - 1:
print(
f"{parallel_info} epoch time: {elapsed_time:.2f} sec, memory: {peak_memory/1e9} GB"
f"{parallel_info} epoch time: {elapsed_time:.2f} sec, parameter memory: {parameter_peak_memory/1e9:.2f} GB, peak memory: {peak_memory/1e9:.2f} GB"
)
get_runtime_state().destory_distributed_env()

Expand Down

0 comments on commit a58e4fa

Please sign in to comment.