Skip to content

Commit

Permalink
Add kmeans clustering based on ray
Browse files Browse the repository at this point in the history
This includes generally three steps:
1. materialize a document's embedding
2. initialize centroids randomly
2. iterate the kmeans process until converge, this is based on ray
   dataset map group and aggregate operators.

The result centroids could be used for downstream work.
  • Loading branch information
bohou-aryn committed Dec 19, 2024
1 parent 9ddcaef commit 006b4fd
Show file tree
Hide file tree
Showing 3 changed files with 151 additions and 0 deletions.
25 changes: 25 additions & 0 deletions lib/sycamore/sycamore/docset.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
from typing import Callable, Optional, Any, Iterable, Type, Union, TYPE_CHECKING
import re

from ray.data.aggregate import AggregateFn

from sycamore.context import Context, context_params, OperationTypes
from sycamore.data import Document, Element, MetadataDocument
from sycamore.functions.tokenizer import Tokenizer
Expand All @@ -16,6 +18,7 @@
)
from sycamore.plan_nodes import Node, Transform
from sycamore.transforms.augment_text import TextAugmentor
from sycamore.transforms.clustering import KMeans
from sycamore.transforms.embed import Embedder
from sycamore.transforms import DocumentStructure, Sort
from sycamore.transforms.extract_entity import EntityExtractor, OpenAIEntityExtractor
Expand Down Expand Up @@ -903,6 +906,28 @@ def map(self, f: Callable[[Document], Document], **resource_args) -> "DocSet":
mapping = Map(self.plan, f=f, **resource_args)
return DocSet(self.context, mapping)

def kmeans(self, K: int, iterations: int, init_mode: str = "random", epsilon: float = 1e-4):
"""
Apply kmeans over embedding field
Args:
K: the count of centroids
iterations: the max iteration runs before converge
init_mode: how the initial centroids are select
epsilon: the condition for determining if it's converged
Return a list of max K centroids
"""

def init_embedding(row):
doc = Document.from_row(row)
return {"vector": doc.embedding, "cluster": -1}

embeddings = self.plan.execute().map(init_embedding).materialize()

initial_centroids = KMeans.init(embeddings, K, init_mode)
centroids = KMeans.update(embeddings, initial_centroids, iterations, epsilon)
return centroids

def flat_map(self, f: Callable[[Document], list[Document]], **resource_args) -> "DocSet":
"""
Applies the FlatMap transformation on the Docset.
Expand Down
52 changes: 52 additions & 0 deletions lib/sycamore/sycamore/tests/unit/transforms/test_clustering.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
import numpy as np
import pytest
import ray.data

import sycamore
from sycamore import DocSet
from sycamore.data import Document
from sycamore.transforms.clustering import KMeans


class TestKMeans:

def test_kmeans(self):
points = np.random.uniform(0, 40, (20, 4))
docs = [
Document(text_representation=f"Document {i}", doc_id=i, embedding=point, properties={"document_number": i})
for i, point in enumerate(points)
]
context = sycamore.init()
docset = context.read.document(docs)
centroids = docset.kmeans(3, 4)
assert len(centroids) == 3

def test_closest(self):
row = [[0, 0, 0, 0]]
centroids = [
[1, 1, 1, 1],
[2, 2, 2, 2],
[-1, -1, -1, -1],
]
assert KMeans.closest(row, centroids) == 0

def test_random(self):
points = np.random.uniform(0, 40, (20, 4))
embeddings = [{"vector": list(point), "cluster": -1} for point in points]
embeddings = ray.data.from_items(embeddings)
centroids = KMeans.random_init(embeddings, 10)
assert len(centroids) == 10

def test_converged(self):
last_ones = [[1.0, 1.0], [10.0, 10.0]]
next_ones = [[2.0, 2.0], [12.0, 12.0]]
assert KMeans.converged(last_ones, next_ones, 10) == True
assert KMeans.converged(last_ones, next_ones, 1) == False

def test_converge(self):
points = np.random.uniform(0, 10, (20, 4))
embeddings = [{"vector": list(point), "cluster": -1} for point in points]
embeddings = ray.data.from_items(embeddings)
centroids = [[2.0, 2.0, 2.0, 2.0], [8.0, 8.0, 8.0, 8.0]]
new_centroids = KMeans.update(embeddings, centroids, 2, 1e-4)
assert len(new_centroids) == 2
74 changes: 74 additions & 0 deletions lib/sycamore/sycamore/transforms/clustering.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
import random

import torch
from ray.data.aggregate import AggregateFn


class KMeans:

@staticmethod
def closest(row, centroids):
row = torch.Tensor([row])
centroids = torch.Tensor(centroids)
distance = torch.cdist(row, centroids)
idx = torch.argmin(distance)
return idx

@staticmethod
def converged(last_ones, next_ones, epsilon):
distance = torch.cdist(torch.Tensor(last_ones), torch.Tensor(next_ones))
return len(last_ones) == torch.sum(distance < epsilon)

@staticmethod
def random_init(embeddings, K):
count = embeddings.count()
assert count > 0 and K < count
fraction = min(2 * K / count, 1.0)

candidates = [list(c["vector"]) for c in embeddings.random_sample(fraction).take()]
candidates.sort()
from itertools import groupby

uniques = [key for key, _ in groupby(candidates)]
assert len(uniques) >= K

centroids = random.sample(uniques, K)
return centroids

@staticmethod
def init(embeddings, K, init_mode):
if init_mode == "random":
return KMeans.random_init(embeddings, K)
else:
raise Exception("Unknown init mode")

@staticmethod
def update(embeddings, centroids, iterations, epsilon):
i = 0
d = len(centroids[0])

update_centroids = AggregateFn(
init=lambda v: ([0] * d, 0),
accumulate_row=lambda a, row: ([x + y for x, y in zip(a[0], row["vector"])], a[1] + 1),
merge=lambda a1, a2: ([x + y for x, y in zip(a1[0], a2[0])], a1[1] + a2[1]),
name="centroids",
)

while i < iterations:

def _find_cluster(row):
idx = KMeans.closest(row["vector"], centroids)
return {"vector": row["vector"], "cluster": idx}

aggregated = embeddings.map(_find_cluster).groupby("cluster").aggregate(update_centroids).take()
import numpy as np

new_centroids = [list(np.array(c["centroids"][0]) / c["centroids"][1]) for c in aggregated]

if KMeans.converged(centroids, new_centroids, epsilon):
return new_centroids
else:
i += 1
centroids = new_centroids

return centroids

0 comments on commit 006b4fd

Please sign in to comment.