Skip to content

Commit

Permalink
OpenSearch improvements (#974)
Browse files Browse the repository at this point in the history
  • Loading branch information
baitsguy authored Oct 24, 2024
1 parent 1a613aa commit 5d17274
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 14 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -169,8 +169,8 @@ def _get_all_elements_for_doc_ids(self, doc_ids: list[str], index: str) -> list[
"""
Returns all records in OpenSearch belonging to a list of Document ids (element.parent_id)
"""
batch_size = 1000
page_size = 1000
batch_size = 100
page_size = 500

all_elements = []
for i in range(0, len(doc_ids), batch_size):
Expand Down
33 changes: 21 additions & 12 deletions lib/sycamore/sycamore/connectors/opensearch/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,20 +6,29 @@


@context_params("opensearch")
def get_knn_query(text_embedder: Embedder, query_phrase: str, k: int = 500, context: Optional[Context] = None):
def get_knn_query(
text_embedder: Embedder,
query_phrase: str,
k: Optional[int] = None,
min_score: Optional[float] = None,
context: Optional[Context] = None,
):
"""
Given a query string and an Embedder, create a simple OpenSearch Knn query. Uses a default k value of 500.
Given a query string and an Embedder, create a simple OpenSearch Knn query.
Supports either 'k' to retrieve k-ANNs, or min_score to return all records within a given distance score.
Uses a default k value of 500.
This is only the base query, if you need to add filters or other specifics you can extend the object.
"""

if k is None and min_score is None:
k = 500
elif k is not None and min_score is not None:
raise ValueError("Only one of `k` or `min_score` should be populated")

embedding = text_embedder.generate_text_embedding(query_phrase)
query = {
"query": {
"knn": {
"embedding": {
"vector": embedding,
"k": k,
}
}
}
}
query = {"query": {"knn": {"embedding": {"vector": embedding}}}}
if k is not None:
query["query"]["knn"]["embedding"]["k"] = k # type: ignore
else:
query["query"]["knn"]["embedding"]["min_score"] = min_score # type: ignore
return query
Original file line number Diff line number Diff line change
Expand Up @@ -431,3 +431,17 @@ def test_get_knn_query(self):

assert get_knn_query(query_phrase="test", k=1000, text_embedder=embedder) == expected_query
embedder.generate_text_embedding.assert_called_with("test")

# default
expected_query = {"query": {"knn": {"embedding": {"vector": embedding, "k": 500}}}}
assert get_knn_query(query_phrase="test", context=context) == expected_query
embedder.generate_text_embedding.assert_called_with("test")

# min_score
expected_query = {"query": {"knn": {"embedding": {"vector": embedding, "min_score": 0.5}}}}
assert get_knn_query(query_phrase="test", min_score=0.5, context=context) == expected_query
embedder.generate_text_embedding.assert_called_with("test")

def test_get_knn_query_validation(self):
with pytest.raises(ValueError, match="Only one of `k` or `min_score` should be populated"):
get_knn_query(Mock(spec=Embedder), query_phrase="test", k=10, min_score=0.5)

0 comments on commit 5d17274

Please sign in to comment.