Skip to content

Commit

Permalink
Fix sort so that it works with an unspecified or None default_value. (#…
Browse files Browse the repository at this point in the history
…1040)

* Split tests into unit via local mode (0.04s runtime) and integration via ray (21.25s)
* Fix similarity test to run in local mode to work around GPU requirement failing in ray mode
* Fix duckdb to create temporary file in already gitignored tmp directory
  • Loading branch information
eric-anderson authored Nov 27, 2024
1 parent 796e3a9 commit 9761ec7
Show file tree
Hide file tree
Showing 9 changed files with 128 additions and 62 deletions.
13 changes: 11 additions & 2 deletions lib/sycamore/sycamore/docset.py
Original file line number Diff line number Diff line change
Expand Up @@ -1190,9 +1190,18 @@ def sort(self, descending: bool, field: str, default_val: Optional[Any] = None)
field: Document field in relation to Document using dotted notation, e.g. properties.filetype
default_val: Default value to use if field does not exist in Document
"""
from sycamore.transforms import Sort
from sycamore.transforms.sort import Sort, DropIfMissingField

return DocSet(self.context, Sort(self.plan, descending, field, default_val))
plan = self.plan
if default_val is None:
import logging

logging.warning(
"Default value is none. Adding explicit filter step to drop documents missing the key."
" This includes any metadata.documents."
)
plan = DropIfMissingField(plan, field)
return DocSet(self.context, Sort(plan, descending, field, default_val))

def groupby_count(self, field: str, unique_field: Optional[str] = None, **kwargs) -> "DocSet":
"""
Expand Down
3 changes: 2 additions & 1 deletion lib/sycamore/sycamore/query/operators/sort.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,5 @@ class Sort(Node):
field: str
"""The name of the database field to sort based on."""

default_value: Any
default_value: Any = None
"""The default value used when sorting if a document is missing the specified field."""
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
from sycamore.context import ExecMode
import sycamore.tests.unit.transforms.test_sort as unit


class TestSort(unit.TestSort):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.exec_mode = ExecMode.RAY
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import os
import pytest
from unittest import mock
from sycamore.data import Document
Expand All @@ -21,11 +22,12 @@ def mock_duckdb_connection():


def test_duckdb_reader_client_from_client_params(mock_duckdb_connection):
params = DuckDBReaderClientParams(db_url="test_db")
os.makedirs("tmp", exist_ok=True)
params = DuckDBReaderClientParams(db_url="tmp/test_db")
with mock.patch("duckdb.connect", return_value=mock_duckdb_connection) as mock_connect:
client = DuckDBReaderClient.from_client_params(params)
assert isinstance(client, DuckDBReaderClient)
mock_connect.assert_called_once_with(database="test_db", read_only=True)
mock_connect.assert_called_once_with(database="tmp/test_db", read_only=True)


def test_duckdb_reader_client_read_records(mock_duckdb_connection):
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import os
import pytest
from unittest import mock
from sycamore.connectors.duckdb.duckdb_writer import (
Expand Down Expand Up @@ -26,7 +27,8 @@ def target_params():

@pytest.fixture
def client_params():
return DuckDBWriterClientParams(db_url="test.db")
os.makedirs("tmp", exist_ok=True)
return DuckDBWriterClientParams(db_url="tmp/test.db")


@pytest.fixture
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -331,20 +331,25 @@ def test_llm_extract_entity():


def test_sort():
context = sycamore.init()
doc_set = Mock(spec=DocSet)
return_doc_set = Mock(spec=DocSet)
doc_set.sort.return_value = return_doc_set
logical_node = Sort(node_id=0, descending=True, field="properties.counter", default_value=0)
sycamore_operator = SycamoreSort(context, logical_node, query_id="test", inputs=[doc_set])
result = sycamore_operator.execute()
# Verify that we can create a sort node without specifying a default value and pydantic won't
# throw an assertion.
Sort(node_id=0, descending=True, field="no-default-value")

doc_set.sort.assert_called_once_with(
descending=logical_node.descending,
field=logical_node.field,
default_val=logical_node.default_value,
)
assert result == return_doc_set
for default_value in [None, 0]:
context = sycamore.init()
doc_set = Mock(spec=DocSet)
return_doc_set = Mock(spec=DocSet)
doc_set.sort.return_value = return_doc_set
logical_node = Sort(node_id=0, descending=True, field="properties.counter", default_value=default_value)
sycamore_operator = SycamoreSort(context, logical_node, query_id="test", inputs=[doc_set])
result = sycamore_operator.execute()

doc_set.sort.assert_called_once_with(
descending=logical_node.descending,
field=logical_node.field,
default_val=logical_node.default_value,
)
assert result == return_doc_set


def test_top_k():
Expand Down
10 changes: 3 additions & 7 deletions lib/sycamore/sycamore/tests/unit/transforms/test_similarity.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,6 @@ def test_transformers_similarity_scorer_no_doc_structure(self):
assert [doc.doc_id for doc in result] == [2, 1, 3, 5, 4]

def test_transformers_similarity_scorer_no_element_id(self):

similarity_scorer = HuggingFaceTransformersSimilarityScorer(RERANKER_MODEL)
score_property_name = "similarity_score"
query = "this is a cat"
Expand Down Expand Up @@ -113,7 +112,6 @@ def test_transformers_similarity_scorer_no_element_id(self):


class TestSimilarityTransform:

def test_transformers_score_similarity(self, mocker):
node = mocker.Mock(spec=Node)
similarity_scorer = HuggingFaceTransformersSimilarityScorer(RERANKER_MODEL, ignore_doc_structure=True)
Expand All @@ -122,15 +120,13 @@ def test_transformers_score_similarity(self, mocker):
{"doc_id": 1, "text_representation": "Members of a strike at Yale University.", "embedding": None},
{"doc_id": 2, "text_representation": "A woman is speaking at a podium outdoors.", "embedding": None},
]
docs = [Document(d) for d in dicts]
input_dataset = ray.data.from_items([{"doc": Document(doc_dict).serialize()} for doc_dict in dicts])
execute = mocker.patch.object(node, "execute")
execute.return_value = input_dataset
input_dataset.show()
output_dataset = score_similarity.execute()
taken = output_dataset.take_all()
taken = score_similarity.local_execute(docs)

for d in taken:
doc = Document.from_row(d)
for doc in taken:
if isinstance(doc, MetadataDocument):
continue
assert float(doc.properties.get("_similarity_score"))
83 changes: 52 additions & 31 deletions lib/sycamore/sycamore/tests/unit/transforms/test_sort.py
Original file line number Diff line number Diff line change
@@ -1,54 +1,62 @@
import string
import pytest
import random
import unittest

import sycamore
from sycamore import DocSet, ExecMode
from sycamore.data import Document, MetadataDocument


class TestSort:
@pytest.fixture()
class TestSort(unittest.TestCase):
NUM_DOCS = 10

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.exec_mode = ExecMode.LOCAL

def docs(self) -> list[Document]:
doc_list = [
# text_representation is random 6 letter strings
Document(
text_representation=f"{''.join(random.choices(string.ascii_letters, k=6))}",
properties={"document_number": random.randint(1, 10000), "even": i if i % 2 == 0 else None},
)
for i in range(10)
for i in range(self.NUM_DOCS)
]

for doc in doc_list:
if doc.properties["even"] is None:
doc.properties.pop("even")
return doc_list

@pytest.fixture(params=(exec_mode for exec_mode in ExecMode if exec_mode != ExecMode.UNKNOWN))
def docset(self, docs: list[Document], exec_mode) -> DocSet:
context = sycamore.init(exec_mode=exec_mode)
return context.read.document(docs)
def docset(self) -> DocSet:
context = sycamore.init(exec_mode=self.exec_mode)
return context.read.document(self.docs())

def test_sort_descending(self, docset: DocSet):
sorted_docset = docset.sort(True, "text_representation")
def test_sort_descending(self):
sorted_docset = self.docset().sort(True, "text_representation")
doc_list = sorted_docset.take_all()

for i in range(1, len(doc_list)):
assert str(doc_list[i].text_representation) <= str(doc_list[i - 1].text_representation)

def test_sort_ascending(self, docset: DocSet):
sorted_docset = docset.sort(False, "properties.document_number")
def test_sort_ascending(self):
sorted_docset = self.docset().sort(False, "properties.document_number")
doc_list = sorted_docset.take_all()

for i in range(1, len(doc_list)):
assert doc_list[i].properties["document_number"] >= doc_list[i - 1].properties["document_number"]

def test_default_value(self, docset: DocSet):
with pytest.raises(Exception):
sorted_docset = docset.sort(False, "properties.even")
doc_list = sorted_docset.take_all()
def test_default_value(self):
sorted_docset = self.docset().sort(False, "properties.even")
doc_list = sorted_docset.take_all()

for i in range(1, len(doc_list)):
assert doc_list[i].properties.get("even", 0) >= doc_list[i - 1].properties.get("even", 0)

assert len(doc_list) == self.NUM_DOCS / 2

sorted_docset = docset.sort(False, "properties.even", 0)
sorted_docset = self.docset().sort(False, "properties.even", 0)
doc_list = sorted_docset.take_all()

for i in range(1, len(doc_list)):
Expand All @@ -61,25 +69,38 @@ def test_metadata_document(self):
MetadataDocument(),
Document(text_representation="C"),
MetadataDocument(),
Document(text_representation=None),
]

context = sycamore.init()
context = sycamore.init(exec_mode=self.exec_mode)
docset = context.read.document(doc_list)

with pytest.raises(Exception):
sorted_docset = docset.sort(False, "text_representation")
sorted_doc_list = sorted_docset.take_all(include_metadata=True)
sorted_docset = docset.sort(False, "text_representation")
sorted_doc_list = sorted_docset.take_all(include_metadata=True)
assert len(sorted_doc_list) == 3
assert sorted_doc_list[0].text_representation == "B"
assert sorted_doc_list[1].text_representation == "C"
assert sorted_doc_list[2].text_representation == "Z"

# must include default value for docsets with MetadataDocuments
sorted_docset = docset.sort(False, "text_representation", "A")
sorted_doc_list = sorted_docset.take_all(include_metadata=True)

for i in range(len(sorted_doc_list)):
if i == 0 or i == 1:
assert isinstance(sorted_doc_list[i], MetadataDocument)
elif i == 2:
assert sorted_doc_list[i].text_representation == "B"
elif i == 3:
assert sorted_doc_list[i].text_representation == "C"
elif i == 4:
assert sorted_doc_list[i].text_representation == "Z"
# Two MetadataDocuments and the Document with text_representation is None from doc_list
# All get the same key, so they can be in any order
for i in range(3):
d = sorted_doc_list[i]
assert isinstance(d, MetadataDocument) or d.text_representation is None

assert sorted_doc_list[3].text_representation == "B"
assert sorted_doc_list[4].text_representation == "C"
assert sorted_doc_list[5].text_representation == "Z"

sorted_docset = docset.sort(True, "text_representation", "A")
sorted_doc_list = sorted_docset.take_all(include_metadata=True)

assert sorted_doc_list[0].text_representation == "Z"
assert sorted_doc_list[1].text_representation == "C"
assert sorted_doc_list[2].text_representation == "B"
for i in range(3):
d = sorted_doc_list[i + 3]
assert isinstance(d, MetadataDocument) or d.text_representation is None
32 changes: 27 additions & 5 deletions lib/sycamore/sycamore/transforms/sort.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,31 @@
if TYPE_CHECKING:
from ray.data import Dataset

from sycamore.plan_nodes import UnaryNode


class DropIfMissingField(UnaryNode):
"""Drop all documents that are missing the specified field, including metadata. This makes ray work because
ray requires a key value that is comparable between all entries which means None can't be used and there
is no easy way to auto-infer the correct type."""

def __init__(self, child, field):
super().__init__(child)
self._field = field

def execute(self, **kwargs) -> "Dataset":
input_dataset = self.child().execute()
result = input_dataset.map_batches(self.ray_drop_documents)
return result

def local_execute(self, all_docs: list[Document]) -> list[Document]:
return [d for d in all_docs if d.field_to_value(self._field) is not None]

def ray_drop_documents(self, ray_input):
all_docs = [Document.deserialize(s) for s in ray_input.get("doc", [])]
out_docs = self.local_execute(all_docs)
return {"doc": [d.serialize() for d in out_docs]}


class Sort(Transform):
"""
Expand Down Expand Up @@ -52,11 +77,8 @@ def ray_callable(input_dict: dict[str, Any]) -> dict[str, Any]:
val = doc.field_to_value(self._field)

if val is None:
if self._default_val is None:
exception_string = f'Field "{self._field}" not present in Document and default value not provided.'
raise Exception(exception_string)
else:
val = self._default_val
assert self._default_val is not None
val = self._default_val

# updates row to include new col
new_doc = doc.to_row()
Expand Down

0 comments on commit 9761ec7

Please sign in to comment.