From 9dfa8de2888b96fb60f078566be7430d72aff6be Mon Sep 17 00:00:00 2001 From: Austin Lee Date: Thu, 9 Jan 2025 18:30:37 -0800 Subject: [PATCH] Fix broken unit tests --- .../connectors/opensearch/opensearch_reader.py | 7 ++++++- .../connectors/opensearch/test_opensearch_read.py | 13 ++++--------- .../unit/connectors/opensearch/test_opensearch.py | 12 +++++++++++- lib/sycamore/sycamore/tests/unit/query/conftest.py | 8 +++++++- 4 files changed, 28 insertions(+), 12 deletions(-) diff --git a/lib/sycamore/sycamore/connectors/opensearch/opensearch_reader.py b/lib/sycamore/sycamore/connectors/opensearch/opensearch_reader.py index 45ccf56c3..03549f0ce 100644 --- a/lib/sycamore/sycamore/connectors/opensearch/opensearch_reader.py +++ b/lib/sycamore/sycamore/connectors/opensearch/opensearch_reader.py @@ -50,6 +50,7 @@ def read_records(self, query_params: BaseDBReader.QueryParams) -> "OpenSearchRea assert isinstance( query_params, OpenSearchReaderQueryParams ), f"Wrong kind of query parameters found: {query_params}" + assert "index" not in query_params.kwargs and "body" not in query_params.kwargs logger.debug(f"OpenSearch query on {query_params.index_name}: {query_params.query}") if "size" not in query_params.query and "size" not in query_params.kwargs: @@ -184,6 +185,9 @@ def get_batches(doc_ids) -> list[list[str]]: else: batch_doc_count += doc_count cur_batch.append(doc_ids[i]) + + if len(cur_batch) > 0: + batches.append(cur_batch) return batches batches = get_batches(doc_ids) @@ -265,6 +269,7 @@ def __init__( self, client_params: OpenSearchReaderClientParams, query_params: BaseDBReader.QueryParams, + use_pit: bool = True, **kwargs, ): assert isinstance( @@ -275,7 +280,7 @@ def __init__( self._client_params = client_params self._query_params = query_params # TODO add support for 'search_after' pagination if a sort field is provided. - self.use_pit = query_params.kwargs.get("use_pit", False) + self.use_pit = use_pit logger.info(f"OpenSearchReader using PIT: {self.use_pit}") @timetrace("OpenSearchReader") diff --git a/lib/sycamore/sycamore/tests/integration/connectors/opensearch/test_opensearch_read.py b/lib/sycamore/sycamore/tests/integration/connectors/opensearch/test_opensearch_read.py index a45407743..63de318ec 100644 --- a/lib/sycamore/sycamore/tests/integration/connectors/opensearch/test_opensearch_read.py +++ b/lib/sycamore/sycamore/tests/integration/connectors/opensearch/test_opensearch_read.py @@ -322,7 +322,7 @@ def test_parallel_read_reconstruct(self, setup_index_large, os_client): os_client_args=TestOpenSearchRead.OS_CLIENT_ARGS, index_name=setup_index_large, reconstruct_document=True, - # parallelism=1, + use_pit=False, ).take_all() t1 = time.time() print(f"Retrieved {len(retrieved_docs_reconstructed)} documents in {t1 - t0} seconds") @@ -341,7 +341,7 @@ def test_parallel_read(self, setup_index_large, os_client): os_client_args=TestOpenSearchRead.OS_CLIENT_ARGS, index_name=setup_index_large, reconstruct_document=False, - parallelism=1, + use_pit=False, ).take_all() t1 = time.time() @@ -360,7 +360,6 @@ def test_parallel_read_reconstruct_with_pit(self, setup_index_large, os_client): os_client_args=TestOpenSearchRead.OS_CLIENT_ARGS, index_name=setup_index_large, reconstruct_document=True, - query_kwargs={"use_pit": True}, ).take_all() t1 = time.time() print(f"Retrieved {len(retrieved_docs_reconstructed)} documents in {t1 - t0} seconds") @@ -379,9 +378,7 @@ def test_parallel_read_with_pit(self, setup_index_large, os_client): os_client_args=TestOpenSearchRead.OS_CLIENT_ARGS, index_name=setup_index_large, reconstruct_document=False, - query_kwargs={"use_pit": True}, - num_cpus=2, - concurrency=5, + concurrency=2, ).take_all() t1 = time.time() @@ -404,9 +401,7 @@ def test_parallel_query_with_pit(self, setup_index_large, os_client): index_name=setup_index_large, query=query, reconstruct_document=False, - query_kwargs={"use_pit": True}, - num_cpus=2, - concurrency=5, + concurrency=2, ).take_all() t1 = time.time() diff --git a/lib/sycamore/sycamore/tests/unit/connectors/opensearch/test_opensearch.py b/lib/sycamore/sycamore/tests/unit/connectors/opensearch/test_opensearch.py index 23c1a6451..4faa98526 100644 --- a/lib/sycamore/sycamore/tests/unit/connectors/opensearch/test_opensearch.py +++ b/lib/sycamore/sycamore/tests/unit/connectors/opensearch/test_opensearch.py @@ -246,6 +246,9 @@ def test_to_docs_reconstruct(self, mocker): } }, ] + + total = len(return_val["hits"]["hits"]) + return_val["hits"]["total"] = {"value": total} client.search.return_value = return_val query_response = OpenSearchReaderQueryResponse(hits, client=client) query_params = OpenSearchReaderQueryParams( @@ -318,7 +321,10 @@ def test_to_docs_reconstruct_no_additional_elements(self, mocker): # no elements match hits = [{"_source": record} for record in records] - client.search.return_value = {"hits": {"hits": [hit for hit in hits if hit["_source"].get("parent_id")]}} + return_val = {"hits": {"hits": [hit for hit in hits if hit["_source"].get("parent_id")]}} + total = len(return_val["hits"]["hits"]) + return_val["hits"]["total"] = {"value": total} + client.search.return_value = return_val query_response = OpenSearchReaderQueryResponse(hits, client=client) query_params = OpenSearchReaderQueryParams( index_name="some index", query=MATCH_ALL_QUERY, reconstruct_document=True @@ -409,6 +415,10 @@ def test_from_document_too_many_fields(self): assert record._index == tp.index_name +class TestOpenSearchReader: + def test_opensearchreader_knn_query(self): + pass + class TestOpenSearchUtils: def test_get_knn_query(self): diff --git a/lib/sycamore/sycamore/tests/unit/query/conftest.py b/lib/sycamore/sycamore/tests/unit/query/conftest.py index b00cd37b1..afe657d96 100644 --- a/lib/sycamore/sycamore/tests/unit/query/conftest.py +++ b/lib/sycamore/sycamore/tests/unit/query/conftest.py @@ -1,5 +1,8 @@ import pytest -from typing import Dict, List, Optional +from typing import Dict, List, Optional, TYPE_CHECKING + +if TYPE_CHECKING: + from ray.data import Dataset from sycamore import DocSet, Context from sycamore.connectors.opensearch.opensearch_reader import ( @@ -21,6 +24,9 @@ class MockOpenSearchReader(OpenSearchReader): def read_docs(self) -> List[Document]: return get_mock_docs() + def execute(self, **kwargs) -> "Dataset": + from ray.data import from_items + return from_items(items=[{"doc": doc.serialize()} for doc in self.read_docs()]) class MockDocSetReader(DocSetReader): """Mock out DocSetReader for tests."""