From e8fadb2239b0cd5ba532e0800dbb033e9d807260 Mon Sep 17 00:00:00 2001 From: moutozf <73951041+moutozf@users.noreply.github.com> Date: Tue, 12 Mar 2024 20:58:05 +0800 Subject: [PATCH] Bird eval all zf (#240) great work, thanks a lot ~ --- dbgpt_hub/eval/evaluation_bird.py | 163 +++++++++++++++++++++++------- 1 file changed, 125 insertions(+), 38 deletions(-) diff --git a/dbgpt_hub/eval/evaluation_bird.py b/dbgpt_hub/eval/evaluation_bird.py index b84744a..aee545e 100644 --- a/dbgpt_hub/eval/evaluation_bird.py +++ b/dbgpt_hub/eval/evaluation_bird.py @@ -2,13 +2,14 @@ do evaluate about the predict sql in dataset BIRD,compare with default dev.sql --db """ - import sys import json import argparse import sqlite3 import multiprocessing as mp from func_timeout import func_timeout, FunctionTimedOut +import math +import time def load_json(dir): @@ -25,19 +26,25 @@ def execute_sql(predicted_sql, ground_truth, db_path): conn = sqlite3.connect(db_path) # Connect to the database cursor = conn.cursor() + pred_start_time = time.time() cursor.execute(predicted_sql) + pred_exec_time = time.time() - pred_start_time predicted_res = cursor.fetchall() + true_start_time = time.time() cursor.execute(ground_truth) + true_exec_time = time.time() - true_start_time ground_truth_res = cursor.fetchall() res = 0 + time_ratio = 0 if set(predicted_res) == set(ground_truth_res): res = 1 - return res + time_ratio = true_exec_time / pred_exec_time if pred_exec_time > 0 else 0 + return res, time_ratio def execute_model(predicted_sql, ground_truth, db_place, idx, meta_time_out): try: - res = func_timeout( + res, time_ratio = func_timeout( meta_time_out, execute_sql, args=(predicted_sql, ground_truth, db_place) ) except KeyboardInterrupt: @@ -45,12 +52,19 @@ def execute_model(predicted_sql, ground_truth, db_place, idx, meta_time_out): except FunctionTimedOut: result = [(f"timeout",)] res = 0 + time_ratio = 0 except Exception as e: result = [(f"error",)] # possibly len(query) > 512 or not executable res = 0 + time_ratio = 0 # print(result) # result = str(set([ret[0] for ret in result])) - result = {"sql_idx": idx, "res": res} + result = { + "sql_idx": idx, + "res": res, + "match": int(predicted_sql == ground_truth), + "time_ratio": time_ratio, + } # print(result) return result @@ -59,7 +73,7 @@ def package_sqls(sql_path, db_root_path, mode="gpt", data_mode="dev"): clean_sqls = [] db_path_list = [] if mode == "gpt": - # sql_data = json.load(open(sql_path + 'predict_' + data_mode + '.json', 'r')) + # sql_data = json.load(open(sql_path + 'predict_' + data_mode + '.json', "r')) # for idx, sql_str in sql_data.items(): # if type(sql_str) == str: # sql, db_name = sql_str.split('\t----- bird -----\t') @@ -104,9 +118,22 @@ def sort_results(list_of_dicts): return sorted(list_of_dicts, key=lambda x: x["sql_idx"]) -def compute_acc_by_diff(exec_results, diff_json_path): +def compute_ves(exec_results): + num_queries = len(exec_results) + total_ratio = 0 + count = 0 + + for i, result in enumerate(exec_results): + if result["time_ratio"] != 0: + count += 1 + total_ratio += math.sqrt(result["time_ratio"]) * 100 + ves = total_ratio / num_queries + return ves + + +def compute_acc_by_diff(exec_results, diff_json_path, metric): num_queries = len(exec_results) - results = [res["res"] for res in exec_results] + results = [res[metric] for res in exec_results] contents = load_json(diff_json_path) simple_results, moderate_results, challenging_results = [], [], [] @@ -119,35 +146,47 @@ def compute_acc_by_diff(exec_results, diff_json_path): if content["difficulty"] == "challenging": challenging_results.append(exec_results[i]) - - simple_acc = sum([res["res"] for res in simple_results]) / len(simple_results) - moderate_acc = sum([res["res"] for res in moderate_results]) / len(moderate_results) - challenging_acc = sum([res["res"] for res in challenging_results]) / len( - challenging_results - ) - all_acc = sum(results) / num_queries + if metric in ["res", "match"]: + simple_acc = sum([res[metric] for res in simple_results]) / len(simple_results) + moderate_acc = sum([res[metric] for res in moderate_results]) / len( + moderate_results + ) + challenging_acc = sum([res[metric] for res in challenging_results]) / len( + challenging_results + ) + all_acc = sum(results) / num_queries + elif metric in ["time_ratio"]: + simple_acc = compute_ves(simple_results) + moderate_acc = compute_ves(moderate_results) + challenging_acc = compute_ves(challenging_results) + all_acc = compute_ves(exec_results) + else: + raise NotImplementedError(f"metric: {metric} is not supported") count_lists = [ len(simple_results), len(moderate_results), len(challenging_results), num_queries, ] - return ( - simple_acc * 100, - moderate_acc * 100, - challenging_acc * 100, - all_acc * 100, - count_lists, - ) + if metric in ["res", "match"]: + return ( + simple_acc * 100, + moderate_acc * 100, + challenging_acc * 100, + all_acc * 100, + count_lists, + ) + else: + return simple_acc, moderate_acc, challenging_acc, all_acc, count_lists -def print_data(score_lists, count_lists): +def print_data(score_lists, count_lists, metric="Exec ACCURACY"): levels = ["simple", "moderate", "challenging", "total"] print("{:20} {:20} {:20} {:20} {:20}".format("", *levels)) print("{:20} {:<20} {:<20} {:<20} {:<20}".format("count", *count_lists)) print( - "====================================== ACCURACY =====================================" + f"====================================== {metric} =====================================" ) print( "{:20} {:<20.2f} {:<20.2f} {:<20.2f} {:<20.2f}".format("accuracy", *score_lists) @@ -157,17 +196,33 @@ def print_data(score_lists, count_lists): if __name__ == "__main__": args_parser = argparse.ArgumentParser() args_parser.add_argument( - "--predicted_sql_path", type=str, required=True, default="" + "--predicted_sql_path", + type=str, + default="../../pred_sql/pred_sql_bird_qwen14b_1212.sql", + ) + args_parser.add_argument( + "--ground_truth_path", type=str, default="../../dbgpt_hub/data/bird/dev/dev.sql" + ) + args_parser.add_argument("--data_mode", type=str, default="dev") + args_parser.add_argument( + "--db_root_path", + type=str, + default="../../dbgpt_hub/data/bird/dev/dev_databases/", ) - args_parser.add_argument("--ground_truth_path", type=str, required=True, default="") - args_parser.add_argument("--data_mode", type=str, required=True, default="dev") - args_parser.add_argument("--db_root_path", type=str, required=True, default="") args_parser.add_argument("--num_cpus", type=int, default=1) args_parser.add_argument("--meta_time_out", type=float, default=30.0) args_parser.add_argument("--mode_gt", type=str, default="gt") args_parser.add_argument("--mode_predict", type=str, default="gpt") args_parser.add_argument("--difficulty", type=str, default="simple") args_parser.add_argument("--diff_json_path", type=str, default="") + args_parser.add_argument( + "--etype", + dest="etype", + type=str, + default="match", + choices=("all", "exec", "match", "ves"), + ) + args = args_parser.parse_args() exec_result = [] @@ -186,20 +241,52 @@ def print_data(score_lists, count_lists): db_paths = db_paths_gt query_pairs = list(zip(pred_queries, gt_queries)) - run_sqls_parallel( - query_pairs, - db_places=db_paths, - num_cpus=args.num_cpus, - meta_time_out=args.meta_time_out, - ) + if args.etype in ["all", "exec", "ves"]: + run_sqls_parallel( + query_pairs, + db_places=db_paths, + num_cpus=args.num_cpus, + meta_time_out=args.meta_time_out, + ) + else: + for i, sql_pair in enumerate(query_pairs): + predicted_sql, ground_truth = sql_pair + exec_result.append( + {"sql_idx": i, "match": int(predicted_sql == ground_truth)} + ) exec_result = sort_results(exec_result) print("start calculate") - simple_acc, moderate_acc, challenging_acc, acc, count_lists = compute_acc_by_diff( - exec_result, args.diff_json_path - ) - score_lists = [simple_acc, moderate_acc, challenging_acc, acc] - print_data(score_lists, count_lists) + if args.etype in ["all", "exec"]: + ( + simple_acc, + moderate_acc, + challenging_acc, + acc, + count_lists, + ) = compute_acc_by_diff(exec_result, args.diff_json_path, "res") + score_lists = [simple_acc, moderate_acc, challenging_acc, acc] + print_data(score_lists, count_lists, metric="Exec Accuracy") + if args.etype in ["all", "match"]: + ( + simple_acc, + moderate_acc, + challenging_acc, + acc, + count_lists, + ) = compute_acc_by_diff(exec_result, args.diff_json_path, "match") + score_lists = [simple_acc, moderate_acc, challenging_acc, acc] + print_data(score_lists, count_lists, metric="Match Accuracy") + if args.etype in ["all", "ves"]: + ( + simple_acc, + moderate_acc, + challenging_acc, + acc, + count_lists, + ) = compute_acc_by_diff(exec_result, args.diff_json_path, "time_ratio") + score_lists = [simple_acc, moderate_acc, challenging_acc, acc] + print_data(score_lists, count_lists, metric="Ves") print( "===========================================================================================" )