Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add kmeans clustering based on ray #1080

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 36 additions & 0 deletions lib/sycamore/sycamore/docset.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
)
from sycamore.plan_nodes import Node, Transform
from sycamore.transforms.augment_text import TextAugmentor
from sycamore.transforms.clustering import KMeans
from sycamore.transforms.embed import Embedder
from sycamore.transforms import DocumentStructure, Sort
from sycamore.transforms.extract_entity import EntityExtractor, OpenAIEntityExtractor
Expand Down Expand Up @@ -915,6 +916,41 @@ def map(self, f: Callable[[Document], Document], **resource_args) -> "DocSet":
mapping = Map(self.plan, f=f, **resource_args)
return DocSet(self.context, mapping)

def kmeans(self, K: int, iterations: int = 20, init_mode: str = "random", epsilon: float = 1e-4):
"""
Apply kmeans over embedding field

Args:
K: the count of centroids
iterations: the max iteration runs before converge
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a reasonable default for this? I at least wouldn't know what a good value to pick would be.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

spark uses 20, we could follow the same, but it should really be a tuning process.

init_mode: how the initial centroids are select
epsilon: the condition for determining if it's converged
Return a list of max K centroids
"""

def init_embedding(row):
doc = Document.from_row(row)
return {"vector": doc.embedding, "cluster": -1}

embeddings = self.plan.execute().map(init_embedding).materialize()

initial_centroids = KMeans.init(embeddings, K, init_mode)
centroids = KMeans.update(embeddings, initial_centroids, iterations, epsilon)
return centroids

def clustering(self, centroids, cluster_field_name, **resource_args) -> "DocSet":
def cluster(doc: Document) -> Document:
idx = KMeans.closest(doc.embedding, centroids)
properties = doc.properties
properties[cluster_field_name] = idx
doc.properties = properties
return doc

from sycamore.transforms import Map

mapping = Map(self.plan, f=cluster, **resource_args)
return DocSet(self.context, mapping)

def flat_map(self, f: Callable[[Document], list[Document]], **resource_args) -> "DocSet":
"""
Applies the FlatMap transformation on the Docset.
Expand Down
65 changes: 65 additions & 0 deletions lib/sycamore/sycamore/tests/unit/transforms/test_clustering.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
import numpy as np
import ray.data

import sycamore
from sycamore.data import Document
from sycamore.transforms.clustering import KMeans


class TestKMeans:

def test_kmeans(self):
points = np.random.uniform(0, 40, (20, 4))
docs = [
Document(text_representation=f"Document {i}", doc_id=i, embedding=point, properties={"document_number": i})
for i, point in enumerate(points)
]
context = sycamore.init()
docset = context.read.document(docs)
centroids = docset.kmeans(3, 4)
assert len(centroids) == 3

def test_closest(self):
row = [[0, 0, 0, 0]]
centroids = [
[1, 1, 1, 1],
[2, 2, 2, 2],
[-1, -1, -1, -1],
]
assert KMeans.closest(row, centroids) == 0

def test_random(self):
points = np.random.uniform(0, 40, (20, 4))
embeddings = [{"vector": list(point), "cluster": -1} for point in points]
embeddings = ray.data.from_items(embeddings)
centroids = KMeans.random_init(embeddings, 10)
assert len(centroids) == 10

def test_converged(self):
last_ones = [[1.0, 1.0], [10.0, 10.0]]
next_ones = [[2.0, 2.0], [12.0, 12.0]]
assert KMeans.converged(last_ones, next_ones, 10).item() is True
assert KMeans.converged(last_ones, next_ones, 1).item() is False

def test_converge(self):
points = np.random.uniform(0, 10, (20, 4))
embeddings = [{"vector": list(point), "cluster": -1} for point in points]
embeddings = ray.data.from_items(embeddings)
centroids = [[2.0, 2.0, 2.0, 2.0], [8.0, 8.0, 8.0, 8.0]]
new_centroids = KMeans.update(embeddings, centroids, 2, 1e-4)
assert len(new_centroids) == 2

def test_clustering(self):
np.random.seed(2024)
points = np.random.uniform(0, 40, (20, 4))
docs = [
Document(text_representation=f"Document {i}", doc_id=i, embedding=point, properties={"document_number": i})
for i, point in enumerate(points)
]
context = sycamore.init()
docset = context.read.document(docs)
centroids = docset.kmeans(3, 4)

clustered_docs = docset.clustering(centroids, "cluster").take_all()
ids = [doc.properties["cluster"] for doc in clustered_docs]
assert all(0 <= idx < 3 for idx in ids)
74 changes: 74 additions & 0 deletions lib/sycamore/sycamore/transforms/clustering.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
import random

import torch
from ray.data.aggregate import AggregateFn


class KMeans:

@staticmethod
def closest(row, centroids):
row = torch.Tensor([row])
centroids = torch.Tensor(centroids)
distance = torch.cdist(row, centroids)
idx = torch.argmin(distance)
return idx

@staticmethod
def converged(last_ones, next_ones, epsilon):
distance = torch.cdist(torch.Tensor(last_ones), torch.Tensor(next_ones))
return len(last_ones) == torch.sum(distance < epsilon)

@staticmethod
def random_init(embeddings, K):
count = embeddings.count()
assert count > 0 and K < count
fraction = min(2 * K / count, 1.0)

candidates = [list(c["vector"]) for c in embeddings.random_sample(fraction).take()]
candidates.sort()
from itertools import groupby

uniques = [key for key, _ in groupby(candidates)]
assert len(uniques) >= K

centroids = random.sample(uniques, K)
return centroids

@staticmethod
def init(embeddings, K, init_mode):
if init_mode == "random":
return KMeans.random_init(embeddings, K)
else:
raise Exception("Unknown init mode")

@staticmethod
def update(embeddings, centroids, iterations, epsilon):
i = 0
d = len(centroids[0])

update_centroids = AggregateFn(
init=lambda v: ([0] * d, 0),
accumulate_row=lambda a, row: ([x + y for x, y in zip(a[0], row["vector"])], a[1] + 1),
merge=lambda a1, a2: ([x + y for x, y in zip(a1[0], a2[0])], a1[1] + a2[1]),
name="centroids",
)

while i < iterations:

def _find_cluster(row):
idx = KMeans.closest(row["vector"], centroids)
return {"vector": row["vector"], "cluster": idx}

aggregated = embeddings.map(_find_cluster).groupby("cluster").aggregate(update_centroids).take()
import numpy as np

new_centroids = [list(np.array(c["centroids"][0]) / c["centroids"][1]) for c in aggregated]

if KMeans.converged(centroids, new_centroids, epsilon):
return new_centroids
else:
i += 1
centroids = new_centroids

return centroids
Loading