-
Notifications
You must be signed in to change notification settings - Fork 49
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Allow Doc reconstruct via function #1072
Changes from 8 commits
1b03c8b
4797add
4fea3d7
dec7b43
3d043e1
8627035
c0c3674
4026a52
7890fe0
3bf9cc9
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
from typing import Callable | ||
|
||
from sycamore.data import Document | ||
|
||
|
||
class DocumentReconstructor: | ||
def __init__(self, index_name: str, reconstruct_fn: Callable[[str, str], Document]): | ||
self.index_name = index_name | ||
self.reconstruct_fn = reconstruct_fn | ||
|
||
def get_required_source_fields(self) -> list[str]: | ||
return ["parent_id"] | ||
|
||
def get_doc_id(self, data: dict) -> str: | ||
return data["_source"]["parent_id"] or data["_id"] | ||
|
||
def reconstruct(self, data) -> Document: | ||
return self.reconstruct_fn(self.index_name, self.get_doc_id(data)) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,12 +1,13 @@ | ||
import logging | ||
|
||
from sycamore.connectors.doc_reconstruct import DocumentReconstructor | ||
from sycamore.data import Document, Element | ||
from sycamore.connectors.base_reader import BaseDBReader | ||
from sycamore.data.document import DocumentPropertyTypes, DocumentSource | ||
from sycamore.utils.import_utils import requires_modules | ||
from dataclasses import dataclass, field | ||
import typing | ||
from typing import Dict | ||
from typing import Dict, Optional | ||
|
||
if typing.TYPE_CHECKING: | ||
from opensearchpy import OpenSearch | ||
|
@@ -20,9 +21,10 @@ class OpenSearchReaderClientParams(BaseDBReader.ClientParams): | |
@dataclass | ||
class OpenSearchReaderQueryParams(BaseDBReader.QueryParams): | ||
index_name: str | ||
query: Dict = field(default_factory=lambda: {"query": {"match_all": {}}}) | ||
query: Dict | ||
kwargs: Dict = field(default_factory=lambda: {}) | ||
reconstruct_document: bool = False | ||
doc_reconstructor: Optional[DocumentReconstructor] = None | ||
|
||
|
||
class OpenSearchReaderClient(BaseDBReader.Client): | ||
|
@@ -47,6 +49,10 @@ def read_records(self, query_params: BaseDBReader.QueryParams) -> "OpenSearchRea | |
if "size" not in query_params.query and "size" not in query_params.kwargs: | ||
query_params.kwargs["size"] = 200 | ||
result = [] | ||
if query_params.reconstruct_document: | ||
eric-anderson marked this conversation as resolved.
Show resolved
Hide resolved
|
||
query_params.kwargs["_source_includes"] = ["doc_id", "parent_id", "properties"] | ||
if query_params.doc_reconstructor is not None: | ||
query_params.kwargs["_source_includes"] = query_params.doc_reconstructor.get_required_source_fields() | ||
# No pagination needed for knn queries | ||
if "query" in query_params.query and "knn" in query_params.query["query"]: | ||
response = self._client.search( | ||
|
@@ -93,7 +99,15 @@ class OpenSearchReaderQueryResponse(BaseDBReader.QueryResponse): | |
def to_docs(self, query_params: "BaseDBReader.QueryParams") -> list[Document]: | ||
assert isinstance(query_params, OpenSearchReaderQueryParams) | ||
result: list[Document] = [] | ||
if not query_params.reconstruct_document: | ||
if query_params.doc_reconstructor is not None: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. How are you planning to implement parallel reconstruction? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe we can discuss this idea in #1029 |
||
logging.info("Using DocID to Document reconstructor") | ||
unique = set() | ||
for data in self.output: | ||
doc_id = query_params.doc_reconstructor.get_doc_id(data) | ||
if doc_id not in unique: | ||
result.append(query_params.doc_reconstructor.reconstruct(data)) | ||
unique.add(doc_id) | ||
elif not query_params.reconstruct_document: | ||
for data in self.output: | ||
doc = Document( | ||
{ | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,10 +1,11 @@ | ||
from typing import Optional, Union, Callable, Dict, Any | ||
from typing import Optional, Union, Callable, Dict | ||
from pathlib import Path | ||
|
||
from pandas import DataFrame | ||
from pyarrow import Table | ||
from pyarrow.filesystem import FileSystem | ||
|
||
from sycamore.connectors.doc_reconstruct import DocumentReconstructor | ||
from sycamore.context import context_params | ||
from sycamore.plan_nodes import Node | ||
from sycamore import Context, DocSet | ||
|
@@ -224,7 +225,8 @@ def opensearch( | |
index_name: str, | ||
query: Optional[Dict] = None, | ||
reconstruct_document: bool = False, | ||
query_kwargs: dict[str, Any] = {}, | ||
doc_reconstructor: Optional[DocumentReconstructor] = None, | ||
query_kwargs=None, | ||
**kwargs, | ||
) -> DocSet: | ||
""" | ||
|
@@ -240,6 +242,7 @@ def opensearch( | |
reconstruct_document: Used to decide whether the returned DocSet comprises reconstructed documents, | ||
i.e. by collecting all elements belong to a single parent document (parent_id). This requires OpenSearch | ||
to be an index of docset.explode() type. Default to false. | ||
doc_reconstructor: (Optional) A custom document reconstructor to use when reconstructing documents. | ||
query_kwargs: (Optional) Parameters to configure the underlying OpenSearch search query. | ||
**kwargs: (Optional) kwargs to pass undefined parameters around. | ||
|
||
|
@@ -281,22 +284,26 @@ def opensearch( | |
os_client_args=OS_CLIENT_ARGS, index_name=INDEX, query=query | ||
) | ||
""" | ||
if query_kwargs is None: | ||
query_kwargs = {} | ||
from sycamore.connectors.opensearch import ( | ||
OpenSearchReader, | ||
OpenSearchReaderClientParams, | ||
OpenSearchReaderQueryParams, | ||
) | ||
|
||
client_params = OpenSearchReaderClientParams(os_client_args=os_client_args) | ||
query_params = ( | ||
OpenSearchReaderQueryParams( | ||
index_name=index_name, query=query, reconstruct_document=reconstruct_document, kwargs=query_kwargs | ||
) | ||
if query is not None | ||
else OpenSearchReaderQueryParams( | ||
index_name=index_name, reconstruct_document=reconstruct_document, kwargs=query_kwargs | ||
) | ||
if query is None: | ||
query = {"query": {"match_all": {}}} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. remove the default in OpenSearchReaderQueryParams There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ok |
||
|
||
query_params = OpenSearchReaderQueryParams( | ||
index_name=index_name, | ||
query=query, | ||
reconstruct_document=reconstruct_document, | ||
doc_reconstructor=doc_reconstructor, | ||
kwargs=query_kwargs, | ||
) | ||
|
||
osr = OpenSearchReader(client_params=client_params, query_params=query_params) | ||
return DocSet(self._context, osr) | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,10 +1,14 @@ | ||
import os | ||
import time | ||
import tempfile | ||
import uuid | ||
|
||
import pytest | ||
from opensearchpy import OpenSearch | ||
|
||
import sycamore | ||
from sycamore import EXEC_LOCAL | ||
from sycamore.connectors.doc_reconstruct import DocumentReconstructor | ||
from sycamore.data import Document | ||
from sycamore.tests.integration.connectors.common import compare_connector_docs | ||
from sycamore.tests.config import TEST_DIR | ||
from sycamore.transforms.partition import UnstructuredPdfPartitioner | ||
|
@@ -13,17 +17,28 @@ | |
|
||
|
||
@pytest.fixture(scope="class") | ||
def setup_index(): | ||
def os_client(): | ||
client = OpenSearch(**TestOpenSearchRead.OS_CLIENT_ARGS) | ||
yield client | ||
|
||
|
||
@pytest.fixture(scope="class") | ||
def setup_index(os_client): | ||
# client = OpenSearch(**TestOpenSearchRead.OS_CLIENT_ARGS) | ||
|
||
# Recreate before | ||
client.indices.delete(TestOpenSearchRead.INDEX, ignore_unavailable=True) | ||
client.indices.create(TestOpenSearchRead.INDEX, **TestOpenSearchRead.INDEX_SETTINGS) | ||
os_client.indices.delete(TestOpenSearchRead.INDEX, ignore_unavailable=True) | ||
os_client.indices.create(TestOpenSearchRead.INDEX, **TestOpenSearchRead.INDEX_SETTINGS) | ||
|
||
yield TestOpenSearchRead.INDEX | ||
|
||
# Delete after | ||
client.indices.delete(TestOpenSearchRead.INDEX, ignore_unavailable=True) | ||
os_client.indices.delete(TestOpenSearchRead.INDEX, ignore_unavailable=True) | ||
|
||
|
||
def get_doc_count(os_client, index_name: str) -> int: | ||
res = os_client.cat.indices(format="json", index=index_name) | ||
return int(res[0]["docs.count"]) | ||
|
||
|
||
class TestOpenSearchRead: | ||
|
@@ -56,7 +71,7 @@ class TestOpenSearchRead: | |
"timeout": 120, | ||
} | ||
|
||
def test_ingest_and_read(self, setup_index, exec_mode): | ||
def test_ingest_and_read(self, setup_index, os_client, exec_mode): | ||
""" | ||
Validates data is readable from OpenSearch, and that we can rebuild processed Sycamore documents. | ||
""" | ||
|
@@ -76,6 +91,13 @@ def test_ingest_and_read(self, setup_index, exec_mode): | |
.take_all() | ||
) | ||
|
||
os_client.indices.refresh(TestOpenSearchRead.INDEX) | ||
|
||
expected_count = len(original_docs) | ||
actual_count = get_doc_count(os_client, TestOpenSearchRead.INDEX) | ||
# refresh should have made all ingested docs immediately available for search | ||
assert actual_count == expected_count, f"Expected {expected_count} documents, found {actual_count}" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. FYI: Golang taught me to use got & want, which are shorter and still clear. Fine to keep as is. |
||
|
||
retrieved_docs = context.read.opensearch( | ||
os_client_args=TestOpenSearchRead.OS_CLIENT_ARGS, index_name=TestOpenSearchRead.INDEX | ||
) | ||
|
@@ -92,15 +114,11 @@ def test_ingest_and_read(self, setup_index, exec_mode): | |
) | ||
original_materialized = sorted(original_docs, key=lambda d: d.doc_id) | ||
|
||
# hack to allow opensearch time to index the data. Without this it's possible we try to query the index before | ||
# all the records are available | ||
time.sleep(1) | ||
retrieved_materialized = sorted(retrieved_docs.take_all(), key=lambda d: d.doc_id) | ||
query_materialized = query_docs.take_all() | ||
retrieved_materialized_reconstructed = sorted(retrieved_docs_reconstructed.take_all(), key=lambda d: d.doc_id) | ||
|
||
with OpenSearch(**TestOpenSearchRead.OS_CLIENT_ARGS) as os_client: | ||
os_client.indices.delete(TestOpenSearchRead.INDEX) | ||
os_client.indices.delete(TestOpenSearchRead.INDEX) | ||
assert len(query_materialized) == 1 # exactly one doc should be returned | ||
compare_connector_docs(original_materialized, retrieved_materialized) | ||
|
||
|
@@ -110,3 +128,116 @@ def test_ingest_and_read(self, setup_index, exec_mode): | |
|
||
for i in range(len(doc.elements) - 1): | ||
assert doc.elements[i].element_index < doc.elements[i + 1].element_index | ||
|
||
def _test_ingest_and_read_via_docid_reconstructor(self, setup_index, os_client, cache_dir): | ||
""" | ||
Validates data is readable from OpenSearch, and that we can rebuild processed Sycamore documents. | ||
""" | ||
|
||
print(f"Using cache dir: {cache_dir}") | ||
|
||
def doc_reconstructor(index_name: str, doc_id: str) -> Document: | ||
import pickle | ||
|
||
data = pickle.load(open(f"{cache_dir}/{TestOpenSearchRead.INDEX}-{doc_id}", "rb")) | ||
return Document(**data) | ||
|
||
def doc_to_name(doc: Document, bin: bytes) -> str: | ||
return f"{TestOpenSearchRead.INDEX}-{doc.doc_id}" | ||
|
||
context = sycamore.init(exec_mode=EXEC_LOCAL) | ||
hidden = str(uuid.uuid4()) | ||
# make sure we read from pickle files -- this part won't be written into opensearch. | ||
dicts = [ | ||
eric-anderson marked this conversation as resolved.
Show resolved
Hide resolved
|
||
{ | ||
"doc_id": "1", | ||
"hidden": hidden, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. added |
||
"elements": [ | ||
{"properties": {"_element_index": 1}, "text_representation": "here is an animal that meows"}, | ||
], | ||
}, | ||
{ | ||
"doc_id": "2", | ||
"elements": [ | ||
{"id": 7, "properties": {"_element_index": 7}, "text_representation": "this is a cat"}, | ||
{ | ||
"id": 1, | ||
"properties": {"_element_index": 1}, | ||
"text_representation": "here is an animal that moos", | ||
}, | ||
], | ||
}, | ||
{ | ||
"doc_id": "3", | ||
"elements": [ | ||
{"properties": {"_element_index": 1}, "text_representation": "here is an animal that moos"}, | ||
], | ||
}, | ||
{ | ||
"doc_id": "4", | ||
"elements": [ | ||
{"id": 1, "properties": {"_element_index": 1}}, | ||
], | ||
}, | ||
{ | ||
"doc_id": "5", | ||
"elements": [ | ||
{ | ||
"properties": {"_element_index": 1}, | ||
"text_representation": "the number of pages in this document are 253", | ||
} | ||
], | ||
}, | ||
{ | ||
"doc_id": "6", | ||
"elements": [ | ||
{"id": 1, "properties": {"_element_index": 1}}, | ||
], | ||
}, | ||
] | ||
docs = [Document(item) for item in dicts] | ||
|
||
original_docs = ( | ||
context.read.document(docs) | ||
.materialize(path={"root": cache_dir, "name": doc_to_name}) | ||
.explode() | ||
.write.opensearch( | ||
os_client_args=TestOpenSearchRead.OS_CLIENT_ARGS, | ||
index_name=TestOpenSearchRead.INDEX, | ||
index_settings=TestOpenSearchRead.INDEX_SETTINGS, | ||
execute=False, | ||
) | ||
.take_all() | ||
) | ||
|
||
os_client.indices.refresh(TestOpenSearchRead.INDEX) | ||
|
||
expected_count = len(original_docs) | ||
actual_count = get_doc_count(os_client, TestOpenSearchRead.INDEX) | ||
# refresh should have made all ingested docs immediately available for search | ||
assert actual_count == expected_count, f"Expected {expected_count} documents, found {actual_count}" | ||
|
||
retrieved_docs_reconstructed = ( | ||
context.read.opensearch( | ||
os_client_args=TestOpenSearchRead.OS_CLIENT_ARGS, | ||
index_name=TestOpenSearchRead.INDEX, | ||
reconstruct_document=True, | ||
doc_reconstructor=DocumentReconstructor(TestOpenSearchRead.INDEX, doc_reconstructor), | ||
) | ||
# .filter(lambda doc: doc.doc_id == "1") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. remove obsolete line. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. removed |
||
.take_all() | ||
) | ||
|
||
assert len(retrieved_docs_reconstructed) == 6 | ||
retrieved_sorted = sorted(retrieved_docs_reconstructed, key=lambda d: d.doc_id) | ||
assert retrieved_sorted[0].data["hidden"] == hidden | ||
assert docs == retrieved_sorted | ||
|
||
# Clean slate between Execution Modes | ||
os_client.indices.delete(TestOpenSearchRead.INDEX) | ||
os_client.indices.create(TestOpenSearchRead.INDEX, **TestOpenSearchRead.INDEX_SETTINGS) | ||
os_client.indices.refresh(TestOpenSearchRead.INDEX) | ||
|
||
def test_ingest_and_read_via_docid_reconstructor(self, setup_index, os_client): | ||
with tempfile.TemporaryDirectory() as cache_dir: | ||
self._test_ingest_and_read_via_docid_reconstructor(setup_index, os_client, cache_dir) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
data: dict here as well.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok