From 47d4df4392a66768bf1ba0e74c1e5081308d2115 Mon Sep 17 00:00:00 2001 From: Austin Lee Date: Mon, 18 Nov 2024 01:27:57 -0800 Subject: [PATCH 1/2] Allow OpenSearchReader to output to MaterializedDataset consisting of refs --- .../sycamore/connectors/base_reader.py | 16 +- .../opensearch/opensearch_reader.py | 204 +++++++++++++++++- lib/sycamore/sycamore/reader.py | 2 +- .../opensearch/test_opensearch_read.py | 61 +++++- 4 files changed, 271 insertions(+), 12 deletions(-) diff --git a/lib/sycamore/sycamore/connectors/base_reader.py b/lib/sycamore/sycamore/connectors/base_reader.py index bd27c6534..d6293f3c8 100644 --- a/lib/sycamore/sycamore/connectors/base_reader.py +++ b/lib/sycamore/sycamore/connectors/base_reader.py @@ -3,6 +3,9 @@ from abc import ABC, abstractmethod +from ray.data.dataset import MaterializedDataset + +# from sycamore.connectors.opensearch.opensearch_dataset import OpenSearchMaterializedDataset from sycamore.data.document import Document from sycamore.plan_nodes import Scan from sycamore.utils.time_trace import TimeTrace @@ -34,7 +37,7 @@ def close(self): # Type param for the objects that are read from the db class QueryResponse(ABC): @abstractmethod - def to_docs(self, query_params: "BaseDBReader.QueryParams") -> list[Document]: + def to_docs(self, query_params: "BaseDBReader.QueryParams", use_refs: bool = False) -> list[Document]: pass # Type param for the object used to estabilish the read target @@ -58,6 +61,7 @@ def __init__( super().__init__(**kwargs) self._client_params = client_params self._query_params = query_params + self._use_refs = kwargs.get('use_refs', False) def read_docs(self) -> list[Document]: try: @@ -66,7 +70,7 @@ def read_docs(self) -> list[Document]: if not client.check_target_presence(self._query_params): raise ValueError("Target is not present\n" f"Parameters: {self._query_params}\n") records = client.read_records(query_params=self._query_params) - docs = records.to_docs(query_params=self._query_params) + docs = records.to_docs(query_params=self._query_params, use_refs=self._use_refs) except Exception as e: raise ValueError(f"Error reading from target: {e}") finally: @@ -80,10 +84,16 @@ def execute(self, **kwargs) -> "Dataset": from ray.data import from_items with TimeTrace("Reader"): - return from_items(items=[{"doc": doc.serialize()} for doc in self.read_docs()]) + ds: MaterializedDataset = from_items(items=[{"doc": doc.serialize()} for doc in self.read_docs()]) + if self._use_refs: + return self.read_datasource(ds) + return ds def local_source(self) -> list[Document]: return self.read_docs() def format(self): return "reader" + + def read_datasource(self, ds: "Dataset") -> "Dataset": + pass \ No newline at end of file diff --git a/lib/sycamore/sycamore/connectors/opensearch/opensearch_reader.py b/lib/sycamore/sycamore/connectors/opensearch/opensearch_reader.py index 769d481c3..6f4480f90 100644 --- a/lib/sycamore/sycamore/connectors/opensearch/opensearch_reader.py +++ b/lib/sycamore/sycamore/connectors/opensearch/opensearch_reader.py @@ -1,5 +1,9 @@ import logging +from datasets import Dataset +from ray.data.dataset import MaterializedDataset + +# from sycamore.connectors.opensearch.opensearch_dataset import OpenSearchDocument, OpenSearchMaterializedDataset from sycamore.data import Document, Element from sycamore.connectors.base_reader import BaseDBReader from sycamore.data.document import DocumentPropertyTypes, DocumentSource @@ -47,6 +51,8 @@ def read_records(self, query_params: BaseDBReader.QueryParams) -> "OpenSearchRea query_params.kwargs["scroll"] = "10m" if "size" not in query_params.query and "size" not in query_params.kwargs: query_params.kwargs["size"] = 200 + + query_params.kwargs['_source_includes'] = ['doc_id', 'parent_id', 'properties'] logging.debug(f"OpenSearch query on {query_params.index_name}: {query_params.query}") response = self._client.search(index=query_params.index_name, body=query_params.query, **query_params.kwargs) scroll_id = response["_scroll_id"] @@ -78,16 +84,22 @@ class OpenSearchReaderQueryResponse(BaseDBReader.QueryResponse): """ client: typing.Optional["OpenSearch"] = None - def to_docs(self, query_params: "BaseDBReader.QueryParams") -> list[Document]: + use_refs: bool = False + + def to_docs(self, query_params: "BaseDBReader.QueryParams", use_refs: bool = False) -> list[Document]: assert isinstance(query_params, OpenSearchReaderQueryParams) result: list[Document] = [] if not query_params.reconstruct_document: for data in self.output: - doc = Document( + doc = OpenSearchDocument( + query_params.index_name, + **data.get("_source", {}), + ) if use_refs else Document( { **data.get("_source", {}), } ) + doc.properties[DocumentPropertyTypes.SOURCE] = DocumentSource.DB_QUERY doc.properties["score"] = data["_score"] result.append(doc) @@ -136,7 +148,11 @@ def to_docs(self, query_params: "BaseDBReader.QueryParams") -> list[Document]: # Batched retrieval of all elements belong to unique docs doc_ids = list(unique_docs.keys()) - all_elements_for_docs = self._get_all_elements_for_doc_ids(doc_ids, query_params.index_name) + + if use_refs: + return [OpenSearchDocument(query_params.index_name, **doc.data) for doc in unique_docs.values()] + + all_elements_for_docs = self._get_all_elements_for_doc_ids(doc_ids, query_params.index_name, use_refs) """ Add elements to unique docs. If they were not part of the original result, @@ -148,6 +164,7 @@ def to_docs(self, query_params: "BaseDBReader.QueryParams") -> list[Document]: **element.get("_source", {}), } ) + assert doc.parent_id, "Got non-element record from OpenSearch reconstruction query" if doc.doc_id not in query_result_elements_per_doc.get(doc.parent_id, {}): doc.properties[DocumentPropertyTypes.SOURCE] = DocumentSource.DOCUMENT_RECONSTRUCTION_RETRIEVAL @@ -164,7 +181,7 @@ def to_docs(self, query_params: "BaseDBReader.QueryParams") -> list[Document]: return result - def _get_all_elements_for_doc_ids(self, doc_ids: list[str], index: str) -> list[typing.Any]: + def _get_all_elements_for_doc_ids(self, doc_ids: list[str], index: str, use_refs: bool = False) -> list[typing.Any]: assert self.client, "_get_all_elements_for_doc_ids requires an OpenSearch client instance in this class" """ Returns all records in OpenSearch belonging to a list of Document ids (element.parent_id) @@ -172,6 +189,10 @@ def _get_all_elements_for_doc_ids(self, doc_ids: list[str], index: str) -> list[ batch_size = 100 page_size = 500 + query_params = {} + if use_refs: + query_params["_source_includes"] = ["doc_id", "parent_id"] + all_elements = [] for i in range(0, len(doc_ids), batch_size): doc_ids_batch = doc_ids[i : i + batch_size] @@ -182,7 +203,7 @@ def _get_all_elements_for_doc_ids(self, doc_ids: list[str], index: str) -> list[ "size": page_size, "from": from_offset, } - response = self.client.search(index=index, body=query) + response = self.client.search(index=index, body=query, **query_params) hits = response["hits"]["hits"] all_elements.extend(hits) if len(hits) < page_size: @@ -196,3 +217,176 @@ class OpenSearchReader(BaseDBReader): Record = OpenSearchReaderQueryResponse ClientParams = OpenSearchReaderClientParams QueryParams = OpenSearchReaderQueryParams + + def read_datasource(self, ds) -> "Dataset": + return OpenSearchMaterializedDataset(ds, self._client_params, self._query_params) + +class OpenSearchDocument(Document): + def __init__(self, index_name: str, **data): + super().__init__(**data) + self.index_name = index_name + + @classmethod + def from_document(cls, doc: Document) -> "OpenSearchDocument": + pass + + def serialize(self) -> bytes: + """Serialize this document to bytes.""" + from pickle import dumps + + return dumps({"index_name": self.index_name, "data": self.data}) + + @staticmethod + def deserialize(raw: bytes) -> "OpenSearchDocument": + """Deserialize from bytes to a OpenSearchDocument.""" + from pickle import loads + + data = loads(raw) + # print(f"Deserialized data: {data}") + return OpenSearchDocument(data["index_name"], **data["data"]) + + +class OpenSearchMaterializedDataset(MaterializedDataset): + + def __init__(self, ds: MaterializedDataset, os_client_params: BaseDBReader.ClientParams, os_query_params: BaseDBReader.QueryParams): + # self.client = client + self.os_client_params = os_client_params + self.os_query_params = os_query_params + super().__init__(ds._plan, ds._logical_plan) + + @staticmethod + def _get_all_elements_for_doc_ids(os_client: "OpenSearch", doc_ids: list[str], index: str, exclude_embedding=True) -> list[typing.Any]: + """ + Returns all records in OpenSearch belonging to a list of Document ids (element.parent_id) + """ + batch_size = 100 + page_size = 500 + + query_params = {} + if exclude_embedding: + query_params["_source_excludes"] = ["embedding"] + + all_elements = [] + for i in range(0, len(doc_ids), batch_size): + doc_ids_batch = doc_ids[i : i + batch_size] + from_offset = 0 + while True: + query = { + "query": {"terms": {"parent_id.keyword": doc_ids_batch}}, + "size": page_size, + "from": from_offset, + } + response = os_client.search(index=index, body=query, **query_params) + hits = response["hits"]["hits"] + all_elements.extend(hits) + if len(hits) < page_size: + break + from_offset += page_size + + logging.debug(f"Got all elements: {len(all_elements)}") + return all_elements + + def to_docs(self, client, index_name, docs: list[OpenSearchDocument], reconstruct_document: bool = True) -> list[Document]: + """ + + """ + logging.debug(f"In to_docs, reconstruct_document: {reconstruct_document}") + + result: list[Document] = [] + if not reconstruct_document: + doc_ids = [doc.data['doc_id'] for doc in docs] + res = client.mget(index=index_name, body={"ids": doc_ids}) + for data in res['docs']: + doc = Document( + { + **data.get("_source", {}), + } + ) + doc.properties[DocumentPropertyTypes.SOURCE] = DocumentSource.DB_QUERY + # doc.properties["score"] = data["_score"] + result.append(doc) + else: + """ + Document reconstruction: + 1. Construct a map of all unique parent Documents (i.e. no parent_id field) + 1.1 If we find doc_ids without parent documents, we create empty parent Documents + 2. Perform a terms query to retrieve all (including non-matched) other records for that parent_id + 3. Add elements to unique parent Documents + """ + # Get unique documents + unique_docs: dict[str, Document] = {} + query_result_elements_per_doc: dict[str, set[str]] = {} + for data in docs: + doc = Document( + { + **data.data, + } + ) + doc.properties[DocumentPropertyTypes.SOURCE] = DocumentSource.DB_QUERY + assert doc.doc_id, "Retrieved invalid doc with missing doc_id" + if not doc.parent_id: + # Always use retrieved doc as the unique parent doc - override any empty parent doc created below + unique_docs[doc.doc_id] = doc + else: + # Create empty parent documents if no parent document was in result set + unique_docs[doc.parent_id] = unique_docs.get( + doc.parent_id, + Document( + { + "doc_id": doc.parent_id, + "properties": { + **doc.properties, + DocumentPropertyTypes.SOURCE: DocumentSource.DOCUMENT_RECONSTRUCTION_PARENT, + }, + } + ), + ) + elements = query_result_elements_per_doc.get(doc.parent_id, set()) + elements.add(doc.doc_id) + query_result_elements_per_doc[doc.parent_id] = elements + + # Batched retrieval of all elements belong to unique docs + doc_ids = list(unique_docs.keys()) + all_elements_for_docs = self._get_all_elements_for_doc_ids(client, doc_ids, index_name) + + """ + Add elements to unique docs. If they were not part of the original result, + we set properties.DocumentPropertyTypes.SOURCE = DOCUMENT_RECONSTRUCTION_RETRIEVAL + """ + for element in all_elements_for_docs: + doc = Document( + { + **element.get("_source", {}), + } + ) + + assert doc.parent_id, "Got non-element record from OpenSearch reconstruction query" + if doc.doc_id not in query_result_elements_per_doc.get(doc.parent_id, {}): + doc.properties[DocumentPropertyTypes.SOURCE] = DocumentSource.DOCUMENT_RECONSTRUCTION_RETRIEVAL + else: + doc.properties[DocumentPropertyTypes.SOURCE] = DocumentSource.DB_QUERY + parent = unique_docs[doc.parent_id] + parent.elements.append(Element(doc.data)) + + result = list(unique_docs.values()) + + # sort elements per doc + for doc in result: + doc.elements.sort(key=lambda e: e.element_index if e.element_index is not None else float("inf")) + + return result + + def iter_rows( + self, *, prefetch_batches: int = 1, prefetch_blocks: int = 0 + ) -> typing.Iterable[typing.Dict[str, typing.Any]]: + os_client = OpenSearchReaderClient.from_client_params(self.os_client_params)._client + + os_docs = [] + for row in super().iter_rows(): + os_doc = OpenSearchDocument.deserialize(row["doc"]) + index_name = os_doc.index_name + os_docs.append(os_doc) + docs = self.to_docs(os_client, index_name, os_docs, self.os_query_params.reconstruct_document) + for doc in docs: + print(f"Yielding doc: {doc}") + yield {"doc": doc.serialize()} diff --git a/lib/sycamore/sycamore/reader.py b/lib/sycamore/sycamore/reader.py index e46d4a090..545800a7c 100644 --- a/lib/sycamore/sycamore/reader.py +++ b/lib/sycamore/sycamore/reader.py @@ -297,7 +297,7 @@ def opensearch( index_name=index_name, reconstruct_document=reconstruct_document, kwargs=query_kwargs ) ) - osr = OpenSearchReader(client_params=client_params, query_params=query_params) + osr = OpenSearchReader(client_params=client_params, query_params=query_params, **kwargs) return DocSet(self._context, osr) @requires_modules("duckdb", extra="duckdb") diff --git a/lib/sycamore/sycamore/tests/integration/connectors/opensearch/test_opensearch_read.py b/lib/sycamore/sycamore/tests/integration/connectors/opensearch/test_opensearch_read.py index 8549ecb1c..edb02a849 100644 --- a/lib/sycamore/sycamore/tests/integration/connectors/opensearch/test_opensearch_read.py +++ b/lib/sycamore/sycamore/tests/integration/connectors/opensearch/test_opensearch_read.py @@ -1,13 +1,17 @@ +import os import time import pytest from opensearchpy import OpenSearch import sycamore +from sycamore.connectors.opensearch.opensearch_reader import OpenSearchReaderClient, \ + OpenSearchReaderClientParams from sycamore.tests.integration.connectors.common import compare_connector_docs from sycamore.tests.config import TEST_DIR from sycamore.transforms.partition import UnstructuredPdfPartitioner +os_admin_password = os.getenv("OS_ADMIN_PASSWORD", "admin") @pytest.fixture(scope="class") def setup_index(): @@ -45,7 +49,7 @@ class TestOpenSearchRead: OS_CLIENT_ARGS = { "hosts": [{"host": "localhost", "port": 9200}], "http_compress": True, - "http_auth": ("admin", "admin"), + "http_auth": ("admin", os_admin_password), "use_ssl": True, "verify_certs": False, "ssl_assert_hostname": False, @@ -73,19 +77,22 @@ def test_ingest_and_read(self, setup_index, exec_mode): .take_all() ) + kwargs = {'use_refs': True} + retrieved_docs = context.read.opensearch( - os_client_args=TestOpenSearchRead.OS_CLIENT_ARGS, index_name=TestOpenSearchRead.INDEX + os_client_args=TestOpenSearchRead.OS_CLIENT_ARGS, index_name=TestOpenSearchRead.INDEX, **kwargs ) target_doc_id = original_docs[-1].doc_id if original_docs[-1].doc_id else "" query = {"query": {"term": {"_id": target_doc_id}}} query_docs = context.read.opensearch( - os_client_args=TestOpenSearchRead.OS_CLIENT_ARGS, index_name=TestOpenSearchRead.INDEX, query=query + os_client_args=TestOpenSearchRead.OS_CLIENT_ARGS, index_name=TestOpenSearchRead.INDEX, query=query, **kwargs ) retrieved_docs_reconstructed = context.read.opensearch( os_client_args=TestOpenSearchRead.OS_CLIENT_ARGS, index_name=TestOpenSearchRead.INDEX, reconstruct_document=True, + **kwargs ) original_materialized = sorted(original_docs, key=lambda d: d.doc_id) @@ -107,3 +114,51 @@ def test_ingest_and_read(self, setup_index, exec_mode): for i in range(len(doc.elements) - 1): assert doc.elements[i].element_index < doc.elements[i + 1].element_index + + def test_ingest_and_count(self, setup_index, exec_mode): + """ + Validates data is readable from OpenSearch, and that we can rebuild processed Sycamore documents. + """ + + client = OpenSearch(**TestOpenSearchRead.OS_CLIENT_ARGS) + + path = str(TEST_DIR / "resources/data/pdfs/Ray.pdf") + context = sycamore.init(exec_mode=exec_mode) + original_docs = ( + context.read.binary(path, binary_format="pdf") + .partition(partitioner=UnstructuredPdfPartitioner()) + .explode() + .write.opensearch( + os_client_args=TestOpenSearchRead.OS_CLIENT_ARGS, + index_name=TestOpenSearchRead.INDEX, + index_settings=TestOpenSearchRead.INDEX_SETTINGS, + execute=False, + ) + .take_all() + ) + + client.indices.refresh(index=TestOpenSearchRead.INDEX) + + use_refs = False + kwargs = {'use_refs': use_refs} + + query = {"query": {"match_all": {}}} + ds1 = context.read.opensearch( + os_client_args=TestOpenSearchRead.OS_CLIENT_ARGS, index_name=TestOpenSearchRead.INDEX, query=query, **kwargs + ).take_all() + + print(f"ExecMode: {exec_mode}, count: {len(ds1)}") + + ds2 = context.read.opensearch( + os_client_args=TestOpenSearchRead.OS_CLIENT_ARGS, + index_name=TestOpenSearchRead.INDEX, + query=query, + reconstruct_document=True, + **kwargs + ).take_all() # count() + + print(f"ExecMode: {exec_mode}, count2: {len(ds2)}") + + assert len(ds2) == 1 + assert len(ds1) == 580 + From b68d2bb1f59dba8abdacd72af2ddd17e30f940a1 Mon Sep 17 00:00:00 2001 From: Austin Lee Date: Mon, 18 Nov 2024 09:28:22 -0800 Subject: [PATCH 2/2] Fix a bug in iter_rows. --- .../opensearch/opensearch_reader.py | 11 ++++--- .../opensearch/test_opensearch_read.py | 29 ++++++++++++++----- 2 files changed, 28 insertions(+), 12 deletions(-) diff --git a/lib/sycamore/sycamore/connectors/opensearch/opensearch_reader.py b/lib/sycamore/sycamore/connectors/opensearch/opensearch_reader.py index 6f4480f90..96b9c2dad 100644 --- a/lib/sycamore/sycamore/connectors/opensearch/opensearch_reader.py +++ b/lib/sycamore/sycamore/connectors/opensearch/opensearch_reader.py @@ -68,6 +68,9 @@ def read_records(self, query_params: BaseDBReader.QueryParams) -> "OpenSearchRea response = self._client.scroll(scroll_id=scroll_id, scroll=query_params.kwargs["scroll"]) finally: self._client.clear_scroll(scroll_id=scroll_id) + + logging.debug(f"Got {len(result)} records") + return OpenSearchReaderQueryResponse(result, self._client) def check_target_presence(self, query_params: BaseDBReader.QueryParams): @@ -179,6 +182,8 @@ def to_docs(self, query_params: "BaseDBReader.QueryParams", use_refs: bool = Fal for doc in result: doc.elements.sort(key=lambda e: e.element_index if e.element_index is not None else float("inf")) + logging.debug(f"Returning {len(result)} documents") + return result def _get_all_elements_for_doc_ids(self, doc_ids: list[str], index: str, use_refs: bool = False) -> list[typing.Any]: @@ -381,12 +386,10 @@ def iter_rows( ) -> typing.Iterable[typing.Dict[str, typing.Any]]: os_client = OpenSearchReaderClient.from_client_params(self.os_client_params)._client - os_docs = [] for row in super().iter_rows(): os_doc = OpenSearchDocument.deserialize(row["doc"]) index_name = os_doc.index_name - os_docs.append(os_doc) - docs = self.to_docs(os_client, index_name, os_docs, self.os_query_params.reconstruct_document) + docs = self.to_docs(os_client, index_name, [os_doc], self.os_query_params.reconstruct_document) for doc in docs: - print(f"Yielding doc: {doc}") + # print(f"Yielding doc {counter}: {doc}") yield {"doc": doc.serialize()} diff --git a/lib/sycamore/sycamore/tests/integration/connectors/opensearch/test_opensearch_read.py b/lib/sycamore/sycamore/tests/integration/connectors/opensearch/test_opensearch_read.py index edb02a849..2abed5bab 100644 --- a/lib/sycamore/sycamore/tests/integration/connectors/opensearch/test_opensearch_read.py +++ b/lib/sycamore/sycamore/tests/integration/connectors/opensearch/test_opensearch_read.py @@ -77,6 +77,9 @@ def test_ingest_and_read(self, setup_index, exec_mode): .take_all() ) + client = OpenSearch(**TestOpenSearchRead.OS_CLIENT_ARGS) + client.indices.refresh(index=TestOpenSearchRead.INDEX) + kwargs = {'use_refs': True} retrieved_docs = context.read.opensearch( @@ -115,6 +118,10 @@ def test_ingest_and_read(self, setup_index, exec_mode): for i in range(len(doc.elements) - 1): assert doc.elements[i].element_index < doc.elements[i + 1].element_index + # Clean up + client.delete_by_query(index=TestOpenSearchRead.INDEX, body={"query": {"match_all": {}}}) + client.indices.refresh(index=TestOpenSearchRead.INDEX) + def test_ingest_and_count(self, setup_index, exec_mode): """ Validates data is readable from OpenSearch, and that we can rebuild processed Sycamore documents. @@ -139,15 +146,18 @@ def test_ingest_and_count(self, setup_index, exec_mode): client.indices.refresh(index=TestOpenSearchRead.INDEX) - use_refs = False + use_refs = True kwargs = {'use_refs': use_refs} query = {"query": {"match_all": {}}} + expected = client.count(index=TestOpenSearchRead.INDEX, body=query)["count"] + ds1 = context.read.opensearch( os_client_args=TestOpenSearchRead.OS_CLIENT_ARGS, index_name=TestOpenSearchRead.INDEX, query=query, **kwargs - ).take_all() + ).count() # take_all() - print(f"ExecMode: {exec_mode}, count: {len(ds1)}") + # print(f"ExecMode: {exec_mode}, count: {len(ds1)}") + print(f"ExecMode: {exec_mode}, count: {ds1}") ds2 = context.read.opensearch( os_client_args=TestOpenSearchRead.OS_CLIENT_ARGS, @@ -155,10 +165,13 @@ def test_ingest_and_count(self, setup_index, exec_mode): query=query, reconstruct_document=True, **kwargs - ).take_all() # count() + ).count() - print(f"ExecMode: {exec_mode}, count2: {len(ds2)}") - - assert len(ds2) == 1 - assert len(ds1) == 580 + print(f"ExecMode: {exec_mode}, count2: {ds2}") + + assert ds2 == 1 + assert ds1 == expected + # Clean up + client.delete_by_query(index=TestOpenSearchRead.INDEX, body={"query": {"match_all": {}}}) + client.indices.refresh(index=TestOpenSearchRead.INDEX)