Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Local doc cache support for document reconstruct #1066

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 28 additions & 3 deletions lib/sycamore/sycamore/connectors/opensearch/opensearch_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from sycamore.data import Document, Element
from sycamore.connectors.base_reader import BaseDBReader
from sycamore.data.document import DocumentPropertyTypes, DocumentSource
from sycamore.utils.cache import Cache
from sycamore.utils.import_utils import requires_modules
from dataclasses import dataclass, field
import typing
Expand All @@ -23,6 +24,7 @@ class OpenSearchReaderQueryParams(BaseDBReader.QueryParams):
query: Dict = field(default_factory=lambda: {"query": {"match_all": {}}})
kwargs: Dict = field(default_factory=lambda: {})
reconstruct_document: bool = False
document_cache: Cache = None # For caching reconstructed documents


class OpenSearchReaderClient(BaseDBReader.Client):
Expand All @@ -47,6 +49,9 @@ def read_records(self, query_params: BaseDBReader.QueryParams) -> "OpenSearchRea
if "size" not in query_params.query and "size" not in query_params.kwargs:
query_params.kwargs["size"] = 200
result = []
# We only fetch the minimum required fields for full document retrieval/reconstruction
if query_params.reconstruct_document:
query_params.kwargs["_source_includes"] = "doc_id,parent_id,properties"
# No pagination needed for knn queries
if "query" in query_params.query and "knn" in query_params.query["query"]:
response = self._client.search(
Expand Down Expand Up @@ -93,6 +98,7 @@ class OpenSearchReaderQueryResponse(BaseDBReader.QueryResponse):
def to_docs(self, query_params: "BaseDBReader.QueryParams") -> list[Document]:
assert isinstance(query_params, OpenSearchReaderQueryParams)
result: list[Document] = []
index_name = query_params.index_name
if not query_params.reconstruct_document:
for data in self.output:
doc = Document(
Expand All @@ -117,6 +123,7 @@ def to_docs(self, query_params: "BaseDBReader.QueryParams") -> list[Document]:
# Get unique documents
unique_docs: dict[str, Document] = {}
query_result_elements_per_doc: dict[str, set[str]] = {}
cached_docs = {}
for data in self.output:
doc = Document(
{
Expand All @@ -126,9 +133,21 @@ def to_docs(self, query_params: "BaseDBReader.QueryParams") -> list[Document]:
doc.properties[DocumentPropertyTypes.SOURCE] = DocumentSource.DB_QUERY
assert doc.doc_id, "Retrieved invalid doc with missing doc_id"
if not doc.parent_id:
if query_params.document_cache is not None:
cached_doc = query_params.document_cache.get(f"{index_name}:{doc.doc_id}")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we probably want to move this get_key logic to a shared method?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes.

if cached_doc:
doc = cached_doc
cached_docs[doc.doc_id] = doc
# Always use retrieved doc as the unique parent doc - override any empty parent doc created below
unique_docs[doc.doc_id] = doc
else:
if query_params.document_cache is not None:
cached_doc = query_params.document_cache.get(f"{index_name}:{doc.parent_id}")
if cached_doc:
doc = cached_doc
cached_docs[doc.parent_id] = doc
unique_docs[doc.parent_id] = doc
continue
# Create empty parent documents if no parent document was in result set
unique_docs[doc.parent_id] = unique_docs.get(
doc.parent_id,
Expand All @@ -147,7 +166,11 @@ def to_docs(self, query_params: "BaseDBReader.QueryParams") -> list[Document]:
query_result_elements_per_doc[doc.parent_id] = elements

# Batched retrieval of all elements belong to unique docs
doc_ids = list(unique_docs.keys())
# doc_ids = list(unique_docs.keys())
doc_ids = [d for d in list(unique_docs.keys()) if d not in list(cached_docs.keys())]

# We can't safely exclude embeddings since we might need them for 'rerank', e.g.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to do embedding style reranking locally? Won't the knn result just be ordered? Or do we expect to reorder based on something else after

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No. When I wrote the comment, I thought we use embeddings for reranking. In general, though, I don't know if we can safely drop them in the context of document reconstruct. Maybe document reconstruct should be a set of options not just a boolean.

# We will need the Planner to determine that and pass that info to the reader.
all_elements_for_docs = self._get_all_elements_for_doc_ids(doc_ids, query_params.index_name)

"""
Expand All @@ -172,7 +195,10 @@ def to_docs(self, query_params: "BaseDBReader.QueryParams") -> list[Document]:

# sort elements per doc
for doc in result:
doc.elements.sort(key=lambda e: e.element_index if e.element_index is not None else float("inf"))
if doc.doc_id not in cached_docs:
doc.elements.sort(key=lambda e: e.element_index if e.element_index is not None else float("inf"))
if query_params.document_cache is not None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can cached_docs have anything if document_cache is None?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It'll be an empty dict.

query_params.document_cache.set(f"{index_name}:{doc.doc_id}", doc)

return result

Expand All @@ -183,7 +209,6 @@ def _get_all_elements_for_doc_ids(self, doc_ids: list[str], index: str) -> list[
"""
batch_size = 100
page_size = 500

all_elements = []
for i in range(0, len(doc_ids), batch_size):
doc_ids_batch = doc_ids[i : i + batch_size]
Expand Down
21 changes: 17 additions & 4 deletions lib/sycamore/sycamore/reader.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Optional, Union, Callable, Dict, Any
from typing import Optional, Union, Callable, Dict
from pathlib import Path

from pandas import DataFrame
Expand All @@ -11,6 +11,7 @@
from sycamore.data import Document
from sycamore.connectors.file import ArrowScan, BinaryScan, DocScan, PandasScan, JsonScan, JsonDocumentScan
from sycamore.connectors.file.file_scan import FileMetadataProvider
from sycamore.utils.cache import cache_from_path
from sycamore.utils.import_utils import requires_modules


Expand Down Expand Up @@ -224,7 +225,7 @@ def opensearch(
index_name: str,
query: Optional[Dict] = None,
reconstruct_document: bool = False,
query_kwargs: dict[str, Any] = {},
query_kwargs=None,
**kwargs,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you change to an optional? Just for consistency

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok.

) -> DocSet:
"""
Expand Down Expand Up @@ -287,14 +288,26 @@ def opensearch(
OpenSearchReaderQueryParams,
)

if query_kwargs is None:
query_kwargs = {}

document_cache_dir = kwargs.get("document_cache_dir", None)
document_cache = cache_from_path(document_cache_dir) if document_cache_dir else None
client_params = OpenSearchReaderClientParams(os_client_args=os_client_args)
query_params = (
OpenSearchReaderQueryParams(
index_name=index_name, query=query, reconstruct_document=reconstruct_document, kwargs=query_kwargs
index_name=index_name,
query=query,
reconstruct_document=reconstruct_document,
document_cache=document_cache,
kwargs=query_kwargs,
)
if query is not None
else OpenSearchReaderQueryParams(
index_name=index_name, reconstruct_document=reconstruct_document, kwargs=query_kwargs
index_name=index_name,
reconstruct_document=reconstruct_document,
document_cache=document_cache,
kwargs=query_kwargs,
)
)
osr = OpenSearchReader(client_params=client_params, query_params=query_params)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
import tempfile
import time

import pytest
Expand Down Expand Up @@ -110,3 +111,87 @@ def test_ingest_and_read(self, setup_index, exec_mode):

for i in range(len(doc.elements) - 1):
assert doc.elements[i].element_index < doc.elements[i + 1].element_index

def test_ingest_and_read_from_cache(self, setup_index, exec_mode):
"""
Validates data is readable from OpenSearch, and that we can rebuild processed Sycamore documents.
"""

path = str(TEST_DIR / "resources/data/pdfs/Ray.pdf")
context = sycamore.init(exec_mode=exec_mode)
original_docs = (
context.read.binary(path, binary_format="pdf")
.partition(partitioner=UnstructuredPdfPartitioner())
.explode()
.write.opensearch(
os_client_args=TestOpenSearchRead.OS_CLIENT_ARGS,
index_name=TestOpenSearchRead.INDEX,
index_settings=TestOpenSearchRead.INDEX_SETTINGS,
execute=False,
)
.take_all()
)

retrieved_docs = context.read.opensearch(
os_client_args=TestOpenSearchRead.OS_CLIENT_ARGS, index_name=TestOpenSearchRead.INDEX
)
target_doc_id = original_docs[-1].doc_id if original_docs[-1].doc_id else ""
query = {"query": {"term": {"_id": target_doc_id}}}
query_docs = context.read.opensearch(
os_client_args=TestOpenSearchRead.OS_CLIENT_ARGS, index_name=TestOpenSearchRead.INDEX, query=query
)

kwargs = {"document_cache_dir": os.path.join(tempfile.gettempdir(), "opensearch_reconstruct_documents")}
retrieved_docs_reconstructed = context.read.opensearch(
os_client_args=TestOpenSearchRead.OS_CLIENT_ARGS,
index_name=TestOpenSearchRead.INDEX,
reconstruct_document=True,
**kwargs
)
original_materialized = sorted(original_docs, key=lambda d: d.doc_id)

# hack to allow opensearch time to index the data. Without this it's possible we try to query the index before
# all the records are available
time.sleep(1)
retrieved_materialized = sorted(retrieved_docs.take_all(), key=lambda d: d.doc_id)
query_materialized = query_docs.take_all()
retrieved_materialized_reconstructed = sorted(retrieved_docs_reconstructed.take_all(), key=lambda d: d.doc_id)
parent_id = retrieved_materialized_reconstructed[0].doc_id

with OpenSearch(**TestOpenSearchRead.OS_CLIENT_ARGS) as os_client:
os_client.indices.delete(TestOpenSearchRead.INDEX)
os_client.indices.create(TestOpenSearchRead.INDEX, **TestOpenSearchRead.INDEX_SETTINGS)
# We only need a parent doc in the index to test the cache
os_client.index(index=TestOpenSearchRead.INDEX, id=parent_id, body={"doc_id": parent_id})
os_client.indices.refresh(TestOpenSearchRead.INDEX)

retrieved_docs_reconstructed_from_cache = context.read.opensearch(
os_client_args=TestOpenSearchRead.OS_CLIENT_ARGS,
index_name=TestOpenSearchRead.INDEX,
reconstruct_document=True,
**kwargs
)
retrieved_materialized_reconstructed_from_cache = sorted(
retrieved_docs_reconstructed_from_cache.take_all(), key=lambda d: d.doc_id
)

assert len(query_materialized) == 1 # exactly one doc should be returned
compare_connector_docs(original_materialized, retrieved_materialized)

assert len(retrieved_materialized_reconstructed) == 1
doc = retrieved_materialized_reconstructed[0]
assert len(doc.elements) == len(retrieved_materialized) - 1 # drop the document parent record

for i in range(len(doc.elements) - 1):
assert doc.elements[i].element_index < doc.elements[i + 1].element_index

assert len(retrieved_materialized_reconstructed_from_cache) == 1
doc = retrieved_materialized_reconstructed_from_cache[0]
for i in range(len(doc.elements) - 1):
assert doc.elements[i].element_index < doc.elements[i + 1].element_index

# Clean slate between Execution Modes
with OpenSearch(**TestOpenSearchRead.OS_CLIENT_ARGS) as os_client:
os_client.indices.delete(TestOpenSearchRead.INDEX)
os_client.indices.create(TestOpenSearchRead.INDEX, **TestOpenSearchRead.INDEX_SETTINGS)
os_client.indices.refresh(TestOpenSearchRead.INDEX)
Loading