Skip to content

Commit

Permalink
Merge pull request #1038 from parea-ai/PAI-1443-advanced-test-case-fi…
Browse files Browse the repository at this point in the history
…ltering

Pai 1443 advanced test case filtering
  • Loading branch information
jalexanderII authored Aug 6, 2024
2 parents 53b9352 + 3fce48a commit 8f4604a
Show file tree
Hide file tree
Showing 3 changed files with 313 additions and 3 deletions.
72 changes: 70 additions & 2 deletions parea/schemas/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,26 @@ class TestCase:
tags: List[str] = field(factory=list)


class TestCaseResult(list):
def __init__(self, test_cases: Union[TestCase, List[TestCase]]):
super().__init__([test_cases] if isinstance(test_cases, TestCase) else test_cases)

def __getitem__(self, key):
result = super().__getitem__(key)
if isinstance(key, slice):
return TestCaseResult(result)
return result

def get_all_test_case_inputs(self) -> List[Dict[str, str]]:
return [case.inputs for case in self]

def get_all_test_case_targets(self) -> List[str]:
return [case.target for case in self if case.target is not None]

def get_all_test_inputs_and_targets_dict(self) -> List[Dict[str, Any]]:
return [{"inputs": case.inputs, "target": case.target} for case in self]


@define
class TestCaseCollection:
id: int
Expand All @@ -262,6 +282,51 @@ class TestCaseCollection:
column_names: List[str] = field(factory=list)
test_cases: Dict[int, TestCase] = field(factory=dict)

@property
def testcases(self) -> TestCaseResult:
return TestCaseResult(list(self.test_cases.values()))

def __getitem__(self, key):
return self.testcases[key]

def filter_testcases(self, **kwargs) -> TestCaseResult:
def matches_criteria(case: TestCase) -> bool:
for key, value in kwargs.items():
if key == "inputs":
if isinstance(value, dict):
if not all(case.inputs.get(k) == v for k, v in value.items()):
return False
elif isinstance(value, list):
for input_filter in value:
input_key, condition_func = input_filter
if input_key not in case.inputs or not condition_func(case.inputs[input_key]):
return False
elif key == "tags":
if isinstance(value, dict):
match_type = value.get("match", "any")
tags_to_match = value.get("tags", [])
if match_type == "all":
if not all(tag in case.tags for tag in tags_to_match):
return False
else: # 'any'
if not any(tag in case.tags for tag in tags_to_match):
return False
elif isinstance(value, list):
if not any(tag in case.tags for tag in value):
return False
elif key == "target":
if callable(value):
if not value(case.target):
return False
elif case.target != value:
return False
elif not hasattr(case, key) or getattr(case, key) != value:
return False
return True

filtered_cases = [case for case in self.test_cases.values() if matches_criteria(case)]
return TestCaseResult(filtered_cases)

def get_all_test_case_inputs(self) -> Iterable[Dict[str, str]]:
return (test_case.inputs for test_case in self.test_cases.values())

Expand All @@ -286,11 +351,14 @@ def write_to_finetune_jsonl(self, file_path: str):
function_call = json.loads(target)
if isinstance(function_call, List):
function_call = function_call[0]
if not "arguments" in function_call:
if "arguments" not in function_call:
# tool use format, need to convert
function_call = function_call["function"]
function_call["arguments"] = json.dumps(function_call["arguments"])
assistant_response = {"role": "assistant", "function_call": function_call}
assistant_response = {
"role": "assistant",
"function_call": function_call,
}
except json.JSONDecodeError:
assistant_response = {"role": "assistant", "content": target}
messages.append(assistant_response)
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ build-backend = "poetry.core.masonry.api"
[tool.poetry]
name = "parea-ai"
packages = [{ include = "parea" }]
version = "0.2.191"
version = "0.2.192"
description = "Parea python sdk"
readme = "README.md"
authors = ["joel-parea-ai <[email protected]>"]
Expand Down
242 changes: 242 additions & 0 deletions tests/test_test_case_collection.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,242 @@
import unittest
from datetime import datetime

from parea.schemas import TestCase, TestCaseCollection, TestCaseResult


class TestTestCaseCollection(unittest.TestCase):
def setUp(self):
self.collection = TestCaseCollection(
id=1,
name="Test Collection",
created_at="2023-05-24",
last_updated_at="2023-05-24",
test_cases={
1: TestCase(
id=1,
test_case_collection_id=0,
inputs={"messages": "Answer this question", "context": "Short context"},
target="Certainly!",
tags=["important", "easy"],
),
2: TestCase(
id=2,
test_case_collection_id=0,
inputs={"messages": "Solve this problem", "context": "Long context with more than 100 characters" * 3},
target="Sure, I can help!",
tags=["important", "hard"],
),
3: TestCase(
id=3,
test_case_collection_id=0,
inputs={"messages": "Explain this concept", "word_count": "75"},
target="Of course!",
tags=["medium"],
),
4: TestCase(
id=4,
test_case_collection_id=0,
inputs={"messages": "Analyze this data", "data_size": "1000", "timestamp": "2023-05-25T10:30:00"},
target="Here's the analysis:",
tags=["data", "analysis", "important"],
),
5: TestCase(
id=5,
test_case_collection_id=0,
inputs={"messages": "Summarize this article", "word_count": "500", "language": "English"},
target="Here's a summary:",
tags=["summary", "medium", "language"],
),
6: TestCase(
id=6,
test_case_collection_id=0,
inputs={"messages": "Translate this sentence", "source_language": "English", "target_language": "French"},
target="Voici la traduction:",
tags=["translation", "language", "easy"],
),
},
)

def test_testcases_property(self):
"""Test the 'testcases' property of TestCaseCollection works like a list."""
self.assertEqual(len(self.collection.testcases), 6)
self.assertIsInstance(self.collection.testcases[0], TestCase)

def test_getitem_testcases(self):
"""Test the indexing and slicing capabilities of TestCaseCollection testcases property."""
self.assertEqual(self.collection.testcases[0].id, 1) # Test single index access
self.assertEqual(len(self.collection.testcases[:2]), 2) # Test slicing
self.assertIsInstance(self.collection.testcases[:2], TestCaseResult) # Slicing should return TestCaseResult
self.assertIsInstance(self.collection.testcases[:2], list) # but still a list

def test_getitem(self):
"""Test the indexing and slicing capabilities of TestCaseCollection."""
self.assertEqual(self.collection[0].id, 1) # Test single index access
self.assertEqual(len(self.collection[:2]), 2) # Test slicing
self.assertIsInstance(self.collection[:2], TestCaseResult) # Slicing should return TestCaseResult
self.assertIsInstance(self.collection[:2], list) # but still a list

def test_filter_by_id(self):
"""Test filtering TestCases by their id."""
result = self.collection.filter_testcases(id=2)
self.assertEqual(len(result), 1)
self.assertEqual(result[0].id, 2)

def test_filter_by_target(self):
"""Test filtering TestCases by their target field."""
result = self.collection.filter_testcases(target="Certainly!")
self.assertEqual(len(result), 1)
self.assertEqual(result[0].id, 1)

def test_filter_by_inputs_basic(self):
"""Test basic filtering of TestCases by their inputs."""
result = self.collection.filter_testcases(inputs={"messages": "Answer this question"})
self.assertEqual(len(result), 1)
self.assertEqual(result[0].id, 1)

def test_filter_by_tags_any(self):
"""
Test filtering TestCases by tags using 'any' match (default behavior).
This should return TestCases that have at least one of the specified tags.
"""
result = self.collection.filter_testcases(tags=["important", "medium"])
self.assertEqual(len(result), 5)
result = self.collection.filter_testcases(tags=["medium"])
self.assertEqual(len(result), 2)

def test_filter_by_tags_all(self):
"""
Test filtering TestCases by tags using 'all' match.
This should return TestCases that have all the specified tags.
"""
result = self.collection.filter_testcases(tags={"match": "all", "tags": ["important", "hard"]})
self.assertEqual(len(result), 1)
self.assertEqual(result[0].id, 2)

def test_filter_by_inputs_advanced(self):
"""
Test advanced filtering of TestCases by their inputs using custom functions.
This test demonstrates how to use lambda functions to create complex filtering conditions.
"""
result = self.collection.filter_testcases(
inputs=[
("messages", lambda x: "question" in x.lower()), # Check if 'question' is in the message
("context", lambda x: len(x) < 50), # Check if context is less than 50 characters
]
)
self.assertEqual(len(result), 1)
self.assertEqual(result[0].id, 1)

# Filter by word count range
result = self.collection.filter_testcases(
inputs=[
("word_count", lambda x: x.isdigit() and 50 < int(x) < 100),
# Check if word count is between 50 and 100
]
)
self.assertEqual(len(result), 1)
self.assertEqual(result[0].id, 3)

# Filter by timestamp
result = self.collection.filter_testcases(
inputs=[
("timestamp", lambda x: datetime.fromisoformat(x) > datetime(2023, 5, 25)), # Check if timestamp is after May 25, 2023
]
)
self.assertEqual(len(result), 1)
self.assertEqual(result[0].id, 4)

# Filter by multiple input fields
result = self.collection.filter_testcases(
inputs=[
("messages", lambda x: "translate" in x.lower()), # Check if 'translate' is in the message
("source_language", lambda x: x == "English"), # Check if source language is English
("target_language", lambda x: x == "French"), # Check if target language is French
]
)
self.assertEqual(len(result), 1)
self.assertEqual(result[0].id, 6)

def test_combined_filters(self):
"""
Test combining multiple filters.
This test demonstrates how to use input filtering and tag filtering together.
"""
result = self.collection.filter_testcases(inputs=[("messages", lambda x: x.startswith("Answer"))], tags=["important"])
self.assertEqual(len(result), 1)
self.assertEqual(result[0].id, 1)

# Combine input filtering, tag filtering, and id filtering
result = self.collection.filter_testcases(id=4, inputs=[("data_size", lambda x: int(x) > 500)], tags=["data", "important"])
self.assertEqual(len(result), 1)
self.assertEqual(result[0].id, 4)

# Combine multiple input filters with tag filtering
result = self.collection.filter_testcases(
inputs=[
("messages", lambda x: "summarize" in x.lower()),
("word_count", lambda x: int(x) > 400),
],
tags={"match": "all", "tags": ["medium", "language"]},
)
self.assertEqual(len(result), 1)
self.assertEqual(result[0].id, 5)

# Complex combination of filters
result = self.collection.filter_testcases(
inputs=[
("messages", lambda x: len(x.split()) > 2), # Messages with more than 2 words
("word_count", lambda x: int(x) > 50 if x.isdigit() else True), # Word count > 50 if present
],
tags=["important", "medium", "language"], # Match any of these tags
target=lambda x: len(x) < 20, # Target response less than 20 characters
)
self.assertEqual(len(result), 2)
self.assertIn(result[0].id, [3, 5])
self.assertIn(result[1].id, [3, 5])

def test_no_match(self):
"""Test the behavior when no TestCases match the filter criteria."""
result = self.collection.filter_testcases(id=999) # No TestCase has this id
self.assertEqual(len(result), 0)

def test_get_all_test_case_inputs(self):
"""Test the get_all_test_case_inputs method."""
result = self.collection.filter_testcases(inputs={"messages": "Answer this question"}).get_all_test_case_inputs()
self.assertEqual(len(result), 1)
self.assertIn("messages", result[0])
self.assertEqual(result[0]["messages"], "Answer this question")

def test_get_all_test_case_targets(self):
"""Test the get_all_test_case_targets method."""
result = self.collection.filter_testcases(
inputs=[
("messages", lambda x: "question" in x.lower()),
("context", lambda x: len(x) < 50),
]
).get_all_test_case_targets()
self.assertEqual(len(result), 1)
self.assertEqual(result[0], "Certainly!")

def test_get_all_test_inputs_and_targets_dict(self):
"""Test the get_all_test_inputs_and_targets_dict method."""
result = self.collection.filter_testcases(tags=["important"]).get_all_test_inputs_and_targets_dict()
self.assertEqual(len(result), 3)
for item in result:
self.assertIn("inputs", item)
self.assertIn("target", item)

def test_chaining_methods(self):
"""Test chaining multiple methods."""
result = self.collection.filter_testcases(tags=["important"])[1:3].get_all_test_case_targets()
self.assertEqual(len(result), 2)
self.assertIn("Sure, I can help!", result)
self.assertIn("Here's the analysis:", result)


if __name__ == "__main__":
unittest.main()

0 comments on commit 8f4604a

Please sign in to comment.