From 3906f1758196a0ffa9b474062983fe7b6691e42e Mon Sep 17 00:00:00 2001 From: Bohou Li Date: Tue, 17 Dec 2024 14:56:13 -0800 Subject: [PATCH] Add kmeans clustering based on ray This includes generally three steps: 1. materialize a document's embedding 2. initialize centroids randomly 2. iterate the kmeans process until converge, this is based on ray dataset map group and aggregate operators. The result centroids could be used for downstream work. --- lib/sycamore/sycamore/docset.py | 15 +++++ .../tests/unit/transforms/test_clustering.py | 56 ++++++++++++++++++ .../sycamore/transforms/clustering.py | 59 +++++++++++++++++++ 3 files changed, 130 insertions(+) create mode 100644 lib/sycamore/sycamore/tests/unit/transforms/test_clustering.py create mode 100644 lib/sycamore/sycamore/transforms/clustering.py diff --git a/lib/sycamore/sycamore/docset.py b/lib/sycamore/sycamore/docset.py index 581a90693..e16a152b3 100644 --- a/lib/sycamore/sycamore/docset.py +++ b/lib/sycamore/sycamore/docset.py @@ -6,6 +6,8 @@ from typing import Callable, Optional, Any, Iterable, Type, Union, TYPE_CHECKING import re +from ray.data.aggregate import AggregateFn + from sycamore.context import Context, context_params, OperationTypes from sycamore.data import Document, Element, MetadataDocument from sycamore.functions.tokenizer import Tokenizer @@ -16,6 +18,7 @@ ) from sycamore.plan_nodes import Node, Transform from sycamore.transforms.augment_text import TextAugmentor +from sycamore.transforms.clustering import KMeans from sycamore.transforms.embed import Embedder from sycamore.transforms import DocumentStructure, Sort from sycamore.transforms.extract_entity import EntityExtractor, OpenAIEntityExtractor @@ -903,6 +906,18 @@ def map(self, f: Callable[[Document], Document], **resource_args) -> "DocSet": mapping = Map(self.plan, f=f, **resource_args) return DocSet(self.context, mapping) + def kmeans(self, K: int, iterations: int, epsilon: float = 1e-4): + # TODO, if there is no embedding column, raise exception + def init_embedding(row): + doc = Document.from_row(row) + return {"vector": doc.embedding, "cluster": -1} + + embeddings = self.plan.execute().map(init_embedding).materialize() + initial_centroids = KMeans.init(embeddings, K) + centroids = KMeans.update(embeddings, initial_centroids, iterations, epsilon) + del embeddings + return centroids + def flat_map(self, f: Callable[[Document], list[Document]], **resource_args) -> "DocSet": """ Applies the FlatMap transformation on the Docset. diff --git a/lib/sycamore/sycamore/tests/unit/transforms/test_clustering.py b/lib/sycamore/sycamore/tests/unit/transforms/test_clustering.py new file mode 100644 index 000000000..b03dd32c2 --- /dev/null +++ b/lib/sycamore/sycamore/tests/unit/transforms/test_clustering.py @@ -0,0 +1,56 @@ +import pytest +import ray.data + +import sycamore +from sycamore import DocSet +from sycamore.data import Document +from sycamore.transforms.clustering import KMeans + + +class TestKMeans: + @pytest.fixture() + def docs(self) -> list[Document]: + print("Generating docs") + return [ + Document( + text_representation=f"Document {i}", + doc_id=i, + embedding=[1.1, 2.2, 3.3, 4.4, 5.5], + properties={"document_number": i}, + ) + for i in range(100) + ] + + @pytest.fixture() + def docset(self, docs: list[Document]) -> DocSet: + context = sycamore.init() + return context.read.document(docs) + + def test_kmeans(self, docset: DocSet): + centroids = docset.kmeans(3, 4) + assert len(centroids) == 3 + + def test_closest(self): + row = [[0, 0, 0, 0]] + centroids = [ + [1, 1, 1, 1], + [2, 2, 2, 2], + [-1, -1, -1, -1], + ] + assert KMeans.closest(row, centroids) == 0 + + def test_converged(self): + last_ones = [[1.0, 1.0], [10.0, 10.0]] + next_ones = [[2.0, 2.0], [12.0, 12.0]] + assert KMeans.converged(last_ones, next_ones, 10) == True + assert KMeans.converged(last_ones, next_ones, 1) == False + + def test_converge(self): + import numpy as np + + points = np.random.uniform(0, 10, (20, 4)) + embeddings = [{"vector": list(point), "cluster": -1} for point in points] + embeddings = ray.data.from_items(embeddings) + centroids = [[2.0, 2.0, 2.0, 2.0], [8.0, 8.0, 8.0, 8.0]] + new_centroids = KMeans.update(embeddings, centroids, 2, 1e-4) + assert len(new_centroids) == 2 diff --git a/lib/sycamore/sycamore/transforms/clustering.py b/lib/sycamore/sycamore/transforms/clustering.py new file mode 100644 index 000000000..6c6a8c31d --- /dev/null +++ b/lib/sycamore/sycamore/transforms/clustering.py @@ -0,0 +1,59 @@ +import torch +from ray.data.aggregate import AggregateFn + + +class KMeans: + + @staticmethod + def closest(row, centroids): + row = torch.Tensor([row]) + centroids = torch.Tensor(centroids) + distance = torch.cdist(row, centroids) + id = torch.argmin(distance) + return id + + @staticmethod + def converged(last_ones, next_ones, epsilon): + # TODO, need accumulate the cost also + distance = torch.cdist(torch.Tensor(last_ones), torch.Tensor(next_ones)) + return len(last_ones) == torch.sum(distance < epsilon) + + @staticmethod + def init(embeddings, K): + # TODO, + # 1. fix this random, guarantee K different samples + # 2. take the k-means|| as initialization + sampled = embeddings.take(K) + centroids = [s["vector"] for s in sampled] + return centroids + + @staticmethod + def update(embeddings, centroids, iterations, epsilon): + i = 0 + d = len(centroids[0]) + + update_centroids = AggregateFn( + init=lambda v: ([0] * d, 0), + accumulate_row=lambda a, row: ([x + y for x, y in zip(a[0], row["vector"])], a[1] + 1), + merge=lambda a1, a2: ([x + y for x, y in zip(a1[0], a2[0])], a1[1] + a2[1]), + name="centroids", + ) + + while i < iterations: + + def _find_cluster(row): + idx = KMeans.closest(row["vector"], centroids) + return {"vector": row["vector"], "cluster": idx} + + aggregated = embeddings.map(_find_cluster).groupby("cluster").aggregate(update_centroids).take() + import numpy as np + + new_centroids = [list(np.array(c["centroids"][0]) / c["centroids"][1]) for c in aggregated] + + if KMeans.converged(centroids, new_centroids, epsilon): + return new_centroids + else: + i += 1 + centroids = new_centroids + + return centroids