Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 67 additions & 0 deletions diskann-benchmark/example/disk-index-filter.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
{
"search_directories": [
"test_data/disk_index_search"
],
"jobs": [
{
"type": "disk-index",
"content": {
"source": {
"disk-index-source": "Build",
"data_type": "float32",
"data": "disk_index_siftsmall_learn_256pts_data.fbin",
"distance": "squared_l2",
"dim": 128,
"max_degree": 32,
"l_build": 50,
"num_threads": 1,
"build_ram_limit_gb": 2.0,
"num_pq_chunks": 128,
"quantization_type": "FP",
"save_path": "siftsmall_index_filter_graph"
},
"search_phase": {
"queries": "disk_index_sample_query_10pts.fbin",
"groundtruth": "disk_index_10pts_idx_uint32_truth_search_filter_res.bin",
"search_list": [10, 20, 40],
"beam_width": 4,
"recall_at": 10,
"num_threads": 1,
"is_flat_search": false,
"distance": "squared_l2",
"vector_filters_file": "disk_index_10pts_idx_uint32_range_res_r_100000.bin"
}
}
},
{
"type": "disk-index",
"content": {
"source": {
"disk-index-source": "Build",
"data_type": "float32",
"data": "disk_index_siftsmall_learn_256pts_data.fbin",
"distance": "squared_l2",
"dim": 128,
"max_degree": 32,
"l_build": 50,
"num_threads": 1,
"build_ram_limit_gb": 2.0,
"num_pq_chunks": 128,
"quantization_type": "FP",
"save_path": "siftsmall_index_filter_flat"
},
"search_phase": {
"queries": "disk_index_sample_query_10pts.fbin",
"groundtruth": "disk_index_10pts_idx_uint32_truth_search_filter_res.bin",
"search_list": [10, 20, 40],
"beam_width": 4,
"recall_at": 10,
"num_threads": 1,
"is_flat_search": true,
"distance": "squared_l2",
"vector_filters_file": "disk_index_10pts_idx_uint32_range_res_r_100000.bin"
}
}
}
]
}
39 changes: 17 additions & 22 deletions diskann-benchmark/src/disk_index/search.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,12 @@ use diskann::utils::VectorRepr;
use diskann_benchmark_runner::{files::InputFile, utils::MicroSeconds};
use diskann_disk::{
data_model::{AdHoc, CachingStrategy},
search::provider::{
disk_provider::{DiskIndexSearcher, SearchPostProcessorKind},
disk_vertex_provider_factory::DiskVertexProviderFactory,
search::{
provider::{
disk_provider::DiskIndexSearcher,
disk_vertex_provider_factory::DiskVertexProviderFactory,
},
search_mode::SearchMode,
},
storage::disk_index_reader::DiskIndexReader,
utils::{instrumentation::PerfLogger, statistics, AlignedFileReaderFactory, QueryStatistics},
Expand All @@ -33,10 +36,7 @@ use serde::{Deserialize, Serialize};

use crate::{
disk_index::json_spancollector::JsonSpanCollector,
inputs::{
disk::{DiskIndexLoad, DiskSearchPhase},
post_processor::TopkPostProcessor,
},
inputs::disk::{DiskIndexLoad, DiskSearchPhase},
utils::{datafiles, SimilarityMeasure},
};

Expand Down Expand Up @@ -268,27 +268,22 @@ where
zipped.for_each_in_pool(
pool.as_ref(),
|(((((q, vf), id_chunk), dist_chunk), stats), rc)| {
let post_processor = search_params.post_processor.as_ref().map_or(
SearchPostProcessorKind::None,
|TopkPostProcessor::DeterminantDiversity(params)| {
SearchPostProcessorKind::DeterminantDiversity(*params)
},
// Construct the SearchMode from the JSON-driven
// `adaptive_l` is now encapsulated in `DiskSearchMode`, so the
// benchmark only supplies the per-query filter and post-processor.
let has_filter = search_params.vector_filters_file.is_some();
let mode: SearchMode<'_> = search_params.search_mode.search_mode(
has_filter,
vf,
search_params.post_processor.as_ref(),
);
let vector_filter = if search_params.vector_filters_file.is_none() {
None
} else {
Some(Box::new(move |vid: &u32| vf.contains(vid))
as Box<dyn Fn(&u32) -> bool + Send + Sync>)
};

match searcher.search(
q,
search_params.recall_at,
l,
Some(search_params.beam_width),
vector_filter,
post_processor,
search_params.is_flat_search,
mode,
) {
Ok(search_result) => {
*stats = search_result.stats.query_statistics;
Expand Down Expand Up @@ -354,7 +349,7 @@ where
num_threads: search_params.num_threads,
beam_width: search_params.beam_width,
recall_at: search_params.recall_at,
is_flat_search: search_params.is_flat_search,
is_flat_search: search_params.search_mode.is_flat_search,
distance: search_params.distance,
uses_vector_filters: search_params.vector_filters_file.is_some(),
num_nodes_to_cache: search_params.num_nodes_to_cache,
Expand Down
114 changes: 114 additions & 0 deletions diskann-benchmark/src/inputs/disk.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,15 @@
use std::{fmt, num::NonZeroUsize, path::Path};

use anyhow::Context;
#[cfg(feature = "disk-index")]

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any reason we have everything gated by #[cfg(feature = "disk-index")]?

use std::collections::HashSet;

#[cfg(feature = "disk-index")]
use diskann::graph;
use diskann_benchmark_runner::{files::InputFile, utils::datatype::DataType, Checker};
#[cfg(feature = "disk-index")]
use diskann_disk::search::search_mode::SearchMode;
#[cfg(feature = "disk-index")]
use diskann_disk::QuantizationType;
use diskann_providers::storage::{get_compressed_pq_file, get_disk_index_file, get_pq_pivot_file};
use serde::{Deserialize, Serialize};
Expand All @@ -17,6 +24,9 @@ use crate::{
utils::SimilarityMeasure,
};

#[cfg(feature = "disk-index")]
use crate::inputs::graph_index::AdaptiveL;

//////////////
// Registry //
//////////////
Expand Down Expand Up @@ -62,6 +72,80 @@ pub(crate) struct DiskIndexBuild {
pub(crate) save_path: String,
}

#[cfg(feature = "disk-index")]
#[derive(Debug, Serialize, Deserialize, Default)]
pub(crate) struct DiskSearchMode {
pub(crate) is_flat_search: bool,
#[serde(default)]
pub(crate) adaptive_l: Option<AdaptiveL>,
}

#[cfg(feature = "disk-index")]
impl DiskSearchMode {
pub(crate) fn search_mode<'a>(
&'a self,
has_vector_filters: bool,
vector_filter: &'a HashSet<u32>,
post_processor: Option<&TopkPostProcessor>,
) -> SearchMode<'a> {
let adaptive_l = self.adaptive_l.as_ref().map(|adaptive_l| {
graph::search::AdaptiveL::new(adaptive_l.sample_count.into(), adaptive_l.scale_factor)
.expect("validated adaptive L must construct")
});

match (
self.is_flat_search,
has_vector_filters,
post_processor,
adaptive_l,
) {
(true, false, _, _) => SearchMode::flat(),
(true, true, _, _) => {
SearchMode::flat_filtered(move |vid: &u32| vector_filter.contains(vid))
}
(false, false, Some(TopkPostProcessor::DeterminantDiversity(params)), _) => {
SearchMode::diverse_graph(*params)
}
(false, true, Some(TopkPostProcessor::DeterminantDiversity(params)), _) => {
SearchMode::diverse_graph_filtered(
move |vid: &u32| vector_filter.contains(vid),
*params,
)
}
(false, false, None, Some(adaptive_l)) => {
SearchMode::inline_filter(|_| true, Some(adaptive_l))
}
(false, true, None, Some(adaptive_l)) => SearchMode::inline_filter(
move |vid: &u32| vector_filter.contains(vid),
Some(adaptive_l),
),
(false, false, None, None) => SearchMode::graph(),
(false, true, None, None) => {
SearchMode::graph_filtered(move |vid: &u32| vector_filter.contains(vid))
}
}
}

pub(crate) fn validate(&mut self, checker: &mut Checker) -> Result<(), anyhow::Error> {
if let Some(adaptive_l) = self.adaptive_l.as_mut() {
adaptive_l.validate(checker)?;
}
Ok(())
}
}

#[cfg(feature = "disk-index")]
impl fmt::Display for DiskSearchMode {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let base = if self.is_flat_search { "flat" } else { "graph" };
if self.adaptive_l.is_some() {
write!(f, "{} + adaptive-l", base)
} else {
write!(f, "{}", base)
}
}
}

/// Search phase configuration
#[derive(Debug, Deserialize, Serialize)]
pub(crate) struct DiskSearchPhase {
Expand All @@ -71,6 +155,15 @@ pub(crate) struct DiskSearchPhase {
pub(crate) beam_width: usize,
pub(crate) search_list: Vec<u32>,
pub(crate) recall_at: u32,
#[cfg(feature = "disk-index")]
#[serde(default)]
pub(crate) search_mode: DiskSearchMode,
// Backward compatibility for older benchmark inputs that used
// `is_flat_search` directly at the search-phase level.
#[cfg(feature = "disk-index")]
#[serde(default, skip_serializing)]
pub(crate) is_flat_search: Option<bool>,
#[cfg(not(feature = "disk-index"))]
pub(crate) is_flat_search: bool,
pub(crate) distance: SimilarityMeasure,
pub(crate) vector_filters_file: Option<InputFile>,
Expand Down Expand Up @@ -181,6 +274,16 @@ impl DiskSearchPhase {
vf.resolve(checker).context("invalid vector_filters_file")?;
}

#[cfg(feature = "disk-index")]
if let Some(is_flat_search) = self.is_flat_search {
self.search_mode.is_flat_search = is_flat_search;
}

#[cfg(feature = "disk-index")]
self.search_mode
.validate(checker)
.context("invalid disk search mode")?;

// basic numeric sanity checks
if self.search_list.is_empty() {
anyhow::bail!("search_list must have at least one value");
Expand Down Expand Up @@ -250,6 +353,14 @@ impl Example for DiskIndexOperation {
beam_width: 16,
recall_at: 10,
num_threads: 8,
#[cfg(feature = "disk-index")]
search_mode: DiskSearchMode {
is_flat_search: false,
adaptive_l: None,
},
#[cfg(feature = "disk-index")]
is_flat_search: None,
#[cfg(not(feature = "disk-index"))]
is_flat_search: false,
distance: SimilarityMeasure::SquaredL2,
vector_filters_file: None,
Expand Down Expand Up @@ -367,6 +478,9 @@ impl DiskSearchPhase {
write_field!(f, "Beam Width", self.beam_width)?;
write_field!(f, "Recall@", self.recall_at)?;
write_field!(f, "Threads", self.num_threads)?;
#[cfg(feature = "disk-index")]
write_field!(f, "Search Mode", self.search_mode)?;
#[cfg(not(feature = "disk-index"))]
write_field!(f, "Flat Search", self.is_flat_search)?;
write_field!(f, "Distance", self.distance)?;
match &self.vector_filters_file {
Expand Down
31 changes: 31 additions & 0 deletions diskann-benchmark/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -688,6 +688,37 @@ mod tests {
run_integration_test(raw);
}

/// Filtered disk search end-to-end: drives the disk-index backend through
/// `disk-index-filter.json`
#[test]
#[cfg(feature = "disk-index")]
fn disk_index_filter_integration() {
let mut raw = value_from_file(&example_directory().join("disk-index-filter.json"));
prefix_search_directories(&mut raw, &root_directory());

let tempdir = tempfile::tempdir().unwrap();
let input_path = tempdir.path().join("disk-index-filter.json");
save_to_file(&input_path, &raw);
let output_path = tempdir.path().join("output.json");

let command = Commands::Run {
input_file: input_path.to_owned(),
output_file: output_path.to_owned(),
dry_run: false,
allow_debug: true,
};
let cli = Cli::from_commands(command, true);
let mut output = Memory::new();
let result = cli.run(&mut output);
let output_str = String::from_utf8(output.into_inner()).unwrap();
println!("output = {}", output_str);
result.expect("disk-index-filter run failed");

assert!(output_path.exists());
let results: Vec<Value> = load_from_file(&output_path);
assert_eq!(results.len(), num_jobs(&raw));
}

#[test]
fn graph_index_inline_filter_yfcc_integration() {
// First, parse and modify the input file to establish paths relative to the
Expand Down
4 changes: 1 addition & 3 deletions diskann-disk/src/build/builder/core.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1089,9 +1089,7 @@ pub(crate) mod disk_index_builder_tests {
&mut indices,
&mut distances,
&mut associated_data,
None,
&|_| true,
false,
&crate::search::search_mode::SearchMode::graph(),
);

diskann_providers::test_utils::assert_top_k_exactly_match(
Expand Down
1 change: 1 addition & 0 deletions diskann-disk/src/search/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,5 @@

pub mod pq;
pub mod provider;
pub mod search_mode;
pub mod traits;
Loading
Loading