Skip to content

Commit

Permalink
Merge pull request #1603 from tursodatabase/vector-search-more-tests
Browse files Browse the repository at this point in the history
add more tests for vector feature
  • Loading branch information
sivukhin authored Jul 25, 2024
2 parents b233a11 + da55b5f commit ce25cac
Show file tree
Hide file tree
Showing 4 changed files with 124 additions and 0 deletions.
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

20 changes: 20 additions & 0 deletions libsql-sqlite3/test/libsql_vector_index.test
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,26 @@ do_execsql_test vector-vacuum {
SELECT COUNT(*) FROM t_vacuum_idx_shadow;
} {2 2}

do_execsql_test vector-many-columns {
CREATE TABLE t_many ( i INTEGER PRIMARY KEY, e1 FLOAT32(2), e2 FLOAT32(2) );
CREATE INDEX t_many_1_idx ON t_many(libsql_vector_idx(e1));
CREATE INDEX t_many_2_idx ON t_many(libsql_vector_idx(e2));
INSERT INTO t_many VALUES (1, vector('[1,1]'), vector('[-1,-1]')), (2, vector('[-1,-1]'), vector('[1,1]'));
SELECT * FROM vector_top_k('t_many_1_idx', vector('[1,1]'), 2);
SELECT * FROM vector_top_k('t_many_2_idx', vector('[1,1]'), 2);
} {1 2 2 1}

do_execsql_test vector-transaction {
CREATE TABLE t_transaction ( i INTEGER PRIMARY KEY, e FLOAT32(2) );
CREATE INDEX t_transaction_idx ON t_transaction(libsql_vector_idx(e));
INSERT INTO t_transaction VALUES (1, vector('[1,2]')), (2, vector('[3,4]'));
BEGIN;
INSERT INTO t_transaction VALUES (3, vector('[4,5]')), (4, vector('[5,6]'));
SELECT * FROM vector_top_k('t_transaction_idx', vector('[4,5]'), 2);
ROLLBACK;
SELECT * FROM vector_top_k('t_transaction_idx', vector('[1,2]'), 2);
} {3 4 1 2}

proc error_messages {sql} {
set ret ""
catch {
Expand Down
1 change: 1 addition & 0 deletions libsql/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ tokio = { version = "1.29.1", features = ["full"] }
tokio-test = "0.4"
tracing-subscriber = "0.3"
tempfile = { version = "3.7.0" }
rand = "0.8.5"

[features]
default = ["core", "replication", "remote"]
Expand Down
102 changes: 102 additions & 0 deletions libsql/tests/integration_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@ use libsql::{
params::{IntoParams, IntoValue},
Connection, Database, Value,
};
use rand::distributions::Uniform;
use rand::prelude::*;
use std::collections::HashSet;

async fn setup() -> Connection {
let db = Database::open(":memory:").unwrap();
Expand Down Expand Up @@ -650,3 +653,102 @@ async fn deserialize_row() {
assert_eq!(data.status, Status::Draft);
assert_eq!(data.wrapper, Wrapper(Status::Published));
}

#[tokio::test]
#[ignore]
// fuzz test can be run explicitly with following command:
// cargo test vector_fuzz_test -- --nocapture --include-ignored
async fn vector_fuzz_test() {
let mut global_rng = rand::thread_rng();
for attempt in 0..10000 {
let seed = global_rng.next_u64();

let mut rng =
rand::rngs::StdRng::from_seed(unsafe { std::mem::transmute([seed, seed, seed, seed]) });
let db = Database::open(":memory:").unwrap();
let conn = db.connect().unwrap();
let dim = rng.gen_range(1..=1536);
let operations = rng.gen_range(1..128);
println!(
"============== ATTEMPT {} (seed {}u64, dim {}, operations {}) ================",
attempt, seed, dim, operations
);

let _ = conn
.execute(
&format!(
"CREATE TABLE users (id INTEGER PRIMARY KEY, v FLOAT32({}) )",
dim
),
(),
)
.await;
// println!("CREATE TABLE users (id INTEGER PRIMARY KEY, v FLOAT32({}) );", dim);
let _ = conn
.execute(
"CREATE INDEX users_idx ON users ( libsql_vector_idx(v) );",
(),
)
.await;
// println!("CREATE INDEX users_idx ON users ( libsql_vector_idx(v) );");

let mut next_id = 1;
let mut alive = HashSet::new();
let uniform = Uniform::new(-1.0, 1.0);
for _ in 0..operations {
let operation = rng.gen_range(0..4);
let vector: Vec<f32> = (0..dim).map(|_| rng.sample(uniform)).collect();
let vector_str = format!(
"[{}]",
vector
.iter()
.map(|x| format!("{}", x))
.collect::<Vec<String>>()
.join(",")
);
if operation == 0 {
// println!("INSERT INTO users VALUES ({}, vector('{}') );", next_id, vector_str);
conn.execute(
"INSERT INTO users VALUES (?, vector(?) )",
libsql::params![next_id, vector_str],
)
.await
.unwrap();
alive.insert(next_id);
next_id += 1;
} else if operation == 1 {
let id = rng.gen_range(0..next_id);
// println!("DELETE FROM users WHERE id = {};", id);
conn.execute("DELETE FROM users WHERE id = ?", libsql::params![id])
.await
.unwrap();
alive.remove(&id);
} else if operation == 2 && !alive.is_empty() {
let id = alive.iter().collect::<Vec<_>>()[rng.gen_range(0..alive.len())];
// println!("UPDATE users SET v = vector('{}') WHERE id = {};", vector_str, id);
conn.execute(
"UPDATE users SET v = vector(?) WHERE id = ?",
libsql::params![vector_str, id],
)
.await
.unwrap();
} else if operation == 3 {
let k = rng.gen_range(1..200);
// println!("SELECT * FROM vector_top_k('users_idx', '{}', {});", vector_str, k);
let result = conn
.query(
"SELECT * FROM vector_top_k('users_idx', ?, ?)",
libsql::params![vector_str, k],
)
.await
.unwrap();
let count = result.into_stream().count().await;
assert!(count <= alive.len());
if alive.len() > 0 {
assert!(count > 0);
}
}
}
let _ = conn.execute("REINDEX users;", ()).await.unwrap();
}
}

0 comments on commit ce25cac

Please sign in to comment.