Skip to content

Commit

Permalink
Merge pull request #1650 from tursodatabase/vector-search-compression
Browse files Browse the repository at this point in the history
vector search: neighbors compression (1bit quantization)
  • Loading branch information
sivukhin authored Aug 11, 2024
2 parents e4c2afc + 4c38e5f commit 8262d23
Show file tree
Hide file tree
Showing 13 changed files with 1,804 additions and 931 deletions.
907 changes: 597 additions & 310 deletions libsql-ffi/bundled/SQLite3MultipleCiphers/src/sqlite3.c

Large diffs are not rendered by default.

907 changes: 597 additions & 310 deletions libsql-ffi/bundled/src/sqlite3.c

Large diffs are not rendered by default.

6 changes: 5 additions & 1 deletion libsql-sqlite3/Makefile.in
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,7 @@ LIBOBJS0 = alter.lo analyze.lo attach.lo auth.lo \
sqlite3session.lo select.lo sqlite3rbu.lo status.lo stmt.lo \
table.lo threads.lo tokenize.lo treeview.lo trigger.lo \
update.lo userauth.lo upsert.lo util.lo vacuum.lo \
vector.lo vectorfloat32.lo vectorfloat64.lo \
vector.lo vectorfloat32.lo vectorfloat64.lo vector1bit.lo \
vectorIndex.lo vectordiskann.lo vectorvtab.lo \
vdbe.lo vdbeapi.lo vdbeaux.lo vdbeblob.lo vdbemem.lo vdbesort.lo \
vdbetrace.lo vdbevtab.lo \
Expand Down Expand Up @@ -302,6 +302,7 @@ SRC = \
$(TOP)/src/util.c \
$(TOP)/src/vacuum.c \
$(TOP)/src/vector.c \
$(TOP)/src/vector1bit.c \
$(TOP)/src/vectorInt.h \
$(TOP)/src/vectorfloat32.c \
$(TOP)/src/vectorfloat64.c \
Expand Down Expand Up @@ -1138,6 +1139,9 @@ vacuum.lo: $(TOP)/src/vacuum.c $(HDR)
vector.lo: $(TOP)/src/vector.c $(HDR)
$(LTCOMPILE) $(TEMP_STORE) -c $(TOP)/src/vector.c

vector1bit.lo: $(TOP)/src/vector1bit.c $(HDR)
$(LTCOMPILE) $(TEMP_STORE) -c $(TOP)/src/vector1bit.c

vectorfloat32.lo: $(TOP)/src/vectorfloat32.c $(HDR)
$(LTCOMPILE) $(TEMP_STORE) -c $(TOP)/src/vectorfloat32.c

Expand Down
123 changes: 95 additions & 28 deletions libsql-sqlite3/src/vector.c
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,9 @@ size_t vectorDataSize(VectorType type, VectorDims dims){
return dims * sizeof(float);
case VECTOR_TYPE_FLOAT64:
return dims * sizeof(double);
case VECTOR_TYPE_1BIT:
assert( dims > 0 );
return (dims + 7) / 8;
default:
assert(0);
}
Expand Down Expand Up @@ -72,10 +75,11 @@ Vector *vectorAlloc(VectorType type, VectorDims dims){
** Note that the vector object points to the blob so if
** you free the blob, the vector becomes invalid.
**/
void vectorInitStatic(Vector *pVector, VectorType type, const unsigned char *pBlob, size_t nBlobSize){
pVector->type = type;
void vectorInitStatic(Vector *pVector, VectorType type, VectorDims dims, void *pBlob){
pVector->flags = VECTOR_FLAGS_STATIC;
vectorInitFromBlob(pVector, pBlob, nBlobSize);
pVector->type = type;
pVector->dims = dims;
pVector->data = pBlob;
}

/*
Expand Down Expand Up @@ -111,6 +115,8 @@ float vectorDistanceCos(const Vector *pVector1, const Vector *pVector2){
return vectorF32DistanceCos(pVector1, pVector2);
case VECTOR_TYPE_FLOAT64:
return vectorF64DistanceCos(pVector1, pVector2);
case VECTOR_TYPE_1BIT:
return vector1BitDistanceHamming(pVector1, pVector2);
default:
assert(0);
}
Expand Down Expand Up @@ -247,16 +253,34 @@ static int vectorParseSqliteText(
return -1;
}

int vectorParseSqliteBlob(
int vectorParseSqliteBlobWithType(
sqlite3_value *arg,
Vector *pVector,
char **pzErrMsg
){
const unsigned char *pBlob;
size_t nBlobSize;

assert( sqlite3_value_type(arg) == SQLITE_BLOB );

pBlob = sqlite3_value_blob(arg);
nBlobSize = sqlite3_value_bytes(arg);
if( nBlobSize % 2 == 1 ){
nBlobSize--;
}

if( nBlobSize < vectorDataSize(pVector->type, pVector->dims) ){
*pzErrMsg = sqlite3_mprintf("invalid vector: not enough bytes: type=%d, dims=%d, size=%ull", pVector->type, pVector->dims, nBlobSize);
return SQLITE_ERROR;
}

switch (pVector->type) {
case VECTOR_TYPE_FLOAT32:
return vectorF32ParseSqliteBlob(arg, pVector, pzErrMsg);
vectorF32DeserializeFromBlob(pVector, pBlob, nBlobSize);
return 0;
case VECTOR_TYPE_FLOAT64:
return vectorF64ParseSqliteBlob(arg, pVector, pzErrMsg);
vectorF64DeserializeFromBlob(pVector, pBlob, nBlobSize);
return 0;
default:
assert(0);
}
Expand Down Expand Up @@ -339,14 +363,14 @@ int detectVectorParameters(sqlite3_value *arg, int typeHint, int *pType, int *pD
}
}

int vectorParse(
int vectorParseWithType(
sqlite3_value *arg,
Vector *pVector,
char **pzErrMsg
){
switch( sqlite3_value_type(arg) ){
case SQLITE_BLOB:
return vectorParseSqliteBlob(arg, pVector, pzErrMsg);
return vectorParseSqliteBlobWithType(arg, pVector, pzErrMsg);
case SQLITE_TEXT:
return vectorParseSqliteText(arg, pVector, pzErrMsg);
default:
Expand All @@ -363,6 +387,9 @@ void vectorDump(const Vector *pVector){
case VECTOR_TYPE_FLOAT64:
vectorF64Dump(pVector);
break;
case VECTOR_TYPE_1BIT:
vector1BitDump(pVector);
break;
default:
assert(0);
}
Expand All @@ -384,20 +411,47 @@ void vectorMarshalToText(
}
}

void vectorSerialize(
void vectorSerializeWithType(
sqlite3_context *context,
const Vector *pVector
){
unsigned char *pBlob;
size_t nBlobSize, nDataSize;

assert( pVector->dims <= MAX_VECTOR_SZ );

nDataSize = vectorDataSize(pVector->type, pVector->dims);
nBlobSize = nDataSize;
if( pVector->type != VECTOR_TYPE_FLOAT32 ){
nBlobSize += (nBlobSize % 2 == 0 ? 1 : 2);
}

if( nBlobSize == 0 ){
sqlite3_result_zeroblob(context, 0);
return;
}

pBlob = sqlite3_malloc64(nBlobSize);
if( pBlob == NULL ){
sqlite3_result_error_nomem(context);
return;
}

if( pVector->type != VECTOR_TYPE_FLOAT32 ){
pBlob[nBlobSize - 1] = pVector->type;
}

switch (pVector->type) {
case VECTOR_TYPE_FLOAT32:
vectorF32Serialize(context, pVector);
vectorF32SerializeToBlob(pVector, pBlob, nDataSize);
break;
case VECTOR_TYPE_FLOAT64:
vectorF64Serialize(context, pVector);
vectorF64SerializeToBlob(pVector, pBlob, nDataSize);
break;
default:
assert(0);
}
sqlite3_result_blob(context, (char*)pBlob, nBlobSize, sqlite3_free);
}

size_t vectorSerializeToBlob(const Vector *pVector, unsigned char *pBlob, size_t nBlobSize){
Expand All @@ -406,18 +460,8 @@ size_t vectorSerializeToBlob(const Vector *pVector, unsigned char *pBlob, size_t
return vectorF32SerializeToBlob(pVector, pBlob, nBlobSize);
case VECTOR_TYPE_FLOAT64:
return vectorF64SerializeToBlob(pVector, pBlob, nBlobSize);
default:
assert(0);
}
return 0;
}

size_t vectorDeserializeFromBlob(Vector *pVector, const unsigned char *pBlob, size_t nBlobSize){
switch (pVector->type) {
case VECTOR_TYPE_FLOAT32:
return vectorF32DeserializeFromBlob(pVector, pBlob, nBlobSize);
case VECTOR_TYPE_FLOAT64:
return vectorF64DeserializeFromBlob(pVector, pBlob, nBlobSize);
case VECTOR_TYPE_1BIT:
return vector1BitSerializeToBlob(pVector, pBlob, nBlobSize);
default:
assert(0);
}
Expand All @@ -437,6 +481,29 @@ void vectorInitFromBlob(Vector *pVector, const unsigned char *pBlob, size_t nBlo
}
}

void vectorConvert(const Vector *pFrom, Vector *pTo){
int i;
u8 *bitData;
float *floatData;

assert( pFrom->dims == pTo->dims );

if( pFrom->type == VECTOR_TYPE_FLOAT32 && pTo->type == VECTOR_TYPE_1BIT ){
floatData = pFrom->data;
bitData = pTo->data;
for(i = 0; i < pFrom->dims; i += 8){
bitData[i / 8] = 0;
}
for(i = 0; i < pFrom->dims; i++){
if( floatData[i] > 0 ){
bitData[i / 8] |= (1 << (i & 7));
}
}
}else{
assert(0);
}
}

/**************************************************************************
** SQL function implementations
****************************************************************************/
Expand Down Expand Up @@ -465,12 +532,12 @@ static void vectorFuncHintedType(
if( pVector==NULL ){
return;
}
if( vectorParse(argv[0], pVector, &pzErrMsg) != 0 ){
if( vectorParseWithType(argv[0], pVector, &pzErrMsg) != 0 ){
sqlite3_result_error(context, pzErrMsg, -1);
sqlite3_free(pzErrMsg);
goto out_free_vec;
}
vectorSerialize(context, pVector);
vectorSerializeWithType(context, pVector);
out_free_vec:
vectorFree(pVector);
}
Expand Down Expand Up @@ -515,7 +582,7 @@ static void vectorExtractFunc(
if( pVector==NULL ){
return;
}
if( vectorParse(argv[0], pVector, &pzErrMsg)<0 ){
if( vectorParseWithType(argv[0], pVector, &pzErrMsg)<0 ){
sqlite3_result_error(context, pzErrMsg, -1);
sqlite3_free(pzErrMsg);
goto out_free;
Expand Down Expand Up @@ -570,12 +637,12 @@ static void vectorDistanceCosFunc(
if( pVector2==NULL ){
goto out_free;
}
if( vectorParse(argv[0], pVector1, &pzErrMsg)<0 ){
if( vectorParseWithType(argv[0], pVector1, &pzErrMsg)<0 ){
sqlite3_result_error(context, pzErrMsg, -1);
sqlite3_free(pzErrMsg);
goto out_free;
}
if( vectorParse(argv[1], pVector2, &pzErrMsg)<0 ){
if( vectorParseWithType(argv[1], pVector2, &pzErrMsg)<0 ){
sqlite3_result_error(context, pzErrMsg, -1);
sqlite3_free(pzErrMsg);
goto out_free;
Expand Down
127 changes: 127 additions & 0 deletions libsql-sqlite3/src/vector1bit.c
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
/*
** 2024-07-04
**
** Copyright 2024 the libSQL authors
**
** Permission is hereby granted, free of charge, to any person obtaining a copy of
** this software and associated documentation files (the "Software"), to deal in
** the Software without restriction, including without limitation the rights to
** use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
** the Software, and to permit persons to whom the Software is furnished to do so,
** subject to the following conditions:
**
** The above copyright notice and this permission notice shall be included in all
** copies or substantial portions of the Software.
**
** THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
** IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
** FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
** COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
** IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
** CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
**
******************************************************************************
**
** 1-bit vector format utilities.
*/
#ifndef SQLITE_OMIT_VECTOR
#include "sqliteInt.h"

#include "vectorInt.h"

#include <math.h>

/**************************************************************************
** Utility routines for debugging
**************************************************************************/

void vector1BitDump(const Vector *pVec){
u8 *elems = pVec->data;
unsigned i;

assert( pVec->type == VECTOR_TYPE_1BIT );

printf("f1bit: [");
for(i = 0; i < pVec->dims; i++){
printf("%s%d", i == 0 ? "" : ", ", ((elems[i / 8] >> (i & 7)) & 1) ? +1 : -1);
}
printf("]\n");
}

/**************************************************************************
** Utility routines for vector serialization and deserialization
**************************************************************************/

size_t vector1BitSerializeToBlob(
const Vector *pVector,
unsigned char *pBlob,
size_t nBlobSize
){
u8 *elems = pVector->data;
u8 *pPtr = pBlob;
unsigned i;

assert( pVector->type == VECTOR_TYPE_1BIT );
assert( pVector->dims <= MAX_VECTOR_SZ );
assert( nBlobSize >= (pVector->dims + 7) / 8 );

for(i = 0; i < (pVector->dims + 7) / 8; i++){
pPtr[i] = elems[i];
}
return (pVector->dims + 7) / 8;
}

// [sum(map(int, bin(i)[2:])) for i in range(256)]
static int BitsCount[256] = {
0, 1, 1, 2, 1, 2, 2, 3, 1, 2, 2, 3, 2, 3, 3, 4,
1, 2, 2, 3, 2, 3, 3, 4, 2, 3, 3, 4, 3, 4, 4, 5,
1, 2, 2, 3, 2, 3, 3, 4, 2, 3, 3, 4, 3, 4, 4, 5,
2, 3, 3, 4, 3, 4, 4, 5, 3, 4, 4, 5, 4, 5, 5, 6,
1, 2, 2, 3, 2, 3, 3, 4, 2, 3, 3, 4, 3, 4, 4, 5,
2, 3, 3, 4, 3, 4, 4, 5, 3, 4, 4, 5, 4, 5, 5, 6,
2, 3, 3, 4, 3, 4, 4, 5, 3, 4, 4, 5, 4, 5, 5, 6,
3, 4, 4, 5, 4, 5, 5, 6, 4, 5, 5, 6, 5, 6, 6, 7,
1, 2, 2, 3, 2, 3, 3, 4, 2, 3, 3, 4, 3, 4, 4, 5,
2, 3, 3, 4, 3, 4, 4, 5, 3, 4, 4, 5, 4, 5, 5, 6,
2, 3, 3, 4, 3, 4, 4, 5, 3, 4, 4, 5, 4, 5, 5, 6,
3, 4, 4, 5, 4, 5, 5, 6, 4, 5, 5, 6, 5, 6, 6, 7,
2, 3, 3, 4, 3, 4, 4, 5, 3, 4, 4, 5, 4, 5, 5, 6,
3, 4, 4, 5, 4, 5, 5, 6, 4, 5, 5, 6, 5, 6, 6, 7,
3, 4, 4, 5, 4, 5, 5, 6, 4, 5, 5, 6, 5, 6, 6, 7,
4, 5, 5, 6, 5, 6, 6, 7, 5, 6, 6, 7, 6, 7, 7, 8,
};

static inline int sqlite3PopCount32(u32 a){
#if GCC_VERSION>=5004000 && !defined(__INTEL_COMPILER)
return __builtin_popcount(a);
#else
return BitsCount[a >> 24] + BitsCount[(a >> 16) & 0xff] + BitsCount[(a >> 8) & 0xff] + BitsCount[a & 0xff];
#endif
}

int vector1BitDistanceHamming(const Vector *v1, const Vector *v2){
int diff = 0;
u8 *e1U8 = v1->data;
u32 *e1U32 = v1->data;
u8 *e2U8 = v2->data;
u32 *e2U32 = v2->data;
int i, len8, len32, offset8;

assert( v1->dims == v2->dims );
assert( v1->type == VECTOR_TYPE_1BIT );
assert( v2->type == VECTOR_TYPE_1BIT );

len8 = (v1->dims + 7) / 8;
len32 = v1->dims / 32;
offset8 = len32 * 4;

for(i = 0; i < len32; i++){
diff += sqlite3PopCount32(e1U32[i] ^ e2U32[i]);
}
for(i = offset8; i < len8; i++){
diff += sqlite3PopCount32(e1U8[i] ^ e2U8[i]);
}
return diff;
}

#endif /* !defined(SQLITE_OMIT_VECTOR) */
Loading

0 comments on commit 8262d23

Please sign in to comment.