From 9c5c509b4df5708ea0c4dfe5b23cec7f53049983 Mon Sep 17 00:00:00 2001 From: Bohou Li Date: Tue, 17 Dec 2024 14:56:13 -0800 Subject: [PATCH] Add kmeans clustering based on ray This includes generally three steps: 1. materialize a document's embedding 2. initialize centroids randomly 2. iterate the kmeans process until converge, this is based on ray dataset map group and aggregate operators. The result centroids could be used for downstream work. --- lib/sycamore/sycamore/docset.py | 35 +++++++++ .../tests/unit/transforms/test_clustering.py | 65 ++++++++++++++++ .../sycamore/transforms/clustering.py | 74 +++++++++++++++++++ 3 files changed, 174 insertions(+) create mode 100644 lib/sycamore/sycamore/tests/unit/transforms/test_clustering.py create mode 100644 lib/sycamore/sycamore/transforms/clustering.py diff --git a/lib/sycamore/sycamore/docset.py b/lib/sycamore/sycamore/docset.py index a2480854c..5340cc85e 100644 --- a/lib/sycamore/sycamore/docset.py +++ b/lib/sycamore/sycamore/docset.py @@ -15,6 +15,7 @@ ) from sycamore.plan_nodes import Node, Transform from sycamore.transforms.augment_text import TextAugmentor +from sycamore.transforms.clustering import KMeans from sycamore.transforms.embed import Embedder from sycamore.transforms import DocumentStructure, Sort from sycamore.transforms.extract_entity import EntityExtractor, OpenAIEntityExtractor @@ -915,6 +916,40 @@ def map(self, f: Callable[[Document], Document], **resource_args) -> "DocSet": mapping = Map(self.plan, f=f, **resource_args) return DocSet(self.context, mapping) + def kmeans(self, K: int, iterations: int = 20, init_mode: str = "random", epsilon: float = 1e-4): + """ + Apply kmeans over embedding field + + Args: + K: the count of centroids + iterations: the max iteration runs before converge + init_mode: how the initial centroids are select + epsilon: the condition for determining if it's converged + Return a list of max K centroids + """ + + def init_embedding(row): + doc = Document.from_row(row) + return {"vector": doc.embedding, "cluster": -1} + + embeddings = self.plan.execute().map(init_embedding).materialize() + + initial_centroids = KMeans.init(embeddings, K, init_mode) + centroids = KMeans.update(embeddings, initial_centroids, iterations, epsilon) + return centroids + + def clustering(self, centroids, cluster_field_name, **resource_args) -> "DocSet": + def cluster(doc: Document) -> Document: + idx = KMeans.closest(doc.embedding, centroids) + properties = doc.properties + properties[cluster_field_name] = idx + doc.properties = properties + return doc + + from sycamore.transforms import Map + mapping = Map(self.plan, f=cluster, **resource_args) + return DocSet(self.context, mapping) + def flat_map(self, f: Callable[[Document], list[Document]], **resource_args) -> "DocSet": """ Applies the FlatMap transformation on the Docset. diff --git a/lib/sycamore/sycamore/tests/unit/transforms/test_clustering.py b/lib/sycamore/sycamore/tests/unit/transforms/test_clustering.py new file mode 100644 index 000000000..4305dd0de --- /dev/null +++ b/lib/sycamore/sycamore/tests/unit/transforms/test_clustering.py @@ -0,0 +1,65 @@ +import numpy as np +import ray.data + +import sycamore +from sycamore.data import Document +from sycamore.transforms.clustering import KMeans + + +class TestKMeans: + + def test_kmeans(self): + points = np.random.uniform(0, 40, (20, 4)) + docs = [ + Document(text_representation=f"Document {i}", doc_id=i, embedding=point, properties={"document_number": i}) + for i, point in enumerate(points) + ] + context = sycamore.init() + docset = context.read.document(docs) + centroids = docset.kmeans(3, 4) + assert len(centroids) == 3 + + def test_closest(self): + row = [[0, 0, 0, 0]] + centroids = [ + [1, 1, 1, 1], + [2, 2, 2, 2], + [-1, -1, -1, -1], + ] + assert KMeans.closest(row, centroids) == 0 + + def test_random(self): + points = np.random.uniform(0, 40, (20, 4)) + embeddings = [{"vector": list(point), "cluster": -1} for point in points] + embeddings = ray.data.from_items(embeddings) + centroids = KMeans.random_init(embeddings, 10) + assert len(centroids) == 10 + + def test_converged(self): + last_ones = [[1.0, 1.0], [10.0, 10.0]] + next_ones = [[2.0, 2.0], [12.0, 12.0]] + assert KMeans.converged(last_ones, next_ones, 10).item() is True + assert KMeans.converged(last_ones, next_ones, 1).item() is False + + def test_converge(self): + points = np.random.uniform(0, 10, (20, 4)) + embeddings = [{"vector": list(point), "cluster": -1} for point in points] + embeddings = ray.data.from_items(embeddings) + centroids = [[2.0, 2.0, 2.0, 2.0], [8.0, 8.0, 8.0, 8.0]] + new_centroids = KMeans.update(embeddings, centroids, 2, 1e-4) + assert len(new_centroids) == 2 + + def test_clustering(self): + np.random.seed(2024) + points = np.random.uniform(0, 40, (20, 4)) + docs = [ + Document(text_representation=f"Document {i}", doc_id=i, embedding=point, properties={"document_number": i}) + for i, point in enumerate(points) + ] + context = sycamore.init() + docset = context.read.document(docs) + centroids = docset.kmeans(3, 4) + + clustered_docs = docset.clustering(centroids, "cluster").take_all() + ids = [doc.properties["cluster"] for doc in clustered_docs] + assert all(0 <= idx < 3 for idx in ids) diff --git a/lib/sycamore/sycamore/transforms/clustering.py b/lib/sycamore/sycamore/transforms/clustering.py new file mode 100644 index 000000000..7da63fcca --- /dev/null +++ b/lib/sycamore/sycamore/transforms/clustering.py @@ -0,0 +1,74 @@ +import random + +import torch +from ray.data.aggregate import AggregateFn + + +class KMeans: + + @staticmethod + def closest(row, centroids): + row = torch.Tensor([row]) + centroids = torch.Tensor(centroids) + distance = torch.cdist(row, centroids) + idx = torch.argmin(distance) + return idx + + @staticmethod + def converged(last_ones, next_ones, epsilon): + distance = torch.cdist(torch.Tensor(last_ones), torch.Tensor(next_ones)) + return len(last_ones) == torch.sum(distance < epsilon) + + @staticmethod + def random_init(embeddings, K): + count = embeddings.count() + assert count > 0 and K < count + fraction = min(2 * K / count, 1.0) + + candidates = [list(c["vector"]) for c in embeddings.random_sample(fraction).take()] + candidates.sort() + from itertools import groupby + + uniques = [key for key, _ in groupby(candidates)] + assert len(uniques) >= K + + centroids = random.sample(uniques, K) + return centroids + + @staticmethod + def init(embeddings, K, init_mode): + if init_mode == "random": + return KMeans.random_init(embeddings, K) + else: + raise Exception("Unknown init mode") + + @staticmethod + def update(embeddings, centroids, iterations, epsilon): + i = 0 + d = len(centroids[0]) + + update_centroids = AggregateFn( + init=lambda v: ([0] * d, 0), + accumulate_row=lambda a, row: ([x + y for x, y in zip(a[0], row["vector"])], a[1] + 1), + merge=lambda a1, a2: ([x + y for x, y in zip(a1[0], a2[0])], a1[1] + a2[1]), + name="centroids", + ) + + while i < iterations: + + def _find_cluster(row): + idx = KMeans.closest(row["vector"], centroids) + return {"vector": row["vector"], "cluster": idx} + + aggregated = embeddings.map(_find_cluster).groupby("cluster").aggregate(update_centroids).take() + import numpy as np + + new_centroids = [list(np.array(c["centroids"][0]) / c["centroids"][1]) for c in aggregated] + + if KMeans.converged(centroids, new_centroids, epsilon): + return new_centroids + else: + i += 1 + centroids = new_centroids + + return centroids