Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow OpenSearchReader to output to MaterializedDataset consisting of refs #1029

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 13 additions & 3 deletions lib/sycamore/sycamore/connectors/base_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@

from abc import ABC, abstractmethod

from ray.data.dataset import MaterializedDataset

# from sycamore.connectors.opensearch.opensearch_dataset import OpenSearchMaterializedDataset
from sycamore.data.document import Document
from sycamore.plan_nodes import Scan
from sycamore.utils.time_trace import TimeTrace
Expand Down Expand Up @@ -34,7 +37,7 @@ def close(self):
# Type param for the objects that are read from the db
class QueryResponse(ABC):
@abstractmethod
def to_docs(self, query_params: "BaseDBReader.QueryParams") -> list[Document]:
def to_docs(self, query_params: "BaseDBReader.QueryParams", use_refs: bool = False) -> list[Document]:
pass

# Type param for the object used to estabilish the read target
Expand All @@ -58,6 +61,7 @@ def __init__(
super().__init__(**kwargs)
self._client_params = client_params
self._query_params = query_params
self._use_refs = kwargs.get('use_refs', False)

def read_docs(self) -> list[Document]:
try:
Expand All @@ -66,7 +70,7 @@ def read_docs(self) -> list[Document]:
if not client.check_target_presence(self._query_params):
raise ValueError("Target is not present\n" f"Parameters: {self._query_params}\n")
records = client.read_records(query_params=self._query_params)
docs = records.to_docs(query_params=self._query_params)
docs = records.to_docs(query_params=self._query_params, use_refs=self._use_refs)
except Exception as e:
raise ValueError(f"Error reading from target: {e}")
finally:
Expand All @@ -80,10 +84,16 @@ def execute(self, **kwargs) -> "Dataset":
from ray.data import from_items

with TimeTrace("Reader"):
return from_items(items=[{"doc": doc.serialize()} for doc in self.read_docs()])
ds: MaterializedDataset = from_items(items=[{"doc": doc.serialize()} for doc in self.read_docs()])
if self._use_refs:
return self.read_datasource(ds)
return ds

def local_source(self) -> list[Document]:
return self.read_docs()

def format(self):
return "reader"

def read_datasource(self, ds: "Dataset") -> "Dataset":
pass
207 changes: 202 additions & 5 deletions lib/sycamore/sycamore/connectors/opensearch/opensearch_reader.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
import logging

from datasets import Dataset
from ray.data.dataset import MaterializedDataset

# from sycamore.connectors.opensearch.opensearch_dataset import OpenSearchDocument, OpenSearchMaterializedDataset
from sycamore.data import Document, Element
from sycamore.connectors.base_reader import BaseDBReader
from sycamore.data.document import DocumentPropertyTypes, DocumentSource
Expand Down Expand Up @@ -47,6 +51,8 @@ def read_records(self, query_params: BaseDBReader.QueryParams) -> "OpenSearchRea
query_params.kwargs["scroll"] = "10m"
if "size" not in query_params.query and "size" not in query_params.kwargs:
query_params.kwargs["size"] = 200

query_params.kwargs['_source_includes'] = ['doc_id', 'parent_id', 'properties']
logging.debug(f"OpenSearch query on {query_params.index_name}: {query_params.query}")
response = self._client.search(index=query_params.index_name, body=query_params.query, **query_params.kwargs)
scroll_id = response["_scroll_id"]
Expand All @@ -62,6 +68,9 @@ def read_records(self, query_params: BaseDBReader.QueryParams) -> "OpenSearchRea
response = self._client.scroll(scroll_id=scroll_id, scroll=query_params.kwargs["scroll"])
finally:
self._client.clear_scroll(scroll_id=scroll_id)

logging.debug(f"Got {len(result)} records")

return OpenSearchReaderQueryResponse(result, self._client)

def check_target_presence(self, query_params: BaseDBReader.QueryParams):
Expand All @@ -78,16 +87,22 @@ class OpenSearchReaderQueryResponse(BaseDBReader.QueryResponse):
"""
client: typing.Optional["OpenSearch"] = None

def to_docs(self, query_params: "BaseDBReader.QueryParams") -> list[Document]:
use_refs: bool = False

def to_docs(self, query_params: "BaseDBReader.QueryParams", use_refs: bool = False) -> list[Document]:
assert isinstance(query_params, OpenSearchReaderQueryParams)
result: list[Document] = []
if not query_params.reconstruct_document:
for data in self.output:
doc = Document(
doc = OpenSearchDocument(
query_params.index_name,
**data.get("_source", {}),
) if use_refs else Document(
{
**data.get("_source", {}),
}
)

doc.properties[DocumentPropertyTypes.SOURCE] = DocumentSource.DB_QUERY
doc.properties["score"] = data["_score"]
result.append(doc)
Expand Down Expand Up @@ -136,7 +151,11 @@ def to_docs(self, query_params: "BaseDBReader.QueryParams") -> list[Document]:

# Batched retrieval of all elements belong to unique docs
doc_ids = list(unique_docs.keys())
all_elements_for_docs = self._get_all_elements_for_doc_ids(doc_ids, query_params.index_name)

if use_refs:
return [OpenSearchDocument(query_params.index_name, **doc.data) for doc in unique_docs.values()]

all_elements_for_docs = self._get_all_elements_for_doc_ids(doc_ids, query_params.index_name, use_refs)

"""
Add elements to unique docs. If they were not part of the original result,
Expand All @@ -148,6 +167,7 @@ def to_docs(self, query_params: "BaseDBReader.QueryParams") -> list[Document]:
**element.get("_source", {}),
}
)

assert doc.parent_id, "Got non-element record from OpenSearch reconstruction query"
if doc.doc_id not in query_result_elements_per_doc.get(doc.parent_id, {}):
doc.properties[DocumentPropertyTypes.SOURCE] = DocumentSource.DOCUMENT_RECONSTRUCTION_RETRIEVAL
Expand All @@ -162,16 +182,22 @@ def to_docs(self, query_params: "BaseDBReader.QueryParams") -> list[Document]:
for doc in result:
doc.elements.sort(key=lambda e: e.element_index if e.element_index is not None else float("inf"))

logging.debug(f"Returning {len(result)} documents")

return result

def _get_all_elements_for_doc_ids(self, doc_ids: list[str], index: str) -> list[typing.Any]:
def _get_all_elements_for_doc_ids(self, doc_ids: list[str], index: str, use_refs: bool = False) -> list[typing.Any]:
assert self.client, "_get_all_elements_for_doc_ids requires an OpenSearch client instance in this class"
"""
Returns all records in OpenSearch belonging to a list of Document ids (element.parent_id)
"""
batch_size = 100
page_size = 500

query_params = {}
if use_refs:
query_params["_source_includes"] = ["doc_id", "parent_id"]

all_elements = []
for i in range(0, len(doc_ids), batch_size):
doc_ids_batch = doc_ids[i : i + batch_size]
Expand All @@ -182,7 +208,7 @@ def _get_all_elements_for_doc_ids(self, doc_ids: list[str], index: str) -> list[
"size": page_size,
"from": from_offset,
}
response = self.client.search(index=index, body=query)
response = self.client.search(index=index, body=query, **query_params)
hits = response["hits"]["hits"]
all_elements.extend(hits)
if len(hits) < page_size:
Expand All @@ -196,3 +222,174 @@ class OpenSearchReader(BaseDBReader):
Record = OpenSearchReaderQueryResponse
ClientParams = OpenSearchReaderClientParams
QueryParams = OpenSearchReaderQueryParams

def read_datasource(self, ds) -> "Dataset":
return OpenSearchMaterializedDataset(ds, self._client_params, self._query_params)

class OpenSearchDocument(Document):
def __init__(self, index_name: str, **data):
super().__init__(**data)
self.index_name = index_name

@classmethod
def from_document(cls, doc: Document) -> "OpenSearchDocument":
pass

def serialize(self) -> bytes:
"""Serialize this document to bytes."""
from pickle import dumps

return dumps({"index_name": self.index_name, "data": self.data})

@staticmethod
def deserialize(raw: bytes) -> "OpenSearchDocument":
"""Deserialize from bytes to a OpenSearchDocument."""
from pickle import loads

data = loads(raw)
# print(f"Deserialized data: {data}")
return OpenSearchDocument(data["index_name"], **data["data"])


class OpenSearchMaterializedDataset(MaterializedDataset):

def __init__(self, ds: MaterializedDataset, os_client_params: BaseDBReader.ClientParams, os_query_params: BaseDBReader.QueryParams):
# self.client = client
self.os_client_params = os_client_params
self.os_query_params = os_query_params
super().__init__(ds._plan, ds._logical_plan)

@staticmethod
def _get_all_elements_for_doc_ids(os_client: "OpenSearch", doc_ids: list[str], index: str, exclude_embedding=True) -> list[typing.Any]:
"""
Returns all records in OpenSearch belonging to a list of Document ids (element.parent_id)
"""
batch_size = 100
page_size = 500

query_params = {}
if exclude_embedding:
query_params["_source_excludes"] = ["embedding"]

all_elements = []
for i in range(0, len(doc_ids), batch_size):
doc_ids_batch = doc_ids[i : i + batch_size]
from_offset = 0
while True:
query = {
"query": {"terms": {"parent_id.keyword": doc_ids_batch}},
"size": page_size,
"from": from_offset,
}
response = os_client.search(index=index, body=query, **query_params)
hits = response["hits"]["hits"]
all_elements.extend(hits)
if len(hits) < page_size:
break
from_offset += page_size

logging.debug(f"Got all elements: {len(all_elements)}")
return all_elements

def to_docs(self, client, index_name, docs: list[OpenSearchDocument], reconstruct_document: bool = True) -> list[Document]:
"""

"""
logging.debug(f"In to_docs, reconstruct_document: {reconstruct_document}")

result: list[Document] = []
if not reconstruct_document:
doc_ids = [doc.data['doc_id'] for doc in docs]
res = client.mget(index=index_name, body={"ids": doc_ids})
for data in res['docs']:
doc = Document(
{
**data.get("_source", {}),
}
)
doc.properties[DocumentPropertyTypes.SOURCE] = DocumentSource.DB_QUERY
# doc.properties["score"] = data["_score"]
result.append(doc)
else:
"""
Document reconstruction:
1. Construct a map of all unique parent Documents (i.e. no parent_id field)
1.1 If we find doc_ids without parent documents, we create empty parent Documents
2. Perform a terms query to retrieve all (including non-matched) other records for that parent_id
3. Add elements to unique parent Documents
"""
# Get unique documents
unique_docs: dict[str, Document] = {}
query_result_elements_per_doc: dict[str, set[str]] = {}
for data in docs:
doc = Document(
{
**data.data,
}
)
doc.properties[DocumentPropertyTypes.SOURCE] = DocumentSource.DB_QUERY
assert doc.doc_id, "Retrieved invalid doc with missing doc_id"
if not doc.parent_id:
# Always use retrieved doc as the unique parent doc - override any empty parent doc created below
unique_docs[doc.doc_id] = doc
else:
# Create empty parent documents if no parent document was in result set
unique_docs[doc.parent_id] = unique_docs.get(
doc.parent_id,
Document(
{
"doc_id": doc.parent_id,
"properties": {
**doc.properties,
DocumentPropertyTypes.SOURCE: DocumentSource.DOCUMENT_RECONSTRUCTION_PARENT,
},
}
),
)
elements = query_result_elements_per_doc.get(doc.parent_id, set())
elements.add(doc.doc_id)
query_result_elements_per_doc[doc.parent_id] = elements

# Batched retrieval of all elements belong to unique docs
doc_ids = list(unique_docs.keys())
all_elements_for_docs = self._get_all_elements_for_doc_ids(client, doc_ids, index_name)

"""
Add elements to unique docs. If they were not part of the original result,
we set properties.DocumentPropertyTypes.SOURCE = DOCUMENT_RECONSTRUCTION_RETRIEVAL
"""
for element in all_elements_for_docs:
doc = Document(
{
**element.get("_source", {}),
}
)

assert doc.parent_id, "Got non-element record from OpenSearch reconstruction query"
if doc.doc_id not in query_result_elements_per_doc.get(doc.parent_id, {}):
doc.properties[DocumentPropertyTypes.SOURCE] = DocumentSource.DOCUMENT_RECONSTRUCTION_RETRIEVAL
else:
doc.properties[DocumentPropertyTypes.SOURCE] = DocumentSource.DB_QUERY
parent = unique_docs[doc.parent_id]
parent.elements.append(Element(doc.data))

result = list(unique_docs.values())

# sort elements per doc
for doc in result:
doc.elements.sort(key=lambda e: e.element_index if e.element_index is not None else float("inf"))

return result

def iter_rows(
self, *, prefetch_batches: int = 1, prefetch_blocks: int = 0
) -> typing.Iterable[typing.Dict[str, typing.Any]]:
os_client = OpenSearchReaderClient.from_client_params(self.os_client_params)._client

for row in super().iter_rows():
os_doc = OpenSearchDocument.deserialize(row["doc"])
index_name = os_doc.index_name
docs = self.to_docs(os_client, index_name, [os_doc], self.os_query_params.reconstruct_document)
for doc in docs:
# print(f"Yielding doc {counter}: {doc}")
yield {"doc": doc.serialize()}
2 changes: 1 addition & 1 deletion lib/sycamore/sycamore/reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,7 +297,7 @@ def opensearch(
index_name=index_name, reconstruct_document=reconstruct_document, kwargs=query_kwargs
)
)
osr = OpenSearchReader(client_params=client_params, query_params=query_params)
osr = OpenSearchReader(client_params=client_params, query_params=query_params, **kwargs)
return DocSet(self._context, osr)

@requires_modules("duckdb", extra="duckdb")
Expand Down
Loading
Loading