Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 49 additions & 0 deletions sqlite-vec-rescore.c
Original file line number Diff line number Diff line change
Expand Up @@ -585,6 +585,55 @@ static int rescore_knn(vec0_vtab *p, vec0_cursor *pCur,
sqlite3_blob_close(blobFloat);
sqlite3_free(fBuf);

// Apply distance constraints (ex `distance <= ?`) on the rescored float
// distances. Candidates failing any constraint are dropped before the
// top-k selection. The constraint targets the final float distance, not
// the coarse quantized distance from phase 1. (issue #308)
for (int c = 0; c < argc; c++) {
int idx = 1 + (c * 4);
if (idxStr[idx + 0] != VEC0_IDXSTR_KIND_KNN_DISTANCE_CONSTRAINT)
continue;
vec0_distance_constraint_operator op = idxStr[idx + 1];
f32 target = (f32)sqlite3_value_double(argv[c]);
i64 kept = 0;
for (i64 j = 0; j < cand_used; j++) {
int pass;
switch (op) {
case VEC0_DISTANCE_CONSTRAINT_GE:
pass = float_distances[j] >= target;
break;
case VEC0_DISTANCE_CONSTRAINT_GT:
pass = float_distances[j] > target;
break;
case VEC0_DISTANCE_CONSTRAINT_LE:
pass = float_distances[j] <= target;
break;
case VEC0_DISTANCE_CONSTRAINT_LT:
pass = float_distances[j] < target;
break;
default:
pass = 1;
break;
}
if (pass) {
cand_rowids[kept] = cand_rowids[j];
float_distances[kept] = float_distances[j];
kept++;
}
}
cand_used = kept;
}

if (cand_used == 0) {
knn_data->current_idx = 0;
knn_data->k = 0;
knn_data->rowids = NULL;
knn_data->distances = NULL;
knn_data->k_used = 0;
sqlite3_free(float_distances);
goto cleanup;
}

// Sort by float distance
for (i64 a = 0; a + 1 < cand_used; a++) {
i64 minIdx = a;
Expand Down
55 changes: 55 additions & 0 deletions tests/test-rescore.py
Original file line number Diff line number Diff line change
Expand Up @@ -725,3 +725,58 @@ def test_unknown_command_errors(db):
)
with pytest.raises(sqlite3.OperationalError, match="unknown vec0 command"):
db.execute("INSERT INTO t(t) VALUES ('not_a_real_command')")


# ============================================================================
# Distance constraint tests (issue #308)
# ============================================================================


def test_knn_distance_constraint_le(db):
"""A `distance <= threshold` constraint must be honored in rescore KNN."""
db.execute(
"CREATE VIRTUAL TABLE t USING vec0("
" embedding float[8] indexed by rescore(quantizer=bit)"
")"
)
# rowid 1 is the exact match (distance 0); rowid 2 is far away.
v1 = [1.0, 0.5, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
v2 = [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0]
db.execute("INSERT INTO t(rowid, embedding) VALUES (1, ?)", [float_vec(v1)])
db.execute("INSERT INTO t(rowid, embedding) VALUES (2, ?)", [float_vec(v2)])

query = [1.0, 0.5, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
rows = db.execute(
"SELECT rowid, distance FROM t "
"WHERE embedding MATCH ? AND k = 10 AND distance <= 0.5 "
"ORDER BY distance",
[float_vec(query)],
).fetchall()

# Only the exact match (distance ~0) is within the 0.5 threshold;
# rowid 2 (distance 1.5) must be filtered out.
assert [r["rowid"] for r in rows] == [1]
assert all(r["distance"] <= 0.5 for r in rows)


def test_knn_distance_constraint_lt_gt(db):
"""`distance < x` and `distance > x` constraints must be honored."""
db.execute(
"CREATE VIRTUAL TABLE t USING vec0("
" embedding float[8] indexed by rescore(quantizer=bit)"
")"
)
v1 = [1.0, 0.5, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
v2 = [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0]
db.execute("INSERT INTO t(rowid, embedding) VALUES (1, ?)", [float_vec(v1)])
db.execute("INSERT INTO t(rowid, embedding) VALUES (2, ?)", [float_vec(v2)])

query = [1.0, 0.5, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
# Exclude the exact match, keep only the far vector.
rows = db.execute(
"SELECT rowid, distance FROM t "
"WHERE embedding MATCH ? AND k = 10 AND distance > 0.5 "
"ORDER BY distance",
[float_vec(query)],
).fetchall()
assert [r["rowid"] for r in rows] == [2]