Skip to content

Commit

Permalink
Add kmeans clustering based on ray
Browse files Browse the repository at this point in the history
This includes generally three steps:
1. materialize a document's embedding
2. initialize centroids randomly
2. iterate the kmeans process until converge, this is based on ray
   dataset map group and aggregate operators.

The result centroids could be used for downstream work.
  • Loading branch information
bohou-aryn committed Dec 18, 2024
1 parent 9ddcaef commit 3906f17
Show file tree
Hide file tree
Showing 3 changed files with 130 additions and 0 deletions.
15 changes: 15 additions & 0 deletions lib/sycamore/sycamore/docset.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
from typing import Callable, Optional, Any, Iterable, Type, Union, TYPE_CHECKING
import re

from ray.data.aggregate import AggregateFn

from sycamore.context import Context, context_params, OperationTypes
from sycamore.data import Document, Element, MetadataDocument
from sycamore.functions.tokenizer import Tokenizer
Expand All @@ -16,6 +18,7 @@
)
from sycamore.plan_nodes import Node, Transform
from sycamore.transforms.augment_text import TextAugmentor
from sycamore.transforms.clustering import KMeans
from sycamore.transforms.embed import Embedder
from sycamore.transforms import DocumentStructure, Sort
from sycamore.transforms.extract_entity import EntityExtractor, OpenAIEntityExtractor
Expand Down Expand Up @@ -903,6 +906,18 @@ def map(self, f: Callable[[Document], Document], **resource_args) -> "DocSet":
mapping = Map(self.plan, f=f, **resource_args)
return DocSet(self.context, mapping)

def kmeans(self, K: int, iterations: int, epsilon: float = 1e-4):
# TODO, if there is no embedding column, raise exception
def init_embedding(row):
doc = Document.from_row(row)
return {"vector": doc.embedding, "cluster": -1}

embeddings = self.plan.execute().map(init_embedding).materialize()
initial_centroids = KMeans.init(embeddings, K)
centroids = KMeans.update(embeddings, initial_centroids, iterations, epsilon)
del embeddings
return centroids

def flat_map(self, f: Callable[[Document], list[Document]], **resource_args) -> "DocSet":
"""
Applies the FlatMap transformation on the Docset.
Expand Down
56 changes: 56 additions & 0 deletions lib/sycamore/sycamore/tests/unit/transforms/test_clustering.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
import pytest
import ray.data

import sycamore
from sycamore import DocSet
from sycamore.data import Document
from sycamore.transforms.clustering import KMeans


class TestKMeans:
@pytest.fixture()
def docs(self) -> list[Document]:
print("Generating docs")
return [
Document(
text_representation=f"Document {i}",
doc_id=i,
embedding=[1.1, 2.2, 3.3, 4.4, 5.5],
properties={"document_number": i},
)
for i in range(100)
]

@pytest.fixture()
def docset(self, docs: list[Document]) -> DocSet:
context = sycamore.init()
return context.read.document(docs)

def test_kmeans(self, docset: DocSet):
centroids = docset.kmeans(3, 4)
assert len(centroids) == 3

def test_closest(self):
row = [[0, 0, 0, 0]]
centroids = [
[1, 1, 1, 1],
[2, 2, 2, 2],
[-1, -1, -1, -1],
]
assert KMeans.closest(row, centroids) == 0

def test_converged(self):
last_ones = [[1.0, 1.0], [10.0, 10.0]]
next_ones = [[2.0, 2.0], [12.0, 12.0]]
assert KMeans.converged(last_ones, next_ones, 10) == True
assert KMeans.converged(last_ones, next_ones, 1) == False

def test_converge(self):
import numpy as np

points = np.random.uniform(0, 10, (20, 4))
embeddings = [{"vector": list(point), "cluster": -1} for point in points]
embeddings = ray.data.from_items(embeddings)
centroids = [[2.0, 2.0, 2.0, 2.0], [8.0, 8.0, 8.0, 8.0]]
new_centroids = KMeans.update(embeddings, centroids, 2, 1e-4)
assert len(new_centroids) == 2
59 changes: 59 additions & 0 deletions lib/sycamore/sycamore/transforms/clustering.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
import torch
from ray.data.aggregate import AggregateFn


class KMeans:

@staticmethod
def closest(row, centroids):
row = torch.Tensor([row])
centroids = torch.Tensor(centroids)
distance = torch.cdist(row, centroids)
id = torch.argmin(distance)
return id

@staticmethod
def converged(last_ones, next_ones, epsilon):
# TODO, need accumulate the cost also
distance = torch.cdist(torch.Tensor(last_ones), torch.Tensor(next_ones))
return len(last_ones) == torch.sum(distance < epsilon)

@staticmethod
def init(embeddings, K):
# TODO,
# 1. fix this random, guarantee K different samples
# 2. take the k-means|| as initialization
sampled = embeddings.take(K)
centroids = [s["vector"] for s in sampled]
return centroids

@staticmethod
def update(embeddings, centroids, iterations, epsilon):
i = 0
d = len(centroids[0])

update_centroids = AggregateFn(
init=lambda v: ([0] * d, 0),
accumulate_row=lambda a, row: ([x + y for x, y in zip(a[0], row["vector"])], a[1] + 1),
merge=lambda a1, a2: ([x + y for x, y in zip(a1[0], a2[0])], a1[1] + a2[1]),
name="centroids",
)

while i < iterations:

def _find_cluster(row):
idx = KMeans.closest(row["vector"], centroids)
return {"vector": row["vector"], "cluster": idx}

aggregated = embeddings.map(_find_cluster).groupby("cluster").aggregate(update_centroids).take()
import numpy as np

new_centroids = [list(np.array(c["centroids"][0]) / c["centroids"][1]) for c in aggregated]

if KMeans.converged(centroids, new_centroids, epsilon):
return new_centroids
else:
i += 1
centroids = new_centroids

return centroids

0 comments on commit 3906f17

Please sign in to comment.