Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support to vLLM backend #34

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 16 additions & 6 deletions mix_eval/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,13 @@ def parse_args():
"Set this properly will allocate more memory for activations, "
"so you can use longer context lengths or larger batch sizes."
)
parser.add_argument(
"--cpu_offload_gb",
type=int,
default=None,
help="Amount of memory (in GB) to offload to CPU for loading the weights. "
"Only valid with vLLM models."
)
parser.add_argument(
"--api_parallel_num",
type=int,
Expand Down Expand Up @@ -152,7 +159,7 @@ def parse_args():
return parser.parse_args()


def _eval(args):
def _eval(args, model=None):
print(f"\n\nStart to evaluate {args.model_name}'s {args.split} split. \n\n")
time_elapsed = 0
start_time = time.time()
Expand Down Expand Up @@ -192,7 +199,8 @@ def _eval(args):
"lines as recorded in cached metadadta. Please check the response file. "
"You might consider delete the response and metadata file to start from scratch.")

model = mix_eval.api.registry.get_model(args.model_name)(args)
if model is None:
model = mix_eval.api.registry.get_model(args.model_name)(args)
eval_dataset = get_eval_dataset(args)
dataloader = DataLoader(
eval_dataset,
Expand Down Expand Up @@ -235,18 +243,20 @@ def _eval(args):
print(f"Finished evaluating {args.model_name}'s {args.split} split. "
f"Used {round(time_elapsed / 60, 2)} minutes.")

return model

def eval(args):
model = None
if args.benchmark == "mixeval":
args.split = "close_freeform"
_eval(args)
model = _eval(args, model)
args.split = "close_multichoice"
_eval(args)
_eval(args, model)
elif args.benchmark == "mixeval_hard":
args.split = "close_freeform_hard"
_eval(args)
model = _eval(args, model)
args.split = "close_multichoice_hard"
_eval(args)
_eval(args, model)
else:
raise ValueError(f"Benchmark {args.benchmark} not supported.")

Expand Down
2 changes: 2 additions & 0 deletions mix_eval/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
"llama_3_70b": "Llama_3_70B",
"llama_3_70b_instruct": "Llama_3_70B_Instruct",

"llama_3_8b_instruct_vllm": "Llama_3_8B_Instruct_vLLM",

"qwen_15_4b": "Qwen_15_4B",
"qwen_15_7b": "Qwen_15_7B",
"qwen_15_32b": "Qwen_15_32B",
Expand Down
58 changes: 58 additions & 0 deletions mix_eval/models/llama_3_8b_instruct_vllm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
from dotenv import load_dotenv
import os

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

from mix_eval.models.vllm import ChatModelVLLM
from mix_eval.api.registry import register_model
from mix_eval.utils.common_utils import get_gpu_memory

@register_model("llama_3_8b_instruct_vllm")
class Llama_3_8B_Instruct_vLLM(ChatModelVLLM):
def __init__(self, args):
super().__init__(args)
self.model_name = "meta-llama/Meta-Llama-3-8B-Instruct"

self.SYSTEM_MESSAGE = {"role": "system", "content": "You are a pirate chatbot who always responds in pirate speak!"} # set to None if no system message
self.USER_MESSAGE_TEMPLATE = lambda x: {"role": "user", "content": x}
self.ASSISTANT_MESSAGE_TEMPLATE = lambda x: {"role": "assistant", "content": x}

self.model_dtype = torch.bfloat16

load_dotenv()
self.hf_token = os.getenv('_FADKLFHAKH_')
self.model = self.build_model()
self.model_max_len = 8192
self.tokenizer = self.build_tokenizer()
self.tokenizer.pad_token = self.tokenizer.eos_token
self.max_input_length_closeend = min(
self.model_max_len,
self.max_input_length
) - self.closeended_max_new_tokens
self.max_input_length_openend = min(
self.model_max_len,
self.max_input_length
) - self.openended_max_new_tokens


terminators = [
self.tokenizer.eos_token_id,
self.tokenizer.convert_tokens_to_ids("<|eot_id|>")
]

self.gen_kwargs = {
'temperature': 0.6,
'top_p': 0.9,
'stop_token_ids': terminators,
}

def build_tokenizer(self):
tokenizer = AutoTokenizer.from_pretrained(
self.model_name,
model_max_length=self.model_max_len,
padding_side=self.padding_side,
use_fast=self.use_fast_tokenizer,
trust_remote_code=self.trust_remote_code,
token=self.hf_token,)
return tokenizer
106 changes: 106 additions & 0 deletions mix_eval/models/vllm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
from .base import ChatModel, BaseModel
from vllm import LLM, SamplingParams
import torch
import json

class ChatModelVLLM(ChatModel):
def build_model(self):
num_gpus = torch.cuda.device_count()

if self.args.cpu_offload_gb:
return LLM(model=self.model_name, tensor_parallel_size=num_gpus, enable_chunked_prefill=True, distributed_executor_backend="ray", cpu_offload_gb=self.args.cpu_offload_gb)
else:
return LLM(model=self.model_name, tensor_parallel_size=num_gpus, enable_chunked_prefill=True, distributed_executor_backend="ray")

def get_closeended_responses(self, batch, response_file):
sampling_params = SamplingParams(max_tokens=self.closeended_max_new_tokens, **self.gen_kwargs)
formated_prompts = [d['raw_inputs']['formated_input'] for d in batch]
inputs = [self.apply_chat_template(self.get_messages(prompt)) for prompt in formated_prompts]

outputs = self.model.generate(inputs, sampling_params)
responses = [output.outputs[0].text for output in outputs]

with open(response_file, "a") as f:
for raw_dict, response in zip(batch, responses):
raw_dict = raw_dict['raw_inputs']
raw_dict['response'] = response
f.write(json.dumps(raw_dict) + "\n")

def get_openended_responses(self, batch, response_file):
sampling_params = SamplingParams(max_tokens=self.closeended_max_new_tokens, **self.gen_kwargs)

messages_batch = [
[
self.SYSTEM_MESSAGE.copy(),
] if self.SYSTEM_MESSAGE is not None else []
for _ in batch
]
turns_batch = [d['raw_inputs']['turns'] for d in batch]
turn_num = len(turns_batch[0])
for turns in turns_batch:
assert len(turns) == turn_num, "All dialogues should have the same number of turns."

responses_all = []
for i in range(turn_num):
for turns, messages in zip(turns_batch, messages_batch):
messages.append(self.USER_MESSAGE_TEMPLATE(turns[i]))
inputs = [self.apply_chat_template(messages) for messages in messages_batch]

outputs = self.model.generate(inputs, sampling_params)
responses = [output.outputs[0].text for output in outputs]

responses_all.append(responses)
for response, messages in zip(responses, messages_batch):
messages.append(self.ASSISTANT_MESSAGE_TEMPLATE(response))

responses_all = list(zip(*responses_all))

with open(response_file, "a") as f:
for raw_dict, response in zip(batch, responses_all):
raw_dict = raw_dict['raw_inputs']
raw_dict['response'] = response
f.write(json.dumps(raw_dict) + "\n")

class BaseModelVLLM(BaseModel):
def build_model(self):
num_gpus = torch.cuda.device_count()

if self.args.cpu_offload_gb:
return LLM(model=self.model_name, tensor_parallel_size=num_gpus, enable_chunked_prefill=True, distributed_executor_backend="ray", cpu_offload_gb=self.args.cpu_offload_gb)
else:
return LLM(model=self.model_name, tensor_parallel_size=num_gpus, enable_chunked_prefill=True, distributed_executor_backend="ray")

def get_closeended_responses(self, batch, response_file):
formated_prompts = [d['raw_inputs']['formated_input'] for d in batch]

# add few-shot prompts
if self.args.split == 'close_multichoice' or self.args.split == 'close_multichoice_hard':
formated_prompts = [
FIVE_SHOT_PREFIX_MULTIPLECHOICE + prompt + '\n'
for prompt in formated_prompts
]
elif self.args.split == 'close_freeform' or self.args.split == 'close_freeform_hard':
formated_prompts = [
FIVE_SHOT_PREFIX_FREEFORM + prompt + '\n'
for prompt in formated_prompts]
else:
raise ValueError(f"Split {self.args.split} not supported in "
f"{self.__class__.__name__}: get_closeended_responses()")

for _fp, _b in zip(formated_prompts, batch):
_b['raw_inputs']['formated_input'] = _fp

outputs = self.model.generate(formated_prompts, sampling_params)
responses = [output.outputs[0].text for output in outputs]

with open(response_file, "a") as f:
for raw_dict, response in zip(batch, responses):
raw_dict = raw_dict['raw_inputs']
raw_dict['response'] = response
f.write(json.dumps(raw_dict) + "\n")

with open(response_file, "a") as f:
for raw_dict, response in zip(batch, responses):
raw_dict = raw_dict['raw_inputs']
raw_dict['response'] = response
f.write(json.dumps(raw_dict) + "\n")