-
Notifications
You must be signed in to change notification settings - Fork 49
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Local doc cache support for document reconstruct #1066
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3,6 +3,7 @@ | |
from sycamore.data import Document, Element | ||
from sycamore.connectors.base_reader import BaseDBReader | ||
from sycamore.data.document import DocumentPropertyTypes, DocumentSource | ||
from sycamore.utils.cache import Cache | ||
from sycamore.utils.import_utils import requires_modules | ||
from dataclasses import dataclass, field | ||
import typing | ||
|
@@ -23,6 +24,7 @@ class OpenSearchReaderQueryParams(BaseDBReader.QueryParams): | |
query: Dict = field(default_factory=lambda: {"query": {"match_all": {}}}) | ||
kwargs: Dict = field(default_factory=lambda: {}) | ||
reconstruct_document: bool = False | ||
document_cache: Cache = None # For caching reconstructed documents | ||
|
||
|
||
class OpenSearchReaderClient(BaseDBReader.Client): | ||
|
@@ -47,6 +49,9 @@ def read_records(self, query_params: BaseDBReader.QueryParams) -> "OpenSearchRea | |
if "size" not in query_params.query and "size" not in query_params.kwargs: | ||
query_params.kwargs["size"] = 200 | ||
result = [] | ||
# We only fetch the minimum required fields for full document retrieval/reconstruction | ||
if query_params.reconstruct_document: | ||
query_params.kwargs["_source_includes"] = "doc_id,parent_id,properties" | ||
# No pagination needed for knn queries | ||
if "query" in query_params.query and "knn" in query_params.query["query"]: | ||
response = self._client.search( | ||
|
@@ -93,6 +98,7 @@ class OpenSearchReaderQueryResponse(BaseDBReader.QueryResponse): | |
def to_docs(self, query_params: "BaseDBReader.QueryParams") -> list[Document]: | ||
assert isinstance(query_params, OpenSearchReaderQueryParams) | ||
result: list[Document] = [] | ||
index_name = query_params.index_name | ||
if not query_params.reconstruct_document: | ||
for data in self.output: | ||
doc = Document( | ||
|
@@ -117,6 +123,7 @@ def to_docs(self, query_params: "BaseDBReader.QueryParams") -> list[Document]: | |
# Get unique documents | ||
unique_docs: dict[str, Document] = {} | ||
query_result_elements_per_doc: dict[str, set[str]] = {} | ||
cached_docs = {} | ||
for data in self.output: | ||
doc = Document( | ||
{ | ||
|
@@ -126,9 +133,21 @@ def to_docs(self, query_params: "BaseDBReader.QueryParams") -> list[Document]: | |
doc.properties[DocumentPropertyTypes.SOURCE] = DocumentSource.DB_QUERY | ||
assert doc.doc_id, "Retrieved invalid doc with missing doc_id" | ||
if not doc.parent_id: | ||
if query_params.document_cache is not None: | ||
cached_doc = query_params.document_cache.get(f"{index_name}:{doc.doc_id}") | ||
if cached_doc: | ||
doc = cached_doc | ||
cached_docs[doc.doc_id] = doc | ||
# Always use retrieved doc as the unique parent doc - override any empty parent doc created below | ||
unique_docs[doc.doc_id] = doc | ||
else: | ||
if query_params.document_cache is not None: | ||
cached_doc = query_params.document_cache.get(f"{index_name}:{doc.parent_id}") | ||
if cached_doc: | ||
doc = cached_doc | ||
cached_docs[doc.parent_id] = doc | ||
unique_docs[doc.parent_id] = doc | ||
continue | ||
# Create empty parent documents if no parent document was in result set | ||
unique_docs[doc.parent_id] = unique_docs.get( | ||
doc.parent_id, | ||
|
@@ -147,7 +166,11 @@ def to_docs(self, query_params: "BaseDBReader.QueryParams") -> list[Document]: | |
query_result_elements_per_doc[doc.parent_id] = elements | ||
|
||
# Batched retrieval of all elements belong to unique docs | ||
doc_ids = list(unique_docs.keys()) | ||
# doc_ids = list(unique_docs.keys()) | ||
doc_ids = [d for d in list(unique_docs.keys()) if d not in list(cached_docs.keys())] | ||
|
||
# We can't safely exclude embeddings since we might need them for 'rerank', e.g. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do we need to do embedding style reranking locally? Won't the knn result just be ordered? Or do we expect to reorder based on something else after There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No. When I wrote the comment, I thought we use embeddings for reranking. In general, though, I don't know if we can safely drop them in the context of document reconstruct. Maybe document reconstruct should be a set of options not just a boolean. |
||
# We will need the Planner to determine that and pass that info to the reader. | ||
all_elements_for_docs = self._get_all_elements_for_doc_ids(doc_ids, query_params.index_name) | ||
|
||
""" | ||
|
@@ -172,7 +195,10 @@ def to_docs(self, query_params: "BaseDBReader.QueryParams") -> list[Document]: | |
|
||
# sort elements per doc | ||
for doc in result: | ||
doc.elements.sort(key=lambda e: e.element_index if e.element_index is not None else float("inf")) | ||
if doc.doc_id not in cached_docs: | ||
doc.elements.sort(key=lambda e: e.element_index if e.element_index is not None else float("inf")) | ||
if query_params.document_cache is not None: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. can cached_docs have anything if document_cache is None? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It'll be an empty dict. |
||
query_params.document_cache.set(f"{index_name}:{doc.doc_id}", doc) | ||
|
||
return result | ||
|
||
|
@@ -183,7 +209,6 @@ def _get_all_elements_for_doc_ids(self, doc_ids: list[str], index: str) -> list[ | |
""" | ||
batch_size = 100 | ||
page_size = 500 | ||
|
||
all_elements = [] | ||
for i in range(0, len(doc_ids), batch_size): | ||
doc_ids_batch = doc_ids[i : i + batch_size] | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,4 @@ | ||
from typing import Optional, Union, Callable, Dict, Any | ||
from typing import Optional, Union, Callable, Dict | ||
from pathlib import Path | ||
|
||
from pandas import DataFrame | ||
|
@@ -11,6 +11,7 @@ | |
from sycamore.data import Document | ||
from sycamore.connectors.file import ArrowScan, BinaryScan, DocScan, PandasScan, JsonScan, JsonDocumentScan | ||
from sycamore.connectors.file.file_scan import FileMetadataProvider | ||
from sycamore.utils.cache import cache_from_path | ||
from sycamore.utils.import_utils import requires_modules | ||
|
||
|
||
|
@@ -224,7 +225,7 @@ def opensearch( | |
index_name: str, | ||
query: Optional[Dict] = None, | ||
reconstruct_document: bool = False, | ||
query_kwargs: dict[str, Any] = {}, | ||
query_kwargs=None, | ||
**kwargs, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. can you change to an optional? Just for consistency There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ok. |
||
) -> DocSet: | ||
""" | ||
|
@@ -287,14 +288,26 @@ def opensearch( | |
OpenSearchReaderQueryParams, | ||
) | ||
|
||
if query_kwargs is None: | ||
query_kwargs = {} | ||
|
||
document_cache_dir = kwargs.get("document_cache_dir", None) | ||
document_cache = cache_from_path(document_cache_dir) if document_cache_dir else None | ||
client_params = OpenSearchReaderClientParams(os_client_args=os_client_args) | ||
query_params = ( | ||
OpenSearchReaderQueryParams( | ||
index_name=index_name, query=query, reconstruct_document=reconstruct_document, kwargs=query_kwargs | ||
index_name=index_name, | ||
query=query, | ||
reconstruct_document=reconstruct_document, | ||
document_cache=document_cache, | ||
kwargs=query_kwargs, | ||
) | ||
if query is not None | ||
else OpenSearchReaderQueryParams( | ||
index_name=index_name, reconstruct_document=reconstruct_document, kwargs=query_kwargs | ||
index_name=index_name, | ||
reconstruct_document=reconstruct_document, | ||
document_cache=document_cache, | ||
kwargs=query_kwargs, | ||
) | ||
) | ||
osr = OpenSearchReader(client_params=client_params, query_params=query_params) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we probably want to move this get_key logic to a shared method?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes.