diff --git a/libsql-ffi/bundled/SQLite3MultipleCiphers/src/sqlite3.c b/libsql-ffi/bundled/SQLite3MultipleCiphers/src/sqlite3.c index 69dd8e24a6..82db050d36 100644 --- a/libsql-ffi/bundled/SQLite3MultipleCiphers/src/sqlite3.c +++ b/libsql-ffi/bundled/SQLite3MultipleCiphers/src/sqlite3.c @@ -216033,6 +216033,7 @@ int vectorIndexSearch( char **pzErrMsg ) { int type, dims, k, rc; + double kDouble; const char *zIdxName; const char *zErrMsg; Vector *pVector = NULL; @@ -216063,17 +216064,32 @@ int vectorIndexSearch( rc = SQLITE_ERROR; goto out; } - if( sqlite3_value_type(argv[2]) != SQLITE_INTEGER ){ - *pzErrMsg = sqlite3_mprintf("vector index(search): third parameter (k) must be a non-negative integer"); - rc = SQLITE_ERROR; - goto out; - } - k = sqlite3_value_int(argv[2]); - if( k < 0 ){ - *pzErrMsg = sqlite3_mprintf("vector index(search): third parameter (k) must be a non-negative integer"); + if( sqlite3_value_type(argv[2]) == SQLITE_INTEGER ){ + k = sqlite3_value_int(argv[2]); + if( k < 0 ){ + *pzErrMsg = sqlite3_mprintf("vector index(search): third parameter (k) must be a non-negative integer, but negative value were provided"); + rc = SQLITE_ERROR; + goto out; + } + }else if( sqlite3_value_type(argv[2]) == SQLITE_FLOAT ) { + kDouble = sqlite3_value_double(argv[2]); + k = (int)kDouble; + if( (double)k != kDouble ){ + *pzErrMsg = sqlite3_mprintf("vector index(search): third parameter (k) must be an integer, but float value were provided"); + rc = SQLITE_ERROR; + goto out; + } + if( k < 0 ){ + *pzErrMsg = sqlite3_mprintf("vector index(search): third parameter (k) must be a non-negative integer, but negative value were provided"); + rc = SQLITE_ERROR; + goto out; + } + }else{ + *pzErrMsg = sqlite3_mprintf("vector index(search): third parameter (k) must be an integer, but unexpected type of value were provided"); rc = SQLITE_ERROR; goto out; } + if( sqlite3_value_type(argv[0]) != SQLITE_TEXT ){ *pzErrMsg = sqlite3_mprintf("vector index(search): first parameter (index) must be a string"); rc = SQLITE_ERROR; diff --git a/libsql-ffi/bundled/src/sqlite3.c b/libsql-ffi/bundled/src/sqlite3.c index 69dd8e24a6..82db050d36 100644 --- a/libsql-ffi/bundled/src/sqlite3.c +++ b/libsql-ffi/bundled/src/sqlite3.c @@ -216033,6 +216033,7 @@ int vectorIndexSearch( char **pzErrMsg ) { int type, dims, k, rc; + double kDouble; const char *zIdxName; const char *zErrMsg; Vector *pVector = NULL; @@ -216063,17 +216064,32 @@ int vectorIndexSearch( rc = SQLITE_ERROR; goto out; } - if( sqlite3_value_type(argv[2]) != SQLITE_INTEGER ){ - *pzErrMsg = sqlite3_mprintf("vector index(search): third parameter (k) must be a non-negative integer"); - rc = SQLITE_ERROR; - goto out; - } - k = sqlite3_value_int(argv[2]); - if( k < 0 ){ - *pzErrMsg = sqlite3_mprintf("vector index(search): third parameter (k) must be a non-negative integer"); + if( sqlite3_value_type(argv[2]) == SQLITE_INTEGER ){ + k = sqlite3_value_int(argv[2]); + if( k < 0 ){ + *pzErrMsg = sqlite3_mprintf("vector index(search): third parameter (k) must be a non-negative integer, but negative value were provided"); + rc = SQLITE_ERROR; + goto out; + } + }else if( sqlite3_value_type(argv[2]) == SQLITE_FLOAT ) { + kDouble = sqlite3_value_double(argv[2]); + k = (int)kDouble; + if( (double)k != kDouble ){ + *pzErrMsg = sqlite3_mprintf("vector index(search): third parameter (k) must be an integer, but float value were provided"); + rc = SQLITE_ERROR; + goto out; + } + if( k < 0 ){ + *pzErrMsg = sqlite3_mprintf("vector index(search): third parameter (k) must be a non-negative integer, but negative value were provided"); + rc = SQLITE_ERROR; + goto out; + } + }else{ + *pzErrMsg = sqlite3_mprintf("vector index(search): third parameter (k) must be an integer, but unexpected type of value were provided"); rc = SQLITE_ERROR; goto out; } + if( sqlite3_value_type(argv[0]) != SQLITE_TEXT ){ *pzErrMsg = sqlite3_mprintf("vector index(search): first parameter (index) must be a string"); rc = SQLITE_ERROR; diff --git a/libsql-sqlite3/src/vectorIndex.c b/libsql-sqlite3/src/vectorIndex.c index a278db72d5..45b3eeb5a9 100644 --- a/libsql-sqlite3/src/vectorIndex.c +++ b/libsql-sqlite3/src/vectorIndex.c @@ -951,6 +951,7 @@ int vectorIndexSearch( char **pzErrMsg ) { int type, dims, k, rc; + double kDouble; const char *zIdxName; const char *zErrMsg; Vector *pVector = NULL; @@ -981,17 +982,32 @@ int vectorIndexSearch( rc = SQLITE_ERROR; goto out; } - if( sqlite3_value_type(argv[2]) != SQLITE_INTEGER ){ - *pzErrMsg = sqlite3_mprintf("vector index(search): third parameter (k) must be a non-negative integer"); - rc = SQLITE_ERROR; - goto out; - } - k = sqlite3_value_int(argv[2]); - if( k < 0 ){ - *pzErrMsg = sqlite3_mprintf("vector index(search): third parameter (k) must be a non-negative integer"); + if( sqlite3_value_type(argv[2]) == SQLITE_INTEGER ){ + k = sqlite3_value_int(argv[2]); + if( k < 0 ){ + *pzErrMsg = sqlite3_mprintf("vector index(search): third parameter (k) must be a non-negative integer, but negative value were provided"); + rc = SQLITE_ERROR; + goto out; + } + }else if( sqlite3_value_type(argv[2]) == SQLITE_FLOAT ) { + kDouble = sqlite3_value_double(argv[2]); + k = (int)kDouble; + if( (double)k != kDouble ){ + *pzErrMsg = sqlite3_mprintf("vector index(search): third parameter (k) must be an integer, but float value were provided"); + rc = SQLITE_ERROR; + goto out; + } + if( k < 0 ){ + *pzErrMsg = sqlite3_mprintf("vector index(search): third parameter (k) must be a non-negative integer, but negative value were provided"); + rc = SQLITE_ERROR; + goto out; + } + }else{ + *pzErrMsg = sqlite3_mprintf("vector index(search): third parameter (k) must be an integer, but unexpected type of value were provided"); rc = SQLITE_ERROR; goto out; } + if( sqlite3_value_type(argv[0]) != SQLITE_TEXT ){ *pzErrMsg = sqlite3_mprintf("vector index(search): first parameter (index) must be a string"); rc = SQLITE_ERROR; diff --git a/libsql-sqlite3/test/libsql_vector_index.test b/libsql-sqlite3/test/libsql_vector_index.test index a88e8643c8..951c9086b0 100644 --- a/libsql-sqlite3/test/libsql_vector_index.test +++ b/libsql-sqlite3/test/libsql_vector_index.test @@ -117,7 +117,8 @@ do_execsql_test vector-simple { SELECT * FROM vector_top_k('t_simple_idx', '[1,2,3]', 1); SELECT * FROM vector_top_k('t_simple_idx', '[5,6,7]', 1); SELECT * FROM vector_top_k('t_simple_idx', vector('[1,2,3]'), 1); -} {{1} {3} {1}} + SELECT * FROM vector_top_k('t_simple_idx', vector('[1,2,3]'), CAST(1 as REAL)); +} {{1} {3} {1} {1}} do_execsql_test vector-empty { CREATE TABLE t_empty( v FLOAT32(3)); @@ -484,6 +485,7 @@ do_test vector-errors { lappend ret [error_messages {INSERT INTO t_err3 VALUES (vector('[1, 2, 3, 4, 5]'))}] lappend ret [error_messages {INSERT INTO t_err3 VALUES (vector64('[1,2,3,4]'))}] lappend ret [error_messages {SELECT * FROM vector_top_k('t_err3_idx', vector('[1,2]'), 2)}] + lappend ret [error_messages {SELECT * FROM vector_top_k('t_err3_idx', vector('[1,2,3,4]'), 2.5)}] sqlite3_exec db { CREATE TABLE t_mixed_t( v FLOAT32(3)); } sqlite3_exec db { INSERT INTO t_mixed_t VALUES('[1]'); } lappend ret [error_messages {CREATE INDEX t_mixed_t_idx ON t_mixed_t( libsql_vector_idx(v) )}] @@ -503,5 +505,6 @@ do_test vector-errors { {vector index(insert): dimensions are different: 5 != 4} {vector index(insert): vector type differs from column type: 2 != 1} {vector index(search): dimensions are different: 2 != 4} + {vector index(search): third parameter (k) must be an integer, but float value were provided} {vector index(insert): dimensions are different: 1 != 3} }]