diff --git a/diskann-benchmark-core/src/search/graph/knn.rs b/diskann-benchmark-core/src/search/graph/knn.rs index 8c96b41a7..f1c074d04 100644 --- a/diskann-benchmark-core/src/search/graph/knn.rs +++ b/diskann-benchmark-core/src/search/graph/knn.rs @@ -23,7 +23,7 @@ use crate::{ }; /// A built-in helper for benchmarking the K-nearest neighbors method -/// [`graph::DiskANNIndex::search`]. +/// [`graph::DiskANNIndex::search`] with optional post-processing support. /// /// This is intended to be used in conjunction with [`search::search`] or /// [`search::search_all`] and provides some basic additional metrics for @@ -32,21 +32,31 @@ use crate::{ /// /// The provided implementation of [`Search`] accepts [`graph::search::Knn`] /// and returns [`Metrics`] as additional output. +/// +/// # Type Parameters +/// +/// - `DP`: The data provider type +/// - `T`: The query element type +/// - `S`: The search strategy type +/// - `PP`: Post-processor selector. Defaults to [`Defaulted`], which uses the +/// strategy's default post-processor. Use [`KNN::with_postprocessor`] to +/// supply an explicit post-processor. #[derive(Debug)] -pub struct KNN +pub struct KNN where DP: provider::DataProvider, { index: Arc>, queries: Arc>, strategy: Strategy, + post_processor: PP, } -impl KNN +impl KNN where DP: provider::DataProvider, { - /// Construct a new [`KNN`] searcher. + /// Construct a new [`KNN`] searcher using the strategy's default post-processor. /// /// If `strategy` is one of the container variants of [`Strategy`], its length /// must match the number of rows in `queries`. If this is the case, then the @@ -68,10 +78,98 @@ where index, queries, strategy, + post_processor: Defaulted, + })) + } +} + +impl KNN> +where + DP: provider::DataProvider, +{ + /// Construct a new [`KNN`] searcher with an explicit post-processor. + /// + /// # Errors + /// + /// Returns an error if the number of elements in `strategy` is not compatible with + /// the number of rows in `queries`. + pub fn with_postprocessor( + index: Arc>, + queries: Arc>, + strategy: Strategy, + post_processor: PP, + ) -> anyhow::Result> { + strategy.length_compatible(queries.nrows())?; + + Ok(Arc::new(Self { + index, + queries, + strategy, + post_processor: Forwarded(post_processor), })) } } +impl KNN +where + DP: provider::DataProvider, +{ + /// Access the index. + pub fn index(&self) -> &Arc> { + &self.index + } +} + +/// Resolves a post-processor for [`KNN`] given a search strategy. +/// +/// This trait lets [`KNN`] support both "use the strategy's default post-processor" +/// ([`Defaulted`]) and "use this explicit post-processor" ([`Forwarded`]) without +/// duplicating the search loop. +pub trait AsPostProcessor<'a, S, DP, T> +where + DP: provider::DataProvider, + S: glue::SearchStrategy<'a, DP, T>, +{ + /// The concrete post-processor used for a single search. + type Processor: glue::SearchPostProcess + Send + Sync; + + /// Construct the post-processor to use for a single search. + fn as_post_processor(&'a self, strategy: &'a S) -> Self::Processor; +} + +/// Marker indicating that [`KNN`] should use the strategy's default post-processor. +#[derive(Debug, Clone, Copy)] +pub struct Defaulted; + +impl<'a, S, DP, T> AsPostProcessor<'a, S, DP, T> for Defaulted +where + DP: provider::DataProvider, + S: glue::DefaultPostProcessor<'a, DP, T, DP::ExternalId>, +{ + type Processor = S::Processor; + + fn as_post_processor(&'a self, strategy: &'a S) -> Self::Processor { + strategy.default_post_processor() + } +} + +/// Wraps an explicit post-processor for use with [`KNN::with_postprocessor`]. +#[derive(Debug, Clone, Copy)] +pub struct Forwarded(PP); + +impl<'a, S, DP, T, PP> AsPostProcessor<'a, S, DP, T> for Forwarded +where + DP: provider::DataProvider, + S: glue::SearchStrategy<'a, DP, T>, + PP: glue::SearchPostProcess + Clone + AsyncFriendly, +{ + type Processor = PP; + + fn as_post_processor(&'a self, _strategy: &'a S) -> Self::Processor { + self.0.clone() + } +} + /// Additional metrics collected during [`KNN`] search. /// /// # Note @@ -86,10 +184,11 @@ pub struct Metrics { pub hops: u32, } -impl Search for KNN +impl Search for KNN where DP: provider::DataProvider, - S: for<'a> glue::DefaultSearchStrategy<'a, DP, &'a [T], DP::ExternalId> + Clone + AsyncFriendly, + S: for<'a> glue::SearchStrategy<'a, DP, &'a [T]> + Clone + AsyncFriendly, + PP: for<'a> AsPostProcessor<'a, S, DP, &'a [T]> + AsyncFriendly, graph::search::Knn: for<'a> graph::Search<'a, DP, S, &'a [T], Output = graph::index::SearchStats>, T: AsyncFriendly + Clone, @@ -117,11 +216,15 @@ where { let context = DP::Context::default(); let knn_search = *parameters; + let strategy = self.strategy.get(index)?; + let processor = self.post_processor.as_post_processor(strategy); + let stats = self .index - .search( + .search_with( knn_search, - self.strategy.get(index)?, + strategy, + processor, &context, self.queries.row(index), buffer, diff --git a/diskann-benchmark/example/async-determinant-diversity.json b/diskann-benchmark/example/async-determinant-diversity.json new file mode 100644 index 000000000..bdab4cf0b --- /dev/null +++ b/diskann-benchmark/example/async-determinant-diversity.json @@ -0,0 +1,48 @@ +{ + "search_directories": [ + "test_data/disk_index_search" + ], + "jobs": [ + { + "type": "graph-index-build", + "content": { + "source": { + "index-source": "Build", + "data_type": "float32", + "data": "disk_index_siftsmall_learn_256pts_data.fbin", + "distance": "squared_l2", + "max_degree": 32, + "l_build": 50, + "alpha": 1.2, + "backedge_ratio": 1.0, + "num_threads": 1, + "start_point_strategy": "medoid", + "num_insert_attempts": 1, + "saturate_inserts": false + }, + "search_phase": { + "search-type": "topk-determinant-diversity", + "queries": "disk_index_sample_query_10pts.fbin", + "groundtruth": "disk_index_10pts_idx_uint32_truth_search_res.bin", + "reps": 5, + "num_threads": [ + 1 + ], + "power": 2.0, + "eta": 0.01, + "runs": [ + { + "search_n": 20, + "search_l": [ + 20, + 30, + 40 + ], + "recall_k": 10 + } + ] + } + } + } + ] +} diff --git a/diskann-benchmark/example/disk-index-determinant-diversity.json b/diskann-benchmark/example/disk-index-determinant-diversity.json new file mode 100644 index 000000000..2962c1d97 --- /dev/null +++ b/diskann-benchmark/example/disk-index-determinant-diversity.json @@ -0,0 +1,42 @@ +{ + "search_directories": [ + "test_data/disk_index_search" + ], + "jobs": [ + { + "type": "disk-index", + "content": { + "source": { + "disk-index-source": "Build", + "data_type": "float32", + "data": "disk_index_siftsmall_learn_256pts_data.fbin", + "distance": "squared_l2", + "dim": 128, + "max_degree": 32, + "l_build": 50, + "num_threads": 1, + "build_ram_limit_gb": 2.0, + "num_pq_chunks": 128, + "quantization_type": "FP", + "save_path": "siftsmall_index_full_det_div" + }, + "search_phase": { + "queries": "disk_index_sample_query_10pts.fbin", + "groundtruth": "disk_index_10pts_idx_uint32_truth_search_res.bin", + "search_list": [10, 20, 40], + "beam_width": 4, + "recall_at": 10, + "num_threads": 1, + "is_flat_search": false, + "distance": "squared_l2", + "vector_filters_file": null, + "post_processor": { + "type": "determinant-diversity", + "power": 2.0, + "eta": 1.0 + } + } + } + } + ] +} diff --git a/diskann-benchmark/src/disk_index/search.rs b/diskann-benchmark/src/disk_index/search.rs index 77acaa351..db7bccbdc 100644 --- a/diskann-benchmark/src/disk_index/search.rs +++ b/diskann-benchmark/src/disk_index/search.rs @@ -14,7 +14,8 @@ use diskann_benchmark_runner::{files::InputFile, utils::MicroSeconds}; use diskann_disk::{ data_model::{AdHoc, CachingStrategy}, search::provider::{ - disk_provider::DiskIndexSearcher, disk_vertex_provider_factory::DiskVertexProviderFactory, + disk_provider::{DiskIndexSearcher, SearchPostProcessorKind}, + disk_vertex_provider_factory::DiskVertexProviderFactory, }, storage::disk_index_reader::DiskIndexReader, utils::{instrumentation::PerfLogger, statistics, AlignedFileReaderFactory, QueryStatistics}, @@ -32,7 +33,10 @@ use serde::{Deserialize, Serialize}; use crate::{ disk_index::json_spancollector::JsonSpanCollector, - inputs::disk::{DiskIndexLoad, DiskSearchPhase}, + inputs::{ + disk::{DiskIndexLoad, DiskSearchPhase}, + post_processor::TopkPostProcessor, + }, utils::{datafiles, SimilarityMeasure}, }; @@ -264,6 +268,12 @@ where zipped.for_each_in_pool( pool.as_ref(), |(((((q, vf), id_chunk), dist_chunk), stats), rc)| { + let post_processor = search_params.post_processor.as_ref().map_or( + SearchPostProcessorKind::None, + |TopkPostProcessor::DeterminantDiversity(params)| { + SearchPostProcessorKind::DeterminantDiversity(*params) + }, + ); let vector_filter = if search_params.vector_filters_file.is_none() { None } else { @@ -277,20 +287,21 @@ where l, Some(search_params.beam_width), vector_filter, + post_processor, search_params.is_flat_search, ) { Ok(search_result) => { *stats = search_result.stats.query_statistics; - *rc = search_result.results.len() as u32; - let actual_results = search_result - .results - .len() - .min(search_params.recall_at as usize); - for (i, result_item) in search_result - .results - .iter() - .take(actual_results) - .enumerate() + let base_count = (search_result.stats.result_count as usize) + .min(search_params.recall_at as usize) + .min(search_result.results.len()); + + *rc = base_count as u32; + id_chunk.fill(0); + dist_chunk.fill(0.0); + + for (i, result_item) in + search_result.results.iter().take(base_count).enumerate() { id_chunk[i] = result_item.vertex_id; dist_chunk[i] = result_item.distance; diff --git a/diskann-benchmark/src/index/benchmarks.rs b/diskann-benchmark/src/index/benchmarks.rs index f229557dd..0a66576a5 100644 --- a/diskann-benchmark/src/index/benchmarks.rs +++ b/diskann-benchmark/src/index/benchmarks.rs @@ -78,7 +78,8 @@ pub(crate) fn register_benchmarks(registry: &mut Registry) -> anyhow::Result<()> .search(plugins::Range) .search(plugins::TopkBetaFilter) .search(plugins::TopkMultihopFilter) - .search(plugins::TopkInlineFilter), + .search(plugins::TopkInlineFilter) + .search(plugins::DeterminantDiversity), )?; registry.register( @@ -442,6 +443,47 @@ impl Strategy { // Topk // //------// +impl search::Plugin, SearchPhase, Strategy> + for plugins::DeterminantDiversity +{ + fn is_match(&self, phase: &SearchPhase) -> bool { + plugins::DeterminantDiversity::is_match(phase) + } + + fn kind(&self) -> &'static str { + plugins::DeterminantDiversity::as_str() + } + + fn run( + &self, + index: Arc>>, + phase: &SearchPhase, + _strategy: &Strategy, + ) -> anyhow::Result { + let (phase, params) = plugins::DeterminantDiversity::get(phase)?; + + let queries = Arc::new(datafiles::load_dataset::(datafiles::BinFile( + &phase.queries, + ))?); + let groundtruth = datafiles::load_groundtruth( + datafiles::BinFile(&phase.groundtruth), + Some(phase.max_k()), + )?; + + let knn = benchmark_core::search::graph::KNN::with_postprocessor( + index, + queries, + benchmark_core::search::graph::Strategy::broadcast(common::FullPrecision), + inmem::DeterminantDiversity::new(params), + )?; + + let steps = search::knn::SearchSteps::new(phase.reps, &phase.num_threads, &phase.runs); + let results = search::knn::run(&knn, &groundtruth, steps)?; + + Ok(AggregatedSearchResults::Topk(results)) + } +} + impl search::Plugin> for plugins::Topk where DP: DataProvider + QueryType, @@ -454,11 +496,11 @@ where + AsyncFriendly, { fn is_match(&self, phase: &SearchPhase) -> bool { - Self::kind() == phase.kind() + plugins::Topk::is_match(phase) } fn kind(&self) -> &'static str { - Self::kind().as_str() + plugins::Topk::as_str() } fn run( @@ -507,11 +549,11 @@ where + AsyncFriendly, { fn is_match(&self, phase: &SearchPhase) -> bool { - Self::kind() == phase.kind() + plugins::Range::is_match(phase) } fn kind(&self) -> &'static str { - Self::kind().as_str() + plugins::Range::as_str() } fn run( @@ -557,11 +599,11 @@ where + AsyncFriendly, { fn is_match(&self, phase: &SearchPhase) -> bool { - Self::kind() == phase.kind() + plugins::TopkBetaFilter::is_match(phase) } fn kind(&self) -> &'static str { - Self::kind().as_str() + plugins::TopkBetaFilter::as_str() } fn run( @@ -622,11 +664,11 @@ where + AsyncFriendly, { fn is_match(&self, phase: &SearchPhase) -> bool { - Self::kind() == phase.kind() + plugins::TopkMultihopFilter::is_match(phase) } fn kind(&self) -> &'static str { - Self::kind().as_str() + plugins::TopkMultihopFilter::as_str() } fn run( @@ -680,11 +722,11 @@ where + AsyncFriendly, { fn is_match(&self, phase: &SearchPhase) -> bool { - Self::kind() == phase.kind() + plugins::TopkInlineFilter::is_match(phase) } fn kind(&self) -> &'static str { - Self::kind().as_str() + plugins::TopkInlineFilter::as_str() } fn run( diff --git a/diskann-benchmark/src/index/inmem/spherical.rs b/diskann-benchmark/src/index/inmem/spherical.rs index 37944e4ef..7a9994648 100644 --- a/diskann-benchmark/src/index/inmem/spherical.rs +++ b/diskann-benchmark/src/index/inmem/spherical.rs @@ -91,7 +91,7 @@ mod imp { }, inputs::{ exhaustive, - graph_index::{SearchPhase, SphericalQuantBuild}, + graph_index::{SearchPhase, SearchPhaseKind, SphericalQuantBuild}, }, utils::{ self, datafiles, @@ -366,11 +366,11 @@ mod imp { for search::plugins::Topk { fn is_match(&self, phase: &SearchPhase) -> bool { - Self::kind() == phase.kind() + search::plugins::Topk::is_match(phase) } fn kind(&self) -> &'static str { - Self::kind().as_str() + SearchPhaseKind::Topk.as_str() } fn run( @@ -409,11 +409,11 @@ mod imp { for search::plugins::Range { fn is_match(&self, phase: &SearchPhase) -> bool { - Self::kind() == phase.kind() + search::plugins::Range::is_match(phase) } fn kind(&self) -> &'static str { - Self::kind().as_str() + SearchPhaseKind::Range.as_str() } fn run( @@ -451,11 +451,11 @@ mod imp { for search::plugins::TopkBetaFilter { fn is_match(&self, phase: &SearchPhase) -> bool { - Self::kind() == phase.kind() + search::plugins::TopkBetaFilter::is_match(phase) } fn kind(&self) -> &'static str { - Self::kind().as_str() + SearchPhaseKind::TopkBetaFilter.as_str() } fn run( @@ -505,11 +505,11 @@ mod imp { for search::plugins::TopkMultihopFilter { fn is_match(&self, phase: &SearchPhase) -> bool { - Self::kind() == phase.kind() + search::plugins::TopkMultihopFilter::is_match(phase) } fn kind(&self) -> &'static str { - Self::kind().as_str() + SearchPhaseKind::TopkMultihopFilter.as_str() } fn run( @@ -555,11 +555,11 @@ mod imp { for search::plugins::TopkInlineFilter { fn is_match(&self, phase: &SearchPhase) -> bool { - Self::kind() == phase.kind() + search::plugins::TopkInlineFilter::is_match(phase) } fn kind(&self) -> &'static str { - Self::kind().as_str() + search::plugins::TopkInlineFilter::as_str() } fn run( diff --git a/diskann-benchmark/src/index/search/knn.rs b/diskann-benchmark/src/index/search/knn.rs index 8dd1c06b5..bc117b175 100644 --- a/diskann-benchmark/src/index/search/knn.rs +++ b/diskann-benchmark/src/index/search/knn.rs @@ -79,10 +79,10 @@ pub(crate) trait Knn { // Impls // /////////// -impl Knn for Arc> +impl Knn for Arc> where DP: diskann::provider::DataProvider, - core_search::graph::KNN: core_search::Search< + core_search::graph::KNN: core_search::Search< Id = DP::InternalId, Parameters = diskann::graph::search::Knn, Output = core_search::graph::knn::Metrics, diff --git a/diskann-benchmark/src/index/search/plugins.rs b/diskann-benchmark/src/index/search/plugins.rs index 7daf7c630..de050004f 100644 --- a/diskann-benchmark/src/index/search/plugins.rs +++ b/diskann-benchmark/src/index/search/plugins.rs @@ -36,8 +36,12 @@ use std::sync::Arc; use diskann::{graph::DiskANNIndex, provider::DataProvider}; use diskann_benchmark_runner::utils::fmt::{Delimit, Quote}; +use diskann_providers::model::graph::provider::DeterminantDiversityParams; -use crate::{index::result::AggregatedSearchResults, inputs::graph_index::SearchPhaseKind}; +use crate::{ + index::result::AggregatedSearchResults, + inputs::graph_index::{SearchPhase, TopkDeterminantDiversityPhase}, +}; /// A dyn-compatible search plugin for `DP`. /// @@ -143,9 +147,34 @@ where pub(crate) struct Topk; impl Topk { - /// Returns [`SearchPhaseKind::Topk`]. - pub(crate) fn kind() -> SearchPhaseKind { - SearchPhaseKind::Topk + pub(crate) fn is_match(phase: &SearchPhase) -> bool { + phase.as_topk().is_ok() + } + + pub(crate) const fn as_str() -> &'static str { + "topk" + } +} + +/// A search plugin for determinant-diversity top-k post-processing. +#[derive(Debug, Clone, Copy)] +pub(crate) struct DeterminantDiversity; + +impl DeterminantDiversity { + pub(crate) fn is_match(phase: &SearchPhase) -> bool { + phase.as_topk_determinant_diversity().is_ok() + } + + pub(crate) const fn as_str() -> &'static str { + "topk-determinant-diversity" + } + + pub(crate) fn get( + phase: &SearchPhase, + ) -> anyhow::Result<(&TopkDeterminantDiversityPhase, DeterminantDiversityParams)> { + let phase = phase.as_topk_determinant_diversity()?; + let params = DeterminantDiversityParams::new(phase.power, phase.eta)?; + Ok((phase, params)) } } @@ -154,9 +183,12 @@ impl Topk { pub(crate) struct Range; impl Range { - /// Returns [`SearchPhaseKind::Range`]. - pub(crate) fn kind() -> SearchPhaseKind { - SearchPhaseKind::Range + pub(crate) fn is_match(phase: &SearchPhase) -> bool { + phase.as_range().is_ok() + } + + pub(crate) const fn as_str() -> &'static str { + "range" } } @@ -165,9 +197,12 @@ impl Range { pub(crate) struct TopkBetaFilter; impl TopkBetaFilter { - /// Returns [`SearchPhaseKind::TopkBetaFilter`]. - pub(crate) fn kind() -> SearchPhaseKind { - SearchPhaseKind::TopkBetaFilter + pub(crate) fn is_match(phase: &SearchPhase) -> bool { + phase.as_topk_beta_filter().is_ok() + } + + pub(crate) const fn as_str() -> &'static str { + "topk + beta filter" } } @@ -176,9 +211,12 @@ impl TopkBetaFilter { pub(crate) struct TopkMultihopFilter; impl TopkMultihopFilter { - /// Returns [`SearchPhaseKind::TopkMultihopFilter`]. - pub(crate) fn kind() -> SearchPhaseKind { - SearchPhaseKind::TopkMultihopFilter + pub(crate) fn is_match(phase: &SearchPhase) -> bool { + phase.as_topk_multihop_filter().is_ok() + } + + pub(crate) const fn as_str() -> &'static str { + "topk + multihop filter" } } @@ -187,8 +225,11 @@ impl TopkMultihopFilter { pub(crate) struct TopkInlineFilter; impl TopkInlineFilter { - /// Returns [`SearchPhaseKind::TopkInlineFilter`]. - pub(crate) fn kind() -> SearchPhaseKind { - SearchPhaseKind::TopkInlineFilter + pub(crate) fn is_match(phase: &SearchPhase) -> bool { + phase.as_topk_inline_filter().is_ok() + } + + pub(crate) const fn as_str() -> &'static str { + "topk + inline filter" } } diff --git a/diskann-benchmark/src/inputs/disk.rs b/diskann-benchmark/src/inputs/disk.rs index 473d7982b..f6b4cacac 100644 --- a/diskann-benchmark/src/inputs/disk.rs +++ b/diskann-benchmark/src/inputs/disk.rs @@ -13,7 +13,7 @@ use diskann_providers::storage::{get_compressed_pq_file, get_disk_index_file, ge use serde::{Deserialize, Serialize}; use crate::{ - inputs::{as_input, Example}, + inputs::{as_input, post_processor::TopkPostProcessor, Example}, utils::SimilarityMeasure, }; @@ -76,6 +76,7 @@ pub(crate) struct DiskSearchPhase { pub(crate) vector_filters_file: Option, pub(crate) num_nodes_to_cache: Option, pub(crate) search_io_limit: Option, + pub(crate) post_processor: Option, } ///////// @@ -210,6 +211,12 @@ impl DiskSearchPhase { anyhow::bail!("search_io_limit must be positive if specified"); } } + + if let Some(pp) = self.post_processor.as_mut() { + pp.validate(checker) + .context("invalid disk search post processor")?; + } + Ok(()) } } @@ -248,6 +255,7 @@ impl Example for DiskIndexOperation { vector_filters_file: None, num_nodes_to_cache: None, search_io_limit: None, + post_processor: None, }; Self { @@ -373,6 +381,10 @@ impl DiskSearchPhase { Some(lim) => write_field!(f, "Search IO Limit", format!("{lim}"))?, None => write_field!(f, "Search IO Limit", "none (defaults to `usize::MAX`)")?, } + match &self.post_processor { + Some(pp) => write_field!(f, "Post Processor", pp)?, + None => write_field!(f, "Post Processor", "none")?, + } Ok(()) } } diff --git a/diskann-benchmark/src/inputs/graph_index.rs b/diskann-benchmark/src/inputs/graph_index.rs index 94872c4f3..0506c8083 100644 --- a/diskann-benchmark/src/inputs/graph_index.rs +++ b/diskann-benchmark/src/inputs/graph_index.rs @@ -15,7 +15,7 @@ use diskann_benchmark_runner::{files::InputFile, utils::datatype::DataType, Chec use diskann_providers::{ model::{ configuration::IndexConfiguration, - graph::provider::async_::inmem::DefaultProviderParameters, + graph::provider::{async_::inmem::DefaultProviderParameters, DeterminantDiversityParams}, }, utils::load_metadata_from_file, }; @@ -354,6 +354,53 @@ impl Example for MultiInsert { } } +#[derive(Debug, Serialize, Deserialize)] +pub(crate) struct TopkDeterminantDiversityPhase { + pub(crate) queries: InputFile, + pub(crate) groundtruth: InputFile, + pub(crate) reps: NonZeroUsize, + pub(crate) num_threads: Vec, + pub(crate) runs: Vec, + pub(crate) power: f32, + pub(crate) eta: f32, +} + +impl TopkDeterminantDiversityPhase { + pub(crate) fn max_k(&self) -> usize { + self.runs.iter().map(|run| run.recall_k).max().unwrap_or(0) + } + + pub(crate) fn validate(&mut self, checker: &mut Checker) -> Result<(), anyhow::Error> { + DeterminantDiversityParams::new(self.power, self.eta) + .map_err(|e| anyhow::anyhow!("invalid determinant-diversity params: {e}"))?; + self.queries.resolve(checker)?; + self.groundtruth.resolve(checker)?; + for (i, run) in self.runs.iter_mut().enumerate() { + run.validate(checker) + .with_context(|| format!("search run {}", i))?; + } + Ok(()) + } +} + +impl Example for TopkDeterminantDiversityPhase { + fn example() -> Self { + Self { + queries: InputFile::new("path/to/queries"), + groundtruth: InputFile::new("path/to/groundtruth"), + reps: NonZeroUsize::new(1).unwrap(), + num_threads: vec![NonZeroUsize::new(1).unwrap()], + runs: vec![GraphSearch { + search_n: 10, + search_l: vec![10, 20, 30, 40], + recall_k: 10, + }], + power: 1.0, + eta: 0.5, + } + } +} + #[derive(Debug, Deserialize, Serialize)] #[serde(tag = "search-type", rename_all = "kebab-case")] pub(crate) enum SearchPhase { @@ -362,6 +409,7 @@ pub(crate) enum SearchPhase { TopkBetaFilter(BetaSearchPhase), TopkMultihopFilter(MultihopFilterSearchPhase), TopkInlineFilter(InlineFilterSearchPhase), + TopkDeterminantDiversity(TopkDeterminantDiversityPhase), } #[derive(Debug, Error)] @@ -389,6 +437,7 @@ impl SearchPhase { Self::TopkBetaFilter(_) => SearchPhaseKind::TopkBetaFilter, Self::TopkMultihopFilter(_) => SearchPhaseKind::TopkMultihopFilter, Self::TopkInlineFilter(_) => SearchPhaseKind::TopkInlineFilter, + Self::TopkDeterminantDiversity(_) => SearchPhaseKind::TopkDeterminantDiversity, } } @@ -445,6 +494,18 @@ impl SearchPhase { )), } } + + pub(crate) fn as_topk_determinant_diversity( + &self, + ) -> Result<&TopkDeterminantDiversityPhase, WrongSearchPhaseKind> { + match self { + Self::TopkDeterminantDiversity(phase) => Ok(phase), + _ => Err(WrongSearchPhaseKind::new( + SearchPhaseKind::TopkDeterminantDiversity, + self.kind(), + )), + } + } } impl SearchPhase { @@ -455,6 +516,7 @@ impl SearchPhase { SearchPhase::TopkBetaFilter(phase) => phase.validate(checker), SearchPhase::TopkMultihopFilter(phase) => phase.validate(checker), SearchPhase::TopkInlineFilter(phase) => phase.validate(checker), + SearchPhase::TopkDeterminantDiversity(phase) => phase.validate(checker), } } } @@ -466,6 +528,7 @@ pub(crate) enum SearchPhaseKind { TopkBetaFilter, TopkMultihopFilter, TopkInlineFilter, + TopkDeterminantDiversity, } impl SearchPhaseKind { @@ -476,6 +539,7 @@ impl SearchPhaseKind { Self::TopkBetaFilter => "topk-beta-filter", Self::TopkMultihopFilter => "topk-multihop-filter", Self::TopkInlineFilter => "topk-inline-filter", + Self::TopkDeterminantDiversity => "topk-determinant-diversity", } } } @@ -911,7 +975,9 @@ impl IndexPQOperation { } pub(crate) fn validate(&mut self, checker: &mut Checker) -> anyhow::Result<()> { - self.index_operation.validate(checker) + self.index_operation.validate(checker)?; + + Ok(()) } } @@ -995,7 +1061,9 @@ impl IndexSQOperation { )); } - self.index_operation.validate(checker) + self.index_operation.validate(checker)?; + + Ok(()) } } @@ -1349,6 +1417,7 @@ impl DynamicIndexRun { self.build.validate(checker)?; self.runbook_params.validate(checker)?; self.search_phase.validate(checker)?; + Ok(()) } diff --git a/diskann-benchmark/src/inputs/mod.rs b/diskann-benchmark/src/inputs/mod.rs index ed49f145e..a5bc76eaa 100644 --- a/diskann-benchmark/src/inputs/mod.rs +++ b/diskann-benchmark/src/inputs/mod.rs @@ -8,6 +8,7 @@ pub(crate) mod exhaustive; pub(crate) mod filters; pub(crate) mod graph_index; pub(crate) mod multi_vector; +pub(crate) mod post_processor; pub(crate) mod save_and_load; #[cfg(feature = "bftree")] diff --git a/diskann-benchmark/src/inputs/post_processor.rs b/diskann-benchmark/src/inputs/post_processor.rs new file mode 100644 index 000000000..e8b9a92a3 --- /dev/null +++ b/diskann-benchmark/src/inputs/post_processor.rs @@ -0,0 +1,70 @@ +/* + * Copyright (c) Microsoft Corporation. + * Licensed under the MIT license. + */ + +use std::fmt; + +use diskann_benchmark_runner::Checker; +use diskann_providers::model::graph::provider::DeterminantDiversityParams; +use serde::{Deserialize, Serialize}; + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(tag = "type", rename_all = "kebab-case")] +enum RawTopkPostProcessor { + DeterminantDiversity { power: f32, eta: f32 }, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(try_from = "RawTopkPostProcessor", into = "RawTopkPostProcessor")] +pub(crate) enum TopkPostProcessor { + DeterminantDiversity(DeterminantDiversityParams), +} + +impl TryFrom for TopkPostProcessor { + type Error = String; + + fn try_from(raw: RawTopkPostProcessor) -> Result { + match raw { + RawTopkPostProcessor::DeterminantDiversity { power, eta } => { + let params = + DeterminantDiversityParams::new(power, eta).map_err(|e| e.to_string())?; + Ok(Self::DeterminantDiversity(params)) + } + } + } +} + +impl From for RawTopkPostProcessor { + fn from(value: TopkPostProcessor) -> Self { + match value { + TopkPostProcessor::DeterminantDiversity(params) => { + RawTopkPostProcessor::DeterminantDiversity { + power: params.power(), + eta: params.eta(), + } + } + } + } +} + +impl TopkPostProcessor { + pub(crate) fn validate(&mut self, _checker: &mut Checker) -> anyhow::Result<()> { + Ok(()) + } +} + +impl fmt::Display for TopkPostProcessor { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + TopkPostProcessor::DeterminantDiversity(params) => { + write!( + f, + "determinant-diversity (power={}, eta={})", + params.power(), + params.eta() + ) + } + } + } +} diff --git a/diskann-disk/src/build/builder/core.rs b/diskann-disk/src/build/builder/core.rs index ed69058de..2fb24c07a 100644 --- a/diskann-disk/src/build/builder/core.rs +++ b/diskann-disk/src/build/builder/core.rs @@ -1089,6 +1089,7 @@ pub(crate) mod disk_index_builder_tests { &mut indices, &mut distances, &mut associated_data, + None, &|_| true, false, ); diff --git a/diskann-disk/src/search/provider/disk_provider.rs b/diskann-disk/src/search/provider/disk_provider.rs index f2f24d5e1..8e95dfa4b 100644 --- a/diskann-disk/src/search/provider/disk_provider.rs +++ b/diskann-disk/src/search/provider/disk_provider.rs @@ -29,10 +29,16 @@ use diskann::{ }; use diskann_providers::storage::StorageReadProvider; use diskann_providers::{ - model::compute_pq_distance, + model::{ + compute_pq_distance, + graph::provider::{determinant_diversity, DeterminantDiversityParams}, + }, storage::{get_compressed_pq_file, get_disk_index_file, get_pq_pivot_file, LoadWith}, }; -use diskann_utils::object_pool::{ObjectPool, PoolOption, TryAsPooled}; +use diskann_utils::{ + object_pool::{ObjectPool, PoolOption, TryAsPooled}, + views::Matrix, +}; use crate::search::pq::{quantizer_preprocess, PQData, PQScratch}; use diskann_vector::{distance::Metric, DistanceFunction}; @@ -264,12 +270,41 @@ pub struct RerankAndFilter<'a> { filter: &'a (dyn Fn(&u32) -> bool + Send + Sync), } +#[derive(Clone, Copy)] +pub struct DeterminantDiversityAndFilter<'a> { + filter: &'a (dyn Fn(&u32) -> bool + Send + Sync), + params: DeterminantDiversityParams, +} + +#[derive(Clone, Copy)] +pub enum SearchPostProcessorKind { + /// No post-processing; search results are returned as-is. + None, + RerankAndFilter, + DeterminantDiversity(DeterminantDiversityParams), +} + +#[derive(Clone, Copy)] +pub enum DiskSearchPostProcessor<'a> { + RerankAndFilter(RerankAndFilter<'a>), + DeterminantDiversity(DeterminantDiversityAndFilter<'a>), +} + impl<'a> RerankAndFilter<'a> { - fn new(filter: &'a (dyn Fn(&u32) -> bool + Send + Sync)) -> Self { + pub fn new(filter: &'a (dyn Fn(&u32) -> bool + Send + Sync)) -> Self { Self { filter } } } +impl<'a> DeterminantDiversityAndFilter<'a> { + pub fn new( + filter: &'a (dyn Fn(&u32) -> bool + Send + Sync), + params: DeterminantDiversityParams, + ) -> Self { + Self { filter, params } + } +} + impl SearchPostProcess< DiskAccessor<'_, Data, VP>, @@ -330,6 +365,120 @@ where } } +impl + SearchPostProcess< + DiskAccessor<'_, Data, VP>, + &[Data::VectorDataType], + ( + as DataProvider>::InternalId, + Data::AssociatedDataType, + ), + > for DeterminantDiversityAndFilter<'_> +where + Data: GraphDataType, + VP: VertexProvider, +{ + type Error = ANNError; + async fn post_process( + &self, + accessor: &mut DiskAccessor<'_, Data, VP>, + query: &[Data::VectorDataType], + candidates: I, + output: &mut B, + ) -> Result + where + I: Iterator> + Send, + B: search_output_buffer::SearchOutputBuffer<(u32, Data::AssociatedDataType)> + + Send + + ?Sized, + { + let provider = accessor.provider; + let query_f32 = Data::VectorDataType::as_f32(query).map_err(Into::into)?; + + let candidate_ids: Vec = candidates + .map(|candidate| candidate.id) + .filter(|id| (self.filter)(id)) + .collect(); + + if candidate_ids.is_empty() { + return Ok(0); + } + + ensure_vertex_loaded(&mut accessor.scratch.vertex_provider, &candidate_ids)?; + + let mut candidate_vectors = Matrix::new(0.0f32, candidate_ids.len(), query_f32.len()); + let mut candidate_distances = Vec::with_capacity(candidate_ids.len()); + let mut associated_data = Vec::with_capacity(candidate_ids.len()); + + for (row_idx, id) in candidate_ids.iter().enumerate() { + let vector = accessor.scratch.vertex_provider.get_vector(id)?; + let distance = provider + .distance_comparer + .evaluate_similarity(query, vector); + let vector_f32 = Data::VectorDataType::as_f32(vector).map_err(Into::into)?; + let data = accessor.scratch.vertex_provider.get_associated_data(id)?; + + candidate_vectors + .row_mut(row_idx) + .copy_from_slice(&vector_f32); + candidate_distances.push(distance); + associated_data.push(*data); + } + + let reranked = determinant_diversity( + candidate_vectors.as_mut_view(), + &candidate_distances, + &query_f32, + usize::MAX, + &self.params, + )?; + + Ok(output.extend(reranked.into_iter().map(|idx| { + let id = candidate_ids[idx]; + let distance = candidate_distances[idx]; + ((id, associated_data[idx]), distance) + }))) + } +} + +impl + SearchPostProcess< + DiskAccessor<'_, Data, VP>, + &[Data::VectorDataType], + ( + as DataProvider>::InternalId, + Data::AssociatedDataType, + ), + > for DiskSearchPostProcessor<'_> +where + Data: GraphDataType, + VP: VertexProvider, +{ + type Error = ANNError; + async fn post_process( + &self, + accessor: &mut DiskAccessor<'_, Data, VP>, + query: &[Data::VectorDataType], + candidates: I, + output: &mut B, + ) -> Result + where + I: Iterator> + Send, + B: search_output_buffer::SearchOutputBuffer<(u32, Data::AssociatedDataType)> + + Send + + ?Sized, + { + match self { + DiskSearchPostProcessor::RerankAndFilter(pp) => { + pp.post_process(accessor, query, candidates, output).await + } + DiskSearchPostProcessor::DeterminantDiversity(pp) => { + pp.post_process(accessor, query, candidates, output).await + } + } + } +} + impl<'this, Data, ProviderFactory> SearchStrategy<'this, DiskProvider, &'this [Data::VectorDataType]> for DiskSearchStrategy<'this, Data, ProviderFactory> @@ -827,6 +976,7 @@ where /// Perform a search on the disk index. /// return the list of nearest neighbors and associated data. + #[allow(clippy::too_many_arguments)] pub fn search( &self, query: &[Data::VectorDataType], @@ -834,6 +984,7 @@ where search_list_size: u32, beam_width: Option, vector_filter: Option>, + post_processor: SearchPostProcessorKind, is_flat_search: bool, ) -> ANNResult> { let mut query_stats = QueryStatistics::default(); @@ -842,6 +993,21 @@ where let mut associated_data = vec![Data::AssociatedDataType::default(); return_list_size as usize]; + let vector_filter = vector_filter.unwrap_or(default_vector_filter::()); + let post_processor = match post_processor { + SearchPostProcessorKind::None => None, + SearchPostProcessorKind::RerankAndFilter => { + Some(DiskSearchPostProcessor::RerankAndFilter( + RerankAndFilter::new(vector_filter.as_ref()), + )) + } + SearchPostProcessorKind::DeterminantDiversity(params) => { + Some(DiskSearchPostProcessor::DeterminantDiversity( + DeterminantDiversityAndFilter::new(vector_filter.as_ref(), params), + )) + } + }; + let stats = self.search_internal( query, return_list_size as usize, @@ -851,7 +1017,8 @@ where &mut indices, &mut distances, &mut associated_data, - &vector_filter.unwrap_or(default_vector_filter::()), + post_processor, + vector_filter.as_ref(), is_flat_search, )?; @@ -888,6 +1055,7 @@ where indices: &mut [u32], distances: &mut [f32], associated_data: &mut [Data::AssociatedDataType], + post_processor: Option>, vector_filter: &(dyn Fn(&Data::VectorIdType) -> bool + Send + Sync), is_flat_search: bool, ) -> ANNResult { @@ -909,10 +1077,18 @@ where l, &mut result_output_buffer, ))? + } else if let Some(processor) = post_processor { + self.runtime.block_on(self.index.search_with( + Knn::new(k, l, beam_width)?, + &strategy, + processor, + &DefaultContext, + query, + &mut result_output_buffer, + ))? } else { - let knn_search = Knn::new(k, l, beam_width)?; self.runtime.block_on(self.index.search( - knn_search, + Knn::new(k, l, beam_width)?, &strategy, &DefaultContext, query, @@ -1309,6 +1485,7 @@ mod disk_provider_tests { &mut indices, &mut distances, &mut associated_data, + None::>, &(|_| true), false, ); @@ -1357,7 +1534,15 @@ mod disk_provider_tests { .for_each_in_pool(pool.as_ref(), |(i, query)| { let result = params .index_search_engine - .search(query, params.k as u32, params.l as u32, beam_width, None, false) + .search( + query, + params.k as u32, + params.l as u32, + beam_width, + None, + SearchPostProcessorKind::None, + false, + ) .unwrap(); let indices: Vec = result.results.iter().map(|item| item.vertex_id).collect(); let associated_data: Vec = @@ -1467,6 +1652,7 @@ mod disk_provider_tests { &mut indices, &mut distances, &mut associated_data, + None::>, &|_| true, false, ); @@ -1537,6 +1723,7 @@ mod disk_provider_tests { search_list_size, Some(4), None, + SearchPostProcessorKind::None, false, ); assert!(result.is_ok(), "Expected search to succeed"); @@ -1553,6 +1740,129 @@ mod disk_provider_tests { ); } + #[test] + fn test_disk_search_determinant_diversity() { + let storage_provider = Arc::new(VirtualStorageProvider::new_overlay(test_data_root())); + let search_engine = create_disk_index_searcher::( + CreateDiskIndexSearcherParams { + max_thread_num: 1, + pq_pivot_file_path: TEST_PQ_PIVOT, + pq_compressed_file_path: TEST_PQ_COMPRESSED, + index_path: TEST_INDEX, + index_path_prefix: TEST_INDEX_PREFIX, + ..Default::default() + }, + &storage_provider, + ); + + let query_vector: [f32; 128] = [1f32; 128]; + let return_list_size = 10u32; + let search_list_size = 20u32; + + // Baseline: no post-processor. Det-div selects from the same L=20 candidate pool, + // so all det-div IDs must be a subset of the baseline candidates. + let baseline = search_engine + .search( + &query_vector, + search_list_size, + search_list_size, + Some(4), + None, + SearchPostProcessorKind::None, + false, + ) + .unwrap(); + let baseline_ids: std::collections::HashSet = + baseline.results.iter().map(|r| r.vertex_id).collect(); + let baseline_top1 = baseline + .results + .first() + .expect("baseline returned no results"); + + // Run with determinant-diversity post-processor (default-ish params). + let params = DeterminantDiversityParams::new(2.0, 0.01).unwrap(); + let result = search_engine + .search( + &query_vector, + return_list_size, + search_list_size, + Some(4), + None, + SearchPostProcessorKind::DeterminantDiversity(params), + false, + ) + .unwrap(); + let det_div_ids: Vec = result.results.iter().map(|r| r.vertex_id).collect(); + + assert_eq!( + det_div_ids.len(), + return_list_size as usize, + "det-div should return k results when the candidate pool is large enough" + ); + for id in &det_div_ids { + assert!( + baseline_ids.contains(id), + "det-div selected id {} that is not in the search candidate pool", + id + ); + } + + let mut unique = std::collections::HashSet::new(); + for id in &det_div_ids { + assert!(unique.insert(*id), "det-div produced duplicate id {}", id); + } + + // Greedy det-div with power > 0 and eta > 0 selects the highest-similarity + // candidate first. + assert_eq!( + result.results[0].vertex_id, baseline_top1.vertex_id, + "det-div top-1 should be the nearest neighbor (highest similarity)" + ); + + // Pure greedy orthogonalization (eta == 0) should also produce a valid subset. + let pure_params = DeterminantDiversityParams::new(2.0, 0.0).unwrap(); + let pure_result = search_engine + .search( + &query_vector, + return_list_size, + search_list_size, + Some(4), + None, + SearchPostProcessorKind::DeterminantDiversity(pure_params), + false, + ) + .unwrap(); + let pure_ids: Vec = pure_result.results.iter().map(|r| r.vertex_id).collect(); + for id in &pure_ids { + assert!( + baseline_ids.contains(id), + "det-div(eta=0) selected id {} that is not in the search candidate pool", + id + ); + } + + // The vector_filter is honored by det-div: filter out the baseline top-1 and + // verify it is excluded from the det-div results. + let excluded = baseline_top1.vertex_id; + let filter: VectorFilter = Box::new(move |id| *id != excluded); + let filtered = search_engine + .search( + &query_vector, + return_list_size, + search_list_size, + Some(4), + Some(filter), + SearchPostProcessorKind::DeterminantDiversity(params), + false, + ) + .unwrap(); + let filtered_ids: Vec = filtered.results.iter().map(|r| r.vertex_id).collect(); + assert!( + !filtered_ids.contains(&excluded), + "det-div results must respect the vector filter" + ); + } + #[cfg(feature = "experimental_diversity_search")] #[test] fn test_disk_search_diversity_search() { @@ -1875,6 +2185,7 @@ mod disk_provider_tests { &mut indices, &mut distances, &mut associated_data, + None::>, &vector_filter, is_flat_search, ); @@ -1897,6 +2208,7 @@ mod disk_provider_tests { 10, None, // beam_width Some(Box::new(vector_filter)), + SearchPostProcessorKind::None, is_flat_search, ); diff --git a/diskann-providers/src/model/graph/provider/async_/inmem/full_precision.rs b/diskann-providers/src/model/graph/provider/async_/inmem/full_precision.rs index 07d47d2f6..15c2d6af1 100644 --- a/diskann-providers/src/model/graph/provider/async_/inmem/full_precision.rs +++ b/diskann-providers/src/model/graph/provider/async_/inmem/full_precision.rs @@ -23,6 +23,7 @@ use diskann::{ }; use diskann_utils::future::AsyncFriendly; +use diskann_utils::views::Matrix; use diskann_vector::{DistanceFunction, PreprocessedDistanceFunction, distance::Metric}; use crate::model::graph::provider::async_::{ @@ -34,6 +35,7 @@ use crate::model::graph::provider::async_::{ inmem::DefaultProvider, postprocess::{AsDeletionCheck, DeletionCheck, RemoveDeletedIdsAndCopy}, }; +use crate::model::graph::provider::{DeterminantDiversityParams, determinant_diversity}; /// A type alias for the DefaultProvider with full-precision as the primary vector store. pub type FullPrecisionProvider = @@ -401,6 +403,71 @@ where } } +/// A [`SearchPostProcess`]or that reranks a full-precision candidate stream using the +/// Determinant-Diversity algorithm, reordering results to promote geometric diversity +/// while preserving relevance to the query. +#[derive(Debug, Clone, Copy)] +pub struct DeterminantDiversity { + params: DeterminantDiversityParams, +} + +impl DeterminantDiversity { + /// Construct a new [`DeterminantDiversity`] post-processor with the given parameters. + pub const fn new(params: DeterminantDiversityParams) -> Self { + Self { params } + } +} + +impl<'a, A> glue::SearchPostProcess for DeterminantDiversity +where + A: HasId + GetFullPrecision + Send + Sync, +{ + type Error = ANNError; + + fn post_process( + &self, + accessor: &mut A, + query: &'a [f32], + candidates: I, + output: &mut B, + ) -> impl Future> + Send + where + I: Iterator> + Send, + B: SearchOutputBuffer + Send + ?Sized, + { + let candidates: Vec> = candidates.collect(); + let candidate_count = candidates.len(); + let store: &FullPrecisionStore = accessor.as_full_precision(); + let mut vectors = Matrix::new(0.0f32, candidate_count, query.len()); + let mut ids = Vec::with_capacity(candidate_count); + let mut distances = Vec::with_capacity(candidate_count); + + for (i, candidate) in candidates.into_iter().enumerate() { + // SAFETY: We accept potential unsynchronized concurrent mutation, matching the + // pattern used by `Rerank` above. + let vector = unsafe { store.get_vector_sync(candidate.id.into_usize()) }; + ids.push(candidate.id); + distances.push(candidate.distance); + vectors.row_mut(i).copy_from_slice(vector); + } + + let indices = match determinant_diversity( + vectors.as_mut_view(), + &distances, + query, + candidate_count, + &self.params, + ) { + Ok(indices) => indices, + Err(error) => return std::future::ready(Err(error.into())), + }; + + let reranked = indices.into_iter().map(|idx| (ids[idx], distances[idx])); + + std::future::ready(Ok(output.extend(reranked))) + } +} + //////////////// // Strategies // //////////////// diff --git a/diskann-providers/src/model/graph/provider/async_/inmem/mod.rs b/diskann-providers/src/model/graph/provider/async_/inmem/mod.rs index b7845f13c..79e755c64 100644 --- a/diskann-providers/src/model/graph/provider/async_/inmem/mod.rs +++ b/diskann-providers/src/model/graph/provider/async_/inmem/mod.rs @@ -19,10 +19,11 @@ pub use product::DefaultQuant; pub mod spherical; mod full_precision; +pub(super) use full_precision::Rerank; pub use full_precision::{ - CreateFullPrecision, FullAccessor, FullPrecisionProvider, FullPrecisionStore, + CreateFullPrecision, DeterminantDiversity, FullAccessor, FullPrecisionProvider, + FullPrecisionStore, GetFullPrecision, }; -pub(super) use full_precision::{GetFullPrecision, Rerank}; #[cfg(test)] pub mod product; diff --git a/diskann-providers/src/model/graph/provider/async_/mod.rs b/diskann-providers/src/model/graph/provider/async_/mod.rs index 2cd108974..6dd707003 100644 --- a/diskann-providers/src/model/graph/provider/async_/mod.rs +++ b/diskann-providers/src/model/graph/provider/async_/mod.rs @@ -7,11 +7,9 @@ pub mod experimental; pub mod common; pub use common::{PrefetchCacheLineLevel, StartPoints, VectorGuard}; -pub(crate) mod postprocess; - pub mod distances; - pub mod memory_vector_provider; +pub(crate) mod postprocess; pub use memory_vector_provider::MemoryVectorProviderAsync; pub mod memory_quant_vector_provider; diff --git a/diskann-providers/src/model/graph/provider/determinant_diversity.rs b/diskann-providers/src/model/graph/provider/determinant_diversity.rs new file mode 100644 index 000000000..5e35c8708 --- /dev/null +++ b/diskann-providers/src/model/graph/provider/determinant_diversity.rs @@ -0,0 +1,921 @@ +/* + * Copyright (c) Microsoft Corporation. + * Licensed under the MIT license. + */ + +//! Determinant-Diversity post-processing for search results. +//! +//! This module implements the Determinant-Diversity algorithm for diversity-promoting +//! reranking of approximate nearest neighbor search results. The algorithm takes +//! relevance-ranked candidates and reorders them to maximize geometric diversity +//! while maintaining relevance to the original query. +//! +//! # Algorithm Overview +//! +//! Determinant-Diversity selects a diverse subset from an initial set of candidates +//! by iteratively choosing points that maximize the determinant of the distance matrix. +//! This creates a diverse set that is both relevant to the query and geometrically spread out. +//! +//! Concretely, each candidate vector v_i is scaled by a relevance weight +//! alpha_i = similarity(d_i)^power / sqrt(eta) derived from its distance d_i +//! to the query (see `distance_to_similarity`). Letting X be the matrix of +//! scaled rows x_i = alpha_i * v_i, we approximately maximize +//! det(X_S * X_S^T + eta * I) over subsets S of size k via greedy pivoted +//! Gram-Schmidt: at each step we pick the row with the largest residual norm +//! and deflate the rest against it. See [`greedy_orthogonal_select`] for the +//! full derivation. +//! +//! # Parameters +//! +//! - **power**: Relevance weighting exponent (must be > 0.0). Controls the emphasis on +//! maintaining relevance scores from the initial search. Higher values prefer relevance +//! over diversity. +//! +//! - **eta**: Numerical stability parameter (must be >= 0.0). Used for ridge regularization: +//! - `eta = 0`: Exact determinant computation (can be numerically unstable for some inputs) +//! - `eta > 0`: Ridge-regularized computation for improved numerical stability +//! +//! # Variants +//! +//! The public entry point is [`determinant_diversity`]. +//! It applies either the unregularized (`eta == 0`) or ridge-regularized (`eta > 0`) +//! formulation internally. +//! +//! # Time Complexity +//! +//! O(n * k * dim), where n is number of candidates, k is requested output size, +//! and dim is vector dimensionality. +//! +//! # References +//! +//! The algorithm is based on diversity-promoting ranking methods for nearest neighbor search, +//! as used in approximate nearest neighbor indices like DiskANN. + +use std::fmt; + +use diskann_utils::views::MutMatrixView; +use diskann_vector::{MathematicalValue, PureDistanceFunction, distance::InnerProduct}; + +/// Parameters for Determinant-Diversity post-processor with validation. +/// +/// Determinant-Diversity is a diversity-promoting reranking algorithm that takes +/// relevance-ranked neighbors and reorders them to maximize geometric diversity +/// while maintaining relevance. +/// +/// # Parameters +/// +/// - `power`: Relevance weighting exponent. Controls the emphasis on maintaining +/// relevance scores from the original search. Must be > 0.0. +/// +/// - `eta`: Numerical stability parameter for ridge-regularization. Controls the +/// trade-off between exact determinant computation (eta=0) and numerical robustness +/// (eta>0). Must be >= 0.0. +/// +/// # Errors +/// +/// Construction fails if: +/// - `power` is non-finite or `<= 0.0` (invalid power weighting) +/// - `eta` is non-finite or `< 0.0` (negative stability parameter) +#[derive(Debug, Clone, Copy)] +pub struct DeterminantDiversityParams { + /// Relevance weighting exponent. Must be > 0.0. + power: f32, + /// Numerical stability parameter. Must be >= 0.0. + eta: f32, +} + +impl DeterminantDiversityParams { + /// Create and validate new Determinant-Diversity parameters. + /// + /// # Errors + /// + /// Returns an error if validation fails: + /// - `power` is non-finite or `<= 0.0`: invalid relevance weighting + /// - `eta` is non-finite or `< 0.0`: invalid numerical stability parameter + pub fn new(power: f32, eta: f32) -> Result { + if !power.is_finite() || power <= 0.0 { + return Err(DeterminantDiversityError::InvalidPower(power)); + } + if !eta.is_finite() || eta < 0.0 { + return Err(DeterminantDiversityError::InvalidEta(eta)); + } + Ok(Self { power, eta }) + } + + /// Get power parameter. + #[inline] + pub fn power(&self) -> f32 { + self.power + } + + /// Get eta parameter. + #[inline] + pub fn eta(&self) -> f32 { + self.eta + } +} + +impl fmt::Display for DeterminantDiversityParams { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!( + f, + "DeterminantDiversity(power={}, eta={})", + self.power, self.eta + ) + } +} + +/// Error produced when constructing [`DeterminantDiversityParams`] or running +/// [`determinant_diversity`]. +#[derive(Debug, Clone, thiserror::Error)] +pub enum DeterminantDiversityError { + #[error("determinant-diversity power must be > 0.0, got: {0}")] + InvalidPower(f32), + #[error("determinant-diversity eta must be >= 0.0, got: {0}")] + InvalidEta(f32), + #[error( + "determinant-diversity candidate matrix has {candidate} columns but query dimension is {query}" + )] + QueryDimensionMismatch { + /// Number of dimensions in the query. + query: usize, + /// Number of columns in the candidate matrix. + candidate: usize, + }, + #[error("determinant-diversity received {distances} distances for {candidates} candidate rows")] + DistanceCountMismatch { + /// Number of supplied distances. + distances: usize, + /// Number of candidate rows. + candidates: usize, + }, +} + +impl From for diskann::ANNError { + #[track_caller] + fn from(err: DeterminantDiversityError) -> Self { + use diskann::ANNErrorKind; + let kind = match err { + DeterminantDiversityError::InvalidPower(_) + | DeterminantDiversityError::InvalidEta(_) => ANNErrorKind::IndexConfigError, + DeterminantDiversityError::QueryDimensionMismatch { .. } + | DeterminantDiversityError::DistanceCountMismatch { .. } => { + ANNErrorKind::DimensionMismatchError + } + }; + diskann::ANNError::new(kind, err) + } +} + +#[derive(Clone, Copy)] +struct DistanceRange { + min: f32, + max: f32, +} + +/// Rerank `candidates` to promote geometric diversity while preserving relevance. +/// +/// Returns the indices (into the rows of `candidates`) of the selected vectors, +/// in selection order, with at most `k` entries. +/// +/// # Arguments +/// +/// - `candidates`: row-major matrix whose `i`-th row is the full-precision vector +/// of the `i`-th candidate. Its column count must equal `query.len()`. +/// - `distances`: candidate-to-query distances, parallel to the rows of +/// `candidates`. Its length must equal `candidates.nrows()`. +/// - `query`: the query vector. +/// - `k`: maximum number of results to return; clamped to `candidates.nrows()`. +/// - `params`: relevance/regularization parameters. +/// +/// # Errors +/// +/// Returns [`DeterminantDiversityError::QueryDimensionMismatch`] if the candidate +/// matrix column count does not equal `query.len()`, or +/// [`DeterminantDiversityError::DistanceCountMismatch`] if `distances.len()` does +/// not equal `candidates.nrows()`. These structural invariants are validated up +/// front so they are enforced consistently even when an input is empty. +/// +/// An empty candidate set, a `k` of zero, or zero-dimensional vectors yield an +/// empty result. +pub fn determinant_diversity( + candidates: MutMatrixView<'_, f32>, + distances: &[f32], + query: &[f32], + k: usize, + params: &DeterminantDiversityParams, +) -> Result, DeterminantDiversityError> { + // Validate structural invariants first so they are enforced consistently, + // regardless of whether any individual input happens to be empty. + if candidates.ncols() != query.len() { + return Err(DeterminantDiversityError::QueryDimensionMismatch { + query: query.len(), + candidate: candidates.ncols(), + }); + } + + if distances.len() != candidates.nrows() { + return Err(DeterminantDiversityError::DistanceCountMismatch { + distances: distances.len(), + candidates: candidates.nrows(), + }); + } + + let k = k.min(candidates.nrows()); + if k == 0 || candidates.ncols() == 0 { + return Ok(Vec::new()); + } + + let distance_range = { + let mut min_distance = f32::INFINITY; + let mut max_distance = f32::NEG_INFINITY; + + for distance in distances { + min_distance = min_distance.min(*distance); + max_distance = max_distance.max(*distance); + } + + DistanceRange { + min: min_distance, + max: max_distance, + } + }; + + // For eta=0, the inv_sqrt_eta factor is 1.0 (greedy orthogonalization without regularization). + // For eta>0, the factor scales residuals for ridge-regularized determinant computation. + let inv_sqrt_eta = if params.eta() > 0.0 { + 1.0 / params.eta().sqrt() + } else { + 1.0 + }; + + Ok(greedy_orthogonal_select( + candidates, + distances, + k, + params.power(), + inv_sqrt_eta, + distance_range, + )) +} + +/// Core greedy selection algorithm for Determinant-Diversity. +/// +/// # Mathematical formulation +/// +/// Let the input candidate set be represented by matrix rows v_i (for i = 1..n) +/// and a parallel distance slice d_i, where d_i is the candidate distance to the +/// query and v_i is the full-precision vector in R^dim. Define the per-candidate scale +/// +/// ```text +/// alpha_i = similarity(d_i)^power * (1 / sqrt(eta)) +/// ``` +/// +/// where similarity(.) in [0, 1] is the normalized "lower-distance-is-better" +/// score from `distance_to_similarity`, and `1 / sqrt(eta)` is `inv_sqrt_eta` +/// (it equals 1 in the unregularized eta == 0 branch -- see the caller). The +/// scaled vectors are +/// +/// ```text +/// x_i = alpha_i * v_i. +/// ``` +/// +/// Define the (regularized) Gram matrix of any subset S = { i_1, ..., i_m } as +/// +/// ```text +/// G_S = X_S * X_S^T + eta * I, +/// ``` +/// +/// where X_S stacks the rows x_i for i in S. The goal is to pick S of size k +/// that approximately maximizes det(G_S), i.e. selects vectors whose scaled +/// rows span the largest volume -- geometrically diverse, while alpha_i keeps +/// relevance. We solve this greedily, which is equivalent to *column-pivoted +/// modified Gram-Schmidt / QR* on the rows x_i. +/// +/// # Algorithm (pivoted QR view) +/// +/// Maintain a residual vector r_i for each candidate. Initially r_i = x_i and +/// ||r_i||^2 = . At each step: +/// +/// 1. **Pivot.** Pick the available candidate i* with the largest residual +/// norm: i* = argmax over available i of ||r_i||^2. This is the direction +/// that contributes the most to the running volume / determinant expansion +/// (since det(G_S) = product of ||r_{i_j*}||^2 along the selection path). +/// +/// 2. **Project & deflate.** For every remaining candidate i, project r_i +/// onto the chosen pivot direction r* = r_{i*} and remove that component: +/// +/// ```text +/// pi_i = / ||r*||^2 +/// r_i := r_i - pi_i * r* +/// ``` +/// +/// 3. **Norm update (Pythagoras).** Because the new r_i is orthogonal to r* +/// by construction, +/// +/// ```text +/// ||r_i_new||^2 = ||r_i||^2 - pi_i^2 * ||r*||^2. +/// ``` +/// +/// We update the cached squared norm in place using this identity (clamped +/// at 0 for numerical safety) instead of recomputing the dot product. +/// +/// Repeat until k pivots are selected. The returned order is the order in +/// which pivots were chosen, which is the diversity-promoting reranking. +/// +/// # Parameters +/// +/// - `inv_sqrt_eta`: scalar 1 / sqrt(eta) baked into the residuals so that +/// the residual norms reflect the regularized Gram matrix X X^T + eta * I. +/// Use 1.0 for the unregularized (eta == 0) variant. +/// - `distances`: candidate distances parallel to matrix rows. +/// - `power`: relevance exponent applied to the per-candidate similarity. +/// - `distance_range`: min/max distances among the candidates, used to +/// normalize distances into similarities in [0, 1]. +/// +/// # Complexity +/// +/// O(n * k * dim) -- for each of k pivots we touch all n residual rows of +/// length `dim`. Memory is O(n * dim) for the contiguous residual matrix. +fn greedy_orthogonal_select( + mut candidates: MutMatrixView<'_, f32>, + distances: &[f32], + k: usize, + power: f32, + inv_sqrt_eta: f32, + distance_range: DistanceRange, +) -> Vec { + let n = candidates.nrows(); + let k = k.min(n); + if k == 0 { + return Vec::new(); + } + + // Cached squared norms ||r_i||^2 for each row. Updated in place via the + // Pythagorean identity in step 3 above. + let mut norms_sq = Vec::with_capacity(n); + + // Step 0: scale rows in-place to initialize residuals r_i = alpha_i * v_i, + // then compute their squared norms. + // alpha_i = similarity(d_i)^power * inv_sqrt_eta. + for (i, distance_to_query) in distances.iter().enumerate() { + let scale = + distance_to_similarity(*distance_to_query, distance_range).powf(power) * inv_sqrt_eta; + for value in candidates.row_mut(i) { + *value *= scale; + } + let norm_sq = dot_product(candidates.row(i), candidates.row(i)); + norms_sq.push(norm_sq); + } + + let mut available = vec![true; n]; + let mut selected = Vec::with_capacity(k); + // Scratch buffer: projection coefficient pi_i for each row against the + // current pivot. Sized n once and overwritten each iteration. + let mut projections = vec![0.0f32; n]; + + for _ in 0..k { + // --- Step 1: Pivot --- + // Pick the available candidate with the largest residual norm. + // partial_cmp can return None for NaN; treat NaN as Equal so the + // iterator's max picks the first non-NaN candidate it has seen. + let best_idx = available + .iter() + .enumerate() + .filter(|&(_, &avail)| avail) + .max_by(|(i, _), (j, _)| { + norms_sq[*i] + .partial_cmp(&norms_sq[*j]) + .unwrap_or(std::cmp::Ordering::Equal) + }) + .map(|(i, _)| i); + + let Some(selected_index) = best_idx else { + break; + }; + + selected.push(selected_index); + available[selected_index] = false; + + // No more deflation needed once the last pivot has been chosen. + if selected.len() == k { + break; + } + + let best_norm_sq = norms_sq[selected_index]; + // If the pivot has zero (or numerically negative) residual norm, the + // remaining rows already lie in the span of previously selected + // pivots; skip deflation to avoid dividing by zero. + if best_norm_sq <= 0.0 { + continue; + } + + // 1 / ||r*||^2, factored out of the projection formula below. + let inv_norm_sq = 1.0 / best_norm_sq; + // Snapshot the pivot row r* before mutably iterating over the other + // rows of `residuals` (they share the same backing storage). + let r_star_copy: Vec = candidates.row(selected_index).to_vec(); + + // --- Step 2a: Compute projection coefficients pi_i = / ||r*||^2. + for i in 0..n { + if !available[i] { + projections[i] = 0.0; + } else { + projections[i] = dot_product(candidates.row(i), &r_star_copy) * inv_norm_sq; + } + } + + // --- Step 2b: Deflate r_i <- r_i - pi_i * r*, and + // --- Step 3: update ||r_i||^2 <- ||r_i||^2 - pi_i^2 * ||r*||^2. + for i in 0..n { + if !available[i] { + continue; + } + + let projection = projections[i]; + for (residual, &star) in candidates.row_mut(i).iter_mut().zip(r_star_copy.iter()) { + *residual -= projection * star; + } + + // Pythagorean update; clamp at 0 to absorb floating-point drift. + norms_sq[i] = (norms_sq[i] - projection * projection * best_norm_sq).max(0.0); + } + } + + selected +} + +/// Maps a raw distance into a similarity score in `(0, 1]` using the candidate +/// set's distance range. +/// +/// DiskANN distance semantics are *lower is better*, so we invert and rescale +/// against the observed [min, max] range: +/// +/// ```text +/// similarity(d) = max((d_max - d) / (d_max - d_min), 0) + EPSILON. +/// ``` +/// +/// - The numerator flips the order so that the *closest* candidate gets the +/// highest similarity (~1) and the *farthest* gets ~0. +/// - The denominator is clamped to `f32::EPSILON` so a degenerate range +/// (d_max == d_min) produces a finite, equal score for all candidates +/// instead of a divide-by-zero. +/// - The trailing `+ EPSILON` ensures the result is strictly positive, so +/// that `similarity(d).powf(power)` never produces a hard zero scale (which +/// would erase a candidate from the QR pivoting and bias selection). +#[inline(always)] +fn distance_to_similarity(distance: f32, distance_range: DistanceRange) -> f32 { + let span = (distance_range.max - distance_range.min).max(f32::EPSILON); + + // Distances are lower-is-better in DiskANN distance semantics. + ((distance_range.max - distance) / span).max(0.0) + f32::EPSILON +} + +#[inline] +fn dot_product(a: &[f32], b: &[f32]) -> f32 { + >>::evaluate(a, b) + .into_inner() +} + +#[cfg(test)] +mod tests { + use super::*; + use diskann_quantization::num::Positive; + use diskann_utils::views::Matrix; + + #[test] + fn test_valid_params() { + assert!(DeterminantDiversityParams::new(1.0, 0.0).is_ok()); + assert!(DeterminantDiversityParams::new(0.5, 1.5).is_ok()); + assert!(DeterminantDiversityParams::new(2.0, 0.1).is_ok()); + } + + #[test] + fn test_invalid_power() { + assert!(DeterminantDiversityParams::new(0.0, 1.0).is_err()); + assert!(DeterminantDiversityParams::new(-1.0, 1.0).is_err()); + } + + #[test] + fn test_invalid_eta() { + assert!(DeterminantDiversityParams::new(1.0, -0.1).is_err()); + } + + #[test] + fn test_invalid_non_finite_values() { + assert!(DeterminantDiversityParams::new(f32::NAN, 0.1).is_err()); + assert!(DeterminantDiversityParams::new(f32::INFINITY, 0.1).is_err()); + assert!(DeterminantDiversityParams::new(1.0, f32::NAN).is_err()); + assert!(DeterminantDiversityParams::new(1.0, f32::INFINITY).is_err()); + } + + #[test] + fn test_display() { + let params = DeterminantDiversityParams::new(1.5, 0.5).unwrap(); + assert_eq!( + params.to_string(), + "DeterminantDiversity(power=1.5, eta=0.5)" + ); + } + + fn run_with_ids( + candidates: Vec<(u32, f32, Vec)>, + query: &[f32], + k: usize, + eta: f32, + power: Positive, + ) -> Vec<(u32, f32)> { + if candidates.is_empty() { + return Vec::new(); + } + + let dim = candidates[0].2.len(); + let mut matrix = Matrix::new(0.0f32, candidates.len(), dim); + let mut ids = Vec::with_capacity(candidates.len()); + let mut distances = Vec::with_capacity(candidates.len()); + + for (i, (id, distance, vector)) in candidates.into_iter().enumerate() { + ids.push(id); + distances.push(distance); + matrix.row_mut(i).copy_from_slice(&vector); + } + + let params = DeterminantDiversityParams::new(power.into_inner(), eta).unwrap(); + determinant_diversity(matrix.as_mut_view(), &distances, query, k, ¶ms) + .expect("valid determinant-diversity inputs") + .into_iter() + .map(|idx| (ids[idx], distances[idx])) + .collect() + } + + /// Test helper: wrap a positive f32 power value. + fn p(value: f32) -> Positive { + Positive::new(value).unwrap() + } + + #[test] + fn test_empty_candidates() { + let result = run_with_ids(Vec::new(), &[1.0, 2.0], 5, 0.5, p(1.0)); + assert_eq!(result.len(), 0); + } + + #[test] + fn test_empty_query_is_dimension_mismatch() { + // A zero-length query against non-empty candidates is a structural + // mismatch (candidate columns != query dimension), not a valid request + // that trivially returns nothing. + let mut matrix = Matrix::new(0.0f32, 1, 2); + matrix.row_mut(0).copy_from_slice(&[1.0, 2.0]); + let params = DeterminantDiversityParams::new(1.0, 0.5).unwrap(); + + let result = determinant_diversity(matrix.as_mut_view(), &[0.5], &[], 5, ¶ms); + assert!(matches!( + result, + Err(DeterminantDiversityError::QueryDimensionMismatch { + query: 0, + candidate: 2, + }) + )); + } + + #[test] + fn test_mismatched_dimensions_errors() { + // Candidate vectors are 2-D, but the query is 3-D, so + // `determinant_diversity` should report a dimension mismatch. + let mut matrix = Matrix::new(0.0f32, 1, 2); + matrix.row_mut(0).copy_from_slice(&[1.0, 2.0]); + let params = DeterminantDiversityParams::new(1.0, 0.5).unwrap(); + + let result = + determinant_diversity(matrix.as_mut_view(), &[0.5], &[1.0, 2.0, 3.0], 5, ¶ms); + assert!(matches!( + result, + Err(DeterminantDiversityError::QueryDimensionMismatch { + query: 3, + candidate: 2, + }) + )); + } + + #[test] + fn test_mismatched_distances_errors() { + // Two candidate rows but only one distance is a structural mismatch. + let mut matrix = Matrix::new(0.0f32, 2, 2); + matrix.row_mut(0).copy_from_slice(&[1.0, 0.0]); + matrix.row_mut(1).copy_from_slice(&[0.0, 1.0]); + let params = DeterminantDiversityParams::new(1.0, 0.5).unwrap(); + + let result = determinant_diversity(matrix.as_mut_view(), &[0.5], &[1.0, 1.0], 2, ¶ms); + assert!(matches!( + result, + Err(DeterminantDiversityError::DistanceCountMismatch { + distances: 1, + candidates: 2, + }) + )); + } + + #[test] + fn test_single_candidate() { + let candidates = vec![(0u32, 0.5, vec![1.0, 2.0])]; + let query = &[1.0, 2.0]; + let result = run_with_ids(candidates, query, 5, 0.5, p(1.0)); + assert_eq!(result.len(), 1); + assert_eq!(result[0].0, 0); + } + + #[test] + fn test_k_larger_than_candidates() { + let candidates = vec![(0u32, 0.5, vec![1.0, 0.0]), (1u32, 0.3, vec![0.0, 1.0])]; + let query = &[1.0, 1.0]; + let result = run_with_ids(candidates, query, 10, 0.5, p(1.0)); + assert_eq!(result.len(), 2); // Should return min(k, candidates.len()) + } + + #[test] + fn test_with_eta_diversity() { + let candidates = vec![ + (0u32, 0.1, vec![1.0, 0.0]), + (1u32, 0.2, vec![0.9, 0.1]), + (2u32, 0.3, vec![0.8, 0.2]), + ]; + let query = &[1.0, 1.0]; + let result = run_with_ids(candidates, query, 2, 1.0, p(1.0)); + + assert_eq!(result.len(), 2); + // Should select based on diversity metric with eta > 0 + assert!(result.iter().all(|(id, _)| *id < 3)); + } + + #[test] + fn test_without_eta_greedy() { + let candidates = vec![ + (0u32, 0.1, vec![1.0, 0.0]), + (1u32, 0.2, vec![0.9, 0.1]), + (2u32, 0.3, vec![0.8, 0.2]), + ]; + let query = &[1.0, 1.0]; + let result = run_with_ids(candidates, query, 2, 0.0, p(1.0)); + + assert_eq!(result.len(), 2); + // Should select based on greedy orthogonalization (eta == 0) + assert!(result.iter().all(|(id, _)| *id < 3)); + } + + #[test] + fn test_power_parameter() { + let candidates = vec![(0u32, 0.1, vec![1.0, 0.0]), (1u32, 0.2, vec![0.0, 1.0])]; + let query = &[1.0, 1.0]; + + // Test with different power values - should still work without panicking + let result1 = run_with_ids(candidates.clone(), query, 2, 0.0, p(1.0)); + let result2 = run_with_ids(candidates, query, 2, 0.0, p(2.0)); + + assert_eq!(result1.len(), 2); + assert_eq!(result2.len(), 2); + } + + #[test] + fn test_distances_preserved() { + let candidates = vec![(0u32, 0.5, vec![1.0, 0.0]), (1u32, 0.3, vec![0.0, 1.0])]; + let query = &[1.0, 1.0]; + let result = run_with_ids(candidates, query, 2, 0.0, p(1.0)); + + // Verify that distances are preserved from input + assert!(result.iter().all(|(_, dist)| *dist == 0.5 || *dist == 0.3)); + } + + /// Verify that diversity is actually promoted: when candidates lie along orthogonal + /// directions, a 2-element diverse subset should choose orthogonal pairs over similar ones. + /// + /// Using equal distances ensures pure diversity drives selection without relevance weighting. + #[test] + fn test_diversity_selects_orthogonal_candidates() { + // Three candidates with equal distance: two very similar (nearly parallel) and one orthogonal. + // Equal distances remove relevance weighting, so pure diversity drives selection. + let candidates = vec![ + (0u32, 0.1, vec![1.0, 0.0, 0.0]), // along x + (1u32, 0.1, vec![0.0, 1.0, 0.0]), // along y - orthogonal to 0 + (2u32, 0.1, vec![0.99, 0.01, 0.0]), // nearly parallel to 0 + ]; + let query = &[1.0, 1.0, 1.0]; + let result = run_with_ids(candidates, query, 2, 0.0, p(1.0)); + + // Should select 2 candidates + assert_eq!(result.len(), 2); + // The diverse pair is (0, 1) - orthogonal. Candidate 2 is redundant with 0. + let ids: Vec = result.iter().map(|(id, _)| *id).collect(); + assert!(ids.contains(&0), "Expected candidate 0 to be selected"); + assert!( + ids.contains(&1), + "Expected candidate 1 (orthogonal) to be selected, not redundant candidate 2" + ); + } + + /// Verify eta variant selects the same k results. + #[test] + fn test_diversity_selects_orthogonal_candidates_with_eta() { + let candidates = vec![ + (0u32, 0.1, vec![1.0, 0.0, 0.0]), + (1u32, 0.1, vec![0.0, 1.0, 0.0]), + (2u32, 0.1, vec![0.99, 0.01, 0.0]), + ]; + let query = &[1.0, 1.0, 1.0]; + let result = run_with_ids(candidates, query, 2, 0.5, p(1.0)); + + assert_eq!(result.len(), 2); + let ids: Vec = result.iter().map(|(id, _)| *id).collect(); + assert!(ids.contains(&0), "Expected candidate 0 to be selected"); + assert!( + ids.contains(&1), + "Expected candidate 1 (orthogonal) to be selected" + ); + } + + /// Verify power=high weights nearby candidates (distance=0.1) more strongly than far ones. + #[test] + fn test_high_power_prefers_closer_candidates() { + // Two orthogonal candidates: one close, one far + let candidates = vec![ + (0u32, 0.1, vec![1.0, 0.0]), // close to query + (1u32, 0.9, vec![0.0, 1.0]), // far from query + ]; + let query = &[1.0, 0.0]; + + // With high power, relevance is heavily weighted so the closest candidate dominates + let result = run_with_ids(candidates.clone(), query, 1, 0.0, p(10.0)); + assert_eq!(result.len(), 1); + // Closest candidate should be preferred due to high power weighting + assert_eq!( + result[0].0, 0, + "Closest candidate should be selected with high power" + ); + } + + /// Verify that distance-to-similarity conversion handles equal distances gracefully. + #[test] + fn test_equal_distances() { + let candidates = vec![ + (0u32, 0.5, vec![1.0, 0.0]), + (1u32, 0.5, vec![0.0, 1.0]), // same distance as 0 + ]; + let query = &[1.0, 0.0]; + let result = run_with_ids(candidates, query, 2, 0.0, p(1.0)); + + // Should still return candidates without panicking + assert_eq!(result.len(), 2); + } + + /// Test eta=0 exactly matches greedy orthogonalization path. + #[test] + fn test_eta_zero_is_greedy_path() { + let candidates = vec![ + (0u32, 0.1, vec![1.0, 0.0]), + (1u32, 0.2, vec![0.0, 1.0]), + (2u32, 0.3, vec![0.5, 0.5]), + ]; + let query = &[1.0, 1.0]; + // eta=0.0 must invoke greedy path, not ridge-regularized + let result = run_with_ids(candidates, query, 2, 0.0, p(1.0)); + assert_eq!(result.len(), 2); + } + + /// k = 0 should return an empty result without panicking, even when + /// candidates are otherwise valid. + #[test] + fn test_k_zero_returns_empty() { + let candidates = vec![(0u32, 0.1, vec![1.0, 0.0]), (1u32, 0.2, vec![0.0, 1.0])]; + let query = &[1.0, 1.0]; + let result = run_with_ids(candidates, query, 0, 0.5, p(1.0)); + assert_eq!(result.len(), 0); + } + + /// Zero-dimensional candidate vectors must be rejected gracefully (the + /// algorithm has no meaningful work to do with empty vectors and the + /// query is treated as effectively empty). + #[test] + fn test_zero_dimensional_candidates() { + let candidates = vec![(0u32, 0.1, Vec::::new()), (1u32, 0.2, Vec::new())]; + // Query is non-empty but candidate vectors have dim 0; we only reach + // the empty-vector early return if the dimension check would have + // matched (so use a 0-length query here to stay on the early path). + let query: &[f32] = &[]; + let result = run_with_ids(candidates, query, 2, 0.0, p(1.0)); + assert_eq!(result.len(), 0); + } + + /// Selecting all n candidates must exit the loop cleanly via the + /// `selected.len() == k` early break (no extra deflation pass). + #[test] + fn test_k_equals_candidates_returns_all() { + let candidates = vec![ + (10u32, 0.1, vec![1.0, 0.0]), + (20u32, 0.2, vec![0.0, 1.0]), + (30u32, 0.3, vec![1.0, 1.0]), + ]; + let query = &[1.0, 1.0]; + let result = run_with_ids(candidates, query, 3, 0.0, p(1.0)); + assert_eq!(result.len(), 3); + + // Result IDs must be a permutation of the input IDs (no duplicates, + // none lost). + let mut ids: Vec = result.iter().map(|(id, _)| *id).collect(); + ids.sort_unstable(); + assert_eq!(ids, vec![10, 20, 30]); + } + + /// When all candidates lie on a single line through the origin, the + /// second-and-later pivots collapse to zero-norm residuals. The pivot + /// loop must exit cleanly and still return up to `k` candidates without + /// dividing by zero. + #[test] + fn test_collinear_candidates_no_division_by_zero() { + // All vectors are positive multiples of (1, 0). After scaling and + // picking the first pivot, every remaining residual is exactly 0. + let candidates = vec![ + (0u32, 0.1, vec![1.0, 0.0]), + (1u32, 0.1, vec![2.0, 0.0]), + (2u32, 0.1, vec![3.0, 0.0]), + ]; + let query = &[1.0, 0.0]; + let result = run_with_ids(candidates, query, 3, 0.0, p(1.0)); + assert_eq!(result.len(), 3); + + let mut ids: Vec = result.iter().map(|(id, _)| *id).collect(); + ids.sort_unstable(); + assert_eq!(ids, vec![0, 1, 2]); + } + + /// The first selected element should be the most relevant one (highest + /// similarity, i.e. smallest distance) when all candidates have equal + /// vector norms. This pins down the initial pivot choice. + #[test] + fn test_first_pivot_is_most_relevant_at_equal_norms() { + // Three orthogonal unit vectors with strictly increasing distances. + // Largest similarity → smallest distance → id=0 should be picked + // first. + let candidates = vec![ + (0u32, 0.1, vec![1.0, 0.0, 0.0]), + (1u32, 0.5, vec![0.0, 1.0, 0.0]), + (2u32, 0.9, vec![0.0, 0.0, 1.0]), + ]; + let query = &[1.0, 1.0, 1.0]; + let result = run_with_ids(candidates, query, 3, 0.0, p(2.0)); + assert_eq!(result.len(), 3); + assert_eq!(result[0].0, 0, "Most relevant candidate must be first"); + } + + /// Distances returned alongside selected ids must come from the + /// corresponding input candidate (not be reordered or recomputed). + #[test] + fn test_ids_pair_with_their_input_distance() { + let candidates = vec![(7u32, 1.5, vec![1.0, 0.0]), (9u32, 0.25, vec![0.0, 1.0])]; + let query = &[1.0, 1.0]; + let result = run_with_ids(candidates, query, 2, 0.0, p(1.0)); + assert_eq!(result.len(), 2); + + for (id, dist) in &result { + match *id { + 7 => assert_eq!(*dist, 1.5), + 9 => assert_eq!(*dist, 0.25), + other => panic!("unexpected id {other}"), + } + } + } + + /// `distance_to_similarity` must produce a strictly positive, finite + /// score even at the extremes of the observed distance range, so the + /// resulting alpha never becomes a hard zero or NaN. + #[test] + fn test_distance_to_similarity_extremes() { + let range = DistanceRange { min: 0.5, max: 2.0 }; + + let s_min = distance_to_similarity(0.5, range); + let s_max = distance_to_similarity(2.0, range); + let s_below = distance_to_similarity(-1.0, range); + let s_above = distance_to_similarity(10.0, range); + + // All scores are strictly positive (we add EPSILON) and finite. + for s in [s_min, s_max, s_below, s_above] { + assert!(s.is_finite()); + assert!(s > 0.0); + } + // Closer (smaller) distance is at least as similar as farther. + assert!(s_min >= s_max); + // A distance below the observed min still clamps to ~1 + EPSILON. + assert!(s_below >= s_min - f32::EPSILON); + // A distance above the observed max still clamps to ~EPSILON. + assert!(s_above <= s_max + f32::EPSILON); + } + + /// Degenerate range (min == max) must not divide by zero. All + /// similarities should be finite and equal. + #[test] + fn test_distance_to_similarity_degenerate_range() { + let range = DistanceRange { min: 0.7, max: 0.7 }; + let a = distance_to_similarity(0.7, range); + let b = distance_to_similarity(0.7, range); + assert!(a.is_finite() && b.is_finite()); + assert_eq!(a, b); + } +} diff --git a/diskann-providers/src/model/graph/provider/mod.rs b/diskann-providers/src/model/graph/provider/mod.rs index 0e045bfb5..ba3f8978c 100644 --- a/diskann-providers/src/model/graph/provider/mod.rs +++ b/diskann-providers/src/model/graph/provider/mod.rs @@ -6,3 +6,12 @@ pub mod async_; // Layers for the async index. pub mod layers; + +/// Determinant-diversity post-processing algorithm. +/// +/// This module is not async-specific. +/// It provides diversity-promoting reranking for nearest neighbor search results. +mod determinant_diversity; +pub use determinant_diversity::{ + DeterminantDiversityError, DeterminantDiversityParams, determinant_diversity, +}; diff --git a/diskann-tools/src/utils/search_disk_index.rs b/diskann-tools/src/utils/search_disk_index.rs index 8bbdb1c8f..3825d2631 100644 --- a/diskann-tools/src/utils/search_disk_index.rs +++ b/diskann-tools/src/utils/search_disk_index.rs @@ -9,7 +9,8 @@ use diskann::utils::IntoUsize; use diskann_disk::{ data_model::{CachingStrategy, GraphDataType}, search::provider::{ - disk_provider::DiskIndexSearcher, disk_vertex_provider_factory::DiskVertexProviderFactory, + disk_provider::{DiskIndexSearcher, SearchPostProcessorKind}, + disk_vertex_provider_factory::DiskVertexProviderFactory, }, storage::disk_index_reader::DiskIndexReader, utils::{ @@ -259,6 +260,7 @@ where l, Some(parameters.beam_width as usize), Some(vector_filter_function), + SearchPostProcessorKind::None, parameters.is_flat_search, );