Skip to content

Commit

Permalink
Merge pull request #1688 from tursodatabase/vector-search-bfloat16
Browse files Browse the repository at this point in the history
vector search: implement and integrate bfloat16
  • Loading branch information
sivukhin authored Aug 23, 2024
2 parents 32037aa + 2d75b23 commit f76bc0a
Show file tree
Hide file tree
Showing 11 changed files with 871 additions and 43 deletions.
284 changes: 272 additions & 12 deletions libsql-ffi/bundled/SQLite3MultipleCiphers/src/sqlite3.c

Large diffs are not rendered by default.

284 changes: 272 additions & 12 deletions libsql-ffi/bundled/src/sqlite3.c

Large diffs are not rendered by default.

7 changes: 6 additions & 1 deletion libsql-sqlite3/Makefile.in
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,7 @@ LIBOBJS0 = alter.lo analyze.lo attach.lo auth.lo \
sqlite3session.lo select.lo sqlite3rbu.lo status.lo stmt.lo \
table.lo threads.lo tokenize.lo treeview.lo trigger.lo \
update.lo userauth.lo upsert.lo util.lo vacuum.lo \
vector.lo vectorfloat32.lo vectorfloat64.lo vectorfloat1bit.lo vectorfloat8.lo vectorfloat16.lo \
vector.lo vectorfloat32.lo vectorfloat64.lo vectorfloat1bit.lo vectorfloat8.lo vectorfloat16.lo vectorfloatb16.lo \
vectorIndex.lo vectordiskann.lo vectorvtab.lo \
vdbe.lo vdbeapi.lo vdbeaux.lo vdbeblob.lo vdbemem.lo vdbesort.lo \
vdbetrace.lo vdbevtab.lo \
Expand Down Expand Up @@ -304,10 +304,12 @@ SRC = \
$(TOP)/src/vector.c \
$(TOP)/src/vectorInt.h \
$(TOP)/src/vectorfloat1bit.c \
$(TOP)/src/vectorfloat1bit.c \
$(TOP)/src/vectorfloat16.c \
$(TOP)/src/vectorfloat32.c \
$(TOP)/src/vectorfloat64.c \
$(TOP)/src/vectorfloat8.c \
$(TOP)/src/vectorfloatb16.c \
$(TOP)/src/vectorIndexInt.h \
$(TOP)/src/vectorIndex.c \
$(TOP)/src/vectordiskann.c \
Expand Down Expand Up @@ -1147,6 +1149,9 @@ vectorfloat1bit.lo: $(TOP)/src/vectorfloat1bit.c $(HDR)
vectorfloat16.lo: $(TOP)/src/vectorfloat16.c $(HDR)
$(LTCOMPILE) $(TEMP_STORE) -c $(TOP)/src/vectorfloat16.c

vectorfloatb16.lo: $(TOP)/src/vectorfloatb16.c $(HDR)
$(LTCOMPILE) $(TEMP_STORE) -c $(TOP)/src/vectorfloatb16.c

vectorfloat32.lo: $(TOP)/src/vectorfloat32.c $(HDR)
$(LTCOMPILE) $(TEMP_STORE) -c $(TOP)/src/vectorfloat32.c

Expand Down
139 changes: 129 additions & 10 deletions libsql-sqlite3/src/vector.c
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@ size_t vectorDataSize(VectorType type, VectorDims dims){
return ALIGN(dims, sizeof(float)) + sizeof(float) /* alpha */ + sizeof(float) /* shift */;
case VECTOR_TYPE_FLOAT16:
return dims * sizeof(u16);
case VECTOR_TYPE_FLOATB16:
return dims * sizeof(u16);
default:
assert(0);
}
Expand Down Expand Up @@ -124,6 +126,8 @@ float vectorDistanceCos(const Vector *pVector1, const Vector *pVector2){
return vectorF8DistanceCos(pVector1, pVector2);
case VECTOR_TYPE_FLOAT16:
return vectorF16DistanceCos(pVector1, pVector2);
case VECTOR_TYPE_FLOATB16:
return vectorFB16DistanceCos(pVector1, pVector2);
default:
assert(0);
}
Expand All @@ -141,6 +145,8 @@ float vectorDistanceL2(const Vector *pVector1, const Vector *pVector2){
return vectorF8DistanceL2(pVector1, pVector2);
case VECTOR_TYPE_FLOAT16:
return vectorF16DistanceL2(pVector1, pVector2);
case VECTOR_TYPE_FLOATB16:
return vectorFB16DistanceL2(pVector1, pVector2);
default:
assert(0);
}
Expand Down Expand Up @@ -314,6 +320,13 @@ static int vectorParseMeta(const unsigned char *pBlob, size_t nBlobSize, int *pT
}
*pDims = nBlobSize / sizeof(u16);
*pDataSize = nBlobSize;
}else if( *pType == VECTOR_TYPE_FLOATB16 ){
if( nBlobSize % 2 != 0 ){
*pzErrMsg = sqlite3_mprintf("vector: floatb16 vector blob length must be divisible by 2 (excluding 'type'-byte): length=%d", nBlobSize);
return SQLITE_ERROR;
}
*pDims = nBlobSize / sizeof(u16);
*pDataSize = nBlobSize;
}else{
*pzErrMsg = sqlite3_mprintf("vector: unexpected binary type: %d", *pType);
return SQLITE_ERROR;
Expand Down Expand Up @@ -365,6 +378,9 @@ int vectorParseSqliteBlobWithType(
case VECTOR_TYPE_FLOAT16:
vectorF16DeserializeFromBlob(pVector, pBlob, nDataSize);
return 0;
case VECTOR_TYPE_FLOATB16:
vectorFB16DeserializeFromBlob(pVector, pBlob, nDataSize);
return 0;
default:
assert(0);
}
Expand Down Expand Up @@ -469,6 +485,9 @@ void vectorDump(const Vector *pVector){
case VECTOR_TYPE_FLOAT16:
vectorF16Dump(pVector);
break;
case VECTOR_TYPE_FLOATB16:
vectorFB16Dump(pVector);
break;
default:
assert(0);
}
Expand All @@ -494,7 +513,7 @@ static int vectorMetaSize(VectorType type, VectorDims dims){
int nDataSize;
if( type == VECTOR_TYPE_FLOAT32 ){
return 0;
}else if( type == VECTOR_TYPE_FLOAT64 || type == VECTOR_TYPE_FLOAT16 ){
}else if( type == VECTOR_TYPE_FLOAT64 || type == VECTOR_TYPE_FLOAT16 || type == VECTOR_TYPE_FLOATB16 ){
return 1;
}else if( type == VECTOR_TYPE_FLOAT1BIT ){
nDataSize = vectorDataSize(type, dims);
Expand All @@ -513,7 +532,7 @@ static int vectorMetaSize(VectorType type, VectorDims dims){
static void vectorSerializeMeta(const Vector *pVector, size_t nDataSize, unsigned char *pBlob, size_t nBlobSize){
if( pVector->type == VECTOR_TYPE_FLOAT32 ){
// no meta for f32 type as this is "default" vector type
}else if( pVector->type == VECTOR_TYPE_FLOAT64 || pVector->type == VECTOR_TYPE_FLOAT16 ){
}else if( pVector->type == VECTOR_TYPE_FLOAT64 || pVector->type == VECTOR_TYPE_FLOAT16 || pVector->type == VECTOR_TYPE_FLOATB16 ){
assert( nDataSize % 2 == 0 );
assert( nBlobSize == nDataSize + 1 );
pBlob[nBlobSize - 1] = pVector->type;
Expand Down Expand Up @@ -582,6 +601,9 @@ void vectorSerializeToBlob(const Vector *pVector, unsigned char *pBlob, size_t n
case VECTOR_TYPE_FLOAT16:
vectorF16SerializeToBlob(pVector, pBlob, nBlobSize);
break;
case VECTOR_TYPE_FLOATB16:
vectorFB16SerializeToBlob(pVector, pBlob, nBlobSize);
break;
default:
assert(0);
}
Expand Down Expand Up @@ -624,6 +646,11 @@ static void vectorConvertFromF32(const Vector *pFrom, Vector *pTo){
for(i = 0; i < pFrom->dims; i++){
dstF16[i] = vectorF16FromFloat(src[i]);
}
}else if( pTo->type == VECTOR_TYPE_FLOATB16 ){
dstF16 = pTo->data;
for(i = 0; i < pFrom->dims; i++){
dstF16[i] = vectorFB16FromFloat(src[i]);
}
}else{
assert( 0 );
}
Expand Down Expand Up @@ -662,6 +689,11 @@ static void vectorConvertFromF64(const Vector *pFrom, Vector *pTo){
for(i = 0; i < pFrom->dims; i++){
dstF16[i] = vectorF16FromFloat(src[i]);
}
}else if( pTo->type == VECTOR_TYPE_FLOATB16 ){
dstF16 = pTo->data;
for(i = 0; i < pFrom->dims; i++){
dstF16[i] = vectorFB16FromFloat(src[i]);
}
}else{
assert( 0 );
}
Expand All @@ -673,7 +705,7 @@ static void vectorConvertFrom1Bit(const Vector *pFrom, Vector *pTo){

float *dstF32;
double *dstF64;
u16 *dstF16;
u16 *dstU16;

assert( pFrom->dims == pTo->dims );
assert( pFrom->type != pTo->type );
Expand Down Expand Up @@ -701,12 +733,23 @@ static void vectorConvertFrom1Bit(const Vector *pFrom, Vector *pTo){
}else if( pTo->type == VECTOR_TYPE_FLOAT16 ){
u16 positive = vectorF16FromFloat(+1);
u16 negative = vectorF16FromFloat(-1);
dstF16 = pTo->data;
dstU16 = pTo->data;
for(i = 0; i < pFrom->dims; i++){
if( ((src[i / 8] >> (i & 7)) & 1) == 1 ){
dstU16[i] = positive;
}else{
dstU16[i] = negative;
}
}
}else if( pTo->type == VECTOR_TYPE_FLOATB16 ){
u16 positive = vectorFB16FromFloat(+1);
u16 negative = vectorFB16FromFloat(-1);
dstU16 = pTo->data;
for(i = 0; i < pFrom->dims; i++){
if( ((src[i / 8] >> (i & 7)) & 1) == 1 ){
dstF16[i] = positive;
dstU16[i] = positive;
}else{
dstF16[i] = negative;
dstU16[i] = negative;
}
}
}else{
Expand Down Expand Up @@ -756,6 +799,11 @@ static void vectorConvertFromF8(const Vector *pFrom, Vector *pTo){
for(i = 0; i < pFrom->dims; i++){
dstF16[i] = vectorF16FromFloat(alpha * src[i] + shift);
}
}else if( pTo->type == VECTOR_TYPE_FLOATB16 ){
dstF16 = pTo->data;
for(i = 0; i < pFrom->dims; i++){
dstF16[i] = vectorFB16FromFloat(alpha * src[i] + shift);
}
}else{
assert( 0 );
}
Expand All @@ -768,6 +816,7 @@ static void vectorConvertFromF16(const Vector *pFrom, Vector *pTo){
float *dstF32;
double *dstF64;
u8 *dst1Bit;
u16 *dstU16;

assert( pFrom->dims == pTo->dims );
assert( pFrom->type != pTo->type );
Expand All @@ -784,6 +833,11 @@ static void vectorConvertFromF16(const Vector *pFrom, Vector *pTo){
for(i = 0; i < pFrom->dims; i++){
dstF64[i] = vectorF16ToFloat(src[i]);
}
}else if( pTo->type == VECTOR_TYPE_FLOATB16 ){
dstU16 = pTo->data;
for(i = 0; i < pFrom->dims; i++){
dstU16[i] = vectorFB16FromFloat(vectorF16ToFloat(src[i]));
}
}else if( pTo->type == VECTOR_TYPE_FLOAT1BIT ){
dst1Bit = pTo->data;
for(i = 0; i < pFrom->dims; i += 8){
Expand All @@ -799,6 +853,50 @@ static void vectorConvertFromF16(const Vector *pFrom, Vector *pTo){
}
}

static void vectorConvertFromFB16(const Vector *pFrom, Vector *pTo){
int i;
u16 *src;

float *dstF32;
double *dstF64;
u8 *dst1Bit;
u16 *dstU16;

assert( pFrom->dims == pTo->dims );
assert( pFrom->type != pTo->type );
assert( pFrom->type == VECTOR_TYPE_FLOATB16 );

src = pFrom->data;
if( pTo->type == VECTOR_TYPE_FLOAT32 ){
dstF32 = pTo->data;
for(i = 0; i < pFrom->dims; i++){
dstF32[i] = vectorFB16ToFloat(src[i]);
}
}else if( pTo->type == VECTOR_TYPE_FLOAT64 ){
dstF64 = pTo->data;
for(i = 0; i < pFrom->dims; i++){
dstF64[i] = vectorFB16ToFloat(src[i]);
}
}else if( pTo->type == VECTOR_TYPE_FLOAT16 ){
dstU16 = pTo->data;
for(i = 0; i < pFrom->dims; i++){
dstU16[i] = vectorF16FromFloat(vectorFB16ToFloat(src[i]));
}
}else if( pTo->type == VECTOR_TYPE_FLOAT1BIT ){
dst1Bit = pTo->data;
for(i = 0; i < pFrom->dims; i += 8){
dst1Bit[i / 8] = 0;
}
for(i = 0; i < pFrom->dims; i++){
if( vectorFB16ToFloat(src[i]) > 0 ){
dst1Bit[i / 8] |= (1 << (i & 7));
}
}
}else{
assert( 0 );
}
}

static inline int clip(float f, int minF, int maxF){
if( f < minF ){
return minF;
Expand All @@ -819,7 +917,7 @@ static void vectorConvertToF8(const Vector *pFrom, Vector *pTo){
float *srcF32;
double *srcF64;
u8 *src1Bit;
u16 *srcF16;
u16 *srcU16;

assert( pFrom->dims == pTo->dims );
assert( pFrom->type != pTo->type );
Expand Down Expand Up @@ -857,14 +955,24 @@ static void vectorConvertToF8(const Vector *pFrom, Vector *pTo){
dst[i] = clip(((((src1Bit[i / 8] >> (i & 7)) & 1) ? +1 : -1) - shift) / alpha, 0, 255);
}
}else if( pFrom->type == VECTOR_TYPE_FLOAT16 ){
srcF16 = pFrom->data;
srcU16 = pFrom->data;
for(i = 0; i < pFrom->dims; i++){
MINMAX(i, vectorF16ToFloat(srcU16[i]), minF, maxF);
}
shift = minF;
alpha = (maxF - minF) / 255;
for(i = 0; i < pFrom->dims; i++){
dst[i] = clip((vectorF16ToFloat(srcU16[i]) - shift) / alpha, 0, 255);
}
}else if( pFrom->type == VECTOR_TYPE_FLOATB16 ){
srcU16 = pFrom->data;
for(i = 0; i < pFrom->dims; i++){
MINMAX(i, vectorF16ToFloat(srcF16[i]), minF, maxF);
MINMAX(i, vectorFB16ToFloat(srcU16[i]), minF, maxF);
}
shift = minF;
alpha = (maxF - minF) / 255;
for(i = 0; i < pFrom->dims; i++){
dst[i] = clip((vectorF16ToFloat(srcF16[i]) - shift) / alpha, 0, 255);
dst[i] = clip((vectorFB16ToFloat(srcU16[i]) - shift) / alpha, 0, 255);
}
}else{
assert( 0 );
Expand Down Expand Up @@ -893,6 +1001,8 @@ void vectorConvert(const Vector *pFrom, Vector *pTo){
vectorConvertFromF8(pFrom, pTo);
}else if( pFrom->type == VECTOR_TYPE_FLOAT16 ){
vectorConvertFromF16(pFrom, pTo);
}else if( pFrom->type == VECTOR_TYPE_FLOATB16 ){
vectorConvertFromFB16(pFrom, pTo);
}else{
assert( 0 );
}
Expand Down Expand Up @@ -985,6 +1095,14 @@ static void vector16Func(
vectorFuncHintedType(context, argc, argv, VECTOR_TYPE_FLOAT16);
}

static void vectorb16Func(
sqlite3_context *context,
int argc,
sqlite3_value **argv
){
vectorFuncHintedType(context, argc, argv, VECTOR_TYPE_FLOATB16);
}

static void vector1BitFunc(
sqlite3_context *context,
int argc,
Expand Down Expand Up @@ -1144,6 +1262,7 @@ void sqlite3RegisterVectorFunctions(void){
FUNCTION(vector1bit, 1, 0, 0, vector1BitFunc),
FUNCTION(vector8, 1, 0, 0, vector8Func),
FUNCTION(vector16, 1, 0, 0, vector16Func),
FUNCTION(vectorb16, 1, 0, 0, vectorb16Func),
FUNCTION(vector_extract, 1, 0, 0, vectorExtractFunc),
FUNCTION(vector_distance_cos, 2, 0, 0, vectorDistanceCosFunc),
FUNCTION(vector_distance_l2, 2, 0, 0, vectorDistanceL2Func),
Expand Down
3 changes: 3 additions & 0 deletions libsql-sqlite3/src/vectorIndex.c
Original file line number Diff line number Diff line change
Expand Up @@ -387,6 +387,8 @@ static struct VectorColumnType VECTOR_COLUMN_TYPES[] = {
{ "F8_BLOB", VECTOR_TYPE_FLOAT8 },
{ "FLOAT16", VECTOR_TYPE_FLOAT16 },
{ "F16_BLOB", VECTOR_TYPE_FLOAT16 },
{ "FLOATB16", VECTOR_TYPE_FLOATB16 },
{ "FB16_BLOB", VECTOR_TYPE_FLOATB16 },
};

/*
Expand All @@ -408,6 +410,7 @@ static struct VectorParamName VECTOR_PARAM_NAMES[] = {
{ "compress_neighbors", VECTOR_COMPRESS_NEIGHBORS_PARAM_ID, 0, "float1bit", VECTOR_TYPE_FLOAT1BIT },
{ "compress_neighbors", VECTOR_COMPRESS_NEIGHBORS_PARAM_ID, 0, "float8", VECTOR_TYPE_FLOAT8 },
{ "compress_neighbors", VECTOR_COMPRESS_NEIGHBORS_PARAM_ID, 0, "float16", VECTOR_TYPE_FLOAT16 },
{ "compress_neighbors", VECTOR_COMPRESS_NEIGHBORS_PARAM_ID, 0, "floatb16", VECTOR_TYPE_FLOATB16 },
{ "compress_neighbors", VECTOR_COMPRESS_NEIGHBORS_PARAM_ID, 0, "float32", VECTOR_TYPE_FLOAT32 },
{ "alpha", VECTOR_PRUNING_ALPHA_PARAM_ID, 2, 0, 0 },
{ "search_l", VECTOR_SEARCH_L_PARAM_ID, 1, 0, 0 },
Expand Down
Loading

0 comments on commit f76bc0a

Please sign in to comment.