diff --git a/diskann-benchmark-core/src/search/graph/inline.rs b/diskann-benchmark-core/src/search/graph/inline.rs index f07bbc200..1874f3f85 100644 --- a/diskann-benchmark-core/src/search/graph/inline.rs +++ b/diskann-benchmark-core/src/search/graph/inline.rs @@ -12,7 +12,7 @@ use diskann::{ }; use diskann_utils::{future::AsyncFriendly, views::Matrix}; -use crate::search::{self, Search, graph::Strategy}; +use crate::search::{self, Search, graph::KnnWrapper, graph::Strategy}; /// A built-in helper for benchmarking filtered K-nearest neighbors search /// using the inline search method. @@ -22,7 +22,7 @@ use crate::search::{self, Search, graph::Strategy}; /// [`search::search_all`] is provided by the [`search::graph::knn::Aggregator`] type (same /// aggregator as [`search::graph::knn::KNN`]). /// -/// The provided implementation of [`Search`] accepts [`graph::search::Knn`] +/// The provided implementation of [`Search`] accepts [`KnnWrapper`] /// and returns [`search::graph::knn::Metrics`] as additional output. #[derive(Debug)] pub struct InlineFilterSearch @@ -93,7 +93,7 @@ where T: AsyncFriendly + Clone, { type Id = DP::ExternalId; - type Parameters = graph::search::Knn; + type Parameters = KnnWrapper; type Output = super::knn::Metrics; fn num_queries(&self) -> usize { @@ -115,7 +115,7 @@ where { let context = DP::Context::default(); let inline_search = graph::search::InlineFilterSearch::new( - *parameters, + parameters.knn, &*self.labels[index], self.adaptive_l.clone(), ); @@ -192,7 +192,7 @@ mod tests { let rt = crate::tokio::runtime(2).unwrap(); let results = search::search( inline.clone(), - graph::search::Knn::new(nearest_neighbors, 10, None).unwrap(), + KnnWrapper::new(nearest_neighbors, 10).unwrap(), NonZeroUsize::new(2).unwrap(), &rt, ) @@ -220,11 +220,11 @@ mod tests { // Try the aggregated strategy. let parameters = [ search::Run::new( - graph::search::Knn::new(nearest_neighbors, 10, None).unwrap(), + KnnWrapper::new(nearest_neighbors, 10).unwrap(), setup.clone(), ), search::Run::new( - graph::search::Knn::new(nearest_neighbors, 15, None).unwrap(), + KnnWrapper::new(nearest_neighbors, 15).unwrap(), setup.clone(), ), ]; diff --git a/diskann-benchmark-core/src/search/graph/knn.rs b/diskann-benchmark-core/src/search/graph/knn.rs index f430ed099..512706ffa 100644 --- a/diskann-benchmark-core/src/search/graph/knn.rs +++ b/diskann-benchmark-core/src/search/graph/knn.rs @@ -5,10 +5,11 @@ //! A built-in helper for benchmarking K-nearest neighbors. -use std::sync::Arc; +use std::{num::NonZeroUsize, sync::Arc}; +use thiserror::Error; use diskann::{ - ANNResult, + ANNError, ANNResult, graph::{self, glue}, provider, }; @@ -30,7 +31,7 @@ use crate::{ /// the latter. Result aggregation for [`search::search_all`] is provided /// by the [`Aggregator`] type. /// -/// The provided implementation of [`Search`] accepts [`graph::search::Knn`] +/// The provided implementation of [`Search`] accepts [`KnnWrapper`] /// and returns [`Metrics`] as additional output. #[derive(Debug)] pub struct KNN @@ -93,7 +94,7 @@ where T: AsyncFriendly + Clone, { type Id = DP::ExternalId; - type Parameters = graph::search::Knn; + type Parameters = KnnWrapper; type Output = Metrics; fn num_queries(&self) -> usize { @@ -114,7 +115,7 @@ where O: graph::SearchOutputBuffer + Send, { let context = DP::Context::default(); - let knn_search = *parameters; + let knn_search = parameters.knn; let stats = self .index .search( @@ -133,6 +134,51 @@ where } } +#[derive(Debug, Error)] +pub enum KnnWrapperError { + #[error("k_value cannot be zero")] + KZero, + #[error("l_value ({l_value}) must be at least k_value ({k_value})")] + LLessThanK { l_value: usize, k_value: usize }, + #[error("invalid KNN parameters")] + InvalidKnnParameters, +} + +impl From for ANNError { + #[track_caller] + fn from(err: KnnWrapperError) -> Self { + ANNError::opaque(err) + } +} + +/// A wrapper for the [`graph::search::Knn`] struct that also includes the `k` value. +#[derive(Debug, Copy, Clone)] +pub struct KnnWrapper { + k_value: NonZeroUsize, + pub knn: graph::search::Knn, +} + +impl KnnWrapper { + /// Construct a new [`KnnWrapper`]. + pub fn new(k_value: usize, l_value: usize) -> Result { + let k_value = NonZeroUsize::new(k_value).ok_or(KnnWrapperError::KZero)?; + if l_value < k_value.get() { + return Err(KnnWrapperError::LLessThanK { + l_value, + k_value: k_value.get(), + }); + } + + let knn = graph::search::Knn::new(l_value, None) + .map_err(|_| KnnWrapperError::InvalidKnnParameters)?; + Ok(Self { k_value, knn }) + } + + pub fn k_value(&self) -> NonZeroUsize { + self.k_value + } +} + /// An [`search::Aggregate`]d summary of multiple [`KNN`] search runs /// returned by the provided [`Aggregator`]. /// @@ -144,7 +190,7 @@ pub struct Summary { pub setup: search::Setup, /// The [`Search::Parameters`] used for the batch of runs. - pub parameters: graph::search::Knn, + pub parameters: KnnWrapper, /// The end-to-end latency for each repetition in the batch. pub end_to_end_latencies: Vec, @@ -212,7 +258,7 @@ impl<'a, I> Aggregator<'a, I> { } } -impl search::Aggregate for Aggregator<'_, I> +impl search::Aggregate for Aggregator<'_, I> where I: crate::recall::RecallCompatible, { @@ -220,7 +266,7 @@ where fn aggregate( &mut self, - run: search::Run, + run: search::Run, mut results: Vec>, ) -> anyhow::Result { // Compute the recall using just the first result. @@ -317,7 +363,7 @@ mod tests { let rt = crate::tokio::runtime(2).unwrap(); let results = search::search( knn.clone(), - graph::search::Knn::new(nearest_neighbors, 10, None).unwrap(), + KnnWrapper::new(nearest_neighbors, 10).unwrap(), NonZeroUsize::new(2).unwrap(), &rt, ) @@ -341,11 +387,11 @@ mod tests { // Try the aggregated strategy. let parameters = [ search::Run::new( - graph::search::Knn::new(nearest_neighbors, 10, None).unwrap(), + KnnWrapper::new(nearest_neighbors, 10).unwrap(), setup.clone(), ), search::Run::new( - graph::search::Knn::new(nearest_neighbors, 15, None).unwrap(), + KnnWrapper::new(nearest_neighbors, 15).unwrap(), setup.clone(), ), ]; diff --git a/diskann-benchmark-core/src/search/graph/mod.rs b/diskann-benchmark-core/src/search/graph/mod.rs index 8063f0875..01bda088c 100644 --- a/diskann-benchmark-core/src/search/graph/mod.rs +++ b/diskann-benchmark-core/src/search/graph/mod.rs @@ -11,7 +11,7 @@ pub mod range; pub mod strategy; pub use inline::InlineFilterSearch; -pub use knn::KNN; +pub use knn::{KNN, KnnWrapper}; pub use multihop::MultiHop; pub use range::Range; diff --git a/diskann-benchmark-core/src/search/graph/multihop.rs b/diskann-benchmark-core/src/search/graph/multihop.rs index 87bad7d01..50ccd034b 100644 --- a/diskann-benchmark-core/src/search/graph/multihop.rs +++ b/diskann-benchmark-core/src/search/graph/multihop.rs @@ -12,7 +12,7 @@ use diskann::{ }; use diskann_utils::{future::AsyncFriendly, views::Matrix}; -use crate::search::{self, Search, graph::Strategy}; +use crate::search::{self, Search, graph::KnnWrapper, graph::Strategy}; /// A built-in helper for benchmarking filtered K-nearest neighbors search /// using the multi-hop search method. @@ -22,7 +22,7 @@ use crate::search::{self, Search, graph::Strategy}; /// [`search::search_all`] is provided by the [`search::graph::knn::Aggregator`] type (same /// aggregator as [`search::graph::knn::KNN`]). /// -/// The provided implementation of [`Search`] accepts [`graph::search::Knn`] +/// The provided implementation of [`Search`] accepts [`KnnWrapper`] /// and returns [`search::graph::knn::Metrics`] as additional output. #[derive(Debug)] pub struct MultiHop @@ -90,7 +90,7 @@ where T: AsyncFriendly + Clone, { type Id = DP::ExternalId; - type Parameters = graph::search::Knn; + type Parameters = KnnWrapper; type Output = super::knn::Metrics; fn num_queries(&self) -> usize { @@ -112,7 +112,7 @@ where { let context = DP::Context::default(); let multihop_search = - graph::search::MultihopFilterSearch::new(*parameters, &*self.labels[index]); + graph::search::MultihopFilterSearch::new(parameters.knn, &*self.labels[index]); let stats = self .index .search( @@ -183,7 +183,7 @@ mod tests { let rt = crate::tokio::runtime(2).unwrap(); let results = search::search( multihop.clone(), - graph::search::Knn::new(nearest_neighbors, 10, None).unwrap(), + KnnWrapper::new(nearest_neighbors, 10).unwrap(), NonZeroUsize::new(2).unwrap(), &rt, ) @@ -211,11 +211,11 @@ mod tests { // Try the aggregated strategy. let parameters = [ search::Run::new( - graph::search::Knn::new(nearest_neighbors, 10, None).unwrap(), + KnnWrapper::new(nearest_neighbors, 10).unwrap(), setup.clone(), ), search::Run::new( - graph::search::Knn::new(nearest_neighbors, 15, None).unwrap(), + KnnWrapper::new(nearest_neighbors, 15).unwrap(), setup.clone(), ), ]; diff --git a/diskann-benchmark/src/backend/index/result.rs b/diskann-benchmark/src/backend/index/result.rs index a68650d14..75d9f54f8 100644 --- a/diskann-benchmark/src/backend/index/result.rs +++ b/diskann-benchmark/src/backend/index/result.rs @@ -120,7 +120,7 @@ impl SearchResults { Self { num_tasks: setup.tasks.into(), search_n: parameters.k_value().get(), - search_l: parameters.l_value().get(), + search_l: parameters.knn.l_value().get(), qps, search_latencies: end_to_end_latencies, mean_latencies, diff --git a/diskann-benchmark/src/backend/index/search/knn.rs b/diskann-benchmark/src/backend/index/search/knn.rs index 2138ee156..c28f5b73f 100644 --- a/diskann-benchmark/src/backend/index/search/knn.rs +++ b/diskann-benchmark/src/backend/index/search/knn.rs @@ -5,8 +5,8 @@ use std::{num::NonZeroUsize, sync::Arc}; -use diskann_benchmark_core::recall::GroundTruthMode; use diskann_benchmark_core::{self as benchmark_core, search as core_search}; +use diskann_benchmark_core::{recall::GroundTruthMode, search::graph::KnnWrapper}; use crate::{backend::index::result::SearchResults, inputs::graph_index::GraphSearch}; @@ -50,8 +50,7 @@ pub(crate) fn run( .search_l .iter() .map(|search_l| { - let search_params = - diskann::graph::search::Knn::new(run.search_n, *search_l, None).unwrap(); + let search_params = KnnWrapper::new(run.search_n, *search_l).unwrap(); core_search::Run::new(search_params, setup.clone()) }) @@ -64,7 +63,7 @@ pub(crate) fn run( Ok(all) } -type Run = core_search::Run; +type Run = core_search::Run; pub(crate) trait Knn { fn search_all( &self, @@ -84,13 +83,13 @@ where DP: diskann::provider::DataProvider, core_search::graph::KNN: core_search::Search< Id = DP::InternalId, - Parameters = diskann::graph::search::Knn, + Parameters = KnnWrapper, Output = core_search::graph::knn::Metrics, >, { fn search_all( &self, - parameters: Vec>, + parameters: Vec>, groundtruth: &dyn benchmark_core::recall::Rows, recall_k: usize, recall_n: usize, @@ -115,13 +114,13 @@ where DP: diskann::provider::DataProvider, core_search::graph::MultiHop: core_search::Search< Id = DP::InternalId, - Parameters = diskann::graph::search::Knn, + Parameters = KnnWrapper, Output = core_search::graph::knn::Metrics, >, { fn search_all( &self, - parameters: Vec>, + parameters: Vec>, groundtruth: &dyn benchmark_core::recall::Rows, recall_k: usize, recall_n: usize, @@ -146,13 +145,13 @@ where DP: diskann::provider::DataProvider, core_search::graph::InlineFilterSearch: core_search::Search< Id = DP::InternalId, - Parameters = diskann::graph::search::Knn, + Parameters = KnnWrapper, Output = core_search::graph::knn::Metrics, >, { fn search_all( &self, - parameters: Vec>, + parameters: Vec>, groundtruth: &dyn benchmark_core::recall::Rows, recall_k: usize, recall_n: usize, diff --git a/diskann-bftree/src/provider.rs b/diskann-bftree/src/provider.rs index 3adaae0f0..5946d4460 100644 --- a/diskann-bftree/src/provider.rs +++ b/diskann-bftree/src/provider.rs @@ -2033,9 +2033,10 @@ mod tests { } let query = vec![3.0; 5]; - let params = Knn::new(5, 10, None).unwrap(); + let params = Knn::new(10, None).unwrap(); - let mut neighbors = vec![Neighbor::::default(); 5]; + let k = 5; + let mut neighbors = vec![Neighbor::::default(); k]; let res = index .search( params, @@ -2048,8 +2049,9 @@ mod tests { .unwrap(); assert_eq!( - res.result_count, 5, - "there are 15 points and we're asking for 5, we expect 5" + res.result_count, k as u32, + "there are 15 points and we're asking for {}, we expect {}", + 5, k ); assert_eq!(neighbors[0].id, 3); } @@ -2084,9 +2086,10 @@ mod tests { .unwrap(); let query = vec![3.0; 5]; - let params = Knn::new(5, 10, None).unwrap(); + let params = Knn::new(10, None).unwrap(); - let mut neighbors = vec![Neighbor::::default(); 5]; + let k = 5; + let mut neighbors = vec![Neighbor::::default(); k]; let res = index .search( params, @@ -2099,8 +2102,9 @@ mod tests { .unwrap(); assert_eq!( - res.result_count, 5, - "there are 15 points and we're asking for 5, we expect 5" + res.result_count, k as u32, + "there are 15 points and we're asking for {}, we expect {}", + 5, k ); let neighbor_ids: Vec = neighbors.iter().map(|n| n.id).collect(); for expected in 1u32..=5 { @@ -2134,9 +2138,10 @@ mod tests { .unwrap(); let query = vec![3.0; 5]; - let params = Knn::new(5, 10, None).unwrap(); + let params = Knn::new(10, None).unwrap(); - let mut neighbors = vec![Neighbor::::default(); 5]; + let k = 5; + let mut neighbors = vec![Neighbor::::default(); k]; let res = index .search( params, @@ -2148,7 +2153,7 @@ mod tests { .await .unwrap(); - assert_eq!(res.result_count, 5); + assert_eq!(res.result_count, k as u32); let neighbor_ids: Vec = neighbors.iter().map(|n| n.id).collect(); assert!(!neighbor_ids.contains(&2u32)); assert!(!neighbor_ids.contains(&4u32)); @@ -2203,9 +2208,10 @@ mod tests { } let query = vec![3.0; 5]; - let params = Knn::new(5, 10, None).unwrap(); + let params = Knn::new(10, None).unwrap(); - let mut neighbors = vec![Neighbor::::default(); 5]; + let k = 5; + let mut neighbors = vec![Neighbor::::default(); k]; let res = index .search( params, @@ -2218,8 +2224,9 @@ mod tests { .unwrap(); assert_eq!( - res.result_count, 5, - "there are 15 points and we're asking for 5, we expect 5" + res.result_count, k as u32, + "there are 15 points and we're asking for {}, we expect {}", + 5, k ); assert_eq!(neighbors[0].id, 3); } @@ -2259,9 +2266,10 @@ mod tests { .unwrap(); let query = vec![3.0; 5]; - let params = Knn::new(5, 10, None).unwrap(); + let params = Knn::new(10, None).unwrap(); - let mut neighbors = vec![Neighbor::::default(); 5]; + let k = 5; + let mut neighbors = vec![Neighbor::::default(); k]; let res = index .search( params, @@ -2273,7 +2281,7 @@ mod tests { .await .unwrap(); - assert_eq!(res.result_count, 5); + assert_eq!(res.result_count, k as u32); let neighbor_ids: Vec = neighbors.iter().map(|n| n.id).collect(); assert!(!neighbor_ids.contains(&2u32)); assert!(!neighbor_ids.contains(&4u32)); diff --git a/diskann-disk/src/search/provider/disk_provider.rs b/diskann-disk/src/search/provider/disk_provider.rs index 5e49cd4dd..ddf458735 100644 --- a/diskann-disk/src/search/provider/disk_provider.rs +++ b/diskann-disk/src/search/provider/disk_provider.rs @@ -899,7 +899,6 @@ where let strategy = self.search_strategy(vector_filter); let timer = Instant::now(); - let k = k_value; let l = search_list_size as usize; let stats = if is_flat_search { self.runtime.block_on(self.flat_search( @@ -910,7 +909,13 @@ where &mut result_output_buffer, ))? } else { - let knn_search = Knn::new(k, l, beam_width)?; + if search_list_size < k_value as u32 { + return Err(ANNError::message( + diskann::ANNErrorKind::IndexError, + "search list size must be at least as large as the number of results requested", + )); + } + let knn_search = Knn::new(l, beam_width)?; self.runtime.block_on(self.index.search( knn_search, &strategy, @@ -959,10 +964,7 @@ fn ensure_vertex_loaded>( mod disk_provider_tests { use crate::test_utils::{GraphDataF32VectorU32Data, GraphDataF32VectorUnitData}; use diskann::{ - graph::{ - search::{record::VisitedSearchRecord, Knn}, - KnnSearchError, - }, + graph::search::{record::VisitedSearchRecord, Knn}, utils::IntoUsize, ANNErrorKind, }; @@ -1410,15 +1412,8 @@ mod disk_provider_tests { "index_path is not correct" ); - // Test error case: l < k - let res = Knn::new_default(20, 10); - assert!(res.is_err()); - assert_eq!( - >::into(res.unwrap_err()).kind(), - ANNErrorKind::IndexError - ); // Test error case: beam_width = 0 - let res = Knn::new(10, 10, Some(0)); + let res = Knn::new(10, Some(0)); assert!(res.is_err()); let search_engine = @@ -1492,9 +1487,10 @@ mod disk_provider_tests { ); let query_vector: [f32; 128] = [1f32; 128]; - let mut indices = vec![0u32; 10]; - let mut distances = vec![0f32; 10]; - let mut associated_data = vec![(); 10]; + let k = 10; + let mut indices = vec![0u32; k]; + let mut distances = vec![0f32; k]; + let mut associated_data = vec![(); k]; let mut result_output_buffer = search_output_buffer::IdDistanceAssociatedData::new( &mut indices, @@ -1503,7 +1499,7 @@ mod disk_provider_tests { ); let strategy = search_engine.search_strategy(&|_| true); let mut search_record = VisitedSearchRecord::new(0); - let search_params = Knn::new(10, 10, Some(4)).unwrap(); + let search_params = Knn::new(10, Some(4)).unwrap(); let recorded_search = diskann::graph::search::RecordedKnn::new(search_params, &mut search_record); search_engine @@ -1529,11 +1525,10 @@ mod disk_provider_tests { assert_eq!(ids, &EXPECTED_NODES); - let return_list_size = 10; let search_list_size = 10; let result = search_engine.search( &query_vector, - return_list_size, + k as u32, search_list_size, Some(4), None, @@ -1542,8 +1537,8 @@ mod disk_provider_tests { assert!(result.is_ok(), "Expected search to succeed"); let search_result = result.unwrap(); assert_eq!( - search_result.results.len() as u32, - return_list_size, + search_result.results.len(), + k, "Expected result count to match" ); assert_eq!( @@ -1614,9 +1609,10 @@ mod disk_provider_tests { // Wrap in Arc once to avoid cloning the HashMap later let attribute_provider = std::sync::Arc::new(attribute_provider); - let mut indices = vec![0u32; 10]; - let mut distances = vec![0f32; 10]; - let mut associated_data = vec![(); 10]; + let original_k = 10; + let mut indices = vec![0u32; original_k]; + let mut distances = vec![0f32; original_k]; + let mut associated_data = vec![(); original_k]; let mut result_output_buffer = search_output_buffer::IdDistanceAssociatedData::new( &mut indices, @@ -1629,10 +1625,12 @@ mod disk_provider_tests { let diverse_params = DiverseSearchParams::new( 0, // diverse_attribute_id 3, // diverse_results_k + original_k, attribute_provider.clone(), - ); + ) + .unwrap(); - let search_params = Knn::new(10, 20, None).unwrap(); + let search_params = Knn::new(20, None).unwrap(); let diverse_search = diskann::graph::search::Diverse::new(search_params, diverse_params); let stats = search_engine @@ -1652,27 +1650,27 @@ mod disk_provider_tests { "Expected to get some results during diversity search" ); - let return_list_size = 10; let search_list_size = 20; let diverse_results_k = 1; let diverse_params = DiverseSearchParams::new( 0, // diverse_attribute_id diverse_results_k, + original_k, attribute_provider.clone(), - ); + ) + .unwrap(); // Test diverse search using the search API - let mut indices2 = vec![0u32; return_list_size as usize]; - let mut distances2 = vec![0f32; return_list_size as usize]; - let mut associated_data2 = vec![(); return_list_size as usize]; + let mut indices2 = vec![0u32; original_k]; + let mut distances2 = vec![0f32; original_k]; + let mut associated_data2 = vec![(); original_k]; let mut result_output_buffer2 = search_output_buffer::IdDistanceAssociatedData::new( &mut indices2, &mut distances2, &mut associated_data2, ); let strategy2 = search_engine.search_strategy(&|_| true); - let search_params2 = - Knn::new(return_list_size as usize, search_list_size as usize, None).unwrap(); + let search_params2 = Knn::new(search_list_size as usize, None).unwrap(); let diverse_search2 = diskann::graph::search::Diverse::new(search_params2, diverse_params); let stats = search_engine @@ -1692,9 +1690,9 @@ mod disk_provider_tests { "Expected diversity search to return results" ); assert!( - stats.result_count <= return_list_size, + stats.result_count <= original_k as u32, "Expected result count to be <= {}", - return_list_size + original_k ); // Verify that we got some results @@ -1944,9 +1942,10 @@ mod disk_provider_tests { ); let query_vector: [f32; 128] = [1f32; 128]; - let mut indices = vec![0u32; 10]; - let mut distances = vec![0f32; 10]; - let mut associated_data = vec![(); 10]; + let k = 10; + let mut indices = vec![0u32; k]; + let mut distances = vec![0f32; k]; + let mut associated_data = vec![(); k]; let mut result_output_buffer = search_output_buffer::IdDistanceAssociatedData::new( &mut indices, @@ -1957,7 +1956,7 @@ mod disk_provider_tests { let strategy = search_engine.search_strategy(&|_| true); let mut search_record = VisitedSearchRecord::new(0); - let search_params = Knn::new(10, 10, Some(4)).unwrap(); + let search_params = Knn::new(10, Some(4)).unwrap(); let recorded_search = diskann::graph::search::RecordedKnn::new(search_params, &mut search_record); search_engine diff --git a/diskann-garnet/src/lib.rs b/diskann-garnet/src/lib.rs index f60e1eaea..37658954e 100644 --- a/diskann-garnet/src/lib.rs +++ b/diskann-garnet/src/lib.rs @@ -596,11 +596,7 @@ pub unsafe extern "C" fn search_vector( output_distances_len, ); - let params = match search::Knn::new( - output_distances_len, - search_exploration_factor as usize, - None, - ) { + let params = match search::Knn::new(search_exploration_factor as usize, None) { Ok(params) => params, Err(_) => return -1, }; @@ -661,11 +657,7 @@ pub unsafe extern "C" fn search_element( output_distances_len, ); - let params = match search::Knn::new( - output_distances_len, - search_exploration_factor as usize, - None, - ) { + let params = match search::Knn::new(search_exploration_factor as usize, None) { Ok(params) => params, Err(_) => return -1, }; diff --git a/diskann-providers/src/index/diskann_async.rs b/diskann-providers/src/index/diskann_async.rs index be512cfd0..53e4b7988 100644 --- a/diskann-providers/src/index/diskann_async.rs +++ b/diskann-providers/src/index/diskann_async.rs @@ -343,8 +343,7 @@ pub(crate) mod tests { let mut distances = vec![0.0; parameters.search_k]; let mut result_output_buffer = search_output_buffer::IdDistance::new(&mut ids, &mut distances); - let graph_search = - graph::search::Knn::new_default(parameters.search_k, parameters.search_l).unwrap(); + let graph_search = graph::search::Knn::new_default(parameters.search_l).unwrap(); index .search( graph_search, @@ -1433,7 +1432,7 @@ pub(crate) mod tests { { let mut result_output_buffer = search_output_buffer::IdDistance::new(&mut ids, &mut distances); - let graph_search = graph::search::Knn::new_default(top_k, search_l).unwrap(); + let graph_search = graph::search::Knn::new_default(search_l).unwrap(); // Full Precision Search. index .search( @@ -1451,7 +1450,7 @@ pub(crate) mod tests { { let mut result_output_buffer = search_output_buffer::IdDistance::new(&mut ids, &mut distances); - let graph_search = graph::search::Knn::new_default(top_k, search_l).unwrap(); + let graph_search = graph::search::Knn::new_default(search_l).unwrap(); // Quantized Search index .search( @@ -1691,8 +1690,7 @@ pub(crate) mod tests { { let mut result_output_buffer = search_output_buffer::IdDistance::new(&mut ids, &mut distances); - let graph_search = - graph::search::Knn::new_default(top_k, search_l).unwrap(); + let graph_search = graph::search::Knn::new_default(search_l).unwrap(); // Full Precision Search. index .search( @@ -1710,8 +1708,7 @@ pub(crate) mod tests { { let mut result_output_buffer = search_output_buffer::IdDistance::new(&mut ids, &mut distances); - let graph_search = - graph::search::Knn::new_default(top_k, search_l).unwrap(); + let graph_search = graph::search::Knn::new_default(search_l).unwrap(); // Quantized Search index .search( @@ -1798,7 +1795,7 @@ pub(crate) mod tests { { let mut result_output_buffer = search_output_buffer::IdDistance::new(&mut ids, &mut distances); - let graph_search = graph::search::Knn::new_default(top_k, top_k).unwrap(); + let graph_search = graph::search::Knn::new_default(top_k).unwrap(); // Quantized Search index .search( @@ -1912,7 +1909,7 @@ pub(crate) mod tests { // Full Precision Search. let mut output = search_output_buffer::IdDistance::new(&mut ids, &mut distances); - let graph_search = graph::search::Knn::new_default(top_k, search_l).unwrap(); + let graph_search = graph::search::Knn::new_default(search_l).unwrap(); index .search(graph_search, &FullPrecision, ctx, query, &mut output) .await @@ -1924,7 +1921,7 @@ pub(crate) mod tests { let strategy = inmem::spherical::Quantized::search( diskann_quantization::spherical::iface::QueryLayout::FourBitTransposed, ); - let graph_search = graph::search::Knn::new_default(top_k, search_l).unwrap(); + let graph_search = graph::search::Knn::new_default(search_l).unwrap(); index .search(graph_search, &strategy, ctx, query, &mut output) @@ -2028,7 +2025,7 @@ pub(crate) mod tests { let strategy = inmem::spherical::Quantized::search( diskann_quantization::spherical::iface::QueryLayout::FourBitTransposed, ); - let graph_search = graph::search::Knn::new_default(top_k, search_l).unwrap(); + let graph_search = graph::search::Knn::new_default(search_l).unwrap(); index .search(graph_search, &strategy, ctx, query, &mut output) @@ -2114,7 +2111,7 @@ pub(crate) mod tests { let mut result_output_buffer = search_output_buffer::IdDistance::new(&mut ids, &mut distances); - let graph_search = graph::search::Knn::new_default(top_k, search_l).unwrap(); + let graph_search = graph::search::Knn::new_default(search_l).unwrap(); // Full Precision Search. index .search( @@ -2693,7 +2690,7 @@ pub(crate) mod tests { let gt = groundtruth(queries.as_view(), query, |a, b| SquaredL2::evaluate(a, b)); let mut result_output_buffer = search_output_buffer::IdDistance::new(&mut ids, &mut distances); - let graph_search = graph::search::Knn::new_default(top_k, search_l).unwrap(); + let graph_search = graph::search::Knn::new_default(search_l).unwrap(); // Full Precision Search. index .search( @@ -2960,11 +2957,12 @@ pub(crate) mod tests { let diverse_params = diskann::graph::DiverseSearchParams::new( 0, // diverse_attribute_id diverse_results_k, + return_list_size, attribute_provider.clone(), - ); + ) + .unwrap(); let search_params = diskann::graph::search::Knn::new( - return_list_size, search_list_size, None, // beam_width ) @@ -3121,7 +3119,7 @@ pub(crate) mod tests { let mut ids = vec![0; top_k]; let mut distances = vec![0.0; top_k]; let ctx = DefaultContext; - let search_params = graph::search::Knn::new_default(top_k, search_l).unwrap(); + let search_params = graph::search::Knn::new_default(search_l).unwrap(); for i in 0..query_count { let query_vector = &queries[i * VECTORS_DIMENSION..(i + 1) * VECTORS_DIMENSION]; diff --git a/diskann-providers/src/index/wrapped_async.rs b/diskann-providers/src/index/wrapped_async.rs index 29e989aef..a8f3c016c 100644 --- a/diskann-providers/src/index/wrapped_async.rs +++ b/diskann-providers/src/index/wrapped_async.rs @@ -773,7 +773,7 @@ mod tests { let mut output = search_output_buffer::IdDistance::new(&mut ids, &mut distances); let query = train_data.row(0); - let kind = graph::search::Knn::new_default(top_k, search_l).unwrap(); + let kind = graph::search::Knn::new_default(search_l).unwrap(); let stats = loaded .search(kind, &FullPrecision, &DefaultContext, query, &mut output) .unwrap(); diff --git a/diskann/src/graph/misc.rs b/diskann/src/graph/misc.rs index 067bd2d21..821a02022 100644 --- a/diskann/src/graph/misc.rs +++ b/diskann/src/graph/misc.rs @@ -2,6 +2,15 @@ * Copyright (c) Microsoft Corporation. * Licensed under the MIT license. */ +#[cfg(feature = "experimental_diversity_search")] +mod imports { + pub(super) use crate::{ANNError, ANNErrorKind}; + pub(super) use std::num::NonZeroUsize; + pub(super) use thiserror::Error; +} + +#[cfg(feature = "experimental_diversity_search")] +use imports::*; // enum used to return the status of the vector that `consolidate_vector` // was called on: Deleted if the vector was already deleted, and Complete @@ -31,6 +40,24 @@ pub enum InplaceDeleteMethod { OneHop, } +/// Error type for [`DiverseSearchParams`] parameter validation. +#[cfg(feature = "experimental_diversity_search")] +#[derive(Debug, Error)] +pub enum DiverseSearchError { + #[error("original k_value cannot be zero")] + OriginalKZero, + #[error("diverse k_value cannot be zero")] + DiverseKZero, +} + +#[cfg(feature = "experimental_diversity_search")] +impl From for ANNError { + #[track_caller] + fn from(err: DiverseSearchError) -> Self { + Self::new(ANNErrorKind::IndexError, err) + } +} + // Parameters for diverse search #[cfg(feature = "experimental_diversity_search")] #[derive(Clone, Debug)] @@ -39,7 +66,8 @@ where P: crate::neighbor::AttributeValueProvider, { pub diverse_attribute_id: usize, - pub diverse_results_k: usize, + pub diverse_results_k: NonZeroUsize, + pub original_k_value: NonZeroUsize, pub attribute_provider: std::sync::Arc

, } @@ -51,13 +79,20 @@ where pub fn new( diverse_attribute_id: usize, diverse_results_k: usize, + original_k_value: usize, attribute_provider: std::sync::Arc

, - ) -> Self { - Self { + ) -> Result { + let diverse_results_k = + NonZeroUsize::new(diverse_results_k).ok_or(DiverseSearchError::DiverseKZero)?; + let original_k_value = + NonZeroUsize::new(original_k_value).ok_or(DiverseSearchError::OriginalKZero)?; + + Ok(Self { diverse_attribute_id, diverse_results_k, + original_k_value, attribute_provider, - } + }) } } diff --git a/diskann/src/graph/search/diverse_search.rs b/diskann/src/graph/search/diverse_search.rs index f87315711..2bb92c115 100644 --- a/diskann/src/graph/search/diverse_search.rs +++ b/diskann/src/graph/search/diverse_search.rs @@ -72,8 +72,8 @@ where let attribute_provider = self.diverse_params.attribute_provider.clone(); let diverse_queue = DiverseNeighborQueue::new( self.inner.l_value().get(), - self.inner.k_value(), - self.diverse_params.diverse_results_k, + self.diverse_params.original_k_value, + self.diverse_params.diverse_results_k.get(), attribute_provider, ); diff --git a/diskann/src/graph/search/knn_search.rs b/diskann/src/graph/search/knn_search.rs index fae00ab84..0c929f2c0 100644 --- a/diskann/src/graph/search/knn_search.rs +++ b/diskann/src/graph/search/knn_search.rs @@ -26,12 +26,8 @@ use crate::{ /// Error type for [`Knn`] parameter validation. #[derive(Debug, Error)] pub enum KnnSearchError { - #[error("l_value ({l_value}) cannot be less than k_value ({k_value})")] - LLessThanK { l_value: usize, k_value: usize }, #[error("beam width cannot be zero")] BeamWidthZero, - #[error("k_value cannot be zero")] - KZero, #[error("l_value cannot be zero")] LZero, } @@ -60,7 +56,6 @@ impl From for ANNError { /// /// # Parameters /// -/// - `k_value`: Number of nearest neighbors to return /// - `l_value`: Search list size (larger values improve recall at cost of latency) /// - `beam_width`: Optional parallel exploration width /// @@ -69,13 +64,11 @@ impl From for ANNError { /// ```ignore /// use diskann::graph::{search::Knn, Search}; /// -/// let params = Knn::new(10, 100, None)?; +/// let params = Knn::new(100, None)?; /// let stats = index.search(params, &strategy, &context, &query, &mut output).await?; /// ``` #[derive(Debug, Clone, Copy)] pub struct Knn { - /// Number of results to return (k in k-NN). - k_value: NonZeroUsize, /// Search list size - controls accuracy vs speed tradeoff. l_value: NonZeroUsize, /// Beam width for parallel graph exploration (defaults to 1). @@ -89,21 +82,10 @@ impl Knn { /// /// # Errors /// - /// Returns an error if `k_value` is zero, `l_value` is zero, - /// `l_value < k_value`, or if `beam_width` is `Some(0)`. - pub fn new( - k_value: usize, - l_value: usize, - beam_width: Option, - ) -> Result { - let k_value = NonZeroUsize::new(k_value).ok_or(KnnSearchError::KZero)?; + /// Returns an error if `l_value` is zero, + /// or if `beam_width` is `Some(0)`. + pub fn new(l_value: usize, beam_width: Option) -> Result { let l_value = NonZeroUsize::new(l_value).ok_or(KnnSearchError::LZero)?; - if k_value > l_value { - return Err(KnnSearchError::LLessThanK { - l_value: l_value.get(), - k_value: k_value.get(), - }); - } const ONE: NonZeroUsize = NonZeroUsize::new(1).unwrap(); let beam_width = match beam_width { @@ -112,21 +94,14 @@ impl Knn { }; Ok(Self { - k_value, l_value, beam_width, }) } /// Create parameters with default beam width. - pub fn new_default(k_value: usize, l_value: usize) -> Result { - Self::new(k_value, l_value, None) - } - - /// Returns the number of results to return (k in k-NN). - #[inline] - pub fn k_value(&self) -> NonZeroUsize { - self.k_value + pub fn new_default(l_value: usize) -> Result { + Self::new(l_value, None) } /// Returns the search list size. @@ -299,25 +274,15 @@ mod tests { #[test] fn test_knn_search_validation() { // Valid - assert!(Knn::new(10, 100, None).is_ok()); - assert!(Knn::new(10, 100, Some(4)).is_ok()); - assert!(Knn::new(10, 10, None).is_ok()); // k == l is valid - - // Invalid: k = 0 - assert!(matches!(Knn::new(0, 100, None), Err(KnnSearchError::KZero))); + assert!(Knn::new(100, None).is_ok()); + assert!(Knn::new(100, Some(4)).is_ok()); // Invalid: l = 0 - assert!(matches!(Knn::new(10, 0, None), Err(KnnSearchError::LZero))); - - // Invalid: l < k - assert!(matches!( - Knn::new(100, 10, None), - Err(KnnSearchError::LLessThanK { .. }) - )); + assert!(matches!(Knn::new(0, None), Err(KnnSearchError::LZero))); // Invalid: zero beam_width assert!(matches!( - Knn::new(10, 100, Some(0)), + Knn::new(100, Some(0)), Err(KnnSearchError::BeamWidthZero) )); } diff --git a/diskann/src/graph/test/cases/grid_insert.rs b/diskann/src/graph/test/cases/grid_insert.rs index 0c0e24724..ed26f970a 100644 --- a/diskann/src/graph/test/cases/grid_insert.rs +++ b/diskann/src/graph/test/cases/grid_insert.rs @@ -201,10 +201,11 @@ fn run_searches( let mut results = Vec::new(); for (query, desc) in queries { - let params = Knn::new(10, 10, None).unwrap(); + let k_value = 10; + let params = Knn::new(10, None).unwrap(); let search_ctx = test_provider::Context::new(); - let mut neighbors = vec![Neighbor::::default(); params.k_value().get()]; + let mut neighbors = vec![Neighbor::::default(); k_value]; let graph::index::SearchStats { cmps, hops, diff --git a/diskann/src/graph/test/cases/grid_search.rs b/diskann/src/graph/test/cases/grid_search.rs index 7afbb2214..dab9001cb 100644 --- a/diskann/src/graph/test/cases/grid_search.rs +++ b/diskann/src/graph/test/cases/grid_search.rs @@ -127,10 +127,11 @@ fn _grid_search(grid: Grid, size: usize, mut parent: TestPath<'_>) { // are correct. let index = setup_grid_search(grid, size); - let params = Knn::new(10, 10, Some(beam_width)).unwrap(); + let k_value = 10; + let params = Knn::new(10, Some(beam_width)).unwrap(); let context = test_provider::Context::new(); - let mut neighbors = vec![Neighbor::::default(); params.k_value().get()]; + let mut neighbors = vec![Neighbor::::default(); k_value]; let graph::index::SearchStats { cmps, hops, @@ -147,7 +148,7 @@ fn _grid_search(grid: Grid, size: usize, mut parent: TestPath<'_>) { .unwrap(); assert!( - result_count.into_usize() <= params.k_value().get(), + result_count.into_usize() <= k_value, "grid search should not return more than the requested number of neighbors", ); diff --git a/diskann/src/graph/test/cases/inline.rs b/diskann/src/graph/test/cases/inline.rs index eaaebc357..36f9d7c0f 100644 --- a/diskann/src/graph/test/cases/inline.rs +++ b/diskann/src/graph/test/cases/inline.rs @@ -348,7 +348,7 @@ fn run_inline_on_grid( adaptive_l: Option, ) -> InlineFilterBaseline { let rt = current_thread_runtime(); - let inline = InlineFilterSearch::new(Knn::new_default(k, l).unwrap(), filter, adaptive_l); + let inline = InlineFilterSearch::new(Knn::new_default(l).unwrap(), filter, adaptive_l); let mut ids = vec![0u32; k]; let mut distances = vec![0.0f32; k]; @@ -391,7 +391,7 @@ fn inline_search_returns_only_final_level_matches() { let filter = LevelLabelProvider::new(); let k = 8; let l = 32; - let inline = InlineFilterSearch::new(Knn::new_default(k, l).unwrap(), &filter, None); + let inline = InlineFilterSearch::new(Knn::new_default(l).unwrap(), &filter, None); let mut ids = vec![0u32; k]; let mut distances = vec![0.0f32; k]; @@ -448,7 +448,7 @@ fn inline_search_three_level_no_adaptive_l_with_l1_finds_no_matches() { let filter = LevelLabelProvider::new(); let k = 1; let l = 1; - let inline = InlineFilterSearch::new(Knn::new_default(k, l).unwrap(), &filter, None); + let inline = InlineFilterSearch::new(Knn::new_default(l).unwrap(), &filter, None); let mut ids = vec![0u32; k]; let mut distances = vec![0.0f32; k]; @@ -500,8 +500,7 @@ fn inline_search_three_level_adaptive_l_with_l1_finds_matches() { let k = 1; let l = 1; let adaptive_l = AdaptiveL::new(1, 16.0).unwrap(); - let inline = - InlineFilterSearch::new(Knn::new_default(k, l).unwrap(), &filter, Some(adaptive_l)); + let inline = InlineFilterSearch::new(Knn::new_default(l).unwrap(), &filter, Some(adaptive_l)); let mut ids = vec![0u32; k]; let mut distances = vec![0.0f32; k]; @@ -633,7 +632,7 @@ fn inline_search_reaches_matches_through_non_matching_nodes() { let k = 5; let l = 20; - let search_params = Knn::new_default(k, l).unwrap(); + let search_params = Knn::new_default(l).unwrap(); let inline = InlineFilterSearch::new(search_params, &filter, None); let mut ids = vec![0u32; k]; @@ -725,7 +724,7 @@ fn inline_callback_filtering_grid() { let k = 20; let l = 40; - let search_params = Knn::new_default(k, l).unwrap(); + let search_params = Knn::new_default(l).unwrap(); let inline = InlineFilterSearch::new(search_params, &filter, None); let mut ids = vec![0u32; k]; diff --git a/diskann/src/graph/test/cases/multihop.rs b/diskann/src/graph/test/cases/multihop.rs index dd7688bd4..0f877ce7d 100644 --- a/diskann/src/graph/test/cases/multihop.rs +++ b/diskann/src/graph/test/cases/multihop.rs @@ -206,7 +206,6 @@ pub(super) fn build_1d_provider( fn run_internal( provider: &test_provider::Provider, query: &[f32], - k: usize, l: usize, max_degree: usize, filter: &dyn QueryLabelProvider, @@ -219,7 +218,7 @@ fn run_internal( let stats = crate::graph::search::multihop_filter_search::multihop_search_internal( max_degree, - &Knn::new_default(k, l).unwrap(), + &Knn::new_default(l).unwrap(), &mut accessor, &mut scratch, &mut NoopSearchRecord::new(), @@ -264,7 +263,7 @@ fn accept_all_finds_all_nodes() { 3, ); - let (stats, results) = run_internal(&provider, &[1.5], 3, 10, 3, &AcceptAll); + let (stats, results) = run_internal(&provider, &[1.5], 10, 3, &AcceptAll); let ids: Vec = results.iter().map(|n| n.id).collect(); assert!(ids.contains(&0), "node 0 should be found"); @@ -305,7 +304,7 @@ fn reject_triggers_two_hop_expansion() { ); let filter = EvenFilter; - let (stats, results) = run_internal(&provider, &[2.0], 5, 20, 4, &filter); + let (stats, results) = run_internal(&provider, &[2.0], 20, 4, &filter); let ids: Vec = results.iter().map(|n| n.id).collect(); @@ -356,7 +355,7 @@ fn reject_all_yields_only_start() { 2, ); - let (_stats, results) = run_internal(&provider, &[0.5], 5, 10, 2, &RejectAll); + let (_stats, results) = run_internal(&provider, &[0.5], 10, 2, &RejectAll); // Only the start point should be in the best set — all one-hop neighbors // were rejected. Two-hop expansion goes through rejected nodes but RejectAll's @@ -392,7 +391,7 @@ fn terminate_stops_search_on_target() { ); let filter = TerminateOnTarget::new(2); - let (_stats, _results) = run_internal(&provider, &[0.0], 4, 10, 2, &filter); + let (_stats, _results) = run_internal(&provider, &[0.0], 10, 2, &filter); let hits = filter.hits(); assert!(hits.contains(&2), "target node 2 should have been visited"); @@ -430,7 +429,7 @@ fn block_and_adjust_modifies_results() { ); let filter = BlockAndAdjust::new(1, 2, 0.5); - let (_stats, results) = run_internal(&provider, &[0.0], 5, 10, 3, &filter); + let (_stats, results) = run_internal(&provider, &[0.0], 10, 3, &filter); let ids: Vec = results.iter().map(|n| n.id).collect(); @@ -556,7 +555,7 @@ fn two_hop_reaches_through_non_matching() { let k = 5; let l = 20; - let search_params = Knn::new_default(k, l).unwrap(); + let search_params = Knn::new_default(l).unwrap(); let multihop = MultihopFilterSearch::new(search_params, &filter); let mut ids = vec![0u32; k]; @@ -625,7 +624,7 @@ fn even_filtering_grid() { let k = 20; let l = 40; - let search_params = Knn::new_default(k, l).unwrap(); + let search_params = Knn::new_default(l).unwrap(); let multihop = MultihopFilterSearch::new(search_params, &filter); let mut ids = vec![0u32; k]; @@ -691,7 +690,7 @@ fn callback_filtering_grid() { let k = 20; let l = 40; - let search_params = Knn::new_default(k, l).unwrap(); + let search_params = Knn::new_default(l).unwrap(); let multihop = MultihopFilterSearch::new(search_params, &filter); let mut ids = vec![0u32; k];