Skip to content
This repository has been archived by the owner on Jul 16, 2024. It is now read-only.

Support more index types besides ivfflat #224

Merged
merged 21 commits into from
Nov 23, 2023
Merged
38 changes: 24 additions & 14 deletions greenplumpython/experimental/embedding.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Callable, cast
from typing import Any, Callable, Literal, Optional, cast
from uuid import uuid4

import greenplumpython as gp
Expand Down Expand Up @@ -75,7 +75,13 @@ class Embedding:
def __init__(self, dataframe: gp.DataFrame) -> None:
self._dataframe = dataframe

def create_index(self, column: str, model_name: str) -> gp.DataFrame:
def create_index(
self,
column: str,
model_name: str,
embedding_dimension: Optional[int] = None,
beeender marked this conversation as resolved.
Show resolved Hide resolved
method: Optional[Literal["ivfflat", "hnsw"]] = "hnsw",
) -> gp.DataFrame:
"""
Generate embeddings and create index for a column of unstructured data.

Expand All @@ -96,6 +102,8 @@ def create_index(self, column: str, model_name: str) -> gp.DataFrame:
Args:
column: name of column to create index on.
model_name: name of model to generate embedding.
embedding_dimension: dimension of the embedding.
method: name of the index access method (i.e. index type) in `pgvector <https://github.com/pgvector/pgvector>`_.

Returns:
Dataframe with target column indexed based on embeddings.
Expand All @@ -105,17 +113,17 @@ def create_index(self, column: str, model_name: str) -> gp.DataFrame:

"""

import sentence_transformers # type: ignore reportMissingImports

model = sentence_transformers.SentenceTransformer(model_name) # type: ignore reportUnknownVariableType

assert self._dataframe.unique_key is not None, "Unique key is required to create index."
try:
word_embedding_dimension: int = model[1].word_embedding_dimension # From models.Pooling
except:
raise NotImplementedError(
"Model '{model_name}' doesn't provide embedding dimension information"
)
if embedding_dimension is None:
try:
import sentence_transformers # type: ignore reportMissingImports

model = sentence_transformers.SentenceTransformer(model_name) # type: ignore reportUnknownVariableType
embedding_dimension: int = model[1].word_embedding_dimension # From models.Pooling
except:
raise NotImplementedError(
"Model '{model_name}' doesn't provide embedding dimension information"
)
Comment on lines +119 to +125
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe import error here

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think import error is for modules, not models.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

make sense


embedding_col_name = "_emb_" + uuid4().hex
embedding_df_cols = list(self._dataframe.unique_key) + [embedding_col_name]
Expand All @@ -126,7 +134,7 @@ def create_index(self, column: str, model_name: str) -> gp.DataFrame:
Callable[[gp.DataFrame], TypeCast],
# FIXME: Modifier must be adapted to all types of model.
# Can this be done with transformers.AutoConfig?
lambda t: gp.type_("vector", modifier=word_embedding_dimension)(_generate_embedding(t[column], model_name)), # type: ignore reportUnknownLambdaType
lambda t: gp.type_("vector", modifier=embedding_dimension)(_generate_embedding(t[column], model_name)), # type: ignore reportUnknownLambdaType
)
},
)[embedding_df_cols]
Expand All @@ -136,8 +144,10 @@ def create_index(self, column: str, model_name: str) -> gp.DataFrame:
distribution_type="hash",
)
.check_unique(self._dataframe.unique_key)
.create_index(columns={embedding_col_name}, method="ivfflat")
)
if method is not None:
assert method == "hnsw" or method == "ivfflat"
embedding_df = embedding_df.create_index(columns={embedding_col_name}, method=method)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

since there may be more method , assert method in ["ivfflat", "hnsw"] may be bettor?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! That's easier to read and write indeed. Changed.

assert self._dataframe._db is not None
_record_dependency._create_in_db(self._dataframe._db)
sql_add_relationship = f"""
Expand Down
Loading