Skip to content

Commit

Permalink
Fix broken unit tests
Browse files Browse the repository at this point in the history
  • Loading branch information
austintlee committed Jan 10, 2025
1 parent 16e10cc commit 9dfa8de
Show file tree
Hide file tree
Showing 4 changed files with 28 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ def read_records(self, query_params: BaseDBReader.QueryParams) -> "OpenSearchRea
assert isinstance(
query_params, OpenSearchReaderQueryParams
), f"Wrong kind of query parameters found: {query_params}"

assert "index" not in query_params.kwargs and "body" not in query_params.kwargs
logger.debug(f"OpenSearch query on {query_params.index_name}: {query_params.query}")
if "size" not in query_params.query and "size" not in query_params.kwargs:
Expand Down Expand Up @@ -184,6 +185,9 @@ def get_batches(doc_ids) -> list[list[str]]:
else:
batch_doc_count += doc_count
cur_batch.append(doc_ids[i])

if len(cur_batch) > 0:
batches.append(cur_batch)
return batches

batches = get_batches(doc_ids)
Expand Down Expand Up @@ -265,6 +269,7 @@ def __init__(
self,
client_params: OpenSearchReaderClientParams,
query_params: BaseDBReader.QueryParams,
use_pit: bool = True,
**kwargs,
):
assert isinstance(
Expand All @@ -275,7 +280,7 @@ def __init__(
self._client_params = client_params
self._query_params = query_params
# TODO add support for 'search_after' pagination if a sort field is provided.
self.use_pit = query_params.kwargs.get("use_pit", False)
self.use_pit = use_pit
logger.info(f"OpenSearchReader using PIT: {self.use_pit}")

@timetrace("OpenSearchReader")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -322,7 +322,7 @@ def test_parallel_read_reconstruct(self, setup_index_large, os_client):
os_client_args=TestOpenSearchRead.OS_CLIENT_ARGS,
index_name=setup_index_large,
reconstruct_document=True,
# parallelism=1,
use_pit=False,
).take_all()
t1 = time.time()
print(f"Retrieved {len(retrieved_docs_reconstructed)} documents in {t1 - t0} seconds")
Expand All @@ -341,7 +341,7 @@ def test_parallel_read(self, setup_index_large, os_client):
os_client_args=TestOpenSearchRead.OS_CLIENT_ARGS,
index_name=setup_index_large,
reconstruct_document=False,
parallelism=1,
use_pit=False,
).take_all()
t1 = time.time()

Expand All @@ -360,7 +360,6 @@ def test_parallel_read_reconstruct_with_pit(self, setup_index_large, os_client):
os_client_args=TestOpenSearchRead.OS_CLIENT_ARGS,
index_name=setup_index_large,
reconstruct_document=True,
query_kwargs={"use_pit": True},
).take_all()
t1 = time.time()
print(f"Retrieved {len(retrieved_docs_reconstructed)} documents in {t1 - t0} seconds")
Expand All @@ -379,9 +378,7 @@ def test_parallel_read_with_pit(self, setup_index_large, os_client):
os_client_args=TestOpenSearchRead.OS_CLIENT_ARGS,
index_name=setup_index_large,
reconstruct_document=False,
query_kwargs={"use_pit": True},
num_cpus=2,
concurrency=5,
concurrency=2,
).take_all()
t1 = time.time()

Expand All @@ -404,9 +401,7 @@ def test_parallel_query_with_pit(self, setup_index_large, os_client):
index_name=setup_index_large,
query=query,
reconstruct_document=False,
query_kwargs={"use_pit": True},
num_cpus=2,
concurrency=5,
concurrency=2,
).take_all()
t1 = time.time()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,9 @@ def test_to_docs_reconstruct(self, mocker):
}
},
]

total = len(return_val["hits"]["hits"])
return_val["hits"]["total"] = {"value": total}
client.search.return_value = return_val
query_response = OpenSearchReaderQueryResponse(hits, client=client)
query_params = OpenSearchReaderQueryParams(
Expand Down Expand Up @@ -318,7 +321,10 @@ def test_to_docs_reconstruct_no_additional_elements(self, mocker):

# no elements match
hits = [{"_source": record} for record in records]
client.search.return_value = {"hits": {"hits": [hit for hit in hits if hit["_source"].get("parent_id")]}}
return_val = {"hits": {"hits": [hit for hit in hits if hit["_source"].get("parent_id")]}}
total = len(return_val["hits"]["hits"])
return_val["hits"]["total"] = {"value": total}
client.search.return_value = return_val
query_response = OpenSearchReaderQueryResponse(hits, client=client)
query_params = OpenSearchReaderQueryParams(
index_name="some index", query=MATCH_ALL_QUERY, reconstruct_document=True
Expand Down Expand Up @@ -409,6 +415,10 @@ def test_from_document_too_many_fields(self):
assert record._index == tp.index_name


class TestOpenSearchReader:
def test_opensearchreader_knn_query(self):
pass

class TestOpenSearchUtils:

def test_get_knn_query(self):
Expand Down
8 changes: 7 additions & 1 deletion lib/sycamore/sycamore/tests/unit/query/conftest.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
import pytest
from typing import Dict, List, Optional
from typing import Dict, List, Optional, TYPE_CHECKING

if TYPE_CHECKING:
from ray.data import Dataset

from sycamore import DocSet, Context
from sycamore.connectors.opensearch.opensearch_reader import (
Expand All @@ -21,6 +24,9 @@ class MockOpenSearchReader(OpenSearchReader):
def read_docs(self) -> List[Document]:
return get_mock_docs()

def execute(self, **kwargs) -> "Dataset":
from ray.data import from_items
return from_items(items=[{"doc": doc.serialize()} for doc in self.read_docs()])

class MockDocSetReader(DocSetReader):
"""Mock out DocSetReader for tests."""
Expand Down

0 comments on commit 9dfa8de

Please sign in to comment.