Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SurrealDB MTree implementation #479

Closed
7 changes: 7 additions & 0 deletions ann_benchmarks/algorithms/surreal_bruteforce/Dockerfile
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
FROM ann-benchmarks

RUN apt-get -y install curl
RUN curl --proto '=https' --tlsv1.2 -sSf https://install.surrealdb.com | sh -s -- --nightly

RUN pip install requests

13 changes: 13 additions & 0 deletions ann_benchmarks/algorithms/surreal_bruteforce/config.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
float:
any:
- base_args: ['@metric']
constructor: SurrealBruteForce
disabled: false
docker_tag: ann-benchmarks-surreal_bruteforce
module: ann_benchmarks.algorithms.surreal_bruteforce
name: surreal_bruteforce
run_groups:
SURREAL-BRUTEFORCE:
arg_groups: [{}]
args: {}
query_args: [['PARALLEL', '']]
93 changes: 93 additions & 0 deletions ann_benchmarks/algorithms/surreal_bruteforce/module.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
import subprocess
import sys
import requests

from ..base.module import BaseANN
from time import sleep

class SurrealBruteForce(BaseANN):

def __init__(self, metric, method_param):
if metric == "euclidean":
self._metric = 'EUCLIDEAN'
elif metric == 'manhattan':
self._metric = 'MANHATTAN'
elif metric == 'angular':
self._metric = 'COSINE'
elif metric == 'hamming':
self._metric = 'HAMMING'
elif metric == 'jaccard':
self._metric = 'JACCARD'
else:
raise RuntimeError(f"unknown metric {metric}")
subprocess.run(f"surreal start --allow-all -u ann -p ann -b 127.0.0.1:8000 memory &", shell=True, check=True, stdout=sys.stdout, stderr=sys.stderr)
print("wait for the server to be up...")
sleep(5)
self._session = requests.Session()
self._session.auth = ('ann', 'ann')
headers={
"surreal-ns": 'ann',
"surreal-db": 'ann',
"Accept": "application/json",
}
self._session.headers.update(headers)

def _sql(self, q):
r = self._session.post('http://127.0.0.1:8000/sql', q)
if r.status_code != 200:
raise RuntimeError(f"{r.text}")
return r

def _ingest(self, dim, X):
# Fit the database per batch
print("Ingesting vectors...")
batch = max(20000 // dim, 1)
q = ""
l = 0
t = 0
for i, embedding in enumerate(X):
v = embedding.tolist()
l += 1
q += f"CREATE items:{i} SET r={v} RETURN NONE;"
if l == batch:
self._checked_sql(q)
q = ''
t += l
l = 0
print(f"\r{t} vectors ingested", end = '')
if l > 0:
self._checked_sql(q)
t += l
print(f"\r{t} vectors ingested", end = '')

def fit(self, X):
dim = X.shape[1]
self._ingest(dim, X)
print("\nIndex construction done")

def _checked_sql(self, q):
res = self._sql(q).json()
for r in res:
if r['status'] != 'OK':
raise RuntimeError(f"Error: {r}")
return res

def set_query_arguments(self, parallel):
self._parallel = parallel
print("parallel = " + self._parallel)

def query(self, v, n):
v = v.tolist()
j = self._checked_sql(f"SELECT id FROM items WHERE r <|{n},{self._metric}|> {v} {self._parallel};")
items = []
for item in j[0]['result']:
id = item['id']
items.append(int(id[6:]))
return items

def __str__(self):
return f"SurrealBruteForce(parallel={self._parallel})"

def done(self) -> None:
self._session.close()
subprocess.run("pkill surreal", shell=True, check=True, stdout=sys.stdout, stderr=sys.stderr)
7 changes: 7 additions & 0 deletions ann_benchmarks/algorithms/surreal_hnsw/Dockerfile
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
FROM ann-benchmarks

RUN apt-get -y install curl
RUN curl --proto '=https' --tlsv1.2 -sSf https://install.surrealdb.com | sh -s -- --nightly

RUN pip install requests

13 changes: 13 additions & 0 deletions ann_benchmarks/algorithms/surreal_hnsw/config.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
float:
euclidean:
- base_args: ['@metric']
constructor: SurrealHnsw
disabled: false
docker_tag: ann-benchmarks-surreal_hnsw
module: ann_benchmarks.algorithms.surreal_hnsw
name: surreal_hnsw
run_groups:
M-24-500:
arg_groups: [{M: 24, efConstruction: 150}]
args: {}
query_args: [[5, 10, 20, 40, 80]]
99 changes: 99 additions & 0 deletions ann_benchmarks/algorithms/surreal_hnsw/module.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
import subprocess
import sys
import requests

from ..base.module import BaseANN
from time import sleep

class SurrealHnsw(BaseANN):

def __init__(self, metric, method_param):
if metric == "euclidean":
self._metric = 'EUCLIDEAN'
elif metric == 'manhattan':
self._metric = 'MANHATTAN'
elif metric == 'angular':
self._metric = 'COSINE'
elif metric == 'jaccard':
self._metric = 'JACCARD'
else:
raise RuntimeError(f"unknown metric {metric}")
self._m = method_param['M']
self._efc = method_param['efConstruction']
subprocess.run(f"surreal start --allow-all -u ann -p ann -b 127.0.0.1:8000 memory &", shell=True, check=True, stdout=sys.stdout, stderr=sys.stderr)
print("wait for the server to be up...")
sleep(5)
self._session = requests.Session()
self._session.auth = ('ann', 'ann')
headers={
"surreal-ns": 'ann',
"surreal-db": 'ann',
"Accept": "application/json",
}
self._session.headers.update(headers)

def _sql(self, q):
r = self._session.post('http://127.0.0.1:8000/sql', q)
if r.status_code != 200:
raise RuntimeError(f"{r.text}")
return r

def _create_index(self, dim):
s = f"DEFINE INDEX ix ON items FIELDS r HNSW DIMENSION {dim} DIST {self._metric} TYPE F32 EFC {self._efc} M {self._m}"
self._checked_sql(s)


def _ingest(self, dim, X):
# Fit the database per batch
print("Ingesting vectors...")
batch = max(20000 // dim, 1)
q = ""
l = 0
t = 0
for i, embedding in enumerate(X):
v = embedding.tolist()
l += 1
q += f"CREATE items:{i} SET r={v} RETURN NONE;"
if l == batch:
self._checked_sql(q)
q = ''
t += l
l = 0
print(f"\r{t} vectors ingested", end = '')
if l > 0:
self._checked_sql(q)
t += l
print(f"\r{t} vectors ingested", end = '')

def fit(self, X):
dim = X.shape[1]
self._create_index(dim)
self._ingest(dim, X)
print("\nIndex construction done")

def _checked_sql(self, q):
res = self._sql(q).json()
for r in res:
if r['status'] != 'OK':
raise RuntimeError(f"Error: {r}")
return res

def set_query_arguments(self, ef_search):
self._efs = ef_search
print("ef = " + str(self._efs))

def query(self, v, n):
v = v.tolist()
j = self._checked_sql(f"SELECT id FROM items WHERE r <|{n},{self._efs}|> {v};")
items = []
for item in j[0]['result']:
id = item['id']
items.append(int(id[6:]))
return items

def __str__(self):
return f"SurrealHnsw(M={self._m}, efc={self._efc}, efs={self._efs})"

def done(self) -> None:
self._session.close()
subprocess.run("pkill surreal", shell=True, check=True, stdout=sys.stdout, stderr=sys.stderr)
7 changes: 7 additions & 0 deletions ann_benchmarks/algorithms/surreal_mtree/Dockerfile
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
FROM ann-benchmarks

RUN apt-get -y install curl
RUN curl --proto '=https' --tlsv1.2 -sSf https://install.surrealdb.com | sh -s -- --nightly

RUN pip install requests

21 changes: 21 additions & 0 deletions ann_benchmarks/algorithms/surreal_mtree/config.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
float:
any:
- base_args: ['@metric']
constructor: SurrealMtree
disabled: false
docker_tag: ann-benchmarks-surreal_mtree
module: ann_benchmarks.algorithms.surreal_mtree
name: surreal_mtree
run_groups:
# memory_large_cache:
# args: ['memory', [20, 40, 80], 5000, 5000]
# memory_full_cache:
# args: ['memory', [20, 40, 80], 0, 0]
# memory_small_cache:
# args: ['memory', [20, 40, 80], 10, 10]
# disk_large_cache:
# args: ['file:mydata/ann.db', [40], 5000, 5000]
disk_full_cache:
args: ['file:mydata/ann.db', [40], 0, 0]
# disk_small_cache:
# args: ['file:mydata/ann.db', [40], 10, 10]
96 changes: 96 additions & 0 deletions ann_benchmarks/algorithms/surreal_mtree/module.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
import subprocess
import sys
import requests

from ..base.module import BaseANN
from time import sleep

class SurrealMtree(BaseANN):

def __init__(self, metric, path = 'memory', capacity = 40, doc_ids_cache = 100, mtree_cache = 100):
self._metric = metric
self._path = path
self._capacity = capacity
self._doc_ids_cache = doc_ids_cache
self._mtree_cache = mtree_cache
subprocess.run(f"surreal start --allow-all -u ann -p ann -b 127.0.0.1:8000 {path} &", shell=True, check=True, stdout=sys.stdout, stderr=sys.stderr)
print("wait for the server to be up...")
sleep(5)
self._session = requests.Session()
self._session.auth = ('ann', 'ann')
headers={
"surreal-ns": 'ann',
"surreal-db": 'ann',
"Accept": "application/json",
}
self._session.headers.update(headers)

def _sql(self, q):
r = self._session.post('http://127.0.0.1:8000/sql', q)
if r.status_code != 200:
raise RuntimeError(f"{r.text}")
return r

def _create_index(self, dim):
if self._metric == "euclidean":
dist = 'EUCLIDEAN'
elif self._metric == 'manhattan':
dist = 'MANHATTAN'
else:
raise RuntimeError(f"unknown metric {self.metric}")
sql = "REMOVE INDEX IF EXISTS ix ON items;\nREMOVE TABLE IF EXISTS items;\n"
sql += f"DEFINE INDEX ix ON items FIELDS r MTREE DIMENSION {dim} DIST {dist} TYPE F32 CAPACITY {self._capacity} DOC_IDS_CACHE {self._doc_ids_cache} MTREE_CACHE {self._mtree_cache};"
print(f"\r{sql}")
self._checked_sql(sql)


def _ingest(self, dim, X):
# Fit the database per batch
print("Ingesting vectors...")
batch = max(20000 // dim, 1)
q = ""
l = 0
t = 0
for i, embedding in enumerate(X):
v = embedding.tolist()
l += 1
q += f"CREATE items:{i} SET r={v} RETURN NONE;"
if l == batch:
self._checked_sql(q)
q = ''
t += l
l = 0
print(f"\r{t} vectors ingested", end = '')
if l > 0:
self._checked_sql(q)
t += l
print(f"\r{t} vectors ingested", end = '')

def fit(self, X):
dim = X.shape[1]
self._create_index(dim)
self._ingest(dim, X)
print("\nIndex construction done")

def _checked_sql(self, q):
res = self._sql(q).json()
for r in res:
if r['status'] != 'OK':
raise RuntimeError(f"Error: {r}")
return res

def query(self, v, n):
v = v.tolist()
j = self._checked_sql(f"SELECT id FROM items WHERE r <|{n}|> {v};")
items = []
for item in j[0]['result']:
id = item['id']
items.append(int(id[6:]))
return items

def __str__(self):
return f"SurrealMtree(path={self._path}, capacity={self._capacity}, doc_ids_cache={self._doc_ids_cache}, mtree_cache={self._mtree_cache})"

def done(self) -> None:
self._session.close()
subprocess.run("pkill surreal", shell=True, check=True, stdout=sys.stdout, stderr=sys.stderr)
Binary file removed results/fashion-mnist-784-euclidean.png
Binary file not shown.
Binary file removed results/gist-960-euclidean.png
Binary file not shown.
Binary file removed results/glove-100-angular.png
Binary file not shown.
Binary file removed results/glove-25-angular.png
Binary file not shown.
Binary file removed results/mnist-784-euclidean.png
Binary file not shown.
Binary file removed results/nytimes-256-angular.png
Binary file not shown.
Binary file removed results/sift-128-euclidean.png
Binary file not shown.