-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
forked fleece-worker, changed bench function
- Loading branch information
0 parents
commit 47e1424
Showing
12 changed files
with
1,780 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
*.pyc | ||
/build/ | ||
/dist/ | ||
/fleece_worker.egg-info/ | ||
/.vscode/ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
## Installation | ||
|
||
### Install From PyPI | ||
``` | ||
pip install fleece-worker | ||
``` | ||
|
||
### Install From Source | ||
``` | ||
pip install -e . | ||
``` | ||
|
||
## Connect to a controller | ||
|
||
``` | ||
python -m fleece-worker -c <controller_url> -t <api_token> | ||
``` | ||
Optional: `--worker-nickname abc`, `--heartbeat-interval 10`, `-w <worker_url>` | ||
|
||
For example: | ||
|
||
``` | ||
python -m fleece-worker -c https://serving-api.colearn.cloud:8443 -t <api_token> | ||
``` | ||
|
||
## Try it out (deprecated) | ||
|
||
``` | ||
python -m fleece-worker | ||
``` | ||
|
||
``` | ||
curl localhost:8080/forward -H 'Content-Type: application/json' -d '{"task_id":"123","step":0,"round":0,"plan":[["http://127.0.0.1:8080",["llama-2-7b-chat-slice/tok_embeddings", "llama-2-7b-chat-slice/layers.0", "llama-2-7b-chat-slice/layers.1", "llama-2-7b-chat-slice/layers.2", "llama-2-7b-chat-slice/layers.3", "llama-2-7b-chat-slice/layers.4", "llama-2-7b-chat-slice/layers.5", "llama-2-7b-chat-slice/layers.6", "llama-2-7b-chat-slice/layers.7", "llama-2-7b-chat-slice/layers.8", "llama-2-7b-chat-slice/layers.9", "llama-2-7b-chat-slice/layers.10", "llama-2-7b-chat-slice/layers.11", "llama-2-7b-chat-slice/layers.12", "llama-2-7b-chat-slice/layers.13", "llama-2-7b-chat-slice/layers.14", "llama-2-7b-chat-slice/layers.15", "llama-2-7b-chat-slice/layers.16", "llama-2-7b-chat-slice/layers.17", "llama-2-7b-chat-slice/layers.18", "llama-2-7b-chat-slice/layers.19", "llama-2-7b-chat-slice/layers.20", "llama-2-7b-chat-slice/layers.21", "llama-2-7b-chat-slice/layers.22", "llama-2-7b-chat-slice/layers.23", "llama-2-7b-chat-slice/layers.24", "llama-2-7b-chat-slice/layers.25", "llama-2-7b-chat-slice/layers.26", "llama-2-7b-chat-slice/layers.27", "llama-2-7b-chat-slice/layers.28", "llama-2-7b-chat-slice/layers.29", "llama-2-7b-chat-slice/layers.30", "llama-2-7b-chat-slice/layers.31", "llama-2-7b-chat-slice/norm", "llama-2-7b-chat-slice/output"]]],"payload":[[1, 518, 25580, 29962, 825, 338, 278, 9522, 412, 310, 1122, 11586, 895, 29973, 518, 29914, 25580, 29962]]}' | ||
``` | ||
``` | ||
curl localhost:8080/forward -H 'Content-Type: application/json' -d '{"task_id":"123","step":0,"round":0,"plan":[["http://127.0.0.1:8080",["llama-2-7b-chat-slice/tok_embeddings", "llama-2-7b-chat-slice/layers.0", "llama-2-7b-chat-slice/layers.1", "llama-2-7b-chat-slice/layers.2", "llama-2-7b-chat-slice/layers.3", "llama-2-7b-chat-slice/layers.4", "llama-2-7b-chat-slice/layers.5", "llama-2-7b-chat-slice/layers.6", "llama-2-7b-chat-slice/layers.7", "llama-2-7b-chat-slice/layers.8", "llama-2-7b-chat-slice/layers.9", "llama-2-7b-chat-slice/layers.10", "llama-2-7b-chat-slice/layers.11", "llama-2-7b-chat-slice/layers.12", "llama-2-7b-chat-slice/layers.13", "llama-2-7b-chat-slice/layers.14", "llama-2-7b-chat-slice/layers.15", "llama-2-7b-chat-slice/layers.16", "llama-2-7b-chat-slice/layers.17", "llama-2-7b-chat-slice/layers.18", "llama-2-7b-chat-slice/layers.19", "llama-2-7b-chat-slice/layers.20", "llama-2-7b-chat-slice/layers.21", "llama-2-7b-chat-slice/layers.22", "llama-2-7b-chat-slice/layers.23", "llama-2-7b-chat-slice/layers.24", "llama-2-7b-chat-slice/layers.25", "llama-2-7b-chat-slice/layers.26", "llama-2-7b-chat-slice/layers.27", "llama-2-7b-chat-slice/layers.28", "llama-2-7b-chat-slice/layers.29", "llama-2-7b-chat-slice/layers.30", "llama-2-7b-chat-slice/layers.31", "llama-2-7b-chat-slice/norm", "llama-2-7b-chat-slice/output"]]],"payload":[[1, 518, 25580, 29962, 825, 338, 278, 9522, 412, 310, 1122, 11586, 895, 29973, 518, 29914, 25580, 29962], [1, 518, 25580, 29962, 3532, 14816, 29903, 6778, 13, 2499, 1994, 1234, 411, 5952, 18282, 13, 29966, 829, 14816, 29903, 6778, 13, 13, 29902, 626, 2675, 304, 3681, 29892, 825, 881, 306, 1074, 29973, 518, 29914, 25580, 29962], [1, 518, 25580, 29962, 3532, 14816, 29903, 6778, 13, 2499, 1994, 1234, 411, 953, 3848, 275, 13, 29966, 829, 14816, 29903, 6778, 13, 13, 5328, 304, 748, 515, 1522, 823, 292, 304, 23526, 29973, 518, 29914, 25580, 29962]]}' | ||
``` | ||
> note that the model will be automatically downloaded to `~/.cache` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,57 @@ | ||
import torch | ||
import pandas as pd | ||
import argparse | ||
|
||
fleece_worker = __import__("fleece-worker") | ||
|
||
worker = fleece_worker.Worker() | ||
|
||
worker.start_layer_forward_engine() | ||
|
||
parser = argparse.ArgumentParser(description='Run the estimation') | ||
parser.add_argument('--model', '-m', type=str, required=True) | ||
args = parser.parse_args() | ||
model = args.model | ||
df_layers = pd.read_csv("./specs/fleece_layers.csv") | ||
layer_names = [] | ||
for idx, row in df_layers.iterrows(): | ||
if not row["From_model"] == model: | ||
continue | ||
layer_name = row["Layer_name"] | ||
if row["Repetition"] == 1: | ||
layer_names.append(layer_name) | ||
else: | ||
for i in range(min(row["Repetition"], 5)): | ||
layer_names.append(f"{layer_name}.{i}") | ||
# example 1 | ||
print("[") | ||
worker.preload_layers(layer_names) | ||
h = torch.tensor([[1, 518, 25580, 29962, 825, 338, 278, 9522, 412, 310, 1122, 11586, 895, 29973, 518, 29914, 25580, 29962]], device="cuda") | ||
start_pos = 0 | ||
is_new_task = start_pos == 0 | ||
kv_cache_dict = dict() | ||
for _ in range(16): | ||
bsz = h.shape[0] | ||
seqlen = h.shape[1] | ||
_, kv_cache_dict = worker.layers_forward(h, layer_names, bsz, is_new_task, 0, start_pos, seqlen, kv_cache_dict) | ||
is_new_task = False | ||
start_pos += seqlen | ||
h = torch.tensor([[29962]], device="cuda") | ||
|
||
# # example 2 | ||
hidden_dim = 4092 if model == "llama-2-7b-chat-slice" else 8192 if model == "llama-2-70b-chat-slice" else 8192 | ||
layer_names = [f"{model}/layers.0", f"{model}/layers.1"] | ||
worker.preload_layers(layer_names) | ||
h = torch.randn((1, 18, 8192), dtype=torch.float16, device="cuda") | ||
start_pos = 0 | ||
is_new_task = start_pos == 0 | ||
kv_cache_dict = dict() | ||
for _ in range(16): | ||
bsz = h.shape[0] | ||
seqlen = h.shape[1] | ||
_, kv_cache_dict = worker.layers_forward(h, layer_names, bsz, is_new_task, 0, start_pos, seqlen, kv_cache_dict) | ||
is_new_task = False | ||
start_pos += seqlen | ||
h = torch.randn((1, 1, 8192), dtype=torch.float16, device="cuda") | ||
|
||
print("]") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
__version__ = "0.1.0" | ||
|
||
from .worker import Worker |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,187 @@ | ||
from typing import List, Tuple, Optional | ||
from fastapi import FastAPI, HTTPException | ||
from peerrtc.peer import Peer | ||
from pydantic import BaseModel | ||
import anyio | ||
import uvicorn | ||
from .worker import Worker | ||
from .__init__ import __version__ | ||
import argparse | ||
import requests | ||
import json | ||
import torch | ||
import concurrent.futures | ||
from anyio.from_thread import BlockingPortal | ||
|
||
app = FastAPI() | ||
worker = Worker() | ||
|
||
|
||
class LayersRequest(BaseModel): | ||
layer_names: List[str] | ||
|
||
|
||
def preload_layers(req: LayersRequest): | ||
try: | ||
worker.preload_layers(req.layer_names) | ||
return None | ||
except Exception as e: | ||
print(e) | ||
raise HTTPException(status_code=500, detail="Internal Server Error") | ||
|
||
|
||
def unload_layers(req: LayersRequest): | ||
try: | ||
worker.unload_layers(req.layer_names) | ||
return None | ||
except Exception as e: | ||
print(e) | ||
raise HTTPException(status_code=500, detail="Internal Server Error") | ||
|
||
|
||
class ForwardRequest(BaseModel): | ||
task_id: str | ||
plan: List[Tuple[str, List[str]]] | ||
step: int | ||
round: int = -1 | ||
payload: Optional[List] = None | ||
max_total_len: int = 2048 | ||
temperature: float = 0.0 | ||
top_p: float = 0.9 | ||
task_manager_url: Optional[str] = None | ||
signature: Optional[str] = None | ||
timestamp: Optional[int] = None | ||
|
||
|
||
executor = concurrent.futures.ThreadPoolExecutor(max_workers=64) | ||
|
||
|
||
def forward(req: ForwardRequest): | ||
try: | ||
executor.submit(worker.forward, req.task_id, req.plan, req.step, req.round, req.payload, req.max_total_len, req.temperature, req.top_p, | ||
req.task_manager_url, req.signature, req.timestamp) | ||
return None | ||
except Exception as e: | ||
print(e) | ||
raise HTTPException(status_code=500, detail="Internal Server Error") | ||
|
||
|
||
class GetInfoRequest(BaseModel): | ||
node_list: List[str] = [] | ||
timeout: int = 30 | ||
|
||
|
||
class GetInfoResponse(BaseModel): | ||
worker_nickname: str | ||
gpu_mem_info: Tuple[int, int] = [0, 0] | ||
latency_list: List[Optional[float]] = [] | ||
|
||
|
||
def get_info(req: GetInfoRequest) -> GetInfoResponse: | ||
try: | ||
worker_nickname, gpu_mem_info, latency_list = worker.get_info( | ||
req.node_list, req.timeout | ||
) | ||
return GetInfoResponse( | ||
worker_nickname=worker_nickname, | ||
gpu_mem_info=gpu_mem_info, | ||
latency_list=latency_list, | ||
) | ||
except Exception as e: | ||
print(e) | ||
raise HTTPException(status_code=500, detail="Internal Server Error") | ||
|
||
|
||
async def main() -> None: | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument("-c", "--controller-url") | ||
parser.add_argument("-w", "--worker-url") | ||
parser.add_argument("-t", "--api-token") | ||
parser.add_argument("--port") | ||
parser.add_argument("--worker-nickname") | ||
parser.add_argument("--heartbeat-interval") | ||
args = parser.parse_args() | ||
if args.worker_url is not None: | ||
worker_url = args.worker_url | ||
parsed = worker_url.split(':') | ||
if len(parsed) >= 3: | ||
port = int(parsed[2]) | ||
else: | ||
port = 8080 | ||
else: | ||
worker_url = "none" | ||
port = 8080 | ||
if args.port is not None: | ||
port = int(args.port) | ||
worker.port = port | ||
if args.api_token is not None: | ||
worker.api_token = args.api_token | ||
if args.worker_nickname is not None: | ||
worker.worker_nickname = args.worker_nickname | ||
if args.heartbeat_interval is not None: | ||
worker.heartbeat_interval = int(args.heartbeat_interval) | ||
if args.controller_url is not None: | ||
worker.controller_url = args.controller_url | ||
data = {"url": worker_url, "version": __version__} | ||
if worker.worker_nickname is not None: | ||
data["nickname"] = worker.worker_nickname | ||
if torch.cuda.is_available(): | ||
model = torch.cuda.get_device_name() | ||
memory = torch.cuda.mem_get_info() | ||
data["gpu_model"] = model | ||
data["gpu_total_memory"] = memory[1] | ||
data["gpu_remaining_memory"] = memory[0] | ||
else: | ||
data["gpu_model"] = "CPU" | ||
data["gpu_total_memory"] = 0 | ||
data["gpu_remaining_memory"] = 0 | ||
r = requests.post(f"{args.controller_url}/register_worker", | ||
json=data, | ||
headers={"api-token": worker.api_token}) | ||
res = json.loads(r.content) | ||
worker.worker_id = res["id"] | ||
worker.pull_worker_url() | ||
worker.start_heartbeat_daemon() | ||
worker.start_layer_forward_engine() | ||
|
||
print("Worker ID: ", worker.worker_id) | ||
|
||
r = requests.get( | ||
f"{args.controller_url}/get_network_servers", | ||
headers={"api-token": worker.api_token} | ||
) | ||
|
||
servers = json.loads(r.content) | ||
signaling = servers["signaling"]["url"] | ||
turns = servers["turn"] | ||
async with BlockingPortal() as portal: | ||
worker.async_portal = portal | ||
async with anyio.create_task_group() as tg: | ||
worker.peer = Peer( | ||
worker.worker_id, | ||
signaling, | ||
[(turn["url"], turn["username"], turn["password"]) for turn in turns], | ||
{ | ||
"preload_layers": preload_layers, | ||
"unload_layers": unload_layers, | ||
"forward": forward, | ||
"get_info": get_info, | ||
}, | ||
tg, | ||
) | ||
|
||
# start the FastAPI server when public IP is available | ||
if worker_url != "none": | ||
app.add_api_route("/preload_layers", preload_layers, methods=["POST"]) | ||
app.add_api_route("/unload_layers", unload_layers, methods=["POST"]) | ||
app.add_api_route("/forward", forward, methods=["POST"]) | ||
app.add_api_route("/get_info", get_info, methods=["POST"]) | ||
|
||
uviconfig = uvicorn.Config(app, host="0.0.0.0", port=port, access_log=True) | ||
uviserver = uvicorn.Server(uviconfig) | ||
tg.start_soon(uviserver.serve) | ||
await portal.sleep_until_stopped() | ||
|
||
|
||
if __name__ == '__main__': | ||
anyio.run(main) |
Oops, something went wrong.