-
Notifications
You must be signed in to change notification settings - Fork 49
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Allow Doc reconstruct via function #1072
Conversation
lib/sycamore/sycamore/connectors/opensearch/opensearch_reader.py
Outdated
Show resolved
Hide resolved
for data in self.output: | ||
doc_id = data["_source"]["parent_id"] or data["_id"] | ||
if doc_id not in unique: | ||
# The reconstructor may or may not fetch the document from a cache |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd remove the comment; it's assuming a particular implementation.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
removed.
@@ -93,7 +98,17 @@ class OpenSearchReaderQueryResponse(BaseDBReader.QueryResponse): | |||
def to_docs(self, query_params: "BaseDBReader.QueryParams") -> list[Document]: | |||
assert isinstance(query_params, OpenSearchReaderQueryParams) | |||
result: list[Document] = [] | |||
if not query_params.reconstruct_document: | |||
if query_params.doc_reconstructor is not None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How are you planning to implement parallel reconstruction?
I think this is another place where making the reconstructor a class would help. It could have a member variable .parallel_reconstruct where the first stage does dict -> (simplified Doc); and the second stage is a Map to to simplified doc -> full doc.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe we can discuss this idea in #1029
original_docs = ( | ||
context.read.binary(path, binary_format="pdf") | ||
.partition(partitioner=UnstructuredPdfPartitioner()) | ||
.materialize(cache_dir) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should not work; default materialize naming is content-hash+doc_id.
sycamore/lib/sycamore/sycamore/materialize.py
Line 313 in f80ff0e
def doc_to_name(doc: Document, bin: bytes) -> str: |
You need to override the naming function to match with line 137.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok
lib/sycamore/sycamore/tests/integration/connectors/opensearch/test_opensearch_read.py
Outdated
Show resolved
Hide resolved
data = pickle.load(open(f"{cache_dir}/{found}", "rb")) | ||
return Document(**data) | ||
|
||
path = str(TEST_DIR / "resources/data/pdfs/Ray.pdf") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would recommend fake documents rather than running the partitioner. This test is testing way more than is necessary. As a side benefit, it becomes very easy to verify that you're reconstructing via pickle. The fake documents can do doc.hidden_var = ; and then you check that hidden_var pops out on the far side. Since it wouldn't be written to opensearch, you know it's doing pickle reconstruction.
Also it seems better to test with multiple documents which is easier if they're fake.
lib/sycamore/sycamore/tests/integration/connectors/opensearch/test_opensearch_read.py
Outdated
Show resolved
Hide resolved
retrieved_materialized = sorted(retrieved_docs.take_all(), key=lambda d: d.doc_id) | ||
query_materialized = query_docs.take_all() | ||
retrieved_materialized_reconstructed = sorted(retrieved_docs_reconstructed.take_all(), key=lambda d: d.doc_id) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would have expected assert original_docs == retrieved_docs
To verify completely the same. Then you can remove most of the other checks.
Once you add multiple docs from earlier comment, you'll need to sort by docid.
pass | ||
|
||
@abstractmethod | ||
def get_doc_id(self, data) -> str: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For now, I'd get rid of the abstract base class. It's not clear to me this is the right abstraction, but we won't know until we implement this for a few more databases. E.g. it's assuming that there are source fields, and that the document comes back as a dictionary.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure
def get_required_source_fields(self) -> list[str]: | ||
return ["parent_id"] | ||
|
||
def get_doc_id(self, data) -> str: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
data : dict
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok
doc_reconstructor=doc_reconstructor, | ||
kwargs=query_kwargs, | ||
) | ||
if query is None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
remove the default in OpenSearchReaderQueryParams
That change makes sure this code and the other code stay in sync.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok
expected_count = len(original_docs) | ||
actual_count = get_doc_count(os_client, TestOpenSearchRead.INDEX) | ||
# refresh should have made all ingested docs immediately available for search | ||
assert actual_count == expected_count, f"Expected {expected_count} documents, found {actual_count}" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
FYI: Golang taught me to use got & want, which are shorter and still clear. Fine to keep as is.
retrieved_materialized = sorted(retrieved_docs.take_all(), key=lambda d: d.doc_id) | ||
query_materialized = query_docs.take_all() | ||
retrieved_materialized_reconstructed = sorted(retrieved_docs_reconstructed.take_all(), key=lambda d: d.doc_id) | ||
|
||
with OpenSearch(**TestOpenSearchRead.OS_CLIENT_ARGS) as os_client: | ||
os_client.indices.delete(TestOpenSearchRead.INDEX) | ||
# with OpenSearch(**TestOpenSearchRead.OS_CLIENT_ARGS) as os_client: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
remove obsolete line.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok
lib/sycamore/sycamore/tests/integration/connectors/opensearch/test_opensearch_read.py
Show resolved
Hide resolved
dicts = [ | ||
{ | ||
"doc_id": "1", | ||
"hidden": hidden, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
# make sure we read from pickle files -- this part won't be written into opensearch.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
added
reconstruct_document=True, | ||
doc_reconstructor=OpenSearchDocumentReconstructor(TestOpenSearchRead.INDEX, doc_reconstructor), | ||
) | ||
.filter(lambda doc: doc.doc_id == "1") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why not read them all, sort and compare? The filter seems unnecessary. You could drop all the docs except for 1 and the test would remain the same.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok
for i in range(len(doc.elements) - 1): | ||
assert doc.elements[i].element_index < doc.elements[i + 1].element_index | ||
assert len(retrieved_docs_reconstructed) == 1 | ||
assert retrieved_docs_reconstructed[0].data["hidden"] == hidden |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does assert retrieved_docs_reconstructed == original_docs
not work?
That's the condition that we want to be true.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, I will add this check (original_docs
is after .explode()
, but you meant docs
).
os_client.indices.delete(TestOpenSearchRead.INDEX) | ||
os_client.indices.create(TestOpenSearchRead.INDEX, **TestOpenSearchRead.INDEX_SETTINGS) | ||
os_client.indices.refresh(TestOpenSearchRead.INDEX) | ||
# with OpenSearch(**TestOpenSearchRead.OS_CLIENT_ARGS) as os_client: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
remove obsolete line.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
def __init__(self, index_name: str, reconstruct_fn: Callable[[str, str], Document]): | ||
self.index_name = index_name | ||
self.reconstruct_fn = reconstruct_fn | ||
|
||
def get_required_source_fields(self) -> list[str]: | ||
return ["parent_id"] | ||
|
||
def get_doc_id(self, data) -> str: | ||
def get_doc_id(self, data: dict) -> str: | ||
return data["_source"]["parent_id"] or data["_id"] | ||
|
||
def reconstruct(self, data) -> Document: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
data: dict here as well.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok
) | ||
.filter(lambda doc: doc.doc_id == "1") | ||
# .filter(lambda doc: doc.doc_id == "1") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
remove obsolete line.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
removed
No description provided.