Skip to content

Commit 0ec52ea

Browse files
committed
Improved BIT support
1 parent 6abfc23 commit 0ec52ea

File tree

1 file changed

+107
-70
lines changed

1 file changed

+107
-70
lines changed

src/sqlite-vector.c

Lines changed: 107 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -846,7 +846,7 @@ static void quantize_binary_i8 (const int8_t *input, uint8_t *output, int dim) {
846846

847847
// MARK: - General Utils -
848848

849-
static size_t vector_type_to_size (vector_type type) {
849+
static int vector_type_to_size (vector_type type) {
850850
switch (type) {
851851
case VECTOR_TYPE_F32: return sizeof(float); // 4 bytes
852852
case VECTOR_TYPE_F16: return sizeof(uint16_t); // 2 bytes
@@ -855,7 +855,7 @@ static size_t vector_type_to_size (vector_type type) {
855855
case VECTOR_TYPE_I8: return sizeof(int8_t); // 1 byte
856856
case VECTOR_TYPE_BIT: return 0; // Special: use vector_bytes_for_dim()
857857
}
858-
return SIZE_T_MAX; // error
858+
return -1; // error
859859
}
860860

861861
static vector_type vector_name_to_type (const char *vname) {
@@ -904,6 +904,7 @@ static vector_distance distance_name_to_type (const char *dname) {
904904
if (strcasecmp(dname, "INNER") == 0) return VECTOR_DISTANCE_DOT;
905905
if (strcasecmp(dname, "L1") == 0) return VECTOR_DISTANCE_L1;
906906
if (strcasecmp(dname, "MANHATTAN") == 0) return VECTOR_DISTANCE_L1;
907+
if (strcasecmp(dname, "HAMMING") == 0) return VECTOR_DISTANCE_HAMMING;
907908
return 0;
908909
}
909910

@@ -1302,63 +1303,65 @@ static int vector_rebuild_quantization (sqlite3_context *context, const char *ta
13021303
if (rc != SQLITE_OK) goto vector_rebuild_quantization_cleanup;
13031304

13041305
// STEP 1
1305-
// find global min/max across ALL vectors
1306+
// find global min/max across ALL vectors (skip for 1BIT quantization which uses fixed threshold)
13061307
#if defined(_WIN32) || defined(__linux__)
13071308
float min_val = FLT_MAX;
13081309
float max_val = -FLT_MAX;
1309-
#else
1310+
#else
13101311
float min_val = MAXFLOAT;
13111312
float max_val = -MAXFLOAT;
13121313
#endif
13131314
bool contains_negative = false;
1314-
1315-
while (1) {
1316-
rc = sqlite3_step(vm);
1317-
if (rc == SQLITE_DONE) {rc = SQLITE_OK; break;}
1318-
else if (rc != SQLITE_ROW) break;
1319-
if (sqlite3_column_type(vm, 1) == SQLITE_NULL) continue;
1320-
1321-
const void *blob = (float *)sqlite3_column_blob(vm, 1);
1322-
if (!blob) continue;
1323-
1324-
int blob_size = sqlite3_column_bytes(vm, 1);
1325-
size_t need_bytes = vector_bytes_for_dim(type, dim);
1326-
if (blob_size < need_bytes) {
1327-
context_result_error(context, SQLITE_ERROR, "Invalid vector blob found at rowid %lld", (long long)sqlite3_column_int64(vm, 0));
1328-
rc = SQLITE_ERROR;
1329-
goto vector_rebuild_quantization_cleanup;
1330-
}
1331-
1332-
for (int i = 0; i < dim; ++i) {
1333-
float val = 0.0f;
1334-
switch (type) {
1335-
case VECTOR_TYPE_F32:
1336-
val = ((float *)blob)[i];
1337-
break;
1338-
case VECTOR_TYPE_F16:
1339-
val = float16_to_float32(((uint16_t *)blob)[i]);
1340-
break;
1341-
case VECTOR_TYPE_BF16:
1342-
val = bfloat16_to_float32(((uint16_t *)blob)[i]);
1343-
break;
1344-
case VECTOR_TYPE_U8:
1345-
val = (float)(((uint8_t *)blob)[i]);
1346-
break;
1347-
case VECTOR_TYPE_I8:
1348-
val = (float)(((int8_t *)blob)[i]);
1349-
break;
1350-
default:
1351-
context_result_error(context, SQLITE_ERROR, "Unsupported vector type");
1352-
rc = SQLITE_ERROR;
1353-
goto vector_rebuild_quantization_cleanup;
1315+
1316+
if (qtype != VECTOR_QUANT_1BIT) {
1317+
while (1) {
1318+
rc = sqlite3_step(vm);
1319+
if (rc == SQLITE_DONE) {rc = SQLITE_OK; break;}
1320+
else if (rc != SQLITE_ROW) break;
1321+
if (sqlite3_column_type(vm, 1) == SQLITE_NULL) continue;
1322+
1323+
const void *blob = (float *)sqlite3_column_blob(vm, 1);
1324+
if (!blob) continue;
1325+
1326+
int blob_size = sqlite3_column_bytes(vm, 1);
1327+
size_t need_bytes = vector_bytes_for_dim(type, dim);
1328+
if (blob_size < need_bytes) {
1329+
context_result_error(context, SQLITE_ERROR, "Invalid vector blob found at rowid %lld", (long long)sqlite3_column_int64(vm, 0));
1330+
rc = SQLITE_ERROR;
1331+
goto vector_rebuild_quantization_cleanup;
13541332
}
13551333

1356-
if (val < min_val) min_val = val;
1357-
if (val > max_val) max_val = val;
1358-
if (val < 0.0) contains_negative = true;
1334+
for (int i = 0; i < dim; ++i) {
1335+
float val = 0.0f;
1336+
switch (type) {
1337+
case VECTOR_TYPE_F32:
1338+
val = ((float *)blob)[i];
1339+
break;
1340+
case VECTOR_TYPE_F16:
1341+
val = float16_to_float32(((uint16_t *)blob)[i]);
1342+
break;
1343+
case VECTOR_TYPE_BF16:
1344+
val = bfloat16_to_float32(((uint16_t *)blob)[i]);
1345+
break;
1346+
case VECTOR_TYPE_U8:
1347+
val = (float)(((uint8_t *)blob)[i]);
1348+
break;
1349+
case VECTOR_TYPE_I8:
1350+
val = (float)(((int8_t *)blob)[i]);
1351+
break;
1352+
default:
1353+
context_result_error(context, SQLITE_ERROR, "Unsupported vector type for 8-bit quantization");
1354+
rc = SQLITE_ERROR;
1355+
goto vector_rebuild_quantization_cleanup;
1356+
}
1357+
1358+
if (val < min_val) min_val = val;
1359+
if (val > max_val) max_val = val;
1360+
if (val < 0.0) contains_negative = true;
1361+
}
13591362
}
13601363
}
1361-
1364+
13621365
// set proper format
13631366
if (qtype == VECTOR_QUANT_AUTO) {
13641367
if (contains_negative == true) qtype = VECTOR_QUANT_S8BIT;
@@ -1694,13 +1697,17 @@ static void *vector_from_json (sqlite3_context *context, sqlite3_vtab *vtab, vec
16941697
}
16951698

16961699
// allocate blob
1700+
// For BIT type, each JSON element is a single bit, pack 8 per byte
16971701
size_t item_size = vector_type_to_size(type);
1698-
size_t alloc = (estimated_count + 1) * item_size;
1702+
size_t alloc = (type == VECTOR_TYPE_BIT) ? ((estimated_count + 1) + 7) / 8 : (estimated_count + 1) * item_size;
16991703
blob = sqlite3_malloc((int)alloc);
17001704
if (!blob) {
17011705
return sqlite_common_set_error(context, vtab, SQLITE_NOMEM, "Out of memory: unable to allocate %lld bytes for BLOB buffer", (long long)alloc);
17021706
}
1703-
1707+
if (type == VECTOR_TYPE_BIT) {
1708+
memset(blob, 0, alloc); // Initialize to zero for bit packing
1709+
}
1710+
17041711
// typed pointers
17051712
float *float_blob = (float *)blob;
17061713
uint8_t *uint8_blob = (uint8_t *)blob;
@@ -1728,41 +1735,54 @@ static void *vector_from_json (sqlite3_context *context, sqlite3_vtab *vtab, vec
17281735
return sqlite_common_set_error(context, vtab, SQLITE_ERROR, "Malformed JSON: expected a number at position %d (found '%c')", (int)(p - json) + 1, *p ? *p : '?');
17291736
}
17301737

1731-
if (count >= (int)(alloc / item_size)) {
1738+
// check bounds
1739+
int max_count = (type == VECTOR_TYPE_BIT) ? (int)(alloc * 8) : (int)(alloc / item_size);
1740+
if (count >= max_count) {
17321741
sqlite3_free(blob);
17331742
return sqlite_common_set_error(context, vtab, SQLITE_ERROR, "Too many elements in JSON array");
17341743
}
1735-
1744+
17361745
// convert to proper type
17371746
switch (type) {
17381747
case VECTOR_TYPE_F32:
17391748
float_blob[count++] = (float)value;
17401749
break;
1741-
1750+
17421751
case VECTOR_TYPE_F16:
17431752
uint16_blob[count++] = float32_to_float16((float)value);
17441753
break;
17451754

17461755
case VECTOR_TYPE_BF16:
17471756
bfloat16_blob[count++] = float32_to_bfloat16((float)value);
17481757
break;
1749-
1758+
17501759
case VECTOR_TYPE_U8:
17511760
if (value < 0 || value > 255) {
17521761
sqlite3_free(blob);
17531762
return sqlite_common_set_error(context, vtab, SQLITE_ERROR, "Value out of range for uint8_t");
17541763
}
17551764
uint8_blob[count++] = (uint8_t)value;
17561765
break;
1757-
1766+
17581767
case VECTOR_TYPE_I8:
17591768
if (value < -128 || value > 127) {
17601769
sqlite3_free(blob);
17611770
return sqlite_common_set_error(context, vtab, SQLITE_ERROR, "Value out of range for int8_t");
17621771
}
17631772
int8_blob[count++] = (int8_t)value;
17641773
break;
1765-
1774+
1775+
case VECTOR_TYPE_BIT:
1776+
if (value != 0 && value != 1) {
1777+
sqlite3_free(blob);
1778+
return sqlite_common_set_error(context, vtab, SQLITE_ERROR, "Value out of range for BIT: expected 0 or 1");
1779+
}
1780+
if ((int)value == 1) {
1781+
uint8_blob[count / 8] |= (1 << (count % 8));
1782+
}
1783+
count++;
1784+
break;
1785+
17661786
default:
17671787
sqlite3_free(blob);
17681788
return sqlite_common_set_error(context, vtab, SQLITE_ERROR, "Unsupported vector type");
@@ -1796,8 +1816,8 @@ static void *vector_from_json (sqlite3_context *context, sqlite3_vtab *vtab, vec
17961816
sqlite3_free(blob);
17971817
return sqlite_common_set_error(context, vtab, SQLITE_ERROR, "Invalid JSON vector dimension: expected %d but found %d", dimension, count);
17981818
}
1799-
1800-
if (size) *size = (int)(count * item_size);
1819+
1820+
if (size) *size = (type == VECTOR_TYPE_BIT) ? (int)((count + 7) / 8) : (int)(count * item_size);
18011821
return blob;
18021822
}
18031823

@@ -1846,8 +1866,10 @@ static void vector_as_type (sqlite3_context *context, vector_type type, int argc
18461866

18471867
char *blob = vector_from_json(context, NULL, type, json, &value_size, dimension);
18481868
if (!blob) return; // error is set in the context
1849-
1850-
VECTOR_PRINT((void *)blob, type, (dimension == 0) ? (value_size / vector_type_to_size(type)) : dimension);
1869+
1870+
int print_dim = dimension;
1871+
if (print_dim == 0) print_dim = (type == VECTOR_TYPE_BIT) ? value_size * 8 : value_size / vector_type_to_size(type);
1872+
VECTOR_PRINT((void *)blob, type, print_dim);
18511873

18521874
sqlite3_result_blob(context, (const void *)blob, value_size, sqlite3_free);
18531875
return;
@@ -1876,6 +1898,10 @@ static void vector_as_i8 (sqlite3_context *context, int argc, sqlite3_value **ar
18761898
vector_as_type(context, VECTOR_TYPE_I8, argc, argv);
18771899
}
18781900

1901+
static void vector_as_bit (sqlite3_context *context, int argc, sqlite3_value **argv) {
1902+
vector_as_type(context, VECTOR_TYPE_BIT, argc, argv);
1903+
}
1904+
18791905
// MARK: - Modules -
18801906
static int vFullScanCursorNext (sqlite3_vtab_cursor *cur);
18811907

@@ -2073,18 +2099,21 @@ static int vFullScanCursorNext (sqlite3_vtab_cursor *cur){
20732099

20742100
// FULL-SCAN
20752101
if (!c->is_quantized) {
2102+
// For BIT type, use byte count instead of dimension
2103+
vector_type vt = c->table->options.v_type;
2104+
int dist_size = (vt == VECTOR_TYPE_BIT) ? ((dimension + 7) / 8) : dimension;
20762105
while (1) {
20772106
int rc = sqlite3_step(vm);
20782107
if (rc == SQLITE_DONE) { c->stream.is_eof = 1; return SQLITE_OK; }
20792108
else if (rc != SQLITE_ROW) return rc;
2080-
2109+
20812110
// skip NULL values
20822111
if (sqlite3_column_type(vm, 1) == SQLITE_NULL) continue;
20832112

20842113
const float *v2 = (const float *)sqlite3_column_blob(vm, 1);
20852114
if (v2 == NULL) continue;
20862115

2087-
float distance = distance_fn((const void *)v1, (const void *)v2, dimension);
2116+
float distance = distance_fn((const void *)v1, (const void *)v2, dist_size);
20882117
if (nearly_zero_float32(distance)) distance = 0.0f;
20892118

20902119
c->stream.distance = distance;
@@ -2245,18 +2274,20 @@ static int vFullScanRun (sqlite3 *db, vFullScanCursor *c, const void *v1, int v1
22452274
// compute distance function
22462275
vector_distance vd = c->table->options.v_distance;
22472276
vector_type vt = c->table->options.v_type;
2277+
if (vt == VECTOR_TYPE_BIT) vd = VECTOR_DISTANCE_HAMMING; // Force Hamming for BIT type
22482278
distance_function_t distance_fn = dispatch_distance_table[vd][vt];
2249-
2279+
int dist_size = (vt == VECTOR_TYPE_BIT) ? ((dimension + 7) / 8) : dimension;
2280+
22502281
while (1) {
22512282
rc = sqlite3_step(vm);
22522283
if (rc == SQLITE_DONE) {rc = SQLITE_OK; goto cleanup;}
22532284
if (rc != SQLITE_ROW) goto cleanup;
22542285
if (sqlite3_column_type(vm, 1) == SQLITE_NULL) continue;
2255-
2286+
22562287
float *v2 = (float *)sqlite3_column_blob(vm, 1);
22572288
if (v2 == NULL) continue;
2258-
2259-
float distance = distance_fn((const void *)v1, (const void *)v2, dimension);
2289+
2290+
float distance = distance_fn((const void *)v1, (const void *)v2, dist_size);
22602291
if (nearly_zero_float32(distance)) distance = 0.0;
22612292
VECTOR_PRINT((void*)v2, vt, dimension);
22622293

@@ -2486,14 +2517,15 @@ static int vStreamScanCursorRun (sqlite3 *db, vFullScanCursor *c, const void *v1
24862517
// compute distance function
24872518
vector_distance vd = c->table->options.v_distance;
24882519
vector_type vt = c->table->options.v_type;
2520+
if (vt == VECTOR_TYPE_BIT) vd = VECTOR_DISTANCE_HAMMING; // Force Hamming for BIT type
24892521
distance_function_t distance_fn = dispatch_distance_table[vd][vt];
2490-
2522+
24912523
c->stream.distance_fn = distance_fn;
24922524
c->stream.vm = vm;
2493-
2525+
24942526
if (sql) sqlite3_free(sql);
24952527
return SQLITE_OK;
2496-
2528+
24972529
cleanup:
24982530
if (sql) sqlite3_free(sql);
24992531
if (vm) sqlite3_finalize(vm);
@@ -2842,7 +2874,12 @@ SQLITE_VECTOR_API int sqlite3_vector_init (sqlite3 *db, char **pzErrMsg, const s
28422874
if (rc != SQLITE_OK) goto cleanup;
28432875
rc = sqlite3_create_function(db, "vector_as_u8", 2, SQLITE_UTF8, ctx, vector_as_u8, NULL, NULL);
28442876
if (rc != SQLITE_OK) goto cleanup;
2845-
2877+
2878+
rc = sqlite3_create_function(db, "vector_as_bit", 1, SQLITE_UTF8, ctx, vector_as_bit, NULL, NULL);
2879+
if (rc != SQLITE_OK) goto cleanup;
2880+
rc = sqlite3_create_function(db, "vector_as_bit", 2, SQLITE_UTF8, ctx, vector_as_bit, NULL, NULL);
2881+
if (rc != SQLITE_OK) goto cleanup;
2882+
28462883
rc = sqlite3_create_module(db, "vector_full_scan", &vFullScanModule, ctx);
28472884
if (rc != SQLITE_OK) goto cleanup;
28482885

0 commit comments

Comments
 (0)