Skip to content

Commit

Permalink
Merge pull request #1202 from vjsliogeris/fix/fix_copy_metrics
Browse files Browse the repository at this point in the history
Fix/fix copy metrics
  • Loading branch information
penguine-ip authored Dec 3, 2024
2 parents 839153b + f5b8286 commit cd02f9e
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 2 deletions.
11 changes: 9 additions & 2 deletions deepeval/metrics/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,15 @@ def copy_metrics(
metric_class = type(metric)
args = vars(metric)

signature = inspect.signature(metric_class.__init__)
valid_params = signature.parameters.keys()
superclasses = metric_class.__mro__

valid_params = []

for superclass in superclasses:
signature = inspect.signature(superclass.__init__)
superclass_params = signature.parameters.keys()
valid_params.extend(superclass_params)
valid_params = set(valid_params)
valid_args = {key: args[key] for key in valid_params if key in args}

copied_metrics.append(metric_class(**valid_args))
Expand Down
36 changes: 36 additions & 0 deletions tests/test_copy_metrics.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
import pytest

from deepeval.metrics import GEval
from deepeval.metrics.utils import copy_metrics
from deepeval.test_case import LLMTestCaseParams
from deepeval.models.gpt_model import GPTModel


class DummyMetric(GEval):
def __init__(self, **kwargs):
kwargs["evaluation_params"] = [
LLMTestCaseParams.ACTUAL_OUTPUT,
LLMTestCaseParams.EXPECTED_OUTPUT,
LLMTestCaseParams.INPUT,
]
kwargs["criteria"] = "All answers are good"
if "name" not in kwargs.keys():
kwargs["name"] = "default_config_name"
super().__init__(**kwargs)


def test_copy_metrics():
# Different than the default, 'gpt-4o'
metric_before = DummyMetric(
model="gpt-4o-mini",
)
metric_after = copy_metrics([metric_before])
vars_before = vars(metric_before)
vars_after = vars(metric_after[0])
for key_before, value_before in vars_before.items():
value_after = vars_after[key_before]
if isinstance(value_before, GPTModel):
assert value_before.model_name == value_after.model_name
else:
assert value_before == value_after
assert key_before in vars_after.keys()

0 comments on commit cd02f9e

Please sign in to comment.