-
Notifications
You must be signed in to change notification settings - Fork 49
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
This includes generally three steps: 1. materialize a document's embedding 2. initialize centroids randomly 2. iterate the kmeans process until converge, this is based on ray dataset map group and aggregate operators. The result centroids could be used for downstream work.
- Loading branch information
1 parent
a3ac64b
commit 9c5c509
Showing
3 changed files
with
174 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
65 changes: 65 additions & 0 deletions
65
lib/sycamore/sycamore/tests/unit/transforms/test_clustering.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,65 @@ | ||
import numpy as np | ||
import ray.data | ||
|
||
import sycamore | ||
from sycamore.data import Document | ||
from sycamore.transforms.clustering import KMeans | ||
|
||
|
||
class TestKMeans: | ||
|
||
def test_kmeans(self): | ||
points = np.random.uniform(0, 40, (20, 4)) | ||
docs = [ | ||
Document(text_representation=f"Document {i}", doc_id=i, embedding=point, properties={"document_number": i}) | ||
for i, point in enumerate(points) | ||
] | ||
context = sycamore.init() | ||
docset = context.read.document(docs) | ||
centroids = docset.kmeans(3, 4) | ||
assert len(centroids) == 3 | ||
|
||
def test_closest(self): | ||
row = [[0, 0, 0, 0]] | ||
centroids = [ | ||
[1, 1, 1, 1], | ||
[2, 2, 2, 2], | ||
[-1, -1, -1, -1], | ||
] | ||
assert KMeans.closest(row, centroids) == 0 | ||
|
||
def test_random(self): | ||
points = np.random.uniform(0, 40, (20, 4)) | ||
embeddings = [{"vector": list(point), "cluster": -1} for point in points] | ||
embeddings = ray.data.from_items(embeddings) | ||
centroids = KMeans.random_init(embeddings, 10) | ||
assert len(centroids) == 10 | ||
|
||
def test_converged(self): | ||
last_ones = [[1.0, 1.0], [10.0, 10.0]] | ||
next_ones = [[2.0, 2.0], [12.0, 12.0]] | ||
assert KMeans.converged(last_ones, next_ones, 10).item() is True | ||
assert KMeans.converged(last_ones, next_ones, 1).item() is False | ||
|
||
def test_converge(self): | ||
points = np.random.uniform(0, 10, (20, 4)) | ||
embeddings = [{"vector": list(point), "cluster": -1} for point in points] | ||
embeddings = ray.data.from_items(embeddings) | ||
centroids = [[2.0, 2.0, 2.0, 2.0], [8.0, 8.0, 8.0, 8.0]] | ||
new_centroids = KMeans.update(embeddings, centroids, 2, 1e-4) | ||
assert len(new_centroids) == 2 | ||
|
||
def test_clustering(self): | ||
np.random.seed(2024) | ||
points = np.random.uniform(0, 40, (20, 4)) | ||
docs = [ | ||
Document(text_representation=f"Document {i}", doc_id=i, embedding=point, properties={"document_number": i}) | ||
for i, point in enumerate(points) | ||
] | ||
context = sycamore.init() | ||
docset = context.read.document(docs) | ||
centroids = docset.kmeans(3, 4) | ||
|
||
clustered_docs = docset.clustering(centroids, "cluster").take_all() | ||
ids = [doc.properties["cluster"] for doc in clustered_docs] | ||
assert all(0 <= idx < 3 for idx in ids) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,74 @@ | ||
import random | ||
|
||
import torch | ||
from ray.data.aggregate import AggregateFn | ||
|
||
|
||
class KMeans: | ||
|
||
@staticmethod | ||
def closest(row, centroids): | ||
row = torch.Tensor([row]) | ||
centroids = torch.Tensor(centroids) | ||
distance = torch.cdist(row, centroids) | ||
idx = torch.argmin(distance) | ||
return idx | ||
|
||
@staticmethod | ||
def converged(last_ones, next_ones, epsilon): | ||
distance = torch.cdist(torch.Tensor(last_ones), torch.Tensor(next_ones)) | ||
return len(last_ones) == torch.sum(distance < epsilon) | ||
|
||
@staticmethod | ||
def random_init(embeddings, K): | ||
count = embeddings.count() | ||
assert count > 0 and K < count | ||
fraction = min(2 * K / count, 1.0) | ||
|
||
candidates = [list(c["vector"]) for c in embeddings.random_sample(fraction).take()] | ||
candidates.sort() | ||
from itertools import groupby | ||
|
||
uniques = [key for key, _ in groupby(candidates)] | ||
assert len(uniques) >= K | ||
|
||
centroids = random.sample(uniques, K) | ||
return centroids | ||
|
||
@staticmethod | ||
def init(embeddings, K, init_mode): | ||
if init_mode == "random": | ||
return KMeans.random_init(embeddings, K) | ||
else: | ||
raise Exception("Unknown init mode") | ||
|
||
@staticmethod | ||
def update(embeddings, centroids, iterations, epsilon): | ||
i = 0 | ||
d = len(centroids[0]) | ||
|
||
update_centroids = AggregateFn( | ||
init=lambda v: ([0] * d, 0), | ||
accumulate_row=lambda a, row: ([x + y for x, y in zip(a[0], row["vector"])], a[1] + 1), | ||
merge=lambda a1, a2: ([x + y for x, y in zip(a1[0], a2[0])], a1[1] + a2[1]), | ||
name="centroids", | ||
) | ||
|
||
while i < iterations: | ||
|
||
def _find_cluster(row): | ||
idx = KMeans.closest(row["vector"], centroids) | ||
return {"vector": row["vector"], "cluster": idx} | ||
|
||
aggregated = embeddings.map(_find_cluster).groupby("cluster").aggregate(update_centroids).take() | ||
import numpy as np | ||
|
||
new_centroids = [list(np.array(c["centroids"][0]) / c["centroids"][1]) for c in aggregated] | ||
|
||
if KMeans.converged(centroids, new_centroids, epsilon): | ||
return new_centroids | ||
else: | ||
i += 1 | ||
centroids = new_centroids | ||
|
||
return centroids |