Skip to content

Commit

Permalink
[Automated Commit] Format Codebase
Browse files Browse the repository at this point in the history
  • Loading branch information
mlcommons-bot committed Jan 15, 2025
1 parent 5bcf8e8 commit b8a7537
Show file tree
Hide file tree
Showing 4 changed files with 62 additions and 31 deletions.
29 changes: 18 additions & 11 deletions language/mixtral-8x7b/standalone_infer/hf_eval_all.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,8 @@ def run_infer(df, ckpt_path, bs):
# Load the model from local if possible.
model_path = Path(ckpt_path)
if not model_path.exists():
raise RuntimeError(f"{ckpt_path} not existed. Please download the checkpoint from mlcommon")
raise RuntimeError(
f"{ckpt_path} not existed. Please download the checkpoint from mlcommon")

tokenizer = AutoTokenizer.from_pretrained(
model_path, padding_side="left", trust_remote_code=True)
Expand All @@ -51,7 +52,8 @@ def run_infer(df, ckpt_path, bs):
tokenizer.pad_token = tokenizer.eos_token
tokenizer.pad_token_id = tokenizer.eos_token_id

# gen parameter. We stop at 1024. Starting from v5.0, min_token is set to 2 to avoid 0-output issue
# gen parameter. We stop at 1024. Starting from v5.0, min_token is set to
# 2 to avoid 0-output issue
gen_kwargs = {
# "min_new_tokens": 1,
"min_new_tokens": 2,
Expand Down Expand Up @@ -80,9 +82,11 @@ def run_infer(df, ckpt_path, bs):
eidx = min(sidx + BS, len(df))

# We use batch_encode_plus for batch inference.
# Note 9/29/2024: Mixtral changed its tokenizer in Jun. Using the Feb 29 2024 version.
# Note 9/29/2024: Mixtral changed its tokenizer in Jun. Using the Feb
# 29 2024 version.
batch_texts = df['input'][sidx:eidx].tolist()
batch_ids = tokenizer.batch_encode_plus(batch_texts, return_tensors="pt", padding=True)
batch_ids = tokenizer.batch_encode_plus(
batch_texts, return_tensors="pt", padding=True)
# tok_input_length = batch_ids['attention_mask'].sum(
# axis=1).to(torch.int32).tolist()
# input_tokens_lens += tok_input_length
Expand All @@ -97,7 +101,7 @@ def run_infer(df, ckpt_path, bs):
batch_ids = batch_ids.to(device)
_, length = batch_ids.input_ids.shape
outputs = model.generate(**batch_ids, num_return_sequences=1,
**gen_kwargs)
**gen_kwargs)

output_ids = outputs[:, length:].cpu().tolist()
output_tokens += output_ids
Expand Down Expand Up @@ -126,6 +130,7 @@ def run_infer(df, ckpt_path, bs):

return output_df


def trim_twos(df):
# Remove all trailing 2s except for 1
def remove_trailing_twos(lst):
Expand All @@ -137,21 +142,25 @@ def remove_trailing_twos(lst):
break
return lst[:-count] if count > 0 else lst

df['infer_tok_ref_output'] = df['infer_tok_ref_output'].apply(remove_trailing_twos)
df['infer_tok_ref_output'] = df['infer_tok_ref_output'].apply(
remove_trailing_twos)
df['trim_lengths'] = df['infer_tok_ref_output'].apply(len)
df['tok_ref_output'] = df['tok_ref_output'].apply(remove_trailing_twos)
df['tok_ref_output_len'] = df['tok_ref_output'].apply(len)
return df


def mbxp_stop(df):
stop_tokens = [13, 13940, 28832, 13]

def modify_list(lst):
for i in range(len(lst) - len(stop_tokens) + 1):
if lst[i:i+len(stop_tokens)] == stop_tokens:
return lst[:i+len(stop_tokens)]
if lst[i:i + len(stop_tokens)] == stop_tokens:
return lst[:i + len(stop_tokens)]
return lst

df.loc[df['dataset'] == 'MBXP', 'infer_tok_ref_output'] = df[df['dataset'] == 'MBXP']['infer_tok_ref_output'].apply(modify_list)
df.loc[df['dataset'] == 'MBXP', 'infer_tok_ref_output'] = df[df['dataset']
== 'MBXP']['infer_tok_ref_output'].apply(modify_list)
df['trim_lengths'] = df['infer_tok_ref_output'].apply(len)
return df

Expand Down Expand Up @@ -190,5 +199,3 @@ def fix_name(df):
df = fix_name(df)

df.to_pickle(args.output_pkl)


57 changes: 40 additions & 17 deletions language/mixtral-8x7b/standalone_infer/run_accuracy.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,8 @@ def calculate_rouge_score(model_outputs, ref_outputs):
m_result = metric.compute(
predictions=m_preds, references=m_targets, use_stemmer=True, use_aggregator=False
)
m_rouge_result = {k: round(np.mean(v) * 100, 4) for k, v in m_result.items()}
m_rouge_result = {k: round(np.mean(v) * 100, 4)
for k, v in m_result.items()}

return m_rouge_result

Expand Down Expand Up @@ -101,30 +102,35 @@ def maybe_remove_comma(x: str) -> str:
def try_float(x: str):
try:
ret = float(x)
except:
except BaseException:
ret = None
return ret


def postprocess_golang(code: str) -> str:
multi_line_imports = re.compile(r"^import \(\n(.+)((?:\n.+)+)\n\)", re.MULTILINE)
multi_line_imports = re.compile(
r"^import \(\n(.+)((?:\n.+)+)\n\)", re.MULTILINE)
line_imports = re.compile(r"^import \".*\"")
func_main = re.compile(r"^func main.*^}", re.MULTILINE | re.DOTALL)

code = code.replace("package main", "") # Remove package main
code = code.replace("package main", "") # Remove package main
code = multi_line_imports.sub("", code)
code = line_imports.sub("", code)
code = func_main.sub("", code)

return code


def postprocess_scala(code: str) -> str:
code = code.replace("object Main extends App {", "")
code = "".join(code.splitlines(True)[:-1])
return code


def postprocess_python(code: str) -> str:
return code.lstrip()


def worker(inp_queue, out_queue):
while True:
try:
Expand All @@ -143,7 +149,7 @@ def worker(inp_queue, out_queue):
try:
solution = solution[:solution.index("```")]
except ValueError:
#Happens when a code block isn't closed properly
# Happens when a code block isn't closed properly
pass

if problem["lang"] == "go":
Expand All @@ -153,15 +159,22 @@ def worker(inp_queue, out_queue):
elif problem["lang"] == "scala":
solution = postprocess_scala(solution)

# Mixtral likes escaping underscores for some reason, so let's remove these
solution = solution.replace("\_", "_")
# Mixtral likes escaping underscores for some reason, so let's remove
# these
solution = solution.replace("\\_", "_")

# The evaluation script evaluates `code = prompt + solution + tests`
# But Mixtral regenerates the prompt in its output, so we should remove this
# But Mixtral regenerates the prompt in its output, so we should remove
# this
problem["prompt"] = ""

result = checker(problem, solution, timeout=20.0)
out_queue.put((key, problem["lang"], result["passed"], result["result"], problem["response"]))
out_queue.put(
(key,
problem["lang"],
result["passed"],
result["result"],
problem["response"]))


def convert_pickle(df: pd.DataFrame, result_keys: dict):
Expand Down Expand Up @@ -193,7 +206,8 @@ def evaluate_mbxp(n_works: int, df: pd.DataFrame, result_keys: dict):
n_problems = 0

for lang, problems in by_lang.items():
if lang not in ["cpp", "python", "php", "javascript", "ruby", "typescript"]:
if lang not in ["cpp", "python", "php",
"javascript", "ruby", "typescript"]:
raise RuntimeError(f"{lang} not in supported list.")

n_problems += len(problems)
Expand All @@ -213,7 +227,10 @@ def evaluate_mbxp(n_works: int, df: pd.DataFrame, result_keys: dict):
lang_counts = {}
for i in tqdm(range(n_problems)):
key, lang, passed, result, response = out_queue.get()
passes[key] = {"passed": passed, "result": result, "response": response}
passes[key] = {
"passed": passed,
"result": result,
"response": response}
n_passed += passed

lang_passed.setdefault(lang, 0)
Expand Down Expand Up @@ -244,7 +261,8 @@ def evaluate_openorca(df: pd.DataFrame, result_keys: dict):
score = calculate_rouge_score(gen_output, gt_output)
gen_token_len = df[result_keys['length']].tolist()
gen_token_per_sample = sum(gen_token_len) / len(gen_token_len)
print(f"OpenOrca score: {score}, gen_token_per_sample: {gen_token_per_sample}")
print(
f"OpenOrca score: {score}, gen_token_per_sample: {gen_token_per_sample}")
return score


Expand All @@ -266,13 +284,18 @@ def evaluate_gsm8k(df: pd.DataFrame, result_keys: dict):
em = correct / total
gen_token_len = df[result_keys['length']].tolist()
gen_token_per_sample = sum(gen_token_len) / len(gen_token_len)
print(f"EM: {em}, correct: {correct} / {total}, gen_token_per_sample: {gen_token_per_sample}")
print(
f"EM: {em}, correct: {correct} / {total}, gen_token_per_sample: {gen_token_per_sample}")
return em


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--n_workers", type=int, default=10, help="The number of processes to use")
parser.add_argument(
"--n_workers",
type=int,
default=10,
help="The number of processes to use")
parser.add_argument("--results_path", type=str, default="mixtral_8x7b_15000_greedy_reference_fp16_mintoken2.pkl",
help="The path to the results file pickle file")
parser.add_argument("--result_key", type=str, default="ref_output",
Expand Down Expand Up @@ -307,9 +330,9 @@ def evaluate_gsm8k(df: pd.DataFrame, result_keys: dict):
"""

df = pd.read_pickle(args.results_path)
df_gsm8k = df[df['dataset']=="GSM8K"].copy()
df_gsm8k = df[df['dataset'] == "GSM8K"].copy()
evaluate_gsm8k(df_gsm8k, result_keys)
df_openorca = df[df['dataset']=="OpenOrca"].copy()
df_openorca = df[df['dataset'] == "OpenOrca"].copy()
evaluate_openorca(df_openorca, result_keys)
df_mbxp = df[df['dataset']=="MBXP"].copy()
df_mbxp = df[df['dataset'] == "MBXP"].copy()
evaluate_mbxp(args.n_workers, df_mbxp, result_keys)
3 changes: 2 additions & 1 deletion loadgen/loadgen.cc
Original file line number Diff line number Diff line change
Expand Up @@ -323,7 +323,8 @@ std::vector<QueryMetadata> GenerateQueries(
size_t pad_size =
(loaded_samples.size() - samples_per_query % loaded_samples.size());
samples_per_query += pad_size;
} else if ((scenario != TestScenario::Offline) && (min_queries % loaded_samples.size() != 0)) {
} else if ((scenario != TestScenario::Offline) &&
(min_queries % loaded_samples.size() != 0)) {
// In Server, SingleStream, MultiStream mode, the min_queries should be
// padded
size_t pad_size =
Expand Down
4 changes: 2 additions & 2 deletions loadgen/test_settings_internal.cc
Original file line number Diff line number Diff line change
Expand Up @@ -757,8 +757,8 @@ int TestSettings::FromConfig(const std::string &path, const std::string &model,
&performance_issue_same_index, nullptr);

if (lookupkv(model, scenario, "sample_concatenate_permutation", &val,
nullptr))
sample_concatenate_permutation = (val == 1) ? true : false;
nullptr))
sample_concatenate_permutation = (val == 1) ? true : false;
if (lookupkv(model, "Server", "coalesce_queries", &val, nullptr))
server_coalesce_queries = (val == 0) ? false : true;
if (lookupkv(model, "Server", "max_async_queries", &val, nullptr))
Expand Down

0 comments on commit b8a7537

Please sign in to comment.