Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 7 additions & 7 deletions diskann-benchmark-core/src/search/graph/inline.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ use diskann::{
};
use diskann_utils::{future::AsyncFriendly, views::Matrix};

use crate::search::{self, Search, graph::Strategy};
use crate::search::{self, Search, graph::KnnWrapper, graph::Strategy};

/// A built-in helper for benchmarking filtered K-nearest neighbors search
/// using the inline search method.
Expand All @@ -22,7 +22,7 @@ use crate::search::{self, Search, graph::Strategy};
/// [`search::search_all`] is provided by the [`search::graph::knn::Aggregator`] type (same
/// aggregator as [`search::graph::knn::KNN`]).
///
/// The provided implementation of [`Search`] accepts [`graph::search::Knn`]
/// The provided implementation of [`Search`] accepts [`KnnWrapper`]
/// and returns [`search::graph::knn::Metrics`] as additional output.
#[derive(Debug)]
pub struct InlineFilterSearch<DP, T, S>
Expand Down Expand Up @@ -93,7 +93,7 @@ where
T: AsyncFriendly + Clone,
{
type Id = DP::ExternalId;
type Parameters = graph::search::Knn;
type Parameters = KnnWrapper;
type Output = super::knn::Metrics;

fn num_queries(&self) -> usize {
Expand All @@ -115,7 +115,7 @@ where
{
let context = DP::Context::default();
let inline_search = graph::search::InlineFilterSearch::new(
*parameters,
parameters.knn,
&*self.labels[index],
self.adaptive_l.clone(),
);
Expand Down Expand Up @@ -192,7 +192,7 @@ mod tests {
let rt = crate::tokio::runtime(2).unwrap();
let results = search::search(
inline.clone(),
graph::search::Knn::new(nearest_neighbors, 10, None).unwrap(),
KnnWrapper::new(nearest_neighbors, 10).unwrap(),
NonZeroUsize::new(2).unwrap(),
&rt,
)
Expand Down Expand Up @@ -220,11 +220,11 @@ mod tests {
// Try the aggregated strategy.
let parameters = [
search::Run::new(
graph::search::Knn::new(nearest_neighbors, 10, None).unwrap(),
KnnWrapper::new(nearest_neighbors, 10).unwrap(),
setup.clone(),
),
search::Run::new(
graph::search::Knn::new(nearest_neighbors, 15, None).unwrap(),
KnnWrapper::new(nearest_neighbors, 15).unwrap(),
setup.clone(),
),
];
Expand Down
68 changes: 57 additions & 11 deletions diskann-benchmark-core/src/search/graph/knn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,11 @@

//! A built-in helper for benchmarking K-nearest neighbors.

use std::sync::Arc;
use std::{num::NonZeroUsize, sync::Arc};
use thiserror::Error;

use diskann::{
ANNResult,
ANNError, ANNResult,
graph::{self, glue},
provider,
};
Expand All @@ -30,7 +31,7 @@ use crate::{
/// the latter. Result aggregation for [`search::search_all`] is provided
/// by the [`Aggregator`] type.
///
/// The provided implementation of [`Search`] accepts [`graph::search::Knn`]
/// The provided implementation of [`Search`] accepts [`KnnWrapper`]
/// and returns [`Metrics`] as additional output.
#[derive(Debug)]
pub struct KNN<DP, T, S>
Expand Down Expand Up @@ -93,7 +94,7 @@ where
T: AsyncFriendly + Clone,
{
type Id = DP::ExternalId;
type Parameters = graph::search::Knn;
type Parameters = KnnWrapper;
type Output = Metrics;

fn num_queries(&self) -> usize {
Expand All @@ -114,7 +115,7 @@ where
O: graph::SearchOutputBuffer<DP::ExternalId> + Send,
{
let context = DP::Context::default();
let knn_search = *parameters;
let knn_search = parameters.knn;
let stats = self
.index
.search(
Expand All @@ -133,6 +134,51 @@ where
}
}

#[derive(Debug, Error)]
pub enum KnnWrapperError {
#[error("k_value cannot be zero")]
KZero,
#[error("l_value ({l_value}) must be at least k_value ({k_value})")]
LLessThanK { l_value: usize, k_value: usize },
#[error("invalid KNN parameters")]
InvalidKnnParameters,
}
Comment thread
Copilot marked this conversation as resolved.

impl From<KnnWrapperError> for ANNError {
#[track_caller]
fn from(err: KnnWrapperError) -> Self {
ANNError::opaque(err)
}
}

/// A wrapper for the [`graph::search::Knn`] struct that also includes the `k` value.
#[derive(Debug, Copy, Clone)]
pub struct KnnWrapper {
k_value: NonZeroUsize,
pub knn: graph::search::Knn,
}

impl KnnWrapper {
/// Construct a new [`KnnWrapper`].
pub fn new(k_value: usize, l_value: usize) -> Result<Self, KnnWrapperError> {
let k_value = NonZeroUsize::new(k_value).ok_or(KnnWrapperError::KZero)?;
if l_value < k_value.get() {
return Err(KnnWrapperError::LLessThanK {
l_value,
k_value: k_value.get(),
});
}
Comment thread
Copilot marked this conversation as resolved.

let knn = graph::search::Knn::new(l_value, None)
.map_err(|_| KnnWrapperError::InvalidKnnParameters)?;
Ok(Self { k_value, knn })
}

pub fn k_value(&self) -> NonZeroUsize {
self.k_value
}
}

/// An [`search::Aggregate`]d summary of multiple [`KNN`] search runs
/// returned by the provided [`Aggregator`].
///
Expand All @@ -144,7 +190,7 @@ pub struct Summary {
pub setup: search::Setup,

/// The [`Search::Parameters`] used for the batch of runs.
pub parameters: graph::search::Knn,
pub parameters: KnnWrapper,

/// The end-to-end latency for each repetition in the batch.
pub end_to_end_latencies: Vec<MicroSeconds>,
Expand Down Expand Up @@ -212,15 +258,15 @@ impl<'a, I> Aggregator<'a, I> {
}
}

impl<I> search::Aggregate<graph::search::Knn, I, Metrics> for Aggregator<'_, I>
impl<I> search::Aggregate<KnnWrapper, I, Metrics> for Aggregator<'_, I>
where
I: crate::recall::RecallCompatible,
{
type Output = Summary;

fn aggregate(
&mut self,
run: search::Run<graph::search::Knn>,
run: search::Run<KnnWrapper>,
mut results: Vec<search::SearchResults<I, Metrics>>,
) -> anyhow::Result<Summary> {
// Compute the recall using just the first result.
Expand Down Expand Up @@ -317,7 +363,7 @@ mod tests {
let rt = crate::tokio::runtime(2).unwrap();
let results = search::search(
knn.clone(),
graph::search::Knn::new(nearest_neighbors, 10, None).unwrap(),
KnnWrapper::new(nearest_neighbors, 10).unwrap(),
NonZeroUsize::new(2).unwrap(),
&rt,
)
Expand All @@ -341,11 +387,11 @@ mod tests {
// Try the aggregated strategy.
let parameters = [
search::Run::new(
graph::search::Knn::new(nearest_neighbors, 10, None).unwrap(),
KnnWrapper::new(nearest_neighbors, 10).unwrap(),
setup.clone(),
),
search::Run::new(
graph::search::Knn::new(nearest_neighbors, 15, None).unwrap(),
KnnWrapper::new(nearest_neighbors, 15).unwrap(),
setup.clone(),
),
];
Expand Down
2 changes: 1 addition & 1 deletion diskann-benchmark-core/src/search/graph/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ pub mod range;
pub mod strategy;

pub use inline::InlineFilterSearch;
pub use knn::KNN;
pub use knn::{KNN, KnnWrapper};
pub use multihop::MultiHop;
pub use range::Range;

Expand Down
14 changes: 7 additions & 7 deletions diskann-benchmark-core/src/search/graph/multihop.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ use diskann::{
};
use diskann_utils::{future::AsyncFriendly, views::Matrix};

use crate::search::{self, Search, graph::Strategy};
use crate::search::{self, Search, graph::KnnWrapper, graph::Strategy};

/// A built-in helper for benchmarking filtered K-nearest neighbors search
/// using the multi-hop search method.
Expand All @@ -22,7 +22,7 @@ use crate::search::{self, Search, graph::Strategy};
/// [`search::search_all`] is provided by the [`search::graph::knn::Aggregator`] type (same
/// aggregator as [`search::graph::knn::KNN`]).
///
/// The provided implementation of [`Search`] accepts [`graph::search::Knn`]
/// The provided implementation of [`Search`] accepts [`KnnWrapper`]
/// and returns [`search::graph::knn::Metrics`] as additional output.
#[derive(Debug)]
pub struct MultiHop<DP, T, S>
Expand Down Expand Up @@ -90,7 +90,7 @@ where
T: AsyncFriendly + Clone,
{
type Id = DP::ExternalId;
type Parameters = graph::search::Knn;
type Parameters = KnnWrapper;
type Output = super::knn::Metrics;

fn num_queries(&self) -> usize {
Expand All @@ -112,7 +112,7 @@ where
{
let context = DP::Context::default();
let multihop_search =
graph::search::MultihopFilterSearch::new(*parameters, &*self.labels[index]);
graph::search::MultihopFilterSearch::new(parameters.knn, &*self.labels[index]);
let stats = self
.index
.search(
Expand Down Expand Up @@ -183,7 +183,7 @@ mod tests {
let rt = crate::tokio::runtime(2).unwrap();
let results = search::search(
multihop.clone(),
graph::search::Knn::new(nearest_neighbors, 10, None).unwrap(),
KnnWrapper::new(nearest_neighbors, 10).unwrap(),
NonZeroUsize::new(2).unwrap(),
&rt,
)
Expand Down Expand Up @@ -211,11 +211,11 @@ mod tests {
// Try the aggregated strategy.
let parameters = [
search::Run::new(
graph::search::Knn::new(nearest_neighbors, 10, None).unwrap(),
KnnWrapper::new(nearest_neighbors, 10).unwrap(),
setup.clone(),
),
search::Run::new(
graph::search::Knn::new(nearest_neighbors, 15, None).unwrap(),
KnnWrapper::new(nearest_neighbors, 15).unwrap(),
setup.clone(),
),
];
Expand Down
2 changes: 1 addition & 1 deletion diskann-benchmark/src/backend/index/result.rs
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ impl SearchResults {
Self {
num_tasks: setup.tasks.into(),
search_n: parameters.k_value().get(),
search_l: parameters.l_value().get(),
search_l: parameters.knn.l_value().get(),
qps,
search_latencies: end_to_end_latencies,
mean_latencies,
Expand Down
19 changes: 9 additions & 10 deletions diskann-benchmark/src/backend/index/search/knn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@

use std::{num::NonZeroUsize, sync::Arc};

use diskann_benchmark_core::recall::GroundTruthMode;
use diskann_benchmark_core::{self as benchmark_core, search as core_search};
use diskann_benchmark_core::{recall::GroundTruthMode, search::graph::KnnWrapper};

use crate::{backend::index::result::SearchResults, inputs::graph_index::GraphSearch};

Expand Down Expand Up @@ -50,8 +50,7 @@ pub(crate) fn run<I>(
.search_l
.iter()
.map(|search_l| {
let search_params =
diskann::graph::search::Knn::new(run.search_n, *search_l, None).unwrap();
let search_params = KnnWrapper::new(run.search_n, *search_l).unwrap();

core_search::Run::new(search_params, setup.clone())
})
Expand All @@ -64,7 +63,7 @@ pub(crate) fn run<I>(
Ok(all)
}

type Run = core_search::Run<diskann::graph::search::Knn>;
type Run = core_search::Run<KnnWrapper>;
pub(crate) trait Knn<I> {
fn search_all(
&self,
Expand All @@ -84,13 +83,13 @@ where
DP: diskann::provider::DataProvider,
core_search::graph::KNN<DP, T, S>: core_search::Search<
Id = DP::InternalId,
Parameters = diskann::graph::search::Knn,
Parameters = KnnWrapper,
Output = core_search::graph::knn::Metrics,
>,
{
fn search_all(
&self,
parameters: Vec<core_search::Run<diskann::graph::search::Knn>>,
parameters: Vec<core_search::Run<KnnWrapper>>,
groundtruth: &dyn benchmark_core::recall::Rows<DP::InternalId>,
recall_k: usize,
recall_n: usize,
Expand All @@ -115,13 +114,13 @@ where
DP: diskann::provider::DataProvider,
core_search::graph::MultiHop<DP, T, S>: core_search::Search<
Id = DP::InternalId,
Parameters = diskann::graph::search::Knn,
Parameters = KnnWrapper,
Output = core_search::graph::knn::Metrics,
>,
{
fn search_all(
&self,
parameters: Vec<core_search::Run<diskann::graph::search::Knn>>,
parameters: Vec<core_search::Run<KnnWrapper>>,
groundtruth: &dyn benchmark_core::recall::Rows<DP::InternalId>,
recall_k: usize,
recall_n: usize,
Expand All @@ -146,13 +145,13 @@ where
DP: diskann::provider::DataProvider,
core_search::graph::InlineFilterSearch<DP, T, S>: core_search::Search<
Id = DP::InternalId,
Parameters = diskann::graph::search::Knn,
Parameters = KnnWrapper,
Output = core_search::graph::knn::Metrics,
>,
{
fn search_all(
&self,
parameters: Vec<core_search::Run<diskann::graph::search::Knn>>,
parameters: Vec<core_search::Run<KnnWrapper>>,
groundtruth: &dyn benchmark_core::recall::Rows<DP::InternalId>,
recall_k: usize,
recall_n: usize,
Expand Down
Loading
Loading