Skip to content

Commit

Permalink
concurrent access to text-to-image http service. (#359)
Browse files Browse the repository at this point in the history
  • Loading branch information
feifeibear authored Nov 22, 2024
1 parent 353bba9 commit c203225
Show file tree
Hide file tree
Showing 6 changed files with 157 additions and 93 deletions.
12 changes: 12 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -468,3 +468,15 @@ We also welcome developers to join and contribute more features and models to th
}
```
[Unveiling Redundancy in Diffusion Transformers (DiTs): A Systematic Study](https://arxiv.org/abs/2411.13588)
@misc{sun2024unveilingredundancydiffusiontransformers,
title={Unveiling Redundancy in Diffusion Transformers (DiTs): A Systematic Study},
author={Xibo Sun and Jiarui Fang and Aoyu Li and Jinzhe Pan},
year={2024},
eprint={2411.13588},
archivePrefix={arXiv},
primaryClass={cs.CV},
url={https://arxiv.org/abs/2411.13588},
}
32 changes: 10 additions & 22 deletions docs/developer/Http_Service.md
Original file line number Diff line number Diff line change
@@ -1,26 +1,14 @@
## Launching a Text-to-Image Http Service
## Launch a Text-to-Image Service

### Creating the Service Image
Launch an HTTP-based text-to-image service that generates images from textual descriptions (prompts) using the DiT model. The generated images can either be returned directly to users or saved to a specified disk location. To enhance processing efficiency, we've implemented a concurrent processing mechanism: requests containing prompts are stored in a request queue, and DiT processes these requests in parallel across multiple GPUs.

```bash
python ./http-service/launch_host.py --config ./http-service/config.json --max_queue_size 4
```
docker build -t xdit-server:0.3.1 -f ./docker/Dockerfile .
```

or (version number may need to be updated)

```
docker pull thufeifeibear/xdit-service:0.3.1
```

Start the service using the following command. The service-related parameters are written in the configuration script `config.json`. We have mapped disk files to the Docker container because we need to pass the downloaded model files. Note the mapping of port 6000; if there is a conflict, please modify it.

```
docker run --gpus all -it --ipc=host --ulimit memlock=-1 --ulimit stack=67108864 -p 6000:6000 -v /cfs:/cfs xdit-server:0.3.1 --config ./config.json
```

The content of `./config.json` includes the number of GPUs to use, the mixed parallelism strategy, the size of the output images, the storage location for generated images, and other information.
The default content in `./config.json` is shown below, which includes settings for the number of GPU cards, hybrid parallelism strategy, output image dimensions, and image storage location:

```
```json
{
"nproc_per_node": 2,
"model": "/cfs/dit/HunyuanDiT-v1.2-Diffusers",
Expand All @@ -34,10 +22,10 @@ The content of `./config.json` includes the number of GPUs to use, the mixed par
}
```

Access the service using an HTTP request. The `save_disk_path` is an optional parameter. If not set, an image is returned; if set, the generated image is saved in the corresponding directory on the disk.
To interact with the service, send HTTP requests as shown below. The `save_disk_path` parameter is optional - if not set, the image will be returned directly; if set, the generated image will be saved to the specified directory on disk.

```
curl -X POST http://127.0.0.1:6001/generate \
```bash
curl -X POST http://127.0.0.1:6000/generate \
-H "Content-Type: application/json" \
-d '{
"prompt": "A lovely rabbit",
Expand All @@ -46,4 +34,4 @@ curl -X POST http://127.0.0.1:6001/generate \
"cfg": 7.5,
"save_disk_path": "/tmp"
}'
```
```
20 changes: 4 additions & 16 deletions docs/developer/Http_Service_zh.md
Original file line number Diff line number Diff line change
@@ -1,23 +1,13 @@
## 启动一个文生图服务

### 制作服务镜像

```
docker build -t xdit-service -f ./docker/Dockerfile .
```

或者直接从dockerhub拉取(版本号可能需要更新)
```
docker pull thufeifeibear/xdit-service
```

用下面方式启动一个服务,服务相关参数写在配置脚本config.json里。我们映射了磁盘文件到docker container中,因为需要传递下载的模型文件。注意映射端口6000,如果冲突请修改。
启动一个基于HTTP的文本生成图像服务。该服务接收用户的文本描述(prompt),利用DiT模型生成相应的图像。生成的图像可以直接返回给用户,或保存到指定的磁盘位置。为了提高处理效率,我们实现了一个并发处理机制:使用请求队列来存储incoming requests,并通过xdit在多个GPU上并行处理队列中的请求。

```
docker run --gpus all -it --ipc=host --ulimit memlock=-1 --ulimit stack=67108864 -p 6000:6000 -v /cfs:/cfs xdit-service --config ./config.json
python ./http-service/launch_host.py --config ./http-service/config.json --max_queue_size 4
```

./config.json中内容如下,包括启动GPU卡数,混合并行策略,输出图片的大小,生成图片存储位置等信息。
./config.json中默认内容如下,包括启动GPU卡数,混合并行策略,输出图片的大小,生成图片存储位置等信息。

```
{
Expand All @@ -35,9 +25,8 @@ docker run --gpus all -it --ipc=host --ulimit memlock=-1 --ulimit stack=67108864

使用http请求访问服务,"save_disk_path"是一个可选项,如果不设置则返回一个图片,如果设置则将生成图片存在磁盘上对应位置的目录中。


```
curl -X POST http://127.0.0.1:6001/generate \
curl -X POST http://127.0.0.1:6000/generate \
-H "Content-Type: application/json" \
-d '{
"prompt": "A lovely rabbit",
Expand All @@ -47,4 +36,3 @@ curl -X POST http://127.0.0.1:6001/generate \
"save_disk_path": "/tmp"
}'
```

2 changes: 1 addition & 1 deletion http-service/config.json
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,4 @@
"width": 512,
"save_disk_path": "/cfs/dit/output",
"use_cfg_parallel": false
}
}
181 changes: 128 additions & 53 deletions http-service/host.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,10 @@
import logging
import base64
import torch.multiprocessing as mp
from queue import Queue
import threading
import asyncio
from collections import deque

from PIL import Image
from flask import Flask, request, jsonify
Expand Down Expand Up @@ -39,6 +43,13 @@
local_rank = None
logger = None
initialized = False
args = None

# a global queue to store request prompts
request_queue = deque()
queue_lock = threading.Lock()
queue_event = threading.Event()
results_store = {} # store request results


def setup_logger():
Expand All @@ -62,10 +73,12 @@ def check_initialize():


def initialize():
global pipe, engine_config, input_config, local_rank, initialized
global pipe, engine_config, input_config, local_rank, initialized, args
mp.set_start_method("spawn", force=True)

parser = FlexibleArgumentParser(description="xFuser Arguments")
parser.add_argument("--max_queue_size", type=int, default=4,
help="Maximum size of the request queue")
args = xFuserArgs.add_cli_args(parser).parse_args()
engine_args = xFuserArgs.from_cli_args(args)
engine_config, input_config = engine_args.create_config()
Expand Down Expand Up @@ -161,65 +174,127 @@ def generate_image_parallel(


@app.route("/generate", methods=["POST"])
def generate_image():
def queue_image_request():
logger.info("Received POST request for image generation")
data = request.json
prompt = data.get("prompt", input_config.prompt)
num_inference_steps = data.get(
"num_inference_steps", input_config.num_inference_steps
)
seed = data.get("seed", input_config.seed)
cfg = data.get("cfg", 8.0)
save_disk_path = data.get("save_disk_path")

# Check if save_disk_path is valid, if not, set it to a default directory
if save_disk_path:
if not os.path.isdir(save_disk_path):
default_path = os.path.join(os.path.expanduser("~"), "tacodit_output")
os.makedirs(default_path, exist_ok=True)
logger.warning(
f"Invalid save_disk_path. Using default path: {default_path}"
)
save_disk_path = default_path
else:
save_disk_path = None

logger.info(
f"Request parameters: prompt='{prompt}', steps={num_inference_steps}, seed={seed}, save_disk_path={save_disk_path}"
)
# Broadcast request parameters to all processes
params = [prompt, num_inference_steps, seed, cfg, save_disk_path]
dist.broadcast_object_list(params, src=0)
logger.info("Parameters broadcasted to all processes")

output, elapsed_time = generate_image_parallel(*params)

if save_disk_path:
# output is a disk path
output_base64 = ""
image_path = save_disk_path
else:
# Ensure output is not None before accessing its attributes
if output and hasattr(output, "images") and output.images:
output_base64 = base64.b64encode(output.images[0].tobytes()).decode("utf-8")
else:
output_base64 = ""
image_path = ""

response = {
"message": "Image generated successfully",
"elapsed_time": f"{elapsed_time:.2f} sec",
"output": output_base64 if not save_disk_path else output,
"save_to_disk": save_disk_path is not None,
}

# logger.info(f"Sending response: {response}")
return jsonify(response)
request_id = str(time.time())

with queue_lock:
# Check queue size
if len(request_queue) >= args.max_queue_size:
return jsonify({
"error": "Queue is full, please try again later",
"queue_size": len(request_queue)
}), 503

request_params = {
"id": request_id,
"prompt": data.get("prompt", input_config.prompt),
"num_inference_steps": data.get("num_inference_steps", input_config.num_inference_steps),
"seed": data.get("seed", input_config.seed),
"cfg": data.get("cfg", 8.0),
"save_disk_path": data.get("save_disk_path")
}

request_queue.append(request_params)
queue_event.set()

return jsonify({
"message": "Request accepted",
"request_id": request_id,
"status_url": f"/status/{request_id}"
}), 202

@app.route("/status/<request_id>", methods=["GET"])
def check_status(request_id):
if request_id in results_store:
result = results_store.pop(request_id)
return jsonify(result), 200

position = None
with queue_lock:
for i, req in enumerate(request_queue):
if req["id"] == request_id:
position = i
break

if position is not None:
return jsonify({
"status": "pending",
"queue_position": position
}), 202

return jsonify({"status": "not_found"}), 404

def process_queue():
while True:
queue_event.wait()

with queue_lock:
if not request_queue:
queue_event.clear()
continue

params = request_queue.popleft()
if not request_queue:
queue_event.clear()

try:
# Extract parameters
request_id = params["id"]
prompt = params["prompt"]
num_inference_steps = params["num_inference_steps"]
seed = params["seed"]
cfg = params["cfg"]
save_disk_path = params["save_disk_path"]

# Broadcast parameters to all processes
broadcast_params = [prompt, num_inference_steps, seed, cfg, save_disk_path]
dist.broadcast_object_list(broadcast_params, src=0)

# Generate image and get results
output, elapsed_time = generate_image_parallel(*broadcast_params)

# Process output results
if save_disk_path:
# output is disk path
result = {
"message": "Image generated successfully",
"elapsed_time": f"{elapsed_time:.2f} sec",
"output": output, # This is the file path
"save_to_disk": True
}
else:
# Process base64 output
if output and hasattr(output, "images") and output.images:
output_base64 = base64.b64encode(output.images[0].tobytes()).decode("utf-8")
else:
output_base64 = ""

result = {
"message": "Image generated successfully",
"elapsed_time": f"{elapsed_time:.2f} sec",
"output": output_base64,
"save_to_disk": False
}

# Store results
results_store[request_id] = result

except Exception as e:
logger.error(f"Error processing request {params['id']}: {str(e)}")
results_store[request_id] = {
"error": str(e),
"status": "failed"
}


def run_host():
if dist.get_rank() == 0:
logger.info("Starting Flask host on rank 0")
# process 0 will process the queue in a separate thread
queue_thread = threading.Thread(target=process_queue, daemon=True)
queue_thread.start()
app.run(host="0.0.0.0", port=6000)
else:
while True:
Expand Down
3 changes: 2 additions & 1 deletion http-service/launch_host.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ def build_command(config):
f"--ring_degree={config['ring_degree']}",
f"--height={config['height']}",
f"--width={config['width']}",
f"--max_queue_size={config.get('max_queue_size', 4)}",
]
if config.get("use_cfg_parallel", False):
cmd.append("--use_cfg_parallel")
Expand All @@ -31,7 +32,7 @@ def main():
parser.add_argument(
"--config",
type=str,
default="config.json",
default="./http-service/config.json",
help="Path to the configuration file",
)
args = parser.parse_args()
Expand Down

0 comments on commit c203225

Please sign in to comment.