Skip to content

Commit

Permalink
Merge pull request #1761 from tursodatabase/vector-search-accept-k-float
Browse files Browse the repository at this point in the history
accept K parameter as float if there is no loss in the precision after rounding to the integer
  • Loading branch information
sivukhin authored Sep 30, 2024
2 parents 8abff7b + 80a10f9 commit 54ff421
Show file tree
Hide file tree
Showing 4 changed files with 76 additions and 25 deletions.
32 changes: 24 additions & 8 deletions libsql-ffi/bundled/SQLite3MultipleCiphers/src/sqlite3.c
Original file line number Diff line number Diff line change
Expand Up @@ -216033,6 +216033,7 @@ int vectorIndexSearch(
char **pzErrMsg
) {
int type, dims, k, rc;
double kDouble;
const char *zIdxName;
const char *zErrMsg;
Vector *pVector = NULL;
Expand Down Expand Up @@ -216063,17 +216064,32 @@ int vectorIndexSearch(
rc = SQLITE_ERROR;
goto out;
}
if( sqlite3_value_type(argv[2]) != SQLITE_INTEGER ){
*pzErrMsg = sqlite3_mprintf("vector index(search): third parameter (k) must be a non-negative integer");
rc = SQLITE_ERROR;
goto out;
}
k = sqlite3_value_int(argv[2]);
if( k < 0 ){
*pzErrMsg = sqlite3_mprintf("vector index(search): third parameter (k) must be a non-negative integer");
if( sqlite3_value_type(argv[2]) == SQLITE_INTEGER ){
k = sqlite3_value_int(argv[2]);
if( k < 0 ){
*pzErrMsg = sqlite3_mprintf("vector index(search): third parameter (k) must be a non-negative integer, but negative value were provided");
rc = SQLITE_ERROR;
goto out;
}
}else if( sqlite3_value_type(argv[2]) == SQLITE_FLOAT ) {
kDouble = sqlite3_value_double(argv[2]);
k = (int)kDouble;
if( (double)k != kDouble ){
*pzErrMsg = sqlite3_mprintf("vector index(search): third parameter (k) must be an integer, but float value were provided");
rc = SQLITE_ERROR;
goto out;
}
if( k < 0 ){
*pzErrMsg = sqlite3_mprintf("vector index(search): third parameter (k) must be a non-negative integer, but negative value were provided");
rc = SQLITE_ERROR;
goto out;
}
}else{
*pzErrMsg = sqlite3_mprintf("vector index(search): third parameter (k) must be an integer, but unexpected type of value were provided");
rc = SQLITE_ERROR;
goto out;
}

if( sqlite3_value_type(argv[0]) != SQLITE_TEXT ){
*pzErrMsg = sqlite3_mprintf("vector index(search): first parameter (index) must be a string");
rc = SQLITE_ERROR;
Expand Down
32 changes: 24 additions & 8 deletions libsql-ffi/bundled/src/sqlite3.c
Original file line number Diff line number Diff line change
Expand Up @@ -216033,6 +216033,7 @@ int vectorIndexSearch(
char **pzErrMsg
) {
int type, dims, k, rc;
double kDouble;
const char *zIdxName;
const char *zErrMsg;
Vector *pVector = NULL;
Expand Down Expand Up @@ -216063,17 +216064,32 @@ int vectorIndexSearch(
rc = SQLITE_ERROR;
goto out;
}
if( sqlite3_value_type(argv[2]) != SQLITE_INTEGER ){
*pzErrMsg = sqlite3_mprintf("vector index(search): third parameter (k) must be a non-negative integer");
rc = SQLITE_ERROR;
goto out;
}
k = sqlite3_value_int(argv[2]);
if( k < 0 ){
*pzErrMsg = sqlite3_mprintf("vector index(search): third parameter (k) must be a non-negative integer");
if( sqlite3_value_type(argv[2]) == SQLITE_INTEGER ){
k = sqlite3_value_int(argv[2]);
if( k < 0 ){
*pzErrMsg = sqlite3_mprintf("vector index(search): third parameter (k) must be a non-negative integer, but negative value were provided");
rc = SQLITE_ERROR;
goto out;
}
}else if( sqlite3_value_type(argv[2]) == SQLITE_FLOAT ) {
kDouble = sqlite3_value_double(argv[2]);
k = (int)kDouble;
if( (double)k != kDouble ){
*pzErrMsg = sqlite3_mprintf("vector index(search): third parameter (k) must be an integer, but float value were provided");
rc = SQLITE_ERROR;
goto out;
}
if( k < 0 ){
*pzErrMsg = sqlite3_mprintf("vector index(search): third parameter (k) must be a non-negative integer, but negative value were provided");
rc = SQLITE_ERROR;
goto out;
}
}else{
*pzErrMsg = sqlite3_mprintf("vector index(search): third parameter (k) must be an integer, but unexpected type of value were provided");
rc = SQLITE_ERROR;
goto out;
}

if( sqlite3_value_type(argv[0]) != SQLITE_TEXT ){
*pzErrMsg = sqlite3_mprintf("vector index(search): first parameter (index) must be a string");
rc = SQLITE_ERROR;
Expand Down
32 changes: 24 additions & 8 deletions libsql-sqlite3/src/vectorIndex.c
Original file line number Diff line number Diff line change
Expand Up @@ -951,6 +951,7 @@ int vectorIndexSearch(
char **pzErrMsg
) {
int type, dims, k, rc;
double kDouble;
const char *zIdxName;
const char *zErrMsg;
Vector *pVector = NULL;
Expand Down Expand Up @@ -981,17 +982,32 @@ int vectorIndexSearch(
rc = SQLITE_ERROR;
goto out;
}
if( sqlite3_value_type(argv[2]) != SQLITE_INTEGER ){
*pzErrMsg = sqlite3_mprintf("vector index(search): third parameter (k) must be a non-negative integer");
rc = SQLITE_ERROR;
goto out;
}
k = sqlite3_value_int(argv[2]);
if( k < 0 ){
*pzErrMsg = sqlite3_mprintf("vector index(search): third parameter (k) must be a non-negative integer");
if( sqlite3_value_type(argv[2]) == SQLITE_INTEGER ){
k = sqlite3_value_int(argv[2]);
if( k < 0 ){
*pzErrMsg = sqlite3_mprintf("vector index(search): third parameter (k) must be a non-negative integer, but negative value were provided");
rc = SQLITE_ERROR;
goto out;
}
}else if( sqlite3_value_type(argv[2]) == SQLITE_FLOAT ) {
kDouble = sqlite3_value_double(argv[2]);
k = (int)kDouble;
if( (double)k != kDouble ){
*pzErrMsg = sqlite3_mprintf("vector index(search): third parameter (k) must be an integer, but float value were provided");
rc = SQLITE_ERROR;
goto out;
}
if( k < 0 ){
*pzErrMsg = sqlite3_mprintf("vector index(search): third parameter (k) must be a non-negative integer, but negative value were provided");
rc = SQLITE_ERROR;
goto out;
}
}else{
*pzErrMsg = sqlite3_mprintf("vector index(search): third parameter (k) must be an integer, but unexpected type of value were provided");
rc = SQLITE_ERROR;
goto out;
}

if( sqlite3_value_type(argv[0]) != SQLITE_TEXT ){
*pzErrMsg = sqlite3_mprintf("vector index(search): first parameter (index) must be a string");
rc = SQLITE_ERROR;
Expand Down
5 changes: 4 additions & 1 deletion libsql-sqlite3/test/libsql_vector_index.test
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,8 @@ do_execsql_test vector-simple {
SELECT * FROM vector_top_k('t_simple_idx', '[1,2,3]', 1);
SELECT * FROM vector_top_k('t_simple_idx', '[5,6,7]', 1);
SELECT * FROM vector_top_k('t_simple_idx', vector('[1,2,3]'), 1);
} {{1} {3} {1}}
SELECT * FROM vector_top_k('t_simple_idx', vector('[1,2,3]'), CAST(1 as REAL));
} {{1} {3} {1} {1}}

do_execsql_test vector-empty {
CREATE TABLE t_empty( v FLOAT32(3));
Expand Down Expand Up @@ -484,6 +485,7 @@ do_test vector-errors {
lappend ret [error_messages {INSERT INTO t_err3 VALUES (vector('[1, 2, 3, 4, 5]'))}]
lappend ret [error_messages {INSERT INTO t_err3 VALUES (vector64('[1,2,3,4]'))}]
lappend ret [error_messages {SELECT * FROM vector_top_k('t_err3_idx', vector('[1,2]'), 2)}]
lappend ret [error_messages {SELECT * FROM vector_top_k('t_err3_idx', vector('[1,2,3,4]'), 2.5)}]
sqlite3_exec db { CREATE TABLE t_mixed_t( v FLOAT32(3)); }
sqlite3_exec db { INSERT INTO t_mixed_t VALUES('[1]'); }
lappend ret [error_messages {CREATE INDEX t_mixed_t_idx ON t_mixed_t( libsql_vector_idx(v) )}]
Expand All @@ -503,5 +505,6 @@ do_test vector-errors {
{vector index(insert): dimensions are different: 5 != 4}
{vector index(insert): vector type differs from column type: 2 != 1}
{vector index(search): dimensions are different: 2 != 4}
{vector index(search): third parameter (k) must be an integer, but float value were provided}
{vector index(insert): dimensions are different: 1 != 3}
}]

0 comments on commit 54ff421

Please sign in to comment.