From e82a493a9ed9d06a932489a8e91081e38b4dbd7a Mon Sep 17 00:00:00 2001 From: yanyc428 Date: Tue, 9 Jan 2024 20:03:54 +0800 Subject: [PATCH] [Fix] Fix error in gsm8k evaluator Co-authored-by: jiangjin1999 <1261842974@qq.com> --- opencompass/datasets/gsm8k.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/opencompass/datasets/gsm8k.py b/opencompass/datasets/gsm8k.py index 88e3fa2f4..539780edb 100644 --- a/opencompass/datasets/gsm8k.py +++ b/opencompass/datasets/gsm8k.py @@ -56,6 +56,14 @@ def gsm8k_postprocess(text: str) -> str: class Gsm8kEvaluator(BaseEvaluator): + def is_equal(self, pred, refer): + try: + if pred == refer or abs(float(pred) - int(refer)) < 1e-6: + return True + except Exception: + pass + return False + def score(self, predictions, references): if len(predictions) != len(references): return { @@ -68,7 +76,7 @@ def score(self, predictions, references): for i, j in zip(predictions, references): detail = {'pred': i, 'answer': j, 'correct': False} count += 1 - if i == j: + if self.is_equal(i, j): correct += 1 detail['correct'] = True details.append(detail)