From 97b36efd47292a203ab442e7b7fe096ea454aedf Mon Sep 17 00:00:00 2001 From: Magdalen Manohar Date: Thu, 14 May 2026 17:56:54 +0000 Subject: [PATCH 01/10] finish up recall computation patch --- diskann-benchmark-core/src/recall.rs | 51 ++++++++----------- .../src/backend/index/benchmarks.rs | 28 ++++++++-- diskann-benchmark/src/utils/datafiles.rs | 13 ++++- diskann-benchmark/src/utils/recall.rs | 6 --- 4 files changed, 57 insertions(+), 41 deletions(-) diff --git a/diskann-benchmark-core/src/recall.rs b/diskann-benchmark-core/src/recall.rs index 0fa4d42c1..cfca474eb 100644 --- a/diskann-benchmark-core/src/recall.rs +++ b/diskann-benchmark-core/src/recall.rs @@ -22,10 +22,6 @@ pub struct RecallMetrics { pub num_queries: usize, /// The average recall across all queries. pub average: f64, - /// The minimum observed recall (max possible value: `recall_n`). - pub minimum: usize, - /// The maximum observed recall (max possible value: `recall_k`). - pub maximum: usize, } #[derive(Debug, Error)] @@ -186,8 +182,8 @@ where } } - // The actual recall computation for fixed-size groundtruth - let mut recall_values: Vec = Vec::new(); + // The actual recall computation for groundtruth + let mut recall_values: Vec = Vec::new(); let mut this_groundtruth = HashSet::new(); let mut this_results = HashSet::new(); @@ -198,26 +194,22 @@ where } let gt_row = groundtruth.row(i); - if gt_row.len() < recall_k { - return Err(ComputeRecallError::NotEnoughGroundTruth( - gt_row.len(), - recall_k, - )); - } + // groundtruth does not have to be fixed-size, so we compute recall_k for this row based on its gt length + let this_recall_k = gt_row.len().min(recall_k); // Populate the groundtruth using the top-k this_groundtruth.clear(); - this_groundtruth.extend(gt_row.iter().take(recall_k).cloned()); + this_groundtruth.extend(gt_row.iter().take(this_recall_k).cloned()); // If we have distances, then continue to append distances as long as the distance // value is constant if let Some(distances) = groundtruth_distances - && recall_k > 0 + && this_recall_k > 0 { let distances_row = distances.row(i); - if distances_row.len() > recall_k - 1 && gt_row.len() > recall_k - 1 { - let last_distance = distances_row[recall_k - 1]; - for (d, g) in distances_row.iter().zip(gt_row.iter()).skip(recall_k) { + if distances_row.len() > this_recall_k - 1 && gt_row.len() > this_recall_k - 1 { + let last_distance = distances_row[this_recall_k - 1]; + for (d, g) in distances_row.iter().zip(gt_row.iter()).skip(this_recall_k) { if *d == last_distance { this_groundtruth.insert(g.clone()); } else { @@ -235,27 +227,28 @@ where .iter() .filter(|i| this_results.contains(i)) .count() - .min(recall_k); + .min(this_recall_k); - recall_values.push(r); - } + // recall is the number of correct results in the top n, divided by k (not n), or 0 if there are no groundtruth results for this query + let recall = if this_recall_k > 0 { + (r as f64) / (this_recall_k as f64) + } else { + 0.0 + }; - // Perform post-processing - let total: usize = recall_values.iter().sum(); - let minimum = recall_values.iter().min().unwrap_or(&0); - let maximum = recall_values.iter().max().unwrap_or(&0); + recall_values.push(recall); + } - // We explicitly check that each groundtruth row has at least `recall_k` elements. - let div = recall_k * nrows; - let average = (total as f64) / (div as f64); + // Compute the average recall + let total: f64 = recall_values.iter().sum(); + let div = recall_values.len(); + let average = (total) / (div as f64); Ok(RecallMetrics { recall_k, recall_n, num_queries: nrows, average, - minimum: *minimum, - maximum: *maximum, }) } diff --git a/diskann-benchmark/src/backend/index/benchmarks.rs b/diskann-benchmark/src/backend/index/benchmarks.rs index 57aafc8eb..6a0150489 100644 --- a/diskann-benchmark/src/backend/index/benchmarks.rs +++ b/diskann-benchmark/src/backend/index/benchmarks.rs @@ -471,7 +471,16 @@ where let queries: Arc> = Arc::new(datafiles::load_dataset(datafiles::BinFile(&topk.queries))?); - let groundtruth = datafiles::load_groundtruth(datafiles::BinFile(&topk.groundtruth))?; + // compute the maximum value of k used in any search + let max_k = topk + .runs + .iter() + .map(|run| run.recall_k) + .max() + .ok_or_else(|| anyhow::anyhow!("No runs provided in Topk phase"))?; + + let groundtruth = + datafiles::load_groundtruth(datafiles::BinFile(&topk.groundtruth), Some(max_k))?; let knn = benchmark_core::search::graph::KNN::new( index.clone(), @@ -695,10 +704,19 @@ where let managed = Managed::new(max_points, consolidate_threshold, managed_stream); - let layered = bigann::WithData::new(managed, data, queries, |path| { - Ok(Box::new(datafiles::load_groundtruth(datafiles::BinFile( - path, - ))?)) + // compute the maximum value of k used in any search + let max_k = topk + .runs + .iter() + .map(|run| run.recall_k) + .max() + .ok_or_else(|| anyhow::anyhow!("No runs provided in Topk phase"))?; + + let layered = bigann::WithData::new(managed, data, queries, move |path| { + Ok(Box::new(datafiles::load_groundtruth( + datafiles::BinFile(path), + Some(max_k), + )?)) }); Ok(layered) diff --git a/diskann-benchmark/src/utils/datafiles.rs b/diskann-benchmark/src/utils/datafiles.rs index 9c5057488..c6d43ccc2 100644 --- a/diskann-benchmark/src/utils/datafiles.rs +++ b/diskann-benchmark/src/utils/datafiles.rs @@ -95,7 +95,7 @@ impl ConvertingLoad for f32 { } /// Load a groundtruth set from disk and return the result as a row-major matrix. -pub(crate) fn load_groundtruth(path: BinFile<'_>) -> anyhow::Result> { +pub(crate) fn load_groundtruth(path: BinFile<'_>, k: Option) -> anyhow::Result> { let provider = diskann_providers::storage::FileStorageProvider; let mut file = provider .open_reader(&path.0.to_string_lossy()) @@ -114,6 +114,17 @@ pub(crate) fn load_groundtruth(path: BinFile<'_>) -> anyhow::Result> let mut groundtruth = Matrix::::new(0, num_points, dim); let groundtruth_slice: &mut [u8] = bytemuck::cast_slice_mut(groundtruth.as_mut_slice()); file.read_exact(groundtruth_slice)?; + + if let Some(expected_k) = k { + if groundtruth.ncols() != expected_k { + return Err(anyhow::anyhow!( + "Each row of groundtruth must have length {} (got {})", + expected_k, + groundtruth.ncols() + )); + } + } + Ok(groundtruth) } diff --git a/diskann-benchmark/src/utils/recall.rs b/diskann-benchmark/src/utils/recall.rs index dcbe86d94..b6eebc72b 100644 --- a/diskann-benchmark/src/utils/recall.rs +++ b/diskann-benchmark/src/utils/recall.rs @@ -18,10 +18,6 @@ pub(crate) struct RecallMetrics { pub(crate) num_queries: usize, /// The average recall across all queries. pub(crate) average: f64, - /// The minimum observed recall (max possible value: `recall_n`). - pub(crate) minimum: usize, - /// The maximum observed recall (max possible value: `recall_k`). - pub(crate) maximum: usize, } impl From<&benchmark_core::recall::RecallMetrics> for RecallMetrics { @@ -31,8 +27,6 @@ impl From<&benchmark_core::recall::RecallMetrics> for RecallMetrics { recall_n: m.recall_n, num_queries: m.num_queries, average: m.average, - minimum: m.minimum, - maximum: m.maximum, } } } From 43eb51742170c5dd0629289834b86c438ede43a1 Mon Sep 17 00:00:00 2001 From: Magdalen Manohar Date: Fri, 22 May 2026 15:11:36 +0000 Subject: [PATCH 02/10] fix conflict --- diskann-benchmark-core/src/recall.rs | 34 ------------------- .../src/backend/index/benchmarks.rs | 9 ----- diskann-benchmark/src/utils/datafiles.rs | 6 ---- 3 files changed, 49 deletions(-) diff --git a/diskann-benchmark-core/src/recall.rs b/diskann-benchmark-core/src/recall.rs index 400a475be..eb1b65868 100644 --- a/diskann-benchmark-core/src/recall.rs +++ b/diskann-benchmark-core/src/recall.rs @@ -205,10 +205,6 @@ where let result = results.row(i); let gt_row = groundtruth.row(i); -<<<<<<< HEAD - // groundtruth does not have to be fixed-size, so we compute recall_k for this row based on its gt length - let this_recall_k = gt_row.len().min(recall_k); -======= // `groundtruth` does not have to be fixed-size, // so we compute `recall_k` for this row based on its gt length let this_recall_k = gt_row.len().min(recall_k); @@ -216,7 +212,6 @@ where if this_recall_k == 0 { continue; } ->>>>>>> 4f70a82133bf43e6bece7572e611cb4dedf2c475 // Populate the groundtruth using the top-k this_groundtruth.clear(); @@ -224,20 +219,6 @@ where // If we have distances, then continue to append distances as long as the distance // value is constant -<<<<<<< HEAD - if let Some(distances) = groundtruth_distances - && this_recall_k > 0 - { - let distances_row = distances.row(i); - if distances_row.len() > this_recall_k - 1 && gt_row.len() > this_recall_k - 1 { - let last_distance = distances_row[this_recall_k - 1]; - for (d, g) in distances_row.iter().zip(gt_row.iter()).skip(this_recall_k) { - if *d == last_distance { - this_groundtruth.insert(g.clone()); - } else { - break; - } -======= if let Some(distances) = groundtruth_distances { let distances_row = distances.row(i); @@ -249,7 +230,6 @@ where this_groundtruth.insert(g.clone()); } else { break; ->>>>>>> 4f70a82133bf43e6bece7572e611cb4dedf2c475 } } } @@ -264,32 +244,18 @@ where .count() .min(this_recall_k); -<<<<<<< HEAD - // recall is the number of correct results in the top n, divided by k (not n), or 0 if there are no groundtruth results for this query - let recall = if this_recall_k > 0 { - (r as f64) / (this_recall_k as f64) - } else { - 0.0 - }; -======= let recall = (r as f64) / (this_recall_k as f64); ->>>>>>> 4f70a82133bf43e6bece7572e611cb4dedf2c475 recall_values.push(recall); } // Compute the average recall let total: f64 = recall_values.iter().sum(); -<<<<<<< HEAD - let div = recall_values.len(); - let average = (total) / (div as f64); -======= let average = if recall_values.is_empty() { 0.0 } else { total / (recall_values.len() as f64) }; ->>>>>>> 4f70a82133bf43e6bece7572e611cb4dedf2c475 Ok(RecallMetrics { recall_k, diff --git a/diskann-benchmark/src/backend/index/benchmarks.rs b/diskann-benchmark/src/backend/index/benchmarks.rs index 39804cb1b..a289b7571 100644 --- a/diskann-benchmark/src/backend/index/benchmarks.rs +++ b/diskann-benchmark/src/backend/index/benchmarks.rs @@ -465,16 +465,7 @@ where Arc::new(datafiles::load_dataset(datafiles::BinFile(&topk.queries))?); // compute the maximum value of k used in any search -<<<<<<< HEAD - let max_k = topk - .runs - .iter() - .map(|run| run.recall_k) - .max() - .ok_or_else(|| anyhow::anyhow!("No runs provided in Topk phase"))?; -======= let max_k = topk.max_k(); ->>>>>>> 4f70a82133bf43e6bece7572e611cb4dedf2c475 let groundtruth = datafiles::load_groundtruth(datafiles::BinFile(&topk.groundtruth), Some(max_k))?; diff --git a/diskann-benchmark/src/utils/datafiles.rs b/diskann-benchmark/src/utils/datafiles.rs index bf3cb3f54..abfe06a7d 100644 --- a/diskann-benchmark/src/utils/datafiles.rs +++ b/diskann-benchmark/src/utils/datafiles.rs @@ -116,15 +116,9 @@ pub(crate) fn load_groundtruth(path: BinFile<'_>, k: Option) -> anyhow::R file.read_exact(groundtruth_slice)?; if let Some(expected_k) = k { -<<<<<<< HEAD - if groundtruth.ncols() != expected_k { - return Err(anyhow::anyhow!( - "Each row of groundtruth must have length {} (got {})", -======= if groundtruth.ncols() < expected_k { return Err(anyhow::anyhow!( "Each row of groundtruth must have at least {} neighbors (got {})", ->>>>>>> 4f70a82133bf43e6bece7572e611cb4dedf2c475 expected_k, groundtruth.ncols() )); From 54ee01bbbe919de015cd6c1c48cb68a865e2f04e Mon Sep 17 00:00:00 2001 From: Magdalen Manohar Date: Tue, 2 Jun 2026 20:01:59 +0000 Subject: [PATCH 03/10] fix conflict --- diskann-benchmark/src/backend/index/benchmarks.rs | 9 --------- 1 file changed, 9 deletions(-) diff --git a/diskann-benchmark/src/backend/index/benchmarks.rs b/diskann-benchmark/src/backend/index/benchmarks.rs index a289b7571..e0684923a 100644 --- a/diskann-benchmark/src/backend/index/benchmarks.rs +++ b/diskann-benchmark/src/backend/index/benchmarks.rs @@ -691,16 +691,7 @@ where let managed = Managed::new(max_points, consolidate_threshold, managed_stream); // compute the maximum value of k used in any search -<<<<<<< HEAD - let max_k = topk - .runs - .iter() - .map(|run| run.recall_k) - .max() - .ok_or_else(|| anyhow::anyhow!("No runs provided in Topk phase"))?; -======= let max_k = topk.max_k(); ->>>>>>> 4f70a82133bf43e6bece7572e611cb4dedf2c475 let layered = bigann::WithData::new(managed, data, queries, move |path| { Ok(Box::new(datafiles::load_groundtruth( From acd1f8c03f89e13a988f5af55f8904556ee140af Mon Sep 17 00:00:00 2001 From: Magdalen Manohar Date: Thu, 11 Jun 2026 20:52:27 +0000 Subject: [PATCH 04/10] remove k_value from Knn within diskann crate --- diskann/src/graph/misc.rs | 36 ++++++++++++-- diskann/src/graph/search/diverse_search.rs | 4 +- diskann/src/graph/search/knn_search.rs | 55 ++++----------------- diskann/src/graph/test/cases/grid_insert.rs | 5 +- diskann/src/graph/test/cases/grid_search.rs | 7 +-- diskann/src/graph/test/cases/inline.rs | 12 ++--- diskann/src/graph/test/cases/multihop.rs | 19 ++++--- 7 files changed, 66 insertions(+), 72 deletions(-) diff --git a/diskann/src/graph/misc.rs b/diskann/src/graph/misc.rs index 067bd2d21..712f3047d 100644 --- a/diskann/src/graph/misc.rs +++ b/diskann/src/graph/misc.rs @@ -3,6 +3,10 @@ * Licensed under the MIT license. */ +use std::num::NonZeroUsize; +use thiserror::Error; +use crate::{ANNError, ANNErrorKind}; + // enum used to return the status of the vector that `consolidate_vector` // was called on: Deleted if the vector was already deleted, and Complete // if the vector was not deleted (and thus is now consolidated) @@ -31,6 +35,24 @@ pub enum InplaceDeleteMethod { OneHop, } +/// Error type for [`DiverseSearchParams`] parameter validation. +#[cfg(feature = "experimental_diversity_search")] +#[derive(Debug, Error)] +pub enum DiverseSearchError { + #[error("original k_value cannot be zero")] + OriginalKZero, + #[error("diverse k_value cannot be zero")] + DiverseKZero, +} + +#[cfg(feature = "experimental_diversity_search")] +impl From for ANNError { + #[track_caller] + fn from(err: DiverseSearchError) -> Self { + Self::new(ANNErrorKind::IndexError, err) + } +} + // Parameters for diverse search #[cfg(feature = "experimental_diversity_search")] #[derive(Clone, Debug)] @@ -39,7 +61,8 @@ where P: crate::neighbor::AttributeValueProvider, { pub diverse_attribute_id: usize, - pub diverse_results_k: usize, + pub diverse_results_k: NonZeroUsize, + pub original_k_value: NonZeroUsize, pub attribute_provider: std::sync::Arc

, } @@ -51,13 +74,18 @@ where pub fn new( diverse_attribute_id: usize, diverse_results_k: usize, + original_k_value: usize, attribute_provider: std::sync::Arc

, - ) -> Self { - Self { + ) -> Result { + let diverse_results_k = NonZeroUsize::new(diverse_results_k).ok_or(DiverseSearchError::DiverseKZero)?; + let original_k_value = NonZeroUsize::new(original_k_value).ok_or(DiverseSearchError::OriginalKZero)?; + + Ok(Self { diverse_attribute_id, diverse_results_k, + original_k_value, attribute_provider, - } + }) } } diff --git a/diskann/src/graph/search/diverse_search.rs b/diskann/src/graph/search/diverse_search.rs index f87315711..2bb92c115 100644 --- a/diskann/src/graph/search/diverse_search.rs +++ b/diskann/src/graph/search/diverse_search.rs @@ -72,8 +72,8 @@ where let attribute_provider = self.diverse_params.attribute_provider.clone(); let diverse_queue = DiverseNeighborQueue::new( self.inner.l_value().get(), - self.inner.k_value(), - self.diverse_params.diverse_results_k, + self.diverse_params.original_k_value, + self.diverse_params.diverse_results_k.get(), attribute_provider, ); diff --git a/diskann/src/graph/search/knn_search.rs b/diskann/src/graph/search/knn_search.rs index fae00ab84..0c929f2c0 100644 --- a/diskann/src/graph/search/knn_search.rs +++ b/diskann/src/graph/search/knn_search.rs @@ -26,12 +26,8 @@ use crate::{ /// Error type for [`Knn`] parameter validation. #[derive(Debug, Error)] pub enum KnnSearchError { - #[error("l_value ({l_value}) cannot be less than k_value ({k_value})")] - LLessThanK { l_value: usize, k_value: usize }, #[error("beam width cannot be zero")] BeamWidthZero, - #[error("k_value cannot be zero")] - KZero, #[error("l_value cannot be zero")] LZero, } @@ -60,7 +56,6 @@ impl From for ANNError { /// /// # Parameters /// -/// - `k_value`: Number of nearest neighbors to return /// - `l_value`: Search list size (larger values improve recall at cost of latency) /// - `beam_width`: Optional parallel exploration width /// @@ -69,13 +64,11 @@ impl From for ANNError { /// ```ignore /// use diskann::graph::{search::Knn, Search}; /// -/// let params = Knn::new(10, 100, None)?; +/// let params = Knn::new(100, None)?; /// let stats = index.search(params, &strategy, &context, &query, &mut output).await?; /// ``` #[derive(Debug, Clone, Copy)] pub struct Knn { - /// Number of results to return (k in k-NN). - k_value: NonZeroUsize, /// Search list size - controls accuracy vs speed tradeoff. l_value: NonZeroUsize, /// Beam width for parallel graph exploration (defaults to 1). @@ -89,21 +82,10 @@ impl Knn { /// /// # Errors /// - /// Returns an error if `k_value` is zero, `l_value` is zero, - /// `l_value < k_value`, or if `beam_width` is `Some(0)`. - pub fn new( - k_value: usize, - l_value: usize, - beam_width: Option, - ) -> Result { - let k_value = NonZeroUsize::new(k_value).ok_or(KnnSearchError::KZero)?; + /// Returns an error if `l_value` is zero, + /// or if `beam_width` is `Some(0)`. + pub fn new(l_value: usize, beam_width: Option) -> Result { let l_value = NonZeroUsize::new(l_value).ok_or(KnnSearchError::LZero)?; - if k_value > l_value { - return Err(KnnSearchError::LLessThanK { - l_value: l_value.get(), - k_value: k_value.get(), - }); - } const ONE: NonZeroUsize = NonZeroUsize::new(1).unwrap(); let beam_width = match beam_width { @@ -112,21 +94,14 @@ impl Knn { }; Ok(Self { - k_value, l_value, beam_width, }) } /// Create parameters with default beam width. - pub fn new_default(k_value: usize, l_value: usize) -> Result { - Self::new(k_value, l_value, None) - } - - /// Returns the number of results to return (k in k-NN). - #[inline] - pub fn k_value(&self) -> NonZeroUsize { - self.k_value + pub fn new_default(l_value: usize) -> Result { + Self::new(l_value, None) } /// Returns the search list size. @@ -299,25 +274,15 @@ mod tests { #[test] fn test_knn_search_validation() { // Valid - assert!(Knn::new(10, 100, None).is_ok()); - assert!(Knn::new(10, 100, Some(4)).is_ok()); - assert!(Knn::new(10, 10, None).is_ok()); // k == l is valid - - // Invalid: k = 0 - assert!(matches!(Knn::new(0, 100, None), Err(KnnSearchError::KZero))); + assert!(Knn::new(100, None).is_ok()); + assert!(Knn::new(100, Some(4)).is_ok()); // Invalid: l = 0 - assert!(matches!(Knn::new(10, 0, None), Err(KnnSearchError::LZero))); - - // Invalid: l < k - assert!(matches!( - Knn::new(100, 10, None), - Err(KnnSearchError::LLessThanK { .. }) - )); + assert!(matches!(Knn::new(0, None), Err(KnnSearchError::LZero))); // Invalid: zero beam_width assert!(matches!( - Knn::new(10, 100, Some(0)), + Knn::new(100, Some(0)), Err(KnnSearchError::BeamWidthZero) )); } diff --git a/diskann/src/graph/test/cases/grid_insert.rs b/diskann/src/graph/test/cases/grid_insert.rs index 0c0e24724..ed26f970a 100644 --- a/diskann/src/graph/test/cases/grid_insert.rs +++ b/diskann/src/graph/test/cases/grid_insert.rs @@ -201,10 +201,11 @@ fn run_searches( let mut results = Vec::new(); for (query, desc) in queries { - let params = Knn::new(10, 10, None).unwrap(); + let k_value = 10; + let params = Knn::new(10, None).unwrap(); let search_ctx = test_provider::Context::new(); - let mut neighbors = vec![Neighbor::::default(); params.k_value().get()]; + let mut neighbors = vec![Neighbor::::default(); k_value]; let graph::index::SearchStats { cmps, hops, diff --git a/diskann/src/graph/test/cases/grid_search.rs b/diskann/src/graph/test/cases/grid_search.rs index 7afbb2214..dab9001cb 100644 --- a/diskann/src/graph/test/cases/grid_search.rs +++ b/diskann/src/graph/test/cases/grid_search.rs @@ -127,10 +127,11 @@ fn _grid_search(grid: Grid, size: usize, mut parent: TestPath<'_>) { // are correct. let index = setup_grid_search(grid, size); - let params = Knn::new(10, 10, Some(beam_width)).unwrap(); + let k_value = 10; + let params = Knn::new(10, Some(beam_width)).unwrap(); let context = test_provider::Context::new(); - let mut neighbors = vec![Neighbor::::default(); params.k_value().get()]; + let mut neighbors = vec![Neighbor::::default(); k_value]; let graph::index::SearchStats { cmps, hops, @@ -147,7 +148,7 @@ fn _grid_search(grid: Grid, size: usize, mut parent: TestPath<'_>) { .unwrap(); assert!( - result_count.into_usize() <= params.k_value().get(), + result_count.into_usize() <= k_value, "grid search should not return more than the requested number of neighbors", ); diff --git a/diskann/src/graph/test/cases/inline.rs b/diskann/src/graph/test/cases/inline.rs index 40fddbc96..f4bf105b7 100644 --- a/diskann/src/graph/test/cases/inline.rs +++ b/diskann/src/graph/test/cases/inline.rs @@ -348,7 +348,7 @@ fn run_inline_on_grid( adaptive_l: Option, ) -> InlineFilterBaseline { let rt = current_thread_runtime(); - let inline = InlineFilterSearch::new(Knn::new_default(k, l).unwrap(), filter, adaptive_l); + let inline = InlineFilterSearch::new(Knn::new_default(l).unwrap(), filter, adaptive_l); let mut ids = vec![0u32; k]; let mut distances = vec![0.0f32; k]; @@ -391,7 +391,7 @@ fn inline_search_returns_only_final_level_matches() { let filter = LevelLabelProvider::new(); let k = 8; let l = 32; - let inline = InlineFilterSearch::new(Knn::new_default(k, l).unwrap(), &filter, None); + let inline = InlineFilterSearch::new(Knn::new_default(l).unwrap(), &filter, None); let mut ids = vec![0u32; k]; let mut distances = vec![0.0f32; k]; @@ -448,7 +448,7 @@ fn inline_search_three_level_no_adaptive_l_with_l1_finds_no_matches() { let filter = LevelLabelProvider::new(); let k = 1; let l = 1; - let inline = InlineFilterSearch::new(Knn::new_default(k, l).unwrap(), &filter, None); + let inline = InlineFilterSearch::new(Knn::new_default(l).unwrap(), &filter, None); let mut ids = vec![0u32; k]; let mut distances = vec![0.0f32; k]; @@ -501,7 +501,7 @@ fn inline_search_three_level_adaptive_l_with_l1_finds_matches() { let l = 1; let adaptive_l = AdaptiveL::new(1, 16.0).unwrap(); let inline = - InlineFilterSearch::new(Knn::new_default(k, l).unwrap(), &filter, Some(adaptive_l)); + InlineFilterSearch::new(Knn::new_default(l).unwrap(), &filter, Some(adaptive_l)); let mut ids = vec![0u32; k]; let mut distances = vec![0.0f32; k]; @@ -633,7 +633,7 @@ fn inline_search_reaches_matches_through_non_matching_nodes() { let k = 5; let l = 20; - let search_params = Knn::new_default(k, l).unwrap(); + let search_params = Knn::new_default(l).unwrap(); let inline = InlineFilterSearch::new(search_params, &filter, None); let mut ids = vec![0u32; k]; @@ -725,7 +725,7 @@ fn inline_callback_filtering_grid() { let k = 20; let l = 40; - let search_params = Knn::new_default(k, l).unwrap(); + let search_params = Knn::new_default(l).unwrap(); let inline = InlineFilterSearch::new(search_params, &filter, None); let mut ids = vec![0u32; k]; diff --git a/diskann/src/graph/test/cases/multihop.rs b/diskann/src/graph/test/cases/multihop.rs index 71cabd2e1..42a17be17 100644 --- a/diskann/src/graph/test/cases/multihop.rs +++ b/diskann/src/graph/test/cases/multihop.rs @@ -206,7 +206,6 @@ pub(super) fn build_1d_provider( fn run_internal( provider: &test_provider::Provider, query: &[f32], - k: usize, l: usize, max_degree: usize, filter: &dyn QueryLabelProvider, @@ -219,7 +218,7 @@ fn run_internal( let stats = crate::graph::search::multihop_search::multihop_search_internal( max_degree, - &Knn::new_default(k, l).unwrap(), + &Knn::new_default(l).unwrap(), &mut accessor, &mut scratch, &mut NoopSearchRecord::new(), @@ -264,7 +263,7 @@ fn accept_all_finds_all_nodes() { 3, ); - let (stats, results) = run_internal(&provider, &[1.5], 3, 10, 3, &AcceptAll); + let (stats, results) = run_internal(&provider, &[1.5], 10, 3, &AcceptAll); let ids: Vec = results.iter().map(|n| n.id).collect(); assert!(ids.contains(&0), "node 0 should be found"); @@ -305,7 +304,7 @@ fn reject_triggers_two_hop_expansion() { ); let filter = EvenFilter; - let (stats, results) = run_internal(&provider, &[2.0], 5, 20, 4, &filter); + let (stats, results) = run_internal(&provider, &[2.0], 20, 4, &filter); let ids: Vec = results.iter().map(|n| n.id).collect(); @@ -356,7 +355,7 @@ fn reject_all_yields_only_start() { 2, ); - let (_stats, results) = run_internal(&provider, &[0.5], 5, 10, 2, &RejectAll); + let (_stats, results) = run_internal(&provider, &[0.5], 10, 2, &RejectAll); // Only the start point should be in the best set — all one-hop neighbors // were rejected. Two-hop expansion goes through rejected nodes but RejectAll's @@ -392,7 +391,7 @@ fn terminate_stops_search_on_target() { ); let filter = TerminateOnTarget::new(2); - let (_stats, _results) = run_internal(&provider, &[0.0], 4, 10, 2, &filter); + let (_stats, _results) = run_internal(&provider, &[0.0], 10, 2, &filter); let hits = filter.hits(); assert!(hits.contains(&2), "target node 2 should have been visited"); @@ -430,7 +429,7 @@ fn block_and_adjust_modifies_results() { ); let filter = BlockAndAdjust::new(1, 2, 0.5); - let (_stats, results) = run_internal(&provider, &[0.0], 5, 10, 3, &filter); + let (_stats, results) = run_internal(&provider, &[0.0], 10, 3, &filter); let ids: Vec = results.iter().map(|n| n.id).collect(); @@ -556,7 +555,7 @@ fn two_hop_reaches_through_non_matching() { let k = 5; let l = 20; - let search_params = Knn::new_default(k, l).unwrap(); + let search_params = Knn::new_default(l).unwrap(); let multihop = MultihopSearch::new(search_params, &filter); let mut ids = vec![0u32; k]; @@ -625,7 +624,7 @@ fn even_filtering_grid() { let k = 20; let l = 40; - let search_params = Knn::new_default(k, l).unwrap(); + let search_params = Knn::new_default(l).unwrap(); let multihop = MultihopSearch::new(search_params, &filter); let mut ids = vec![0u32; k]; @@ -691,7 +690,7 @@ fn callback_filtering_grid() { let k = 20; let l = 40; - let search_params = Knn::new_default(k, l).unwrap(); + let search_params = Knn::new_default(l).unwrap(); let multihop = MultihopSearch::new(search_params, &filter); let mut ids = vec![0u32; k]; From 02023483c46f1a0a8207e75116e96d4468a31dd0 Mon Sep 17 00:00:00 2001 From: Magdalen Manohar Date: Fri, 12 Jun 2026 15:21:45 +0000 Subject: [PATCH 05/10] integrate changes into benchmark --- .../src/search/graph/inline.rs | 14 ++-- .../src/search/graph/knn.rs | 65 +++++++++++++++---- .../src/search/graph/mod.rs | 2 +- .../src/search/graph/multihop.rs | 15 +++-- diskann-benchmark/src/backend/index/result.rs | 2 +- .../src/backend/index/search/knn.rs | 19 +++--- .../src/search/provider/disk_provider.rs | 3 +- diskann/src/graph/misc.rs | 17 +++-- diskann/src/graph/test/cases/inline.rs | 3 +- 9 files changed, 94 insertions(+), 46 deletions(-) diff --git a/diskann-benchmark-core/src/search/graph/inline.rs b/diskann-benchmark-core/src/search/graph/inline.rs index f07bbc200..1874f3f85 100644 --- a/diskann-benchmark-core/src/search/graph/inline.rs +++ b/diskann-benchmark-core/src/search/graph/inline.rs @@ -12,7 +12,7 @@ use diskann::{ }; use diskann_utils::{future::AsyncFriendly, views::Matrix}; -use crate::search::{self, Search, graph::Strategy}; +use crate::search::{self, Search, graph::KnnWrapper, graph::Strategy}; /// A built-in helper for benchmarking filtered K-nearest neighbors search /// using the inline search method. @@ -22,7 +22,7 @@ use crate::search::{self, Search, graph::Strategy}; /// [`search::search_all`] is provided by the [`search::graph::knn::Aggregator`] type (same /// aggregator as [`search::graph::knn::KNN`]). /// -/// The provided implementation of [`Search`] accepts [`graph::search::Knn`] +/// The provided implementation of [`Search`] accepts [`KnnWrapper`] /// and returns [`search::graph::knn::Metrics`] as additional output. #[derive(Debug)] pub struct InlineFilterSearch @@ -93,7 +93,7 @@ where T: AsyncFriendly + Clone, { type Id = DP::ExternalId; - type Parameters = graph::search::Knn; + type Parameters = KnnWrapper; type Output = super::knn::Metrics; fn num_queries(&self) -> usize { @@ -115,7 +115,7 @@ where { let context = DP::Context::default(); let inline_search = graph::search::InlineFilterSearch::new( - *parameters, + parameters.knn, &*self.labels[index], self.adaptive_l.clone(), ); @@ -192,7 +192,7 @@ mod tests { let rt = crate::tokio::runtime(2).unwrap(); let results = search::search( inline.clone(), - graph::search::Knn::new(nearest_neighbors, 10, None).unwrap(), + KnnWrapper::new(nearest_neighbors, 10).unwrap(), NonZeroUsize::new(2).unwrap(), &rt, ) @@ -220,11 +220,11 @@ mod tests { // Try the aggregated strategy. let parameters = [ search::Run::new( - graph::search::Knn::new(nearest_neighbors, 10, None).unwrap(), + KnnWrapper::new(nearest_neighbors, 10).unwrap(), setup.clone(), ), search::Run::new( - graph::search::Knn::new(nearest_neighbors, 15, None).unwrap(), + KnnWrapper::new(nearest_neighbors, 15).unwrap(), setup.clone(), ), ]; diff --git a/diskann-benchmark-core/src/search/graph/knn.rs b/diskann-benchmark-core/src/search/graph/knn.rs index f430ed099..d7981f13e 100644 --- a/diskann-benchmark-core/src/search/graph/knn.rs +++ b/diskann-benchmark-core/src/search/graph/knn.rs @@ -5,10 +5,11 @@ //! A built-in helper for benchmarking K-nearest neighbors. -use std::sync::Arc; +use std::{num::NonZeroUsize, sync::Arc}; +use thiserror::Error; use diskann::{ - ANNResult, + ANNError, ANNResult, graph::{self, glue}, provider, }; @@ -30,7 +31,7 @@ use crate::{ /// the latter. Result aggregation for [`search::search_all`] is provided /// by the [`Aggregator`] type. /// -/// The provided implementation of [`Search`] accepts [`graph::search::Knn`] +/// The provided implementation of [`Search`] accepts [`KnnWrapper`] /// and returns [`Metrics`] as additional output. #[derive(Debug)] pub struct KNN @@ -93,7 +94,7 @@ where T: AsyncFriendly + Clone, { type Id = DP::ExternalId; - type Parameters = graph::search::Knn; + type Parameters = KnnWrapper; type Output = Metrics; fn num_queries(&self) -> usize { @@ -114,7 +115,7 @@ where O: graph::SearchOutputBuffer + Send, { let context = DP::Context::default(); - let knn_search = *parameters; + let knn_search = parameters.knn; let stats = self .index .search( @@ -133,6 +134,48 @@ where } } +#[derive(Debug, Error)] +pub enum KnnWrapperError { + #[error("k_value cannot be zero")] + KZero, + #[error("l_value must be at least k_value")] + LLessThanK, + #[error("invalid KNN parameters")] + InvalidKnnParameters, +} + +impl From for ANNError { + #[track_caller] + fn from(err: KnnWrapperError) -> Self { + ANNError::opaque(err) + } +} + +/// A wrapper for the [`graph::search::Knn`] struct that also includes the `k` value. +#[derive(Debug, Copy, Clone)] +pub struct KnnWrapper { + k_value: NonZeroUsize, + pub knn: graph::search::Knn, +} + +impl KnnWrapper { + /// Construct a new [`KnnWrapper`]. + pub fn new(k_value: usize, l_value: usize) -> Result { + let k_value = NonZeroUsize::new(k_value).ok_or(KnnWrapperError::KZero)?; + if l_value < k_value.get() { + return Err(KnnWrapperError::LLessThanK); + } + + let knn = graph::search::Knn::new(l_value, None) + .map_err(|_| KnnWrapperError::InvalidKnnParameters)?; + Ok(Self { k_value, knn }) + } + + pub fn k_value(&self) -> NonZeroUsize { + self.k_value + } +} + /// An [`search::Aggregate`]d summary of multiple [`KNN`] search runs /// returned by the provided [`Aggregator`]. /// @@ -144,7 +187,7 @@ pub struct Summary { pub setup: search::Setup, /// The [`Search::Parameters`] used for the batch of runs. - pub parameters: graph::search::Knn, + pub parameters: KnnWrapper, /// The end-to-end latency for each repetition in the batch. pub end_to_end_latencies: Vec, @@ -212,7 +255,7 @@ impl<'a, I> Aggregator<'a, I> { } } -impl search::Aggregate for Aggregator<'_, I> +impl search::Aggregate for Aggregator<'_, I> where I: crate::recall::RecallCompatible, { @@ -220,7 +263,7 @@ where fn aggregate( &mut self, - run: search::Run, + run: search::Run, mut results: Vec>, ) -> anyhow::Result

{ // Compute the recall using just the first result. @@ -317,7 +360,7 @@ mod tests { let rt = crate::tokio::runtime(2).unwrap(); let results = search::search( knn.clone(), - graph::search::Knn::new(nearest_neighbors, 10, None).unwrap(), + KnnWrapper::new(nearest_neighbors, 10).unwrap(), NonZeroUsize::new(2).unwrap(), &rt, ) @@ -341,11 +384,11 @@ mod tests { // Try the aggregated strategy. let parameters = [ search::Run::new( - graph::search::Knn::new(nearest_neighbors, 10, None).unwrap(), + KnnWrapper::new(nearest_neighbors, 10).unwrap(), setup.clone(), ), search::Run::new( - graph::search::Knn::new(nearest_neighbors, 15, None).unwrap(), + KnnWrapper::new(nearest_neighbors, 15).unwrap(), setup.clone(), ), ]; diff --git a/diskann-benchmark-core/src/search/graph/mod.rs b/diskann-benchmark-core/src/search/graph/mod.rs index 8063f0875..01bda088c 100644 --- a/diskann-benchmark-core/src/search/graph/mod.rs +++ b/diskann-benchmark-core/src/search/graph/mod.rs @@ -11,7 +11,7 @@ pub mod range; pub mod strategy; pub use inline::InlineFilterSearch; -pub use knn::KNN; +pub use knn::{KNN, KnnWrapper}; pub use multihop::MultiHop; pub use range::Range; diff --git a/diskann-benchmark-core/src/search/graph/multihop.rs b/diskann-benchmark-core/src/search/graph/multihop.rs index 0bde74c1f..d602f88e5 100644 --- a/diskann-benchmark-core/src/search/graph/multihop.rs +++ b/diskann-benchmark-core/src/search/graph/multihop.rs @@ -12,7 +12,7 @@ use diskann::{ }; use diskann_utils::{future::AsyncFriendly, views::Matrix}; -use crate::search::{self, Search, graph::Strategy}; +use crate::search::{self, Search, graph::KnnWrapper, graph::Strategy}; /// A built-in helper for benchmarking filtered K-nearest neighbors search /// using the multi-hop search method. @@ -22,7 +22,7 @@ use crate::search::{self, Search, graph::Strategy}; /// [`search::search_all`] is provided by the [`search::graph::knn::Aggregator`] type (same /// aggregator as [`search::graph::knn::KNN`]). /// -/// The provided implementation of [`Search`] accepts [`graph::search::Knn`] +/// The provided implementation of [`Search`] accepts [`KnnWrapper`] /// and returns [`search::graph::knn::Metrics`] as additional output. #[derive(Debug)] pub struct MultiHop @@ -90,7 +90,7 @@ where T: AsyncFriendly + Clone, { type Id = DP::ExternalId; - type Parameters = graph::search::Knn; + type Parameters = KnnWrapper; type Output = super::knn::Metrics; fn num_queries(&self) -> usize { @@ -111,7 +111,8 @@ where O: graph::SearchOutputBuffer + Send, { let context = DP::Context::default(); - let multihop_search = graph::search::MultihopSearch::new(*parameters, &*self.labels[index]); + let multihop_search = + graph::search::MultihopSearch::new(parameters.knn, &*self.labels[index]); let stats = self .index .search( @@ -182,7 +183,7 @@ mod tests { let rt = crate::tokio::runtime(2).unwrap(); let results = search::search( multihop.clone(), - graph::search::Knn::new(nearest_neighbors, 10, None).unwrap(), + KnnWrapper::new(nearest_neighbors, 10).unwrap(), NonZeroUsize::new(2).unwrap(), &rt, ) @@ -210,11 +211,11 @@ mod tests { // Try the aggregated strategy. let parameters = [ search::Run::new( - graph::search::Knn::new(nearest_neighbors, 10, None).unwrap(), + KnnWrapper::new(nearest_neighbors, 10).unwrap(), setup.clone(), ), search::Run::new( - graph::search::Knn::new(nearest_neighbors, 15, None).unwrap(), + KnnWrapper::new(nearest_neighbors, 15).unwrap(), setup.clone(), ), ]; diff --git a/diskann-benchmark/src/backend/index/result.rs b/diskann-benchmark/src/backend/index/result.rs index a68650d14..75d9f54f8 100644 --- a/diskann-benchmark/src/backend/index/result.rs +++ b/diskann-benchmark/src/backend/index/result.rs @@ -120,7 +120,7 @@ impl SearchResults { Self { num_tasks: setup.tasks.into(), search_n: parameters.k_value().get(), - search_l: parameters.l_value().get(), + search_l: parameters.knn.l_value().get(), qps, search_latencies: end_to_end_latencies, mean_latencies, diff --git a/diskann-benchmark/src/backend/index/search/knn.rs b/diskann-benchmark/src/backend/index/search/knn.rs index 2138ee156..c28f5b73f 100644 --- a/diskann-benchmark/src/backend/index/search/knn.rs +++ b/diskann-benchmark/src/backend/index/search/knn.rs @@ -5,8 +5,8 @@ use std::{num::NonZeroUsize, sync::Arc}; -use diskann_benchmark_core::recall::GroundTruthMode; use diskann_benchmark_core::{self as benchmark_core, search as core_search}; +use diskann_benchmark_core::{recall::GroundTruthMode, search::graph::KnnWrapper}; use crate::{backend::index::result::SearchResults, inputs::graph_index::GraphSearch}; @@ -50,8 +50,7 @@ pub(crate) fn run( .search_l .iter() .map(|search_l| { - let search_params = - diskann::graph::search::Knn::new(run.search_n, *search_l, None).unwrap(); + let search_params = KnnWrapper::new(run.search_n, *search_l).unwrap(); core_search::Run::new(search_params, setup.clone()) }) @@ -64,7 +63,7 @@ pub(crate) fn run( Ok(all) } -type Run = core_search::Run; +type Run = core_search::Run; pub(crate) trait Knn { fn search_all( &self, @@ -84,13 +83,13 @@ where DP: diskann::provider::DataProvider, core_search::graph::KNN: core_search::Search< Id = DP::InternalId, - Parameters = diskann::graph::search::Knn, + Parameters = KnnWrapper, Output = core_search::graph::knn::Metrics, >, { fn search_all( &self, - parameters: Vec>, + parameters: Vec>, groundtruth: &dyn benchmark_core::recall::Rows, recall_k: usize, recall_n: usize, @@ -115,13 +114,13 @@ where DP: diskann::provider::DataProvider, core_search::graph::MultiHop: core_search::Search< Id = DP::InternalId, - Parameters = diskann::graph::search::Knn, + Parameters = KnnWrapper, Output = core_search::graph::knn::Metrics, >, { fn search_all( &self, - parameters: Vec>, + parameters: Vec>, groundtruth: &dyn benchmark_core::recall::Rows, recall_k: usize, recall_n: usize, @@ -146,13 +145,13 @@ where DP: diskann::provider::DataProvider, core_search::graph::InlineFilterSearch: core_search::Search< Id = DP::InternalId, - Parameters = diskann::graph::search::Knn, + Parameters = KnnWrapper, Output = core_search::graph::knn::Metrics, >, { fn search_all( &self, - parameters: Vec>, + parameters: Vec>, groundtruth: &dyn benchmark_core::recall::Rows, recall_k: usize, recall_n: usize, diff --git a/diskann-disk/src/search/provider/disk_provider.rs b/diskann-disk/src/search/provider/disk_provider.rs index 5e49cd4dd..567541c82 100644 --- a/diskann-disk/src/search/provider/disk_provider.rs +++ b/diskann-disk/src/search/provider/disk_provider.rs @@ -899,7 +899,6 @@ where let strategy = self.search_strategy(vector_filter); let timer = Instant::now(); - let k = k_value; let l = search_list_size as usize; let stats = if is_flat_search { self.runtime.block_on(self.flat_search( @@ -910,7 +909,7 @@ where &mut result_output_buffer, ))? } else { - let knn_search = Knn::new(k, l, beam_width)?; + let knn_search = Knn::new(l, beam_width)?; self.runtime.block_on(self.index.search( knn_search, &strategy, diff --git a/diskann/src/graph/misc.rs b/diskann/src/graph/misc.rs index 712f3047d..821a02022 100644 --- a/diskann/src/graph/misc.rs +++ b/diskann/src/graph/misc.rs @@ -2,10 +2,15 @@ * Copyright (c) Microsoft Corporation. * Licensed under the MIT license. */ +#[cfg(feature = "experimental_diversity_search")] +mod imports { + pub(super) use crate::{ANNError, ANNErrorKind}; + pub(super) use std::num::NonZeroUsize; + pub(super) use thiserror::Error; +} -use std::num::NonZeroUsize; -use thiserror::Error; -use crate::{ANNError, ANNErrorKind}; +#[cfg(feature = "experimental_diversity_search")] +use imports::*; // enum used to return the status of the vector that `consolidate_vector` // was called on: Deleted if the vector was already deleted, and Complete @@ -77,8 +82,10 @@ where original_k_value: usize, attribute_provider: std::sync::Arc

, ) -> Result { - let diverse_results_k = NonZeroUsize::new(diverse_results_k).ok_or(DiverseSearchError::DiverseKZero)?; - let original_k_value = NonZeroUsize::new(original_k_value).ok_or(DiverseSearchError::OriginalKZero)?; + let diverse_results_k = + NonZeroUsize::new(diverse_results_k).ok_or(DiverseSearchError::DiverseKZero)?; + let original_k_value = + NonZeroUsize::new(original_k_value).ok_or(DiverseSearchError::OriginalKZero)?; Ok(Self { diverse_attribute_id, diff --git a/diskann/src/graph/test/cases/inline.rs b/diskann/src/graph/test/cases/inline.rs index f4bf105b7..d142e2160 100644 --- a/diskann/src/graph/test/cases/inline.rs +++ b/diskann/src/graph/test/cases/inline.rs @@ -500,8 +500,7 @@ fn inline_search_three_level_adaptive_l_with_l1_finds_matches() { let k = 1; let l = 1; let adaptive_l = AdaptiveL::new(1, 16.0).unwrap(); - let inline = - InlineFilterSearch::new(Knn::new_default(l).unwrap(), &filter, Some(adaptive_l)); + let inline = InlineFilterSearch::new(Knn::new_default(l).unwrap(), &filter, Some(adaptive_l)); let mut ids = vec![0u32; k]; let mut distances = vec![0.0f32; k]; From 0191feeadfe93966952c04f967d33cf6103e6996 Mon Sep 17 00:00:00 2001 From: Magdalen Dobson Manohar <58752279+magdalendobson@users.noreply.github.com> Date: Fri, 12 Jun 2026 11:36:57 -0400 Subject: [PATCH 06/10] Potential fix for pull request finding Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com> --- diskann-benchmark-core/src/search/graph/knn.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/diskann-benchmark-core/src/search/graph/knn.rs b/diskann-benchmark-core/src/search/graph/knn.rs index d7981f13e..df957a60e 100644 --- a/diskann-benchmark-core/src/search/graph/knn.rs +++ b/diskann-benchmark-core/src/search/graph/knn.rs @@ -138,8 +138,8 @@ where pub enum KnnWrapperError { #[error("k_value cannot be zero")] KZero, - #[error("l_value must be at least k_value")] - LLessThanK, + #[error("l_value ({l_value}) must be at least k_value ({k_value})")] + LLessThanK { l_value: usize, k_value: usize }, #[error("invalid KNN parameters")] InvalidKnnParameters, } From 2f28f50292d182d29d0ce7648e26f9c34c0da2a6 Mon Sep 17 00:00:00 2001 From: Magdalen Dobson Manohar <58752279+magdalendobson@users.noreply.github.com> Date: Fri, 12 Jun 2026 11:37:13 -0400 Subject: [PATCH 07/10] Potential fix for pull request finding Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com> --- diskann-benchmark-core/src/search/graph/knn.rs | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/diskann-benchmark-core/src/search/graph/knn.rs b/diskann-benchmark-core/src/search/graph/knn.rs index df957a60e..512706ffa 100644 --- a/diskann-benchmark-core/src/search/graph/knn.rs +++ b/diskann-benchmark-core/src/search/graph/knn.rs @@ -163,7 +163,10 @@ impl KnnWrapper { pub fn new(k_value: usize, l_value: usize) -> Result { let k_value = NonZeroUsize::new(k_value).ok_or(KnnWrapperError::KZero)?; if l_value < k_value.get() { - return Err(KnnWrapperError::LLessThanK); + return Err(KnnWrapperError::LLessThanK { + l_value, + k_value: k_value.get(), + }); } let knn = graph::search::Knn::new(l_value, None) From 7ee7938db985f3e07119866bcb3bb8e365bf7fee Mon Sep 17 00:00:00 2001 From: Magdalen Manohar Date: Fri, 12 Jun 2026 17:45:22 +0000 Subject: [PATCH 08/10] move changes to bftree and disk --- diskann-bftree/src/provider.rs | 44 ++++++----- .../src/search/provider/disk_provider.rs | 76 +++++++++---------- 2 files changed, 64 insertions(+), 56 deletions(-) diff --git a/diskann-bftree/src/provider.rs b/diskann-bftree/src/provider.rs index 3adaae0f0..5946d4460 100644 --- a/diskann-bftree/src/provider.rs +++ b/diskann-bftree/src/provider.rs @@ -2033,9 +2033,10 @@ mod tests { } let query = vec![3.0; 5]; - let params = Knn::new(5, 10, None).unwrap(); + let params = Knn::new(10, None).unwrap(); - let mut neighbors = vec![Neighbor::::default(); 5]; + let k = 5; + let mut neighbors = vec![Neighbor::::default(); k]; let res = index .search( params, @@ -2048,8 +2049,9 @@ mod tests { .unwrap(); assert_eq!( - res.result_count, 5, - "there are 15 points and we're asking for 5, we expect 5" + res.result_count, k as u32, + "there are 15 points and we're asking for {}, we expect {}", + 5, k ); assert_eq!(neighbors[0].id, 3); } @@ -2084,9 +2086,10 @@ mod tests { .unwrap(); let query = vec![3.0; 5]; - let params = Knn::new(5, 10, None).unwrap(); + let params = Knn::new(10, None).unwrap(); - let mut neighbors = vec![Neighbor::::default(); 5]; + let k = 5; + let mut neighbors = vec![Neighbor::::default(); k]; let res = index .search( params, @@ -2099,8 +2102,9 @@ mod tests { .unwrap(); assert_eq!( - res.result_count, 5, - "there are 15 points and we're asking for 5, we expect 5" + res.result_count, k as u32, + "there are 15 points and we're asking for {}, we expect {}", + 5, k ); let neighbor_ids: Vec = neighbors.iter().map(|n| n.id).collect(); for expected in 1u32..=5 { @@ -2134,9 +2138,10 @@ mod tests { .unwrap(); let query = vec![3.0; 5]; - let params = Knn::new(5, 10, None).unwrap(); + let params = Knn::new(10, None).unwrap(); - let mut neighbors = vec![Neighbor::::default(); 5]; + let k = 5; + let mut neighbors = vec![Neighbor::::default(); k]; let res = index .search( params, @@ -2148,7 +2153,7 @@ mod tests { .await .unwrap(); - assert_eq!(res.result_count, 5); + assert_eq!(res.result_count, k as u32); let neighbor_ids: Vec = neighbors.iter().map(|n| n.id).collect(); assert!(!neighbor_ids.contains(&2u32)); assert!(!neighbor_ids.contains(&4u32)); @@ -2203,9 +2208,10 @@ mod tests { } let query = vec![3.0; 5]; - let params = Knn::new(5, 10, None).unwrap(); + let params = Knn::new(10, None).unwrap(); - let mut neighbors = vec![Neighbor::::default(); 5]; + let k = 5; + let mut neighbors = vec![Neighbor::::default(); k]; let res = index .search( params, @@ -2218,8 +2224,9 @@ mod tests { .unwrap(); assert_eq!( - res.result_count, 5, - "there are 15 points and we're asking for 5, we expect 5" + res.result_count, k as u32, + "there are 15 points and we're asking for {}, we expect {}", + 5, k ); assert_eq!(neighbors[0].id, 3); } @@ -2259,9 +2266,10 @@ mod tests { .unwrap(); let query = vec![3.0; 5]; - let params = Knn::new(5, 10, None).unwrap(); + let params = Knn::new(10, None).unwrap(); - let mut neighbors = vec![Neighbor::::default(); 5]; + let k = 5; + let mut neighbors = vec![Neighbor::::default(); k]; let res = index .search( params, @@ -2273,7 +2281,7 @@ mod tests { .await .unwrap(); - assert_eq!(res.result_count, 5); + assert_eq!(res.result_count, k as u32); let neighbor_ids: Vec = neighbors.iter().map(|n| n.id).collect(); assert!(!neighbor_ids.contains(&2u32)); assert!(!neighbor_ids.contains(&4u32)); diff --git a/diskann-disk/src/search/provider/disk_provider.rs b/diskann-disk/src/search/provider/disk_provider.rs index 567541c82..ddf458735 100644 --- a/diskann-disk/src/search/provider/disk_provider.rs +++ b/diskann-disk/src/search/provider/disk_provider.rs @@ -909,6 +909,12 @@ where &mut result_output_buffer, ))? } else { + if search_list_size < k_value as u32 { + return Err(ANNError::message( + diskann::ANNErrorKind::IndexError, + "search list size must be at least as large as the number of results requested", + )); + } let knn_search = Knn::new(l, beam_width)?; self.runtime.block_on(self.index.search( knn_search, @@ -958,10 +964,7 @@ fn ensure_vertex_loaded>( mod disk_provider_tests { use crate::test_utils::{GraphDataF32VectorU32Data, GraphDataF32VectorUnitData}; use diskann::{ - graph::{ - search::{record::VisitedSearchRecord, Knn}, - KnnSearchError, - }, + graph::search::{record::VisitedSearchRecord, Knn}, utils::IntoUsize, ANNErrorKind, }; @@ -1409,15 +1412,8 @@ mod disk_provider_tests { "index_path is not correct" ); - // Test error case: l < k - let res = Knn::new_default(20, 10); - assert!(res.is_err()); - assert_eq!( - >::into(res.unwrap_err()).kind(), - ANNErrorKind::IndexError - ); // Test error case: beam_width = 0 - let res = Knn::new(10, 10, Some(0)); + let res = Knn::new(10, Some(0)); assert!(res.is_err()); let search_engine = @@ -1491,9 +1487,10 @@ mod disk_provider_tests { ); let query_vector: [f32; 128] = [1f32; 128]; - let mut indices = vec![0u32; 10]; - let mut distances = vec![0f32; 10]; - let mut associated_data = vec![(); 10]; + let k = 10; + let mut indices = vec![0u32; k]; + let mut distances = vec![0f32; k]; + let mut associated_data = vec![(); k]; let mut result_output_buffer = search_output_buffer::IdDistanceAssociatedData::new( &mut indices, @@ -1502,7 +1499,7 @@ mod disk_provider_tests { ); let strategy = search_engine.search_strategy(&|_| true); let mut search_record = VisitedSearchRecord::new(0); - let search_params = Knn::new(10, 10, Some(4)).unwrap(); + let search_params = Knn::new(10, Some(4)).unwrap(); let recorded_search = diskann::graph::search::RecordedKnn::new(search_params, &mut search_record); search_engine @@ -1528,11 +1525,10 @@ mod disk_provider_tests { assert_eq!(ids, &EXPECTED_NODES); - let return_list_size = 10; let search_list_size = 10; let result = search_engine.search( &query_vector, - return_list_size, + k as u32, search_list_size, Some(4), None, @@ -1541,8 +1537,8 @@ mod disk_provider_tests { assert!(result.is_ok(), "Expected search to succeed"); let search_result = result.unwrap(); assert_eq!( - search_result.results.len() as u32, - return_list_size, + search_result.results.len(), + k, "Expected result count to match" ); assert_eq!( @@ -1613,9 +1609,10 @@ mod disk_provider_tests { // Wrap in Arc once to avoid cloning the HashMap later let attribute_provider = std::sync::Arc::new(attribute_provider); - let mut indices = vec![0u32; 10]; - let mut distances = vec![0f32; 10]; - let mut associated_data = vec![(); 10]; + let original_k = 10; + let mut indices = vec![0u32; original_k]; + let mut distances = vec![0f32; original_k]; + let mut associated_data = vec![(); original_k]; let mut result_output_buffer = search_output_buffer::IdDistanceAssociatedData::new( &mut indices, @@ -1628,10 +1625,12 @@ mod disk_provider_tests { let diverse_params = DiverseSearchParams::new( 0, // diverse_attribute_id 3, // diverse_results_k + original_k, attribute_provider.clone(), - ); + ) + .unwrap(); - let search_params = Knn::new(10, 20, None).unwrap(); + let search_params = Knn::new(20, None).unwrap(); let diverse_search = diskann::graph::search::Diverse::new(search_params, diverse_params); let stats = search_engine @@ -1651,27 +1650,27 @@ mod disk_provider_tests { "Expected to get some results during diversity search" ); - let return_list_size = 10; let search_list_size = 20; let diverse_results_k = 1; let diverse_params = DiverseSearchParams::new( 0, // diverse_attribute_id diverse_results_k, + original_k, attribute_provider.clone(), - ); + ) + .unwrap(); // Test diverse search using the search API - let mut indices2 = vec![0u32; return_list_size as usize]; - let mut distances2 = vec![0f32; return_list_size as usize]; - let mut associated_data2 = vec![(); return_list_size as usize]; + let mut indices2 = vec![0u32; original_k]; + let mut distances2 = vec![0f32; original_k]; + let mut associated_data2 = vec![(); original_k]; let mut result_output_buffer2 = search_output_buffer::IdDistanceAssociatedData::new( &mut indices2, &mut distances2, &mut associated_data2, ); let strategy2 = search_engine.search_strategy(&|_| true); - let search_params2 = - Knn::new(return_list_size as usize, search_list_size as usize, None).unwrap(); + let search_params2 = Knn::new(search_list_size as usize, None).unwrap(); let diverse_search2 = diskann::graph::search::Diverse::new(search_params2, diverse_params); let stats = search_engine @@ -1691,9 +1690,9 @@ mod disk_provider_tests { "Expected diversity search to return results" ); assert!( - stats.result_count <= return_list_size, + stats.result_count <= original_k as u32, "Expected result count to be <= {}", - return_list_size + original_k ); // Verify that we got some results @@ -1943,9 +1942,10 @@ mod disk_provider_tests { ); let query_vector: [f32; 128] = [1f32; 128]; - let mut indices = vec![0u32; 10]; - let mut distances = vec![0f32; 10]; - let mut associated_data = vec![(); 10]; + let k = 10; + let mut indices = vec![0u32; k]; + let mut distances = vec![0f32; k]; + let mut associated_data = vec![(); k]; let mut result_output_buffer = search_output_buffer::IdDistanceAssociatedData::new( &mut indices, @@ -1956,7 +1956,7 @@ mod disk_provider_tests { let strategy = search_engine.search_strategy(&|_| true); let mut search_record = VisitedSearchRecord::new(0); - let search_params = Knn::new(10, 10, Some(4)).unwrap(); + let search_params = Knn::new(10, Some(4)).unwrap(); let recorded_search = diskann::graph::search::RecordedKnn::new(search_params, &mut search_record); search_engine From 10940adfc595b5fc2445ddf24ba11223906f13bb Mon Sep 17 00:00:00 2001 From: Magdalen Manohar Date: Fri, 12 Jun 2026 18:28:14 +0000 Subject: [PATCH 09/10] address changes in more crates --- diskann-garnet/src/lib.rs | 2 -- diskann-providers/src/index/diskann_async.rs | 28 ++++++++++---------- diskann-providers/src/index/wrapped_async.rs | 2 +- 3 files changed, 15 insertions(+), 17 deletions(-) diff --git a/diskann-garnet/src/lib.rs b/diskann-garnet/src/lib.rs index f60e1eaea..accbeacf9 100644 --- a/diskann-garnet/src/lib.rs +++ b/diskann-garnet/src/lib.rs @@ -597,7 +597,6 @@ pub unsafe extern "C" fn search_vector( ); let params = match search::Knn::new( - output_distances_len, search_exploration_factor as usize, None, ) { @@ -662,7 +661,6 @@ pub unsafe extern "C" fn search_element( ); let params = match search::Knn::new( - output_distances_len, search_exploration_factor as usize, None, ) { diff --git a/diskann-providers/src/index/diskann_async.rs b/diskann-providers/src/index/diskann_async.rs index be512cfd0..791fa4361 100644 --- a/diskann-providers/src/index/diskann_async.rs +++ b/diskann-providers/src/index/diskann_async.rs @@ -344,7 +344,7 @@ pub(crate) mod tests { let mut result_output_buffer = search_output_buffer::IdDistance::new(&mut ids, &mut distances); let graph_search = - graph::search::Knn::new_default(parameters.search_k, parameters.search_l).unwrap(); + graph::search::Knn::new_default(parameters.search_l).unwrap(); index .search( graph_search, @@ -1433,7 +1433,7 @@ pub(crate) mod tests { { let mut result_output_buffer = search_output_buffer::IdDistance::new(&mut ids, &mut distances); - let graph_search = graph::search::Knn::new_default(top_k, search_l).unwrap(); + let graph_search = graph::search::Knn::new_default(search_l).unwrap(); // Full Precision Search. index .search( @@ -1451,7 +1451,7 @@ pub(crate) mod tests { { let mut result_output_buffer = search_output_buffer::IdDistance::new(&mut ids, &mut distances); - let graph_search = graph::search::Knn::new_default(top_k, search_l).unwrap(); + let graph_search = graph::search::Knn::new_default(search_l).unwrap(); // Quantized Search index .search( @@ -1692,7 +1692,7 @@ pub(crate) mod tests { let mut result_output_buffer = search_output_buffer::IdDistance::new(&mut ids, &mut distances); let graph_search = - graph::search::Knn::new_default(top_k, search_l).unwrap(); + graph::search::Knn::new_default(search_l).unwrap(); // Full Precision Search. index .search( @@ -1711,7 +1711,7 @@ pub(crate) mod tests { let mut result_output_buffer = search_output_buffer::IdDistance::new(&mut ids, &mut distances); let graph_search = - graph::search::Knn::new_default(top_k, search_l).unwrap(); + graph::search::Knn::new_default(search_l).unwrap(); // Quantized Search index .search( @@ -1798,7 +1798,7 @@ pub(crate) mod tests { { let mut result_output_buffer = search_output_buffer::IdDistance::new(&mut ids, &mut distances); - let graph_search = graph::search::Knn::new_default(top_k, top_k).unwrap(); + let graph_search = graph::search::Knn::new_default(top_k).unwrap(); // Quantized Search index .search( @@ -1912,7 +1912,7 @@ pub(crate) mod tests { // Full Precision Search. let mut output = search_output_buffer::IdDistance::new(&mut ids, &mut distances); - let graph_search = graph::search::Knn::new_default(top_k, search_l).unwrap(); + let graph_search = graph::search::Knn::new_default(search_l).unwrap(); index .search(graph_search, &FullPrecision, ctx, query, &mut output) .await @@ -1924,7 +1924,7 @@ pub(crate) mod tests { let strategy = inmem::spherical::Quantized::search( diskann_quantization::spherical::iface::QueryLayout::FourBitTransposed, ); - let graph_search = graph::search::Knn::new_default(top_k, search_l).unwrap(); + let graph_search = graph::search::Knn::new_default(search_l).unwrap(); index .search(graph_search, &strategy, ctx, query, &mut output) @@ -2028,7 +2028,7 @@ pub(crate) mod tests { let strategy = inmem::spherical::Quantized::search( diskann_quantization::spherical::iface::QueryLayout::FourBitTransposed, ); - let graph_search = graph::search::Knn::new_default(top_k, search_l).unwrap(); + let graph_search = graph::search::Knn::new_default(search_l).unwrap(); index .search(graph_search, &strategy, ctx, query, &mut output) @@ -2114,7 +2114,7 @@ pub(crate) mod tests { let mut result_output_buffer = search_output_buffer::IdDistance::new(&mut ids, &mut distances); - let graph_search = graph::search::Knn::new_default(top_k, search_l).unwrap(); + let graph_search = graph::search::Knn::new_default(search_l).unwrap(); // Full Precision Search. index .search( @@ -2693,7 +2693,7 @@ pub(crate) mod tests { let gt = groundtruth(queries.as_view(), query, |a, b| SquaredL2::evaluate(a, b)); let mut result_output_buffer = search_output_buffer::IdDistance::new(&mut ids, &mut distances); - let graph_search = graph::search::Knn::new_default(top_k, search_l).unwrap(); + let graph_search = graph::search::Knn::new_default(search_l).unwrap(); // Full Precision Search. index .search( @@ -2960,11 +2960,11 @@ pub(crate) mod tests { let diverse_params = diskann::graph::DiverseSearchParams::new( 0, // diverse_attribute_id diverse_results_k, + return_list_size, attribute_provider.clone(), - ); + ).unwrap(); let search_params = diskann::graph::search::Knn::new( - return_list_size, search_list_size, None, // beam_width ) @@ -3121,7 +3121,7 @@ pub(crate) mod tests { let mut ids = vec![0; top_k]; let mut distances = vec![0.0; top_k]; let ctx = DefaultContext; - let search_params = graph::search::Knn::new_default(top_k, search_l).unwrap(); + let search_params = graph::search::Knn::new_default(search_l).unwrap(); for i in 0..query_count { let query_vector = &queries[i * VECTORS_DIMENSION..(i + 1) * VECTORS_DIMENSION]; diff --git a/diskann-providers/src/index/wrapped_async.rs b/diskann-providers/src/index/wrapped_async.rs index 29e989aef..a8f3c016c 100644 --- a/diskann-providers/src/index/wrapped_async.rs +++ b/diskann-providers/src/index/wrapped_async.rs @@ -773,7 +773,7 @@ mod tests { let mut output = search_output_buffer::IdDistance::new(&mut ids, &mut distances); let query = train_data.row(0); - let kind = graph::search::Knn::new_default(top_k, search_l).unwrap(); + let kind = graph::search::Knn::new_default(search_l).unwrap(); let stats = loaded .search(kind, &FullPrecision, &DefaultContext, query, &mut output) .unwrap(); From e0ca6f65a88bb29990a82583541351271137637a Mon Sep 17 00:00:00 2001 From: Magdalen Manohar Date: Fri, 12 Jun 2026 21:40:57 +0000 Subject: [PATCH 10/10] fmt --- diskann-garnet/src/lib.rs | 10 ++-------- diskann-providers/src/index/diskann_async.rs | 12 +++++------- 2 files changed, 7 insertions(+), 15 deletions(-) diff --git a/diskann-garnet/src/lib.rs b/diskann-garnet/src/lib.rs index accbeacf9..37658954e 100644 --- a/diskann-garnet/src/lib.rs +++ b/diskann-garnet/src/lib.rs @@ -596,10 +596,7 @@ pub unsafe extern "C" fn search_vector( output_distances_len, ); - let params = match search::Knn::new( - search_exploration_factor as usize, - None, - ) { + let params = match search::Knn::new(search_exploration_factor as usize, None) { Ok(params) => params, Err(_) => return -1, }; @@ -660,10 +657,7 @@ pub unsafe extern "C" fn search_element( output_distances_len, ); - let params = match search::Knn::new( - search_exploration_factor as usize, - None, - ) { + let params = match search::Knn::new(search_exploration_factor as usize, None) { Ok(params) => params, Err(_) => return -1, }; diff --git a/diskann-providers/src/index/diskann_async.rs b/diskann-providers/src/index/diskann_async.rs index 791fa4361..53e4b7988 100644 --- a/diskann-providers/src/index/diskann_async.rs +++ b/diskann-providers/src/index/diskann_async.rs @@ -343,8 +343,7 @@ pub(crate) mod tests { let mut distances = vec![0.0; parameters.search_k]; let mut result_output_buffer = search_output_buffer::IdDistance::new(&mut ids, &mut distances); - let graph_search = - graph::search::Knn::new_default(parameters.search_l).unwrap(); + let graph_search = graph::search::Knn::new_default(parameters.search_l).unwrap(); index .search( graph_search, @@ -1691,8 +1690,7 @@ pub(crate) mod tests { { let mut result_output_buffer = search_output_buffer::IdDistance::new(&mut ids, &mut distances); - let graph_search = - graph::search::Knn::new_default(search_l).unwrap(); + let graph_search = graph::search::Knn::new_default(search_l).unwrap(); // Full Precision Search. index .search( @@ -1710,8 +1708,7 @@ pub(crate) mod tests { { let mut result_output_buffer = search_output_buffer::IdDistance::new(&mut ids, &mut distances); - let graph_search = - graph::search::Knn::new_default(search_l).unwrap(); + let graph_search = graph::search::Knn::new_default(search_l).unwrap(); // Quantized Search index .search( @@ -2962,7 +2959,8 @@ pub(crate) mod tests { diverse_results_k, return_list_size, attribute_provider.clone(), - ).unwrap(); + ) + .unwrap(); let search_params = diskann::graph::search::Knn::new( search_list_size,