Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow Doc reconstruct via function #1072

Merged
merged 10 commits into from
Dec 19, 2024
Merged
18 changes: 18 additions & 0 deletions lib/sycamore/sycamore/connectors/doc_reconstruct.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
from typing import Callable

from sycamore.data import Document


class DocumentReconstructor:
def __init__(self, index_name: str, reconstruct_fn: Callable[[str, str], Document]):
self.index_name = index_name
self.reconstruct_fn = reconstruct_fn

def get_required_source_fields(self) -> list[str]:
return ["parent_id"]

def get_doc_id(self, data: dict) -> str:
return data["_source"]["parent_id"] or data["_id"]

def reconstruct(self, data) -> Document:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

data: dict here as well.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok

return self.reconstruct_fn(self.index_name, self.get_doc_id(data))
20 changes: 17 additions & 3 deletions lib/sycamore/sycamore/connectors/opensearch/opensearch_reader.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
import logging

from sycamore.connectors.doc_reconstruct import DocumentReconstructor
from sycamore.data import Document, Element
from sycamore.connectors.base_reader import BaseDBReader
from sycamore.data.document import DocumentPropertyTypes, DocumentSource
from sycamore.utils.import_utils import requires_modules
from dataclasses import dataclass, field
import typing
from typing import Dict
from typing import Dict, Optional

if typing.TYPE_CHECKING:
from opensearchpy import OpenSearch
Expand All @@ -20,9 +21,10 @@ class OpenSearchReaderClientParams(BaseDBReader.ClientParams):
@dataclass
class OpenSearchReaderQueryParams(BaseDBReader.QueryParams):
index_name: str
query: Dict = field(default_factory=lambda: {"query": {"match_all": {}}})
query: Dict
kwargs: Dict = field(default_factory=lambda: {})
reconstruct_document: bool = False
doc_reconstructor: Optional[DocumentReconstructor] = None


class OpenSearchReaderClient(BaseDBReader.Client):
Expand All @@ -47,6 +49,10 @@ def read_records(self, query_params: BaseDBReader.QueryParams) -> "OpenSearchRea
if "size" not in query_params.query and "size" not in query_params.kwargs:
query_params.kwargs["size"] = 200
result = []
if query_params.reconstruct_document:
eric-anderson marked this conversation as resolved.
Show resolved Hide resolved
query_params.kwargs["_source_includes"] = ["doc_id", "parent_id", "properties"]
if query_params.doc_reconstructor is not None:
query_params.kwargs["_source_includes"] = query_params.doc_reconstructor.get_required_source_fields()
# No pagination needed for knn queries
if "query" in query_params.query and "knn" in query_params.query["query"]:
response = self._client.search(
Expand Down Expand Up @@ -93,7 +99,15 @@ class OpenSearchReaderQueryResponse(BaseDBReader.QueryResponse):
def to_docs(self, query_params: "BaseDBReader.QueryParams") -> list[Document]:
assert isinstance(query_params, OpenSearchReaderQueryParams)
result: list[Document] = []
if not query_params.reconstruct_document:
if query_params.doc_reconstructor is not None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How are you planning to implement parallel reconstruction?
I think this is another place where making the reconstructor a class would help. It could have a member variable .parallel_reconstruct where the first stage does dict -> (simplified Doc); and the second stage is a Map to to simplified doc -> full doc.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we can discuss this idea in #1029

logging.info("Using DocID to Document reconstructor")
unique = set()
for data in self.output:
doc_id = query_params.doc_reconstructor.get_doc_id(data)
if doc_id not in unique:
result.append(query_params.doc_reconstructor.reconstruct(data))
unique.add(doc_id)
elif not query_params.reconstruct_document:
for data in self.output:
doc = Document(
{
Expand Down
27 changes: 17 additions & 10 deletions lib/sycamore/sycamore/reader.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
from typing import Optional, Union, Callable, Dict, Any
from typing import Optional, Union, Callable, Dict
from pathlib import Path

from pandas import DataFrame
from pyarrow import Table
from pyarrow.filesystem import FileSystem

from sycamore.connectors.doc_reconstruct import DocumentReconstructor
from sycamore.context import context_params
from sycamore.plan_nodes import Node
from sycamore import Context, DocSet
Expand Down Expand Up @@ -224,7 +225,8 @@ def opensearch(
index_name: str,
query: Optional[Dict] = None,
reconstruct_document: bool = False,
query_kwargs: dict[str, Any] = {},
doc_reconstructor: Optional[DocumentReconstructor] = None,
query_kwargs=None,
**kwargs,
) -> DocSet:
"""
Expand All @@ -240,6 +242,7 @@ def opensearch(
reconstruct_document: Used to decide whether the returned DocSet comprises reconstructed documents,
i.e. by collecting all elements belong to a single parent document (parent_id). This requires OpenSearch
to be an index of docset.explode() type. Default to false.
doc_reconstructor: (Optional) A custom document reconstructor to use when reconstructing documents.
query_kwargs: (Optional) Parameters to configure the underlying OpenSearch search query.
**kwargs: (Optional) kwargs to pass undefined parameters around.

Expand Down Expand Up @@ -281,22 +284,26 @@ def opensearch(
os_client_args=OS_CLIENT_ARGS, index_name=INDEX, query=query
)
"""
if query_kwargs is None:
query_kwargs = {}
from sycamore.connectors.opensearch import (
OpenSearchReader,
OpenSearchReaderClientParams,
OpenSearchReaderQueryParams,
)

client_params = OpenSearchReaderClientParams(os_client_args=os_client_args)
query_params = (
OpenSearchReaderQueryParams(
index_name=index_name, query=query, reconstruct_document=reconstruct_document, kwargs=query_kwargs
)
if query is not None
else OpenSearchReaderQueryParams(
index_name=index_name, reconstruct_document=reconstruct_document, kwargs=query_kwargs
)
if query is None:
query = {"query": {"match_all": {}}}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove the default in OpenSearchReaderQueryParams
That change makes sure this code and the other code stay in sync.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok


query_params = OpenSearchReaderQueryParams(
index_name=index_name,
query=query,
reconstruct_document=reconstruct_document,
doc_reconstructor=doc_reconstructor,
kwargs=query_kwargs,
)

osr = OpenSearchReader(client_params=client_params, query_params=query_params)
return DocSet(self._context, osr)

Expand Down
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
import os
import time
import tempfile
import uuid

import pytest
from opensearchpy import OpenSearch

import sycamore
from sycamore import EXEC_LOCAL
from sycamore.connectors.doc_reconstruct import DocumentReconstructor
from sycamore.data import Document
from sycamore.tests.integration.connectors.common import compare_connector_docs
from sycamore.tests.config import TEST_DIR
from sycamore.transforms.partition import UnstructuredPdfPartitioner
Expand All @@ -13,17 +17,28 @@


@pytest.fixture(scope="class")
def setup_index():
def os_client():
client = OpenSearch(**TestOpenSearchRead.OS_CLIENT_ARGS)
yield client


@pytest.fixture(scope="class")
def setup_index(os_client):
# client = OpenSearch(**TestOpenSearchRead.OS_CLIENT_ARGS)

# Recreate before
client.indices.delete(TestOpenSearchRead.INDEX, ignore_unavailable=True)
client.indices.create(TestOpenSearchRead.INDEX, **TestOpenSearchRead.INDEX_SETTINGS)
os_client.indices.delete(TestOpenSearchRead.INDEX, ignore_unavailable=True)
os_client.indices.create(TestOpenSearchRead.INDEX, **TestOpenSearchRead.INDEX_SETTINGS)

yield TestOpenSearchRead.INDEX

# Delete after
client.indices.delete(TestOpenSearchRead.INDEX, ignore_unavailable=True)
os_client.indices.delete(TestOpenSearchRead.INDEX, ignore_unavailable=True)


def get_doc_count(os_client, index_name: str) -> int:
res = os_client.cat.indices(format="json", index=index_name)
return int(res[0]["docs.count"])


class TestOpenSearchRead:
Expand Down Expand Up @@ -56,7 +71,7 @@ class TestOpenSearchRead:
"timeout": 120,
}

def test_ingest_and_read(self, setup_index, exec_mode):
def test_ingest_and_read(self, setup_index, os_client, exec_mode):
"""
Validates data is readable from OpenSearch, and that we can rebuild processed Sycamore documents.
"""
Expand All @@ -76,6 +91,13 @@ def test_ingest_and_read(self, setup_index, exec_mode):
.take_all()
)

os_client.indices.refresh(TestOpenSearchRead.INDEX)

expected_count = len(original_docs)
actual_count = get_doc_count(os_client, TestOpenSearchRead.INDEX)
# refresh should have made all ingested docs immediately available for search
assert actual_count == expected_count, f"Expected {expected_count} documents, found {actual_count}"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FYI: Golang taught me to use got & want, which are shorter and still clear. Fine to keep as is.


retrieved_docs = context.read.opensearch(
os_client_args=TestOpenSearchRead.OS_CLIENT_ARGS, index_name=TestOpenSearchRead.INDEX
)
Expand All @@ -92,15 +114,11 @@ def test_ingest_and_read(self, setup_index, exec_mode):
)
original_materialized = sorted(original_docs, key=lambda d: d.doc_id)

# hack to allow opensearch time to index the data. Without this it's possible we try to query the index before
# all the records are available
time.sleep(1)
retrieved_materialized = sorted(retrieved_docs.take_all(), key=lambda d: d.doc_id)
query_materialized = query_docs.take_all()
retrieved_materialized_reconstructed = sorted(retrieved_docs_reconstructed.take_all(), key=lambda d: d.doc_id)

with OpenSearch(**TestOpenSearchRead.OS_CLIENT_ARGS) as os_client:
os_client.indices.delete(TestOpenSearchRead.INDEX)
os_client.indices.delete(TestOpenSearchRead.INDEX)
assert len(query_materialized) == 1 # exactly one doc should be returned
compare_connector_docs(original_materialized, retrieved_materialized)

Expand All @@ -110,3 +128,116 @@ def test_ingest_and_read(self, setup_index, exec_mode):

for i in range(len(doc.elements) - 1):
assert doc.elements[i].element_index < doc.elements[i + 1].element_index

def _test_ingest_and_read_via_docid_reconstructor(self, setup_index, os_client, cache_dir):
"""
Validates data is readable from OpenSearch, and that we can rebuild processed Sycamore documents.
"""

print(f"Using cache dir: {cache_dir}")

def doc_reconstructor(index_name: str, doc_id: str) -> Document:
import pickle

data = pickle.load(open(f"{cache_dir}/{TestOpenSearchRead.INDEX}-{doc_id}", "rb"))
return Document(**data)

def doc_to_name(doc: Document, bin: bytes) -> str:
return f"{TestOpenSearchRead.INDEX}-{doc.doc_id}"

context = sycamore.init(exec_mode=EXEC_LOCAL)
hidden = str(uuid.uuid4())
# make sure we read from pickle files -- this part won't be written into opensearch.
dicts = [
eric-anderson marked this conversation as resolved.
Show resolved Hide resolved
{
"doc_id": "1",
"hidden": hidden,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

# make sure we read from pickle files -- this part won't be written into opensearch.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added

"elements": [
{"properties": {"_element_index": 1}, "text_representation": "here is an animal that meows"},
],
},
{
"doc_id": "2",
"elements": [
{"id": 7, "properties": {"_element_index": 7}, "text_representation": "this is a cat"},
{
"id": 1,
"properties": {"_element_index": 1},
"text_representation": "here is an animal that moos",
},
],
},
{
"doc_id": "3",
"elements": [
{"properties": {"_element_index": 1}, "text_representation": "here is an animal that moos"},
],
},
{
"doc_id": "4",
"elements": [
{"id": 1, "properties": {"_element_index": 1}},
],
},
{
"doc_id": "5",
"elements": [
{
"properties": {"_element_index": 1},
"text_representation": "the number of pages in this document are 253",
}
],
},
{
"doc_id": "6",
"elements": [
{"id": 1, "properties": {"_element_index": 1}},
],
},
]
docs = [Document(item) for item in dicts]

original_docs = (
context.read.document(docs)
.materialize(path={"root": cache_dir, "name": doc_to_name})
.explode()
.write.opensearch(
os_client_args=TestOpenSearchRead.OS_CLIENT_ARGS,
index_name=TestOpenSearchRead.INDEX,
index_settings=TestOpenSearchRead.INDEX_SETTINGS,
execute=False,
)
.take_all()
)

os_client.indices.refresh(TestOpenSearchRead.INDEX)

expected_count = len(original_docs)
actual_count = get_doc_count(os_client, TestOpenSearchRead.INDEX)
# refresh should have made all ingested docs immediately available for search
assert actual_count == expected_count, f"Expected {expected_count} documents, found {actual_count}"

retrieved_docs_reconstructed = (
context.read.opensearch(
os_client_args=TestOpenSearchRead.OS_CLIENT_ARGS,
index_name=TestOpenSearchRead.INDEX,
reconstruct_document=True,
doc_reconstructor=DocumentReconstructor(TestOpenSearchRead.INDEX, doc_reconstructor),
)
# .filter(lambda doc: doc.doc_id == "1")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove obsolete line.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

removed

.take_all()
)

assert len(retrieved_docs_reconstructed) == 6
retrieved_sorted = sorted(retrieved_docs_reconstructed, key=lambda d: d.doc_id)
assert retrieved_sorted[0].data["hidden"] == hidden
assert docs == retrieved_sorted

# Clean slate between Execution Modes
os_client.indices.delete(TestOpenSearchRead.INDEX)
os_client.indices.create(TestOpenSearchRead.INDEX, **TestOpenSearchRead.INDEX_SETTINGS)
os_client.indices.refresh(TestOpenSearchRead.INDEX)

def test_ingest_and_read_via_docid_reconstructor(self, setup_index, os_client):
with tempfile.TemporaryDirectory() as cache_dir:
self._test_ingest_and_read_via_docid_reconstructor(setup_index, os_client, cache_dir)
Loading