-
Notifications
You must be signed in to change notification settings - Fork 49
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
This includes generally three steps: 1. materialize a document's embedding 2. initialize centroids randomly 2. iterate the kmeans process until converge, this is based on ray dataset map group and aggregate operators. The result centroids could be used for downstream work.
- Loading branch information
1 parent
9ddcaef
commit 3906f17
Showing
3 changed files
with
130 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
56 changes: 56 additions & 0 deletions
56
lib/sycamore/sycamore/tests/unit/transforms/test_clustering.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,56 @@ | ||
import pytest | ||
import ray.data | ||
|
||
import sycamore | ||
from sycamore import DocSet | ||
from sycamore.data import Document | ||
from sycamore.transforms.clustering import KMeans | ||
|
||
|
||
class TestKMeans: | ||
@pytest.fixture() | ||
def docs(self) -> list[Document]: | ||
print("Generating docs") | ||
return [ | ||
Document( | ||
text_representation=f"Document {i}", | ||
doc_id=i, | ||
embedding=[1.1, 2.2, 3.3, 4.4, 5.5], | ||
properties={"document_number": i}, | ||
) | ||
for i in range(100) | ||
] | ||
|
||
@pytest.fixture() | ||
def docset(self, docs: list[Document]) -> DocSet: | ||
context = sycamore.init() | ||
return context.read.document(docs) | ||
|
||
def test_kmeans(self, docset: DocSet): | ||
centroids = docset.kmeans(3, 4) | ||
assert len(centroids) == 3 | ||
|
||
def test_closest(self): | ||
row = [[0, 0, 0, 0]] | ||
centroids = [ | ||
[1, 1, 1, 1], | ||
[2, 2, 2, 2], | ||
[-1, -1, -1, -1], | ||
] | ||
assert KMeans.closest(row, centroids) == 0 | ||
|
||
def test_converged(self): | ||
last_ones = [[1.0, 1.0], [10.0, 10.0]] | ||
next_ones = [[2.0, 2.0], [12.0, 12.0]] | ||
assert KMeans.converged(last_ones, next_ones, 10) == True | ||
assert KMeans.converged(last_ones, next_ones, 1) == False | ||
|
||
def test_converge(self): | ||
import numpy as np | ||
|
||
points = np.random.uniform(0, 10, (20, 4)) | ||
embeddings = [{"vector": list(point), "cluster": -1} for point in points] | ||
embeddings = ray.data.from_items(embeddings) | ||
centroids = [[2.0, 2.0, 2.0, 2.0], [8.0, 8.0, 8.0, 8.0]] | ||
new_centroids = KMeans.update(embeddings, centroids, 2, 1e-4) | ||
assert len(new_centroids) == 2 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,59 @@ | ||
import torch | ||
from ray.data.aggregate import AggregateFn | ||
|
||
|
||
class KMeans: | ||
|
||
@staticmethod | ||
def closest(row, centroids): | ||
row = torch.Tensor([row]) | ||
centroids = torch.Tensor(centroids) | ||
distance = torch.cdist(row, centroids) | ||
id = torch.argmin(distance) | ||
return id | ||
|
||
@staticmethod | ||
def converged(last_ones, next_ones, epsilon): | ||
# TODO, need accumulate the cost also | ||
distance = torch.cdist(torch.Tensor(last_ones), torch.Tensor(next_ones)) | ||
return len(last_ones) == torch.sum(distance < epsilon) | ||
|
||
@staticmethod | ||
def init(embeddings, K): | ||
# TODO, | ||
# 1. fix this random, guarantee K different samples | ||
# 2. take the k-means|| as initialization | ||
sampled = embeddings.take(K) | ||
centroids = [s["vector"] for s in sampled] | ||
return centroids | ||
|
||
@staticmethod | ||
def update(embeddings, centroids, iterations, epsilon): | ||
i = 0 | ||
d = len(centroids[0]) | ||
|
||
update_centroids = AggregateFn( | ||
init=lambda v: ([0] * d, 0), | ||
accumulate_row=lambda a, row: ([x + y for x, y in zip(a[0], row["vector"])], a[1] + 1), | ||
merge=lambda a1, a2: ([x + y for x, y in zip(a1[0], a2[0])], a1[1] + a2[1]), | ||
name="centroids", | ||
) | ||
|
||
while i < iterations: | ||
|
||
def _find_cluster(row): | ||
idx = KMeans.closest(row["vector"], centroids) | ||
return {"vector": row["vector"], "cluster": idx} | ||
|
||
aggregated = embeddings.map(_find_cluster).groupby("cluster").aggregate(update_centroids).take() | ||
import numpy as np | ||
|
||
new_centroids = [list(np.array(c["centroids"][0]) / c["centroids"][1]) for c in aggregated] | ||
|
||
if KMeans.converged(centroids, new_centroids, epsilon): | ||
return new_centroids | ||
else: | ||
i += 1 | ||
centroids = new_centroids | ||
|
||
return centroids |